51. Posterior Distributions for AR(1) Parameters#

We’ll begin with some Python imports.

!pip install arviz pymc numpyro jax
Hide code cell output
Requirement already satisfied: arviz in /opt/conda/envs/quantecon/lib/python3.10/site-packages (0.15.1)
Collecting pymc
  Downloading pymc-5.8.2-py3-none-any.whl (468 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/468.9 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 468.9/468.9 kB 17.6 MB/s eta 0:00:00
?25hCollecting numpyro
  Downloading numpyro-0.13.1-py3-none-any.whl (312 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/312.7 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 312.7/312.7 kB 45.7 MB/s eta 0:00:00
?25hRequirement already satisfied: jax in /opt/conda/envs/quantecon/lib/python3.10/site-packages (0.4.8)
Requirement already satisfied: pandas>=1.3.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (1.5.3)
Requirement already satisfied: typing-extensions>=4.1.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (4.4.0)
Requirement already satisfied: xarray-einstats>=0.3 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (0.5.1)
Requirement already satisfied: matplotlib>=3.2 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (3.7.0)
Requirement already satisfied: xarray>=0.21.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (2022.11.0)
Requirement already satisfied: h5netcdf>=1.0.2 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (1.1.0)
Requirement already satisfied: setuptools>=60.0.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (65.6.3)
Requirement already satisfied: packaging in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (22.0)
Requirement already satisfied: numpy>=1.20.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (1.23.5)
Requirement already satisfied: scipy>=1.8.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from arviz) (1.10.0)
Collecting fastprogress>=0.2.0
  Downloading fastprogress-1.0.3-py3-none-any.whl (12 kB)
Collecting cachetools>=4.2.1
  Downloading cachetools-5.3.1-py3-none-any.whl (9.3 kB)
Requirement already satisfied: cloudpickle in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from pymc) (2.0.0)
Collecting pytensor<2.17,>=2.16.1
  Downloading pytensor-2.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.9 MB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.9 MB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.9/1.9 MB 69.4 MB/s eta 0:00:00
?25h
Collecting jaxlib>=0.4.14
  Downloading jaxlib-0.4.16-cp310-cp310-manylinux2014_x86_64.whl (84.5 MB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/84.5 MB ? eta -:--:--
     ━╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.2/84.5 MB 126.9 MB/s eta 0:00:01
     ━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.5/84.5 MB 123.7 MB/s eta 0:00:01
     ━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.0/84.5 MB 124.8 MB/s eta 0:00:01
     ━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 17.4/84.5 MB 124.2 MB/s eta 0:00:01
     ━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.8/84.5 MB 123.3 MB/s eta 0:00:01
     ━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━ 26.3/84.5 MB 125.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━ 30.6/84.5 MB 123.5 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━ 35.0/84.5 MB 124.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━ 39.4/84.5 MB 124.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━ 43.7/84.5 MB 124.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━ 48.1/84.5 MB 123.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━ 52.5/84.5 MB 123.9 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━ 56.8/84.5 MB 124.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━━━ 61.2/84.5 MB 124.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━━ 65.6/84.5 MB 124.1 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━ 70.0/84.5 MB 123.9 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━ 74.3/84.5 MB 123.3 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━ 78.6/84.5 MB 123.0 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺ 82.9/84.5 MB 122.1 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ 84.5/84.5 MB 121.2 MB/s eta 0:00:01
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.5/84.5 MB 19.3 MB/s eta 0:00:00
?25h
Requirement already satisfied: multipledispatch in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from numpyro) (0.6.0)
Collecting jax
  Downloading jax-0.4.16-py3-none-any.whl (1.6 MB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/1.6 MB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 75.3 MB/s eta 0:00:00
?25hRequirement already satisfied: tqdm in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from numpyro) (4.64.1)
Collecting ml-dtypes>=0.2.0
  Downloading ml_dtypes-0.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (206 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/206.6 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 206.6/206.6 kB 37.9 MB/s eta 0:00:00
?25hRequirement already satisfied: opt-einsum in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from jax) (3.3.0)
Requirement already satisfied: h5py in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from h5netcdf>=1.0.2->arviz) (3.7.0)
Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from matplotlib>=3.2->arviz) (1.0.5)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from matplotlib>=3.2->arviz) (3.0.9)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from matplotlib>=3.2->arviz) (0.11.0)
Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from matplotlib>=3.2->arviz) (2.8.2)
Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from matplotlib>=3.2->arviz) (4.25.0)
Requirement already satisfied: pillow>=6.2.0 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from matplotlib>=3.2->arviz) (9.4.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from matplotlib>=3.2->arviz) (1.4.4)
Requirement already satisfied: pytz>=2020.1 in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from pandas>=1.3.0->arviz) (2023.3)
Collecting logical-unification
  Downloading logical-unification-0.4.6.tar.gz (31 kB)
  Preparing metadata (setup.py) ... ?25l-
 done
