HW6: Neural Networks and VAEs#


Name:

Collaborators:


In this homework assignment, we will explore automatic differentiation, neural networks, and amortized variational inference.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, Bernoulli, Uniform
from torch.distributions.kl import kl_divergence
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.utils import make_grid
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms

import numpy as np
import math
from tqdm.notebook import tqdm

torch.manual_seed(305)

import matplotlib.pyplot as plt
from matplotlib.cm import Blues
import seaborn as sns
sns.set_context("notebook")

Problem 1: Optimizing a Quadratic Objective via Gradient Descent#

We’ll start off by optimizing a simple objective using gradient descent. We will compute the required gradients using PyTorch’s automatic differentation capabilities.

Consider the function \(f: \mathbb{R}^D \to \mathbb{R}\) given by:

\[ f(\mathbf{x}) = \mathbf{x}^\top \mathbf{A} \mathbf{x} \]

where \(\mathbf{A} \in \mathbb{R}^{D \times D}\) is a fixed positive definite matrix. It is obvious that a global minimizer of \(f\) is \(\mathbf{x}^* = 0\). We will try to recover this known global minimizer using gradient descent. We note that although the objective is seemingly simple, a stochastic version of this objective has been used as a model for neural network loss surfaces (e.g. see this paper).

We will simplify the objective further by assuming that \(\mathbf{A}\) is diagonal, so \(\mathbf{A} = \text{diag}(\mathbf{a})\) for some \(\mathbf{a} \in \mathbb{R}_{++}^D\).

Problem 1a: Gradient Descent with a well-conditioned Objective#

Recall the gradient descent update rule for a fixed step size \(\alpha\) is:

\[\mathbf{x}^{(k + 1)} \gets \mathbf{x}^{(k)} - \alpha \nabla f(\mathbf{x}^{(k)})\]

Using PyTorch’s automatic differentiation system, complete the function below implementing an optimization loop to minimize the objective.

def f(x, a):
    """ Evaluates the objective.

    Args:
        x: (D,) tensor 
        a: (D,) tensor of positive values
    Returns:
        value: (,) tensor
    """
    return (a * x**2).sum()

def run_optimization(x, a, optimizer, num_iters):
    """ Runs an optimization algorithm on the objective for num_iters.

    Args:
        x: Starting position for optimization. nn.Parameter of shape (D,)
        a: parameter defining curvature of objective. tensor of shape (D,)
        optimizer: a torch.optim Optimizer
        num_iters: number of iterations to run optimization for

    Returns:
        xs: value of x at each iteration. tensor of shape (num_iters + 1, D)
        losses: value of objective at each iterate. tensor of shape (num_iters + 1,)
    """
    losses = []
    # we have to make a copy of the data so that elements of list do not 
    # reference same location in memory
    xs = [x.data.clone()]

    for it in range(num_iters):

        ###
        # YOUR CODE HERE
        # 1. zero out the gradient
        # 2. compute the loss
        # 3. compute the gradient of the loss
        # 4. take one optimizer step (this will update x in place)
        ...
        loss = ...
        ...
        ##
        losses.append(loss)
        xs.append(x.data.clone()) 

    # Return the stacked losses and xs
    losses = torch.tensor(losses)
    xs = torch.vstack(xs)
    return xs, losses

We will now run gradient descent on the objective in \(D = 2\) dimensions, with \(\mathbf{a} = (0.2, 0.2)\). We will use 50 iterations with a learning rate of \(\alpha = 0.8\), starting from \(\mathbf{x}_0 = (10, 10)\).

Note that we initialize x using nn.Parameter, which tells PyTorch we need to compute gradients with respect to x. Note that PyTorch’s optimizers expect an iterable of Parameters to be passed as the first argument, so you must pass in [x] rather than x itself.

D = 2
x = nn.Parameter(10 * torch.ones(D))
a = 0.2 * torch.ones(D)

### 
# YOUR CODE HERE
optimizer = ...
xs, losses = run_optimization(...)
##

