Mixture Models and the EM Algorithm#

Unit II was all about supervised learning problems. Both encoding and decoding involve predicting one signal from another: neural firing rates given stimuli, or behavioral outputs given neural activity. Now we turn our attention to unsupervised learning problems. Given measurements of neural activity or behavior, we aim to find latent variables — cluster, factors, etc. — that can explain the data.

Unsupervised learning has many applications in neuroscience. First and foremost, it’s a means of dimensionality reduction. Latent variables offer a compressed representation of high-dimensional neural spike trains or behavioral videos. Low dimensional representations can aid in visualization as well as hypothesis generation.

We encountered a few examples of unsupervised learning in Unit I. Spike sorting and calcium deconvolution were both unsupervised problems: the spike times and templates were a low dimensional representation of a high dimensional voltage trace or video.

Unsupervised learning is all about constraints. We need to set some boundaries on what types of latent variables we are looking for. I like to think of constraints in terms of the five D’s

  • Dimensionality: how many latent clusters, factors, etc.?

  • Domain: are the latent variables discrete, continuous, bounded, sparse, etc.?

  • Dynamics: how do the latent variables change over time?

  • Dependencies: how do the latent variables relate to the observed data?

  • Distribution: do we have prior knowledge about the variables’ probability?

This are certainly not an exhaustive list of the types of constraints, but it covers many bases. Thinking in terms of constraints allows us to organize many commonly used models.

Gaussian mixture models#

Let’s start with a simple and canonical example: the Gaussian mixture model (GMM).

Let \(\mathbf{x}_t \in \mathbb{R}^D\) denote the \(t\)-th observation. Let \(z_t \in \{1, \ldots, K\}\) denote the discrete latent state (aka cluster assignment) of that data point.

Assume each observation is independently drawn according to the following model:

\[\begin{split} \begin{align*} z_t &\sim \mathrm{Cat}(\boldsymbol{\pi}) \\ \mathbf{x}_t &\sim \mathcal{N}(\boldsymbol{\mu}_{z_t}, \boldsymbol{\Sigma}_{z_t}) \end{align*} \end{split}\]

 The parameters include the prior probabilities of the cluster assignments, \(\boldsymbol{\pi} \in \Delta_K\), and the conditional mean and variance of each cluster, \(\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k\). Let \(\boldsymbol{\theta} = \{\boldsymbol{\pi}, \{\boldsymbol{\mu}, \boldsymbol{\Sigma}\}_{k=1}^K\}\) denote the set of all parameters.

The joint probability is,

\[ \begin{align*} p(\mathbf{z}_{1:T}, \mathbf{x}_{1:T}; \boldsymbol{\theta}) &= \prod_{t=1}^T \mathrm{Cat}(z_t; \boldsymbol{\pi}) \, \mathcal{N}(\mathbf{x}_t; \boldsymbol{\mu}_{z_t}, \boldsymbol{\Sigma}_{z_t}) \end{align*} \]

MAP estimation and K-means#

In Chapter 2, we fit this type of model using MAP estimation,

\[ \mathbf{z}_{1:T}^\star, \boldsymbol{\theta}^\star = \mathrm{arg}\,\mathrm{max} \; p(\mathbf{z}_{1:T}, \mathbf{x}_{1:T}; \boldsymbol{\theta}) \]

We solved for the MAP estimate by coordinate ascent. For the GMM, the coordinates updates are,

  1. For each data point \(t=1,\ldots, T\):

    \[ z_t = \mathrm{arg}\,\max_k \; \pi_k \, \mathcal{N}(\mathbf{x}_t \mid \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k) \]
  2. For each cluster \(k=1,\ldots, K\)

    \[\begin{split} \begin{align*} T_k &= \sum_{t=1}^T \mathbb{I}[z_t = k] \\ \boldsymbol{\mu}_k &= \frac{1}{T_k} \sum_{t=1}^T \mathbb{I}[z_t = k] \, \mathbf{x}_t \\ \boldsymbol{\Sigma}_k &= \frac{1}{T_k} \sum_{t=1}^T \mathbb{I}[z_t = k] \, (\mathbf{x}_t - \boldsymbol{\mu}_k) (\mathbf{x}_t - \boldsymbol{\mu}_k)^\top \end{align*} \end{split}\]