?25hCollecting miniKanren
  Downloading miniKanren-1.0.3.tar.gz (41 kB)
?25l     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0/41.3 kB ? eta -:--:--
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 41.3/41.3 kB 9.0 MB/s eta 0:00:00
?25h
  Preparing metadata (setup.py) ... ?25l-
 done
?25hCollecting cons
  Downloading cons-0.4.6.tar.gz (26 kB)
  Preparing metadata (setup.py) ... ?25l-
 done
?25hCollecting etuples
  Downloading etuples-0.3.9.tar.gz (30 kB)
  Preparing metadata (setup.py) ... ?25l-
 done
?25hRequirement already satisfied: filelock in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from pytensor<2.17,>=2.16.1->pymc) (3.9.0)
Requirement already satisfied: six in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from multipledispatch->numpyro) (1.16.0)
Requirement already satisfied: toolz in /opt/conda/envs/quantecon/lib/python3.10/site-packages (from logical-unification->pytensor<2.17,>=2.16.1->pymc) (0.12.0)
Building wheels for collected packages: cons, logical-unification, etuples, miniKanren
  Building wheel for cons (setup.py) ... ?25l-
 \
 |
 done
?25h  Created wheel for cons: filename=cons-0.4.6-py3-none-any.whl size=9100 sha256=f847586bedede12d333f9ed45e15cd53d88b477c1381e44360366f96f3e3c15c
  Stored in directory: /github/home/.cache/pip/wheels/7b/43/b4/ccd47a951bb42b3ec93e9d2798a768007c4165aa9368322852
  Building wheel for logical-unification (setup.py) ... ?25l-
 \
 |
 done
?25h  Created wheel for logical-unification: filename=logical_unification-0.4.6-py3-none-any.whl size=13911 sha256=3ac5db5d15724948618374264826e5f0b76d884459420382252d8db7e7f30c75
  Stored in directory: /github/home/.cache/pip/wheels/81/57/68/6c4b68f3accea61451bb3dae7378ec44a12f46525cfb4beb77
  Building wheel for etuples (setup.py) ... ?25l-
 \
 |
 done
?25h  Created wheel for etuples: filename=etuples-0.3.9-py3-none-any.whl size=12617 sha256=79b34c1a4bc0a5b1eb0d0b3e2f1f3601ee98e83c3b4129b98b5dd5a609ae068f
  Stored in directory: /github/home/.cache/pip/wheels/da/12/be/e5f59b5b13a5d1ba5f0b90dfb2c76eec420aa26b8c2d75cc6d
  Building wheel for miniKanren (setup.py) ... ?25l-
 \
 |
 done
?25h  Created wheel for miniKanren: filename=miniKanren-1.0.3-py3-none-any.whl size=23910 sha256=d5be8c9b01ba7417133ab6c64a40fadb57c4347bcf631a6a8081db4ca6cc1fb6
  Stored in directory: /github/home/.cache/pip/wheels/28/61/4f/cfd90900d6403ad2b2eeafd0a5d7b487074e2a8c1065be1605
