HW4: Bayesian Mixture Models#


Name:

Names of any collaborators:


Background#

In this homework assignment we will investigate image segmentation —specifically, separating the background from the foreground of the image. To do so, you’ll fit Bayesian mixtures of Gaussians using the expectation-maximization (EM) algorithm.

The figure below shows the original input image and the resulting segmentations into background and foreground. By the end of this assignment, you will have implemented the algorithm to achieve this segmentation.

Reference on image segmentation: https://en.wikipedia.org/wiki/Image_segmentation

Model#

We will use a simple mixture model to cluster the pixels (with the number of clusters \(K = 2\) in our image segmentation problem). The likelihood is a mixture of Gaussian distributions.

\[\begin{split} \begin{align*} x_n \mid z_n, \{\mu_k, \Sigma_k\}_{k=1}^K &\sim \mathcal{N}(\mu_{z_n}, \Sigma_{z_n}) \\ z_n \mid \pi &\sim \text{Categorical}(\pi) \end{align*} \end{split}\]

where \(x_n \in \mathbb{R}^D\) is distributed according to a Gaussian distribution with the specified mean, \(\mu_k\), and covariance, \(\Sigma_k\), for its corresponding cluster \(z_n = k\), and \(z_n\) is distributed as a multinomial with hyperparameter \(\pi\). We will represent the images as a set of \(N\) pixels, \(\{x_n\}_{n=1}^N\), each in \(D=3\) dimensional space, since there are three color channels (red, green, and blue).

We specify the following priors on \(\mu_k\), \(\Sigma_k\), and \(\pi\).

  • Assume a normal-inverse-Wishart prior prior for each cluster mean and covariance.

\[ \begin{align*} p(\mu_k, \Sigma_k) &= \mathrm{IW}(\Sigma_k \mid \Sigma_0, \nu_0) \, \mathcal{N}(\mu_k \mid \mu_0, \kappa_0^{-1} \Sigma_k) \end{align*} \]

Here \(\Sigma_0, \nu_0, \mu_0, \kappa_0\) are hyper-parameters.

  • We give a symmetric Dirichlet distribution prior to the mixing proportions, \(\pi\):

\[ p(\pi \mid \alpha) = \text{Dirichlet}(\alpha 1_K) \]

where \(1_K\) is an all-ones vector of length \(K\) and \(\alpha\) is a hyperparameter.

Problem 1 [math]: EM calculations#

In this problem, you will derive the EM procedure for our Bayesian model. For notational simplicity, let

\[ \theta = (\{\mu_k, \Sigma_k\}_{k=1}^K, \pi) \]

be the tuple of parameters we wish to estimate via EM. Let \(\theta^{(i)}\) be the parameter value at iteration \(i\). Recall the EM procedure is given by two steps:

  • Expectation step (E-step): Compute

\[ \begin{align*} q_n(z_n) &= p(z_n \mid x_n, \theta^{(i)}) \end{align*} \]
  • Maximization step (M-step): Find new parameters

\[ \begin{align*} \theta^{(i+1)} = \underset{\theta}{\operatorname{argmax}} \mathbb{E}_{q} [\log p(\mathbf{X}, \mathbf{Z}, \theta)] \end{align*} \]

You will need these derivations to be correct for the implementation in Problem 2 to be correct, so we highly recommend taking the time to double-check them.

Problem 1a: Derive the posterior distribution for \(q_n(z_n) = p(z_n | x_n, \theta)\).#


Your answer here


Problem 1b: Derive the expected log probability#

Show that

\[\begin{split} \begin{align*} \mathbb{E}_q\left[ \log p(X, Z, \theta) \right] &= \underbrace{\sum\limits_{k=1}^K \left[ \sum\limits_{n=1}^N \left[ \omega_{nk} \log \mathcal{N}(x_n \mid \mu_{k}, \Sigma_k) \right] + \log p(\mu_k, \Sigma_k) \right]}_{\mathcal{L}_1(\mu, \Sigma)} \\&\qquad + \underbrace{\sum\limits_{k=1}^K \left[ \sum_{n=1}^N \left[\omega_{nk} \log \pi_k \right] + (\alpha_k-1) \log \pi_k \right]}_{\mathcal{L}_2(\pi)} + C \end{align*} \end{split}\]