Let’s plot the loss curve. As you can see, the objective approaches the optimal value of zero quite quickly.

plt.figure()
plt.plot(losses)
plt.xlabel("Gradient Descent Iteration")
plt.ylabel("Objective Value")

Next, let’s visualize the trajectory of the gradient descent iterates \(\mathbf{x}^{(k)}\).

def visualize_gradient_descent(xs, a, width=12, grid_size=200):
    """ Visualizes gradient descent when iterates are two-dimensional."""
    
    def batch_f(X):
        """
        Args: 
            X: (N, D) tensor
        """
        return (X**2 * a).sum(dim=1)

    grid_size = 200
    x1s = np.linspace(-width, width, grid_size)
    x2s = np.linspace(-width, width, grid_size)
    X1, X2 = np.meshgrid(x1s, x2s)
    points = np.transpose([np.tile(x1s, len(x2s)), np.repeat(x2s, len(x1s))])
    points = torch.tensor(points, dtype=torch.float)

    Z = batch_f(points).reshape(grid_size, grid_size)

    fig = plt.figure(figsize = (10,7))
    contours = plt.contour(X1, X2, Z, 20)

    plt.plot(xs[:,0], xs[:,1])
    plt.plot(xs[:,0], xs[:,1], '*', label = "Cost function")

    plt.xlabel('$x_1$', fontsize=15)
    plt.ylabel('$x_2$', fontsize=15)
    plt.show()
visualize_gradient_descent(xs, a)

Problem 1b: Gradient Descent with a Ill-Conditioned Objective#

Next, let’s see how gradient descent performs on an ill-conditioned objective, where the amount of curvature in each direction varies. For our objective, we can control the amount of curvature using \(\mathbf{a}\). Repeat part (a), this time using \(\mathbf{a} = (0.05, 1.2)\) while holding all other hyperparameters constant.

x = nn.Parameter(10 * torch.ones(D))
a = torch.tensor([0.05, 1.2])

###
# YOUR CODE HERE
optimizer = ...
xs, losses = run_optimization(...)
##
plt.figure()
plt.plot(losses)
plt.xlabel("Gradient Descent Iteration")
plt.ylabel("Objective Value")
visualize_gradient_descent(xs, a)

Explain why the gradient descent iterates oscillate around the line \(\mathbf{x}_2 = 0\) and suggest one change that could be made to eliminate this behavior, without changing the objective itself.


Your answer here.


Problem 1c: Optimizing a High-Dimensional Objective#

Now let’s tackle a more challenging, higher-dimensional problem. We’ll use the same objective, but this time use \(\mathbf{a} \in \mathbb{R}^{10}\) where:

\[a_i = 10^{-2 + \frac{4(i - 1)}{9}}\]

for \(i \in \{1, \dots, 10\}\). This means the curvature of the dimensions ranges from \(0.01\) to \(100\).

Experiment with different optimizers and hyperparameter settings on this problem, starting from the initial point \(x^{(0)} = (10, \dots, 10)\) and run your chosen optimizer for \(1000\) iterations. You can see a complete list of optimizers PyTorch has implemented here. Find an optimizer/hyperparameter regime that achieves a final loss of less than \(0.01\).

x = nn.Parameter(10 * torch.ones(10))
a = torch.tensor([10**i for i in np.linspace(-2, 2, num=10)])

### 
# YOUR CODE HERE
optimizer = ...
xs, losses = run_optimization(...)
##
plt.figure()
plt.plot(losses)
plt.xlabel("Optimization Iteration")
plt.ylabel("Objective Value")
print("Final loss: {:.10f}".format(losses[-1].item()))

Problem 2: Neural Network Classification#

Next, we will use a neural network to solve a classification problem for which the data is not linearly separable. We will implement the network as a PyTorch nn.module and train it using gradient descent, computing gradients using automatic differentation.

First, we create and visualize a two-dimensional dataset where each point is labeled as positive (1) or negative (0). As seen below, the positive and negative points are not linearly separable.

