Mixture models and label switching with MCMC

Data set download

You should be sure to update your bebi103 package prior to using this notebook using

pip install --upgrade bebi103

[1]:
import numpy as np
import scipy.stats as st
import pandas as pd

import cmdstanpy
import arviz as az

import bokeh_catplot
import bebi103

import holoviews as hv
hv.extension('bokeh')

import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
Loading BokehJS ...

We continue with our analysis of the smFISH data, but this time for the Rex1 gene. Here, we saw clear bimodality in the data.

[2]:
# Load DataFrame and get counts
df = pd.read_csv("../data/singer_transcript_counts.csv", comment="#")
n = df["Rex1"].values

bokeh.io.show(
    bokeh_catplot.ecdf(n, x_axis_label="mRNA count", frame_height=150, frame_width=200)
)

Mixture models

Since the Negative Binomial distribution is unimodal, what could be the story here? It is quite possible we are seeing two different states of cells, one with one level of bursty expression of the gene of interest, and another with a different level of bursty expression. This could mean the cells are differentiating. So, we would expect the number of mRNA transcripts to be distributed according to a linear combination of negative binomial distributions. We can write out the PMF as

\begin{align} f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) &= w\,\frac{\Gamma(n + \alpha_1)}{n!\,\Gamma(\alpha_1)}\,\left(\frac{\beta_1}{1+\beta_1}\right)^{\alpha_1}\left(\frac{1}{1+\beta_1}\right)^{n} \\[1em] &\;\;\;\;+ (1-w) \,\frac{\Gamma(n + \alpha_2)}{n!\,\Gamma(\alpha_2)}\,\left(\frac{\beta_2}{1+\beta_2}\right)^{\alpha_2}\left(\frac{1}{1+\beta_2}\right)^{n} , \end{align}

where \(w\) is the probability that the burst size and frequency are determined by \(1/\beta_1\) and \(\alpha_1\). Such a model, in which a variable is distributed as a linear combination of distributions, is called a mixture model.

For this case, we can write this likelihood more concisely as

\begin{align} n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2). \end{align}

We have to specify priors on \(\alpha_1\), \(\beta_1\), \(\alpha_2\), \(\beta_2\), and \(w\). We can retain the same priors for the \(\alpha\)’s and \(\beta\)’s, and we will assume a Uniform prior for \(w\). We then have the following model.

\begin{align} &\alpha_i \sim \text{LogNorm}(0,2) \text{ for } i \in [1, 2] \\[1em] &b_i \sim \text{LogNorm}(2, 3) \text{ for } i \in [1, 2], \\[1em] &\beta_i = 1/b_i,\\[1em] &w \sim \text{Beta}(1, 1), \\[1em] &n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2). \end{align}

Coding up a mixture model

There are a few considerations for coding up a mixture model that also introduce Stan syntax. Importantly, under the hood, Stan uses the log posterior, as do almost all samplers, when sampling out of the posterior. Dealing with a mixture model presents a unique challenge for computing the log likelihood (which is one of the summands of the log posterior). Consider the log likelihood in the present example.

\begin{align} \ln f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) = \ln(w\,a_1 + (1-w)a_2), \end{align}

where

\begin{align} a_i = \frac{\Gamma(n + \alpha_i)!}{n!\,\Gamma(\alpha_i)}\,\left(\frac{\beta_i}{1+\beta_i}\right)^{\alpha_i}\left(\frac{1}{1+\beta_i}\right)^{n}. \end{align}

While the logarithm of a product is conveniently split, we cannot split the logarithm of a sum. If we consider the sum directly, we will get serious underflow errors for parameters for which the terms \(a_1\) or \(a_2\) are small. To compute this in a more numerically stable way, we need to use the log-sum-exp trick. Fortunately, Stan has a built-in function to compute the contributions to the log posterior of a mixture, the log_mix function. To update the posterior with this log_mix function, we need to add to target. In Stan, the keyword target is a special variable that holds the running sum of the contributions to the log posterior. When you make statements like alpha ~ lognormal(0.0, 2.0.), Stan is adding the appropriate terms to target under the hood. In the case of mixture models, we need to add to target explicitly. More generally, you can add any terms to target, and Stan considers these terms as part of the log posterior. The Stan code below implements this for the mixture model.

data {
  int<lower=0> N;
  int<lower=0> n[N];
}


parameters {
  vector<lower=0>[2] alpha;
  vector<lower=0>[2] b;
  real<lower=0, upper=1> w;
}


transformed parameters {
  vector[2] beta_ = 1.0 ./ b;
}


