HW7: Autoregressive HMMs#


Name:

Names of any collaborators:


In this lab we’ll develop hidden Markov models, specifically Gaussian autoregressive hidden Markov models, to analyze depth videos of freely behaving mice. We’ll implement the model developed by Wiltschko et al (2015) and extended in Markowitz et al (2018). Figure 1 of Wiltschko et al is reproduced above.

References

Markowitz, J. E., Gillis, W. F., Beron, C. C., Neufeld, S. Q., Robertson, K., Bhagat, N. D., … & Sabatini, B. L. (2018). The striatum organizes 3D behavior via moment-to-moment action selection. Cell, 174(1), 44-58.

Wiltschko, A. B., Johnson, M. J., Iurilli, G., Peterson, R. E., Katon, J. M., Pashkovski, S. L., … & Datta, S. R. (2015). Mapping sub-second structure in mouse behavior. Neuron, 88(6), 1121-1135.

Environment Setup#

%%capture
!pip install pynwb
!wget -nc https://raw.githubusercontent.com/slinderman/stats305c/main/assignments/hw7/helpers.py
!wget -nc https://www.dropbox.com/s/564wzasu1w7iogh/moseq_data.zip
!unzip -n moseq_data.zip
# First, import necessary libraries.
import torch
from torch.distributions import MultivariateNormal, Categorical
import torch.nn.functional as F

from dataclasses import dataclass
from tqdm.auto import trange
from google.colab import files

import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import seaborn as sns
sns.set_context("notebook")

# We've written a few helpers for plotting, etc.
import helpers

Part 1: Implement the forward-backward algorithm#

First, implement the forward-backward algorithm for computing the posterior distribution on latent states of a hidden Markov model, \(q(z) = p(z \mid x, \Theta)\). Specifically, this algorithm will return a \(T \times K\) matrix where each entry represents the posterior probability that \(q(z_t = k)\).

Problem 1a [Code]: Implement the forward pass#

As we derived in class, the forward pass recursively computes the normalized forward messages \(\tilde{\alpha}_t\) and the marginal log likelihood \(\log p(x \mid \Theta) = \sum_{t} \log A_t\).

Notes:

  • This function takes in the log likelihoods, \(\log \ell_{tk}\), so you’ll have to exponentiate in the forward pass

  • You need to be careful exponentiating though. If the log likelihoods are very negative, they’ll all be essentially zero when exponentiated and you’ll run into a divide-by-zero error when you compute the normalized forward message. Alternatively, if they’re large positive numbers, your exponent will blow up and you’ll get nan’s in your calculations.

  • To avoid numerical issues, subtract \(\max_k (\log \ell_{tk})\) prior to exponentiating. It won’t affect the normalized messages, but you will have to account for it in your computation of the marginal likelihood.

def forward_pass(initial_probs, transition_matrix, log_likes):
    """
    Perform the (normalized) forward pass of the HMM.

    Parameters
    ----------
    initial_probs: $\pi$, the initial state probabilities. Length K, sums to 1.
    transition_matrix: $P$, a KxK transition matrix. Rows sum to 1.
    log_likes: $\log \ell_{t,k}$, a TxK matrix of _log_ likelihoods.

    Returns
    -------
    alphas: TxK matrix with _normalized_ forward messages $\tilde{\alpha}_{t,k}$
    marginal_ll: Scalar marginal log likelihood $\log p(x | \Theta)$
    """
    ##
    # YOUR CODE HERE
    alphas = ...
    marginal_ll = ...
    #
    ##
    return alphas, marginal_ll

Problem 1b [Code]: Implement the backward pass#

Recursively compute the backward messages \(\beta_t\). Again, normalize to avoid underflow, and be careful when you exponentiate the log likelihoods. The same trick of subtracting the max before exponentiating will work here too.

def backward_pass(transition_matrix, log_likes):
    """
    Perform the (normalized) backward pass of the HMM.

    Parameters
    ----------
    transition_matrix: $P$, a KxK transition matrix. Rows sum to 1.
    log_likes: $\log \ell_{t,k}$, a TxK matrix of _log_ likelihoods.

    Returns
    -------
    betas: TxK matrix with _normalized_ backward messages $\tilde{\beta}_{t,k}$
    """
    ##
    # YOUR CODE HERE
    betas = ...
    #
    ##

    return betas

