Model Preliminaries

1. Variational Inference is used to approximate posterior densities for Bayesian models as an alternative strategy to MCMC.

2. Given a (joint) model, $p(X, Z)$, with latent variables $Z = z_1, … z_m$, and observations $X = x_1, .. x_n$. we are often interested in computing the posterior $p(Z|X)$, the probability of our latent variables given the data.

3. Often the exact posterior $p(Z|X)$ is intractable to calculate. Thus we aim to optimise a different distribution $q$, from a family of convenient distributions $Q$ over the latent variables, and minimize the KL divergence to the exact posterior.

$$q^*(Z) = argmin_{q(z)\in Q} KL(q(Z) || p(Z|X) )$$

1. The KL is not tractable either because it requires us to compute $p(X)$. Well the marginal log likelihood can be decomposed into the following, therefore maximising the ELBO is equivalent to minimising the KL divergence.

$$log p(X) = ELBO(q) + KL(q(Z) || p(Z|X))$$

Coordinate Ascent Variational Inference

• The mean field variational family describes a family of distributions where the latent variables are mutually independent and $q$ factorizes such that

$$q(Z) = \prod_{j=1}^m q_j(z_j)$$

• A general expression for the optimal solution $log p (q_j^* (z_j))$ is given by: $$log p(q_j^*(z_j)) \propto \mathbb{E}_{i\neq j}[log p(X, Z)]$$

• Because each $q_j^*(z_j)$ can be computed wrt to the other factors where $i\neq j$, each update is valid (since it doesn’t depend on itself). But requires cycling/iterating through the factors until convergence because of dependence on a factor that has been recently updated.

• Note when considering $q_j^* (z_j)$ we only need to consider terms that have some dependence on $z_j$ because other terms can be considered constants that do not affect $q_j^*$.

Pseudocode for CAVI

(from david Blei)

Joint distribution $p(X, Z)$
Variational density $q(Z) = \prod_{j=1}^{m}q_j(z_j)$
Variational Factors $q_j(z_j)$
$ELBO = \mathbb{E}[logp(Z, X)] - \mathbb{E}[logq(Z)]$

while ELBO has not converged
for $j \in {1, …, m}$
Set $log p (q_j(z_j)) \propto \mathbb{E}_{i\neq j}[log p(X, Z)]$
end
compute ELBO
end

Concrete Examples in Gory Detail

Univariate Gaussian

1. Specify the generative model: \begin{align} X \sim \mathcal{N}(\mu, \sigma^2) \\ \mu \sim \mathcal{N}(\mu_0, \sigma^2) \\ \sigma^2 \sim InvGamma(\alpha_0, \beta_0) \\ \end{align}

2. The joint distribution is: \begin{align} p(X, Z) &= p(X, \mu, \sigma^2) \\ &= p(X | \mu, \sigma^2) p(\mu | \sigma^2) p(\sigma^2) \end{align}

3. The variational density is: \begin{align} q(Z) &= \prod_{j=1}^m q_j(z_j) \\ q(Z) &= q(\mu, \sigma^2) \\ &= q_\mu(\mu) q_{\sigma^2}(\sigma^2) \\ \end{align}

4. ELBO to check convergence. Note the dependence of $\mu$ on $\sigma^2$ under $p$ but their independence under $q$. \begin{align} ELBO(q) &= \mathbb{E}[logp(Z, X)] - \mathbb{E}[log q(Z)] \\ &= \mathbb{E}[logp(X|Z) + logp(Z)] - \mathbb{E}[log q(Z)] \\ &= \mathbb{E}[logp(X|\mu, \sigma^2) + log p(\mu|\sigma^2) + logp(\sigma^2)] - \mathbb{E}[log q(\mu) + logq(\sigma^2)] \end{align}

5. Optimal factor equation for iterative updates $$log q_j(z_j) = \mathbb{E}_{i\neq j}[log p(X, Z)] + C$$

