10. Mixture models and label switching with MCMC

Data set download


[1]:
# Colab setup ------------------
import os, sys, subprocess
if "google.colab" in sys.modules:
    cmd = "pip install --upgrade iqplot colorcet bebi103 arviz cmdstanpy watermark"
    process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    stdout, stderr = process.communicate()
    import cmdstanpy; cmdstanpy.install_cmdstan()
    data_path = "https://s3.amazonaws.com/bebi103.caltech.edu/data/"
else:
    data_path = "../data/"
# ------------------------------

import numpy as np
import scipy.stats as st
import pandas as pd

import cmdstanpy
import arviz as az

import iqplot
import bebi103

import holoviews as hv
hv.extension('bokeh')

import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
Loading BokehJS ...

We continue with our analysis of the smFISH data, but this time for the Rex1 gene. Here, we saw clear bimodality in the data.

[2]:
# Load DataFrame and get counts
df = pd.read_csv(os.path.join(data_path, "singer_transcript_counts.csv"), comment="#")
n = df["Rex1"].values

bokeh.io.show(
    iqplot.ecdf(n, x_axis_label="mRNA count", frame_height=150, frame_width=200)
)

Mixture models

Since the Negative Binomial distribution is unimodal, what could be the story here? It is quite possible we are seeing two different states of cells, one with one level of bursty expression of the gene of interest, and another with a different level of bursty expression. This could mean the cells are differentiating. So, we would expect the number of mRNA transcripts to be distributed according to a linear combination of two negative binomial distributions. We can write out the PMF as

\begin{align} f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) &= w\,\frac{\Gamma(n + \alpha_1)}{n!\,\Gamma(\alpha_1)}\,\left(\frac{\beta_1}{1+\beta_1}\right)^{\alpha_1}\left(\frac{1}{1+\beta_1}\right)^{n} \\[1em] &\;\;\;\;+ (1-w) \,\frac{\Gamma(n + \alpha_2)}{n!\,\Gamma(\alpha_2)}\,\left(\frac{\beta_2}{1+\beta_2}\right)^{\alpha_2}\left(\frac{1}{1+\beta_2}\right)^{n} , \end{align}

where \(w\) is the probability that the burst size and frequency are determined by \(1/\beta_1\) and \(\alpha_1\). Such a model, in which a variable is distributed as a linear combination of distributions, is called a mixture model.

For this case, we can write this likelihood more concisely as

\begin{align} n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2)\;\;\forall i. \end{align}

We have to specify priors on \(\alpha_1\), \(\beta_1\), \(\alpha_2\), \(\beta_2\), and \(w\). We can retain the same priors from the previous lesson for the \(\alpha\)’s and \(\beta\)’s, and we will assume a Uniform prior for \(w\). We then have the following model.

\begin{align} &\log_{10} \alpha_i \sim \text{Norm}(0,1) \text{ for } i \in [1, 2] \\[1em] &\log_{10} b_i \sim \text{Norm}(2, 1) \text{ for } i \in [1, 2], \\[1em] &\beta_i = 1/b_i,\\[1em] &w \sim \text{Beta}(1, 1), \\[1em] &n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2)\;\;\forall i. \end{align}

Note that since the prior for \(w\) is Uniform (Beta(1, 1)), we do not need to explicitly consider its prior, but simply enforce constrains on \(w\).

Coding up a mixture model

There are a few considerations for coding up a mixture model that also introduce Stan syntax. Importantly, under the hood, Stan uses the log posterior, as do almost all samplers, when sampling out of the posterior. Dealing with a mixture model presents a unique challenge for computing the log likelihood (which is one of the summands of the log posterior). Consider the log likelihood in the present example.

\begin{align} \ln f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) = \ln(w\,a_1 + (1-w)a_2), \end{align}

where

\begin{align} a_i = \frac{\Gamma(n + \alpha_i)!}{n!\,\Gamma(\alpha_i)}\,\left(\frac{\beta_i}{1+\beta_i}\right)^{\alpha_i}\left(\frac{1}{1+\beta_i}\right)^{n}. \end{align}

While the logarithm of a product is conveniently split, we cannot split the logarithm of a sum. If we consider the sum directly, we will get serious underflow errors for parameters for which the terms \(a_1\) or \(a_2\) are small. To compute this in a more numerically stable way, we need to use the log-sum-exp trick. Fortunately, Stan has a built-in function to compute the contributions to the log posterior of a mixture, the log_mix function. To update the posterior with this log_mix function, we need to add to target. In Stan, the keyword target is a special variable that holds the running sum of the contributions to the log posterior. When you make statements like theta ~ normal(0.0, 1.0), Stan is adding the appropriate terms to target under the hood. In the case of mixture models, we need to add to target explicitly. More generally, you can add any terms to target, and Stan considers these terms as part of the log posterior. The Stan code below implements this for the mixture model.