for some constant \(C\), where \(\omega_{nk} = q_n(z_n =k)\), and where \(\mathcal{L}_1, \mathcal{L}_2\) represent the terms in the expected log probability that depend on \(\{\mu_k, \Sigma_k\}_{k=1}^K\) and \(\pi\), respectively.


Your answer here


Problem 1c: Expand \(\mathcal{L}_1\) in exponential family form.#

Show that \(\log p(x_n\mid z_n=k, \mu_k, \Sigma_k)\) and \(\log p(\mu_k, \Sigma_k)\) can be represented as the following:

\[\begin{split} \begin{align*} \log p(x_n\mid z_n=k, \mu_k, \Sigma_k) &= t(x_n)^\top \eta_k - A(\eta_k) + c \\ \log p(\mu_k, \Sigma_k) &= \phi^\top \eta_k - \nu A(\eta_k) + c' \end{align*} \end{split}\]

for some contants \(c\), \(c'\), functions \(t\), \(A\) (explicitly find these), hyperparameters \(\phi\), \(\nu\) (explicitly find these), where,

\[\begin{split} \begin{align*} \eta_k &:= \left(-\frac{1}{2}\log|\Sigma_k|, -\frac{1}{2}\Sigma_k^{-1}, \Sigma_k^{-1} \mu_k, -\frac{1}{2} \mu_k^\top \Sigma_k^{-1} \mu_k \right) \\ \end{align*} \end{split}\]

Here, inner-product between elements \(a, b\) of the form \(\eta_k\) is defined to be

\[ \langle a, b \rangle := a_1 b_1 + \mathrm{Tr}(a_2 b_2) + a_3^\top b_3 + a_4 b_4 \]

Deduce that \(\mathcal{L}_1\) can be written as

\[\begin{split} \begin{align*} \mathcal{L}_1(\mu, \Sigma) &= \sum\limits_{k=1}^K \left[ \sum\limits_{n=1}^N \left[ \omega_{nk} (t(x_n)^\top \eta_k - A(\eta_k)) \right] + \phi^\top \eta_k - \nu A(\eta_k) \right] + c \\ &= \sum\limits_{k=1}^K \left[ \phi_{k}^\top \eta_k - \nu_{k} A(\eta_k) \right] + c \end{align*} \end{split}\]

with

\[\begin{split} \begin{align*} \phi_{k} &= \phi + \sum\limits_{n=1}^N \omega_{n,k} t(x_n) \\ \nu_{k} &= \nu + \sum\limits_{n=1}^N \omega_{n,k} \\ \omega_{n,k} &= q_n(z_n=k) \end{align*} \end{split}\]

Conclude that each summand of \(\mathcal{L}_1\) is the log-pdf (up to a constant) of some Normal-Inverse-Wishart (NIW) distribution of \((\mu_k, \Sigma_k)\).

Problem 1d: Maximize \(\mathcal{L}_1\).#

Find the mode of an NIW distribution for \((\mu, \Sigma)\) with parameters \((\Sigma_0, \nu_0, \kappa_0, \mu_0)\). Use this result and (c) to find the closed-form solution for maximizing \(\mathcal{L}_1\) w.r.t. \(\mu_k, \Sigma_k\).


Your answer here


Problem 1e: Maximize \(\mathcal{L}_2\).#

Find the maximizing solution \(\pi^*\) of \(\mathcal{L}_2\).


Your answer here


Problem 2 [code]: Implement EM for the Gaussian mixture model#

We have provided starter code below. First, you need to fill it with your own implementation of the EM algorithm. This entails writing three functions:

  1. log_probability, which computes the log probability \(\log p(X, \theta)\)

  2. e_step, which computes the posteriors \(q_n(z_n)\) for each data point, fixing the current parameters.

  3. m_Step, which returns new parameters, fixing the current posteriors.

Then, you will test your code on a simple example, using the code we have proved.

You may not rely on external implementations such as those offered by Tensorflow or scikit-learn.

Setup#

import torch
from torch.distributions import MultivariateNormal, Categorical, Dirichlet

from tqdm.auto import trange
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms

Helpers#

We have provided a helper function to compute the inverse Wishart log probability since this is not one of the standard distributions in torch.distributions.

def invwishart_log_prob(Sigma, nu0, Sigma0):
    """
    Helper function to compute the inverse Wishart log probability, since its
    not given in torch.distributions.

    Args:

    Sigma:      (..., D, D) batch of covariance matrices
    nu0:        scalar degree of freedom of inverse Wishart distribution
    Sigma0:     (D, D) scale matrix for inverse Wishart distribution

    Returns:

    lp:         (...,) a batch of log probabilities
    """
    D = Sigma.shape[-1]
    assert Sigma.shape[-2:] == (D, D)
    assert Sigma0.shape[-2:] == (D, D)
    nu0 = torch.tensor(nu0)

    lp = -(nu0 + D + 1) / 2 * torch.logdet(Sigma)
    lp -= torch.linalg.solve(Sigma, Sigma0)\
        .diagonal(dim1=-1, dim2=-2).sum(axis=-1) / 2

    # log normalizing constant
    lp += nu0 / 2 * torch.logdet(Sigma0)
    lp -= nu0 * D / 2 * torch.log(torch.tensor(2.0))
    lp -= torch.special.multigammaln(nu0 / 2, D)
    return lp

Problem 2a: Implement the log_probability function.#

def log_probability(X, mus, Sigmas, pi,
                    alpha, mu0, kappa0, nu0, Sigma0):
    """
    Compute the log probability \log p(X, \theta), summing over the discrete
    cluster assignments.

    Hint: You may use the invwishart_log_prob function above.
    Hint: You may also want to use torch.logsumexp to do the sum over z.

    Args:
    - X:        (N, D) tensor of data points
    - mus:      (K, D) tensor of cluster means
    - Sigmas:   (K, D, D) tensor of cluster covariances
    - pi:       (K,) tensor of cluster weights
    - alpha:    (K,) concentration of the Dirichlet prior
    - mu0:      (D,) tensor with the prior mean
    - kappa0:   scalar prior precision
    - nu0:      scalar prior degrees of freedom
    - Sigma0:   (D, D) tensor of prior scale of the covariance

    Returns:
    - lp:       scalar log probability of the data and parameters, summing over
                the discrete latent variables
    """
    lp = 0

    ###
    # Your code here.
    ##
    return lp

Problem 2b: Implement the e_step function#

def e_step(X, mus, Sigmas, pi):
    """
    Perform one E step to compute the posterior 

        q_n(z_n) = p(z_n | x_n, \theta)

    for each data point. 

    Args:
    - X:        (N, D) tensor of data points
    - mus:      (K, D) tensor of cluster means
    - Sigmas:   (K, D, D) tensor of cluster covariances
    - pi:       (K,) tensor of cluster weights

    Returns:
    - Q:        (N, K) tensor of responsibilities; i.e. posterior probabilities. 
                Each row should be non-negative and sum to one
    """
    N, D = X.shape
    K, _ = mus.shape
    q = torch.zeros((N, K))

    ###
    # Your code here.
    ##
    return q
    

Problem 2c: Implement the m_step function#

def m_step(X, q, alpha, mu0, kappa0, nu0, Sigma0):
    """
    Perform one M-step to find new parameters given the current posterior
    and hyperparameters.

    Args:
    - X:        (N, D) data matrix
    - q:        (N, K) responsibilities; i.e. posterior probabilities
    - alpha:    (K,) concentration of the Dirichlet prior
    - mu0:      (D,) tensor with the prior mean
    - kappa0:   scalar prior precision
    - nu0:      scalar prior degrees of freedom
    - Sigma0:   (D, D) tensor of prior scale of the covariance

    Returns:
    - mus:      (K, D) new means for each cluster
    - Sigmas:   (K, D, D) new covariances for each cluster
    - pi:       (K,) new cluster probabilities
    """
    N, D = X.shape
    _, K = q.shape

    ### 
    # Your code here.
    ##
    return mus, Sigmas, pi

EM function [given]#

We’ve provided an em function to run EM on a given dataset with the specified hyperparameters.

def em(X, 
       K=2, 
       n_iter=100, 
       alpha=torch.ones(3),
       mu0=torch.zeros(3),
       kappa0=1.0,
       nu0=4.0,
       Sigma0=torch.eye(3)):
    """
    EM algorithm.

    Args:
    - X: Matrix of size (N, D). Each row of X stores one data point
    - K: the desired number of clusters in the model. Default: 2
    - n_iter: number of iterations of EM. Default: 100
    - alpha0: prior concentration of cluster probabilities
    - mu0, kappa0, nu0, Sigma0: parameters of normal-inverse-Wishart prior.
        Their shapes must be consistent with D, the data dimension.
        
    Returns:
    - mus: cluster means
    - Sigmas: cluster covariances
    - pi: cluster assignment probabilities
    - q: posterior probability of Z | X, mus, Sigmas, pi with final params.
    """
    N, D = X.shape
    assert alpha.shape == (K,)
    assert mu0.shape == (D,)
    assert Sigma0.shape == (D, D)
    hypers = (alpha, mu0, kappa0, nu0, Sigma0)

    # Initialize cluster parameters
    pi = alpha / torch.sum(alpha)
    mus = X[Categorical(logits=torch.zeros(N)).sample((K,))]
    Sigmas = Sigma0.repeat(K, 1, 1)

    # Initialize log prob outputs
    lps = []

    # Run EM
    for _ in trange(n_iter):
        q = e_step(X, mus, Sigmas, pi)
        lps.append(log_probability(X, mus, Sigmas, pi, *hypers))
        mus, Sigmas, pi = m_step(X, q, *hypers)
        
    # Run one last E-step to tighten the bound
    q = e_step(X, mus, Sigmas, pi)
    lps.append(log_probability(X, mus, Sigmas, pi, *hypers))

    return torch.tensor(lps), mus, Sigmas, pi, q

Test your implementation on a toy dataset#

Test your example on a synthetic data set.

For example, the ground truth could be two clusters, with means \([5,5]\) and \([8,8]\) with identity covariance matrices, respectively. You could generate \(100\) points in each cluster.

Whichever example you choose, be sure to specify it and show that your implementation roughly recovers the ground truth by displaying the cluster means/covariances.

def confidence_ellipse(mean, cov, ax, n_std=3.0, facecolor='none', **kwargs):
    """
    Modified from: https://matplotlib.org/3.5.0/gallery/\
        statistics/confidence_ellipse.html
    Create a plot of the covariance confidence ellipse of *x* and *y*.

    Parameters
    ----------
    mean: vector-like, shape (n,)
        Mean vector.
        
    cov : matrix-like, shape (n, n)
        Covariance matrix.

    ax : matplotlib.axes.Axes
        The axes object to draw the ellipse into.

    n_std : float
        The number of standard deviations to determine the ellipse's radiuses.

    **kwargs
        Forwarded to `~matplotlib.patches.Ellipse`

    Returns
    -------
    matplotlib.patches.Ellipse
    """
    # compute the 2D covariance ellipse
    pearson = cov[0, 1] / torch.sqrt(cov[0, 0] * cov[1, 1])
    ell_radius_x = torch.sqrt(1 + pearson)
    ell_radius_y = torch.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), 
                      width=ell_radius_x * 2, 
                      height=ell_radius_y * 2,
                      facecolor=facecolor, 
                      **kwargs)

    # Calculating the standard deviation
    # the square root of the variance and multiplying
    # with the given number of standard deviations.
    scale = torch.sqrt(torch.diag(cov) * n_std)
    
    # Transform the ellipse by rotating, scaling, and translating
    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(*scale) \
        .translate(*mean)
    ellipse.set_transform(transf + ax.transData)

    # Add the patch to the axis
    return ax.add_patch(ellipse)