Expectation $\mathbb{E}$ is evaluated wrt $\sigma^2$, and all terms that do not depend on $\mu$ are collapsed into the constant $C$. \begin{align} log q^*(\mu) &= \mathbb{E} [ logp(X| \mu, \sigma^2) + logp(\mu | \sigma^2) + logp(\sigma^2)] + C \\ &= \mathbb{E} [logp(X|\mu, \sigma^2) + logp(\mu| \sigma^2)] + C \end{align}

Recall pdf of a normal distribution: \begin{align} p(x_i|\mu, \sigma^2) &= \frac{1}{\sigma\sqrt{2\pi}} exp(-\frac{(x_i - \mu)^2}{2\sigma^2}) \\ p(X | \mu, \sigma^2) &= (\frac{1}{\sigma\sqrt{2\pi}})^n exp(-\frac{\sum_{i=1}^n(x_i-\mu)^2}{2\sigma^2}) \\ logp(X| \mu, \sigma^2) &= nlog (\frac{1}{\sigma\sqrt{2\pi}}) - \frac{1}{2\sigma^2}\sum_{i=1}^n(x_i-\mu)^2 \\ &= - \frac{1}{2\sigma^2}\sum_{i=1}^n(x_i-\mu)^2 + C \end{align}

Expanding $logp(\mu|\sigma^2)$ in a similar fashion, eq(18) becomes:

\begin{align} logq^*(\mu) &= \mathbb{E}[-\frac{1}{2\sigma^2}(\mu-\mu_0)^2-\frac{1}{2\sigma^2}\sum_{i=1}^n(x_i-\mu)^2] + C \\ &= -\mathbb{E}[\frac{1}{2\sigma^2} ((\mu-\mu_0)^2 + \sum_{i=1}^n(x_i-\mu)^2)] + C \end{align}

Expanding terms and after completing the square around $\mu$, we find that a Gaussian pops out for $q^*(\mu)$:

\begin{align} logq^*(\mu) &= -\mathbb{E}[\frac{1}{(n+1)(2\sigma^2)}(\mu - \frac{n\bar{x} + \mu_0}{n+1})^2] + C \\ \mu & \sim \mathcal{N}(\mu_n, \sigma_n^2) \\ \mu_n &= \frac{n\bar{x} + \mu_0}{n+1} \\ \sigma_n^2 &= (n+1)(2\mathbb{E}[\sigma^2]) \end{align}

Now for $logq^*(\sigma^2)$, expectation $\mathbb{E}$ is evaluated wrt $\mu$, and all terms that do not depend on $\sigma^2$ are collapsed into the constant $C$.

Working from pdf of the inverse-gamma distribution and collapsing terms that don’t depend on $\sigma^2$:

\begin{align} p(\sigma^2; \alpha_0, \beta_0) &= \frac{\beta_0^{\alpha_0}}{\Gamma(\alpha_0)} \sigma^{2-\alpha_0-1} exp(\frac{-\beta_0}{\sigma^2}) \\ logp(\sigma^2; \alpha_0, \beta_0) &= \alpha_0 log(\beta_0) - log\Gamma(\alpha_0) + (-\alpha_0 -1)log(\sigma^2) - \frac{\beta_0}{\sigma^2} \\ logp(\sigma^2; \alpha_0, \beta_0) &= (-\alpha_0 -1)log(\sigma^2) - \frac{\beta_0}{\sigma^2} + C\\ \end{align}

Putting $logq^*(\sigma^2)$ together, we get:

\begin{align} log q^*(\sigma^2) &= \mathbb{E} [ logp(X| \mu, \sigma^2) + logp(\mu | \sigma^2) + logp(\sigma^2)] + C \\ &= (-\alpha_0 -1)log(\sigma^2) - \frac{\beta_0}{\sigma^2} + log(\sigma^2 2\pi)^{\frac{-N}{2}} + log(\sigma^2 2\pi)^{\frac{-1}{2}} -\mathbb{E}[\frac{1}{2\sigma^2} ((\mu-\mu_0)^2 + \sum_{i=1}^n(x_i-\mu)^2)] + C \\ \end{align}

A inverse-gamma pops out for $q^*(\sigma^2)$, where all $\mathbb{E}$ is evaluated wrt $\mu$:

