16. MCMC diagnostics

Data set download


[2]:
import numpy as np
import pandas as pd

import cmdstanpy
import arviz as az

import bebi103
import iqplot

import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
/Users/bois/opt/anaconda3/lib/python3.9/site-packages/colorcet/__init__.py:74: UserWarning: Trying to register the cmap 'cet_gray' which already exists.
  register_cmap("cet_"+name, cmap=cm[name])
/Users/bois/opt/anaconda3/lib/python3.9/site-packages/colorcet/__init__.py:74: UserWarning: Trying to register the cmap 'cet_gray_r' which already exists.
  register_cmap("cet_"+name, cmap=cm[name])
Loading BokehJS ...

In previous lessons, you have seen that we can sample out of arbitrary probability distributions, most notably posterior probability distributions in the context of Bayesian inference, using Markov chain Monte Carlo. However, there are a few questions we need to answer to make sure our MCMC samplers are in fact sampling the target distribution.

  1. Have we achieved stationarity? That is, have the chains sampled enough that we are effectively getting independent samples out of the target distribution?

  2. Can the chains access all areas of parameter space?

  3. Have we taken enough samples to get a good picture of the posterior?

There are diagnostic checks we can do to address these questions, and these checks are the topic of this lesson.

The data set

As we set out to learn about MCMC diagnostics, we will again use the data set from Singer, et al. consisting of mRNA transcript counts in cells from single molecule FISH experiments. We’ll start by loading in the data set. We will work with the Rest data, which I will go ahead and pull out as a Numpy array. I’ll make a quick plot of the ECDF as a reminder of the data set.

[3]:
df = pd.read_csv(os.path.join(data_path, "singer_transcript_counts.csv"), comment="#")
n = df["Rest"].values

bokeh.io.show(iqplot.ecdf(n, q="transcript count"))

The model

As we have previously discussed, the transcript counts are Negative Binomially distributed under a model for bursty gene expression. We built the following generative model.

\begin{align} &\log_{10} \alpha \sim \text{Norm}(0, 1), \\[1em] &\log_{10} b \sim \text{LogNorm}(2, 1), \\[1em] &\beta = 1/b,\\[1em] &n_i \sim \text{NegBinom}(\alpha, \beta) \;\forall i. \end{align}

Here, \(\alpha\) is the frequency of bursts in gene expression and \(b\) is the size of the bursts. We do a change of variables to convert \(b\) to \(\beta\), as required for parametrization with Stan. The Stan code for this model is

data {
  int<lower=0> N;
  int<lower=0> n[N];
}


parameters {
  real log10_alpha;
  real log10_b;
}


transformed parameters {
  real alpha = 10^log10_alpha;
  real b = 10^log10_b;
  real beta_ = 1.0 / b;
}


model {
  // Priors
  log10_alpha ~ normal(0, 1);
  log10_b ~ normal(2, 1);

  // Likelihood
  n ~ neg_binomial(alpha, beta_);
}

We will compile this model so we have it ready for use.

[4]:
sm = cmdstanpy.CmdStanModel(stan_file='smfish.stan')
INFO:cmdstanpy:compiling stan file /Users/bois/Dropbox/git/bebi103_course/2022/b/content/lessons/16/smfish.stan to exe file /Users/bois/Dropbox/git/bebi103_course/2022/b/content/lessons/16/smfish
INFO:cmdstanpy:compiled model executable: /Users/bois/Dropbox/git/bebi103_course/2022/b/content/lessons/16/smfish

Let’s get some samples to work with. We will seed the random number generator for reproducibility purposes.

[5]:
with bebi103.stan.disable_logging():
    samples = sm.sample(data=dict(N=len(n), n=n), seed=3252)
samples = az.from_cmdstanpy(posterior=samples)

Diagnostics for any MCMC sampler

We will first investigate diagnostics that apply to any MCMC sampler, not just Hamiltonian Monte Carlo samplers like Stan uses.

The Gelman-Rubin R-hat statistic