Problem 1c [Code]: Combine the forward and backward passes#

Compute the posterior marginal probabilities. We call these the expected_states because \(q(z_t = k) = \mathbb{E}_{q(z)}[\mathbb{I}[z_t = k]]\). To compute them, combine the forward messages, backward messages, and the likelihoods, then normalize. Again, be careful when exponentiating the likelihoods.

@dataclass
class HMMPosterior:
    expected_states: torch.Tensor
    marginal_ll: float


def forward_backward(initial_probs, transition_matrix, log_likes):
    """
    Fun the forward and backward passes and then combine to compute the 
    posterior probabilities q(z_t=k).

    Parameters
    ----------
    initial_probs: $\pi$, the initial state probabilities. Length K, sums to 1.
    transition_matrix: $P$, a KxK transition matrix. Rows sum to 1.
    log_likes: $\log \ell_{t,k}$, a TxK matrix of _log_ likelihoods.

    Returns
    -------
    posterior: an HMMPosterior object
    """
    ##
    # YOUR CODE HERE
    expected_states = ...
    marginal_ll = ...
    #
    ##
    
    # Package the results into a HMMPosterior
    return HMMPosterior(expected_states=expected_states,
                        marginal_ll=marginal_ll)

Time it on some more realistic sizes#

It should take about 3 seconds for a \(T=36000\) time series with \(K=50\) states.

%timeit forward_backward(*helpers.random_args(36000, 50))

Part 2: Gaussian HMM#

First we’ll implement a hidden Markov model (HMM) with Gaussian observations. This is the same model we studied in class,

\[ \begin{align} p(x, z; \Theta) &= \mathrm{Cat}(z_1; \pi) \prod_{t=2}^{T} \mathrm{Cat}(z_t; P_{z_{t-1}}) \prod_{t=1}^T \mathcal{N}(x_t; \mu_{z_t}, \Sigma_{z_t}) \end{align} \]

with parameters \(\Theta = \pi, P, \{\mu_k, \Sigma_k\}_{k=1}^K\). The observed datapoints are \(x_t \in \mathbb{R}^{D}\) and the latent states are \(z_t \in \{1,\ldots, K\}\).

Problem 2a [Code]: Complete the following GaussianHMM class#

Finish the code below to implement a GaussianHMM object. Specifically, complete the following functions:

  • sample: to simulate from the joint distribution \(p(z_{1:T}, x_{1:T})\).

  • e_step: to compute the posterior expectations and marginal likelihood using the forward_backward function you wrote in Part 1.

  • m_step: to update the parameters by maximizing the expected log joint probability under the posterior from e_step.

  • fit: to run the EM algorithm.

Notes:

  • Recall that in Homework 4 you derived the M-step for a Gaussian mixture model with a normal-inverse Wishart prior distribution. You can reuse the same calculations for the M-step of the Gaussian HMM. Here, we are assuming an improper uniform prior on the parameters \((\mu_k, \Sigma_k)\), but you can think of that as a normal-inverse-Wishart prior with parameters \(\mu_0=0\), \(\kappa_0=0\), \(\Sigma_0=0\), and \(\nu_0=-(D+2)\).

  • For numerical stability, in the M-step you may need to add a small amount to the diagonal of \(\Sigma_k\) and explicitly make it symmetric; e.g. after solving for the optimal covariance do,

Sigma = 0.5 * (Sigma + Sigma.T) + 1e-4 * torch.eye(self.data_dim)

You can think of this as a very weak NIW prior.

  • We will keep the initial distribution and transition matrix fixed in this code!