Successfully built cons logical-unification etuples miniKanren
Installing collected packages: ml-dtypes, fastprogress, cachetools, logical-unification, jaxlib, jax, numpyro, cons, etuples, miniKanren, pytensor, pymc
  Attempting uninstall: ml-dtypes
    Found existing installation: ml-dtypes 0.0.4
    Uninstalling ml-dtypes-0.0.4:
      Successfully uninstalled ml-dtypes-0.0.4
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.7+cuda12.cudnn88
    Uninstalling jaxlib-0.4.7+cuda12.cudnn88:
      Successfully uninstalled jaxlib-0.4.7+cuda12.cudnn88
  Attempting uninstall: jax
    Found existing installation: jax 0.4.8
    Uninstalling jax-0.4.8:
      Successfully uninstalled jax-0.4.8
Successfully installed cachetools-5.3.1 cons-0.4.6 etuples-0.3.9 fastprogress-1.0.3 jax-0.4.16 jaxlib-0.4.16 logical-unification-0.4.6 miniKanren-1.0.3 ml-dtypes-0.3.0 numpyro-0.13.1 pymc-5.8.2 pytensor-2.16.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

import arviz as az
import pymc as pmc
import numpyro
from numpyro import distributions as dist

import numpy as np
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt

%matplotlib inline

import logging
logging.basicConfig()
logger = logging.getLogger('pymc')
logger.setLevel(logging.CRITICAL)

This lecture uses Bayesian methods offered by pymc and numpyro to make statistical inferences about two parameters of a univariate first-order autoregression.

The model is a good laboratory for illustrating consequences of alternative ways of modeling the distribution of the initial \(y_0\):

  • As a fixed number

  • As a random variable drawn from the stationary distribution of the \(\{y_t\}\) stochastic process

The first component of the statistical model is

(51.1)#\[ y_{t+1} = \rho y_t + \sigma_x \epsilon_{t+1}, \quad t \geq 0 \]

where the scalars \(\rho\) and \(\sigma_x\) satisfy \(|\rho| < 1\) and \(\sigma_x > 0\); \(\{\epsilon_{t+1}\}\) is a sequence of i.i.d. normal random variables with mean \(0\) and variance \(1\).

The second component of the statistical model is

(51.2)#\[ y_0 \sim {\cal N}(\mu_0, \sigma_0^2) \]

Consider a sample \(\{y_t\}_{t=0}^T\) governed by this statistical model.

The model implies that the likelihood function of \(\{y_t\}_{t=0}^T\) can be factored:

\[ f(y_T, y_{T-1}, \ldots, y_0) = f(y_T| y_{T-1}) f(y_{T-1}| y_{T-2}) \cdots f(y_1 | y_0 ) f(y_0) \]

where we use \(f\) to denote a generic probability density.

The statistical model (51.1)-(51.2) implies

\[\begin{split} \begin{aligned} f(y_t | y_{t-1}) & \sim {\mathcal N}(\rho y_{t-1}, \sigma_x^2) \\ f(y_0) & \sim {\mathcal N}(\mu_0, \sigma_0^2) \end{aligned} \end{split}\]

We want to study how inferences about the unknown parameters \((\rho, \sigma_x)\) depend on what is assumed about the parameters \(\mu_0, \sigma_0\) of the distribution of \(y_0\).

Below, we study two widely used alternative assumptions:

  • \((\mu_0,\sigma_0) = (y_0, 0)\) which means that \(y_0\) is drawn from the distribution \({\mathcal N}(y_0, 0)\); in effect, we are conditioning on an observed initial value.

  • \(\mu_0,\sigma_0\) are functions of \(\rho, \sigma_x\) because \(y_0\) is drawn from the stationary distribution implied by \(\rho, \sigma_x\).

Note: We do not treat a third possible case in which \(\mu_0,\sigma_0\) are free parameters to be estimated.

