10. Mixture models and label switching with MCMC¶
[1]:
# Colab setup ------------------
import os, sys, subprocess
if "google.colab" in sys.modules:
cmd = "pip install --upgrade iqplot colorcet bebi103 arviz cmdstanpy watermark"
process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
import cmdstanpy; cmdstanpy.install_cmdstan()
data_path = "https://s3.amazonaws.com/bebi103.caltech.edu/data/"
else:
data_path = "../data/"
# ------------------------------
import numpy as np
import scipy.stats as st
import pandas as pd
import cmdstanpy
import arviz as az
import iqplot
import bebi103
import holoviews as hv
hv.extension('bokeh')
import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
We continue with our analysis of the smFISH data, but this time for the Rex1 gene. Here, we saw clear bimodality in the data.
[2]:
# Load DataFrame and get counts
df = pd.read_csv(os.path.join(data_path, "singer_transcript_counts.csv"), comment="#")
n = df["Rex1"].values
bokeh.io.show(
iqplot.ecdf(n, x_axis_label="mRNA count", frame_height=150, frame_width=200)
)
Mixture models¶
Since the Negative Binomial distribution is unimodal, what could be the story here? It is quite possible we are seeing two different states of cells, one with one level of bursty expression of the gene of interest, and another with a different level of bursty expression. This could mean the cells are differentiating. So, we would expect the number of mRNA transcripts to be distributed according to a linear combination of two negative binomial distributions. We can write out the PMF as
\begin{align} f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) &= w\,\frac{\Gamma(n + \alpha_1)}{n!\,\Gamma(\alpha_1)}\,\left(\frac{\beta_1}{1+\beta_1}\right)^{\alpha_1}\left(\frac{1}{1+\beta_1}\right)^{n} \\[1em] &\;\;\;\;+ (1-w) \,\frac{\Gamma(n + \alpha_2)}{n!\,\Gamma(\alpha_2)}\,\left(\frac{\beta_2}{1+\beta_2}\right)^{\alpha_2}\left(\frac{1}{1+\beta_2}\right)^{n} , \end{align}
where \(w\) is the probability that the burst size and frequency are determined by \(1/\beta_1\) and \(\alpha_1\). Such a model, in which a variable is distributed as a linear combination of distributions, is called a mixture model.
For this case, we can write this likelihood more concisely as
\begin{align} n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2)\;\;\forall i. \end{align}
We have to specify priors on \(\alpha_1\), \(\beta_1\), \(\alpha_2\), \(\beta_2\), and \(w\). We can retain the same priors from the previous lesson for the \(\alpha\)’s and \(\beta\)’s, and we will assume a Uniform prior for \(w\). We then have the following model.
\begin{align} &\log_{10} \alpha_i \sim \text{Norm}(0,1) \text{ for } i \in [1, 2] \\[1em] &\log_{10} b_i \sim \text{Norm}(2, 1) \text{ for } i \in [1, 2], \\[1em] &\beta_i = 1/b_i,\\[1em] &w \sim \text{Beta}(1, 1), \\[1em] &n_i \sim w \, \text{NegBinom}(\alpha_1, \beta_1) + (1-w)\,\text{NegBinom}(\alpha_2, \beta_2)\;\;\forall i. \end{align}
Note that since the prior for \(w\) is Uniform (Beta(1, 1)), we do not need to explicitly consider its prior, but simply enforce constrains on \(w\).
Coding up a mixture model¶
There are a few considerations for coding up a mixture model that also introduce Stan syntax. Importantly, under the hood, Stan uses the log posterior, as do almost all samplers, when sampling out of the posterior. Dealing with a mixture model presents a unique challenge for computing the log likelihood (which is one of the summands of the log posterior). Consider the log likelihood in the present example.
\begin{align} \ln f(n\mid \alpha_1, \alpha_2, \beta_1, \beta_2, w) = \ln(w\,a_1 + (1-w)a_2), \end{align}
where
\begin{align} a_i = \frac{\Gamma(n + \alpha_i)!}{n!\,\Gamma(\alpha_i)}\,\left(\frac{\beta_i}{1+\beta_i}\right)^{\alpha_i}\left(\frac{1}{1+\beta_i}\right)^{n}. \end{align}
While the logarithm of a product is conveniently split, we cannot split the logarithm of a sum. If we consider the sum directly, we will get serious underflow errors for parameters for which the terms \(a_1\) or \(a_2\) are small. To compute this in a more numerically stable way, we need to use the log-sum-exp trick. Fortunately, Stan has a built-in function to compute the contributions to the log posterior of a mixture, the log_mix
function. To update the posterior with this log_mix
function, we need to add to target
. In Stan, the keyword target
is a special variable that holds the running sum of the contributions to the log posterior. When you make statements like theta ~ normal(0.0, 1.0)
, Stan is adding the appropriate terms to target
under the hood. In the case of mixture models, we need to add to target
explicitly. More generally, you can add any terms to target
, and Stan considers these
terms as part of the log posterior. The Stan code below implements this for the mixture model.
data {
int<lower=0> N;
int<lower=0> n[N];
}
parameters {
vector<lower=0>[2] alpha;
vector<lower=0>[2] b;
real<lower=0, upper=1> w;
}
transformed parameters {
vector[2] beta_ = 1.0 ./ b;
}
model {
// Priors
alpha ~ lognormal(0.0, 2.3);
b ~ lognormal(4.6, 2.3);
w ~ beta(1.0, 1.0);
// Likelihood
for (n_val in n) {
target += log_mix(
w,
neg_binomial_lpmf(n_val | alpha[1], beta_[1]),
neg_binomial_lpmf(n_val | alpha[2], beta_[2])
);
}
}
In addition to the log_mix
function, there is some more new syntax. - To add the contribution of the Negative Binomial PMF to the log posterior, we use negative_binomial_lpmf
. Every discrete Stan distribution has a function <distribution>_lpmf
that gives the value of the log PMF, and every continuous distribution has a similar function <distribution>_lpdf
. The arguments of the function are as you would write them on paper, with a bar (|
) signifying the conditioning on the
parameters. - In the above, I have specified alpha
, b
, and beta_
as vectors. A vector is an array that behaves like a column vector, which means you can do matrix operations with it. It also allows you to do element-wise operations. Notice that I have written
vector[2] beta_ = 1.0 ./ b;
This means that beta_
is a 2-vector and that each element is given by the inverse of the corresponding element in b
. The ./
operator accomplishes this (note the dot in front of the slash). In general, preceding an operator with a .
indicates that the operation should be done elementwise. - Note also that Stan is smart enough to know that if I give alpha
a prior that it needs to assign it independently to each element in alpha
. The same is of course true for b
. - To
specify a lower and upper bound, we use the <lower=0, upper=1>
syntax. - Note the for
loop construction. for (n_val in n)
means iterate over each entry in n
, yielding its value as n_val
within the loop. The contents of the loop are enclosed in braces. Note that we could equivalently have written the for
loop as
// Likelihood
for (i in 1:N) {
target += log_mix(
w,
neg_binomial_lpmf(n[i] | alpha[1], beta_[1]),
neg_binomial_lpmf(n[i] | alpha[2], beta_[2])
);
}
Here, we are looping over integers starting an 1 and ending at \(N\), inclusive. You can think of the 1:N
syntax to be kind of like Python’s range()
function. However, unlike Python’s range()
, Stan’s range is inclusive of the end value.
Now that we have our model set up, let’s compile and sample from it! I will use the seed
kwarg to set the seed for the random number generator to ensure that I always get the same result for illustrative purposes.
[3]:
sm = cmdstanpy.CmdStanModel(stan_file='mixture.stan')
data = {'n': n, 'N': len(n)}
samples = sm.sample(
data=data,
seed=3252,
chains=4,
iter_sampling=1000,
)
samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:compiling stan program, exe file: /Users/bois/Dropbox/git/bebi103_course/2021/b/content/lessons/10/mixture
INFO:cmdstanpy:compiler options: stanc_options=None, cpp_options=None
INFO:cmdstanpy:compiled model file: /Users/bois/Dropbox/git/bebi103_course/2021/b/content/lessons/10/mixture
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 1
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:finish chain 3
INFO:cmdstanpy:finish chain 4
Parsing the output¶
Let’s look at the results.
[4]:
samples
[4]:
-
- alpha_dim_0: 2
- b_dim_0: 2
- beta__dim_0: 2
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- alpha_dim_0(alpha_dim_0)int640 1
array([0, 1])
- b_dim_0(b_dim_0)int640 1
array([0, 1])
- beta__dim_0(beta__dim_0)int640 1
array([0, 1])
- alpha(chain, draw, alpha_dim_0)float645.07 1.545 5.042 ... 5.413 2.415
array([[[5.07007, 1.54548], [5.04235, 2.95352], [5.07235, 2.75516], ..., [6.99711, 1.9303 ], [6.46584, 2.09056], [5.64056, 1.69207]], [[3.04233, 4.99325], [3.20847, 5.10905], [3.29715, 4.64033], ..., [3.02517, 3.71043], [2.30623, 4.05702], [5.00969, 4.05395]], [[6.4903 , 2.69146], [5.95053, 1.8708 ], [5.75369, 1.82934], ..., [4.75807, 2.67088], [5.44997, 1.28233], [6.93184, 2.48345]], [[5.2402 , 2.32556], [4.20055, 2.6144 ], [6.53649, 2.02032], ..., [5.4086 , 2.86373], [5.67289, 3.11592], [5.41349, 2.41479]]])
- b(chain, draw, b_dim_0)float6432.26 10.05 32.27 ... 28.9 5.952
array([[[32.2602 , 10.0512 ], [32.2721 , 4.74096], [32.0265 , 4.16415], ..., [24.2533 , 12.5999 ], [25.4106 , 8.15408], [28.3603 , 11.2862 ]], [[ 4.77957, 32.7551 ], [ 4.84619, 31.3895 ], [ 4.57538, 34.6801 ], ..., [ 4.27527, 42.9493 ], [ 6.65832, 43.1293 ], [ 2.245 , 39.8205 ]], [[27.1845 , 6.79834], [27.8398 , 7.77275], [28.0022 , 7.67535], ..., [32.8354 , 5.19881], [31.1361 , 14.4174 ], [23.7428 , 7.93736]], [[31.6094 , 8.10888], [38.1369 , 5.29526], [25.1742 , 7.02706], ..., [28.7685 , 4.94521], [29.6885 , 5.11471], [28.9044 , 5.9515 ]]])
- w(chain, draw)float640.8312 0.8276 ... 0.8299 0.7957
array([[0.831249, 0.827551, 0.824924, ..., 0.775423, 0.813166, 0.803763], [0.164336, 0.166227, 0.183446, ..., 0.171863, 0.182348, 0.158067], [0.818529, 0.830456, 0.829678, ..., 0.840864, 0.806454, 0.788392], [0.825284, 0.83569 , 0.837702, ..., 0.828062, 0.829867, 0.795699]])
- beta_(chain, draw, beta__dim_0)float640.031 0.09949 ... 0.0346 0.168
array([[[0.030998 , 0.099491 ], [0.0309865, 0.210928 ], [0.0312242, 0.240145 ], ..., [0.0412315, 0.0793659], [0.0393537, 0.122638 ], [0.0352606, 0.0886034]], [[0.209224 , 0.0305296], [0.206348 , 0.0318578], [0.218561 , 0.028835 ], ..., [0.233903 , 0.0232832], [0.150188 , 0.0231861], [0.445435 , 0.0251127]], [[0.0367857, 0.147095 ], [0.0359198, 0.128655 ], [0.0357115, 0.130287 ], ..., [0.0304549, 0.192352 ], [0.032117 , 0.0693609], [0.0421181, 0.125987 ]], [[0.0316361, 0.123322 ], [0.0262213, 0.188848 ], [0.0397232, 0.142307 ], ..., [0.0347603, 0.202216 ], [0.0336831, 0.195515 ], [0.0345969, 0.168025 ]]])
- created_at :
- 2021-01-20T23:49:21.358416
- arviz_version :
- 0.11.0
- inference_library :
- cmdstanpy
- inference_library_version :
- 0.9.67
<xarray.Dataset> Dimensions: (alpha_dim_0: 2, b_dim_0: 2, beta__dim_0: 2, chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 * alpha_dim_0 (alpha_dim_0) int64 0 1 * b_dim_0 (b_dim_0) int64 0 1 * beta__dim_0 (beta__dim_0) int64 0 1 Data variables: alpha (chain, draw, alpha_dim_0) float64 5.07 1.545 ... 5.413 2.415 b (chain, draw, b_dim_0) float64 32.26 10.05 32.27 ... 28.9 5.952 w (chain, draw) float64 0.8312 0.8276 0.8249 ... 0.8299 0.7957 beta_ (chain, draw, beta__dim_0) float64 0.031 0.09949 ... 0.168 Attributes: created_at: 2021-01-20T23:49:21.358416 arviz_version: 0.11.0 inference_library: cmdstanpy inference_library_version: 0.9.67
xarray.Dataset -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- lp(chain, draw)float64-1.6e+03 -1.598e+03 ... -1.6e+03
array([[-1600.23, -1597.71, -1600.4 , ..., -1600.1 , -1598.89, -1599.43], [-1597.53, -1597.63, -1598.36, ..., -1601.78, -1603.58, -1600.04], [-1601.79, -1599.76, -1600.03, ..., -1598.75, -1601.21, -1600.13], [-1598.27, -1599.51, -1601.2 , ..., -1598.98, -1598.32, -1600.19]])
- accept_stat(chain, draw)float640.9941 0.9491 ... 0.6901 0.9339
array([[0.994083, 0.949134, 0.685197, ..., 0.984979, 0.952573, 0.854717], [0.964653, 0.926773, 0.997354, ..., 0.919948, 0.925812, 0.93059 ], [0.798182, 0.932039, 0.988172, ..., 0.989637, 0.99152 , 0.999965], [0.991132, 0.998843, 0.905153, ..., 0.875167, 0.690058, 0.933936]])
- stepsize(chain, draw)float640.1438 0.1438 ... 0.1303 0.1303
array([[0.143791, 0.143791, 0.143791, ..., 0.143791, 0.143791, 0.143791], [0.140028, 0.140028, 0.140028, ..., 0.140028, 0.140028, 0.140028], [0.10926 , 0.10926 , 0.10926 , ..., 0.10926 , 0.10926 , 0.10926 ], [0.130277, 0.130277, 0.130277, ..., 0.130277, 0.130277, 0.130277]])
- treedepth(chain, draw)int645 4 2 4 4 5 4 3 ... 3 4 4 4 3 4 2 5
array([[5, 4, 2, ..., 5, 5, 5], [3, 3, 4, ..., 5, 4, 4], [2, 3, 3, ..., 2, 5, 5], [4, 5, 5, ..., 4, 2, 5]])
- n_leapfrog(chain, draw)int6431 15 3 19 31 39 ... 31 7 23 3 31
array([[31, 15, 3, ..., 31, 31, 31], [15, 7, 31, ..., 31, 31, 27], [ 3, 11, 7, ..., 7, 63, 31], [31, 43, 31, ..., 23, 3, 31]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- energy(chain, draw)float641.602e+03 1.601e+03 ... 1.601e+03
array([[1601.75, 1601.15, 1601.31, ..., 1601.64, 1601.11, 1601.27], [1601.78, 1598.5 , 1598.66, ..., 1603.89, 1605.04, 1608.36], [1602.73, 1605.32, 1601.02, ..., 1601.28, 1603.47, 1602.19], [1601.55, 1601.6 , 1602.82, ..., 1600.76, 1604.71, 1601.02]])
- created_at :
- 2021-01-20T23:49:21.362240
- arviz_version :
- 0.11.0
- inference_library :
- cmdstanpy
- inference_library_version :
- 0.9.67
<xarray.Dataset> Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: lp (chain, draw) float64 -1.6e+03 -1.598e+03 ... -1.6e+03 accept_stat (chain, draw) float64 0.9941 0.9491 0.6852 ... 0.6901 0.9339 stepsize (chain, draw) float64 0.1438 0.1438 0.1438 ... 0.1303 0.1303 treedepth (chain, draw) int64 5 4 2 4 4 5 4 3 4 4 ... 4 5 3 4 4 4 3 4 2 5 n_leapfrog (chain, draw) int64 31 15 3 19 31 39 31 ... 31 27 31 7 23 3 31 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float64 1.602e+03 1.601e+03 ... 1.601e+03 Attributes: created_at: 2021-01-20T23:49:21.362240 arviz_version: 0.11.0 inference_library: cmdstanpy inference_library_version: 0.9.67
xarray.Dataset
The arviz.InferenceData
object has two xarray DataSets, posterior
and sample_stats
. We will work extensively with sample_stats
in future lessons (and find that they are crucial for checking your sampling!), but for now will focus on the posterior
object that has the samples. You can look at it be expanding the posterior
view above.
Note now that the posterior
output is a bit more complex. This is because the parameters alpha
, b
, and beta_
are vector-valued. The entries in the vectors are respectively indexed by indexes alpha_dim_0
, b_dim_0
, and beta__dim_0
. The samples of these vector valued parameters are then three dimensional arrays. The first two dimensions are the chain
and draw
, like we have already seen, but the third dimension specifies which element in the vector.
Samples are selecting using xarray data selection, which you can read about in the xarray docs. As an example, to access sample number 478 from chain 2, do the following.
[5]:
samples.posterior.loc[dict(chain=2, draw=478)]
[5]:
<xarray.Dataset> Dimensions: (alpha_dim_0: 2, b_dim_0: 2, beta__dim_0: 2) Coordinates: chain int64 2 draw int64 478 * alpha_dim_0 (alpha_dim_0) int64 0 1 * b_dim_0 (b_dim_0) int64 0 1 * beta__dim_0 (beta__dim_0) int64 0 1 Data variables: alpha (alpha_dim_0) float64 4.498 8.293 b (b_dim_0) float64 35.77 1.554 w float64 0.8408 beta_ (beta__dim_0) float64 0.02795 0.6433 Attributes: created_at: 2021-01-20T23:49:21.358416 arviz_version: 0.11.0 inference_library: cmdstanpy inference_library_version: 0.9.67
- alpha_dim_0: 2
- b_dim_0: 2
- beta__dim_0: 2
- chain()int642
array(2)
- draw()int64478
array(478)
- alpha_dim_0(alpha_dim_0)int640 1
array([0, 1])
- b_dim_0(b_dim_0)int640 1
array([0, 1])
- beta__dim_0(beta__dim_0)int640 1
array([0, 1])
- alpha(alpha_dim_0)float644.498 8.293
array([4.4981 , 8.29259])
- b(b_dim_0)float6435.77 1.554
array([35.7747 , 1.55439])
- w()float640.8408
array(0.840791)
- beta_(beta__dim_0)float640.02795 0.6433
array([0.0279527, 0.643337 ])
- created_at :
- 2021-01-20T23:49:21.358416
- arviz_version :
- 0.11.0
- inference_library :
- cmdstanpy
- inference_library_version :
- 0.9.67
If you wanted only the value of \(\alpha_2\) from that sample, you would do (noting the zero-based indexing of xarrays):
[6]:
samples.posterior["alpha"].loc[dict(chain=2, draw=478, alpha_dim_0=1)]
[6]:
<xarray.DataArray 'alpha' ()> array(8.29259) Coordinates: chain int64 2 draw int64 478 alpha_dim_0 int64 1
- 8.293
array(8.29259)
- chain()int642
array(2)
- draw()int64478
array(478)
- alpha_dim_0()int641
array(1)
If you wanted it as a scalar, you would use the values
attribute and convert the resulting zero-dimension Numpy array to a float
.
[7]:
float(samples.posterior["alpha"].loc[dict(chain=2, draw=478, alpha_dim_0=1)].values)
[7]:
8.29259
While xarray data types are quite powerful, it is often more convenient to work with the more familiar Pandas DataFrames. The .to_dataframe()
method of xarrays is quite useful for this purpose.
[8]:
samples.posterior.to_dataframe().head()
[8]:
alpha | b | w | beta_ | |||||
---|---|---|---|---|---|---|---|---|
alpha_dim_0 | b_dim_0 | beta__dim_0 | chain | draw | ||||
0 | 0 | 0 | 0 | 0 | 5.07007 | 32.2602 | 0.831249 | 0.030998 |
1 | 5.04235 | 32.2721 | 0.827551 | 0.030987 | ||||
2 | 5.07235 | 32.0265 | 0.824924 | 0.031224 | ||||
3 | 5.04161 | 32.2483 | 0.835578 | 0.031009 | ||||
4 | 4.27491 | 38.1043 | 0.868500 | 0.026244 |
As we might expect, the indexes of the xarray become indexes of the data frame. This can be cumbersome to work with when plotting because we typically wish to plot marginal distributions, e.g., of element 1 of alpha
. A more convenient form might be a data frame where each element of a vector- (or matrix-) valued parameter is a column. This is accomplished using the bebi103.stan.arviz_to_dataframe()
function.
[9]:
df_mcmc = bebi103.stan.arviz_to_dataframe(samples)
df_mcmc.head()
[9]:
alpha[0] | alpha[1] | b[0] | b[1] | w | beta_[0] | beta_[1] | chain__ | draw__ | diverging__ | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 5.07007 | 1.54548 | 32.2602 | 10.05120 | 0.831249 | 0.030998 | 0.099491 | 0 | 0 | False |
1 | 5.04235 | 2.95352 | 32.2721 | 4.74096 | 0.827551 | 0.030987 | 0.210928 | 0 | 1 | False |
2 | 5.07235 | 2.75516 | 32.0265 | 4.16415 | 0.824924 | 0.031224 | 0.240145 | 0 | 2 | False |
3 | 5.04161 | 2.89813 | 32.2483 | 6.03983 | 0.835578 | 0.031009 | 0.165568 | 0 | 3 | False |
4 | 4.27491 | 4.73762 | 38.1043 | 2.62758 | 0.868500 | 0.026244 | 0.380578 | 0 | 4 | False |
This data frame is easier to work with for plotting. Note that there are added columns for the chain number ('chain__'
) and draw ('draw__'
). These have double underscores in their names for easier reference. This function also includes some diagnostics from the sample_stats
group of the ArviZ InferenceData in the data frame, in this case the diverging__
column, and we will discuss the meaning of those in future lessons.
Plotting the samples¶
To plot our samples, it is convenient to make scatter plots of samples for each pair of parameters. This means we plot all marginalized posteriors that contain two parameters. We can also plot marginalized posteriors where we only have one variable. This is conveniently done using a corner plot, as implemented in bebi103.viz.corner()
.
[10]:
# Parameters we want to plot with pretty names for display
pars = [("alpha[0]", "α₁"), ("alpha[1]", "α₂"), ("b[0]", "b₁"), ("b[1]", "b₂"), "w"]
bokeh.io.show(bebi103.viz.corner(samples, parameters=pars))
This looks peculiar. We see a strong bimodality in the posterior. The cause of this can be revealed if we color the glyphs by the chain ID, accomplished using the color_by_chain
kwarg. We only need to show one plot, we I will show \(w\) vs. \(\alpha_2\).
[11]:
bokeh.io.show(
bebi103.viz.corner(
samples, parameters=[("alpha[1]", "α₂"), "w"], color_by_chain=True
)
)
We see that three of the chains (colored blue, red, and green) are centered around w ≈ 0.8, while the other (colored orange) is around w ≈ 0.2. Note that these two values of w sum to unity. We have just uncovered a nonidentifiable model. A nonidentifiable model is a model for which we cannot unambiguously determine the parameter values. That is, two or more parameter sets are observationally equivalent.
Label switching¶
There are many reasons why a model may be nonidentifiable. In this case, we are seeing a manifestation of label switching (see section 5 of the Stan Manual). Before launching into label switching in this particular case, I emphasize that this is certainly not the only way models can be nonidentifiable, and I am presenting this as a case study on how to deal with a particular kind of nonidentifiability. If you are seeing nonidentifiability in your model, do not automatically assume that it is because of label switching. I also am going through this case study to demonstrate that even in fairly simple models (in this case, two cell populations, each characterized by their burst size and frequency), devilish problems can arise when trying to do inference. You need to be vigilant.
In this mixture model, it is arbitrary which \((\alpha, b)\) pair we label as \((\alpha_1, b_1)\) or \((\alpha_2, b_2)\). We can switch the labels, and also change \(w\) to \(1-w\), and we have exactly the same posterior probability. To demonstrate that this is the case, I will generate the same grid of plots as above, but switch the labels the appropriate labels and convert \(w\) to \(1-w\) for every \(w < 0.5\). (Note that this will not in general work, especially if the different modes from label switching overlap, and is not a good idea for analysis; I’m just doing it here to illustrate how label switching leads to nonidentifiability.)
[12]:
# Perform the label switch
params = ["alpha[0]", "alpha[1]", "b[0]", "b[1]", "w"]
switch = df_mcmc.loc[df_mcmc["w"] > 0.5, params]
switch = switch.rename(
columns={
"b[0]": "b[1]",
"b[1]": "b[0]",
"alpha[1]": "alpha[0]",
"alpha[0]": "alpha[1]",
}
)
switch["w"] = 1 - switch["w"]
df_switch = pd.concat([df_mcmc.loc[df_mcmc["w"] < 0.5, params], switch], sort=True)
# Make corner plot
bokeh.io.show(bebi103.viz.corner(df_switch, parameters=pars))
We see that if we fix the label switching, the posterior is indeed unimodal. So, making an identifiable model in this case means that we have to deal with the label switching problem. There are many approaches to doing this, and you can see a very detailed discussion about dealing with label switch and other problems associated with mixture models in this blog post by Michael Betancourt.
In looking at the corner plot, we see a very strong, non-Normal correlation between \(\alpha_1\) and \(b_1\) and also between \(\alpha_2\) and \(\beta_2\). There is clearly some structure in the posterior that would be impossible to discover using the MAP alone.
Initializing walkers¶
We will take a strategy suggested in earlier versions of the Stan manual. We noted earlier that the chains tended to stay on a single mode. We will initialize the chains to instead start near only one mode. This, like any of the fixes for mixture models, does not guarantee that we get good sampling (and we will discuss more diagnostics for good in a future tutorial), but it in practice it can work quite well, as it will here.
To determine where to start the chains, we will select a chain and start the samplers at the mean. First, we need to compute the parameter means for the first chain.
[13]:
# Compute mean of parameters for chain 1
param_means = df_mcmc.loc[df_mcmc['chain__']==1, params].mean()
# Take a look
param_means
[13]:
alpha[0] 3.260401
alpha[1] 5.192185
b[0] 6.070271
b[1] 32.019363
w 0.168397
dtype: float64
Now that we have the means for chain 0, we can use them to pass into Stan’s sampler. An easy way to do that is to pass a dictionary of starting points usint the inits
kwarg for sm.sample()
. Note that vector-valued parameters (such as alpha
and b
) need to be specified as vectors, which means you can use a Python list.
Before constructing this, I pause to note that by default Stan chooses starting values for the chains by drawing random number on the interval [-2, 2]. For constrained parameters, they are transformed to be unconstrained and then again drawn from this interval. If the posterior parameter values have very low probability mass in this interval, warmup may take longer. It is sometimes advisable to start the chains at different starting points. The method of initializing the chains we are using here is meant to deal with the label switching problem, but I am using it also to demonstrate how to provide starting points for chains.
[14]:
inits = {
"alpha": [param_means["alpha[0]"], param_means["alpha[1]"]],
"b": [param_means["b[0]"], param_means["b[1]"]],
"w": param_means["w"],
}
samples = sm.sample(data=data, inits=inits)
samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:finish chain 4
INFO:cmdstanpy:finish chain 1
INFO:cmdstanpy:finish chain 3
Let’s look at the corner plot again to see how this worked.
[15]:
bokeh.io.show(bebi103.viz.corner(samples, parameters=pars))
We again have a return to a single mode in the posterior and we have fixed the identifiability problem. We can do our ad hoc model assessment by plotting the theoretical CDFs.
[16]:
# Make ECDF
p = iqplot.ecdf(data=n, x_axis_label="mRNA count")
# x-values and samples to use in plot
x = np.arange(int(1.05 * n.max()))
alpha0s = samples.posterior["alpha"].values[:, :, 0].flatten()[::40]
alpha1s = samples.posterior["alpha"].values[:, :, 1].flatten()[::40]
beta0s = samples.posterior["beta_"].values[:, :, 0].flatten()[::40]
beta1s = samples.posterior["beta_"].values[:, :, 1].flatten()[::40]
ws = samples.posterior["w"].values.flatten()[::40]
for alpha0, alpha1, beta0, beta1, w in zip(alpha0s, alpha1s, beta0s, beta1s, ws):
y = w * st.nbinom.cdf(x, alpha0, beta0 / (1 + beta0))
y += (1 - w) * st.nbinom.cdf(x, alpha1, beta1 / (1 + beta1))
x_plot, y_plot = bebi103.viz.cdf_to_staircase(x, y)
p.line(x_plot, y_plot, line_width=0.5, color="orange", level="underlay")
bokeh.io.show(p)
The mixture model performs much better than the single Negative-Binomial model.
Conclusions¶
You have learned how to use Stan to use MCMC to sample out of a posterior distribution. I hope it is evident how convenient and powerful this is. I also hope you have an understanding of how fragile statistical modeling can be, as you saw with a label switching-based nonidentifiability.
We have looked at some visualizations of MCMC results in this lesson, and in coming lessons, we will take a closer look at how to visualize and report MCMC results.
[17]:
bebi103.stan.clean_cmdstan()
Computing environment¶
[18]:
%load_ext watermark
%watermark -v -p numpy,scipy,pandas,cmdstanpy,arviz,bokeh,iqplot,holoviews,bebi103,jupyterlab
print("cmdstan :", bebi103.stan.cmdstan_version())
Python implementation: CPython
Python version : 3.8.5
IPython version : 7.19.0
numpy : 1.19.2
scipy : 1.5.2
pandas : 1.2.0
cmdstanpy : 0.9.67
arviz : 0.11.0
bokeh : 2.2.3
iqplot : 0.2.0
holoviews : 1.14.0
bebi103 : 0.1.2
jupyterlab: 2.2.6
cmdstan : 2.25.0