class: center, middle, inverse, title-slide # Introduction to Bayesian Inference ## Foundational Computational Biology II ### Kieran Campbell ### Lunenfeld Tanenbaum Research Institute & University of Toronto ### 2021-04-23 (updated: 2022-04-04) --- class: inverse # What we'll cover 1. Re-introduction to Bayes rule 2. Markov Chain Monte Carlo: Metropolis-Hastings 3. Variational methods 4. An introduction to Stan --- # Conventions * We'll have `\(N\)` samples indexed `\(n = 1, \ldots, N\)` -- * When an output is multidimensional, we'll refer to output dimensions as `\(p=1,\ldots,P\)` -- * Our data for sample `\(n\)` is `\(y_{n}\)` (possibly high dimensional) -- * Our parameter of interest is `\(\theta\)` (possibly high dimensional) --- # Linking science to statistics Many scientific questions fundamentally posed as `$$p(\theta | y_1, \ldots, y_N)$$` -- ## Examples * `\(y_1, \ldots, y_N\)` expression measurements across replicates, `\(\theta\)` "true" expression -- * `\(y_1, \ldots, y_N\)` ctDNA measurements across `\(N\)` cases, `\(\theta\)` "true" ctDNA measurement in case controls -- ### May also want to compare case to control `\(\theta_1\)` ctDNA abundance in cases, `\(\theta_2\)` ctDNA abundance in controls `$$p(\theta_1 > \theta_2 | y_1, \ldots, y_N)$$` --- # Rederiving Bayes rule So how do we get `\(p(\theta | y)\)` ? -- Remember for any R.V.s `\(A,B\)` `$$p(A,B) = p(A|B)p(B) = p(B|A)p(A)$$` -- Substituting `\(B=\theta\)` and `\(A=y\)` and dividing we get **Bayes' rule**: `$$p(\theta|y) = \frac{p(y|\theta)p(\theta)}{p(y)}$$` -- What is this strange `\(p(y)\)`? -- Discrete `\(\theta\)`: `\(p(y) = \sum_{\text{All possible values of } \theta}p(y|\theta)p(\theta)\)` -- Continuous `\(\theta\)`: `\(p(y) = \int_{\text{Space of } \theta} p(y|\theta)p(\theta) \mathrm{d} \theta\)` --- # A concrete example (of Bayes' rule) Let `\(\theta = 1\)` if a patient has diabetes, 0 otherwise -- Let `\(y\)` be the fasting blood sugar level of a patient in mg/dL -- Let `\(p(y|\theta=1) = \mathcal{N}(y|135,10)\)` and `\(p(y|\theta=0) = \mathcal{N}(y|80,15)\)` -- Let `\(p(\theta=1) = 0.073\)` -- A patient comes with `\(y=106\)` mg/dL. What is the probability they have diabetes? -- `$$p(\theta=1|y) = \frac{p(y|\theta=1)p(\theta=1)}{p(y|\theta=1)p(\theta=1) + p(y|\theta=0)p(\theta=0)}$$` -- `$$=\frac{\mathcal{N}(106|135,10) \times 0.073}{\mathcal{N}(106|135,10) \times 0.073 + \mathcal{N}(106|80,15) \times (1-0.073)}$$` -- ```r a <- dnorm(106,135,10) * 0.073; b <- dnorm(106,80,15) * (1-0.073) print(a / (a+b)) ``` ``` ## [1] 0.007854316 ``` --- # Credible intervals Armed with the posterior, we can compute some cool quantities: > What's the posterior mean of `\(\theta\)`? `$$\mathbb{E}_{p(\mathbf{\theta} | \mathbf{Y})} [ \mathbf{\theta} ] = \int \mathbf{\theta} p(\mathbf{\theta} | \mathbf{Y}) \mathrm{d} \theta$$` > What's the probability `\(\theta\)` falls in some region? `$$p(\theta \in [a,b] | \mathbf{Y}) = \int_a^b p(\mathbf{\theta} | \mathbf{Y}) \mathrm{d} \theta$$` -- _This is what you've always wanted a confidence interval to be_ .footnote[ Note: everything is always conditioned on your model! ] --- # Bayesian inference Easy, right? -- Wrong...as soon as `\(\mathbf{\theta}\)` becomes moderately high dimensional, Bayesian inference becomes _hard_. Why? -- `$$p(\theta|y) = \frac{\color{darkgreen}{p(y|\theta)}\color{darkblue}{p(\theta)}}{\color{darkred}{\int_{\text{Space of } \theta} p(y|\theta)p(\theta) \mathrm{d} \theta}}$$` `\(\color{darkgreen}{p(y|\theta)}\)` <svg viewBox="0 0 448 512" xmlns="http://www.w3.org/2000/svg" style="height:1em;fill:currentColor;position:relative;display:inline-block;top:.1em;"> [ comment ] <path d="M190.5 66.9l22.2-22.2c9.4-9.4 24.6-9.4 33.9 0L441 239c9.4 9.4 9.4 24.6 0 33.9L246.6 467.3c-9.4 9.4-24.6 9.4-33.9 0l-22.2-22.2c-9.5-9.5-9.3-25 .4-34.3L311.4 296H24c-13.3 0-24-10.7-24-24v-32c0-13.3 10.7-24 24-24h287.4L190.9 101.2c-9.8-9.3-10-24.8-.4-34.3z"></path></svg> easy to evaluate `\(\color{darkblue}{p(\theta)}\)` <svg viewBox="0 0 448 512" xmlns="http://www.w3.org/2000/svg" style="height:1em;fill:currentColor;position:relative;display:inline-block;top:.1em;"> [ comment ] <path d="M190.5 66.9l22.2-22.2c9.4-9.4 24.6-9.4 33.9 0L441 239c9.4 9.4 9.4 24.6 0 33.9L246.6 467.3c-9.4 9.4-24.6 9.4-33.9 0l-22.2-22.2c-9.5-9.5-9.3-25 .4-34.3L311.4 296H24c-13.3 0-24-10.7-24-24v-32c0-13.3 10.7-24 24-24h287.4L190.9 101.2c-9.8-9.3-10-24.8-.4-34.3z"></path></svg> easy to evaluate `\(\color{darkred}{\int_{\text{Space of } \theta} p(y|\theta)p(\theta) \mathrm{d} \theta}\)` <svg viewBox="0 0 448 512" xmlns="http://www.w3.org/2000/svg" style="height:1em;fill:currentColor;position:relative;display:inline-block;top:.1em;"> [ comment ] <path d="M190.5 66.9l22.2-22.2c9.4-9.4 24.6-9.4 33.9 0L441 239c9.4 9.4 9.4 24.6 0 33.9L246.6 467.3c-9.4 9.4-24.6 9.4-33.9 0l-22.2-22.2c-9.5-9.5-9.3-25 .4-34.3L311.4 296H24c-13.3 0-24-10.7-24-24v-32c0-13.3 10.7-24 24-24h287.4L190.9 101.2c-9.8-9.3-10-24.8-.4-34.3z"></path></svg> hard to evaluate --- # The what of Bayesian inference Most of Bayesian inference comes down to smart ways to approximate `\(p(\theta|y)\)` by only evaluating `\(\color{darkgreen}{p(y|\theta)}\)` and `\(\color{darkblue}{p(\theta)}\)` -- Two main approaches (we'll cover here): -- 1. _Sampling methods_ (e.g. MCMC): draw samples `\(\theta^s \sim p(\theta|y)\)` -- 2. _Variational methods_: come up with an easy-to-evaluate distribution `\(q(\theta)\)` and make it as similar as possible to `\(p(\theta|y)\)` while only evaluating `\(\color{darkgreen}{p(y|\theta)}\)` and `\(\color{darkblue}{p(\theta)}\)` --- # Gibbs sampling Suppose `\(\mathbf{\theta}\)` is high-ish dimensional, i.e. `$$\mathbf{\theta} = [\theta_1, \theta_2, \ldots, \theta_P ]$$` -- Gibbs sampling (and many other approaches) approximates the distribution `\(p(\mathbf{\theta} | \mathbf{Y})\)` by returning a set of samples `$$\mathbf{\theta}^{(1)}, \mathbf{\theta}^{(2)}, \ldots, \mathbf{\theta}^{(S)}$$` for `\(S\)` samples -- Remember here each `\(\mathbf{\theta}^{(s)}\)` is `\(P\)` dimensional -- Using these samples we can take empirical expectations of interesting quantities, e.g.: > What's the posterior mean of `\(\mathbf{\theta}\)`? `$$\mathbb{E}_{p(\mathbf{\theta} | \mathbf{Y})} [ \mathbf{\theta} ] = \int \mathbf{\theta} p(\mathbf{\theta} | \mathbf{Y}) \mathrm{d} \theta \approx \sum_{s=1}^S \mathbf{\theta}^{(s)}$$` --- # Gibbs sampling (II) So how does Gibbs sampling sample `\(\mathbf{\theta}^{(s)}\)` from `\(p(\mathbf{\theta} | \mathbf{Y})\)` ? -- Remember `\(\mathbf{\theta}\)` is `\(P\)` dimensional -- Relies on knowing the conditionals `\(p(\theta_p | \mathbf{\theta}_{-p}, \mathbf{Y})\)` where `\(\mathbf{\theta}_{-p}\)` is all elements of `\(\mathbf{\theta}\)` other than `\(p\)` -- Gibbs sampling proceeds via the following 1. Initialize `\(\mathbf{\theta}\)` 2. For each `\(p \in 1, \ldots, P\)` sample `\(\theta_p \sim p(\theta_p | \mathbf{\theta}_{-p}, \mathbf{Y})\)` 3. Repeat this whole process a number of times Then `\(\mathbf{\theta}^{(1)}, \mathbf{\theta}^{(2)}, \ldots, \mathbf{\theta}^{(S)}\)` approximate `\(p(\mathbf{\theta} | \mathbf{Y})\)` -- <svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" style="height:1em;fill:currentColor;position:relative;display:inline-block;top:.1em;"> [ comment ] <path d="M256 8C119.033 8 8 119.033 8 256s111.033 248 248 248 248-111.033 248-248S392.967 8 256 8zm0 48c110.532 0 200 89.451 200 200 0 110.532-89.451 200-200 200-110.532 0-200-89.451-200-200 0-110.532 89.451-200 200-200m140.204 130.267l-22.536-22.718c-4.667-4.705-12.265-4.736-16.97-.068L215.346 303.697l-59.792-60.277c-4.667-4.705-12.265-4.736-16.97-.069l-22.719 22.536c-4.705 4.667-4.736 12.265-.068 16.971l90.781 91.516c4.667 4.705 12.265 4.736 16.97.068l172.589-171.204c4.704-4.668 4.734-12.266.067-16.971z"></path></svg> efficient - no samples wasted <svg viewBox="0 0 512 512" xmlns="http://www.w3.org/2000/svg" style="height:1em;fill:currentColor;position:relative;display:inline-block;top:.1em;"> [ comment ] <path d="M256 8C119 8 8 119 8 256s111 248 248 248 248-111 248-248S393 8 256 8zm0 448c-110.5 0-200-89.5-200-200S145.5 56 256 56s200 89.5 200 200-89.5 200-200 200zm101.8-262.2L295.6 256l62.2 62.2c4.7 4.7 4.7 12.3 0 17l-22.6 22.6c-4.7 4.7-12.3 4.7-17 0L256 295.6l-62.2 62.2c-4.7 4.7-12.3 4.7-17 0l-22.6-22.6c-4.7-4.7-4.7-12.3 0-17l62.2-62.2-62.2-62.2c-4.7-4.7-4.7-12.3 0-17l22.6-22.6c4.7-4.7 12.3-4.7 17 0l62.2 62.2 62.2-62.2c4.7-4.7 12.3-4.7 17 0l22.6 22.6c4.7 4.7 4.7 12.3 0 17z"></path></svg> can get stuck in local optima --- # Metropolis-Hastings MCMC Gibbs requires knowing `\(p(\theta_p | \mathbf{\theta}_{-p}, \mathbf{Y})\)` Metropolis Hastings only requires `\(p(\mathbf{Y},\theta) = \color{darkgreen}{p(\mathbf{Y}|\theta)}\color{darkblue}{p(\theta)}\)` -- ## Transition kernel Need a way to hop around parameter space -- Introduce `\(Q(\theta'|\theta)\)` interpreted as > What's the probability I could move to `\(\theta'\)` given I'm at `\(\theta\)`? Common choice `\(Q(\theta'|\theta) = \mathcal{N}(\theta, \lambda)\)` for some `\(\lambda\)` i.e. a ball around `\(\theta\)` with variance `\(\lambda\)` --- # Metropolis-Hastings MCMC (II) Metropolis Hastings then works as follows: At `\(t=1\)` initialize `\(\theta^{(t)}\)` -- Then repeatedly: 1. Propose a new `\(\theta^{(t+1)} \sim Q(\cdot | \theta^{(t)})\)` -- 2. Compute the _acceptance ratio_ `$$\alpha = \min(1, \frac{p( \theta^{(t+1)} | \mathbf{Y}) Q(\theta^{(t)} | \theta^{(t+1)}) }{p( \theta^{(t)} | \mathbf{Y}) Q(\theta^{(t+1)} | \theta^{(t)})})$$` -- 3. Sample some `\(u \sim \text{Unif}(0,1)\)` -- 4. If `\(u < \alpha\)` add `\(\theta^{(t+1)}\)` to my set of samples of `\(\theta\)`, otherwise add `\(\theta^{(t)}\)` --- # Metropolis-Hastings MCMC (III) `$$\alpha = \min(1, \frac{p( \theta^{(t+1)} | \mathbf{Y}) Q(\theta^{(t)} | \theta^{(t+1)}) }{p( \theta^{(t)} | \mathbf{Y}) Q(\theta^{(t+1)} | \theta^{(t)})})$$` > Hey! You told use we didn't need to know `\(p( \theta | \mathbf{Y})\)`! What's the point in this? -- `$$\alpha = \min(1, \frac{p( \mathbf{Y} | \theta^{(t+1)}) p(\theta^{(t+1)}) \color{darkred}{/ p(\mathbf{Y})} Q(\theta^{(t)} | \theta^{(t+1)}) }{p( \mathbf{Y} | \theta^{(t)}) p(\theta^{(t)}) \color{darkred}{/ p(\mathbf{Y})} Q(\theta^{(t+1)} | \theta^{(t)})})$$` -- As you collect more and more samples of `\(\theta\)`, they approximate `\(p( \theta | \mathbf{Y})\)` -- * Often include some _burn in_ * Often sample multiple chains (distinct initializations) --- # MCMC visualized .center[ <img src="bayesian-figs/mh.gif" width=90%> ] .footnote[ Figure from https://mbjoseph.github.io/posts/2018-12-25-animating-the-metropolis-algorithm/ More reading: https://stephens999.github.io/fiveMinuteStats/MH_intro.html ] --- # Variational Inference We know `\(p(\mathbf{\theta} | \mathbf{Y})\)` is hard to evaluate Let's come up with some simpler distribution `\(q(\theta | \lambda)\)`<sup>1</sup> and adjust `\(\lambda\)` to make .footnote[ [1] _Not_ the same as the proposal distribution in MCMC ] `$$q(\theta | \lambda) \approx p(\mathbf{\theta} | \mathbf{Y})$$` all without directly evaluating `\(p(\mathbf{\theta} | \mathbf{Y})\)` You: > Gasp! Magic! --- # Variational Inference (II) To make `\(q(\theta | \lambda) \approx p(\mathbf{\theta} | \mathbf{Y})\)` we need some idea of a distance<sup>1</sup> between `\(q\)` and `\(p\)` .footnote[ [1] Not actually a distance ] -- Common choice is Kullback–Leibler divergence (KL-divergence), defined as `$$\text{KL}(q||p) = \mathbb{E}_{q(\theta|\lambda)} [\log q(\theta|\lambda) - \log p(\theta | \mathbf{Y})]$$` `\(\text{KL}(q||p)=0\)` when `\(p\)` and `\(q\)` are the same distribution Idea is to adjust `\(\lambda\)` to minimize `\(\text{KL}(q||p)\)` -- > This still depends on `\(p(\theta | \mathbf{Y})\)`!!! What's the point? --- # Variational Inference (III) `$$\begin{aligned} \text{KL}(q||p) & = \mathbb{E}_{q(\theta|\lambda)} [\log q(\theta|\lambda) - \log p(\theta | \mathbf{Y})] \\ & = \mathbb{E}_{q(\theta|\lambda)} [\log q(\theta|\lambda) - \log p(\mathbf{Y} | \theta) - \log p(\theta) + \log p(\mathbf{Y})] \end{aligned}$$` -- `\(\mathbb{E}_{q(\theta|\lambda)} [ \log p(\mathbf{Y}) ]\)` is a constant wrt `\(\lambda\)` so our original minization problem is the same as minizing `$$\mathbb{E}_{q(\theta|\lambda)} [\log q(\theta|\lambda) - \log p(\mathbf{Y} | \theta) - \log p(\theta)]$$` We only need to evaluate `\(p(\mathbf{Y} | \theta)\)` and `\(p(\theta)\)`! -- ## Mean field approximations Common choice is `\(q(\theta|\lambda) = \mathcal{N}(\theta | \mu_\lambda, \sigma_\lambda^2)\)` -- If `\(\theta\)` is `\(P\)`-dimensional, make a _mean field approximation_ `$$q(\theta | \lambda) = \prod_{p=1}^P q_p(\theta_p | \lambda_p)$$` --- # Variational Bayes: miscellaneous observations * Often much faster than MCMC across a range of problems -- * No guarantee (unlike MCMC<sup>1</sup>) you end up with `\(p(\theta | \mathbf{Y})\)` .footnote[ [1] Asymptotically ] -- * Much research into making `\(q\)` more flexible -- * If we have sample specific `\(z_n\)`, popular choice is `\(q(z_n | y_n) = \text{neural_network}(y_n)\)` -- * Blei, Kucukelbir, and McAuliffe [BKM17] is an excellent overview of the topic --- # Probabilistic programming languages ## "Normal" programming languages ```python x = capture_user_input() if x == "somevalue": # do things ``` * Variables are deterministic* * Executing the program <svg viewBox="0 0 448 512" xmlns="http://www.w3.org/2000/svg" style="height:1em;fill:currentColor;position:relative;display:inline-block;top:.1em;"> [ comment ] <path d="M190.5 66.9l22.2-22.2c9.4-9.4 24.6-9.4 33.9 0L441 239c9.4 9.4 9.4 24.6 0 33.9L246.6 467.3c-9.4 9.4-24.6 9.4-33.9 0l-22.2-22.2c-9.5-9.5-9.3-25 .4-34.3L311.4 296H24c-13.3 0-24-10.7-24-24v-32c0-13.3 10.7-24 24-24h287.4L190.9 101.2c-9.8-9.3-10-24.8-.4-34.3z"></path></svg> running instruction set -- ## Probabilistic programming languages * Variables are random * Executing the program <svg viewBox="0 0 448 512" xmlns="http://www.w3.org/2000/svg" style="height:1em;fill:currentColor;position:relative;display:inline-block;top:.1em;"> [ comment ] <path d="M190.5 66.9l22.2-22.2c9.4-9.4 24.6-9.4 33.9 0L441 239c9.4 9.4 9.4 24.6 0 33.9L246.6 467.3c-9.4 9.4-24.6 9.4-33.9 0l-22.2-22.2c-9.5-9.5-9.3-25 .4-34.3L311.4 296H24c-13.3 0-24-10.7-24-24v-32c0-13.3 10.7-24 24-24h287.4L190.9 101.2c-9.8-9.3-10-24.8-.4-34.3z"></path></svg> probabilistic inference -- PPLs essentially automate statistical inference --- # Introduction to STAN Stan (https://mc-stan.org/) is a PPL for performing * MCMC * Variational Bayes * (Penalized) maximum likelihood estimation -- ## Practicalities * Compiled language, looks a bit like C++ * Write your models in a `.stan` file * Interfaces to R and Python (and Matlab, and Julia, and...) --- # Anatomy of a Stan program 3 main "blocks" in a Stan file -- ## 1. Data All the _fixed_ variables (i.e. data) get declared here -- ## 2. Parameters All the variables you want to perform inference -- ## 3. Model How the data + parameters come together in a statistical model --- # An example Suppose `\(y_n\)` is tumour volume, `\(x_n\)` is SNP status (0 or 1) -- `$$y_n \sim \text{LogNormal}(\beta_1 + \beta_2 x_n, \sigma)$$` and we put some funky priors `$$\beta_1, \beta_2 \sim \mathcal{N}(0,1)$$` `$$\sigma \sim \text{Gamma}(2, 0.1)$$` -- Real object of inference here is `\(\beta_1\)` (effect of SNP on tumour volume) Ideally want to target `\(p(\beta_1 | \mathbf{x}, \mathbf{y})\)` ignoring everything else! --- # The model ```r data { int<lower = 1> N; // number of samples real<lower = 0> y[N]; // tumour volume int<lower=0,upper=1> x[N]; // SNP status } parameters { real beta[2]; real<lower=0> sigma; } model { beta ~ normal(0,1); sigma ~ gamma(2, 0.1); for(n in 1:N) y[n] ~ lognormal(beta[1] + beta[2] * x[n], sigma); } ``` --- # Inference ```r library(rstan) data <- list(N=N, x=x, y=y) fit <- stan(model_code=code, data=data) ``` --- # Output of Stan fit ```r plot(fit) ``` ``` ## ci_level: 0.8 (80% intervals) ``` ``` ## outer_level: 0.95 (95% intervals) ``` ![](22-bayesian-inference_files/figure-html/unnamed-chunk-7-1.png)<!-- --> --- # Traceplots ```r traceplot(fit) ``` ![](22-bayesian-inference_files/figure-html/unnamed-chunk-8-1.png)<!-- --> --- # Joint posterior ```r pairs(fit, pars=c("beta", "sigma")) ``` ![](22-bayesian-inference_files/figure-html/unnamed-chunk-9-1.png)<!-- --> --- # Many other probabilistic programming languages 1. PyMC3 - https://docs.pymc.io/ 2. Edward - http://edwardlib.org/ 3. Pyro - https://pyro.ai/ 4. NumPyro - http://num.pyro.ai/en/stable/ .center[ <img src="https://en.meming.world/images/en/c/cd/How_Do_You_Do%2C_Fellow_Kids%3F.jpg" width=60%> ] --- # References These slides: [camlab.ca/teaching](https://www.camlab.ca/teaching) Blei, D. M., A. Kucukelbir, and J. D. McAuliffe (2017). "Variational inference: A review for statisticians". In: _Journal of the American statistical Association_ 112.518, pp. 859-877.