class GaussianHMM:
    """Simple implementation of a Gaussian HMM.
    """
    def __init__(self, num_states, data_dim):
        self.num_states = num_states
        self.data_dim = data_dim

        # Initialize the HMM parameters
        self.initial_probs = torch.ones(num_states) / num_states
        self.transition_matrix = \
            0.9 * torch.eye(num_states) + \
            0.1 * torch.ones((num_states, num_states)) / num_states
        self.emission_means = torch.randn(num_states, data_dim)
        self.emission_covs = torch.eye(data_dim).repeat(num_states, 1, 1)

    def sample(self, num_timesteps, seed=0):
        """Sample the HMM
        """
        # Set random seed
        torch.manual_seed(seed)

        # Initialize outputs
        states = torch.full((num_timesteps,), -1, dtype=int)
        data = torch.zeros((num_timesteps, self.data_dim))

        ## 
        # YOUR CODE HERE
        states[0] = ...                     # Sample the initial state
        for t in range(num_timesteps):
            data[t] = ...                   # Sample emission
            if t < num_timesteps - 1:
                states[t+1] = ...           # Sample next state
        #
        ##

        return states, data

    def e_step(self, data):
        """Run the forward-backward algorithm and return the posterior distribution
        over latent states given the data and the current model parameters.
        """
        ##
        # YOUR CODE HERE
        posterior = ...
        #
        ##
        return posterior

    def m_step(self, data, posterior):
        """Perform one m-step to update the emission means and covariance given the
        data and the posteriors output by the forward-backward algorithm.
        
        NOTE: We will keep the initial distribution and transition matrix fixed!
        """
        ##
        # YOUR CODE HERE
        self.emission_means = ...
        self.emission_covs = ...
        #
        ##

    def fit(self, data, num_iters=100):
        """Estimate the parameters of the HMM with expectation-maximization (EM).
        """
        # Initialize the posterior randomly
        expected_states = torch.rand(len(data), num_states)
        expected_states /= expected_states.sum(axis=1, keepdims=True)
        posterior = HMMPosterior(expected_states=expected_states,
                                 marginal_ll=-torch.inf)
        
        # Track the marginal log likelihood of the data over EM iterations
        lls = []

        # Main loop of the EM algorithm
        for itr in trange(num_iters):
            ###
            # YOUR CODE HERE

            # E step: compute the posterior distribution given current parameters
            posterior = ...

            # Track the log likeliood
            lls.append(...)

            # M step: udate model parameters under the current posterior
            ...
            #
            ##

            
        # convert lls to arrays and return
        lls = torch.tensor(lls)
        return lls, posterior
        

Sample synthetic data from the model#

# Make a "true" HMM
num_states = 5
data_dim = 2
true_hmm = GaussianHMM(num_states, data_dim)

# Override the emission distribution
true_hmm.emission_means = torch.column_stack([
    torch.cos(torch.linspace(0, 2 * torch.pi, num_states+1))[:-1],
    torch.sin(torch.linspace(0, 2 * torch.pi, num_states+1))[:-1]
])
true_hmm.emission_covs = 0.25**2 * torch.eye(data_dim).repeat(num_states, 1, 1)

# Sample the model
num_timesteps = 200
states, emissions = true_hmm.sample(num_timesteps, seed=305+ord('c'))

# Plot the data and the smoothed data
lim = 1.05 * abs(emissions).max()
plt.figure(figsize=(8, 6))
plt.imshow(states[None,:],
           aspect="auto",
           interpolation="none",
           cmap=helpers.cmap,
           vmin=0,
           vmax=len(helpers.colors)-1,
           extent=(0, num_timesteps, -lim, (data_dim)*lim))

means = true_hmm.emission_means[states]
for d in range(data_dim):
    plt.plot(emissions[:,d] + lim * d, '-k')
    plt.plot(means[:,d] + lim * d, ':k')

plt.xlim(0, num_timesteps)
plt.xlabel("time")
plt.yticks(lim * torch.arange(data_dim), ["$x_{}$".format(d+1) for d in range(data_dim)])

plt.title("Simulated data from an HMM")

plt.tight_layout()

Fit the Gaussian HMM to synthetic data#

# Build the HMM and fit it with EM
hmm = GaussianHMM(num_states, data_dim)
lls, posterior = hmm.fit(emissions)

# Plot the log likelihoods. They should go up.
plt.plot(lls)
plt.xlabel("iteration")
plt.ylabel("marginal log lkhd")
plt.grid(True)
# Plot the true and inferred states
fig, axs = plt.subplots(2, 1, sharex=True)
axs[0].imshow(states[None,:],
              aspect="auto",
              interpolation="none",
              cmap=helpers.cmap,
              vmin=0, vmax=len(helpers.colors)-1)