model {
  // Priors
  alpha ~ lognormal(0.0, 2.0);
  b ~ lognormal(2.0, 3.0);
  w ~ beta(1.0, 1.0);

  // Likelihood
  for (n_val in n) {
    target += log_mix(
      w,
      neg_binomial_lpmf(n_val | alpha[1], beta_[1]),
      neg_binomial_lpmf(n_val | alpha[2], beta_[2])
    );
  }
}

In addition to the log_mix function, there is some more new syntax. - To add the contribution of the Negative Binomial PMF to the log posterior, we use negative_binomial_lmpf. Every discrete Stan distribution has a function <distribution>_lpmf that gives the value of the log PMF, and every continuous distribution has a similar function <distribution>_lpdf. The arguments of the function are as you would write them on paper, with a bar (|) signifying the conditioning on the parameters. - In the above, I have specified alpha, b, and beta_ as vectors. A vector is an array that behaves like a column vector, which means you can do matrix operations with it. It also allows you to do element-wise operations. Notice that I have written

vector[2] beta_ = 1.0 ./ b;

This means that beta_ is a 2-vector and that each element is given by the inverse of the corresponding element in b. The ./ operator accomplishes this (note the dot in front of the slash). In general, preceding an operator with a . indicates that the operation should be done elementwise. - Note also that Stan is smart enough to know that if I give alpha a prior that it need to assign it independently to each element in alpha. The same is of course true for b. - To specify a lower and upper bound, we use the <lower=0, upper=1> syntax. - Note the for loop construction. for (n_val in n) means iterate over each entry in n, yielding its value as n_val within the loop. The contents of the loop are enclosed in braces. Note that we could equivalently have written the for loop as

// Likelihood
for (i in 1:N) {
  target += log_mix(
    w,
    neg_binomial_lpmf(n_val | alpha[1], beta_[1]),
    neg_binomial_lpmf(n_val | alpha[2], beta_[2])
  );
}

Here, we are looping over integers starting an 1 and ending at \(N\), inclusive. You can think of the 1:N syntax to be kind of like Python’s range() function. However, unlike Python’s range(), Stan’s range is inclusive of the end value.

Now that we have our model set up, let’s compile and sample from it! I will use the seed kwarg to set the seed for the random number generator to ensure that I always get the same result for illustrative purposes.

[3]:
sm = cmdstanpy.CmdStanModel(stan_file='mixture_1.stan')

data = {'n': n, 'N': len(n)}

samples = sm.sample(
    data=data,
    seed=523921,
    chains=4,
    sampling_iters=1000,
)

samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:stan to c++ (/Users/bois/Dropbox/git/bebi103_course/2020/b/content/lessons/lesson_04/mixture_1.hpp)
INFO:cmdstanpy:compiling c++
INFO:cmdstanpy:compiled model file: /Users/bois/Dropbox/git/bebi103_course/2020/b/content/lessons/lesson_04/mixture_1
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:finish chain 1
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 3
INFO:cmdstanpy:finish chain 4

Parsing the output

Let’s look at the results.

[4]:
samples.posterior
[4]:
<xarray.Dataset>
Dimensions:      (alpha_dim_0: 2, b_dim_0: 2, beta__dim_0: 2, chain: 4, draw: 1000)
Coordinates:
  * chain        (chain) int64 0 1 2 3
  * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
  * alpha_dim_0  (alpha_dim_0) int64 0 1
  * b_dim_0      (b_dim_0) int64 0 1
  * beta__dim_0  (beta__dim_0) int64 0 1
Data variables:
    alpha        (chain, draw, alpha_dim_0) float64 4.225 2.492 ... 4.347 4.685
    b            (chain, draw, b_dim_0) float64 37.08 4.773 ... 2.975 35.66
    w            (chain, draw) float64 0.7938 0.8895 0.8653 ... 0.1723 0.1705
    beta_        (chain, draw, beta__dim_0) float64 0.02697 0.2095 ... 0.02805
Attributes:
    created_at:                 2020-01-24T22:17:01.223954
    inference_library:          cmdstanpy
    inference_library_version:  0.8.0

Note now that the output is a bit more complex. This is because the parameters alpha, b, and beta_ are vector valued. The are respectively indexed by indexes alpha_dim_0, b_dim_0, and beta__dim_0. The samples of these vector valued parameters are then three dimensional arrays. The first two dimensions are the chain and draw, like we have already seen, but the third dimension specifies which element in the vector.

Since we will want to make plots of these samples using HoloViews, we should convert from xarray to a Pandas data frame.

[5]:
df_mcmc = samples.posterior.to_dataframe()