Unknown parameters are \(\rho, \sigma_x\).

We have independent prior probability distributions for \(\rho, \sigma_x\) and want to compute a posterior probability distribution after observing a sample \(\{y_{t}\}_{t=0}^T\).

The notebook uses pymc4 and numpyro to compute a posterior distribution of \(\rho, \sigma_x\). We will use NUTS samplers to generate samples from the posterior in a chain. Both of these libraries support NUTS samplers.

NUTS is a form of Monte Carlo Markov Chain (MCMC) algorithm that bypasses random walk behaviour and allows for convergence to a target distribution more quickly. This not only has the advantage of speed, but allows for complex models to be fitted without having to employ specialised knowledge regarding the theory underlying those fitting methods.

Thus, we explore consequences of making these alternative assumptions about the distribution of \(y_0\):

  • A first procedure is to condition on whatever value of \(y_0\) is observed. This amounts to assuming that the probability distribution of the random variable \(y_0\) is a Dirac delta function that puts probability one on the observed value of \(y_0\).

  • A second procedure assumes that \(y_0\) is drawn from the stationary distribution of a process described by (51.1) so that \(y_0 \sim {\cal N} \left(0, {\sigma_x^2\over (1-\rho)^2} \right) \)

When the initial value \(y_0\) is far out in a tail of the stationary distribution, conditioning on an initial value gives a posterior that is more accurate in a sense that we’ll explain.

Basically, when \(y_0\) happens to be in a tail of the stationary distribution and we don’t condition on \(y_0\), the likelihood function for \(\{y_t\}_{t=0}^T\) adjusts the posterior distribution of the parameter pair \(\rho, \sigma_x \) to make the observed value of \(y_0\) more likely than it really is under the stationary distribution, thereby adversely twisting the posterior in short samples.

An example below shows how not conditioning on \(y_0\) adversely shifts the posterior probability distribution of \(\rho\) toward larger values.

We begin by solving a direct problem that simulates an AR(1) process.

How we select the initial value \(y_0\) matters.

  • If we think \(y_0\) is drawn from the stationary distribution \({\mathcal N}(0, \frac{\sigma_x^{2}}{1-\rho^2})\), then it is a good idea to use this distribution as \(f(y_0)\). Why? Because \(y_0\) contains information about \(\rho, \sigma_x\).

  • If we suspect that \(y_0\) is far in the tails of the stationary distribution – so that variation in early observations in the sample have a significant transient component – it is better to condition on \(y_0\) by setting \(f(y_0) = 1\).

To illustrate the issue, we’ll begin by choosing an initial \(y_0\) that is far out in a tail of the stationary distribution.

def ar1_simulate(rho, sigma, y0, T):

    # Allocate space and draw epsilons
    y = np.empty(T)
    eps = np.random.normal(0.,sigma,T)

    # Initial condition and step forward
    y[0] = y0
    for t in range(1, T):
        y[t] = rho*y[t-1] + eps[t]

    return y

sigma =  1.
rho = 0.5
T = 50

np.random.seed(145353452)
y = ar1_simulate(rho, sigma, 10, T)
plt.plot(y)
plt.tight_layout()
_images/b443aafe37412746ea90ed8f7a8d980dc90f5215851131e1253cdfc59cd60fbb.png

Now we shall use Bayes’ law to construct a posterior distribution, conditioning on the initial value of \(y_0\).

(Later we’ll assume that \(y_0\) is drawn from the stationary distribution, but not now.)

First we’ll use pymc4.

51.1. PyMC Implementation#

For a normal distribution in pymc, \(var = 1/\tau = \sigma^{2}\).

AR1_model = pmc.Model()