data {
  int<lower=0> N;
  int<lower=0> n[N];
}


parameters {
  vector<lower=0>[2] alpha;
  vector<lower=0>[2] b;
  real<lower=0, upper=1> w;
}


transformed parameters {
  vector[2] beta_ = 1.0 ./ b;
}


model {
  // Priors
  alpha ~ lognormal(0.0, 2.3);
  b ~ lognormal(4.6, 2.3);
  w ~ beta(1.0, 1.0);

  // Likelihood
  for (n_val in n) {
    target += log_mix(
      w,
      neg_binomial_lpmf(n_val | alpha[1], beta_[1]),
      neg_binomial_lpmf(n_val | alpha[2], beta_[2])
    );
  }
}

In addition to the log_mix function, there is some more new syntax. - To add the contribution of the Negative Binomial PMF to the log posterior, we use negative_binomial_lpmf. Every discrete Stan distribution has a function <distribution>_lpmf that gives the value of the log PMF, and every continuous distribution has a similar function <distribution>_lpdf. The arguments of the function are as you would write them on paper, with a bar (|) signifying the conditioning on the parameters. - In the above, I have specified alpha, b, and beta_ as vectors. A vector is an array that behaves like a column vector, which means you can do matrix operations with it. It also allows you to do element-wise operations. Notice that I have written

vector[2] beta_ = 1.0 ./ b;

This means that beta_ is a 2-vector and that each element is given by the inverse of the corresponding element in b. The ./ operator accomplishes this (note the dot in front of the slash). In general, preceding an operator with a . indicates that the operation should be done elementwise. - Note also that Stan is smart enough to know that if I give alpha a prior that it needs to assign it independently to each element in alpha. The same is of course true for b. - To specify a lower and upper bound, we use the <lower=0, upper=1> syntax. - Note the for loop construction. for (n_val in n) means iterate over each entry in n, yielding its value as n_val within the loop. The contents of the loop are enclosed in braces. Note that we could equivalently have written the for loop as

// Likelihood
for (i in 1:N) {
  target += log_mix(
    w,
    neg_binomial_lpmf(n[i] | alpha[1], beta_[1]),
    neg_binomial_lpmf(n[i] | alpha[2], beta_[2])
  );
}

Here, we are looping over integers starting an 1 and ending at \(N\), inclusive. You can think of the 1:N syntax to be kind of like Python’s range() function. However, unlike Python’s range(), Stan’s range is inclusive of the end value.

Now that we have our model set up, let’s compile and sample from it! I will use the seed kwarg to set the seed for the random number generator to ensure that I always get the same result for illustrative purposes.

[3]:
sm = cmdstanpy.CmdStanModel(stan_file='mixture.stan')

data = {'n': n, 'N': len(n)}

samples = sm.sample(
    data=data,
    seed=3252,
    chains=4,
    iter_sampling=1000,
)

samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:compiling stan program, exe file: /Users/bois/Dropbox/git/bebi103_course/2021/b/content/lessons/10/mixture
INFO:cmdstanpy:compiler options: stanc_options=None, cpp_options=None
INFO:cmdstanpy:compiled model file: /Users/bois/Dropbox/git/bebi103_course/2021/b/content/lessons/10/mixture
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 1
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:finish chain 3
INFO:cmdstanpy:finish chain 4

Parsing the output

Let’s look at the results.