In words, we set the cluster means and covariances equal to the same means and covariances of the data points assigned to that cluster. If we were to fix \(\boldsymbol{\Sigma}_k = \mathbf{I}\) for all clusters, this algorithm would be equivalent to K-means clustering.

Expectation-Maximization (EM)#

MAP estimation gives us a point estimate of the latent variables and parameters. However, point estimates can lead to an overly optimistic view of the model. Specifically, MAP estimation found the best assignment of latent, which may not reflect how well samples from the model match the observed data.

Ideally, we would like to find parameters that maximize the marginal likelihood, aka the model evidence,

\[ p(\mathbf{x}_{1:T}; \boldsymbol{\theta}) = \int p(\mathbf{x}_{1:T}, \mathbf{z}_{1:T}; \boldsymbol{\theta}) \, \mathrm{d} z_1 \cdots \mathrm{d} z_T. \]

Note

For discrete latent variable models, the integral becomes a sum.

Once we have those parameters, we can compute the posterior distribution over latent variables,

\[ p(\mathbf{z}_{1:T} \mid \mathbf{x}_{1:T}; \boldsymbol{\theta}). \]

The expectation-maximization (EM) algorithm does exactly that: it finds a local maximum of the marginal likelihood via an iterative algorithm very similar to the MAP estimation algorithm above. The key difference is that instead of using hard assignments of data points to the most likely cluser, we use soft assignments of data points based on their posterior probabilities.

  1. For each data point \(t=1,\ldots, T\) and cluster \(k=1,\ldots, K\), set:

    \[ q_{t,k} = p(z_t = k \mid \mathbf{x}_t; \boldsymbol{\theta}) = \frac{\pi_k \, \mathcal{N}(\mathbf{x}_t \mid \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k)}{\sum_{j=1}^K \pi_j \, \mathcal{N}(\mathbf{x}_t \mid \boldsymbol{\mu}_j, \boldsymbol{\Sigma}_j)} \]
  2. For each cluster \(k=1,\ldots, K\), update the parameters as follows:

    \[\begin{split} \begin{align*} T_k &= \sum_{t=1}^T q_{t,k} \\ \boldsymbol{\mu}_k &= \frac{1}{T_k} \sum_{t=1}^T q_{t,k} \, \mathbf{x}_t \\ \boldsymbol{\Sigma}_k &= \frac{1}{T_k} \sum_{t=1}^T q_{t,k} \, (\mathbf{x}_t - \boldsymbol{\mu}_k) (\mathbf{x}_t - \boldsymbol{\mu}_k)^\top \end{align*} \end{split}\]

The posterior probabilities \(q_{t,k}\) are sometimes called the responsibilities. When we update the parameters, we take a weighted average of all data points based on these responsibilities.

The Evidence Lower Bound (ELBO)#

Why does EM work? We can view it as coordinate ascent on the evidence lower bound (ELBO). First, let’s rewrite the log marginal likelihood as the log of an expectation,

\[\begin{split} \begin{align*} \log p(\mathbf{x}; \boldsymbol{\theta}) &= \log \int p(\mathbf{x}, \mathbf{z}; \boldsymbol{\theta}) \, \mathrm{d}\mathbf{z} \\ &= \log \int \frac{q(\mathbf{z})}{q(\mathbf{z})} p(\mathbf{x}, \mathbf{z}; \boldsymbol{\theta}) \, \mathrm{d}\mathbf{z} \\ &= \log \mathbb{E}_{q(\mathbf{z})} \left[ \frac{p(\mathbf{x}, \mathbf{z}; \boldsymbol{\theta})}{q(\mathbf{z})} \right] \end{align*} \end{split}\]

Note

We dropped the subscripts on \(\mathbf{x}_{1:T}\) and \(\mathbf{z}_{1:T}\) to be more concise.

This equality holds for any distribution \(q(\mathbf{z})\) as long as \(p\) is absolutely continuous with respect to \(q\). We call \(q(\mathbf{z})\) the variational posterior.

Since log is a concave function, Jensen’s inequality says that swapping the order of the log and the expectation gives a lower bound on the log marginal likelihood,