def test_toy(seed=305+ord('c'),
             n_test=200,
             mus=torch.Tensor([[5,5], [8,8]]),
             covs=torch.eye(2).repeat(2,1,1),
             K=2,
             n_iter=300,
             ):
    K, D = mus.shape
    assert covs.shape == (K, D, D)
    
    # Generate n_test random data points from each of K classes and combine
    torch.manual_seed(seed)
    X = MultivariateNormal(mus, covs).sample((n_test,)).reshape(-1, D)
    
    # Run the EM algorithm
    em_results = em(X, K=K, n_iter=n_iter,
                    alpha=torch.ones(K),
                    mu0=torch.zeros(D),
                    kappa0=1.0,
                    nu0=3.0,
                    Sigma0=torch.eye(D))
    
    # Return data and results
    return (X, *em_results)
K = 2
X, lps, means, covs, probs, q = test_toy(K=K)

# display the results  
for k in range(K):
    print("Cluster ", k, ":")
    print("\t mu:    ", means[k,:])
    print("\t Sigma: ", covs[k,:,:])
    print("\t probs: ", probs[k])
    print("")

# Plot the log probabilities over EM iterations
plt.figure()
plt.plot(lps[1:])
plt.xlabel("EM iteration")
plt.ylabel("log probability")

# create a second figure to plot the clustered data
fig, ax = plt.subplots(figsize=(6, 6))

