HW3: Continuous Latent Variable Models#


Name:

Collaborators:


This homework explores continuous latent variable models like PCA and factor analysis. We will work with a synthetic dataset (MNIST digits) where we artificially mask out some pixels. Then we’ll see how well we can reconstruct the images by performing Bayesian inference in a factor analysis model with missing data.

This application may seem a bit contrived – who cares about MNIST digits? – but it has real-world applications. For example, Markowitz et al (2018) used this technique to find a low-dimensional embedding of images of partially occluded mice.

Along the way, we’ll build some intuition for PCA, hone our Gibbs sampling skills, and as a bonus, you can learn about multivariate Gaussian distribution for matrices called the matrix normal distribution.

Setup#

import torch
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms

from torch.distributions import Gamma, Normal, Bernoulli, MultivariateNormal, \
    TransformedDistribution
from torch.distributions.transforms import PowerTransform

from tqdm.auto import trange

import matplotlib.pyplot as plt
from matplotlib.cm import Blues
import seaborn as sns
sns.set_context("notebook")
class ScaledInvChiSq(TransformedDistribution):
    """
    Implementation of the scaled inverse \chi^2 distribution defined in class.
    We will implement it as a transformation of a gamma distribution.
    """
    def __init__(self, dof, scale):
        base = Gamma(dof / 2, dof * scale / 2)
        transforms = [PowerTransform(-1)]
        TransformedDistribution.__init__(self, base, transforms)
        self.dof = dof
        self.scale = scale

Download the MNIST dataset#

# Download MNIST training data and convert to float32
# Only use a subset of the images
N = 10000
X3d_true = MNIST(root='.', train=True, transform=None, download=True).data
X3d_true = X3d_true.type(torch.float32)
X3d_true = X3d_true[:N]
_, H, W = X3d_true.shape

# Add some noise to the images so they are not strictly integers
# Otherwise we get weird numerical bugs in the Gibbs sampling code!
torch.manual_seed(305)
X3d_true += Normal(0, 3).sample(X3d_true.shape)

Write simple functions to mask off some of the data#

We’ll make three types of masks:

  • Lines through the center of the image

  • Circles of random radius

  • Speckle, where each pixel is missing at random

Hide code cell content
def random_line_mask(num_samples, 
                     mask_size=(28, 28), 
                     lw=2):
    """
    Make a mask from a line through the center of the image.

    Args:
        num_samples: number of masks to generate
        mask_size: pixels by pixels
        lw: line width in pixels

    Returns:
        masks: (num_samples,) + mask_size array of binary masks

    """
    # Sample random orientations for each line
    us = Normal(0, 1).sample((num_samples, 2))
    us /= torch.norm(us, dim=1, keepdim=True)

    # Get distance of each xy coordinate to the line
    # this is the norm of (x, y) - (xp, yph) where (xp, yp)
    # is the projection onto the line
    X, Y = torch.meshgrid(torch.arange(mask_size[0]), 
                          torch.arange(mask_size[1]))
    xy = torch.column_stack([X.ravel(), Y.ravel()])
    xy = xy - torch.tensor(mask_size) / 2.0
    
    # Project onto the line
    # xpyp.shape == (num_samples, num_points, 2)
    xpyp = (us @ xy.T).unsqueeze(2) * us.unsqueeze(1)  
    dist = torch.norm(xy - xpyp, dim=2)

    # Make masks based on a distance threshold
    return (dist < lw).reshape((num_samples,) + mask_size)
    

def random_circle_mask(num_samples, 
                       mask_size=(28, 28),
                       std_origin=3.0,
                       mean_radius=3.0,
                       df_radius=7.0):
    """
    Sample random circular masks.

    Args:
        num_samples: number of masks to generate
        mask_size: mask size in pixels
        std_origin: standard deviation of the origin in pixels
        mean_radius: mean radius of the circular masks
        df_radius: degrees of freedom of a chi^2 distribution on radii.

    Returns:
        masks: (num_samples,) + mask_size array of binary masks
    """
    centers = Normal(0, std_origin).sample((num_samples, 2))
    radii = 0.1 + Gamma(df_radius / 2,
                      df_radius / mean_radius / 2.0).sample((num_samples,))

    # Determine whether each point is inside the corresponding circle
    X, Y = torch.meshgrid(torch.arange(mask_size[0]),
                          torch.arange(mask_size[1]))
    X = X - mask_size[0] / 2.0
    Y = Y - mask_size[1] / 2.0
    xy = torch.column_stack([X.ravel(), Y.ravel()]) # (num_points, 2)
    dist = torch.norm(centers.unsqueeze(1) - xy, dim = 2)
    return  (dist < radii.unsqueeze(1)).reshape((num_samples,) + mask_size)