with AR1_model:

    # Start with priors
    rho = pmc.Uniform('rho', lower=-1., upper=1.) # Assume stable rho
    sigma = pmc.HalfNormal('sigma', sigma = np.sqrt(10))

    # Expected value of y at the next period (rho * y)
    yhat = rho * y[:-1]

    # Likelihood of the actual realization
    y_like = pmc.Normal('y_obs', mu=yhat, sigma=sigma, observed=y[1:])

pmc.sample by default uses the NUTS samplers to generate samples as shown in the below cell:

with AR1_model:
    trace = pmc.sample(50000, tune=10000, return_inferencedata=True)
100.00% [240000/240000 01:24<00:00 Sampling 4 chains, 0 divergences]
with AR1_model:
    az.plot_trace(trace, figsize=(17,6))
_images/31af8b0529dd4cc97d2e2da37168de0e1b1230fc1574e97332bc48e888aefe95.png

Evidently, the posteriors aren’t centered on the true values of \(.5, 1\) that we used to generate the data.

This is a symptom of the classic Hurwicz bias for first order autoregressive processes (see Leonid Hurwicz [Hur50].)

The Hurwicz bias is worse the smaller is the sample (see [OW69]).

Be that as it may, here is more information about the posterior.

with AR1_model:
    summary = az.summary(trace, round_to=4)

summary
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
rho 0.5364 0.0712 0.4035 0.6713 0.0002 0.0001 176745.5969 123630.9264 1.0
sigma 1.0105 0.1067 0.8214 1.2135 0.0003 0.0002 164066.3034 136011.9523 1.0

Now we shall compute a posterior distribution after seeing the same data but instead assuming that \(y_0\) is drawn from the stationary distribution.

This means that

\[ y_0 \sim N \left(0, \frac{\sigma_x^{2}}{1 - \rho^{2}} \right) \]

We alter the code as follows:

AR1_model_y0 = pmc.Model()

with AR1_model_y0:

    # Start with priors
    rho = pmc.Uniform('rho', lower=-1., upper=1.) # Assume stable rho
    sigma = pmc.HalfNormal('sigma', sigma=np.sqrt(10))

    # Standard deviation of ergodic y
    y_sd = sigma / np.sqrt(1 - rho**2)

    # yhat
    yhat = rho * y[:-1]
    y_data = pmc.Normal('y_obs', mu=yhat, sigma=sigma, observed=y[1:])
    y0_data = pmc.Normal('y0_obs', mu=0., sigma=y_sd, observed=y[0])
with AR1_model_y0:
    trace_y0 = pmc.sample(50000, tune=10000, return_inferencedata=True)

# Grey vertical lines are the cases of divergence
100.00% [240000/240000 01:30<00:00 Sampling 4 chains, 63 divergences]
with AR1_model_y0:
    az.plot_trace(trace_y0, figsize=(17,6))
_images/db59a882a200d7eb6415799452e4185078e1dfa4393c3e3b9b0c5919264b0adf.png
with AR1_model:
    summary_y0 = az.summary(trace_y0, round_to=4)

summary_y0
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
rho 0.8757 0.0812 0.7315 0.9985 0.0002 0.0002 105591.9624 79498.6237 1.0
sigma 1.4045 0.1470 1.1401 1.6833 0.0004 0.0003 116488.9532 99802.8118 1.0

Please note how the posterior for \(\rho\) has shifted to the right relative to when we conditioned on \(y_0\) instead of assuming that \(y_0\) is drawn from the stationary distribution.

Think about why this happens.

Hint

It is connected to how Bayes Law (conditional probability) solves an inverse problem by putting high probability on parameter values that make observations more likely.

We’ll return to this issue after we use numpyro to compute posteriors under our two alternative assumptions about the distribution of \(y_0\).

We’ll now repeat the calculations using numpyro.

51.2. Numpyro Implementation#