# plot scatter 
ax.scatter(X[:,0], X[:,1], c=torch.argmax(q, 1), marker='.')

for i in range(K):
  # plot mean as red dots
  ax.scatter(means[i,0], means[i,1], c='red')

  # plot covariance ellipses
  confidence_ellipse(means[i,:], covs[i], ax, n_std=1, 
                     edgecolor='red', linestyle=':')
  confidence_ellipse(means[i,:], covs[i], ax, n_std=2, 
                     edgecolor='red', linestyle=':')

Problem 3 [short answer]: Perform image segmentation#

All you have to do for this part is run the code we’ve provided below to test your EM implementation on a couple image segmentation problems and then answer the discussion questions below.

Now that you have implemented the EM algorithm, you are ready to perform image segmentation!

First, we’ll download some test images.

# First, download the files from the github page
!wget -nc https://raw.githubusercontent.com/slinderman/stats305c/main/assignments/hw4/images/fox.png
!wget -nc https://raw.githubusercontent.com/slinderman/stats305c/main/assignments/hw4/images/cow.png
!wget -nc https://raw.githubusercontent.com/slinderman/stats305c/main/assignments/hw4/images/owl.png
!wget -nc https://raw.githubusercontent.com/slinderman/stats305c/main/assignments/hw4/images/zebra.png