torch.manual_seed(305)

def make_dataset(num_points):
    radius = 5

    def sample_annulus(inner_radius, outer_radius, num_points):
        r = Uniform(inner_radius, outer_radius).sample((num_points,))
        angle = Uniform(0, 2 * math.pi).sample((num_points,))
        x = r * torch.cos(angle)
        y = r * torch.sin(angle)
        data = torch.vstack([x, y]).T
        return data
    
    # Generate positive examples (labeled 1)
    data_1 = sample_annulus(0, 0.5 * radius, num_points // 2)
    labels_1 = torch.ones(num_points // 2)
        
    # Generate negative examples (labeled 0).
    data_0 = sample_annulus(0.7 * radius, radius, num_points // 2)
    labels_0 = torch.zeros(num_points // 2)
        
    data = torch.vstack([data_0, data_1])
    labels = torch.concat([labels_0, labels_1])

    return data, labels
    
num_data = 500
data, labels = make_dataset(num_data)

# Note: red indicates a label of 1, blue indicates a label of 0
plt.scatter(data[:num_data//2, 0], data[:num_data//2, 1], color='red') 
plt.scatter(data[num_data//2:, 0], data[num_data//2:, 1], color='blue') 

Problem 2a: The Maximum Likelihood Objective#

We will try to classify this data using a neural network. We posit the following statistical model for the labels \(y \in \{0, 1\}\) given features \(\mathbf{x} \in \mathbb{R}^2\):

\[ y \mid x \sim \text{Bern}(\sigma(\text{NN}_{\boldsymbol{\theta}}(\mathbf{x})))\]

Here, \(\text{NN}_{\boldsymbol{\theta}}: \mathbb{R}^2 \to \mathbb{R}\) denotes a neural network with parameters \(\boldsymbol{\theta} \in \mathbb{R}^P\) mapping datapoints \(\mathbf{x}\) to \(\text{logit}(\mathbb{P}(y = 1 \mid \mathbf{x}))\). Recall the logit function is given by

\[\text{logit}(p) = \log \left( \frac{p}{1-p} \right)\]

and its inverse is the sigmoid:

\[\sigma(x) = \frac{1}{1 + \exp (-x)}\]

We estimate the parameters \(\boldsymbol{\theta}\) using maximum likelihood. Show that for a dataset \(\{(\mathbf{x}_n, y_n)\}_{n=1}^N\) the negative log-likelihood objective, rescaled by the number of datapoints \(N\), may be written as:

\[ L(\boldsymbol{\theta}) = \frac{1}{N} \sum_{n = 1}^N -y_n \log \sigma(\text{NN}_{\boldsymbol{\theta}}(\mathbf{x})) - (1 - y_n) \log (1 - \sigma(\text{NN}_{\boldsymbol{\theta}}(\mathbf{x}))) = \frac{1}{N} \sum_{n = 1}^N \ell(y, \text{NN}_{\boldsymbol{\theta}}(\mathbf{x}_n)) \]

where \(\ell(y, x) = -y \log \sigma(x) - (1 - y) \log(1 - \sigma(x))\).


Your answer here.


Problem 2b: Define the Neural Network#

We will use a neural network with two hidden layers, the first of which has three hidden units and the second of which has five hidden units. The equations defining the output \(z \in \mathbb{R}\) of our neural network given an input \(\mathbf{x} \in \mathbb{R}^2\) are:

\[\begin{split} \begin{align*} \mathbf{h}_1 &= \text{ReLU}(\mathbf{W_1} \mathbf{x} + \mathbf{b_1}) \\ \mathbf{h}_2 &= \text{ReLU}(\mathbf{W_2} \mathbf{h}_1 + \mathbf{b_2}) \\ z &= \mathbf{w}_3^\top \mathbf{h}_2 + b_3 \end{align*} \end{split}\]

The parameters of the network are \(\boldsymbol{\theta} = (\mathbf{W}_1, \mathbf{b}_1, \mathbf{W}_2, \mathbf{b}_2, \mathbf{w}_3, b_3)\) where \(\mathbf{W}_1 \in \mathbb{R}^{3 \times 2}, \mathbf{b}_1 \in \mathbb{R}^3, \mathbf{W}_2 \in \mathbb{R}^{5 \times 3}, \mathbf{b}_2 \in \mathbb{R}^5, \mathbf{w}_3 \in \mathbb{R}^5\), and \(b_3 \in \mathbb{R}\).

Implement this network as a PyTorch nn.Module using nn.Linear and F.relu.

class SimpleNet(nn.Module):

    def __init__(self):
        super().__init__()
        ###
        # YOUR CODE HERE
        ##

    def forward(self, x):
        """ Implements the forward pass of the network.

        Args:
            x: torch.tensor of shape (N, 2)
        Returns:
            logits: torch.tensor of shape (N,) containing the logits
        """
        ###
        # YOUR CODE HERE
        logits = ...
        ##
        return logits

Let’s visualize the predictions of an untrained network. As we can see, the network does not succeed at classifying the points without training

def visualize_predictions(net):
    num_points = 200
    x1s = np.linspace(-6.0, 6.0, num_points)
    x2s = np.linspace(-6.0, 6.0, num_points)
    X1, X2 = np.meshgrid(x1s, x2s)

    points = np.transpose([np.tile(x1s, len(x2s)), np.repeat(x2s, len(x1s))])
    points = torch.tensor(points, dtype=torch.float)
    with torch.no_grad():
        probs = torch.sigmoid(net(points)).reshape(num_points, num_points)

    plt.pcolormesh(X1, X2, probs, cmap=plt.cm.get_cmap('YlGn'), vmin=0, vmax=1)
    plt.colorbar()
    plt.scatter(data[:num_data//2, 0], data[:num_data//2, 1], color='red') 
    plt.scatter(data[num_data//2:, 0], data[num_data//2:, 1], color='blue') 
    plt.title("Output Probabilities")

torch.manual_seed(305)
model = SimpleNet()
visualize_predictions(model)

Problem 2c: Train the Network#

We will now find the parameters of our network by maximizing the likelihood, or equivalently minimizing the negative log-likelihood. We will use full-batch gradient descent. That is, we will use the gradient \(\nabla L(\boldsymbol{\theta})\) itself to update the parameters rather than a stochastic estimate of \(\nabla L(\boldsymbol{\theta})\).

Use the SGD optimizer from torch.optim with a learning rate of \(1\) and no momentum for 1000 iterations. Note that the function \(\ell\) from above is implemented in PyTorch as nn.BCEWithLogitsLoss.

num_steps = 1000
losses = []

###
# YOUR CODE HERE
loss_fn = ...
optimizer = ...
for it in tqdm(range(num_steps)):
    ...
    loss = ...
    ...
    losses.append(loss.item())
##
plt.figure()
plt.plot(losses)
plt.xlabel("Gradient Descent Iteration")
plt.ylabel("Objective Value")

Let’s visualize the predictions of our trained network. We see that the network has learned to separate the positive and negative examples.

visualize_predictions(model)

Problem 3: Amortized Variational Inference#

In this problem, we will train a variational autoencoder for the MNIST dataset of handwritten digits. First, let’s download this dataset using PyTorch’s datasets module and visualize some of the digits. We will use a binarized version of the dataset in which each pixel value is either 0 or 1.

# Download MNIST dataset and create dataloaders. 
def binarize(imgs, integer=False):
    threshold = 127 if integer else 0.5
    imgs = imgs.clone()
    imgs[imgs < threshold] = 0.
    imgs[imgs >= threshold] = 1.
    return imgs

train_dataset = datasets.MNIST(root="data", train=True, download=True,
                         transform=transforms.ToTensor())

test_dataset = datasets.MNIST(root='data', train=False, download=True,
                             transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("Number of points in dataset: {0}".format(train_dataset.data.shape[0]))
print("Number of batches per epoch: {0}".format(len(train_loader)))
# Visualize some digits in the dataset.
imgs, _ = next(iter(train_loader))
imgs = binarize(imgs)
fig, ax = plt.subplots(1, 6, figsize=(14, 14))
fig.tight_layout()
for i, ax in enumerate(ax.flat):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(imgs[i].squeeze(), alpha=0.8, cmap='gray')

Problem 3a: Decoder Network#

We represent a \(28 \times 28\) image \(\mathbf{x}\) as a flattened \(784\) dimensional vector of binary values, i.e. \(\mathbf{x} \in \{0, 1\}^{784}\). We specify our generative model as:

\[ \mathbf{z} \sim \mathcal{N}(0, I), \quad \mathbf{x} \mid \mathbf{z} \sim \text{Bern}(\sigma(D_{\boldsymbol{\theta}}(\mathbf{z}))\]

Here, \(D_{\boldsymbol{\theta}}: \mathbb{R}^2 \to \mathbb{R}^{784}\) is a neural network with parameters \(\boldsymbol{\theta}\) and \(\mathbf{z} \in \mathbb{R}^{2}\) is a two-dimensional latent variable. We use only two dimensions so that the latent space can be easily visualized later, but using a higher dimensional latent variable would give a more flexible generative model.

We will parametrize \(D_{\boldsymbol{\theta}}\) as a fully connected neural network with two hidden layers and ReLU activations. We use 256 units in the first hidden layer and 512 in the second. Note that as in Problem 2, the network maps to the logits of the Bernoulli distribution and not the probabilities themselves. Implement this decoder network in PyTorch below.

# Define decoder architecture
class Decoder(nn.Module):
    """ Neural network defining p(x | z) """

    def __init__(self, data_dim, latent_dim, hidden_dims=[256, 512]):
        super().__init__()
        self.data_dim = data_dim

        ###
        # YOUR CODE HERE
        ## 
        
    def forward(self, z):
        """ Returns Bernoulli conditional distribution of p(x | z), parametrized
        by logits.
        Args:
            z: (N, latent_dim) torch.tensor
        Returns:
            Bernoulli distribution with a batch of (N, data_dim) logits
        """
        ###
        # YOUR CODE HERE
        ## 
        return Bernoulli(...)

Problem 3b: Encoder Network#

We will estimate the parameters of the generative model by maximizing the Evidence Lower Bound (ELBO). As the exact posterior \(p(\mathbf{z} \mid \mathbf{x})\) is unknown, we will use an approximate, amortized posterior \(q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x}) = \mathcal{N}(\mathbf{z} \mid \mu_{\boldsymbol{\phi}}(\mathbf{x}), \text{diag}(\sigma^2_{\boldsymbol{\phi}}(\mathbf{x})))\). We let \(\left(\mu_{\boldsymbol{\phi}}(\mathbf{x}), \log \sigma^2_{\boldsymbol{\phi}}(\mathbf{x}) \right) = E_{\boldsymbol{\phi}}(\mathbf{x})\) where \(E_{\boldsymbol{\phi}}: \mathbb{R}^{784} \to \mathbb{R}^2 \times \mathbb{R}^2\) is a neural network with parameters \(\boldsymbol{\phi}\).

As above, we parametrize \(E_{\boldsymbol{\phi}}\) as a neural network with two layers of hidden units and ReLU activations. We use 512 hidden units in the first layer and 256 in the second. Then we let \(\mu_{\boldsymbol{\phi}}\) and \(\log \sigma^2_{\boldsymbol{\phi}}\) be affine functions of the hidden layer activations. Implement the encoder \(E_{\boldsymbol{\phi}}\) in the code below.

# Define encoder architecture
class Encoder(nn.Module):
    """ Neural network defining q(z | x). """

    def __init__(self, data_dim, latent_dim, hidden_dims=[512, 256]):
        super().__init__()

        ###
        # YOUR CODE HERE
        ##

    def forward(self, x):
        """ Returns Normal conditional distribution for q(z | x), with mean and
        log-variance output by a neural network.

        Args:
            x: (N, data_dim) torch.tensor
        Returns:
            Normal distribution with a batch of (N, latent_dim) means and standard deviations
        """
        ###
        # YOUR CODE HERE
        ##
        return Normal(...)

ELBO Derivation [given]#

As a function of \(\boldsymbol{\theta}\) and \(\boldsymbol{\phi}\), we can write the ELBO for a single datapoint \(\mathbf{x}\) as:

\[\mathcal{L}(\mathbf{x}, \boldsymbol{\theta}, \boldsymbol{\phi}) = \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x})} \left[ \log p_{\boldsymbol{\theta}}(\mathbf{x}, \mathbf{z}) - \log q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x}) \right]\]

We can obtain a lower bound of the log-likelihood for an entire dataset \(\{\mathbf{x}^{(n)} \}_{n=1}^N\), rescaled by the number of datapoints \(N\), as:

\[\mathcal{L}(\boldsymbol{\theta}, \boldsymbol{\phi}) = \frac{1}{N} \sum_{n=1}^N \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z}^{(n)} \mid \mathbf{x}^{(n)})} \left[ \log p_{\boldsymbol{\theta}}(\mathbf{x}^{(n)}, \mathbf{z}^{(n)}) - \log q_{\boldsymbol{\phi}}(\mathbf{z}^{(n)} \mid \mathbf{x}^{(n)}) \right]\]

We can rewrite the per-datapoint ELBO as:

\[\begin{split} \begin{align*} \mathcal{L}(\mathbf{x}, \boldsymbol{\theta}, \boldsymbol{\phi}) &= \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x})} \left[ \log p_{\boldsymbol{\theta}}(\mathbf{x} \mid \mathbf{z}) \right] - \text{KL}\left( q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x}) \mid\mid p(\mathbf{z})\right) \\ &= \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \left[ \log p_{\boldsymbol{\theta}}(\mathbf{x} \mid \mu_{\boldsymbol{\phi}}(\mathbf{x}) + \boldsymbol{\epsilon} \odot \sigma_{\boldsymbol{\phi}}(\mathbf{x})) \right] - \text{KL}\left( q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x}) \mid\mid p(\mathbf{z})\right) \end{align*} \end{split}\]