\[ \begin{align*} \log p(\mathbf{x}; \boldsymbol{\theta}) \geq \mathbb{E}_{q(\mathbf{\mathbf{z}})} \left[ \log p(\mathbf{x}, \mathbf{z}; \boldsymbol{\theta}) - \log q(\mathbf{z}) \right] \triangleq \mathcal{L}[q, \boldsymbol{\theta}] \end{align*} \]

Jensen’s inequality

Jensen’s inequality relates convex functions of expectations to expectations of convex functions. If \(\mathbf{x}\) is a random variable and \(f\) is a convex function,

\[ f(\mathbb{E}[\mathbf{x}]) \leq \mathbb{E}[f(\mathbf{x})] \]

If \(f\) is a concave function, the inequality is reversed.

The functional \(\mathcal{L}[q, \boldsymbol{\theta}]\) is called evidence lower bound, a.k.a. the ELBO.

EM and coordinate ascent#

We can view the EM algorithm as coordinate ascent on the ELBO.

  1. M-step: Update the parameters,

    \[\begin{split} \begin{align*} \boldsymbol{\theta} &= \mathrm{arg}\,\mathrm{max}_{\boldsymbol{\theta}} \; \mathcal{L}[q, \boldsymbol{\theta}] \\ &= \mathrm{arg}\,\mathrm{max}_{\boldsymbol{\theta}} \; \mathbb{E}_{q(z)}[ \log p(\mathbf{x}, \mathbf{z}, \boldsymbol{\theta})]. \end{align*} \end{split}\]

    That is, maximize the expected log joint probability.

  2. E-step: Update the variational posterior by setting it equal to the posterior,

    \[ q(\mathbf{z}) = p(\mathbf{z} \mid \mathbf{x}; \boldsymbol{\theta}). \]

    To see why this maximizes the ELBO for fixed \(\boldsymbol{\theta}\), note that

    \[\begin{split} \begin{align*} q(\mathbf{z}) &= \mathrm{arg}\,\mathrm{max}_q \; \mathcal{L}[q, \boldsymbol{\theta}] \\ &= \mathrm{arg}\,\mathrm{max}_q \; \mathbb{E}_{q(\mathbf{z})} \left[ \frac{\log p(\mathbf{x}, \mathbf{z}, \boldsymbol{\theta})}{q(\mathbf{\mathbf{z}})} \right] \\ &= \mathrm{arg}\,\mathrm{min}_q \; \mathrm{KL}\left(q(\mathbf{\mathbf{z}}) \, \| \, p(\mathbf{z} \mid \mathbf{x}, \boldsymbol{\theta}) \right) \\ &= p(\mathbf{z} \mid \mathbf{x}, \boldsymbol{\theta}) \end{align*} \end{split}\]

    where \(\mathrm{KL}(q \| p)\) denotes the Kullback-Leibler divergence from \(q\) to \(p\). The divergence is non-negative and zero iff \(q = p\). Thus, maximizing the ELBO wrt \(q\) amounts to setting the variational posterior equal to the true posterior.

Note

The ELBO is tight after the E-step. Substituting \(q(\mathbf{z}) = p(\mathbf{z} \mid \mathbf{x}; \boldsymbol{\theta})\) into the ELBO,

\[\begin{split} \begin{align*} \mathcal{L}[q, \boldsymbol{\theta}] &= \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x}, \boldsymbol{\theta})} \left[ \log \frac{p(\mathbf{x}, \mathbf{z}, \boldsymbol{\theta})}{p(\mathbf{z} \mid \mathbf{x}, \boldsymbol{\theta})} \right] \\ &= \mathbb{E}_{p(\mathbf{z} \mid \mathbf{x}, \boldsymbol{\theta})} \left[ \log p(\mathbf{x}, \boldsymbol{\theta}) \right] \\ &= \log p(\mathbf{x}, \boldsymbol{\theta}) \end{align*} \end{split}\]

Iterating between the E- and M-steps converges to a local optimum of the ELBO. Since after each E-step the ELBO is tight, the local optimum of the ELBO must also be a local optimum of the log marginal likelihood.

Exponential family mixtures#