Next, we’ve written some helper functions to run your EM code to segment the images, print summaries of the results, and make some nice plots.

def load_image(filename):
    image = plt.imread(filename + ".png")[:, :, :3]
    plt.imshow(image)

    # get height, width and number of channels
    H, W, C = image.shape
    X = image.copy().astype(float)

    # reshape into pixels, each has 3 channels (RGB)
    X = X.reshape((H * W, C)) 
    return image, torch.Tensor(X)

def save_segmentation(image, assignments, filename=None):
    import numpy as np
    fig, axs = plt.subplots(1, K + 1, figsize=(4 * (K + 1), 4))
    axs[0].imshow(image)
    axs[0].set_axis_off()
    axs[0].set_title("original image")
    
    for k in range(K):
        im = image.copy()
        im[assignments != k] = np.nan
        axs[k+1].imshow(im)
        axs[k+1].set_axis_off()
        axs[k+1].set_title("component {}".format(k))
    
    if filename is not None:
        plt.savefig(filename)

def run_segmentation(filename, 
                     K=2, 
                     seed=305 + ord('c'),
                     n_iter=100,
                     alpha=100):
    # Load the specified image
    image, X = load_image(filename)

    # Run EM on a GMM with K classes
    torch.manual_seed(seed)
    lps, means, covs, probs, q = em(X, K=K, n_iter=100, 
                                    alpha=alpha * torch.ones(K))
    assignments = torch.argmax(q, axis=1).reshape(image.shape[:2])

    # Print the results
    print(filename + " results:")
    for k in range(K):
        print("Cluster ", k, ":")
        print("\t mu:    ", means[k,:])
        print("\t Sigma: ", covs[k,:,:])
        print("\t probs: ", probs[k])
        print("")

    # Plot the log probability over iterations
    plt.figure()
    plt.plot(lps[1:])
    plt.xlabel("EM iteration")
    plt.ylabel("log probability")

    # Save 
    save_segmentation(image, assignments, filename=filename + "_seg.png")

Finally, run the segmentation for each image#

Please run all of these cells! It should only take a few seconds for each cell to complete. E.g. our reference implementation takes 21 seconds for fox, 4 seconds for cow, 2 seconds for owl, and 12 seconds for zebra.

run_segmentation("fox")
run_segmentation("cow")
run_segmentation("owl")
run_segmentation("zebra")

Problem 3a: Multiple restarts#

Explain why you might need multiple restarts for EM to obtain the best results.


Your answer here.


Problem 3b: Model improvements#

How could you extend this model – e.g. by building in more prior information about images – to improve the background segmentations?


Your answer here.


Submission Instructions#

Formatting: check that your code does not exceed 80 characters in line width. If you’re working in Colab, you can set Tools → Settings → Editor → Vertical ruler column to 80 to see when you’ve exceeded the limit.

Download your notebook in .ipynb format and use the following commands to convert it to PDF:

jupyter nbconvert --to pdf hw4_yourname.ipynb

Dependencies:

  • nbconvert: If you’re using Anaconda for package management,

conda install -c anaconda nbconvert

Upload your .pdf files to Gradescope.