This allows us to obtain an unbiased estimate of the per-datapoint ELBO by first sampling \(\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\), then computing:

\[\hat{\mathcal{L}}(\mathbf{x}, \boldsymbol{\theta}, \boldsymbol{\phi}) = \log p_{\boldsymbol{\theta}}(\mathbf{x} \mid \mu_{\boldsymbol{\phi}}(\mathbf{x}) + \boldsymbol{\epsilon} \odot \sigma_{\boldsymbol{\phi}}(\mathbf{x})) - \text{KL}\left( q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x}) \mid\mid p(\mathbf{z})\right)\]

This is known as the reparametrization trick, and it will allow us to straightforwardly use automatic differentiation to obtain the gradient of \(\hat{\mathcal{L}}\) with respect to \(\boldsymbol{\phi}\).

Given a minibatch \(\{\mathbf{x}^{(b)} \}_{b=1}^B\) sampled uniformly from the entire dataset, we can simulate independent normal variates \(\boldsymbol{\epsilon}^{(b)}\) to form an unbiased estimator of the ELBO for the entire dataset:

\[\hat{\mathcal{L}}(\boldsymbol{\theta}, \boldsymbol{\phi}) = \frac{1}{B} \sum_{b = 1}^B \log p_{\boldsymbol{\theta}}(\mathbf{x}^{(b)} \mid \mu_{\boldsymbol{\phi}}(\mathbf{x}^{(b)}) + \boldsymbol{\epsilon} \odot \sigma_{\boldsymbol{\phi}}(\mathbf{x}^{(b)})) - \text{KL}\left( q_{\boldsymbol{\phi}}(\mathbf{z}^{(b)} \mid \mathbf{x}^{(b)}) \mid\mid p(\mathbf{z}^{(b)})\right) \]

