Mixture Models and the EM Algorithm#

Unit II was all about supervised learning problems. Both encoding and decoding involve predicting one signal from another: neural firing rates given stimuli, or behavioral outputs given neural activity. Now we turn our attention to unsupervised learning problems. Given measurements of neural activity or behavior, we aim to find latent variables — cluster, factors, etc. — that can explain the data.

Unsupervised learning has many applications in neuroscience. First and foremost, it’s a means of dimensionality reduction. Latent variables offer a compressed representation of high-dimensional neural spike trains or behavioral videos. Low dimensional representations can aid in visualization as well as hypothesis generation.

We encountered a few examples of unsupervised learning in Unit I. Spike sorting and calcium deconvolution were both unsupervised problems: the spike times and templates were a low dimensional representation of a high dimensional voltage trace or video.

Unsupervised learning is all about constraints. We need to set some boundaries on what types of latent variables we are looking for. I like to think of constraints in terms of the five D’s

  • Dimensionality: how many latent clusters, factors, etc.?

  • Domain: are the latent variables discrete, continuous, bounded, sparse, etc.?

  • Dynamics: how do the latent variables change over time?

  • Dependencies: how do the latent variables relate to the observed data?

  • Distribution: do we have prior knowledge about the variables’ probability?

This are certainly not an exhaustive list of the types of constraints, but it covers many bases. Thinking in terms of constraints allows us to organize many commonly used models.

Gaussian mixture models#

Let’s start with a simple and canonical example: the Gaussian mixture model (GMM).

Let xtRD denote the t-th observation. Let zt{1,,K} denote the discrete latent state (aka cluster assignment) of that data point.

Assume each observation is independently drawn according to the following model:

ztCat(π)xtN(μzt,Σzt)

 The parameters include the prior probabilities of the cluster assignments, πΔK, and the conditional mean and variance of each cluster, μk,Σk. Let θ={π,{μ,Σ}k=1K} denote the set of all parameters.

The joint probability is,

p(z1:T,x1:T;θ)=t=1TCat(zt;π)N(xt;μzt,Σzt)

MAP estimation and K-means#

In Chapter 2, we fit this type of model using MAP estimation,

z1:T,θ=argmaxp(z1:T,x1:T;θ)

We solved for the MAP estimate by coordinate ascent. For the GMM, the coordinates updates are,

  1. For each data point t=1,,T:

    zt=argmaxkπkN(xtμk,Σk)
  2. For each cluster k=1,,K

    Tk=t=1TI[zt=k]μk=1Tkt=1TI[zt=k]xtΣk=1Tkt=1TI[zt=k](xtμk)(xtμk)

In words, we set the cluster means and covariances equal to the same means and covariances of the data points assigned to that cluster. If we were to fix Σk=I for all clusters, this algorithm would be equivalent to K-means clustering.

Expectation-Maximization (EM)#

MAP estimation gives us a point estimate of the latent variables and parameters. However, point estimates can lead to an overly optimistic view of the model. Specifically, MAP estimation found the best assignment of latent, which may not reflect how well samples from the model match the observed data.

Ideally, we would like to find parameters that maximize the marginal likelihood, aka the model evidence,

p(x1:T;θ)=p(x1:T,z1:T;θ)dz1dzT.

Note

For discrete latent variable models, the integral becomes a sum.

Once we have those parameters, we can compute the posterior distribution over latent variables,

p(z1:Tx1:T;θ).

The expectation-maximization (EM) algorithm does exactly that: it finds a local maximum of the marginal likelihood via an iterative algorithm very similar to the MAP estimation algorithm above. The key difference is that instead of using hard assignments of data points to the most likely cluser, we use soft assignments of data points based on their posterior probabilities.

  1. For each data point t=1,,T and cluster k=1,,K, set:

    qt,k=p(zt=kxt;θ)=πkN(xtμk,Σk)j=1KπjN(xtμj,Σj)
  2. For each cluster k=1,,K, update the parameters as follows:

    Tk=t=1Tqt,kμk=1Tkt=1Tqt,kxtΣk=1Tkt=1Tqt,k(xtμk)(xtμk)

The posterior probabilities qt,k are sometimes called the responsibilities. When we update the parameters, we take a weighted average of all data points based on these responsibilities.