[4]:
samples
[4]:
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:      (alpha_dim_0: 2, b_dim_0: 2, beta__dim_0: 2, chain: 4, draw: 1000)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * alpha_dim_0  (alpha_dim_0) int64 0 1
        * b_dim_0      (b_dim_0) int64 0 1
        * beta__dim_0  (beta__dim_0) int64 0 1
      Data variables:
          alpha        (chain, draw, alpha_dim_0) float64 5.07 1.545 ... 5.413 2.415
          b            (chain, draw, b_dim_0) float64 32.26 10.05 32.27 ... 28.9 5.952
          w            (chain, draw) float64 0.8312 0.8276 0.8249 ... 0.8299 0.7957
          beta_        (chain, draw, beta__dim_0) float64 0.031 0.09949 ... 0.168
      Attributes:
          created_at:                 2021-01-20T23:49:21.358416
          arviz_version:              0.11.0
          inference_library:          cmdstanpy
          inference_library_version:  0.9.67

    • <xarray.Dataset>
      Dimensions:      (chain: 4, draw: 1000)
      Coordinates:
        * chain        (chain) int64 0 1 2 3
        * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
      Data variables:
          lp           (chain, draw) float64 -1.6e+03 -1.598e+03 ... -1.6e+03
          accept_stat  (chain, draw) float64 0.9941 0.9491 0.6852 ... 0.6901 0.9339
          stepsize     (chain, draw) float64 0.1438 0.1438 0.1438 ... 0.1303 0.1303
          treedepth    (chain, draw) int64 5 4 2 4 4 5 4 3 4 4 ... 4 5 3 4 4 4 3 4 2 5
          n_leapfrog   (chain, draw) int64 31 15 3 19 31 39 31 ... 31 27 31 7 23 3 31
          diverging    (chain, draw) bool False False False ... False False False
          energy       (chain, draw) float64 1.602e+03 1.601e+03 ... 1.601e+03
      Attributes:
          created_at:                 2021-01-20T23:49:21.362240
          arviz_version:              0.11.0
          inference_library:          cmdstanpy
          inference_library_version:  0.9.67

The arviz.InferenceData object has two xarray DataSets, posterior and sample_stats. We will work extensively with sample_stats in future lessons (and find that they are crucial for checking your sampling!), but for now will focus on the posterior object that has the samples. You can look at it be expanding the posterior view above.

Note now that the posterior output is a bit more complex. This is because the parameters alpha, b, and beta_ are vector-valued. The entries in the vectors are respectively indexed by indexes alpha_dim_0, b_dim_0, and beta__dim_0. The samples of these vector valued parameters are then three dimensional arrays. The first two dimensions are the chain and draw, like we have already seen, but the third dimension specifies which element in the vector.

Samples are selecting using xarray data selection, which you can read about in the xarray docs. As an example, to access sample number 478 from chain 2, do the following.

[5]:
samples.posterior.loc[dict(chain=2, draw=478)]
[5]:
<xarray.Dataset>
Dimensions:      (alpha_dim_0: 2, b_dim_0: 2, beta__dim_0: 2)
Coordinates:
    chain        int64 2
    draw         int64 478
  * alpha_dim_0  (alpha_dim_0) int64 0 1
  * b_dim_0      (b_dim_0) int64 0 1
  * beta__dim_0  (beta__dim_0) int64 0 1
Data variables:
    alpha        (alpha_dim_0) float64 4.498 8.293
    b            (b_dim_0) float64 35.77 1.554
    w            float64 0.8408
    beta_        (beta__dim_0) float64 0.02795 0.6433
Attributes:
    created_at:                 2021-01-20T23:49:21.358416
    arviz_version:              0.11.0
    inference_library:          cmdstanpy
    inference_library_version:  0.9.67

If you wanted only the value of \(\alpha_2\) from that sample, you would do (noting the zero-based indexing of xarrays):

[6]:
samples.posterior["alpha"].loc[dict(chain=2, draw=478, alpha_dim_0=1)]
[6]:
<xarray.DataArray 'alpha' ()>
array(8.29259)
Coordinates:
    chain        int64 2
    draw         int64 478
    alpha_dim_0  int64 1

If you wanted it as a scalar, you would use the values attribute and convert the resulting zero-dimension Numpy array to a float.

[7]:
float(samples.posterior["alpha"].loc[dict(chain=2, draw=478, alpha_dim_0=1)].values)
[7]:
8.29259

While xarray data types are quite powerful, it is often more convenient to work with the more familiar Pandas DataFrames. The .to_dataframe() method of xarrays is quite useful for this purpose.

[8]:
samples.posterior.to_dataframe().head()
[8]:
alpha b w beta_
alpha_dim_0 b_dim_0 beta__dim_0 chain draw
0 0 0 0 0 5.07007 32.2602 0.831249 0.030998
1 5.04235 32.2721 0.827551 0.030987
2 5.07235 32.0265 0.824924 0.031224
3 5.04161 32.2483 0.835578 0.031009
4 4.27491 38.1043 0.868500 0.026244

As we might expect, the indexes of the xarray become indexes of the data frame. This can be cumbersome to work with when plotting because we typically wish to plot marginal distributions, e.g., of element 1 of alpha. A more convenient form might be a data frame where each element of a vector- (or matrix-) valued parameter is a column. This is accomplished using the bebi103.stan.arviz_to_dataframe() function.

[9]:
df_mcmc = bebi103.stan.arviz_to_dataframe(samples)