Problem 3c: Implement the ELBO#

Using our derivations above, implement the estimator of the ELBO \(\hat{\mathcal{L}}(\boldsymbol{\theta}, \boldsymbol{\phi})\). We assume sampling of the minibatch x is done outside of the function, but you must sample the noise variables \(\boldsymbol{\epsilon}\) within the elbo function. You should use the kl_divergence function imported above to analytically compute the KL divergence between the Gaussian distributions \(q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x})\) and \(p(\mathbf{z})\). Make sure you use rsample on a Distribution object to use the reparametrization trick and not sample.

def elbo(x, encoder, decoder):
    """ Computes a stochastic estimate of the rescaled evidence lower bound

    Args:
        x: (N, data_dim) torch.tensor
        encoder: an Encoder
        decoder: a Decoder
    Returns:
        elbo: a (,) torch.tensor containing the estimate of the ELBO
    """
    ###
    # YOUR CODE HERE
    elbo = ...
    ##
    return elbo

Implement the Training Loop [given]#

Using our Encoder and Decoder definitions, as well as the elbo function, we have provided training code below. This code uses the Adam optimizer, a sophisticated optimization algorithm which uses the history of past gradients to rescale gradients before applying an update.

We train for 20 epochs (an “epoch” refers to a complete pass through the dataset). Our implementation takes 10 minutes to run and achieves a training ELBO of \(-135\) and a test ELBO of \(-138\).