\begin{align} log q^*(\sigma^2) &= -((\alpha_0 + \frac{n+1}{2}) - 1)log(\sigma^2) - \frac{1}{\sigma^2}(\beta_0 + \frac{1}{2}\mathbb{E}((\mu-\mu_0)^2 + \sum_{i=1}^n(x_i-\mu)^2)]) + C \\ \sigma^2 &\sim InvGamma(\alpha_n, \beta_n) \\ \alpha_n &= (\alpha_0 + \frac{n+1}{2}) \\ \beta_n &= (\beta_0 + \frac{1}{2}\mathbb{E}((\mu-\mu_0)^2 + \sum_{i=1}^n(x_i-\mu)^2)]) \\ &= \beta_0 + \frac{1}{2}\sum_{i=1}^nx_i^2 - \bar{x}\sum_{i=1}^nx_i + (\frac{n+1}{2})(\frac{1}{n}\mathbb{E}[\sigma^2] + \bar{x}^2) - \mu_0\bar{x} + \frac{1}{2}\mu_0^2 \end{align}

Note the expectation under $q$ for the normal and inverse gamma distributions: \begin{align} \mathbb{E}[\mu] &= \mu_n \\ \mathbb{E}[\sigma^2] &= \frac{\beta_n - 1}{\alpha_n} \end{align}

The algorithm converges after 3 iterations in this simple example. We can see coordinate ascent from the following: first the $\mu$ is fitted to the data, followed by $\sigma^2$.

In general Gharamani & Beal showed that if $q$ are from the same exponential family, then we can derive exact updates for the latent variables (which we have done here by brute-force algebra). Also note that because we have estimates of $\mu_n$, $\alpha_n$ and $\beta_n$, we have the entire posterior distribution over the parameters, not just point estimates of $\mu$.

Multivariate Gaussian Mixture Model

1. Specify the generative model: \begin{align} x_i &\sim \mathcal{N}(\mu_{z_i}, \Sigma_{z_i}) \\ z_i &\sim multinomial(\phi_i) \\ \mu_k &\sim \mathcal{N}(\mu_0, \Sigma_k) \\ \Sigma_k &\sim \mathcal{W}^{-1}(\Sigma_0, \nu_0) \\ \phi &\sim Dir(\alpha_0) \end{align}

2. The joint distribution where $Z’$ are all latent variables is: \begin{align} p(X,Z’) &= p(X, Z, \boldsymbol{\phi}, \boldsymbol{\mu}, \boldsymbol{\Sigma}) \\ &= p(X|Z, \boldsymbol{\mu}, \boldsymbol{\Sigma})p(Z|\boldsymbol{\phi})p(\boldsymbol{\phi})p(\boldsymbol{\mu}|\boldsymbol{\Sigma})p(\boldsymbol{\Sigma}) \end{align}

3. The variational density is: \begin{align} q(Z’) &= \prod_{j=1}^m q_j(z_j) \\ &= q(Z)q(\boldsymbol{\phi})q(\boldsymbol{\mu})q(\boldsymbol{\Sigma}) \end{align}

4. ELBO to check convergence. Note the dependence of $\mu$ on $\sigma^2$ under $p$ but their independence under $q$. \begin{align} ELBO(q) &= \mathbb{E}[logp(Z, X)] - \mathbb{E}[log q(Z)] \\ &= \mathbb{E}[logp(X|Z) + logp(Z)] - \mathbb{E}[log q(Z)] \\ &= \mathbb{E}[logp(X|\mu, \sigma^2) + log p(\mu|\sigma^2) + logp(\sigma^2)] - \mathbb{E}[log q(\mu) + logq(\sigma^2)] \end{align}

$$logq_j(z_j) = \mathbb{E}_{i\neq j}[logp(X, Z’)] + C$$

For $logq^*(Z)$, expectation $\mathbb{E}$ is taken wrt $\boldsymbol{\phi}, \boldsymbol{\mu}, \boldsymbol{\Sigma}$.