# Take a look
df_mcmc.head()
[5]:
alpha b w beta_
alpha_dim_0 b_dim_0 beta__dim_0 chain draw
0 0 0 0 0 4.22468 37.0786 0.793771 0.026970
1 3.96298 37.9174 0.889474 0.026373
2 4.21675 40.1750 0.865307 0.024891
3 5.15633 32.3177 0.829723 0.030943
4 4.88598 33.0659 0.824991 0.030243

As we might expect, the indexes of the xarray become indexes of the data frame. This can be cumbersome to work with when plotting because we typically wish to plot marginal distributions, e.g., of element 1 of alpha. A more convenient form might be a data frame where each element of a vector- (or matrix-) valued parameter is a column. This is accomplished using the bebi103.stan.posterior_to_dataframe() function.

[6]:
df_mcmc = bebi103.stan.posterior_to_dataframe(samples)

df_mcmc.head()
[6]:
alpha[0] alpha[1] b[0] b[1] w beta_[0] beta_[1] chain__ draw__ divergent__
0 4.22468 2.49172 37.0786 4.77340 0.793771 0.026970 0.209494 0 0 False
1 3.96298 5.74581 37.9174 1.96575 0.889474 0.026373 0.508711 0 1 False
2 4.21675 3.27079 40.1750 3.96473 0.865307 0.024891 0.252224 0 2 False
3 5.15633 4.67289 32.3177 3.02366 0.829723 0.030943 0.330725 0 3 False
4 4.88598 2.34684 33.0659 6.23863 0.824991 0.030243 0.160292 0 4 False

This data frame is easier to work with for plotting. Note that there are added columns for the chain number ('chain__') and draw ('draw__'). These have double underscores in their names for easier reference. We will not deal with the divergent__ column now, but will use that in forthcoming lessons.

Plotting the samples

To plot our samples, it is convenient to make scatter plots of samples for each pair of parameters. This means we plot all marginalized posteriors that contain two parameters. We can also plot marginalized posteriors where we only have one variable. This is conveniently done using a corner plot, as implemented in bebi103.viz.corner().

[7]:
# Parameters we want to plot
pars = ["alpha[0]", "alpha[1]", "b[0]", "b[1]", "w"]

bokeh.io.show(bebi103.viz.corner(samples, pars=pars, plot_width=125))

This looks peculiar. We see a strong bimodality in the posterior. The cause of this can be revealed if we color the glyphs by the chain ID. We only need to show one plot, we I will show \(\alpha_2\) versus \(w\).

[8]:
bokeh.io.show(
    bebi103.viz.corner(samples, pars=pars, plot_width=125, color_by_chain=True)
)

We see that three of the chains (colored blue, orange, and green) are centered around w = 0.8, while the other (colored red) is around w = 0.2. Note that these two values of w sum to unity. We have just uncovered a nonidentifiable model. A nonidentifiable model is a model for which we cannot unambiguously determine the parameter values. That is, two or more parameter sets are observationally equivalent.

Label switching

There are many reasons why a model may be nonidentifiable. In this case, we are seeing a manifestation of label switching (see section 5 of the Stan Manual). Before launching into label switchin in this particular case, I emphasize that this is certainly not the only way models can be nonidentifiable, and I am presenting this as a case study on how to deal with a particular kind of nonidentifiability. If you are seeing nonidentifiability in your model, do not automatically assume that it is because of label switching. I also am going through this case study to demonstrate that even in fairly simple models (in this case, two cell populations, each characterized by their burst size and frequency), devilish problems can arise when trying to do inference. You need to be vigilant.

In this mixture model, it is arbitrary which \((\alpha, b)\) pair we label as \((\alpha_1, b_1)\) or \((\alpha_2, b_2)\). We can switch the labels, and also change \(w\) to \(1-w\), and we have exactly the same posterior probability. To demonstrate that this is the case, I will generate the same grid of plots as above, but switch the labels the appropriate labels and convert \(w\) to \(1-w\) for every \(w < 0.5\). (Note that this will not in general work, especially if the different modes from label switching overlap, and is not a good idea for analysis; I’m just doing it here to illustrate how label switching leads to nonidentifiability.)

[9]:
# Perform the label switch
switch = df_mcmc.loc[df_mcmc["w"] > 0.5, pars]
switch = switch.rename(
    columns={
        "b[0]": "b[1]",
        "b[1]": "b[0]",
        "alpha[1]": "alpha[0]",
        "alpha[0]": "alpha[1]",
    }
)
switch["w"] = 1 - switch["w"]

df_switch = pd.concat([df_mcmc.loc[df_mcmc["w"] < 0.5, pars], switch], sort=True)

# Make corner plot
bokeh.io.show(bebi103.viz.corner(df_switch, pars=pars, plot_width=125))