encoder = Encoder(data_dim=784, latent_dim=2)
decoder = Decoder(data_dim=784, latent_dim=2)
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()),
                       lr=3e-4)

num_epochs = 20

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    train_elbo = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = binarize(x.reshape(x.shape[0], -1))
        optimizer.zero_grad()

        loss = -elbo(x, encoder, decoder)  
        loss.backward()
        train_elbo -= loss.item() * len(x)
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tELBO: {:.6f}'.format(
                epoch, batch_idx * len(x), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), -loss.item()))
            
    encoder.eval()
    decoder.eval()
    test_elbo = 0
    with torch.no_grad():
        for x, _ in test_loader:
            x = binarize(x.reshape(x.shape[0], -1))#.cuda()
            test_elbo += elbo(x, encoder, decoder).item() * len(x)
            
    train_elbo /= len(train_loader.dataset)
    test_elbo /= len(test_loader.dataset)
    
    print('====> Epoch: {} Average ELBO: {:.4f} Test ELBO: {:.4f}'.format(epoch,
                                                                          train_elbo,
                                                                          test_elbo))

Problem 3d: Visualize samples from the trained model#

In addition to the ELBO, we can sample from the trained model to assess its performance. Use the code below to generate an \(8 \times 8\) grid of sampled digits from the model. Note that we follow the common practice of using the mean of \(p_{\boldsymbol{\theta}}(\mathbf{x} \mid \mathbf{z})\) rather than resampling from this distribution when visualizing samples. Critique these samples. What aspects of the data distribution does the model seem to have trouble learning?