axs[0].set_yticks([])
axs[0].set_title("true states")

axs[1].imshow(posterior.expected_states.T,
              aspect="auto",
              interpolation="none",
              cmap="Greys",
              vmin=0, vmax=1)
axs[1].set_yticks(torch.arange(num_states))
axs[1].set_ylabel("state")
axs[1].set_xlabel("time")
axs[1].set_title("expected states")

plt.tight_layout()

Problem 2b [Code]: Cross validation#

Fit HMMs with varying numbers of discrete states, \(K\), and compare them on held-out test data. For each \(K\), fit an HMM multiple times from different initial conditions to guard against local optima in the EM fits. Plot the held-out likelihoods as a function of \(K\).

##
# Your code here

#
##

Problem 2c [Short Answer]: Initialization#

The HMM doens’t always find the true latent states. Sometimes it merges the red and blue states, for example. Running multiple restarts with random initializations sometimes works, but not always. Can you think of smarter initialization strategies?


Your answer here.


Part 3: Autoregressive HMMs#

Autoregressive hidden Markov models (ARHMMs) replace the Gaussian observations with an AR model:

\[ \begin{align} p(x, z \mid \Theta) &= \mathrm{Cat}(z_1 \mid \pi) \prod_{t=2}^{T} \mathrm{Cat}(z_t \mid P_{z_{t-1}}) \prod_{t=1}^T p(x_t \mid x_{1:t-1}, z_t) \end{align} \]

The model is “autoregressive” because \(x_t\) depends not only on \(z_t\) but on \(x_{1:t-1}\) as well. The precise form of this dependence varies; here we will consider linear Gaussian dependencies on only the most recent \(L\) timesteps,:

\[ \begin{align} p(x_t \mid x_{1:t-1}, z_t) &= \mathcal{N}\left(x_t \mid \sum_{l=1}^L A_{z_t,l} x_{t-l} + b_{z_t}, Q_{z_t} \right) \qquad \text{for } t > L \end{align} \]

To complete the model, assume

\[ \begin{align} p(x_t \mid x_{1:t-1}, z_t) &= \mathcal{N}\left(x_t \mid 0, I \right) \qquad \text{for } t \leq L \end{align} \]

The new parameters are \(\Theta = \pi, P, \{\{A_{k,l}\}_{l=1}^L, b_{k}, Q_k\}_{k=1}^K\), which include weights \(A_{k,l} \in \mathbb{R}^{D \times D}\) for each of the \(K\) states and the \(L\) lags, and a bias vector \(b_k \in \mathbb{R}^D\).

Note that we can write this as a simple linear regression,

\[ \begin{align} p(x_t \mid x_{1:t-1}, z_t) &= \mathcal{N}\left(x_t \mid W_k \phi_t , Q_{z_t} \right) \end{align} \]

where \(\phi_t = (x_{t-1}, \ldots, x_{t-L}, 1) \in \mathbb{R}^{LD +1}\) is a vector of covariates (aka features) that includes the past \(L\) time steps along with a 1 for the bias term.

\[ \begin{align} W_k = \begin{bmatrix} A_{k,1} & A_{k,2} & \ldots & A_{kL} & b_k \end{bmatrix} \in \mathbb{R}^{D \times LD + 1} \end{align} \]

is a block matrix of the autoregressive weights and the bias.

Note that the covariates are fixed functions of the data so we can precompute them, if we know the number of lags \(L\).

Problem 3a [Math]: Derive the natural parameters and sufficient statistics for a linear regression#

Expand the expected log likelihood of a linear regression model in terms of \(W_k\) and \(b_k\),

\[ \begin{align} \mathbb{E}_{q(z)}\left[ \sum_{t=1}^T \mathbb{I}[z_t=k] \cdot \log \mathcal{N}(x_t \mid W_k \phi_t, Q_k) \right]. \end{align} \]

Write it as a sum of inner products between natural parameters (i.e. functions of \(W_k\) and \(Q_k\)) and expected sufficient statistics (i.e. functions of \(q\), \(x\) and \(\phi\)).


Your answer here


Problem 3b [Math]: Solve for the optimal linear regression parameters given expected sufficient statistics#