def plot_posterior(sample):
    """
    Plot trace and histogram
    """
    # To np array
    rhos = sample['rho']
    sigmas = sample['sigma']
    rhos, sigmas, = np.array(rhos), np.array(sigmas)

    fig, axs = plt.subplots(2, 2, figsize=(17, 6))
    # Plot trace
    axs[0, 0].plot(rhos)   # rho
    axs[1, 0].plot(sigmas) # sigma

    # Plot posterior
    axs[0, 1].hist(rhos, bins=50, density=True, alpha=0.7)
    axs[0, 1].set_xlim([0, 1])
    axs[1, 1].hist(sigmas, bins=50, density=True, alpha=0.7)

    axs[0, 0].set_title("rho")
    axs[0, 1].set_title("rho")
    axs[1, 0].set_title("sigma")
    axs[1, 1].set_title("sigma")
    plt.show()
def AR1_model(data):
    # set prior
    rho = numpyro.sample('rho', dist.Uniform(low=-1., high=1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(scale=np.sqrt(10)))

    # Expected value of y at the next period (rho * y)
    yhat = rho * data[:-1]

    # Likelihood of the actual realization.
    y_data = numpyro.sample('y_obs', dist.Normal(loc=yhat, scale=sigma), obs=data[1:])
# Make jnp array
y = jnp.array(y)

# Set NUTS kernal
NUTS_kernel = numpyro.infer.NUTS(AR1_model)

# Run MCMC
mcmc = numpyro.infer.MCMC(NUTS_kernel, num_samples=50000, num_warmup=10000, progress_bar=False)
mcmc.run(rng_key=random.PRNGKey(1), data=y)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695401966.257830     167 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
plot_posterior(mcmc.get_samples())
_images/4ad4f1146effee56ca0096d4c80727af1679c43cba8f74bd5734d22f2d8978c7.png
mcmc.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       rho      0.54      0.07      0.54      0.42      0.65  31893.01      1.00
     sigma      1.01      0.11      1.00      0.84      1.18  32715.75      1.00

Number of divergences: 0

Next, we again compute the posterior under the assumption that \(y_0\) is drawn from the stationary distribution, so that

\[ y_0 \sim N \left(0, \frac{\sigma_x^{2}}{1 - \rho^{2}} \right) \]

Here’s the new code to achieve this.

def AR1_model_y0(data):
    # Set prior
    rho = numpyro.sample('rho', dist.Uniform(low=-1., high=1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(scale=np.sqrt(10)))

    # Standard deviation of ergodic y
    y_sd = sigma / jnp.sqrt(1 - rho**2)

    # Expected value of y at the next period (rho * y)
    yhat = rho * data[:-1]

    # Likelihood of the actual realization.
    y_data = numpyro.sample('y_obs', dist.Normal(loc=yhat, scale=sigma), obs=data[1:])
    y0_data = numpyro.sample('y0_obs', dist.Normal(loc=0., scale=y_sd), obs=data[0])
# Make jnp array
y = jnp.array(y)

# Set NUTS kernal
NUTS_kernel = numpyro.infer.NUTS(AR1_model_y0)

# Run MCMC
mcmc2 = numpyro.infer.MCMC(NUTS_kernel, num_samples=50000, num_warmup=10000, progress_bar=False)
mcmc2.run(rng_key=random.PRNGKey(1), data=y)
plot_posterior(mcmc2.get_samples())
_images/1cf4049074e087b79163fbe681f19949fda35bdefb10069522149165f074a6f6.png
mcmc2.print_summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
       rho      0.88      0.08      0.89      0.76      1.00  27184.82      1.00
     sigma      1.40      0.15      1.39      1.16      1.64  25417.89      1.00

Number of divergences: 0

Look what happened to the posterior!

It has moved far from the true values of the parameters used to generate the data because of how Bayes’ Law (i.e., conditional probability) is telling numpyro to explain what it interprets as “explosive” observations early in the sample.

Bayes’ Law is able to generate a plausible likelihood for the first observation by driving \(\rho \rightarrow 1\) and \(\sigma \uparrow\) in order to raise the variance of the stationary distribution.

Our example illustrates the importance of what you assume about the distribution of initial conditions.