# Visualize sampled digits from our model
decoder.eval()

num_samples = 64
with torch.no_grad():
    z = torch.randn(num_samples, 2)
    expected_xs = decoder.forward(z).mean 
    expected_xs = expected_xs.reshape(-1, 28, 28).unsqueeze(1)

# Plot the expected_xs as a grid of images
expected_xs_grid = make_grid(expected_xs, nrow=8)
plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(expected_xs_grid.permute(1, 2, 0), vmin=0., vmax=1.)
plt.show()

Your answer here.


Problem 3e: Visualize the Latent Embeddings#

Given \(\mathbf{x}\), we can interpret the mean of the approximate posterior \(\mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z} \mid \mathbf{x})}[\mathbf{z}]\) as a lower dimensional representation of \(\mathbf{x}\). In the code below, we find the mean of the approximate posterior for each of the datapoints in the dataset and then visualize these means with a scatter plot. We color each point according to the label of the encoded digit. What do you notice? Are there classes with significant overlap, and are these classes which are visually similar? Is there a class which has clear separation from the others, and if so, why do you think this is?

Note that we did not provide any information about the class label to either the generative model or the approximate posterior!

# Compute the mean of the latents given the data
encoder.eval()
with torch.no_grad():
    means = []
    ys = []
    for x, y in train_loader:        
        x = binarize(x.reshape(x.shape[0], -1))
        mean = encoder.forward(x).mean
        means.append(mean)
        ys.append(y)

