Mixture Models and EM#
We have a lot of tools in our toolkit now. We learned about exponential family distributions, which form the building blocks of more complex models. We learned about Bayesian inference algorithms, like MCMC and VI, which allow us to infer the posterior distribution of parameters in complex models. In this section of the course, we’ll put those tools to use developing latent variable models for complex datasets. We’ll start with the simplest latent variable model — mixture models.
Motivation#
Why do we use mixture models? Sometimes it is reasonable to model our data as a superposition of distinct clusters. For example, when working with single cell RNA sequencing data, different cell types may give rise to different patterns of gene expression.
From Kiselev et al. [KAH19].
Or in a computer vision application, perhaps pixels can be assigned to foreground and background.
From https://ai.stanford.edu/~syyeung/cvweb/tutorial3.html
Sometimes we just want to model a complex, multimodal distribution. Kernel density estimates are essentially mixture models for this type of task!
From Scikit-learn KDE Demo.
This chapter is about models and algorithms for fitting mixture models like these.
Mixture Models#
Let,
\(N\) denote the number of data points
\(K\) denote the number of mixture components (i.e., clusters)
\(\mbx_n \in \reals^D\) denote the \(n\)-th data point
\(z_n \in \{1, \ldots, K\}\) be a latent variable denoting the cluster assignment of the \(n\)-th data point
The model is parameterized by,
\(\mbtheta_k\) be natural parameters of cluster \(k\).
\(\mbpi \in \Delta_{K-1}\) be cluster proportions (probabilities).
The generative model is as follows
Sample the assignment of each data point:
\[\begin{align*} z_n &\iid{\sim} \mathrm{Cat}(\mbpi) \quad \text{for } n = 1, \ldots, N \end{align*}\]Sample data points given their assignments:
\[\begin{align*} \mbx_n &\sim p(\mbx \mid \mbtheta_{z_n}) \quad \text{for } n = 1, \ldots, N \end{align*} \]
Joint distribution#
The joint distribution of the data and latent variables is,
Exponential family mixture models#
Assume an exponential family likelihood of the form,
Example: Gaussian Mixture Model (GMM)
Assume the conditional distribution of \(\mbx_n\) is a Gaussian with mean \(\mbtheta_k \in \reals^D\) and identity covariance:
In exponential family form, the sufficient statistics of the Gaussian are \(t(\mbx) = \mbx\). Since we are assuming identity covariance for now, the natural parameters are simply \(\mbeta_k = \mbtheta_k\).
Two Inference Algorithms#
Let’s stick with the Gaussian mixture model for now. Suppose we observe data points \(\{\mbx_n\}_{n=1}^N\) and want to infer the assignments \(\{z_n\}_{n=1}^N\) and estimate the means \(\{\mbtheta_k\}_{k=1}^K\). Here are two intuitive algorithms.
MAP Inference and K-Means#
Suppose we knew the cluster assignments, \(z_n\). Then it would be straightforward to estimate the cluster means: we could simply use the maximum likelihood estimate, \(\hat{\mbtheta}_{\mathsf{MLE}} = \frac{1}{N_k} \sum_{n: z_n=k} \mbx_n\), where \(N_k = \sum_n \mathbb{I}[z_n =k]\) is the number of data poitns in cluster \(k\).
Likewise, if we knew the cluster means, it would be straightforward to compute the maximum a posteriori (MAP) estimate of \(z_n\): we would assign each data point to the nearest cluster. If we alternate these two steps, we obtain the K-Means algorithm:
Algorithm 4 (K-Means)
Repeat until convergence:
For each \(n=1,\ldots, N\), fix the means \(\mbtheta\) and set,
\[\begin{align*} z_n &= \hat{z}_{n,\mathsf{MAP}} \\ &= \arg \max_{k \in \{1,\ldots,K\}} \mathrm{N}(\mbx_n \mid \mbtheta_k, \mbI) \\ &= \arg \min_{k \in \{1,\ldots, K\}} \|\mbx_n - \mbtheta_k\|_2 \end{align*}\]For each \(k=1,\ldots,K\), fix all assignments \(\mbz\) and set,
\[\begin{align*} \mbtheta_k &= \hat{\mbtheta}_{\mathsf{MLE}} \\ &= \frac{1}{N_k} \sum_{n=1}^K \bbI[z_n=k] \mbx_n \end{align*}\]
Question
What does this algorithm implicitly assume about the cluster probabilities \(\mbpi\)? How would you modify this algorithm to incorporate and estimate \(\mbpi\)?
Connection between K-Means and MAP estimation
Note that if we put an improper uniform prior on \(\mbtheta_k\), we could think of this entire algorithm as MAP estimation of \(\mbtheta\) and \(\mbz\) via coordinate ascent!
Maximum Likelihood Estimation via EM#
K-Means made hard assignments of data points to clusters in each iteration. That sounds a little extreme — do you really want to attribute a datapoint to a single class when it is right in the middle of two clusters? What if we used soft assignments instead?
Algorithm 5 (EM for a GMM)
Repeat until convergence:
For each data point \(n\) and component \(k\), compute the responsibility:
\[\begin{align*} \omega_{nk} = \frac{\pi_k \mathrm{N}(\mbx_n \mid \mbtheta_k, \mbI)}{\sum_{j=1}^K \pi_j \mathrm{N}(\mbx_n \mid \mbtheta_j, \mbI)} \end{align*}\]For each component \(k\), update the mean:
\[\begin{align*} \mbtheta_k^\star &= \frac{1}{N_k} \sum_{n=1}^K \omega_{nk} \mbx_n \end{align*}\]
This is the Expectation-Maximization (EM) algorithm. As we will show, EM yields an estimate that maximizes the marginal likelihood of the data.
Theoretical Motivation#
Rather than maximizing the joint probability, EM is maximizing the marginal probability,
For discrete mixtures (with small enough \(K\)) we can evaluate the log marginal probability. We can usually evaluate its gradient too, so we could just do gradient ascent to find \(\mbtheta^*\). However, EM typically obtains faster convergence rates.
Evidence Lower Bound (ELBO)#
The key idea is to obtain a lower bound on the marginal probability,
where \(q(z_n)\) is any distribution on \(z_n \in \{1,\ldots,K\}\) such that \(q(z_n)\) is absolutely continuous w.r.t. \(p(\mbx_n, z_n; \mbtheta)\).
Jensen’s Inequality
Jensen’s inequality states that,
if \(f\) is a concave function, with equality iff \(f\) is linear.
Applied to the log marginal probability, Jensen’s inequality yields,
where \(\mbq = (q_1, \ldots, q_N)\) is a tuple of distributions, one for each latent variable \(z_n\).
This is called the evidence lower bound, or ELBO for short. It is a function of \(\mbtheta\) and a functional of \(\mbq\), since each \(q_n\) is a probability density function. We can think of EM as coordinate ascent on the ELBO, alternating between updating the parameters \(\mbtheta\) and the posteriors $q.
M-step: Gaussian case#
Suppose we fix \(\mbq\). Since each \(z_n\) is a discrete latent variable, \(q_n\) must be a probability mass function. Let it be denoted by,
(These will be the responsibilities from before.)
Now, recall our basic model, \(\mbx_n \sim \mathrm{N}(\mbtheta_{z_n}, \mbI)\),
Zooming in on just \(\mbtheta_k\),
where
Taking derivatives and setting to zero yields,
These are the same as the EM updates shown above!
E-step: Gaussian case#
As a function of \(q_n\), for discrete Gaussian mixtures with identity covariance,
where \(\mbpi = [\pi_1, \ldots, \pi_K]^\top\) is the vector of prior cluster probabilities.
We also have two constraints: \(\omega_{nk} \geq 0\) and \(\sum_k \omega_{nk} = 1\). Let’s ignore the non-negative constraint for now (it will automatically be satisfied anyway) and write the Lagrangian with the simplex constraint,
Taking the partial derivative wrt \(\omega_{nk}\) and setting to zero yields,
Enforcing the simplex constraint yields,
just like above.
Note that
That is, the responsibilities equal the posterior probabilities!
The ELBO is tight after the E-step#
Equivalently, \(q_n\) equals the posterior, \(p(z_n \mid \mbx_n, \mbtheta)\). At that point, the ELBO simplifies to,
EM as a minorize-maximize (MM) algorithm
Note that the The ELBO is tight after the E-step!.
We can view the EM algorihtm as a minorize-maximize (MM) algorithm where we iteratively lower bound the ELBO and and then maximize the lower bound.
M-step: General Case#
Now let’s consider the general Bayesian mixture with exponential family likelihoods.
While we’re at it, let’s also add conjugate priors \(p(\mbtheta)\) with pseudo-counts \(\nu\) and pseudo-observations \(\mbchi\). As a function of \(\mbtheta\),
Zooming in on just \(\mbtheta_k\),
where
Taking derivatives and setting to zero yields,
Recall that \(\nabla A^{-1}: \cM \mapsto \Omega\) is a mapping from mean parameters to natural parameters (and the inverse exists for minimal exponential families). Thus, the generic M-step above amounts to finding the natural parameters \(\mbtheta_k^*\) that yield the expected sufficient statistics \(\mbchi_{k}' / \nu_k'\) by inverting the gradient mapping.
E-step: General Case#
In our first pass, we assumed \(q_n\) was a finite pmf. More generally, \(q_n\) will be a probability density function, and optimizing over functions usually requires the calculus of variations. (Ugh!)
However, note that we can write the ELBO in a slightly different form,
where \(\KL{\cdot}{\cdot}\) denote the Kullback-Leibler divergence. (Note that we included the prior on \(\mbtheta\) since we are treating it as a random variable with a prior in this general case.)
Recall, the KL divergence is defined as,
It gives a notion of how similar two distributions are, but it is not a metric! (It is not symmetric.) Still, it has some intuitive properties:
It is non-negative, \(\KL{q(z)}{p(z)} \geq 0\).
It equals zero iff the distributions are the same, \(\KL{q(z)}{p(z)} = 0 \iff q(z) = p(z)\) almost everywhere.
Maximizing the ELBO wrt \(q_n\) amounts to minimizing the KL divergence to the posterior \(p(z_n \mid \mbx_n, \mbtheta)\),
As we said, the KL is minimized when \(q_n(z_n) = p(z_n \mid \mbx_n, \mbtheta)\), so the optimal update is,
just like we found above.
Conclusion#
Mixture models are basic building blocks of statistics, and our first encounter with discrete latent variable models (LVMs). (Where have we seen continuous LVMs so far?) Mixture models have widespread uses in both density estimation (e.g., kernel density estimators) and data science (e.g., clustering). Next, we’ll talk about how to extend mixture models to cases where the cluster assignments are correlated in time.