$$logq^*(Z) = \mathbb{E}_{\boldsymbol{\phi}, \boldsymbol{\mu}, \boldsymbol{\Sigma}}[logp(X|Z, \boldsymbol{\mu}, \boldsymbol{\Sigma}) + logp(Z|\boldsymbol{\phi})] + C$$

\begin{align} logq^*(z_{ik}) &= \mathbb{E}_{\boldsymbol{\phi}, \boldsymbol{\mu}, \boldsymbol{\Sigma}}[logp(x_i|z_i, \mu_k, \Sigma_k) + logp(z_i|\boldsymbol{\phi_k})] + C \\ \end{align}

Recall pdf and logpdf of multivariate normal: \begin{align} p(x_i|\mu_k, \Sigma_k) &= 2\pi^{-k/2} det(\Sigma_k)^{-1/2} exp (-\frac{1}{2}(x_i -\mu_k)^T\Sigma_k^{-1}(x-\mu_k)) \\ logp(x_i | \mu_k, \Sigma_k) &= -\frac{k}{2}log(2\pi) - \frac{1}{2}log det(\Sigma_k) - \frac{1}{2}(x_i -\mu_k)^T\Sigma_k^{-1}(x-\mu_k) \end{align}

Absorbing terms that do not depend on $z_{ik}$ into the additive constant:

$$log q^*(z_{ik}) = \mathbb{E}_{\boldsymbol{\phi}, \boldsymbol{\mu}, \boldsymbol{\Sigma}}[- \frac{1}{2}log det(\Sigma_k) - \frac{1}{2}(x_i -\mu_k)^T\Sigma_k^{-1}(x-\mu_k) + log(\phi_k)] + C$$

Since $q(z_i) = \phi_i$ is a distribution over $k$, it needs to be normalised. Therefore the update is: $$\phi_{ik} = \frac{exp(logq(z_{ik}))}{\sum_{j}exp(logq(z_{ij}))}$$

For $q^*(\phi)$, expectation $\mathbb{E}$ is taken wrt $Z, \boldsymbol{\mu}, \boldsymbol{\Sigma}$.

$$logq^*(\phi) = \mathbb{E}_{Z, \boldsymbol{\mu}, \boldsymbol{\Sigma}}[log p(Z|\phi) + log p(\phi)] + C$$

The multinomial $p(z|\phi)$ is: \begin{align} p(Z|\boldsymbol{\phi}) &= \prod_{k=1}^K \prod_{i=1}^n p(z_{ik}| \phi_k) ^{z_{ik}} \\ log(p(Z|\boldsymbol{\phi}) &= \sum_{k=1}^K \sum_{i=1}^n z_{ik} log p(z_{ik}|\phi_k) \end{align}

Recall the dirichlet distribution $\phi \sim Dir(\alpha_0)$, where $B$ is a normalizing constant:

\begin{align} p(\phi) &= \frac{1}{B(\alpha_0)} \prod_{k=1}^K \phi_k ^{\alpha_0-1} \\ logp(\phi) &= -log(B(\alpha_0)) + (\alpha_0-1)\sum_{k=1}^Klog(\phi_k) \\ \end{align}

Putting things together, and using $\mathbb{E}[z_{ik}] = \phi_{ik}$

\begin{align} logq^*(\phi) &= (\alpha_0 - 1) \sum_{k=1}^K log(\phi_k) + \sum_{k=1}^K \sum_{i=1}^n \phi_{ik} log(\phi_{k}) + C \\ &= (\alpha_0 - 1) \sum_{k=1}^K log(\phi_k) + \sum_{k=1}^K log(\phi_k)^{\sum_{i=1}^n \phi_{ik}} + C \\ &= \sum_{k=1}^K log(\phi_k)^{(\alpha_0 - 1) + N_k} + C \end{align}

which is a dirichlet distribution, $q^*(\boldsymbol{\phi}) = Dir(\boldsymbol{\phi}|\alpha_n)$ with each $k$ component $\alpha_{nk} = \alpha_0 + N_k$

References

Variational Inference: A review for statisticians
Graphical Models and variational methods
[Bishop PRML Approximate Inference (461-473)