means = torch.vstack(means)
ys = torch.hstack(ys)

# Plot the first two dimensions of the latents
fig, ax = plt.subplots(figsize=(8, 8))
for i in range(10):
    means_i = means[ys == i]
    ax.scatter(means_i[:, 0], means_i[:, 1], label=str(i))

ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
ax.legend()

Your answer here.


Problem 3f: Interpolation in the Latent Space#

Another desideratum for a latent variable model is smooth interpolation in the latent space. For example, if we linearly interpolate between a latent \(\mathbf{z}_{start}\) corresponding to a \(7\) and a latent \(\mathbf{z}_{end}\) corresponding to a \(1\), we should observe the decodings of the interpolations smoothly change from a \(7\) to a \(1\).

In the code below, we sample \(8\) different starting latent variables and \(8\) different ending latent variables from the prior, linearly interpolate between them for \(10\) points, then plot the decodings. Does our model smoothly change between decoded digits? Are there digit pairs it was more successful interpolating between?

# Interpolate between 8 randomly chosen start and end points
latent_starts = torch.randn(8, 2)
latent_ends = torch.randn(8, 2)

means = []
for t in torch.linspace(0, 1, 10):
    z = latent_starts + t * (latent_ends - latent_starts)
    with torch.no_grad():
        means.append(decoder.forward(z).mean.reshape(-1, 28, 28).unsqueeze(0))

means_tensor = torch.vstack(means).permute(1, 0, 2, 3).reshape(-1, 28, 28).unsqueeze(1)
sample_grid = make_grid(means_tensor, nrow=10)

plt.figure(figsize=(10,10))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()

Your answer here.


Problem 4: Reflections#

Problem 4a#

Discuss one reason why we use amortized variational inference rather than optimizing per-datapoint latent variables \(\mu^{(n)}, (\boldsymbol{\sigma}^2)^{(n)}\) (so \(p(\mathbf{z^{(n)}} \mid \mathbf{x}^{(n)})= \mathcal{N}(\mathbf{z}^{(n)} \mid \mu^{(n)}, \text{diag}((\boldsymbol{\sigma}^2)^{(n)})\)).


Your answer here.


Problem 4b#

Describe one way you could improve the variational autoencoder, either by changing the encoder or decoder network structure or by changing the model itself, and why you think your proposed change would help.


Your answer here.


Problem 4c#

Suppose rather than using a Gaussian prior on \(\mathbf{z}\), we used \(\mathbf{z} \overset{ind}{\sim} \text{Bern}(0.5)\). We can modify \(E_{\boldsymbol{\phi}}\) to output logits for a multivariate Bernoulli distribution: \(p(\mathbf{z} \mid \mathbf{x}; \boldsymbol{\phi}) = \text{Bern}(\mathbf{z}; \sigma(E_{\boldsymbol{\phi}}(\mathbf{x})))\). Where would our optimization procedure break down in this case?


Your answer here.


Submission Instructions#

Formatting: check that your code does not exceed 80 characters in line width. You can set Tools → Settings → Editor → Vertical ruler column to 80 to see when you’ve exceeded the limit.

Download your notebook in .ipynb format and remove the Open in Colab button. Then run the following command to convert to a PDF:

jupyter nbconvert --to pdf <yourname>_hw6.ipynb

Installing nbconvert:

If you’re using Anaconda for package management,

conda install -c anaconda nbconvert

Upload your .pdf files to Gradescope.