Parameter estimation with Markov chain Monte Carlo¶
[1]:
# Colab setup ------------------
import os, sys, subprocess
if "google.colab" in sys.modules:
cmd = "pip install --upgrade iqplot colorcet bebi103 arviz cmdstanpy watermark"
process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
import cmdstanpy; cmdstanpy.install_cmdstan()
data_path = "https://s3.amazonaws.com/bebi103.caltech.edu/data/"
else:
data_path = "../data/"
# ------------------------------
import numpy as np
import scipy.stats as st
import pandas as pd
import cmdstanpy
import arviz as az
import iqplot
import bebi103
import holoviews as hv
hv.extension('bokeh')
bebi103.hv.set_defaults()
import bokeh.io
import bokeh.plotting
bokeh.io.output_notebook()
In this lesson, we will learn how to use Markov chain Monte Carlo to do parameter estimation. To get the basic idea behind MCMC, imagine for a moment that we can draw samples out of the posterior distribution. This means that the probability of choosing given values of a set of parameters is proportional to the posterior probability of that set of values. If we drew many many such samples, we could reconstruct the posterior from the samples, e.g., by making histograms. That’s a big thing to imagine: that we can draw properly weighted samples. But, it turns out that we can! That is what MCMC allows us to do.
We discussed some theory behind this seemingly miraculous capability in lecture. For this lesson, we will just use the fact that we can do the sampling to learn about posterior distributions in the context of parameter estimation.
Stan: Our MCMC engine¶
We will use Stan as our main engine for performing MCMC, and will use one of its Python interfaces, CmdStanPy. Stan is the state of the art for MCMC. Importantly, it is also a probabilistic programming language, which allows us to more easily specify Bayesian generative models. The Stan documentation will be very useful for you.
The data set¶
The data come from the Elowitz lab, published in Singer et al., Dynamic Heterogeneity and DNA Methylation in Embryonic Stem Cells, Molec. Cell, 55, 319-331, 2014, available here. In the following paragraphs, I repeat the description of the data set and EDA from last term:
In this paper, the authors investigated cell populations of embryonic stem cells using RNA single molecule fluorescence in situ hybridization (smFISH), a technique that enables them to count the number of mRNA transcripts in a cell for a given gene. They were able to measure four different genes in the same cells. So, for one experiment, they get the counts of four different genes in a collection of cells.
The authors focused on genes that code for pluripotency-associated regulators to study cell differentiation. Indeed, differing gene expression levels are a hallmark of differentiated cells. The authors do not just look at counts in a given cell at a given time. The temporal nature of gene expression is also important. While the authors do not directly look at temporal data using smFISH (since the technique requires fixing the cells), they did look at time lapse fluorescence movies of other regulators. We will not focus on these experiments here, but will discuss how the distribution of mRNA counts acquired via smFISH can serve to provide some insight about the dynamics of gene expression.
The data set we are analyzing now comes from an experiment where smFISH was performed in 279 cells for the genes rex1, rest, nanog, and prdm14. The data set may be downloaded at https://s3.amazonaws.com/bebi103.caltech.edu/data/singer_transcript_counts.csv.
ECDFs of mRNA counts¶
We will do a quick EDA to get a feel for the data set by generating ECDFs for the mRNA counts for each of the four genes.
[2]:
# Load DataFrame
df = pd.read_csv(os.path.join(data_path, 'singer_transcript_counts.csv'), comment='#')
genes = ["Nanog", "Prdm14", "Rest", "Rex1"]
plots = [
iqplot.ecdf(
data=df[gene].values,
q=gene,
x_axis_label="mRNA count",
title=gene,
frame_height=150,
frame_width=200,
)
for gene in genes
]
bokeh.io.show(bokeh.layouts.gridplot(plots, ncols=2))
Note the difference in the \(x\)-axis scales. Clearly, prdm14 has far fewer mRNA copies than the other genes. The presence of two inflection points in the Rex1 EDCF implies bimodality.
Building a generative model¶
As we discussed in a lesson last term, we can model the transcript counts, which result from bursty gene expression, as being Negative Binomially distributed. For a given gene, the likelihood for the counts is
\begin{align} n_i \mid \alpha, b \sim \text{NegBinom}(\alpha, 1/b) \;\forall i, \end{align}
where \(\alpha\) is the burst frequency (higher \(\alpha\) means gene expression comes on more frequently) and \(b\) is the burst size, i.e., the typical number of transcripts made per burst. We have therefore identified the two parameters we need to estimate, \(\alpha\) and \(b\).
Because the Negative Binomial distribution is often parametrized in terms of \(\alpha\) and \(\beta= 1/b\), we can alternatively state our likelihood as
\begin{align} &\beta = 1/b,\\[1em] &n_i \mid \alpha, \beta \sim \text{NegBinom}(\alpha, \beta)\;\; \forall i. \end{align}
Given that we have a Negative Binomial likelihood, we are left to specify priors the burst size \(b\) and the burst frequency \(\alpha\).
Order-of-magnitude tricks¶
I find that codifying prior knowledge often involves making order-of-magnitude estimates of biological/biophysical quantities. To do this, I use an age-old trick that is best demonstrated by example. Let’s say someone tells me about a new bacterium and I have to guess how long a single cell of that species is.
To make my guess, I start absurdly low. Certainly, the cell is bigger than of order nanometer, since that’s the diameter of a strand of DNA. I would bet a year’s salary on it. I would also bet a year’s salary that it would be bigger than 10 nm without flinching. How about 100 nm? Well, I’m pretty sure that bacteria tend not to be smaller than 100 nm, but I don’t think I’d bet a year’s salary. I feel uneasy enough about that that I won’t make that bet. So, I put 100 nm as the lower end of my guess.
Now, let’s consider absurdly large sizes. I would bet a year’s salary that it is less than a meter long. How about 10 cm? That’s still gigantic, and I would bet a year’s salary that it’s smaller than that. How about 1 cm? Still gigantic. How about 1 mm? Well, this is still huge, but there is tremendous diversity among bacteria. I know there are eukaryotic cells this big (for example a Xenopus egg), so, even though I strongly suspect that bacterium would be smaller than 1 mm, I wouldn’t bet a year’s salary. So, 1 mm is my upper bound.
If we were coming up with an order-of-magnitude estimate, we would take the geometric mean of the high and low boundaries. In this case, we would get \(\sqrt{10^{-7}\cdot 10^{-3}} \text{ m} = 10^{-5}\text{ m} =\) 10 µm, which, perhaps not surprisingly, is within an order of magnitude of “typical” bacterial size, for example of E. coli.
Notice that this order-of-magnitude type of estimates operates on a logarithmic scale. We estimate between \(10^{-7}\) and \(10^{-3}\) meters. So, for encoding a prior for the parameter, it is convenient to come up with the prior for the base-ten logarithm of the parameter instead (ignoring the mathematical absurdity with taking logarithms of quantities with units), and then transform the variable. In the bacterial size example, I could use a Normal distribution where 95% of the probability mass lies between \(-7\) and \(-3\). The width of my range of reasonable values from the bet-a-year’s-salary approach is 4 log units, so if I choose a Normal distribution centered at \(-5\) with scale parameter of 1, I capture this prior information. So, my prior for the bacterial length \(\ell\) is
\begin{align} &\log_{10} \ell \sim \text{Norm}(-5, 1),\\[1em] &\ell = 10^{\log_{10}\ell}. \end{align}
Equivalently, since \(\ln 10 \approx 2.3\), we can write this as \(\ell \sim \text{LogNorm}(-2.3\cdot 5, 2.3)\).
To summarize the procedure:
Start at absurdly low values for the parameter and work your way up to a value that you would be hesitant to bet a year’s salary on. This is your low estimate.
Start at absurdly high values for the parameter and work your way down to a value that you would be hesitant to bet a year’s salary on. This is your high estimate.
Determine the center of your two estimates on a logarithmic scale. This is the location parameter \(\mu_{10}\) for the Normal prior.
Take the difference of the high to low estimates and divide it by four. The result is the scale parameter \(\sigma_{10}\) for the Normal prior.
The prior for the base-ten logarithm of the parameter is then \(\text{Norm}(\mu_{10}, \sigma_{10})\). Equivalently, we can say that the parameter is distributed as \(\text{LogNorm}(2.3\mu_{10}, 2.3\sigma_{10})\).
Priors for burst size and inter-burst time¶
Let’s apply this technique to get our priors for the burst size and inter-burst times. I would expect the time between bursts to be longer than a second, since it takes time for the transcriptional machinery to assemble. I would expect it to be shorter than a few hours, since an organism would need to adapt its gene expression based on environmental changes on that time scale or faster. The time between bursts needs to be in units of RNA lifetimes, and bacterial RNA lifetimes are of order minutes. So, the range of values of \(\alpha\) is \(10^{-2}\) to \(10^2\), leading to a prior of
\begin{align} \log_{10} \alpha \sim \text{Norm}(0, 1). \end{align}
I would expect the burst size to depend on promoter strength and/or strength of transcriptional activators. I could imagine anywhere from a few to a few thousand transcripts per burst, giving a range of \(10^0\) to \(10^4\), and a prior of
\begin{align} \log_{10} b \sim \text{Norm}(2, 1). \end{align}
We then have the following model.
\begin{align} &\log_{10} \alpha \sim \text{Norm}(0, 1),\\[1em] &\log_{10} b \sim \text{Norm}(2, 1),\\[1em] &\beta = 1/b,\\[1em] &n_i \sim \text{NegBinom}(\alpha, \beta) \;\forall i. \end{align}
Sampling the posterior¶
To draw samples out of the posterior, we need to use some new Stan syntax. Here is the Stan code we will use.
data {
int<lower=0> N;
int<lower=0> n[N];
}
parameters {
real log10_alpha;
real log10_b;
}
transformed parameters {
real alpha = 10^log10_alpha;
real b = 10^log10_b;
real beta_ = 1.0 / b;
}
model {
// Priors
log10_alpha ~ normal(0, 1);
log10_b ~ normal(2, 1);
// Likelihood
n ~ neg_binomial(alpha, beta_);
}
The
data
block contains the counts \(n\) of the mRNA transcripts. There are \(N\) cells that are measured. Mostdata
blocks look like this. There is an integer parameter that specifies the size of the data set, and then the data set is given as an array. Note that we specified a lower bound on the data (as we will do on the parameters) using the<lower=0>
syntax.The
parameters
block tell us what the parameters of the posterior are. In this case, we wish to sample out of the posterior \(g(\alpha, b \mid \mathbf{n})\), where \(\mathbf{n}\) is the set of transcript counts for the gene. So, the two parameters are \(\alpha\) and \(b\). However, since defining the prior was more easily done in terms of logarithms, we specify \(\log_{10} \alpha\) and \(\log_{10} b\) as the parameters.The
transformed parameters
block allows you to do any transformation of the parameters you are sampling for convenience. In this case, Stan’s Negative Binomial distribution is parametrized by \(\beta = 1/b\), so we make the transformation of theb
tobeta_
. Notice that I have called this variablebeta_
and notbeta
. I did this becausebeta
is one of Stan’s distributions, and you should avoid naming a variable after a word that is already in the Stan language. The other transformation we need to make involve converting the logarithms to the actual parameter values.Finally, the
model
block is where the model is specified. The syntax of the model block is almost identical to that of the hand-written model.
Now that we have specified our model, we can compile it.
[3]:
sm = cmdstanpy.CmdStanModel(stan_file='smfish.stan')
INFO:cmdstanpy:compiling stan program, exe file: /Users/bois/Dropbox/git/bebi103_course/2021/b/content/lessons/09/smfish
INFO:cmdstanpy:compiler options: stanc_options=None, cpp_options=None
INFO:cmdstanpy:compiled model file: /Users/bois/Dropbox/git/bebi103_course/2021/b/content/lessons/09/smfish
With our compiled model, we just need to specify the data and let Stan’s sampler do the work! When using CmdStanPy, the data has to be passed in as a dictionary with keys corresponding to the variable names declared in the data
block of the Stan program and values as Numpy arrays with the appropriate data type. For this calculation, we will use the data set for the rest gene.
[4]:
# Construct data dict, making sure data are ints
data = dict(N=len(df), n=df["Rest"].values.astype(int))
# Sample using Stan
samples = sm.sample(
data=data,
chains=4,
iter_sampling=1000,
)
# Convert to ArviZ InferenceData instance
samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:finish chain 3
INFO:cmdstanpy:finish chain 4
INFO:cmdstanpy:finish chain 1
Let’s take a quick look at the samples.
[5]:
samples.posterior
[5]:
<xarray.Dataset> Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: log10_alpha (chain, draw) float64 0.5799 0.5904 0.6064 ... 0.6483 0.7014 log10_b (chain, draw) float64 1.294 1.284 1.275 ... 1.213 1.216 1.19 alpha (chain, draw) float64 3.801 3.894 4.04 ... 4.821 4.45 5.028 b (chain, draw) float64 19.7 19.24 18.86 ... 16.33 16.43 15.48 beta_ (chain, draw) float64 0.05076 0.05197 ... 0.06087 0.06461 Attributes: created_at: 2021-01-20T23:47:40.574137 arviz_version: 0.11.0 inference_library: cmdstanpy inference_library_version: 0.9.67
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- log10_alpha(chain, draw)float640.5799 0.5904 ... 0.6483 0.7014
array([[0.579914, 0.590407, 0.606418, ..., 0.657254, 0.694734, 0.682243], [0.583215, 0.587873, 0.592196, ..., 0.663777, 0.553088, 0.553434], [0.660856, 0.662177, 0.662177, ..., 0.658305, 0.658305, 0.656678], [0.638236, 0.637285, 0.63584 , ..., 0.683107, 0.648335, 0.701405]])
- log10_b(chain, draw)float641.294 1.284 1.275 ... 1.216 1.19
array([[1.29444, 1.28422, 1.27549, ..., 1.20715, 1.17332, 1.19012], [1.28628, 1.283 , 1.29394, ..., 1.21649, 1.33134, 1.32736], [1.20868, 1.20868, 1.20868, ..., 1.22512, 1.22512, 1.21776], [1.24408, 1.239 , 1.22203, ..., 1.2129 , 1.21562, 1.18973]])
- alpha(chain, draw)float643.801 3.894 4.04 ... 4.45 5.028
array([[3.80114, 3.89409, 4.04034, ..., 4.54207, 4.95147, 4.81109], [3.83014, 3.87144, 3.91018, ..., 4.61081, 3.57345, 3.5763 ], [4.5799 , 4.59385, 4.59385, ..., 4.55307, 4.55307, 4.53605], [4.34746, 4.33795, 4.32355, ..., 4.82067, 4.44975, 5.02812]])
- b(chain, draw)float6419.7 19.24 18.86 ... 16.43 15.48
array([[19.6987, 19.2409, 18.8577, ..., 16.1119, 14.9045, 15.4923], [19.3323, 19.1867, 19.6762, ..., 16.4622, 21.4456, 21.2499], [16.1688, 16.169 , 16.169 , ..., 16.7926, 16.7926, 16.5105], [17.5422, 17.3381, 16.6736, ..., 16.3269, 16.4294, 15.4784]])
- beta_(chain, draw)float640.05076 0.05197 ... 0.06087 0.06461
array([[0.0507649, 0.0519727, 0.0530287, ..., 0.062066 , 0.0670937, 0.0645481], [0.0517268, 0.0521194, 0.0508228, ..., 0.0607452, 0.0466296, 0.0470591], [0.0618476, 0.0618466, 0.0618466, ..., 0.0595499, 0.0595499, 0.0605675], [0.0570053, 0.0576764, 0.0599751, ..., 0.0612486, 0.0608663, 0.064606 ]])
- created_at :
- 2021-01-20T23:47:40.574137
- arviz_version :
- 0.11.0
- inference_library :
- cmdstanpy
- inference_library_version :
- 0.9.67
As we have already seen, the samples are indexed by chain and draw. Parameters represented in the parameters
and transformed parameters
blocks are reported.
Plots of the samples¶
There are many ways of looking at the samples. In this case, since we have two parameters of interest, the pulse frequency and pulse size, we can plot the samples as a scatter plot to get the approximate density. For this kind of plot, HoloViews expects a Pandas data frame (or similar object). We can convert an xarray object into a data frame using the to_dataframe()
method.
[6]:
df_mcmc = samples.posterior.to_dataframe()
# Take a look
df_mcmc.head()
[6]:
log10_alpha | log10_b | alpha | b | beta_ | ||
---|---|---|---|---|---|---|
chain | draw | |||||
0 | 0 | 0.579914 | 1.29444 | 3.80114 | 19.6987 | 0.050765 |
1 | 0.590407 | 1.28422 | 3.89409 | 19.2409 | 0.051973 | |
2 | 0.606418 | 1.27549 | 4.04034 | 18.8577 | 0.053029 | |
3 | 0.616111 | 1.26819 | 4.13153 | 18.5433 | 0.053928 | |
4 | 0.597194 | 1.27039 | 3.95543 | 18.6377 | 0.053655 |
The indexes from the xarray become indexes for the data frame and the parameter names are the columns. We can now use HoloViews to make our scatter plot. I will use transparency to help visualize the density of points.
[7]:
hv.Points(
data=df_mcmc,
kdims=[('alpha', 'α'), 'b']
).opts(
alpha=0.2,
size=2,
)
Data type cannot be displayed:
[7]:
We see very strong correlation between \(\alpha\) and \(b\). This does not necessarily mean that they depend on each other. Rather, it means that our degree of belief about their values depends on both in a correlated way. The measurements we made cannot effectively separate the effects of \(\alpha\) and \(b\) on the transcript counts.
Marginalizing the posterior¶
We can also plot the marginalized posterior distributions. Remember that the marginalized distributions properly take into account the effects of the other variable, including the strong correlation I just mentioned. To obtain the marginalized distribution, we simply ignore the samples of the parameters we are marginalizing out. It is convenient to look at the marginalized distributions as ECDFs.
[8]:
plots = [
iqplot.ecdf(df_mcmc, q=param, plot_height=200, plot_width=250)
for param in ["alpha", "b"]
]
bokeh.io.show(bokeh.layouts.gridplot(plots, ncols=2))
Alternatively, we can visualize the marginalized posterior PDFs as histograms. Because we have such a large number of samples, binning bias from histograms is less of a concern.
[9]:
plots = [
iqplot.histogram(df_mcmc, q=param, plot_height=200, plot_width=250, rug=False)
for param in ["alpha", "b"]
]
bokeh.io.show(bokeh.layouts.gridplot(plots, ncols=2))
Analysis for all genes¶
We can do the same analysis for all genes. To do so, we input the data sets for each gene into the sampler and make our plot of the posterior. When we do the sampling, to avoid clutter on our screen, we can disable the logging that CmdStanPy sends by using the bebi103.stan.disable_logging()
context manager.
[10]:
plots = []
for gene in df.columns:
data = dict(N=len(df), n=df[gene].values.astype(int))
with bebi103.stan.disable_logging():
samples = sm.sample(data=data, chains=4, iter_sampling=1000)
samples = az.from_cmdstanpy(posterior=samples)
df_mcmc = samples.posterior.to_dataframe()
plots.append(
hv.Points(data=df_mcmc, kdims=[("alpha", "α"), "b"], label=gene).opts(
alpha=0.05, axiswise=True, frame_height=200, frame_width=200, size=2
)
)
hv.Layout(plots).cols(2)
Data type cannot be displayed:
[10]:
Note that this single Negative Binomial model probably does not describe the Rex1 data, as can be seen from the ECDF of the measurements. Nonetheless, we can still assume the model is true and compute (i.e., sample) the posterior as if the model were true. This is always what we are doing when we perform parameter estimations. That said, we should seek a more apt model for Rex1.
Display of “best fit”¶
After performing an MCMC calculation to access the posterior, we often want to visualize, for example, the ECDF of the measurements along with ECDFs predicted from the model. We will discuss methods for doing this in coming lessons when we discuss posterior predictive checks. For now, we will plot theoretical CDFs for parameter sets drawn from the posterior. First, we’ll grab the posterior samples again.
[11]:
# Re-obtain samples for rest
data = dict(N=len(df), n=df["Rest"].values.astype(int))
samples = sm.sample(
data=data,
chains=4,
iter_sampling=1000,
)
samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 1
INFO:cmdstanpy:finish chain 4
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:finish chain 3
And now we’ll get a plot of the ECDF.
[12]:
p = iqplot.ecdf(data=df['Rest'].values, x_axis_label='mRNA count')
We’ll generate a new CDFs for 100 sets of parameter values.
[13]:
# x-values and samples to use in plot
x = np.arange(251)
alphas = samples.posterior['alpha'].values.flatten()[::40]
betas = samples.posterior['beta_'].values.flatten()[::40]
for alpha, beta in zip(alphas, betas):
y = st.nbinom.cdf(x, alpha, beta/(1+beta))
x_plot, y_plot = bebi103.viz.cdf_to_staircase(x, y)
p.line(x_plot, y_plot, line_width=0.5, color='orange', level='underlay')
bokeh.io.show(p)
The measured CDF seems to be within reason given the model. This is not true for the Rex1 gene, however.
[14]:
# Re-obtain samples for rest
data = dict(N=len(df), n=df["Rex1"].values.astype(int))
with bebi103.stan.disable_logging():
samples = sm.sample(
data=data,
chains=4,
iter_sampling=1000,
)
samples = az.from_cmdstanpy(posterior=samples)
# Make ECDF
p = iqplot.ecdf(data=df['Rex1'].values, x_axis_label='mRNA count')
# x-values and samples to use in plot
x = np.arange(426)
alphas = samples.posterior['alpha'].values.flatten()[::40]
betas = samples.posterior['beta_'].values.flatten()[::40]
for alpha, beta in zip(alphas, betas):
y = st.nbinom.cdf(x, alpha, beta/(1+beta))
x_plot, y_plot = bebi103.viz.cdf_to_staircase(x, y)
p.line(x_plot, y_plot, line_width=0.5, color='orange', level='underlay')
bokeh.io.show(p)
The lesson here is that getting nicely identifiable parameter estimates does not mean that the model is good. As I mentioned before, we will do more careful assessment of this when we do posterior predictive checks.
[15]:
bebi103.stan.clean_cmdstan()
Computing environment¶
[16]:
%load_ext watermark
%watermark -v -p numpy,scipy,pandas,cmdstanpy,arviz,bokeh,holoviews,iqplot,bebi103,jupyterlab
print("cmdstan :", bebi103.stan.cmdstan_version())
Python implementation: CPython
Python version : 3.8.5
IPython version : 7.19.0
numpy : 1.19.2
scipy : 1.5.2
pandas : 1.2.0
cmdstanpy : 0.9.67
arviz : 0.11.0
bokeh : 2.2.3
holoviews : 1.14.0
iqplot : 0.2.0
bebi103 : 0.1.2
jupyterlab: 2.2.6
cmdstan : 2.25.0