We see that if we fix the label switching, the posterior is indeed unimodal. So, making an identifiable model in this case means that we have to deal with the label switching problem. There are many approaches to doing this, and you can see a very detailed discussion about dealing with label switch and other problems associated with mixture models here.

We will take a strategy suggested in earlier versions of the Stan manual. We noted earlier that the chains tended to stay on a single mode. We will initialize the chains to instead start near only one mode. This, like any of the fixes for mixture models, does not guarantee that we get good sampling (and we will discuss more diagnostics for good in a future tutorial), but it in practice it can work quite well, as it will here.

To determine where to start the chains, we will select a chain and start the samplers at the mean. First, we need to compute the parameter means for the first chain.

[10]:
# Compute mean of parameters for chain 1
params = ['alpha[0]', 'alpha[1]', 'b[0]', 'b[1]', 'w']
param_means = df_mcmc.loc[df_mcmc['chain__']==1, params].mean()

# Take a look
param_means
[10]:
alpha[0]     5.192274
alpha[1]     3.713972
b[0]        31.938437
b[1]         5.646036
w            0.833541
dtype: float64

Now that we have the means for chain 0, we can use them to pass into Stan’s sampler. An easy way to do that is to pass a dictionary of starting points usint the inits kwarg for sm.sample(). Note that vector-values parameters (such as alpha and b) need to be specified as vectors, which means you can use a Python list.

Before constructing this, I pause to note that by default Stan chooses starting values for the chains by drawing random number on the interval [-2, 2]. For constrained parameters, they are transformed to be unconstrained and then again drawn from this interval. If the posterior parameter values have very low probability mass in this interval, warmup may take longer. It is sometimes advisable to start the chains at different starting points.

[11]:
inits = {
    "alpha": [param_means["alpha[0]"], param_means["alpha[1]"]],
    "b": [param_means["b[0]"], param_means["b[1]"]],
    "w": param_means["w"],
}

samples = sm.sample(data=data, inits=inits)
samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:finish chain 1
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 3
INFO:cmdstanpy:finish chain 4

Let’s look at the corner plot again to see how this worked.

[12]:
# Make corner plot
bokeh.io.show(
    bebi103.viz.corner(
        samples, pars=pars, plot_width=125, xtick_label_orientation=np.pi / 4
    )
)

We again have a return to a single mode in the posterior and we have fixed the identifiability problem. We can do our ad hoc model assessment by plotting the theoretical CDFs.

[13]:
# Make ECDF
p = bokeh_catplot.ecdf(data=n, x_axis_label='mRNA count')

# x-values and samples to use in plot
x = np.arange(426)
alpha0s = samples.posterior['alpha'].values[:,:,0].flatten()[::40]
alpha1s = samples.posterior['alpha'].values[:,:,1].flatten()[::40]
beta0s = samples.posterior['beta_'].values[:,:,0].flatten()[::40]
beta1s = samples.posterior['beta_'].values[:,:,1].flatten()[::40]
ws = samples.posterior['w'].values.flatten()[::40]

for alpha0, alpha1, beta0, beta1, w in zip(alpha0s, alpha1s, beta0s, beta1s, ws):
    y = w * st.nbinom.cdf(x, alpha0, beta0/(1+beta0))
    y += (1-w) * st.nbinom.cdf(x, alpha1, beta1/(1+beta1))
    x_plot, y_plot = bebi103.viz.cdf_to_staircase(x, y)
    p.line(x_plot, y_plot, line_width=0.5, color='orange', level='underlay')

bokeh.io.show(p)

The mixture model performs much better than the single Negative-Binomial model.

Conclusions

You have learned how to use Stan to use MCMC to sample out of a posterior distribution. I hope it is evident how convenient and powerful this is. I also hope you have an understanding of how fragile statistical modeling can be, as you saw with a label switching-based nonidentifiability.

We have looked at some visualizations of MCMC results in this lesson, and in coming lessons, we will take a closer look at how to visualize and report MCMC results.

[14]:
bebi103.stan.clean_cmdstan()

Computing environment

[15]:
%load_ext watermark
%watermark -v -p numpy,scipy,pandas,cmdstanpy,arviz,bokeh,bokeh_catplot,holoviews,bebi103,jupyterlab
CPython 3.7.6
IPython 7.11.1

numpy 1.18.1
scipy 1.3.1
pandas 0.24.2
cmdstanpy 0.8.0
arviz 0.6.1
bokeh 1.4.0
bokeh_catplot 0.1.7
holoviews 1.12.7
bebi103 0.0.50
jupyterlab 1.2.5