def random_speckle_mask(num_samples,
                        mask_size=(28, 28),
                        p_missing=0.1):
    """
    Sample a random speckle mask where each pissing is missing with equal 
    probability.

    Args:
        num_samples: number of masks to sample
        p_speckle: probability that a pixel is missing

    Returns:
        masks: (num_samples,) + mask_size binary array
    """
    masks = Bernoulli(p_missing).sample((num_samples,) + mask_size)
    return masks.type(torch.BoolTensor)

Make masks and apply them to each data point#

# Make masks for each data point
torch.manual_seed(305)
line_masks = random_line_mask(N // 3)
circ_masks = random_circle_mask(N // 3)
spck_masks = random_speckle_mask(N - len(line_masks) - len(circ_masks))
mask3d = torch.cat([line_masks, circ_masks, spck_masks])[torch.randperm(N)]

# Make the training data by substituting 255 (the max value of a uint8) 
# for each missing pixel
X3d = torch.clone(X3d_true)
X3d[mask3d] = 255.0

Plot the masks and the masked data#

# Plot a few masks
fig, axs = plt.subplots(5, 5, figsize=(8, 8))
for i in range(5):
    for j in range(5):
        axs[i, j].imshow(mask3d[i * 5 + j], interpolation="none")
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.suptitle("Random Masks")
# Plot a few masked data points
fig, axs = plt.subplots(5, 5, figsize=(8, 8))
for i in range(5):
    for j in range(5):
        axs[i, j].imshow(X3d[i * 5 + j], interpolation="none")
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.suptitle("Masked Data")

Flatten the data and masks into 2D tensors#

The masked data is now stored in the tensor X3d, which has shape (60000, 28, 28). We will flatten the tensor into X, which has shape (60000, 784), and consider each row to be a vector-valued observation. We’ll do the same for the masks.

X_true = X3d_true.reshape((N, -1))
X = X3d.reshape((N, -1))
mask = mask3d.reshape((N, -1))

Note: From here on out, you should only need X and mask in your code algorithm. X_true is reserved for validation purposes.

Part 1: Principal Components Analysis and the SVD#

Problem 1a [Code]: Run PCA on directly on the masked data#

In this problem, you’ll investigate what happens if you run PCA on X directly.

Implement PCA by taking the SVD of the centered and rescaled data matrix. Plot the first 25 principal components.

def pca(X):
    """
    Compute the principal components and the fraction of variance explained 
    using the SVD of the scaled and centered data matrix. 

    Args:
        X: a shape (N, D) tensor

    Returns:
        pcs: a shape (D, D) tensor whose columns are the full set of D principal
            components. This matrix should be orthogonal.

        var_explained: a shape (D,) tensor whose entries are the variance 
            explained by each corresponding principal component.
    """
    ## 
    # Your code below.
    #
    ##
    return pcs, var_explained

We have provided some code below to run your code and plot the results.

def plot_pca(pcs, var_explained):
    """
    Helper function to plot the principal components and the variance explained,
    aka scree plot.
    """
    # Plot the first 25 principal components
    fig, axs = plt.subplots(5, 5, figsize=(8, 8))
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(pcs[:, i * 5 + j].reshape((28, 28)), 
                            interpolation="none")
            axs[i, j].set_xticks([])
            axs[i, j].set_yticks([])
            axs[i, j].set_title("PC {}".format(i * 5 + j + 1))
    plt.tight_layout()

    # Make the scree plot
    plt.figure()
    plt.plot(torch.cumsum(var_explained, dim=0))
    plt.xlabel("Number of PCs")
    plt.xlim(0, 784)
    plt.ylabel("Fraction of Variance Explained")
    plt.ylim(0, 1)
    plt.grid(True)
# Plot the pca results for X, the flattened, masked data
plot_pca(*pca(X))
# Compare the results to PCA on the X_true, the flattened true data
plot_pca(*pca(X_true))

Problem 1b [Short Answer]: Why does PCA on the masked data need so many more components?#

PCA needs far fewer components to reach 90% variance explained on the real data (X_true) than it does on the masked data (X). Intuitively, why is that?


Your answer here.


Part 2: Gibbs Sampling for Factor Analysis with Missing Data#

Now we will try to fit a continuous latent variable model to the masked data by treating the masked pixels as missing data. As in lecture, we will assume a conjugate prior of the form,

\[\begin{split} \begin{align*} \sigma_d^2 &\sim \chi^{-2}(\nu_0, \sigma_0^2) \\ \mathbf{w}_d &\sim \mathcal{N}(\mathbf{0}, \tfrac{\sigma_d^2}{\kappa_0} \mathbf{I}) \\ \mu_d &\sim \mathcal{N}(0, \tfrac{\sigma_d^2}{\lambda_0}) \end{align*} \end{split}\]

The only thing we’ve added is a prior on the mean, which we previously assumed to be fixed at zero.

Given the parameters, the distribution on latent variables and data is,

\[\begin{split} \begin{align*} \mathbf{z}_n &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ \mathbf{x}_n &\sim \mathcal{N}(\mathbf{W} \mathbf{z}_n + \boldsymbol{\mu}, \mathrm{diag}(\boldsymbol{\sigma}^2)) \end{align*} \end{split}\]

where \(\mathbf{W} \in \mathbb{R}^{D \times M}\) is a matrix with rows \(\mathbf{w}_d\), \(\boldsymbol{\mu} = [\mu_1, \ldots, \mu_D]^\top\) and \(\boldsymbol{\sigma}^2 = [\sigma_1^2, \ldots, \sigma_D^2]^\top\).

The graphical model (omitting the hyperparameters) looks like this:

Factor Analysis with Missing Data Graphical Model

Here, the \(d\)th coordinate is missing from the \(n\)-th data point. On other data points, other subsets of coordinates may be missing.

To formalize the problem, let

\[\begin{split} \begin{align*} \mathbf{X}_{\mathsf{obs}} &= \{x_{n,d}: x_{n,d} \text{ is observed}\} \\ \mathbf{X}_{\mathsf{miss}} &= \{x_{n,d}: x_{n,d} \text{ is missing}\} \end{align*} \end{split}\]

denote the observed and missing data, respectively.

Our goal is to infer the posterior distribution over parameters and latent variables and missing data given only the observed data and hyperparamters,

\[ \begin{align*} p(\mathbf{W}, \boldsymbol{\mu}, \boldsymbol{\sigma^2}, \mathbf{Z}, \mathbf{X}_{\mathsf{miss}} \mid \mathbf{X}_{\mathsf{obs}}, \boldsymbol{\eta}), \end{align*} \]

where \(\boldsymbol{\eta} = (\nu_0, \sigma_0^2, \kappa_0, \lambda_0)\) are the hyperparameters.

To do so, we will implement a Gibbs sampling algorithm that alternates between updating the parameters \(\mathbf{W}\) and \(\boldsymbol{\sigma^2}\) and the latent variables \(\mathbf{z}_n\) for each data point, and then we’ll add one more step: sampling new values for the missing data \(\mathbf{X}_{\mathsf{miss}}\) from their conditional distribution. With samples of \(\mathbf{X}_{\mathsf{miss}}\), for example, we can approximate the posterior distribution over the masked regions of the image.

Problem 2a [Math]: Derive the complete conditional distributions for the Gibbs sampler#

Specifically, derive closed form expressions for the following conditional distributions:

  • \(p(\mathbf{w_d} \mid \{\mathbf{w}_i\}_{i \neq d}, \boldsymbol{\mu}, \boldsymbol{\sigma}^2, \mathbf{Z}, \mathbf{X}_{\mathsf{miss}}, \mathbf{X}_{\mathsf{obs}}, \boldsymbol{\eta})\)

  • \(p(\mu_d \mid \{\mu_i\}_{i \neq d}, \mathbf{W}, \boldsymbol{\sigma}^2, \mathbf{Z}, \mathbf{X}_{\mathsf{miss}}, \mathbf{X}_{\mathsf{obs}}, \boldsymbol{\eta})\)

  • \(p(\sigma_d^2 \mid \{\sigma_i^2\}_{i \neq d}, \mathbf{W}, \boldsymbol{\mu}, \mathbf{Z}, \mathbf{X}_{\mathsf{miss}}, \mathbf{X}_{\mathsf{obs}}, \boldsymbol{\eta})\)

  • \(p(\mathbf{z}_n \mid \mathbf{W}, \boldsymbol{\mu}, \boldsymbol{\sigma}^2, \{\mathbf{z}_i\}_{i\neq n}, \mathbf{X}_{\mathsf{miss}}, \mathbf{X}_{\mathsf{obs}}, \boldsymbol{\eta})\)

  • \(p(x_{n,d} \mid \mathbf{W}, \boldsymbol{\mu}, \boldsymbol{\sigma}^2, \mathbf{Z}, \mathbf{X}_{\mathsf{obs}}, \boldsymbol{\eta})\) for each missing entry \(x_{n,d}\)

Hint: Your expressions may not depend on all of the conditioned upon variables.


Your answer here.


Problem 2b [Short answer]: Which Gibbs steps can be performed in parallel?#

As in Assignment 2, some of these updates can be performed in parallel using a blocked Gibbs udpate. Which ones?


Your answer here.


Problem 2c [Code]: Implement the Gibbs sampler#

Finish the functions below to implement the udpates you derived above. We have provided some function headers to help you organize your solutions.

def log_probability(X, Z, W, mu, sigmasq, nu0, sigmasq0, kappa0, lambda0):
    """
    Evaluate the log joint probability of the _complete_ data and all the 
    latent variables and parameters.

    Args:
        X: shape (N,D) tensor with the complete data (current samples of the 
            missing data are filled in)
        Z: shape (N,M) tensor with the latent variables
        W: shape (D,M) tensor of weights
        mu: shape (D,) tensor with the mean parameter
        sigmasq: shape (D,) tensor with the variance parameters
        nu0, sigmasq0: scalar hyperparameters for the prior on variance
        kappa0: scalar hyperparameter for the prior on weights
        lambda0: scalar hyperparameter for the prior on mean
    """
    ###
    # Your code here.
    #
    # Hint: Take advantage of Pytorch distributions' support for broadcasting
    # to evaluate many log probabilities at once.
    ##
    return lp


def gibbs_sample_latents(W, mu, sigmasq, X):
    """
    Sample new weights W given the other parameters, latent variables, and 
    hyperparameters.

    Args:
        W: shape (D,M) tensor of weights
        mu: shape (D,) tensor with the mean 
        sigmasq: shape (D,) tensor with variance parameters
        X: shape (N,D) tensor with the complete data (current samples of the 
            missing data are filled in)

    Returns:
        Z: shape (N,M) tensor with latent variables sampled from their 
            conditional
    """
    ###
    # Your code here.
    # 
    # Hint: use the MultivariateNormal distribution object and take advantage
    # of its broadcasting capabilities to sample the rows of Z in parallel.
    #
    # Hint: `torch.linalg.solve(J, h.unsqueeze(2))` will broadcast a solve of a
    # a shape (M, M) tensor `J` with a shape (N, M) tensor `h`. It gives a 
    # tensor of shape (N, M, 1). If you're not careful with broadcasting, you 
    # can get out of memory issues and crash the kernel.
    ##
    return Z


def gibbs_sample_weights(mu, sigmasq, Z, X, kappa0):
    """
    Sample new weights W given the other parameters, latent variables, and 
    hyperparameters.

    Args:
        mu: shape (D,) tensor with the mean parameter
        sigmasq: shape (D,) tensor with the variance parameters
        Z: shape (N,M) tensor with the latent variables
        X: shape (N,D) tensor with the complete data (current samples of the 
            missing data are filled in)
        kappa0: scalar hyperparameter for the prior on weights

    Returns:
        W: shape (D,M) tensor of weights sampled from its conditional
    """
    ###
    # Your code here.
    # 
    # Hint: you can use the MultivariateNormal distribution object and take 
    # advantage of its broadcasting capabilities to sample many rows of W in 
    # parallel.
    ##
    return W


def gibbs_sample_mean(W, sigmasq, Z, X, lambda0):
    """
    Sample new weights W given the other parameters, latent variables, and 
    hyperparameters.

    Args:
        W: shape (D,M) tensor of weights
        sigmasq: shape (D,) tensor with the variance parameters
        Z: shape (N,M) tensor with the latent variables
        X: shape (N,D) tensor with the complete data (current samples of the 
            missing data are filled in)
        lambda0: scalar hyperparameter for the prior on mean

    Returns:
        mu: shape (D,) tensor with the mean sampled from its conditional
    """
    ###
    # Your code here.
    #
    ##
    return mu


def gibbs_sample_variance(W, mu, Z, X, nu0, sigmasq0, kappa0, lambda0):
    """
    Sample new weights W given the other parameters, latent variables, and 
    hyperparameters.

    Args:
        W: shape (D,M) tensor of weights
        mu: shape (D,) tensor with the mean 
        Z: shape (N,M) tensor with the latent variables
        X: shape (N,D) tensor with the complete data (current samples of the 
            missing data are filled in)
        nu0, sigmasq0: scalar hyperparameters for the prior on variance
        kappa0: scalar hyperparameter for the prior on weights
        lambda0: scalar hyperparameter for the prior on mean

    Returns:
        sigmasq: shape (D,) tensor with variance sampled from its conditional
    """
    ###
    # Your code here.
    # 
    # Hint: You may use the ScaledInvChiSq distribution provide above. It also
    # supports broadcasting.
    ##
    return sigmasq


def gibbs_sample_missing_data(W, mu, sigmasq, Z, X, mask):
    """
    Sample new weights W given the other parameters, latent variables, and 
    hyperparameters.

    Args:
        W: shape (D,M) tensor of weights
        mu: shape (D,) tensor with the mean 
        sigmasq: shape (D,) tensor with variance parameters
        Z: shape (N,M) tensor with the latent variables
        X: shape (N,D) tensor with the complete data (current samples of the 
            missing data are filled in)
        mask: shape (N,D) boolean tensor where 1 (True) specifies that the 
            corresponding entry in X is missing and needs to be resampled.

    Returns:
        X: shape (N,D) tensor which is the same as the given X in entries where
            mask == 0 (False), but which has new values sampled from their 
            conditional distribution in entries where mask == 1 (True).
    """
    ###
    # Your code here.
    # 
    # Hint: Pytorch supports the same sorts of indexing tricks as numpy. 
    # See: https://pytorch.org/cppdocs/notes/tensor_indexing.html
    # For example, you can use `X[mask] = vals` to set only the entries where 
    # the boolean mask is 1 (True). In this expression, `vals` is a 1d tensor
    # whose length equals the number of missing values, 
    # i.e. `len(vals) = mask.sum()`. 
    ##
    return X

Run the Gibbs Sampler [Provided]#

We have provided a simple function to run your Gibbs sampling code on the masked data from above. Collecting 200 Gibbs samples takes about 5 minutes with my implementation (on a Colab notebook, not using the GPU).

def gibbs(X, 
          mask, 
          M=50,
          nu0=1.1, 
          sigmasq0=10., 
          kappa0=0.01, 
          lambda0=0.01, 
          N_samples=200):
    """
    Run the Gibbs sampler.

    Args:

        X: shape (N,D) tensor with the complete data (current samples of the 
            missing data are filled in)
        mask: shape (N,D) boolean tensor where 1 (True) specifies that the 
            corresponding entry in X is missing and needs to be resampled.
        M: the dimension of the continuous latent variables
        nu0, sigmasq0: scalar hyperparameters for the prior on variance
        kappa0: scalar hyperparameter for the prior on weights
        lambda0: scalar hyperparameter for the prior on mean
        N_samples:  number of Gibbs iterations to run
    
    Returns:

    Dictionary with samples of the parameters tausq, mu, thetas, sigmasqs, and 
    the log joint probability at each iteration.
    """
    N, D = X.shape

    # We will be updating X in place each time we sample missing data.
    # Rather than overwriting the data that's passed in, we'll make a clone 
    # and update that instead.
    X = torch.clone(X)

    # Similarly, all the missing data is currently set to 255 (the high value).
    # Let's initialize the missing data with the mean of the observed data.
    fmask = mask.type(torch.float32)
    N_obs = torch.sum(1 - fmask, dim=0)
    X_mean = torch.sum(X * (1 - fmask), dim=0) / N_obs
    X[mask] = X_mean.repeat(N, 1)[mask]

    # Initialize the mean \mu to the sample mean and the variance \sigmasq to
    # the sample variance of the observed data. Initialize the weights and the 
    # latent variables randomly.
    mu = X_mean
    sigmasq = torch.sum((X - X_mean)**2 * (1 - fmask), dim=0) / N_obs
    W = Normal(0, 1).sample((D, M))
    Z = Normal(0, 1).sample((N, M))

    # Compute the initial log probability
    lp = log_probability(X, Z, W, mu, sigmasq, nu0, sigmasq0, kappa0, lambda0)
    
    # Initialize the output
    samples = [(torch.clone(X[mask]), Z, W, mu, sigmasq, lp)]

    # Run the Gibbs sampler
    for itr in trange(N_samples - 1):
        # Cycle through each update 
        Z = gibbs_sample_latents(W, mu, sigmasq, X)
        W = gibbs_sample_weights(mu, sigmasq, Z, X, kappa0)
        mu = gibbs_sample_mean(W, sigmasq, Z, X, lambda0)
        sigmasq = gibbs_sample_variance(W, mu, Z, X, 
                                        nu0, sigmasq0, kappa0, lambda0)
        X = gibbs_sample_missing_data(W, mu, sigmasq, Z, X, mask)

        # Compute the log probability
        lp = log_probability(X, Z, W, mu, sigmasq, 
                             nu0, sigmasq0, kappa0, lambda0)
                
        # Update the sample list
        samples.append((torch.clone(X[mask]), Z, W, mu, sigmasq, lp))

    # Combine the output into a dictionary with a cool python zip trick
    samples_dict = dict()
    keys = ["X_miss", "Z", "W", "mu", "sigmasq", "lps"]
    values = zip(*samples)
    for key, value in zip(keys, values):
        samples_dict[key] = torch.stack(value)

    return samples_dict
# This takes about 5-6 min with my code. For debugging purposes, you may want
# to reduce N_samples, but please reset it to 200 for your final analysis.
N_samples = 200
samples = gibbs(X, mask, M=50, N_samples=N_samples)

Plot your results [Provided]#

The code below generates the following plots:

  • Trace of the log joint probability

  • The first 25 data points with their missing values filled in with the average of \(\mathbf{X}_{\mathsf{miss}}\) from the last half of the Gibbs samples. - 25 factors from the final Gibbs sample arranged into a 5x5 grid where each factor is shown as a 28x28 pixel image.

  • The root mean squared error of the reconstructed image over iterations.

  • Plot of the mean \(\boldsymbol{\mu}\) averaged over the last half of the Gibbs samples, shown as a 28x28 pixel image

  • Plot of the variance \(\boldsymbol{\sigma}^2\) averaged over the last half of the Gibbs samples, shown as a 28x28 pixel image

offset = 5
plt.plot(torch.arange(offset, N_samples), samples["lps"][offset:])
plt.xlabel("Iteration")
plt.ylabel("Log Joint Probability")
# Plot the masked and reconstructed data, using the mean of X_miss samples
X_miss = samples['X_miss'][N_samples//2:].mean(dim=0)
X_recon = torch.clone(X)
X_recon[mask] = X_miss

# Plot a few masked data points
fig, axs = plt.subplots(5, 5, figsize=(16, 8))
for i in range(5):
    for j in range(5):
        im = torch.column_stack([X[i * 5 + j].reshape(28, 28),
                                 X_recon[i * 5 + j].reshape(28, 28)])
        axs[i, j].imshow(im, interpolation="none", vmin=0, vmax=255)
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
fig.suptitle("Masked and Reconstructed Data")
# Plot the reconstruction error across Gibbs iterations
rmse = torch.sqrt(((samples['X_miss'] - X_true[mask])**2).mean(axis=1))
plt.plot(rmse)
plt.xlabel("Iteration")
plt.ylabel("RMSE")
# Plot the first 25 principal components
W = samples['W'][-1]
fig, axs = plt.subplots(5, 5, figsize=(8, 8))
for i in range(5):
    for j in range(5):
        axs[i, j].imshow(W[:, i * 5 + j].reshape((28, 28)), 
                        interpolation="none")
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
        axs[i, j].set_title("Factor {}".format(i * 5 + j + 1))
plt.tight_layout()
# Plot the posterior mean of $\mu$
plt.imshow(samples["mu"][N_samples//2:].mean(dim=0).reshape(28, 28))
plt.xticks([])
plt.yticks([])
plt.title("Mean Image")
plt.colorbar()
# Plot the posterior mean of $\sigma^2$
plt.imshow(torch.sqrt(samples["sigmasq"][N_samples//2:])\
           .mean(0).reshape(28, 28))
plt.xticks([])
plt.yticks([])
plt.title("Per-Pixel Variance")
plt.colorbar()

Problem 2d [Short answer]: Discussion#

Were you surprised at how well (or poorly) you were able to reconstruct the masked images using factor analysis? Could you imagine alternative approaches that might perform better, and why?


Your answer here


Bonus: The matrix normal distribution#

In the model above, we put a prior on the weights \(\mathbf{W} \in \mathbb{R}^{D \times M}\) by assuming each row to be an independent multivariate normal vector,

\[ \begin{align*} p(\mathbf{W}) &= \prod_{d=1}^D \mathcal{N}(\mathbf{w}_d \mid \mathbf{0}, \tfrac{\sigma_d^2}{\kappa_0} \mathbf{I}). \end{align*} \]

However, in class we noted that it’s a bit strange to put a prior on the rows when it’s the columns (i.e. the principal components) that we really care about.

For this bonus problem, we’ll derive a matrix normal prior distribution instead. The matrix normal is a distribution on matrices \(\mathbf{W} \in \mathbb{R}^{D \times M}\) with three parameters: a mean \(\mathbf{M} \in \mathbb{R}^{D \times M}\), a positive definite covariance among the rows \(\mathbf{\Sigma}_r \in \mathbb{R}_{\succeq 0}^{D \times D}\), and a positive definite covariance among the columns \(\mathbf{\Sigma}_c \in \mathbb{R}_{\succeq 0}^{M \times M}\).

The matrix normal distribution is equivalent to a multivariate distribution on the vectorized (aka flattened or raveled) matrix where the covariance matrix obeys a special, Kronecker-factored form. Specifically,

\[ \begin{align*} \mathbf{W} \sim \mathcal{MN}(\mathbf{M}, \mathbf{\Sigma}_r, \mathbf{\Sigma}_c) \iff \mathrm{vec}(\mathbf{W}) \sim \mathcal{N}(\mathrm{vec}(\mathbf{M}), \mathbf{\Sigma}_r \otimes \mathbf{\Sigma}_c), \end{align*} \]

where \(\mathrm{vec}(\cdot)\) is the vectorization operation that ravels a matrix into a vector (here in row-major, i.e. C order) and \(\otimes\) denotes the Kronecker product.

For example, suppose

\[\begin{split} \begin{align*} \mathbf{M} = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}. \end{align*} \end{split}\]

Then

\[ \begin{align*} \mathrm{vec}\left( \mathbf{M} \right) = [1, 2, 3, 4, 5, 6]^\top. \end{align*} \]

The vectorized matrix is the concatenation of its rows.

To illustrate the Kronecker product, suppose

\[\begin{split} \begin{align*} \mathbf{\Sigma}_r = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix}, \quad \mathbf{\Sigma}_c = \begin{bmatrix} 1 & -1 \\ -1 & 2 \end{bmatrix} \end{align*} \end{split}\]

Then,

\[\begin{split} \begin{align*} \mathbf{\Sigma}_r \otimes \mathbf{\Sigma}_c = \begin{bmatrix} 1 & -1 & 0 & 0 & 0 & 0\\ -1 & 2 & 0 & 0 & 0 & 0\\ 0 & 0 & 1 & -1 & 0 & 0 \\ 0 & 0 & -1 & 2 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & -1 \\ 0 & 0 & 0 & 0 & -1 & 2 \end{bmatrix} \end{align*} \end{split}\]

Since \(\mathbf{\Sigma}_r\) is the identity matrix, each row \(\mathbf{w}_d \in \mathbb{R}^2\) is an independent multivariate normal random variable with covariance \(\mathbf{\Sigma}_c\). With this example in mind, we now see that the prior we used in Part 2 was really a special case of the matrix normal distribution with \(\mathbf{M} = \mathbf{0}\), \(\mathbf{\Sigma}_r = \mathrm{diag}([\sigma_1^2, \ldots, \sigma_D^2])\), and \(\mathbf{\Sigma}_c = \kappa_0^{-1} \mathbf{I}\).

We can derive the matrix normal density by starting from the multivariate normal density on the vectorized matrix,

\[\begin{split} \begin{align*} p(\mathbf{W} \mid \mathbf{M}, \mathbf{\Sigma}_r, \mathbf{\Sigma}_c) &= (2 \pi)^{-\frac{DM}{2}} |\mathbf{\Sigma}_r \otimes \mathbf{\Sigma}_c | \exp \left\{ -\frac{1}{2} \mathrm{vec}(\mathbf{W} - \mathbf{M})^\top (\mathbf{\Sigma}_r \otimes \mathbf{\Sigma}_c)^{-1} \mathrm{vec}(\mathbf{W} - \mathbf{M}) \right\} \\ &= (2 \pi)^{-\frac{DM}{2}} |\mathbf{\Sigma}_r|^M |\mathbf{\Sigma}_c|^D \exp \left\{ -\frac{1}{2} \mathrm{vec}(\mathbf{W} - \mathbf{M})^\top (\mathbf{\Sigma}_r^{-1} \otimes \mathbf{\Sigma}_c^{-1}) \mathrm{vec}(\mathbf{W} - \mathbf{M}) \right\} \\ &= (2 \pi)^{-\frac{DM}{2}} |\mathbf{\Sigma}_r|^M |\mathbf{\Sigma}_c|^D \exp \left\{ -\frac{1}{2} \mathrm{vec}(\mathbf{W} - \mathbf{M})^\top \mathrm{vec}(\mathbf{\Sigma}_r^{-1}(\mathbf{W} - \mathbf{M}) \mathbf{\Sigma}_c^{-1}) \right\} \\ &= (2 \pi)^{-\frac{DM}{2}} |\mathbf{\Sigma}_r|^M |\mathbf{\Sigma}_c|^D \exp \left\{ -\frac{1}{2}\mathrm{Tr} \left[ \mathbf{\Sigma}_c^{-1} (\mathbf{W} - \mathbf{M})^\top \mathbf{\Sigma}_r^{-1} (\mathbf{W} - \mathbf{M}) \right] \right\} \\ &\propto \exp \left\{ -\frac{1}{2}\mathrm{Tr} \left[ \mathbf{\Sigma}_c^{-1} \mathbf{W}^\top \mathbf{\Sigma}_r^{-1} \mathbf{W} \right] + \mathrm{Tr} \left[\mathbf{\Sigma}_c^{-1} \mathbf{M}^\top \mathbf{\Sigma}_r^{-1} \mathbf{W} \right] \right\} \end{align*} \end{split}\]

Note: the definitions given here are appropriate for Python/PyTorch, where vectorization is performed in row-major order. This is in contrast to the definition on Wikipedia, which assumes column-major order, as in Matlab or R. The only difference is ther order of the Kronecker product is flipped.

Bonus Problem [Math]: Derive the conditional distribution of the factor analysis weights under a matrix normal prior#

Now consider the factor analysis model,

\[\begin{split} \begin{align*} \mathbf{z}_n &\sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \\ \mathbf{x}_n &\sim \mathcal{N}(\mathbf{W} \mathbf{z}_n + \boldsymbol{\mu}, \mathrm{diag}(\boldsymbol{\sigma}^2)) \end{align*} \end{split}\]

where \(\boldsymbol{\mu} = [\mu_1, \ldots, \mu_D]^\top\) and \(\boldsymbol{\sigma}^2 = [\sigma_1^2, \ldots, \sigma_D^2]^\top\).

Suppose we put the following matrix normal prior on the weights and variances,

\[\begin{split} \begin{align*} \sigma_d^2 &\sim \chi^{-2}(\nu_0, \sigma_0^2) \\ \boldsymbol{\mu} &\sim \mathcal{N}(\mathbf{0}, \mathrm{diag}(\boldsymbol{\sigma}^2) / \lambda_0) \\ \mathbf{W} &\sim \mathcal{MN}(\mathbf{0}, \mathrm{diag}(\boldsymbol{\sigma}^2), \mathbf{\Sigma}_c) \end{align*} \end{split}\]

where \(\mathbf{\Sigma}_c\) is the prior covariance among the columns.

Derive the complete conditional distribution of the weights,

\[ \begin{align*} p(\mathbf{W} \mid \{\mathbf{z}_n, \mathbf{x}_n\}_{n=1}^N, \boldsymbol{\mu}, \boldsymbol{\sigma}^2, \mathbf{\Sigma}_c) \end{align*} \]

Finally, let \(\mathbf{\Sigma}_c^{-1} \to \mathbf{0}\). What does the conditional mean of \(\mathbf{W}\) converge to? Does this expression look familiar?


Your answer here.


Submission Instructions#

Formatting: check that your code does not exceed 80 characters in line width. You can set Tools → Settings → Editor → Vertical ruler column to 80 to see when you’ve exceeded the limit.

Download your notebook in .ipynb format and remove the Open in Colab button. Then run the following command to convert to a PDF:

jupyter nbconvert --to pdf <yourname>_hw3.ipynb

Installing nbconvert:

If you’re using Anaconda for package management,

conda install -c anaconda nbconvert

If you can’t get nbconvert to work, you may print to PDF using your browswer, but please make sure that none of your code, text, or math is cut off.

Upload your .pdf files to Gradescope.