The Evidence Lower Bound (ELBO)#

Why does EM work? We can view it as coordinate ascent on the evidence lower bound (ELBO). First, let’s rewrite the log marginal likelihood as the log of an expectation,

logp(x;θ)=logp(x,z;θ)dz=logq(z)q(z)p(x,z;θ)dz=logEq(z)[p(x,z;θ)q(z)]

Note

We dropped the subscripts on x1:T and z1:T to be more concise.

This equality holds for any distribution q(z) as long as p is absolutely continuous with respect to q. We call q(z) the variational posterior.

Since log is a concave function, Jensen’s inequality says that swapping the order of the log and the expectation gives a lower bound on the log marginal likelihood,

logp(x;θ)Eq(z)[logp(x,z;θ)logq(z)]L[q,θ]

Jensen’s inequality

Jensen’s inequality relates convex functions of expectations to expectations of convex functions. If x is a random variable and f is a convex function,

f(E[x])E[f(x)]

If f is a concave function, the inequality is reversed.

The functional L[q,θ] is called evidence lower bound, a.k.a. the ELBO.

EM and coordinate ascent#

We can view the EM algorithm as coordinate ascent on the ELBO.

  1. M-step: Update the parameters,

    θ=argmaxθL[q,θ]=argmaxθEq(z)[logp(x,z,θ)].

    That is, maximize the expected log joint probability.

  2. E-step: Update the variational posterior by setting it equal to the posterior,

    q(z)=p(zx;θ).

    To see why this maximizes the ELBO for fixed θ, note that

    q(z)=argmaxqL[q,θ]=argmaxqEq(z)[logp(x,z,θ)q(z)]=argminqKL(q(z)p(zx,θ))=p(zx,θ)

    where KL(qp) denotes the Kullback-Leibler divergence from q to p. The divergence is non-negative and zero iff q=p. Thus, maximizing the ELBO wrt q amounts to setting the variational posterior equal to the true posterior.

Note

The ELBO is tight after the E-step. Substituting q(z)=p(zx;θ) into the ELBO,

L[q,θ]=Ep(zx,θ)[logp(x,z,θ)p(zx,θ)]=Ep(zx,θ)[logp(x,θ)]=logp(x,θ)

Iterating between the E- and M-steps converges to a local optimum of the ELBO. Since after each E-step the ELBO is tight, the local optimum of the ELBO must also be a local optimum of the log marginal likelihood.

Exponential family mixtures#

So far we’ve focused on Gaussian mixture models. Now let’s consider the more general case of exponential family mixture models for which,

p(xnzn=k;η)=h(xn)exp{t(xn),ηkA(ηk)}

where η={ηk}k=1K is the set of natural parameters, t(xn) are the sufficient statistics, and A(ηk) is the log normalizer. Recall Chapter 9 for an introduction to exponential family distributions.

Warning

We have switched to indexing data points by n instead of t to avoid confusion with the sufficient statistic t. Furthermore, we are assuming that all K components of the mixture model belong to the same exponential family (and hence have the same sufficient statistics, log normalizer, base measure, etc.).

Under this model, the E-step reduces to computing the responsibilities,

qn,k=p(zn=kxn;η)=πkexp{t(xn),ηkA(ηk)}j=1Kπjexp{t(xn),ηjA(ηj)}

As a function of the parameters ηk, the ELBO is,

L(q,η)=Eq(z)[logp(x,z;η)]+c=n=1Nqn,klogp(xnzn=k;η)+c=n=1Nqn,k[t(xn),ηkA(ηk)]+c=tk,ηkNkA(ηk)+c

where

tk=n=1Nqn,kt(xn)Nk=n=1Nqn,k.

Taking derivatives and setting to zero yields,

ηk=[A]1(tk/Nk)

Recall that A is the moment generating funtion: derivatives of the log normalizer yield expected sufficient statistics. Assuming the exponential family is minimal, the inverse of moment generating function is well-defined, and the optimal natural parameters are those for which the expected sufficient statistics match the weighted average sufficient stastics, tk/Nk.

Normalized sufficient statistics

Note that M-step is invariant to rescaling the ELBO. We could have multiplied L(q,η) by a positive constant and the optimal ηk would remain the same. For example, we could normalize the statistics by the size of the dataset,