So far we’ve focused on Gaussian mixture models. Now let’s consider the more general case of exponential family mixture models for which,

\[ p(\mathbf{x}_n \mid z_n = k; \boldsymbol{\eta}) = h(\mathbf{x}_n) \exp \left\{\langle t(\mathbf{x}_n), \boldsymbol{\eta}_k \rangle - A(\boldsymbol{\eta}_k) \right\} \]

where \(\boldsymbol{\eta} = \{\boldsymbol{\eta}_k\}_{k=1}^K\) is the set of natural parameters, \(t(\mathbf{x}_n)\) are the sufficient statistics, and \(A(\boldsymbol{\eta}_k)\) is the log normalizer. Recall Chapter 9 for an introduction to exponential family distributions.

Warning

We have switched to indexing data points by \(n\) instead of \(t\) to avoid confusion with the sufficient statistic \(t\). Furthermore, we are assuming that all \(K\) components of the mixture model belong to the same exponential family (and hence have the same sufficient statistics, log normalizer, base measure, etc.).

Under this model, the E-step reduces to computing the responsibilities,

\[ q_{n,k} = p(z_n = k \mid \mathbf{x}_n; \boldsymbol{\eta}) = \frac{\pi_k \, \exp \left\{\langle t(\mathbf{x}_n), \boldsymbol{\eta}_k \rangle - A(\boldsymbol{\eta}_k) \right\}}{{\sum_{j=1}^K \pi_j \, \exp \left\{\langle t(\mathbf{x}_n), \boldsymbol{\eta}_j \rangle - A(\boldsymbol{\eta}_j) \right\}}} \]

As a function of the parameters \(\boldsymbol{\eta}_k\), the ELBO is,

\[\begin{split} \begin{align*} \mathcal{L}(q, \boldsymbol{\eta}) &= \mathbb{E}_{q(\mathbf{z})}[\log p(\mathbf{x}, \mathbf{z}; \boldsymbol{\eta})] + c\\ &= \sum_{n=1}^N q_{n,k} \, \log p(\mathbf{x}_n \mid z_n = k; \boldsymbol{\eta}) + c \\ &= \sum_{n=1}^N q_{n,k} \left[ \langle t(\mathbf{x}_n), \boldsymbol{\eta}_k \rangle - A(\boldsymbol{\eta}_k) \right] + c \\ &= \langle \mathbf{t}_k, \boldsymbol{\eta}_k \rangle - N_k A(\boldsymbol{\eta}_k) + c \end{align*} \end{split}\]

where

\[\begin{split} \begin{align*} \mathbf{t}_k &= \sum_{n=1}^N q_{n,k} \, t(\mathbf{x}_n) \\ N_k &= \sum_{n=1}^N q_{n,k}. \end{align*} \end{split}\]

Taking derivatives and setting to zero yields,