The Gelman-Rubin R-hat statistic is a useful metric to determine if we have achieved stationarity with our chains. The idea is that we run multiple chains in parallel (at least four). For a given parameter, we then compute the variance in the samples between the chains, and then the variance of samples within the chains. The ratio of these two is the Gelman-Rubin R-hat statistic, usually denoted as \(\hat{R}\), and we compute \(\hat{R}\) for each chain.

\begin{align} \hat{R} = \frac{\text{variance between chains}}{\text{variance within chains}}. \end{align}

The value of \(\hat{R}\) approaches unity if the chains are properly sampling the target distribution because the chains should be identical in their sampling of the posterior if they have all reached the limiting distribution. As a rule of thumb, recommended by Vehtari, et al., the value of \(\hat{R}\) should be less than 1.01. There are more details involved in calculation of \(\hat{R}\), and you may read about them in the Vehtari, et al. paper.

ArviZ automatically computes \(\hat{R}\) using state-of-the-art rank normalization techniques (published in Vehtari, et al.).

[6]:
az.rhat(samples)
[6]:
<xarray.Dataset>
Dimensions:      ()
Data variables:
    log10_alpha  float64 1.005
    log10_b      float64 1.006
    alpha        float64 1.005
    b            float64 1.006
    beta_        float64 1.006

We see that Rhat for each of the three parameters is 1.01, satisfying the rule of thumb.

If we want to see a quick summary of the results of MCMC, including mean parameter values, we can use az.summary().

[7]:
az.summary(samples)
[7]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
log10_alpha 0.651 0.039 0.576 0.723 0.001 0.001 825.0 936.0 1.01
log10_b 1.224 0.041 1.150 1.302 0.001 0.001 851.0 903.0 1.01
alpha 4.490 0.400 3.746 5.251 0.014 0.010 825.0 936.0 1.01
b 16.838 1.603 14.046 19.953 0.055 0.039 851.0 903.0 1.01
beta_ 0.060 0.006 0.050 0.071 0.000 0.000 851.0 903.0 1.01

We will discuss what some of these other statistics aside from \(\hat{R}\) mean momentarily.

To see examples where they have not converged, we will sample again, but only allow the chains seven warm-up steps.

[8]:
with bebi103.stan.disable_logging():
    samples_limited_warmup = sm.sample(
        data=dict(N=len(n), n=n), iter_warmup=7, iter_sampling=1000, seed=3252
    )
samples_limited_warmup = az.from_cmdstanpy(samples_limited_warmup)

az.summary(samples_limited_warmup)

[8]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
log10_alpha 0.333 0.246 -0.029 0.685 0.119 0.091 4.0 4.0 3.35
log10_b 1.378 0.935 0.079 2.706 0.465 0.356 4.0 4.0 8.79
alpha 2.492 1.274 0.935 4.840 0.619 0.472 4.0 4.0 4.75
b 139.580 212.999 1.199 508.306 105.927 81.093 4.0 4.0 8.79
beta_ 0.232 0.348 0.002 0.834 0.173 0.133 4.0 4.0 8.79

Now, the \(\hat{R}\) values are large; the chains have not converged to the limiting distribution. Note also that the mean values of \(\alpha\) and \(b\) are not the same as for the properly warmed-up sampler. This emphasizes the point that warm-up is crucial for performance of the sampler. If you see \(\hat{R}\) values that are too large, you may be able to fix it by having the walkers take more warm-up steps.

We can also see the poor mixing of the chains by looking at the trace plot.

[9]:
bokeh.io.show(
    bebi103.viz.trace(
        samples_limited_warmup,
        parameters=["alpha", "b"],
        line_kwargs=dict(line_width=2),
    )
)

This is pathological; three of the chains are essentially not moving. One of the chains is moving very poorly. This means that most proposed steps are being rejeced.

As is the case with all diagnostic metrics, there are caveats. You can read about them for \(\hat{R}\) in the Vehtari, et al. paper and in section the Stan manual.

Effective samples size

Recall that MCMC samplers do not draw independent samples from the target distribution. Rather, the samples are correlated. Ideally, though, we would draw independent samples. We would like to get an estimate for the number of effectively independent samples we draw. This is referred to either as effective samples size (ESS) or number of effective samples (\(n_\mathrm{eff}\)).