tk=tkNNk=NkN

Then ηk=[A]1(tk/Nk)=[A]1(tk/Nk) is unchanged. Working with normalized sufficient statistics like these can be more numerically stable, especially for very large datasets.

Stochastic EM#

Finally, the EM algorithm presented above is a batch algorithm. The M-step involves aggregating sufficient statistics from the entire dataset. For very large datasets, it is more efficient to do an M-step after each mini-batch of data. That is how stochastic EM works.

The key idea is to maintain a running estimate of the sufficient statistics. Assume we have N data points equally divided into M mini-batches, each of size NM. Let tk(m) and Nk(m) be the normalized statistics and responsibilities computed on the m-th mini-batch. (They are normalized by summing over the mini-batch and then dividing by its size, NM.)

In stochastic EM, we keep a running estimate of the normalized statistics. After processing each mini-batch, we update the running estimate by taking a convex combination of the previous estimate and the estimate from the current mini-batch.

tk(1α)tk+αtk(m)Nk(1α)Nk+αNk(m)

where α[0,1] is the step size. After each mini-batch, we update our parameters by setting,

ηk=[A]1(tk/Nk).

To ensure convergence, we decay the learning rate according to a schedule, starting with α=1 and slowly decaying it toward zero after each mini-batch.

One sweep through the entire set of M mini-batches — i.e., a sweep through the whole dataset — is called an epoch. Since stochastic EM updates parameters once per mini-batch as opposed to once per epoch, as in standard EM, the parameter estimates often converge in many fewer epochs.

Stochastic EM for GMM#

Let’s consider the special case of the Gaussian mixture model and derive the stochastic EM algorithm. The Gaussian probability can be written in exponential family form as,

p(xnμk,Σk)=(2π)D2|Σk|12exp{12(xnμk)Σ1(xnμk)}=(2π)D2exp{xn,Σk1μk+xnxn,12Σk112μΣk1μk12log|Σk|}=h(xn)exp{t1(xn),Σk1μk+t2(xn),12Σk1A(μk,Σk)}

where the sufficient statistics are

t1(xn)=xnt2(xn)=xnxn

and the log normalizer is

A(μk,Σk)=12μΣk1μk+12log|Σk|

Warning

Natural parameters We could write this log probability in terms of its natural parameters ηk,1=Σk1μk and ηk,2=Σk1, but since we ultimately want the standard parameters anyway, let’s just leave it in this form.

In the M-step of stochastic EM, we need to maximize the ELBO with respect to μk,Σk. The ELBO is a weighted sum of log probabilities, which is of the form,

L(q,θ)=tk,1,Σk1μk+tk,2,12Σk1NkA(μk,Σk)

where

tk,1=1Nn=1Nqn,kxntk,2=1Nn=1Nqn,kxnxnNk=1Nn=1Nqn,k

Taking the derivative with respect to μk,

μkL(q,θ)=Σk1tk,1NkΣk1μk

Setting to zero and solving for the mean yields,

μk=tk,1/Nk.

Substituting this into the ELBO yields,

L(q,θ)=12tk,1tk,1/Nktk,2,Σk112Nklog|Σk|

Taking the derivative with respect to Σk1 yields,

Σk1L(q,θ)=12(tk,1tk,1Nktk,2)+12NkΣk

Then solving for Σk gives,

Σk=1Nk(tk,2tk,1tk,1Nk).

Conclusion#

This chapter introduced unsupervised learning methods using latent variables.

  • We started with the Gaussian mixture model and recalled the coordinate ascent algorithm for MAP estimation, which is very similar to K-means.

  • Then we introduced the expectation-maximization algorithm and saw that for GMMs, it nicely parallels the MAP estimation algorithm. However, it yields parameters that maximize the marginal likelihood instead of the joint.

  • We derived EM as coordinate ascent on the evidence lower bound (ELBO). This view will generalize to more complex models in subsequent chapters.

  • Finally, we considered the more general class of exponential family mixture models and derived a stochastic EM algorithm that operates on running estimates of the expected sufficient statistics.

Next time, we’ll consider sequential latent variable models called hidden Markov models (HMMs).