Display of MCMC samples¶
[1]:
# Colab setup ------------------
import os, sys, subprocess
if "google.colab" in sys.modules:
cmd = "pip install --upgrade iqplot colorcet bebi103 arviz cmdstanpy watermark"
process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
import cmdstanpy; cmdstanpy.install_cmdstan()
data_path = "https://s3.amazonaws.com/bebi103.caltech.edu/data/"
else:
data_path = "../data/"
# ------------------------------
import numpy as np
import pandas as pd
import cmdstanpy
import arviz as az
import iqplot
import bebi103
import holoviews as hv
hv.extension('bokeh')
bebi103.hv.set_defaults()
import bokeh.io
bokeh.io.output_notebook()
In a previous lesson, we learned how to use Stan to sample out of posterior distributions specified by a generative model. In this lecture, we will learn about some techniques to visualize the results. As our first working example, we will consider again the conserved tubulin spindle size model, discussed at length previously, using the data set from Good, et al., Science, 342, 856-860, 2013.
We is a quick look at the data set.
[2]:
df = pd.read_csv(os.path.join(data_path, 'good_invitro_droplet_data.csv'), comment='#')
hv.Scatter(
data=df,
kdims='Droplet Diameter (um)',
vdims='Spindle Length (um)',
)
Data type cannot be displayed:
[2]:
Visualization with ArviZ¶
ArviZ has lots of handy visualization of MCMC results, and it is a very promising package to use to quickly get plots. It currently still has some drawbacks. First, some plots are not rendered properly with Bokeh. This problem will likely be fixed very soon with future releases. Second, some of the aesthetic choices in the plots would not be my first choice. There are other specific drawbacks to the ArviZ visualizations that I will point out as we go along.
For each of the visualizations below, I show how to construct the visualization using ArviZ, but also using the bebi103 package.
The model and samples¶
We will briefly repeat the sampling of the model as we did in a previous lesson. As a reminder, the model is
\begin{align} &\log_{10} \phi \sim \text{Norm}(1.5, 0.75),\\[1em] &\gamma \sim \text{Beta}(1.1, 1.1), \\[1em] &\sigma \sim \text{HalfNorm}(10),\\[1em] &\mu_i = \frac{\gamma d_i}{\left(1+(\gamma d_i/\phi)^3\right)^{\frac{1}{3}}}, \\[1em] &l_i \mid d_i, \gamma, \phi, \sigma \sim \text{Norm}(\mu_i, \sigma) \;\forall i. \end{align}
We used the following Stan code to implement this model.
functions {
real ell_theor(real d, real phi, real gamma_) {
real denom_ratio = (gamma_ * d / phi)^3;
return gamma_ * d / (1 + denom_ratio)^(1.0 / 3.0);
}
}
data {
int N;
real d[N];
real ell[N];
}
parameters {
real log10_phi;
real gamma_;
real<lower=0> sigma;
}
transformed parameters {
real phi = 10^log10_phi;
}
model {
log10_phi ~ normal(1.5, 0.75);
gamma_ ~ beta(1.1, 1.1);
sigma ~ normal(0.0, 10.0);
for (i in 1:N) {
ell[i] ~ normal(ell_theor(d[i], phi, gamma_), sigma);
}
}
Let’s obtain some samples so we can start to look at visualizations.
[3]:
sm = cmdstanpy.CmdStanModel(stan_file='spindle.stan')
data = dict(
N=len(df),
d=df["Droplet Diameter (um)"].values,
ell=df["Spindle Length (um)"].values,
)
samples = sm.sample(
data=data,
chains=4,
iter_sampling=1000,
)
samples = az.from_cmdstanpy(posterior=samples)
INFO:cmdstanpy:found newer exe file, not recompiling
INFO:cmdstanpy:compiled model file: /Users/bois/Dropbox/git/bebi103_course/2021/b/content/lessons/12/spindle
INFO:cmdstanpy:start chain 1
INFO:cmdstanpy:start chain 2
INFO:cmdstanpy:start chain 3
INFO:cmdstanpy:start chain 4
INFO:cmdstanpy:finish chain 4
INFO:cmdstanpy:finish chain 3
INFO:cmdstanpy:finish chain 2
INFO:cmdstanpy:finish chain 1
[4]:
samples
[4]:
-
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- log10_phi(chain, draw)float641.576 1.575 1.584 ... 1.585 1.584
array([[1.57633, 1.57491, 1.5839 , ..., 1.58652, 1.58142, 1.58458], [1.57928, 1.58381, 1.58343, ..., 1.58005, 1.58469, 1.57798], [1.5878 , 1.58434, 1.57993, ..., 1.58464, 1.5796 , 1.58649], [1.58777, 1.58231, 1.57624, ..., 1.58271, 1.5851 , 1.58385]])
- gamma_(chain, draw)float640.8943 0.8976 ... 0.8516 0.8477
array([[0.894319, 0.897639, 0.884862, ..., 0.856864, 0.863981, 0.856134], [0.864442, 0.858905, 0.86298 , ..., 0.875154, 0.841966, 0.873787], [0.838893, 0.854959, 0.848618, ..., 0.858557, 0.861863, 0.845696], [0.830751, 0.865369, 0.8751 , ..., 0.850072, 0.851586, 0.847714]])
- sigma(chain, draw)float643.698 3.704 3.85 ... 3.81 3.875
array([[3.69845, 3.70376, 3.85035, ..., 3.5353 , 3.60493, 3.69572], [3.77689, 3.86167, 3.53016, ..., 3.84636, 3.71386, 3.95581], [3.69343, 3.70836, 3.71914, ..., 3.91961, 3.87118, 3.80008], [3.76449, 3.7053 , 3.91219, ..., 3.7839 , 3.80973, 3.87489]])
- phi(chain, draw)float6437.7 37.58 38.36 ... 38.47 38.36
array([[37.6987, 37.5758, 38.3623, ..., 38.5942, 38.1431, 38.4219], [37.9563, 38.3541, 38.3203, ..., 38.0236, 38.432 , 37.8425], [38.7078, 38.4004, 38.0132, ..., 38.4272, 37.9836, 38.5913], [38.7053, 38.2221, 37.6914, ..., 38.2566, 38.468 , 38.3572]])
- created_at :
- 2021-01-27T17:56:50.145317
- arviz_version :
- 0.11.0
- inference_library :
- cmdstanpy
- inference_library_version :
- 0.9.67
<xarray.Dataset> Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999 Data variables: log10_phi (chain, draw) float64 1.576 1.575 1.584 ... 1.583 1.585 1.584 gamma_ (chain, draw) float64 0.8943 0.8976 0.8849 ... 0.8516 0.8477 sigma (chain, draw) float64 3.698 3.704 3.85 3.844 ... 3.784 3.81 3.875 phi (chain, draw) float64 37.7 37.58 38.36 ... 38.26 38.47 38.36 Attributes: created_at: 2021-01-27T17:56:50.145317 arviz_version: 0.11.0 inference_library: cmdstanpy inference_library_version: 0.9.67
xarray.Dataset -
- chain: 4
- draw: 1000
- chain(chain)int640 1 2 3
array([0, 1, 2, 3])
- draw(draw)int640 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
- lp(chain, draw)float64-1.222e+03 ... -1.221e+03
array([[-1222.4 , -1222.74, -1224.18, ..., -1223.84, -1221.66, -1220.74], [-1220.88, -1220.99, -1223.32, ..., -1221.15, -1221.3 , -1222.65], [-1221.48, -1220.63, -1222.41, ..., -1221.78, -1221.47, -1220.91], [-1222.06, -1220.66, -1222.57, ..., -1220.88, -1220.71, -1221.4 ]])
- accept_stat(chain, draw)float640.9816 0.9992 ... 0.9869 0.9281
array([[0.981634, 0.999233, 0.73838 , ..., 0.941859, 1. , 0.967261], [1. , 0.830718, 0.957472, ..., 0.986345, 0.942467, 0.989788], [0.982065, 0.999404, 0.60787 , ..., 0.916075, 0.941789, 0.880959], [0.897354, 0.989246, 0.582772, ..., 0.95018 , 0.986923, 0.928053]])
- stepsize(chain, draw)float640.3326 0.3326 ... 0.4397 0.4397
array([[0.33261 , 0.33261 , 0.33261 , ..., 0.33261 , 0.33261 , 0.33261 ], [0.379432, 0.379432, 0.379432, ..., 0.379432, 0.379432, 0.379432], [0.411897, 0.411897, 0.411897, ..., 0.411897, 0.411897, 0.411897], [0.439693, 0.439693, 0.439693, ..., 0.439693, 0.439693, 0.439693]])
- treedepth(chain, draw)int642 3 3 1 3 2 3 3 ... 3 2 2 4 2 3 1 3
array([[2, 3, 3, ..., 3, 3, 3], [3, 2, 3, ..., 4, 4, 3], [3, 4, 2, ..., 3, 3, 3], [4, 2, 2, ..., 3, 1, 3]])
- n_leapfrog(chain, draw)int643 7 7 1 11 3 7 7 ... 7 3 15 3 7 3 7
array([[ 3, 7, 7, ..., 15, 7, 7], [ 7, 7, 11, ..., 15, 15, 7], [ 7, 15, 3, ..., 15, 7, 15], [15, 7, 7, ..., 7, 3, 7]])
- diverging(chain, draw)boolFalse False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
- energy(chain, draw)float641.223e+03 1.223e+03 ... 1.222e+03
array([[1222.58, 1223.07, 1230.43, ..., 1228.2 , 1224.67, 1222.15], [1221.71, 1222.83, 1223.36, ..., 1222.38, 1221.9 , 1223.39], [1221.74, 1221.65, 1223.4 , ..., 1224.13, 1222.66, 1224.15], [1223.9 , 1224.06, 1224.93, ..., 1223.17, 1221.23, 1222.22]])
- created_at :
- 2021-01-27T17:56:50.153855
- arviz_version :
- 0.11.0
- inference_library :
- cmdstanpy
- inference_library_version :
- 0.9.67
<xarray.Dataset> Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: lp (chain, draw) float64 -1.222e+03 -1.223e+03 ... -1.221e+03 accept_stat (chain, draw) float64 0.9816 0.9992 0.7384 ... 0.9869 0.9281 stepsize (chain, draw) float64 0.3326 0.3326 0.3326 ... 0.4397 0.4397 treedepth (chain, draw) int64 2 3 3 1 3 2 3 3 2 2 ... 2 3 3 2 2 4 2 3 1 3 n_leapfrog (chain, draw) int64 3 7 7 1 11 3 7 7 7 ... 15 7 7 3 15 3 7 3 7 diverging (chain, draw) bool False False False ... False False False energy (chain, draw) float64 1.223e+03 1.223e+03 ... 1.222e+03 Attributes: created_at: 2021-01-27T17:56:50.153855 arviz_version: 0.11.0 inference_library: cmdstanpy inference_library_version: 0.9.67
xarray.Dataset
Examining traces¶
The first type of visualization we will explore is useful for diagnosing potential problems with the sampler.
Trace plots¶
Trace plots are a common way of visualizing the trajectory a sampler takes through parameter space. We plot the value of the parameter versus step number. You can make these plots using bebi103.viz.trace_plot()
. By default, the traces are colored according to chain number. In our sampling we just did, we used Stan’s default of four chains.
Trace plots with ArviZ¶
To make plots with ArviZ, the basic syntax is az.plot_***(samples, backend='bokeh')
, where ***
is the kind of plot you want. The backend="bokeh"
kwarg indicates that you want the plot rendered with Bokeh. Let’s make a trace plot.
[5]:
az.plot_trace(samples, backend="bokeh");
ArviZ generates two plots for each variable. The plots to the right show the trace of the sampler. In this plot, the x-axis is the iteration number of the steps of the walker and the y-axis is the values of the parameter for that step.
ArviZ also plots a picture of the marginalized posterior distribution for each trace to the left. The issue I have these plots because they use kernel density estimation for the posterior plot. KDE involves selecting a bandwidth, which leaves an adjustable parameter. I prefer simply to plot the ECDFs of the samples to visualize the marginal distributions, as I show below.
Trace plots with bebi103¶
To make trace plots with the bebi103 module, you can do the following.
[6]:
bokeh.io.show(
bebi103.viz.trace(samples)
)
Interpetation of trace plots¶
These trace plots look pretty good; the sampler is bounding around a central value. Trace plots are very commonly used, which is why I present them here, but are not particularly useful. This is not just my opinion. Here is what Dan Simpson has to say about trace plots.
Parallel coordinate plots¶
As Dan Simpson pointed out in his talk I linked to above, plots that visualize diagnostics effectively are better. We will talk more about diagnostics of MCMC in later lessons, but for now, I will display one of the diagnostic plots Dan showed in his talk. Here, we put the parameter names on the x-axis, and we plot the values of the parameters on the y-axis. Each line in the plot represents a single sample of the set of parameters.
Parallel coordinate plots with ArviZ¶
[7]:
az.plot_parallel(samples, backend='bokeh');
To make sure we compare things of the same magnitude so we can better see the details of how the parameters vary with respect to each other, we scale the samples by the minimum and maximum values.
[8]:
az.plot_parallel(
samples, var_names=["phi", "gamma_", "sigma"], backend="bokeh", norm_method="minmax"
);
Parallel coordinate plots with bebi103¶
[9]:
bokeh.io.show(
bebi103.viz.parcoord(
samples,
transformation="minmax",
parameters=["phi", "gamma_", "sigma"],
)
)
Intepretation of parallel coordinate plots¶
Normally, samples with problems are plotted in another color to help diagnose the problems. In this particular set of samples, there were no problems, so everything looks fine.
Interestingly, the “neck” between phi
and gamma_
is indicative of anticorrelation. When φ is high, γ is low, and vice-versa.
Plots of marginalized distributions¶
While we cannot in general plot a multidimensional distribution, we can plot marginalized distributions, either of one parameter or of two.
Plotting marginalized distributions of one parameter¶
There are three main options for plotting marginalized distributions of a single parameter.
A kernel density estimate of the marginalized posterior.
A histogram of the marginalized posterior.
An ECDF of the samples out of the posterior for a particular parameter. This approximates the CDF of the marginalized distribution.
I prefer (3), though (2) is a good option as well. For continuous parameters, ArviZ offers only a KDE plot.
Plotting marginalized distributions with ArviZ¶
We already saw that ArviZ can give these plots along with trace plots. We can also directly plot them using the az.plot_density()
. When using this function, the distributions are truncated at the bound of the highest probability density region, or HPD. If we’re considering a 95% credible interval, the HPD interval is the shortest interval that contains 95% of the probability of the posterior. By default, az.plot_density()
truncates the plot of the KDE of the PDF for a 94% HPD.
[10]:
az.plot_density(samples, backend='bokeh');
Plotting marginalized distributions with iqplot¶
To get samples out of the marginalized posterior for a single parameter, we simply ignore the values of the parameters that are not the one of interest. We can then use iqplot
to make plots of histograms or ECDFs.
To do so, we need to convert the posterior samples to a tidy data frame.
[11]:
# Convert to data frame
df_mcmc = bebi103.stan.arviz_to_dataframe(samples)
# Take a look
df_mcmc.head()
[11]:
log10_phi | gamma_ | sigma | phi | chain__ | draw__ | diverging__ | |
---|---|---|---|---|---|---|---|
0 | 1.57633 | 0.894319 | 3.69845 | 37.6987 | 0 | 0 | False |
1 | 1.57491 | 0.897639 | 3.70376 | 37.5758 | 0 | 1 | False |
2 | 1.58390 | 0.884862 | 3.85035 | 38.3623 | 0 | 2 | False |
3 | 1.58353 | 0.884912 | 3.84393 | 38.3288 | 0 | 3 | False |
4 | 1.58195 | 0.858369 | 3.58953 | 38.1898 | 0 | 4 | False |
Now, we can use iqplot
to make histograms.
[12]:
hists = [
iqplot.histogram(
df_mcmc, q=param, density=True, rug=False, frame_height=150
)
for param in ["phi", "gamma_", "sigma"]
]
bokeh.io.show(bokeh.layouts.gridplot(hists, ncols=1))
A better option, in my opinion, is to make ECDFs of the samples.
[13]:
ecdfs = [
iqplot.ecdf(
df_mcmc, q=param, style="staircase", frame_height=150
)
for param in ["phi", "gamma_", "sigma"]
]
bokeh.io.show(bokeh.layouts.gridplot(ecdfs, ncols=1))
Marginal posteriors of two parameters¶
We can also plot the two-dimensional distribution (of most interest here are the parameters \(\phi\) and \(\gamma\)). The simplest way to plot these is simple to plot each point, possibily with some transparency.
[14]:
hv.Points(
df_mcmc,
kdims=[('phi', 'ϕ [µm]'), ('gamma_', 'γ')],
).opts(
alpha = 0.1,
size=2,
)
Data type cannot be displayed:
[14]:
Pair plots with ArviZ¶
We might like to do this for all pairs of plots. ArviZ enables this to be done conveniently.
[15]:
az.plot_pair(
samples, var_names=["phi", "gamma_", "sigma"], backend="bokeh", figsize=(8, 8)
);
ArviZ also offers pair two-dimensional binning in the form of a hex plot. (This does not work properly with a Bokeh backend, so we use a Matplotlib backend.)
[16]:
az.plot_pair(
samples,
var_names=["phi", "gamma_", "sigma"],
kind="hexbin",
backend="matplotlib",
figsize=(8, 8),
);
Corner plots¶
It would be convenient to plot all visualizable marginal distributions (that is one- and two-dimensional distribution). We can conveniently do that with a corner plot, implemented in the bebi103 package.
[17]:
bokeh.io.show(
bebi103.viz.corner(
samples,
parameters=[("phi", "ϕ [µm]"), ("gamma_", "γ"), ("sigma", "σ [µm]")],
xtick_label_orientation=np.pi / 4,
)
)
This is a nice way to summarize the posterior and is useful for visualizing how various parameters covary. We can also do a corner plot with the one-parameter marginalized posteriors represented as CDFs.
[18]:
bokeh.io.show(
bebi103.viz.corner(
samples,
parameters=[("phi", "ϕ [µm]"), ("gamma_", "γ"), ("sigma", "σ [µm]")],
plot_ecdf=True,
xtick_label_orientation=np.pi / 4,
)
)
In my view, this last plot is my preferred method of displaying results. All possible display-able marginal posteriors are plotted and laid out in a logical way.
[19]:
bebi103.stan.clean_cmdstan()
Computing environment¶
[20]:
%load_ext watermark
%watermark -v -p numpy,pandas,cmdstanpy,arviz,bokeh,holoviews,iqplot,bebi103,jupyterlab
print("cmdstan :", bebi103.stan.cmdstan_version())
CPython 3.8.5
IPython 7.19.0
numpy 1.19.2
pandas 1.2.1
cmdstanpy 0.9.67
arviz 0.11.0
bokeh 2.2.3
holoviews 1.14.0
iqplot 0.2.0
bebi103 0.1.2
jupyterlab 2.2.6
cmdstan : 2.26.0