ArviZ computes ESS according to the prescription laid out in the Vehtari, et al. paper using az.ess(). In the summary, this is given in the ess_bulk column.

[10]:
az.summary(samples)
[10]:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
log10_alpha 0.651 0.039 0.576 0.723 0.001 0.001 825.0 936.0 1.01
log10_b 1.224 0.041 1.150 1.302 0.001 0.001 851.0 903.0 1.01
alpha 4.490 0.400 3.746 5.251 0.014 0.010 825.0 936.0 1.01
b 16.838 1.603 14.046 19.953 0.055 0.039 851.0 903.0 1.01
beta_ 0.060 0.006 0.050 0.071 0.000 0.000 851.0 903.0 1.01

We took a total of 4000 steps (1000 on each of four chains), and got an ESS of about 500. This is a reasonable number, and as a rule of thumb, according to Vehtari, et al., you should have ESS > 400.

We will not consider ess_mean or ess_sd, which are way of computing ESS used in the past, but we will consider ess_tail, referred to as tail-ESS. Again, I will not go into detail of how this is calculated, but this is the effective sample size when considering the more extreme values of the posterior (by default the lower and upper 5th percentiles). Note that this is not the number of samples that landed in the tails, but rather a measure of what the total number of effective samples would be if we were effectively sampling the tails. Again, we want tail-ESS to be greater than 400 as a rule of thumb. We have accomplished this here.

Bear in mind that the ESS calculation is approximate and subject to error. There are, as usual, other caveats, which are discussed in the Vehtari, et al. paper and the Stan manual.

Monte Carlo standard error

The Monte Carlo standard errors (MCSE) are reported as msce_mean and mcse_sd. They are measurements of the standard error of the mean and the standard error of the standard deviation of the chains. They provide an estimate as to how accurate the expectation values given from MCMC samples of the mean and standard deviation are. In practice, if the MCSE of the mean is less than the standard deviation of the samples themselves (that is the mcse_mean column is much less than the sd column), we have taken plenty of samples. The only reason to use the MCSE is if we have a particular strong interest in getting very precise measurement of the mean in particular.

I was hesitant to even discuss this here, since I agree with Gelman, “For Bayesian inference, I don’t think it’s generally necessary or appropriate to report Monte Carlo standard errors of posterior means and quantiles…”

Diagnostics for HMC

Both \(\hat{R}\) and ESS are useful diagnostics for any MCMC sampler, but Hamiltonian Monte Carlo offers other diagnostics to help ensure that the sampling is going as it should. It is important to note that these diagnostics are a feature of HMC, not a bug. By that I mean that the absence of these diagnostics, particularly divergences, from other sampling methods means that it is harder to ensure that they are sampling properly. The ability to check that it is working properly makes HMC all the more powerful.

Divergences

Hamiltonian Monte Carlo enables large step sizes by taking into account the shape of the target distribution and tracing trajectories along it. (This is of course a very loose description. You should read Michael Betancourt’s wonderful introduction to HMC to get a more complete picture.) When a trajectory encounters a region of parameter space where the posterior (target) distribution has high curvature, the trajectory can veer sharply. These events can be detected and are registered as divergences. A given Monte Carlo step ends in a divergence if this happens. This does not necessarily mean that there is a problem with the sample, but there is a good chance that there is.

Stan keeps track of divergences and reports them. In ArviZ InferenceData objects, they are stored in the sample_stats attribute. Let’s look first at our good samples where we properly warmed up the sampler.

[11]:
samples.sample_stats.diverging
[11]:
<xarray.DataArray 'diverging' (chain: 4, draw: 1000)>
array([[False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False],
       [False, False, False, ..., False, False, False]])
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999

We can check how many divergences we had by summing them.

[12]:
int(np.sum(samples.sample_stats.diverging))
[12]:
0

So, the properly warmed up sampler had no divergences. Let’s look at the improperly warmed-up sampler.

[13]:
int(np.sum(samples_limited_warmup.sample_stats.diverging))
[13]:
3507

Yikes! All kinds of divergences there. This is endemic of a sampler in trouble.