df_mcmc.head()
[9]:
alpha[0] alpha[1] b[0] b[1] w beta_[0] beta_[1] chain__ draw__ diverging__
0 5.07007 1.54548 32.2602 10.05120 0.831249 0.030998 0.099491 0 0 False
1 5.04235 2.95352 32.2721 4.74096 0.827551 0.030987 0.210928 0 1 False
2 5.07235 2.75516 32.0265 4.16415 0.824924 0.031224 0.240145 0 2 False
3 5.04161 2.89813 32.2483 6.03983 0.835578 0.031009 0.165568 0 3 False
4 4.27491 4.73762 38.1043 2.62758 0.868500 0.026244 0.380578 0 4 False

This data frame is easier to work with for plotting. Note that there are added columns for the chain number ('chain__') and draw ('draw__'). These have double underscores in their names for easier reference. This function also includes some diagnostics from the sample_stats group of the ArviZ InferenceData in the data frame, in this case the diverging__ column, and we will discuss the meaning of those in future lessons.

Plotting the samples

To plot our samples, it is convenient to make scatter plots of samples for each pair of parameters. This means we plot all marginalized posteriors that contain two parameters. We can also plot marginalized posteriors where we only have one variable. This is conveniently done using a corner plot, as implemented in bebi103.viz.corner().

[10]:
# Parameters we want to plot with pretty names for display
pars = [("alpha[0]", "α₁"), ("alpha[1]", "α₂"), ("b[0]", "b₁"), ("b[1]", "b₂"), "w"]

bokeh.io.show(bebi103.viz.corner(samples, parameters=pars))

This looks peculiar. We see a strong bimodality in the posterior. The cause of this can be revealed if we color the glyphs by the chain ID, accomplished using the color_by_chain kwarg. We only need to show one plot, we I will show \(w\) vs. \(\alpha_2\).

[11]:
bokeh.io.show(
    bebi103.viz.corner(
        samples, parameters=[("alpha[1]", "α₂"), "w"], color_by_chain=True
    )
)

We see that three of the chains (colored blue, red, and green) are centered around w ≈ 0.8, while the other (colored orange) is around w ≈ 0.2. Note that these two values of w sum to unity. We have just uncovered a nonidentifiable model. A nonidentifiable model is a model for which we cannot unambiguously determine the parameter values. That is, two or more parameter sets are observationally equivalent.

Label switching

There are many reasons why a model may be nonidentifiable. In this case, we are seeing a manifestation of label switching (see section 5 of the Stan Manual). Before launching into label switching in this particular case, I emphasize that this is certainly not the only way models can be nonidentifiable, and I am presenting this as a case study on how to deal with a particular kind of nonidentifiability. If you are seeing nonidentifiability in your model, do not automatically assume that it is because of label switching. I also am going through this case study to demonstrate that even in fairly simple models (in this case, two cell populations, each characterized by their burst size and frequency), devilish problems can arise when trying to do inference. You need to be vigilant.

In this mixture model, it is arbitrary which \((\alpha, b)\) pair we label as \((\alpha_1, b_1)\) or \((\alpha_2, b_2)\). We can switch the labels, and also change \(w\) to \(1-w\), and we have exactly the same posterior probability. To demonstrate that this is the case, I will generate the same grid of plots as above, but switch the labels the appropriate labels and convert \(w\) to \(1-w\) for every \(w < 0.5\). (Note that this will not in general work, especially if the different modes from label switching overlap, and is not a good idea for analysis; I’m just doing it here to illustrate how label switching leads to nonidentifiability.)

[12]:
# Perform the label switch
params = ["alpha[0]", "alpha[1]", "b[0]", "b[1]", "w"]
switch = df_mcmc.loc[df_mcmc["w"] > 0.5, params]
switch = switch.rename(
    columns={
        "b[0]": "b[1]",
        "b[1]": "b[0]",
        "alpha[1]": "alpha[0]",
        "alpha[0]": "alpha[1]",
    }
)
switch["w"] = 1 - switch["w"]

df_switch = pd.concat([df_mcmc.loc[df_mcmc["w"] < 0.5, params], switch], sort=True)

# Make corner plot
bokeh.io.show(bebi103.viz.corner(df_switch, parameters=pars))

We see that if we fix the label switching, the posterior is indeed unimodal. So, making an identifiable model in this case means that we have to deal with the label switching problem. There are many approaches to doing this, and you can see a very detailed discussion about dealing with label switch and other problems associated with mixture models in this blog post by Michael Betancourt.

In looking at the corner plot, we see a very strong, non-Normal correlation between \(\alpha_1\) and \(b_1\) and also between \(\alpha_2\) and \(\beta_2\). There is clearly some structure in the posterior that would be impossible to discover using the MAP alone.