Solve for \(W_k^\star, Q_k^\star\) that maximize the objective above in terms of the expected sufficient statistics.


Your answer here


Problem 3c [Code]: Implement an Autoregressive HMM#

Now complete the code below to implement an AR-HMM.

Note: This code assumes \(L=1\).

class AutoregressiveHMM:
    """Simple implementation of an Autoregressive HMM.
    """
    def __init__(self, num_states, data_dim):
        self.num_states = num_states
        self.data_dim = data_dim

        # Initialize the HMM parameters
        self.initial_probs = torch.ones(num_states) / num_states
        self.initial_mean = torch.zeros(data_dim)
        self.initial_cov = torch.eye(data_dim)
        self.transition_matrix = \
            0.9 * torch.eye(num_states) + \
            0.1 * torch.ones((num_states, num_states)) / num_states
        self.emission_dynamics = torch.randn(num_states, data_dim, data_dim)
        self.emission_bias = torch.randn(num_states, data_dim)
        self.emission_cov = torch.eye(data_dim).repeat(num_states, 1, 1)

    def sample(self, num_timesteps, seed=0):
        """Sample the HMM
        """
        # Set random seed
        torch.manual_seed(seed)

        # Initialize outputs
        states = torch.full((num_timesteps,), -1, dtype=int)
        data = torch.zeros((num_timesteps, self.data_dim))

        ## 
        # YOUR CODE HERE
        states[0] = ...                     # Sample the initial state
        data[0] = ...                       # Sample the initial emission
        for t in range(1, num_timesteps):
            states[t] = ...                 # Sample state
            data[t] = ...                   # Sample emission
        #
        ##
        return states, data

    def e_step(self, data):
        """Perform one e-step to compute the posterior over the latent states
        given the data.
        """
        ###
        # YOUR CODE HERE
        posterior = ...
        #
        ##
        return posterior

    def m_step(self, data, posterior):
        """Perform one m-step to update the emission means and covariance given the
        data and the posteriors output by the forward-backward algorithm.
        """
        ##
        # YOUR CODE HERE
        self.emission_dynamics = ...
        self.emission_bias = ...
        self.emission_covs = ...
        #
        ##

    def fit(self, data, num_iters=100):
        """Estimate the parameters of the HMM with expectation-maximization (EM).
        """
        # Initialize the posterior randomly
        expected_states = F.softmax(torch.randn(len(data), num_states), dim=0)
        posterior = HMMPosterior(expected_states=expected_states,
                                 marginal_ll=-torch.inf)
        
        # Track the marginal log likelihood of the data over EM iterations
        lls = []

        # Main loop of the EM algorithm
        for itr in trange(num_iters):
            ###
            # YOUR CODE HERE
            # E step: compute the posterior given the current parameters
            posterior = ...

            # Track the log likeliood
            lls.append(...)

            # M step: update model parameters under the current posterior
            ...
            #
            ##

        # convert lls to arrays and return
        lls = torch.tensor(lls)
        return lls, posterior        

Sample synthetic data from the model#

# Make observation distributions
num_states = 5
data_dim = 2

# Initialize the transition matrix to proceed in a cycle
transition_probs = (torch.arange(num_states)**10).type(torch.float)
transition_probs /= transition_probs.sum()
transition_matrix = torch.zeros((num_states, num_states))
for k, p in enumerate(transition_probs.flip(0)):
    transition_matrix += torch.roll(p * torch.eye(num_states), k, dims=1)

# Initialize the AR dynamics to spiral toward points
rotation_matrix = \
    lambda theta: torch.tensor([[torch.cos(theta), -torch.sin(theta)],
                                [torch.sin(theta),  torch.cos(theta)]])
theta = torch.tensor(-torch.pi / 25)
dynamics = 0.8 * rotation_matrix(theta).repeat(num_states, 1, 1)
bias = torch.column_stack([torch.cos(torch.linspace(0, 2*torch.pi, num_states+1)[:-1]), 
                        torch.sin(torch.linspace(0, 2*torch.pi, num_states+1)[:-1])])
covs = torch.tile(0.001 * torch.eye(data_dim), (num_states, 1, 1))

# Compute the stationary points
stationary_points = torch.linalg.solve(torch.eye(data_dim) - dynamics, bias)

