{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4b3746c1",
   "metadata": {},
   "source": [
    "# Non-Conjugate Priors"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d89a712c",
   "metadata": {},
   "source": [
    "# GPU\n",
    "\n",
    "This lecture was built using a machine with the latest CUDA and CUDANN frameworks installed with access to a GPU.\n",
    "\n",
    "To run this lecture on [Google Colab](https://colab.research.google.com/), click on the “play” icon top right, select Colab, and set the runtime environment to include a GPU.\n",
    "\n",
    "To run this lecture on your own machine, you need to install the software listed following this notice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "620e326a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "!pip install numpyro pyro-ppl torch jax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2095810e",
   "metadata": {},
   "source": [
    "This lecture is a sequel to the [quantecon lecture](https://python.quantecon.org/prob_meaning.html).\n",
    "\n",
    "That lecture offers a Bayesian interpretation of probability in a setting in which the likelihood function and the prior distribution\n",
    "over parameters just happened to form a **conjugate** pair in which\n",
    "\n",
    "- application of Bayes’ Law produces a posterior distribution that has the same functional form as the prior  \n",
    "\n",
    "\n",
    "Having a likelihood and prior that  are conjugate can simplify calculation of a posterior, faciltating  analytical or nearly analytical calculations.\n",
    "\n",
    "But in many situations  the likelihood and prior need not form a conjugate pair.\n",
    "\n",
    "- after all, a person’s prior is his or her own business and would take a form conjugate to a likelihood only by remote coincidence  \n",
    "\n",
    "\n",
    "In these situations, computing a posterior can become very challenging.\n",
    "\n",
    "In this lecture, we illustrate how modern Bayesians confront non-conjugate priors  by using  Monte Carlo techniques that involve\n",
    "\n",
    "- first  cleverly forming a Markov chain whose invariant distribution is the posterior distribution we want  \n",
    "- simulating the Markov chain until it has converged and then sampling from the invariant distribution to approximate the posterior  \n",
    "\n",
    "\n",
    "We shall illustrate the approach by deploying two powerful Python modules that implement this approach as well as another closely related one to\n",
    "be described below.\n",
    "\n",
    "The two Python modules are\n",
    "\n",
    "- `numpyro`  \n",
    "- `pymc4`  \n",
    "\n",
    "\n",
    "As usual, we begin by importing some Python code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28bdd0ff",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import binom\n",
    "import scipy.stats as st\n",
    "import torch\n",
    "\n",
    "# jax\n",
    "import jax.numpy as jnp\n",
    "from jax import lax, random\n",
    "\n",
    "# pyro\n",
    "import pyro\n",
    "from pyro import distributions as dist\n",
    "import pyro.distributions.constraints as constraints\n",
    "from pyro.infer import MCMC, NUTS, SVI, ELBO, Trace_ELBO\n",
    "from pyro.optim import Adam\n",
    "\n",
    "# numpyro\n",
    "import numpyro\n",
    "from numpyro import distributions as ndist\n",
    "import numpyro.distributions.constraints as nconstraints\n",
    "from numpyro.infer import MCMC as nMCMC\n",
    "from numpyro.infer import NUTS as nNUTS\n",
    "from numpyro.infer import SVI as nSVI\n",
    "from numpyro.infer import ELBO as nELBO\n",
    "from numpyro.infer import Trace_ELBO as nTrace_ELBO\n",
    "from numpyro.optim import Adam as nAdam"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7d6726f",
   "metadata": {},
   "source": [
    "## Unleashing MCMC on a  Binomial Likelihood\n",
    "\n",
    "This lecture begins with the binomial example in the [quantecon lecture](https://python.quantecon.org/prob_meaning.html).\n",
    "\n",
    "That lecture computed a posterior\n",
    "\n",
    "- analytically via choosing the conjugate priors,  \n",
    "\n",
    "\n",
    "This lecture instead computes posteriors\n",
    "\n",
    "- numerically by sampling from the posterior distribution through MCMC methods, and  \n",
    "- using a variational inference (VI) approximation.  \n",
    "\n",
    "\n",
    "We use both the packages `pyro` and `numpyro` with assistance from  `jax` to approximate a  posterior distribution\n",
    "\n",
    "We use several alternative prior distributions\n",
    "\n",
    "We  compare computed posteriors  with ones associated with a conjugate prior as described in  [the quantecon lecture](https://python.quantecon.org/prob_meaning.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebda3ecf",
   "metadata": {},
   "source": [
    "### Analytical Posterior\n",
    "\n",
    "Assume that the random variable $ X\\sim Binom\\left(n,\\theta\\right) $.\n",
    "\n",
    "This defines a likelihood function\n",
    "\n",
    "$$\n",
    "L\\left(Y\\vert\\theta\\right) = \\textrm{Prob}(X =  k | \\theta) =\n",
    "\\left(\\frac{n!}{k! (n-k)!} \\right) \\theta^k (1-\\theta)^{n-k}\n",
    "$$\n",
    "\n",
    "where $ Y=k $ is an observed data point.\n",
    "\n",
    "We view  $ \\theta $ as a random variable for which we assign a prior distribution having density $ f(\\theta) $.\n",
    "\n",
    "We will try alternative priors later, but for now, suppose the prior is distributed as $ \\theta\\sim Beta\\left(\\alpha,\\beta\\right) $, i.e.,\n",
    "\n",
    "$$\n",
    "f(\\theta) = \\textrm{Prob}(\\theta) = \\frac{\\theta^{\\alpha - 1} (1 - \\theta)^{\\beta - 1}}{B(\\alpha, \\beta)}\n",
    "$$\n",
    "\n",
    "We choose this as our prior for now because  we know that a conjugate prior for the binomial likelihood function is a beta distribution.\n",
    "\n",
    "After observing  $ k $ successes among $ N $ sample observations, the posterior  probability distributionof  $ \\theta $ is\n",
    "\n",
    "$$\n",
    "\\textrm{Prob}(\\theta|k) = \\frac{\\textrm{Prob}(\\theta,k)}{\\textrm{Prob}(k)}=\\frac{\\textrm{Prob}(k|\\theta)\\textrm{Prob}(\\theta)}{\\textrm{Prob}(k)}=\\frac{\\textrm{Prob}(k|\\theta) \\textrm{Prob}(\\theta)}{\\int_0^1 \\textrm{Prob}(k|\\theta)\\textrm{Prob}(\\theta) d\\theta}\n",
    "$$\n",
    "\n",
    "$$\n",
    "=\\frac{{N \\choose k} (1 - \\theta)^{N-k} \\theta^k \\frac{\\theta^{\\alpha - 1} (1 - \\theta)^{\\beta - 1}}{B(\\alpha, \\beta)}}{\\int_0^1 {N \\choose k} (1 - \\theta)^{N-k} \\theta^k\\frac{\\theta^{\\alpha - 1} (1 - \\theta)^{\\beta - 1}}{B(\\alpha, \\beta)} d\\theta}\n",
    "$$\n",
    "\n",
    "$$\n",
    "=\\frac{(1 -\\theta)^{\\beta+N-k-1} \\theta^{\\alpha+k-1}}{\\int_0^1 (1 - \\theta)^{\\beta+N-k-1} \\theta^{\\alpha+k-1} d\\theta} .\n",
    "$$\n",
    "\n",
    "Thus,\n",
    "\n",
    "$$\n",
    "\\textrm{Prob}(\\theta|k) \\sim {Beta}(\\alpha + k, \\beta+N-k)\n",
    "$$\n",
    "\n",
    "The analytical posterior for a given conjugate beta prior is coded in the following Python code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb2197c2",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "def simulate_draw(theta, n):\n",
    "    \"\"\"\n",
    "    Draws a Bernoulli sample of size n with probability P(Y=1) = theta\n",
    "    \"\"\"\n",
    "    rand_draw = np.random.rand(n)\n",
    "    draw = (rand_draw < theta).astype(int)\n",
    "    return draw\n",
    "\n",
    "\n",
    "def analytical_beta_posterior(data, alpha0, beta0):\n",
    "    \"\"\"\n",
    "    Computes analytically the posterior distribution with beta prior parametrized by (alpha, beta)\n",
    "    given # num observations\n",
    "\n",
    "    Parameters\n",
    "    ---------\n",
    "    num : int.\n",
    "        the number of observations after which we calculate the posterior\n",
    "    alpha0, beta0 : float.\n",
    "        the parameters for the beta distribution as a prior\n",
    "\n",
    "    Returns\n",
    "    ---------\n",
    "    The posterior beta distribution\n",
    "    \"\"\"\n",
    "    num = len(data)\n",
    "    up_num = data.sum()\n",
    "    down_num = num - up_num\n",
    "    return st.beta(alpha0 + up_num, beta0 + down_num)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5b4bd57",
   "metadata": {},
   "source": [
    "### Two Ways to Approximate Posteriors\n",
    "\n",
    "Suppose that we don’t have a conjugate prior.\n",
    "\n",
    "Then  we  can’t compute posteriors analytically.\n",
    "\n",
    "Instead,  we use computational tools to approximate the posterior distribution for a set of alternative prior distributions using both `Pyro` and `Numpyro` packages in Python.\n",
    "\n",
    "We first use the **Markov Chain Monte Carlo** (MCMC) algorithm .\n",
    "\n",
    "We implement the NUTS sampler to sample from the posterior.\n",
    "\n",
    "In that way we construct a sampling distribution that approximates the  posterior.\n",
    "\n",
    "After doing that we deply another procedure called  **Variational Inference** (VI).\n",
    "\n",
    "In particular, we implement Stochastic Variational Inference (SVI) machinery in both `Pyro` and `Numpyro`.\n",
    "\n",
    "The MCMC algorithm  supposedly generates a more accurate approximation since in principle it directly samples from the posterior distribution.\n",
    "\n",
    "But it  can be computationally expensive, especially when dimension is large.\n",
    "\n",
    "A VI approach can be  cheaper, but it is likely to produce an inferior approximation to the posterior, for the simple reason that it requires guessing a parametric **guide functional form** that we use to approximate a posterior.\n",
    "\n",
    "This guide function is likely  at best to be an imperfect approximation.\n",
    "\n",
    "By paying the cost of restricting the putative posterior to have a restricted functional form,\n",
    "the problem of approximating a posteriors is transformed to a well-posed optimization problem that seeks parameters of the putative posterior  that minimize\n",
    "a Kullback-Leibler (KL) divergence between true posterior and the putatitive posterior  distribution.\n",
    "\n",
    "- minimizing the KL divergence is  equivalent with  maximizing a criterion called  the **Evidence Lower Bound** (ELBO), as we shall verify soon.  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b8dda07",
   "metadata": {},
   "source": [
    "## Prior Distributions\n",
    "\n",
    "In order to be able to apply MCMC sampling or VI, `Pyro` and `Numpyro` require  that a prior distribution satisfy special properties:\n",
    "\n",
    "- we must be able sample from it;  \n",
    "- we must be able to  compute the log pdf  pointwise;  \n",
    "- the pdf must be  differentiable with respect to the parameters.  \n",
    "\n",
    "\n",
    "We’ll want to define a distribution `class`.\n",
    "\n",
    "We  will use the following priors:\n",
    "\n",
    "- a uniform distribution on $ [\\underline \\theta, \\overline \\theta] $, where $ 0 \\leq \\underline \\theta < \\overline \\theta \\leq 1 $.  \n",
    "- a truncated log-normal distribution with support on $ [0,1] $ with parameters $ (\\mu,\\sigma) $.  \n",
    "  - To implement this, let $ Z\\sim Normal(\\mu,\\sigma) $ and $ \\tilde{Z} $ be truncated normal with support $ [\\log(0),\\log(1)] $, then $ \\exp(Z) $ has a log normal distribution with bounded support $ [0,1] $. This can be easily coded since `Numpyro` has a built-in truncated normal distribution, and `Torch` provides a `TransformedDistribution` class that includes an exponential transformation.  \n",
    "  - Alternatively, we can use a rejection sampling strategy by assigning the probability rate to $ 0 $ outside the bounds and rescaling accepted samples, i.e., realizations that are within the bounds, by the total probability computed via CDF of the original distribution. This can be implemented by defining a truncated distribution class with `pyro`’s `dist.Rejector` class.  \n",
    "  - We implement both methods in the below section and verify that they  produce the same result.  \n",
    "- a shifted von Mises distribution that has support confined to $ [0,1] $ with parameter $ (\\mu,\\kappa) $.  \n",
    "  - Let $ X\\sim vonMises(0,\\kappa) $. We know that $ X $ has bounded support $ [-\\pi, \\pi] $. We can define a shifted von Mises random variable $ \\tilde{X}=a+bX $ where $ a=0.5, b=1/(2 \\pi) $ so that $ \\tilde{X} $ is supported on $ [0,1] $.  \n",
    "  - This can be implemented using `Torch`’s `TransformedDistribution` class  with its `AffineTransform` method.  \n",
    "  - If instead, we want the prior to be von-Mises distributed with center $ \\mu=0.5 $, we can choose a high concentration level $ \\kappa $ so that most mass is located between $ 0 $ and $ 1 $. Then we can truncate the distribution using the above strategy. This can be implemented using  `pyro`’s `dist.Rejector` class. We choose $ \\kappa > 40 $ in this case.  \n",
    "- a truncated Laplace distribution.  \n",
    "  - We also considered a truncated Laplace distribution because its density comes in a piece-wise non-smooth form and has a distinctive spiked shape.  \n",
    "  - The truncated Laplace can be created using `Numpyro`’s `TruncatedDistribution` class.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0fe01f5",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# used by Numpyro\n",
    "def TruncatedLogNormal_trans(loc, scale):\n",
    "    \"\"\"\n",
    "    Obtains the truncated log normal distribution using numpyro's TruncatedNormal and ExpTransform\n",
    "    \"\"\"\n",
    "    base_dist = ndist.TruncatedNormal(low=jnp.log(0), high=jnp.log(1), loc=loc, scale=scale)\n",
    "    return ndist.TransformedDistribution(\n",
    "        base_dist,ndist.transforms.ExpTransform()\n",
    "        )\n",
    "\n",
    "def ShiftedVonMises(kappa):\n",
    "    \"\"\"\n",
    "    Obtains the shifted von Mises distribution using AffineTransform\n",
    "    \"\"\"\n",
    "    base_dist = ndist.VonMises(0, kappa)\n",
    "    return ndist.TransformedDistribution(\n",
    "        base_dist, ndist.transforms.AffineTransform(loc=0.5, scale=1/(2*jnp.pi))\n",
    "        )\n",
    "\n",
    "def TruncatedLaplace(loc, scale):\n",
    "    \"\"\"\n",
    "    Obtains the truncated Laplace distribution on [0,1]\n",
    "    \"\"\"\n",
    "    base_dist = ndist.Laplace(loc, scale)\n",
    "    return ndist.TruncatedDistribution(\n",
    "        base_dist, low=0.0, high=1.0\n",
    "    )\n",
    "\n",
    "# used by Pyro\n",
    "class TruncatedLogNormal(dist.Rejector):\n",
    "    \"\"\"\n",
    "    Define a TruncatedLogNormal distribution through rejection sampling in Pyro\n",
    "    \"\"\"\n",
    "    def __init__(self, loc, scale_0, upp=1):\n",
    "        self.upp = upp\n",
    "        propose = dist.LogNormal(loc, scale_0)\n",
    "\n",
    "        def log_prob_accept(x):\n",
    "            return (x < upp).type_as(x).log()\n",
    "\n",
    "        log_scale = dist.LogNormal(loc, scale_0).cdf(torch.as_tensor(upp)).log()\n",
    "        super(TruncatedLogNormal, self).__init__(propose, log_prob_accept, log_scale)\n",
    "\n",
    "    @constraints.dependent_property\n",
    "    def support(self):\n",
    "        return constraints.interval(0, self.upp)\n",
    "\n",
    "\n",
    "class TruncatedvonMises(dist.Rejector):\n",
    "    \"\"\"\n",
    "    Define a TruncatedvonMises distribution through rejection sampling in Pyro\n",
    "    \"\"\"\n",
    "    def __init__(self, kappa, mu=0.5, low=0.0, upp=1.0):\n",
    "        self.low, self.upp = low, upp\n",
    "        propose = dist.VonMises(mu, kappa)\n",
    "\n",
    "        def log_prob_accept(x):\n",
    "            return ((x > low) & (x < upp)).type_as(x).log()\n",
    "\n",
    "        log_scale = torch.log(\n",
    "            torch.tensor(\n",
    "                st.vonmises(kappa=kappa, loc=mu).cdf(upp)\n",
    "                - st.vonmises(kappa=kappa, loc=mu).cdf(low))\n",
    "        )\n",
    "        super(TruncatedvonMises, self).__init__(propose, log_prob_accept, log_scale)\n",
    "\n",
    "    @constraints.dependent_property\n",
    "    def support(self):\n",
    "        return constraints.interval(self.low, self.upp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20fa89e8",
   "metadata": {},
   "source": [
    "### Variational Inference\n",
    "\n",
    "Instead of directly sampling from the posterior,  the **variational inference**  methodw approximates an unknown posterior distribution with  a family of tractable distributions/densities.\n",
    "\n",
    "It then seeks to minimizes a measure of statistical discrepancy between the approximating and  true posteriors.\n",
    "\n",
    "Thus variational inference (VI)  approximates a posterior by solving  a  minimization problem.\n",
    "\n",
    "Let the latent parameter/variable that we want to infer  be $ \\theta $.\n",
    "\n",
    "Let the  prior be  $ p(\\theta) $ and the likelihood be $ p\\left(Y\\vert\\theta\\right) $.\n",
    "\n",
    "We want  $ p\\left(\\theta\\vert Y\\right) $.\n",
    "\n",
    "Bayes’ rule implies\n",
    "\n",
    "$$\n",
    "p\\left(\\theta\\vert Y\\right)=\\frac{p\\left(Y,\\theta\\right)}{p\\left(Y\\right)}=\\frac{p\\left(Y\\vert\\theta\\right)p\\left(\\theta\\right)}{p\\left(Y\\right)}\n",
    "$$\n",
    "\n",
    "where\n",
    "\n",
    "\n",
    "<a id='equation-eq-intchallenge'></a>\n",
    "$$\n",
    "p\\left(Y\\right)=\\int d\\theta p\\left(Y\\mid\\theta\\right)p\\left(Y\\right). \\tag{46.1}\n",
    "$$\n",
    "\n",
    "The integral on the right side of [(46.1)](#equation-eq-intchallenge)  is typically difficult to compute.\n",
    "\n",
    "Consider a  **guide distribution** $ q_{\\phi}(\\theta) $ parameterized by $ \\phi $ that we’ll use to approximate the posterior.\n",
    "\n",
    "We choose  parameters $ \\phi $ of the guide distribution to minimize a Kullback-Leibler (KL)  divergence between the approximate posterior $ q_{\\phi}(\\theta) $ and  the posterior:\n",
    "\n",
    "$$\n",
    "D_{KL}(q(\\theta;\\phi)\\;\\|\\;p(\\theta\\mid Y)) \\equiv -\\int d\\theta q(\\theta;\\phi)\\log\\frac{p(\\theta\\mid Y)}{q(\\theta;\\phi)}\n",
    "$$\n",
    "\n",
    "Thus, we want a **variational distribution** $ q $ that solves\n",
    "\n",
    "$$\n",
    "\\min_{\\phi}\\quad D_{KL}(q(\\theta;\\phi)\\;\\|\\;p(\\theta\\mid Y))\n",
    "$$\n",
    "\n",
    "Note that\n",
    "\n",
    "$$\n",
    "\\begin{aligned}D_{KL}(q(\\theta;\\phi)\\;\\|\\;p(\\theta\\mid Y)) & =-\\int d\\theta q(\\theta;\\phi)\\log\\frac{P(\\theta\\mid Y)}{q(\\theta;\\phi)}\\\\\n",
    " & =-\\int d\\theta q(\\theta)\\log\\frac{\\frac{p(\\theta,Y)}{p(Y)}}{q(\\theta)}\\\\\n",
    " & =-\\int d\\theta q(\\theta)\\log\\frac{p(\\theta,Y)}{p(\\theta)q(Y)}\\\\\n",
    " & =-\\int d\\theta q(\\theta)\\left[\\log\\frac{p(\\theta,Y)}{q(\\theta)}-\\log p(Y)\\right]\\\\\n",
    " & =-\\int d\\theta q(\\theta)\\log\\frac{p(\\theta,Y)}{q(\\theta)}+\\int d\\theta q(\\theta)\\log p(Y)\\\\\n",
    " & =-\\int d\\theta q(\\theta)\\log\\frac{p(\\theta,Y)}{q(\\theta)}+\\log p(Y)\\\\\n",
    "\\log p(Y)&=D_{KL}(q(\\theta;\\phi)\\;\\|\\;p(\\theta\\mid Y))+\\int d\\theta q_{\\phi}(\\theta)\\log\\frac{p(\\theta,Y)}{q_{\\phi}(\\theta)}\n",
    "\\end{aligned}\n",
    "$$\n",
    "\n",
    "For  observed data $ Y $, $ p(\\theta,Y) $ is a constant, so minimizing KL divergence is equivalent to maximizing\n",
    "\n",
    "\n",
    "<a id='equation-eq-elbo'></a>\n",
    "$$\n",
    "ELBO\\equiv\\int d\\theta q_{\\phi}(\\theta)\\log\\frac{p(\\theta,Y)}{q_{\\phi}(\\theta)}=\\mathbb{E}_{q_{\\phi}(\\theta)}\\left[\\log p(\\theta,Y)-\\log q_{\\phi}(\\theta)\\right] \\tag{46.2}\n",
    "$$\n",
    "\n",
    "Formula [(46.2)](#equation-eq-elbo) is called  the evidence lower bound (ELBO).\n",
    "\n",
    "A standard optimization routine can used to search for the optimal $ \\phi $ in our parametrized distribution $ q_{\\phi}(\\theta) $.\n",
    "\n",
    "The parameterized  distribution $ q_{\\phi}(\\theta) $ is called the **variational distribution**.\n",
    "\n",
    "We can implement Stochastic Variational Inference (SVI) in Pyro and Numpyro using the `Adam` gradient descent algorithm to approximate posterior.\n",
    "\n",
    "We use  two sets of variational distributions: Beta and TruncatedNormal with support $ [0,1] $\n",
    "\n",
    "- Learnable parameters for the Beta distribution are (alpha, beta), both of which are positive.  \n",
    "- Learnable parameters for the Truncated Normal distribution are (loc, scale).  \n",
    "\n",
    "\n",
    ">**Note**\n",
    ">\n",
    ">We restrict the truncated Normal parameter ‘loc’ to be in the interval $ [0,1] $"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "507e9be9",
   "metadata": {},
   "source": [
    "## Implementation\n",
    "\n",
    "We have constructed a Python class `BaysianInference` that requires the following arguments to be initialized:\n",
    "\n",
    "- `param`: a tuple/scalar of parameters dependent on distribution types  \n",
    "- `name_dist`: a string that specifies distribution names  \n",
    "\n",
    "\n",
    "The (`param`, `name_dist`) pair includes:\n",
    "\n",
    "- (‘beta’, alpha, beta)  \n",
    "- (‘uniform’, upper_bound, lower_bound)  \n",
    "- (‘lognormal’, loc, scale)  \n",
    "  - Note: This is the truncated log normal.  \n",
    "- (‘vonMises’, kappa), where kappa denotes concentration parameter, and center location is set to $ 0.5 $.  \n",
    "  - Note: When using `Pyro`, this is the truncated version of the original vonMises distribution;  \n",
    "  - Note: When using `Numpyro`, this is the **shifted** distribution.  \n",
    "- (‘laplace’, loc, scale)  \n",
    "  - Note: This is the truncated Laplace  \n",
    "\n",
    "\n",
    "The class `BaysianInference` has several key methods :\n",
    "\n",
    "- `sample_prior`:  \n",
    "  - This can be used to draw a single sample from the given prior distribution.  \n",
    "- `show_prior`:  \n",
    "  - Plots the approximate prior distribution by repeatedly drawing samples and fitting a kernal density curve.  \n",
    "- `MCMC_sampling`:  \n",
    "  - INPUT: (data, num_samples, num_warmup=1000)  \n",
    "  - Take a `np.array` data and generate MCMC sampling of posterior of size `num_samples`.  \n",
    "- `SVI_run`:  \n",
    "  - INPUT: (data, guide_dist, n_steps=10000)  \n",
    "  - guide_dist = ‘normal’ - use a **truncated** normal distribution as the parametrized guide  \n",
    "  - guide_dist = ‘beta’ - use a beta distribution as the parametrized guide  \n",
    "  - RETURN: (params, losses) - the learned parameters in a `dict` and the vector of loss at each step.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccff50b0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "class BayesianInference:\n",
    "    def __init__(self, param, name_dist, solver):\n",
    "        \"\"\"\n",
    "        Parameters\n",
    "        ---------\n",
    "        param : tuple.\n",
    "            a tuple object that contains all relevant parameters for the distribution\n",
    "        dist : str.\n",
    "            name of the distribution - 'beta', 'uniform', 'lognormal', 'vonMises', 'tent'\n",
    "        solver : str.\n",
    "            either pyro or numpyro\n",
    "        \"\"\"\n",
    "        self.param = param\n",
    "        self.name_dist = name_dist\n",
    "        self.solver = solver\n",
    "\n",
    "        # jax requires explicit PRNG state to be passed\n",
    "        self.rng_key = random.PRNGKey(0)\n",
    "\n",
    "\n",
    "    def sample_prior(self):\n",
    "        \"\"\"\n",
    "        Define the prior distribution to sample from in Pyro/Numpyro models.\n",
    "        \"\"\"\n",
    "        if self.name_dist=='beta':\n",
    "            # unpack parameters\n",
    "            alpha0, beta0 = self.param\n",
    "            if self.solver=='pyro':\n",
    "                sample = pyro.sample('theta', dist.Beta(alpha0, beta0))\n",
    "            else:\n",
    "                sample = numpyro.sample('theta', ndist.Beta(alpha0, beta0), rng_key=self.rng_key)\n",
    "\n",
    "        elif self.name_dist=='uniform':\n",
    "            # unpack parameters\n",
    "            lb, ub = self.param\n",
    "            if self.solver=='pyro':\n",
    "                sample = pyro.sample('theta', dist.Uniform(lb, ub))\n",
    "            else:\n",
    "                sample = numpyro.sample('theta', ndist.Uniform(lb, ub), rng_key=self.rng_key)\n",
    "\n",
    "        elif self.name_dist=='lognormal':\n",
    "            # unpack parameters\n",
    "            loc, scale = self.param\n",
    "            if self.solver=='pyro':\n",
    "                sample = pyro.sample('theta', TruncatedLogNormal(loc, scale))\n",
    "            else:\n",
    "                sample = numpyro.sample('theta', TruncatedLogNormal_trans(loc, scale), rng_key=self.rng_key)\n",
    "\n",
    "        elif self.name_dist=='vonMises':\n",
    "            # unpack parameters\n",
    "            kappa = self.param\n",
    "            if self.solver=='pyro':\n",
    "                sample = pyro.sample('theta', TruncatedvonMises(kappa))\n",
    "            else:\n",
    "                sample = numpyro.sample('theta', ShiftedVonMises(kappa), rng_key=self.rng_key)\n",
    "\n",
    "        elif self.name_dist=='laplace':\n",
    "            # unpack parameters\n",
    "            loc, scale = self.param\n",
    "            if self.solver=='pyro':\n",
    "                print(\"WARNING: Please use Numpyro for truncated Laplace.\")\n",
    "                sample = None\n",
    "            else:\n",
    "                sample = numpyro.sample('theta', TruncatedLaplace(loc, scale), rng_key=self.rng_key)\n",
    "\n",
    "        return sample\n",
    "\n",
    "\n",
    "    def show_prior(self, size=1e5, bins=20, disp_plot=1):\n",
    "        \"\"\"\n",
    "        Visualizes prior distribution by sampling from prior and plots the approximated sampling distribution\n",
    "        \"\"\"\n",
    "        self.bins = bins\n",
    "\n",
    "        if self.solver=='pyro':\n",
    "            with pyro.plate('show_prior', size=size):\n",
    "                sample = self.sample_prior()\n",
    "            # to numpy\n",
    "            sample_array = sample.numpy()\n",
    "\n",
    "        elif self.solver=='numpyro':\n",
    "            with numpyro.plate('show_prior', size=size):\n",
    "                sample = self.sample_prior()\n",
    "            # to numpy\n",
    "            sample_array=jnp.asarray(sample)\n",
    "\n",
    "        # plot histogram and kernel density\n",
    "        if disp_plot==1:\n",
    "            sns.displot(sample_array, kde=True, stat='density', bins=bins, height=5, aspect=1.5)\n",
    "            plt.xlim(0, 1)\n",
    "            plt.show()\n",
    "        else:\n",
    "            return sample_array\n",
    "\n",
    "\n",
    "    def model(self, data):\n",
    "        \"\"\"\n",
    "        Define the probabilistic model by specifying prior, conditional likelihood, and data conditioning\n",
    "        \"\"\"\n",
    "        if not torch.is_tensor(data):\n",
    "            data = torch.tensor(data)\n",
    "        # set prior\n",
    "        theta = self.sample_prior()\n",
    "\n",
    "        # sample from conditional likelihood\n",
    "        if self.solver=='pyro':\n",
    "            output = pyro.sample('obs', dist.Binomial(len(data), theta), obs=torch.sum(data))\n",
    "        else:\n",
    "            # Note: numpyro.sample() requires obs=np.ndarray\n",
    "            output = numpyro.sample('obs', ndist.Binomial(len(data), theta), obs=torch.sum(data).numpy())\n",
    "        return output\n",
    "\n",
    "\n",
    "    def MCMC_sampling(self, data, num_samples, num_warmup=1000):\n",
    "        \"\"\"\n",
    "        Computes numerically the posterior distribution with beta prior parametrized by (alpha0, beta0)\n",
    "        given data using MCMC\n",
    "        \"\"\"\n",
    "        # use pyro\n",
    "        if self.solver=='pyro':\n",
    "            # tensorize\n",
    "            data = torch.tensor(data)\n",
    "            nuts_kernel = NUTS(self.model)\n",
    "            mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=num_warmup, disable_progbar=True)\n",
    "            mcmc.run(data)\n",
    "\n",
    "        # use numpyro\n",
    "        elif self.solver=='numpyro':\n",
    "            data = np.array(data, dtype=float)\n",
    "            nuts_kernel = nNUTS(self.model)\n",
    "            mcmc = nMCMC(nuts_kernel, num_samples=num_samples, num_warmup=num_warmup, progress_bar=False)\n",
    "            mcmc.run(self.rng_key, data=data)\n",
    "\n",
    "        # collect samples\n",
    "        samples = mcmc.get_samples()['theta']\n",
    "        return samples\n",
    "\n",
    "\n",
    "    def beta_guide(self, data):\n",
    "        \"\"\"\n",
    "        Defines the candidate parametrized variational distribution that we train to approximate posterior with Pyro/Numpyro\n",
    "        Here we use parameterized beta\n",
    "        \"\"\"\n",
    "        if self.solver=='pyro':\n",
    "            alpha_q = pyro.param('alpha_q', torch.tensor(0.5),\n",
    "                            constraint=constraints.positive)\n",
    "            beta_q = pyro.param('beta_q', torch.tensor(0.5),\n",
    "                            constraint=constraints.positive)\n",
    "            pyro.sample('theta', dist.Beta(alpha_q, beta_q))\n",
    "\n",
    "        else:\n",
    "            alpha_q = numpyro.param('alpha_q', 10,\n",
    "                            constraint=nconstraints.positive)\n",
    "            beta_q = numpyro.param('beta_q', 10,\n",
    "                            constraint=nconstraints.positive)\n",
    "\n",
    "            numpyro.sample('theta', ndist.Beta(alpha_q, beta_q))\n",
    "\n",
    "\n",
    "    def truncnormal_guide(self, data):\n",
    "        \"\"\"\n",
    "        Defines the candidate parametrized variational distribution that we train to approximate posterior with Pyro/Numpyro\n",
    "        Here we use truncated normal on [0,1]\n",
    "        \"\"\"\n",
    "        loc = numpyro.param('loc', 0.5,\n",
    "                        constraint=nconstraints.interval(0.0, 1.0))\n",
    "        scale = numpyro.param('scale', 1,\n",
    "                        constraint=nconstraints.positive)\n",
    "        numpyro.sample('theta', ndist.TruncatedNormal(loc, scale, low=0.0, high=1.0))\n",
    "\n",
    "\n",
    "    def SVI_init(self, guide_dist, lr=0.0005):\n",
    "        \"\"\"\n",
    "        Initiate SVI training mode with Adam optimizer\n",
    "        NOTE: truncnormal_guide can only be used with numpyro solver\n",
    "        \"\"\"\n",
    "        adam_params = {\"lr\": lr}\n",
    "\n",
    "        if guide_dist=='beta':\n",
    "            if self.solver=='pyro':\n",
    "                optimizer = Adam(adam_params)\n",
    "                svi = SVI(self.model, self.beta_guide, optimizer, loss=Trace_ELBO())\n",
    "\n",
    "            elif self.solver=='numpyro':\n",
    "                optimizer = nAdam(step_size=lr)\n",
    "                svi = nSVI(self.model, self.beta_guide, optimizer, loss=nTrace_ELBO())\n",
    "\n",
    "        elif guide_dist=='normal':\n",
    "            # only allow numpyro\n",
    "            if self.solver=='pyro':\n",
    "                print(\"WARNING: Please use Numpyro with TruncatedNormal guide\")\n",
    "                svi = None\n",
    "\n",
    "            elif self.solver=='numpyro':\n",
    "                optimizer = nAdam(step_size=lr)\n",
    "                svi = nSVI(self.model, self.truncnormal_guide, optimizer, loss=nTrace_ELBO())\n",
    "        else:\n",
    "            print(\"WARNING: Please input either 'beta' or 'normal'\")\n",
    "            svi = None\n",
    "\n",
    "        return svi\n",
    "\n",
    "    def SVI_run(self, data, guide_dist, n_steps=10000):\n",
    "        \"\"\"\n",
    "        Runs SVI and returns optimized parameters and losses\n",
    "\n",
    "        Returns\n",
    "        --------\n",
    "        params : the learned parameters for guide\n",
    "        losses : a vector of loss at each step\n",
    "        \"\"\"\n",
    "\n",
    "        # initiate SVI\n",
    "        svi = self.SVI_init(guide_dist=guide_dist)\n",
    "\n",
    "        # do gradient steps\n",
    "        if self.solver=='pyro':\n",
    "             # tensorize data\n",
    "            if not torch.is_tensor(data):\n",
    "                data = torch.tensor(data)\n",
    "            # store loss vector\n",
    "            losses = np.zeros(n_steps)\n",
    "            for step in range(n_steps):\n",
    "                losses[step] = svi.step(data)\n",
    "\n",
    "            # pyro only supports beta VI distribution\n",
    "            params = {\n",
    "                'alpha_q': pyro.param('alpha_q').item(),\n",
    "                'beta_q': pyro.param('beta_q').item()\n",
    "                }\n",
    "\n",
    "        elif self.solver=='numpyro':\n",
    "            data = np.array(data, dtype=float)\n",
    "            result = svi.run(self.rng_key, n_steps, data, progress_bar=False)\n",
    "            params = dict(\n",
    "                (key, np.asarray(value)) for key, value in result.params.items()\n",
    "                )\n",
    "            losses = np.asarray(result.losses)\n",
    "\n",
    "        return params, losses"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d896bc2d",
   "metadata": {},
   "source": [
    "## Alternative Prior Distributions\n",
    "\n",
    "Let’s see how well our sampling algorithm does in approximating\n",
    "\n",
    "- a log normal distribution  \n",
    "- a uniform distribution  \n",
    "\n",
    "\n",
    "To examine our alternative  prior distributions, we’ll plot  approximate prior distributions below by calling the `show_prior` method.\n",
    "\n",
    "We verify that the rejection sampling strategy under `Pyro` produces the same log normal distribution as the truncated normal transformation under `Numpyro`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db51038d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# truncated log normal\n",
    "exampleLN = BayesianInference(param=(0,2), name_dist='lognormal', solver='numpyro')\n",
    "exampleLN.show_prior(size=100000,bins=20)\n",
    "\n",
    "# truncated uniform\n",
    "exampleUN = BayesianInference(param=(0.1,0.8), name_dist='uniform', solver='numpyro')\n",
    "exampleUN.show_prior(size=100000,bins=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9c84ec3",
   "metadata": {},
   "source": [
    "The above graphs show that sampling seems to work well with both distributions.\n",
    "\n",
    "Now let’s see how well things work with a couple of von Mises distributions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53180381",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# shifted von Mises\n",
    "exampleVM = BayesianInference(param=10, name_dist='vonMises', solver='numpyro')\n",
    "exampleVM.show_prior(size=100000,bins=20)\n",
    "\n",
    "# truncated von Mises\n",
    "exampleVM_trunc = BayesianInference(param=20, name_dist='vonMises', solver='pyro')\n",
    "exampleVM_trunc.show_prior(size=100000,bins=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f571ee1e",
   "metadata": {},
   "source": [
    "These graphs look good too.\n",
    "\n",
    "Now let’s try with a Laplace distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c323067",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# truncated Laplace\n",
    "exampleLP = BayesianInference(param=(0.5,0.05), name_dist='laplace', solver='numpyro')\n",
    "exampleLP.show_prior(size=100000,bins=40)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cb37052",
   "metadata": {},
   "source": [
    "Having assured ourselves that our sampler seems to do a good job, let’s put it to work in using MCMC to compute posterior probabilities."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12504cb1",
   "metadata": {},
   "source": [
    "## Posteriors Via MCMC and VI\n",
    "\n",
    "We construct a class  `BayesianInferencePlot` to implement MCMC or VI algorithms and plot multiple posteriors for different updating data sizes and different  possible prior.\n",
    "\n",
    "This class takes as inputs the true data generating parameter ‘theta’, a list of updating data sizes for multiple posterior plotting, and a defined and parametrized `BayesianInference` class.\n",
    "\n",
    "It has two key methods:\n",
    "\n",
    "- `BayesianInferencePlot.MCMC_plot()` takes wanted MCMC sample size as input and plot the output posteriors  together with the prior defined in `BayesianInference` class.  \n",
    "- `BayesianInferencePlot.SVI_plot()` takes wanted VI distribution class (‘beta’ or ‘normal’) as input and plot the posteriors together with the prior.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79076cf5",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "class BayesianInferencePlot:\n",
    "    \"\"\"\n",
    "    Easily implement the MCMC and VI inference for a given instance of BayesianInference class and\n",
    "    plot the prior together with multiple posteriors\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    theta : float.\n",
    "        the true DGP parameter\n",
    "    N_list : list.\n",
    "        a list of sample size\n",
    "    BayesianInferenceClass : class.\n",
    "        a class initiated using BayesianInference()\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, theta, N_list, BayesianInferenceClass, binwidth=0.02):\n",
    "        \"\"\"\n",
    "        Enter Parameters for data generation and plotting\n",
    "        \"\"\"\n",
    "        self.theta = theta\n",
    "        self.N_list = N_list\n",
    "        self.BayesianInferenceClass = BayesianInferenceClass\n",
    "\n",
    "        # plotting parameters\n",
    "        self.binwidth = binwidth\n",
    "        self.linewidth=0.05\n",
    "        self.colorlist = sns.color_palette(n_colors=len(N_list))\n",
    "\n",
    "        # data generation\n",
    "        N_max = max(N_list)\n",
    "        self.data = simulate_draw(theta, N_max)\n",
    "\n",
    "\n",
    "    def MCMC_plot(self, num_samples, num_warmup=1000):\n",
    "        \"\"\"\n",
    "        Parameters as in MCMC_sampling except that data is already defined\n",
    "        \"\"\"\n",
    "        fig, ax = plt.subplots(figsize=(10, 6))\n",
    "\n",
    "        # plot prior\n",
    "        prior_sample = self.BayesianInferenceClass.show_prior(disp_plot=0)\n",
    "        sns.histplot(\n",
    "            data=prior_sample, kde=True, stat='density',\n",
    "            binwidth=self.binwidth,\n",
    "            color='#4C4E52',\n",
    "            linewidth=self.linewidth,\n",
    "            alpha=0.1,\n",
    "            ax=ax,\n",
    "            label='Prior Distribution'\n",
    "            )\n",
    "\n",
    "        # plot posteriors\n",
    "        for id, n in enumerate(self.N_list):\n",
    "            samples = self.BayesianInferenceClass.MCMC_sampling(\n",
    "                self.data[:n], num_samples, num_warmup\n",
    "            )\n",
    "            sns.histplot(\n",
    "                samples, kde=True, stat='density',\n",
    "                binwidth=self.binwidth,\n",
    "                linewidth=self.linewidth,\n",
    "                alpha=0.2,\n",
    "                color=self.colorlist[id-1],\n",
    "                label=f'Posterior with $n={n}$'\n",
    "                )\n",
    "        ax.legend()\n",
    "        ax.set_title('MCMC Sampling density of Posterior Distributions', fontsize=15)\n",
    "        plt.xlim(0, 1)\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "    def SVI_fitting(self, guide_dist, params):\n",
    "        \"\"\"\n",
    "        Fit the beta/truncnormal curve using parameters trained by SVI.\n",
    "        I create plot using PDF given by scipy.stats distributions since torch.dist do not have embedded PDF methods.\n",
    "        \"\"\"\n",
    "        # create x axis\n",
    "        xaxis = np.linspace(0,1,1000)\n",
    "        if guide_dist=='beta':\n",
    "            y = st.beta.pdf(xaxis, a=params['alpha_q'], b=params['beta_q'])\n",
    "\n",
    "        elif guide_dist=='normal':\n",
    "\n",
    "            # rescale upper/lower bound. See Scipy's truncnorm doc\n",
    "            lower, upper = (0, 1)\n",
    "            loc, scale = params['loc'], params['scale']\n",
    "            a, b = (lower - loc) / scale, (upper - loc) / scale\n",
    "\n",
    "            y = st.truncnorm.pdf(xaxis, a=a, b=b, loc=params['loc'], scale=params['scale'])\n",
    "        return (xaxis, y)\n",
    "\n",
    "\n",
    "    def SVI_plot(self, guide_dist, n_steps=2000):\n",
    "        \"\"\"\n",
    "        Parameters as in SVI_run except that data is already defined\n",
    "        \"\"\"\n",
    "        fig, ax = plt.subplots(figsize=(10, 6))\n",
    "\n",
    "        # plot prior\n",
    "        prior_sample = self.BayesianInferenceClass.show_prior(disp_plot=0)\n",
    "        sns.histplot(\n",
    "            data=prior_sample, kde=True, stat='density',\n",
    "            binwidth=self.binwidth,\n",
    "            color='#4C4E52',\n",
    "            linewidth=self.linewidth,\n",
    "            alpha=0.1,\n",
    "            ax=ax,\n",
    "            label='Prior Distribution'\n",
    "            )\n",
    "\n",
    "        # plot posteriors\n",
    "        for id, n in enumerate(self.N_list):\n",
    "            (params, losses) = self.BayesianInferenceClass.SVI_run(self.data[:n], guide_dist, n_steps)\n",
    "            x, y = self.SVI_fitting(guide_dist, params)\n",
    "            ax.plot(x, y,\n",
    "                alpha=1,\n",
    "                color=self.colorlist[id-1],\n",
    "                label=f'Posterior with $n={n}$'\n",
    "                )\n",
    "        ax.legend()\n",
    "        ax.set_title(f'SVI density of Posterior Distributions with {guide_dist} guide', fontsize=15)\n",
    "        plt.xlim(0, 1)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32b5bbfa",
   "metadata": {},
   "source": [
    "Let’s set some parameters that we’ll use in all of the examples  below.\n",
    "\n",
    "To save computer time at first, notice that  we’ll set `MCMC_num_samples = 2000` and `SVI_num_steps = 5000`.\n",
    "\n",
    "(Later, to increase accuracy of approximations, we’ll want to increase these.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1accc06",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "num_list = [5,10,50,100,1000]\n",
    "MCMC_num_samples = 2000\n",
    "SVI_num_steps = 5000\n",
    "\n",
    "# theta is the data generating process\n",
    "true_theta = 0.8"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06a53fbb",
   "metadata": {},
   "source": [
    "### Beta Prior and Posteriors:\n",
    "\n",
    "Let’s compare outcomes when we use a Beta prior.\n",
    "\n",
    "For the same Beta prior, we shall\n",
    "\n",
    "- compute posteriors analytically  \n",
    "- compute posteriors using MCMC  via  `Pyro` and `Numpyro`.  \n",
    "- compute posteriors using  VI via  `Pyro` and `Numpyro`.  \n",
    "\n",
    "\n",
    "Let’s start with the analytical method that we described in this quantecon lecture [https://python.quantecon.org/prob_meaning.html](https://python.quantecon.org/prob_meaning.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a594390f",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# First examine Beta priors\n",
    "BETA_pyro = BayesianInference(param=(5,5), name_dist='beta', solver='pyro')\n",
    "BETA_numpyro = BayesianInference(param=(5,5), name_dist='beta', solver='numpyro')\n",
    "\n",
    "BETA_pyro_plot = BayesianInferencePlot(true_theta, num_list, BETA_pyro)\n",
    "BETA_numpyro_plot = BayesianInferencePlot(true_theta, num_list, BETA_numpyro)\n",
    "\n",
    "\n",
    "# plot analytical Beta prior and posteriors\n",
    "xaxis = np.linspace(0,1,1000)\n",
    "y_prior = st.beta.pdf(xaxis, 5, 5)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(10, 6))\n",
    "# plot analytical beta prior\n",
    "ax.plot(xaxis, y_prior, label='Analytical Beta Prior', color='#4C4E52')\n",
    "\n",
    "data, colorlist, N_list = BETA_pyro_plot.data, BETA_pyro_plot.colorlist, BETA_pyro_plot.N_list\n",
    "# plot analytical beta posteriors\n",
    "for id, n in enumerate(N_list):\n",
    "    func = analytical_beta_posterior(data[:n], alpha0=5, beta0=5)\n",
    "    y_posterior = func.pdf(xaxis)\n",
    "    ax.plot(\n",
    "        xaxis, y_posterior, color=colorlist[id-1], label=f'Analytical Beta Posterior with $n={n}$')\n",
    "ax.legend()\n",
    "ax.set_title('Analytical Beta Prior and Posterior', fontsize=15)\n",
    "plt.xlim(0, 1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff9bbfaa",
   "metadata": {},
   "source": [
    "Now let’s use MCMC while still using a beta prior.\n",
    "\n",
    "We’ll do this for both MCMC and VI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e067699",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "BayesianInferencePlot(true_theta, num_list, BETA_pyro).MCMC_plot(num_samples=MCMC_num_samples)\n",
    "BayesianInferencePlot(true_theta, num_list, BETA_numpyro).SVI_plot(guide_dist='beta', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d02fc4f",
   "metadata": {},
   "source": [
    "Here the MCMC approximation looks good.\n",
    "\n",
    "But the VI approximation doesn’t look so good.\n",
    "\n",
    "- even though we use the  beta distribution as our guide, the VI approximated posterior distributions do not closely resemble the posteriors that we had just computed analytically.  \n",
    "\n",
    "\n",
    "(Here, our initial parameter for Beta guide is (0.5, 0.5).)\n",
    "\n",
    "But if we increase the number of  steps from 5000 to 10000 in VI as we now shall do, we’ll get VI-approximated   posteriors\n",
    "will be  more accurate, as we shall see next.\n",
    "\n",
    "(Increasing the step size increases computational time though)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29ad3af0",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "BayesianInferencePlot(true_theta, num_list, BETA_numpyro).SVI_plot(guide_dist='beta', n_steps=100000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da7c4cb9",
   "metadata": {},
   "source": [
    "## Non-conjugate Prior Distributions\n",
    "\n",
    "Having assured ourselves that our MCMC and VI methods can work well when we have  conjugate prior and so can also compute analytically, we\n",
    "next proceed to situations in which our  prior  is not a beta distribution, so we don’t have a conjugate prior.\n",
    "\n",
    "So we will have non-conjugate priors and are cast into situations in which we can’t calculate posteriors analytically."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b732a35",
   "metadata": {},
   "source": [
    "### MCMC\n",
    "\n",
    "First, we implement and display  MCMC.\n",
    "\n",
    "We first initialize the `BayesianInference` classes and then can directly call `BayesianInferencePlot` to plot both MCMC and SVI approximating posteriors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5954c68e",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Initialize BayesianInference classes\n",
    "# try uniform\n",
    "STD_UNIFORM_pyro = BayesianInference(param=(0,1), name_dist='uniform', solver='pyro')\n",
    "UNIFORM_numpyro = BayesianInference(param=(0.2,0.7), name_dist='uniform', solver='numpyro')\n",
    "\n",
    "# try truncated lognormal\n",
    "LOGNORMAL_numpyro = BayesianInference(param=(0,2), name_dist='lognormal', solver='numpyro')\n",
    "LOGNORMAL_pyro = BayesianInference(param=(0,2), name_dist='lognormal', solver='pyro')\n",
    "\n",
    "# try von Mises\n",
    "# shifted von Mises\n",
    "VONMISES_numpyro = BayesianInference(param=10, name_dist='vonMises', solver='numpyro')\n",
    "# truncated von Mises\n",
    "VONMISES_pyro = BayesianInference(param=40, name_dist='vonMises', solver='pyro')\n",
    "\n",
    "# try laplace\n",
    "LAPLACE_numpyro = BayesianInference(param=(0.5, 0.07), name_dist='laplace', solver='numpyro')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35d7dad4",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Uniform\n",
    "example_CLASS = STD_UNIFORM_pyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(num_samples=MCMC_num_samples)\n",
    "\n",
    "example_CLASS = UNIFORM_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(num_samples=MCMC_num_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6b03c8a",
   "metadata": {},
   "source": [
    "In the situation depicted above, we have assumed a  $ Uniform(\\underline{\\theta}, \\overline{\\theta}) $ prior that puts zero probability   outside a bounded support that excludes the true value.\n",
    "\n",
    "Consequently,  the posterior cannot put positive probability above $ \\overline{\\theta} $ or below $ \\underline{\\theta} $.\n",
    "\n",
    "Note how when  the true data-generating $ \\theta $ is located at $ 0.8 $ as it is here,  when $ n $ gets large, the posterior  concentrate on the upper bound of the support of the prior,  $ 0.7 $ here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "067a9a7a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Log Normal\n",
    "example_CLASS = LOGNORMAL_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(num_samples=MCMC_num_samples)\n",
    "\n",
    "example_CLASS = LOGNORMAL_pyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(num_samples=MCMC_num_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8c9ae76",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Von Mises\n",
    "example_CLASS = VONMISES_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "print('\\nNOTE: Shifted von Mises')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(num_samples=MCMC_num_samples)\n",
    "\n",
    "example_CLASS = VONMISES_pyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "print('\\nNOTE: Truncated von Mises')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(num_samples=MCMC_num_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0bf7c3d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Laplace\n",
    "example_CLASS = LAPLACE_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).MCMC_plot(num_samples=MCMC_num_samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a956e4d5",
   "metadata": {},
   "source": [
    "To get more accuracy we will now increase the number of steps for Variational Inference (VI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be3f8a3c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "SVI_num_steps = 50000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bbbcb24e",
   "metadata": {},
   "source": [
    "#### VI with a  Truncated Normal Guide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb821fba",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Uniform\n",
    "example_CLASS = BayesianInference(param=(0,1), name_dist='uniform', solver='numpyro')\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='normal', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecbdecfb",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Log Normal\n",
    "example_CLASS = LOGNORMAL_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='normal', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff134e6a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Von Mises\n",
    "example_CLASS = VONMISES_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "print('\\nNB: Shifted von Mises')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='normal', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5c10363",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Laplace\n",
    "example_CLASS = LAPLACE_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='normal', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f65c96f7",
   "metadata": {},
   "source": [
    "#### Variational Inference with a  Beta Guide Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d538e2d",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Uniform\n",
    "example_CLASS = STD_UNIFORM_pyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='beta', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0118046a",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Log Normal\n",
    "example_CLASS = LOGNORMAL_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='beta', n_steps=SVI_num_steps)\n",
    "\n",
    "example_CLASS = LOGNORMAL_pyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='beta', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09f10c3c",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Von Mises\n",
    "example_CLASS = VONMISES_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "print('\\nNB: Shifted von Mises')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='beta', n_steps=SVI_num_steps)\n",
    "\n",
    "example_CLASS = VONMISES_pyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "print('\\nNB: Truncated von Mises')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='beta', n_steps=SVI_num_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a6fd135",
   "metadata": {
    "hide-output": false
   },
   "outputs": [],
   "source": [
    "# Laplace\n",
    "example_CLASS = LAPLACE_numpyro\n",
    "print(f'=======INFO=======\\nParameters: {example_CLASS.param}\\nPrior Dist: {example_CLASS.name_dist}\\nSolver: {example_CLASS.solver}')\n",
    "BayesianInferencePlot(true_theta, num_list, example_CLASS).SVI_plot(guide_dist='beta', n_steps=SVI_num_steps)"
   ]
  }
 ],
 "metadata": {
  "date": 1750302292.0343344,
  "filename": "bayes_nonconj.md",
  "kernelspec": {
   "display_name": "Python",
   "language": "python3",
   "name": "python3"
  },
  "title": "Non-Conjugate Priors"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}