Initializing walkers

We will take a strategy suggested in earlier versions of the Stan manual. We noted earlier that the chains tended to stay on a single mode. We will initialize the chains to instead start near only one mode. This, like any of the fixes for mixture models, does not guarantee that we get good sampling (and we will discuss more diagnostics for good in a future tutorial), but it in practice it can work quite well, as it will here.

To determine where to start the chains, we will select a chain and start the samplers at the mean. First, we need to compute the parameter means for the first chain.

[13]:
# Compute mean of parameters for chain 1
param_means = df_mcmc.loc[df_mcmc['chain__']==1, params].mean()

# Take a look
param_means
[13]:
alpha[0]     3.260401
alpha[1]     5.192185
b[0]         6.070271
b[1]        32.019363
w            0.168397
dtype: float64

Now that we have the means for chain 0, we can use them to pass into Stan’s sampler. An easy way to do that is to pass a dictionary of starting points usint the inits kwarg for sm.sample(). Note that vector-valued parameters (such as alpha and b) need to be specified as vectors, which means you can use a Python list.

Before constructing this, I pause to note that by default Stan chooses starting values for the chains by drawing random number on the interval [-2, 2]. For constrained parameters, they are transformed to be unconstrained and then again drawn from this interval. If the posterior parameter values have very low probability mass in this interval, warmup may take longer. It is sometimes advisable to start the chains at different starting points. The method of initializing the chains we are using here is meant to deal with the label switching problem, but I am using it also to demonstrate how to provide starting points for chains.

[14]:
inits = {
    "alpha": [param_means["alpha[0]"], param_means["alpha[1]"]],
    "b": [param_means["b[0]"], param_means["b[1]"]],
    "w": param_means["w"],
}

samples = sm.sample(data=data, inits=inits)
samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:finish chain 4
INFO:cmdstanpy:finish chain 1
INFO:cmdstanpy:finish chain 3

Let’s look at the corner plot again to see how this worked.

[15]:
bokeh.io.show(bebi103.viz.corner(samples, parameters=pars))

We again have a return to a single mode in the posterior and we have fixed the identifiability problem. We can do our ad hoc model assessment by plotting the theoretical CDFs.

[16]:
# Make ECDF
p = iqplot.ecdf(data=n, x_axis_label="mRNA count")

# x-values and samples to use in plot
x = np.arange(int(1.05 * n.max()))
alpha0s = samples.posterior["alpha"].values[:, :, 0].flatten()[::40]
alpha1s = samples.posterior["alpha"].values[:, :, 1].flatten()[::40]
beta0s = samples.posterior["beta_"].values[:, :, 0].flatten()[::40]
beta1s = samples.posterior["beta_"].values[:, :, 1].flatten()[::40]
ws = samples.posterior["w"].values.flatten()[::40]

for alpha0, alpha1, beta0, beta1, w in zip(alpha0s, alpha1s, beta0s, beta1s, ws):
    y = w * st.nbinom.cdf(x, alpha0, beta0 / (1 + beta0))
    y += (1 - w) * st.nbinom.cdf(x, alpha1, beta1 / (1 + beta1))
    x_plot, y_plot = bebi103.viz.cdf_to_staircase(x, y)
    p.line(x_plot, y_plot, line_width=0.5, color="orange", level="underlay")

bokeh.io.show(p)

The mixture model performs much better than the single Negative-Binomial model.

Conclusions

You have learned how to use Stan to use MCMC to sample out of a posterior distribution. I hope it is evident how convenient and powerful this is. I also hope you have an understanding of how fragile statistical modeling can be, as you saw with a label switching-based nonidentifiability.

We have looked at some visualizations of MCMC results in this lesson, and in coming lessons, we will take a closer look at how to visualize and report MCMC results.

[17]:
bebi103.stan.clean_cmdstan()

Computing environment

[18]:
%load_ext watermark
%watermark -v -p numpy,scipy,pandas,cmdstanpy,arviz,bokeh,iqplot,holoviews,bebi103,jupyterlab
print("cmdstan   :", bebi103.stan.cmdstan_version())
Python implementation: CPython
Python version       : 3.8.5
IPython version      : 7.19.0

numpy     : 1.19.2
scipy     : 1.5.2
pandas    : 1.2.0
cmdstanpy : 0.9.67
arviz     : 0.11.0
bokeh     : 2.2.3
iqplot    : 0.2.0
holoviews : 1.14.0
bebi103   : 0.1.2
jupyterlab: 2.2.6

cmdstan   : 2.25.0