We will talk more about divergences later in future lessons when we deal with distributions that are inherently difficult to sample, regardless of whether or not we warmed up properly.

Tree depth

The explanation of this diagnostic is a little computer-sciencey, so you can skip to the last sentence of this section if the CS terms are unfamiliar to you.

The HMC algorithm used by Stan uses [recursion](https://en.wikipedia.org/wiki/Recursion_(computer_science). In practice when doing recursive calculations, you need to put a bound on how deep the recursion can go, i.e., you need to cap the tree depth, lest you get stack overflow. Stan therefore has to have a limit on tree depth, the default of which is 10. If this tree depth is hit while trying to take a sample, the sampling is not wrong, but less efficient. Stan therefore reports the tree depth information for each sample. These are also included in the sample_stats.

[14]:
samples.sample_stats.treedepth
[14]:
<xarray.DataArray 'treedepth' (chain: 4, draw: 1000)>
array([[2, 2, 1, ..., 2, 2, 2],
       [3, 2, 2, ..., 2, 1, 2],
       [4, 1, 4, ..., 3, 2, 3],
       [2, 3, 2, ..., 3, 2, 1]])
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999

We can look how many hit a tree depth of 10.

[15]:
int(np.sum(samples.sample_stats.treedepth == 10))
[15]:
0

So, in this case, we never hit the tree depth. When we do hit the tree depth often, it typically results in a less efficient sampler and the ESS will decrease.

E-BFMI

The energy-Bayes fraction of missing information, or E-BFMI is another metric that is specific to HMC samplers. Loosely speaking (again), it is a measure of how effective the sampler is at taking long steps. Some details are given in the Betancourt paper on HMC, and we will not go into them here, but say that as a rule of thumb, values below 0.3 can be indicative of inefficient sampling.

Stan also automatically computes the E-BFMI.

[16]:
samples.sample_stats.energy
[16]:
<xarray.DataArray 'energy' (chain: 4, draw: 1000)>
array([[1379.2 , 1379.07, 1378.82, ..., 1380.36, 1379.6 , 1380.07],
       [1380.48, 1379.76, 1380.22, ..., 1378.31, 1378.01, 1378.68],
       [1377.95, 1377.72, 1377.83, ..., 1379.7 , 1377.75, 1377.8 ],
       [1379.9 , 1379.01, 1377.64, ..., 1380.55, 1379.25, 1377.59]])
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999

The energies all look high, but let’s do a quick check to see if we have any small ones.

[17]:
int(np.sum(samples.sample_stats.energy < 0.3))
[17]:
0

Nope! We’re in good shape.

Quickly checking the diagnostics

I wrote a function, based on work by Michael Betancourt, to quickly check these diagnostics for a set of samples. It is available in the bebi103.stan submodule.

[18]:
bebi103.stan.check_all_diagnostics(samples)
Effective sample size looks reasonable for all parameters.

Rhat looks reasonable for all parameters.

0 of 4000 (0.0%) iterations ended with a divergence.

0 of 4000 (0.0%) iterations saturated the maximum tree depth of 10.

E-BFMI indicated no pathological behavior.
[18]:
0

This is a quick check you can do to make sure everything is in order after obtaining samples. But it is very important to note that passing all of these diagnostic checks does not ensure that you achieved effective sampling. And perhaps even more importantly, getting effective sampling certainly does not guarantee that your model is a good one. Nonetheless, good, identifiable models tend to pass the diagnostic checks more often than poor ones.

[19]:
bebi103.stan.clean_cmdstan()

Computing environment

[20]:
%load_ext watermark
%watermark -v -p numpy,pandas,cmdstanpy,arviz,bokeh,bebi103,jupyterlab
print("cmdstan   :", bebi103.stan.cmdstan_version())
Python implementation: CPython
Python version       : 3.9.7
IPython version      : 7.29.0

numpy     : 1.20.3
pandas    : 1.3.5
cmdstanpy : 1.0.0
arviz     : 0.11.4
bokeh     : 2.3.3
bebi103   : 0.1.10
jupyterlab: 3.2.1

cmdstan   : 2.28.2