# Construct an ARHMM and overwrite the emission parameters
true_arhmm = AutoregressiveHMM(num_states, data_dim)
true_arhmm.transition_matrix = transition_matrix
true_arhmm.emission_dynamics = dynamics
true_arhmm.emission_bias = bias
true_arhmm.emission_cov = covs

# Plot the true ARHMM dynamics for each of the 5 states
helpers.plot_dynamics(true_arhmm)
# Sample from the true ARHMM
states, data = true_arhmm.sample(10000, seed=305+ord('c'))

# Plot the data
for k in range(num_states):
    plt.plot(*data[states==k].T, 'o', color=helpers.colors[k],
         alpha=0.75, markersize=3)
    
plt.plot(*data.T, '-k', lw=0.5, alpha=0.2)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.gca().set_aspect("equal")

Fit an ARHMM to the synthetic data#

# Construct another ARHMM and fit it with EM
arhmm = AutoregressiveHMM(num_states, data_dim)
lls, posterior = arhmm.fit(data, num_iters=25)

# Plot the log likelihoods. They should go up.
plt.plot(lls)
plt.xlabel("iteration")
plt.ylabel("marginal log lkhd")
plt.grid(True)
# Plot the true and inferred states
fig, axs = plt.subplots(2, 1, sharex=True)
axs[0].imshow(states[None,:],
              aspect="auto",
              interpolation="none",
              cmap=helpers.cmap,
              vmin=0, vmax=len(helpers.colors)-1)
axs[0].set_xlim(0, 1000)
axs[0].set_yticks([])
axs[0].set_title("true states")

axs[1].imshow(posterior.expected_states.T,
              aspect="auto",
              interpolation="none",
              cmap="Greys",
              vmin=0, vmax=1)
axs[1].set_xlim(0, 1000)
axs[1].set_yticks(torch.arange(num_states))
axs[1].set_ylabel("state")
axs[1].set_xlabel("time")
axs[1].set_title("expected states")
plt.tight_layout()
# Plot the learned dynamics
helpers.plot_dynamics(arhmm)

As with the Gaussian HMM, you may find that the ARHMM doesn’t perfectly learn the true underlying states.

Part 4: Fit the ARHMM to mouse videos#

Now we’ll load in some real data from depth video recordings of freely moving mice. This data is from the Datta Lab at Harvard Medical School. The references are given at the top of this notebook.

The video frames, even after cropping, are still 80x80 pixels. That’s a 3600 dimensional observation. In practice, the frames can be adequately reconstructed with far fewer principal components. As little as ten PCs does a pretty good job of capturing the mouse’s posture.

The Datta lab has already computed the principal components and included them in the NWB. We’ll extract them, along with other relevant information like the centroid position and heading angle of the mouse, which we’ll use for making “crowd” movies below. Finally, they also included labels from MoSeq, an autoregressive (AR) HMM. You’ll build an ARHMM in Part 3 of the lab and infer similar discrete latent state sequences yourself!

# Load one session of data
data_dim = 10
train_dataset, test_dataset = helpers.load_dataset(indices=[0], num_pcs=data_dim)
train_data = train_dataset[0]
test_data = test_dataset[0]

You should now have a train_dataset and a test_dataset loaded in memory. Each dataset is a list of dictionaries, one for each mouse. Each dictionary contains a few keys, most important of which is the data key, containing the standardized principal component time series, as shown above. For the test dataset, we also included the frames key, which has the original 80x80 images. We’ll use these to create the movies of each inferred state.

Note: Keeping the data in memory is costly but convenient. You shouldn’t run out of memory in this lab, but if you ever did, a better solution might be to write the preprocessed data (e.g. with the standardized PC trajectories) back to the NWB files and reload those files as necessary during fitting.

Plot a slice of data#

In the background, we’re showing the labels that were given to us from MoSeq, an autoregressive hidden Markov model.

helpers.plot_data_and_states(
    train_data, train_data["labels"],
    title="data and given discrete states")

Fit it!#

With my implementation, this takes about 5 minutes to run on Colab.

# Build the HMM
num_states = 25
data_dim = 10
arhmm = AutoregressiveHMM(num_states, data_dim)