\[ \boldsymbol{\eta}_k = [A']^{-1}(\mathbf{t}_k / N_k) \]

Recall that \(A'\) is the moment generating funtion: derivatives of the log normalizer yield expected sufficient statistics. Assuming the exponential family is minimal, the inverse of moment generating function is well-defined, and the optimal natural parameters are those for which the expected sufficient statistics match the weighted average sufficient stastics, \(\mathbf{t}_k / N_k\).

Normalized sufficient statistics

Note that M-step is invariant to rescaling the ELBO. We could have multiplied \(\mathcal{L}(q, \boldsymbol{\eta})\) by a positive constant and the optimal \(\boldsymbol{\eta}_k\) would remain the same. For example, we could normalize the statistics by the size of the dataset,

\[\begin{split} \begin{align*} \overline{\mathbf{t}}_k &= \frac{\mathbf{t}_k}{N} \\ \overline{N}_k &= \frac{N_k}{N} \end{align*} \end{split}\]

Then \(\boldsymbol{\eta}_k = [A']^{-1}(\mathbf{t}_k / N_k) = [A']^{-1}(\overline{\mathbf{t}}_k / \overline{N}_k)\) is unchanged. Working with normalized sufficient statistics like these can be more numerically stable, especially for very large datasets.

Stochastic EM#

Finally, the EM algorithm presented above is a batch algorithm. The M-step involves aggregating sufficient statistics from the entire dataset. For very large datasets, it is more efficient to do an M-step after each mini-batch of data. That is how stochastic EM works.

The key idea is to maintain a running estimate of the sufficient statistics. Assume we have \(N\) data points equally divided into \(M\) mini-batches, each of size \(\tfrac{N}{M}\). Let \(\overline{t}_k^{(m)}\) and \(\overline{N}_k^{(m)}\) be the normalized statistics and responsibilities computed on the \(m\)-th mini-batch. (They are normalized by summing over the mini-batch and then dividing by its size, \(\tfrac{N}{M}\).)

In stochastic EM, we keep a running estimate of the normalized statistics. After processing each mini-batch, we update the running estimate by taking a convex combination of the previous estimate and the estimate from the current mini-batch.

\[\begin{split} \begin{align*} \overline{\mathbf{t}}_k \leftarrow (1-\alpha) \overline{\mathbf{t}}_k + \alpha \overline{\mathbf{t}}_k^{(m)} \\ \overline{N}_k \leftarrow (1-\alpha) \overline{N}_k + \alpha \overline{N}_k^{(m)} \end{align*} \end{split}\]

where \(\alpha \in [0, 1]\) is the step size. After each mini-batch, we update our parameters by setting,

\[ \boldsymbol{\eta}_k = [A']^{-1}(\overline{\mathbf{t}}_k / \overline{N}_k). \]

To ensure convergence, we decay the learning rate according to a schedule, starting with \(\alpha = 1\) and slowly decaying it toward zero after each mini-batch.

One sweep through the entire set of \(M\) mini-batches — i.e., a sweep through the whole dataset — is called an epoch. Since stochastic EM updates parameters once per mini-batch as opposed to once per epoch, as in standard EM, the parameter estimates often converge in many fewer epochs.

Stochastic EM for GMM#

Let’s consider the special case of the Gaussian mixture model and derive the stochastic EM algorithm. The Gaussian probability can be written in exponential family form as,

\[\begin{split} \begin{align*} p(\mathbf{x}_n \mid \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k) &= (2 \pi)^{-\frac{D}{2}} |\boldsymbol{\Sigma}_k|^{-\frac{1}{2}} \exp \left\{-\frac{1}{2} (\mathbf{x}_n - \boldsymbol{\mu}_k)^\top \boldsymbol{\Sigma}^{-1} (\mathbf{x}_n - \boldsymbol{\mu}_k) \right\} \\ &= (2 \pi)^{-\frac{D}{2}} \exp \left\{\langle \mathbf{x}_n, \boldsymbol{\Sigma}_k^{-1} \boldsymbol{\mu}_k \rangle + \langle \mathbf{x}_n \mathbf{x}_n^\top, -\tfrac{1}{2} \boldsymbol{\Sigma}_k^{-1} \rangle -\frac{1}{2} \boldsymbol{\mu}^\top \boldsymbol{\Sigma}_k^{-1} \boldsymbol{\mu}_k -\frac{1}{2} \log |\boldsymbol{\Sigma}_k| \right\} \\ &= h(\mathbf{x}_n) \exp \left\{\langle t_1(\mathbf{x}_n), \boldsymbol{\Sigma}_k^{-1} \boldsymbol{\mu}_k \rangle + \langle t_2(\mathbf{x}_n), -\tfrac{1}{2} \boldsymbol{\Sigma}_k^{-1} \rangle - A(\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k) \right\} \end{align*} \end{split}\]

where the sufficient statistics are

\[ \begin{align*} t_1(\mathbf{x}_n) &= \mathbf{x}_n & t_2(\mathbf{x}_n) &= \mathbf{x}_n \mathbf{x}_n^\top \end{align*} \]

and the log normalizer is

\[ A(\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k) = \frac{1}{2} \boldsymbol{\mu}^\top \boldsymbol{\Sigma}_k^{-1} \boldsymbol{\mu}_k + \frac{1}{2} \log |\boldsymbol{\Sigma}_k| \]

Warning

Natural parameters We could write this log probability in terms of its natural parameters \(\boldsymbol{\eta}_{k,1} = \boldsymbol{\Sigma}_k^{-1} \boldsymbol{\mu}_k\) and \(\boldsymbol{\eta}_{k,2} = \boldsymbol{\Sigma}_k^{-1}\), but since we ultimately want the standard parameters anyway, let’s just leave it in this form.

In the M-step of stochastic EM, we need to maximize the ELBO with respect to \(\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k\). The ELBO is a weighted sum of log probabilities, which is of the form,

\[ \begin{align*} \mathcal{L}(q, \boldsymbol{\theta}) &= \langle \overline{\mathbf{t}}_{k,1}, \boldsymbol{\Sigma}_k^{-1} \boldsymbol{\mu}_k \rangle + \langle \overline{\mathbf{t}}_{k,2}, -\tfrac{1}{2} \boldsymbol{\Sigma}_k^{-1} \rangle - \overline{N}_k A(\boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k) \end{align*} \]

where

\[ \begin{align*} \overline{\mathbf{t}}_{k,1} &= \frac{1}{N} \sum_{n=1}^N q_{n,k} \, \mathbf{x}_n & \overline{\mathbf{t}}_{k,2} &= \frac{1}{N} \sum_{n=1}^N q_{n,k} \, \mathbf{x}_n \mathbf{x}_n^\top & \overline{N}_{k} &= \frac{1}{N} \sum_{n=1}^N q_{n,k} \end{align*} \]

Taking the derivative with respect to \(\boldsymbol{\mu}_k\),

\[ \nabla_{\boldsymbol{\mu}_k} \mathcal{L}(q, \boldsymbol{\theta}) = \boldsymbol{\Sigma}_k^{-1} \overline{\mathbf{t}}_{k,1} - \overline{N}_{k} \boldsymbol{\Sigma}_k^{-1} \boldsymbol{\mu}_k \]

Setting to zero and solving for the mean yields,

\[ \boldsymbol{\mu}_k = \overline{\mathbf{t}}_{k,1} / \overline{N}_{k}. \]

Substituting this into the ELBO yields,

\[ \begin{align*} \mathcal{L}(q, \boldsymbol{\theta}) &= \frac{1}{2} \langle\overline{\mathbf{t}}_{k,1} \overline{\mathbf{t}}_{k,1}^\top / \overline{N}_{k} - \overline{\mathbf{t}}_{k,2}, \boldsymbol{\Sigma}_k^{-1} \rangle - \frac{1}{2} \overline{N}_k \log |\boldsymbol{\Sigma}_k | \end{align*} \]

Taking the derivative with respect to \(\boldsymbol{\Sigma}_k^{-1}\) yields,

\[ \nabla_{\boldsymbol{\Sigma}_k^{-1}} \mathcal{L}(q, \boldsymbol{\theta}) = \frac{1}{2} \left( \frac{\overline{\mathbf{t}}_{k,1} \overline{\mathbf{t}}_{k,1}^\top}{\overline{N}_{k}} - \overline{\mathbf{t}}_{k,2} \right) + \frac{1}{2} \overline{N}_k \boldsymbol{\Sigma}_k \]

Then solving for \(\boldsymbol{\Sigma}_k\) gives,

\[ \boldsymbol{\Sigma}_k = \frac{1}{\overline{N}_k} \left( \overline{\mathbf{t}}_{k,2} - \frac{\overline{\mathbf{t}}_{k,1} \overline{\mathbf{t}}_{k,1}^\top}{\overline{N}_{k}} \right). \]

Conclusion#

This chapter introduced unsupervised learning methods using latent variables.

  • We started with the Gaussian mixture model and recalled the coordinate ascent algorithm for MAP estimation, which is very similar to K-means.

  • Then we introduced the expectation-maximization algorithm and saw that for GMMs, it nicely parallels the MAP estimation algorithm. However, it yields parameters that maximize the marginal likelihood instead of the joint.

  • We derived EM as coordinate ascent on the evidence lower bound (ELBO). This view will generalize to more complex models in subsequent chapters.

  • Finally, we considered the more general class of exponential family mixture models and derived a stochastic EM algorithm that operates on running estimates of the expected sufficient statistics.

Next time, we’ll consider sequential latent variable models called hidden Markov models (HMMs).