# Fit it!
lls, posterior = arhmm.fit(torch.tensor(train_data["data"]), 
                           num_iters=50)

plt.plot(lls, label="train")
plt.xlabel("iteration")
plt.ylabel("marginal log lkhd")
plt.grid(True)
plt.legend()

Plot the data and the inferred states#

We’ll make the same plot as above (in the warm-up) but using our inferred states instead. Hopefully, the states seem to switch along with changes in the data.

Note: We’re showing the state with the highest marginal probability, \(z_t^\star = \mathrm{arg} \, \mathrm{max}_k \; q(z_t = k)\). This is different from the most likely state path, \(z_{1:T}^\star = \mathrm{arg}\,\mathrm{max} \; q(z)\). We could compute the latter with the Viterbi algorithm, which is similar to the forward-backward algorithm you implemented above.

arhmm_states = posterior.expected_states.argmax(1)
helpers.plot_data_and_states(train_data, arhmm_states)

Plot the state usage histogram#

The state usage histogram shows how often each discrete state was used under the posterior distribution. You’ll probably see a long tail of states with non-trivial usage (hundreds of frames), all the way out to state 50. That suggests the model is using all its available capacity, and we could probably crank the number of states up even further for this model.

# Sort states by usage
arhmm_states = posterior.expected_states.argmax(1)
arhmm_usage = torch.bincount(arhmm_states, minlength=num_states)
arhmm_order = torch.argsort(arhmm_usage).flip(0)

plt.bar(torch.arange(num_states), arhmm_usage[arhmm_order])
plt.xlabel("state index [ordered]")
plt.ylabel("num frames")
plt.title("histogram of inferred state usage")

Plot some “crowd” movies#

test_posterior = arhmm.e_step(torch.tensor(test_data["data"]))
helpers.play(helpers.make_crowd_movie(
    int(arhmm_order[0]), [test_data], [test_posterior]))
helpers.play(helpers.make_crowd_movie(
    int(arhmm_order[1]), [test_data], [test_posterior]))
helpers.play(helpers.make_crowd_movie(
    int(arhmm_order[2]), [test_data], [test_posterior]))
helpers.play(helpers.make_crowd_movie(
    int(arhmm_order[3]), [test_data], [test_posterior]))
helpers.play(helpers.make_crowd_movie(
    int(arhmm_order[4]), [test_data], [test_posterior]))
helpers.play(helpers.make_crowd_movie(
    int(arhmm_order[20]), [test_data], [test_posterior]))

Download crowd movies for all states#

# Make "crowd" movies for each state and save them to disk
# Then you can download them and play them offline
for i in trange(num_states):
    helpers.play(helpers.make_crowd_movie(
        int(arhmm_order[i]), [test_data], [test_posterior]),
        filename="arhmm_crowd_{}.mp4".format(i), show=False)

# Zip the movies up    
!zip arhmm_crowd_movies.zip arhmm_crowd_*.mp4

# Download the files as a zip
files.download("arhmm_crowd_movies.zip")

Problem 4a [Short Answer]: Discussion#

Now that you’ve completed the analysis, discuss your findings in one or two paragraphs. Some questions to consider (though you need not answer all) are:

  • Did any interesting states pop out in your crowd movies?

  • Are the less frequently used states interesting or are they just noise?

  • It took a few minutes to fit data from a single mouse with ~50,000 frames of video. In practice, we have data from dozens of mice and millions of frames of video. What approaches might you take to speed up the fitting procedure?

  • Aside from runtime, what other challenges might you encounter when fitting the same model to multiple mice? What could you do to address those challenges?

  • The ARHMM finds reasonable looking discrete states (“syllables”) but it’s surely not a perfect model. What changes could you make to better model mouse behavior?


Your answer here.


Formatting: check that your code does not exceed 80 characters in line width. If you’re working in Colab, you can set Tools → Settings → Editor → Vertical ruler column to 80 to see when you’ve exceeded the limit.

Download your notebook in .ipynb format and use the following commands to convert it to PDF:

jupyter nbconvert --to pdf hw7_yourname.ipynb

Dependencies:

  • nbconvert: If you’re using Anaconda for package management,

conda install -c anaconda nbconvert

Upload your .pdf file to Gradescope.