The Multivariate Normal Distribution#

import torch
from torch.distributions import Normal, MultivariateNormal, Wishart
torch.manual_seed(305)

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context("notebook")

The Generative Story#

Start with a vector of standard normal random variates, z=[z1,,zD] where zdN(0,1).

This is a D-dimensional random variable, but not a very interesting one. All the coordinates are independent! The joint density is,

p(z)=d=1DN(zd0,1)=d=1D12πe12zd2=(2π)D2exp{12d=1Dzd2}=(2π)D2exp{12zz}

What do the contours of this joint density look like in D=2 dimensions?

z1, z2 = torch.meshgrid(torch.linspace(-4, 4, 50),
                        torch.linspace(-4, 4, 50))

logpdf = Normal(0, 1).log_prob(z1) + Normal(0, 1).log_prob(z2)

plt.contour(z1, z2, logpdf)
plt.xlabel("$z_1$")
plt.ylabel("$z_2$")
plt.gca().set_aspect(1)
/opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
../_images/93457163fa449f3244e51cba7dbe9addb57492f194ac888795ea806b42df0476.png

We can obtain more interesting joint distributions by transforming this random vector.

For example, let U be an orthogonal D×D matrix and Λ=diag([λ1,,λD]) with λd>0. Define the linearly transformed random variable x=UΛ12z.

# Make an orthogonal 2 x 2 matrix
theta1 = torch.tensor(torch.pi / 4)
u1 = torch.tensor([torch.cos(theta1), torch.sin(theta1)])
theta2 = theta1 + torch.pi / 2
u2 = torch.tensor([torch.cos(theta2), torch.sin(theta2)])
U = torch.column_stack([u1, u2])

# choose two eigenvales
lmbda1 = torch.tensor(2.0**2)
lmbda2 = torch.tensor(0.5**2)
Lmbda = torch.tensor([lmbda1, lmbda2])

# pick one isocontour in z space and transform it to x coordinates
zs = torch.column_stack([
    torch.cos(torch.linspace(0, 2 * torch.pi, 50)),
    torch.sin(torch.linspace(0, 2 * torch.pi, 50))
])
xs = (zs * torch.sqrt(Lmbda)) @ U.T

# Plot the new basis
plt.plot(xs[:, 0], xs[:, 1], ':')
plt.arrow(0, 0, torch.sqrt(lmbda1) * u1[0], torch.sqrt(lmbda1) * u1[1], 
          color='k', head_width=0.2)
plt.text(torch.sqrt(lmbda1) * u1[0], torch.sqrt(lmbda1) * u1[1] + .4, 
         "$\sqrt{\lambda_1} u_1$")

plt.arrow(0, 0, torch.sqrt(lmbda2) * u2[0], torch.sqrt(lmbda2) * u2[1], 
          color='k', head_width=0.2)
plt.text(torch.sqrt(lmbda2) * u2[0] - 1.5, torch.sqrt(lmbda2) * u2[1], 
         "$\sqrt{\lambda_2} u_2$")

plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.gca().set_aspect(1)
../_images/17a6249474cbdefed1e3a974b9f318785a3aeda0ea1165ddd47fe22cace5615e.png

Write a function to visualize covariance matrices#

def plot_cov(Sigma, mu=None, ax=None, **kwargs):
    """
    Simple function to visualize a covariance matrix.
    """
    # Set the mean to zero if not given
    D = Sigma.shape[-1]
    mu = torch.zeros(D,) if mu is None else mu

    # Compute the eigendecomposition
    Lmbda, U = torch.linalg.eigh(Sigma)

    # Find one isocontour
    zs = torch.column_stack([
        torch.cos(torch.linspace(0, 2 * torch.pi, 50)),
        torch.sin(torch.linspace(0, 2 * torch.pi, 50))
    ])
    xs = (zs * torch.sqrt(Lmbda)) @ U.T

    # plot the isocontour
    ax = plt.axes(aspect=1) if ax is None else ax
    ax.plot(mu[0] + xs[:, 0], mu[1] + xs[:, 1], **kwargs)

Draw many samples from a Wishart distribution and plot their inverses#

nu0 = torch.tensor(4.)
Lmbda0 = torch.eye(2) / nu0
Lmbdas = Wishart(nu0, covariance_matrix=Lmbda0).sample(sample_shape=(5, 5))

fig, axs = plt.subplots(5, 5, figsize=(8, 8), sharex=True, sharey=True)
for i in range(5):
    for j in range(5):
        plot_cov(torch.inverse(Lmbdas[i, j]), ax=axs[i, j])
        plot_cov(torch.eye(2), ax=axs[i, j], color='k')
        axs[i, j].set_xlim(-4, 4)
        axs[i, j].set_ylim(-4, 4)
        axs[i, j].set_aspect(1)

for i in range(5):
    axs[i, 0].set_ylabel("$x_2$")
    axs[-1, i].set_xlabel("$x_1$")
/opt/hostedtoolcache/Python/3.9.17/x64/lib/python3.9/site-packages/torch/distributions/wishart.py:253: UserWarning: Singular sample detected.
  warnings.warn("Singular sample detected.")
../_images/3abe9f8968e557a69dc97084eca9e8e062e30374740d887be7a833ac7e788227.png

Bayesian inference with unknown precision#

The Wishart distribution is a conjugate prior for the precision of a multivariate normal distribution.

ΛW(ν0,Λ0),xnN(μ,Λ1).

Then, letting η=(μ,ν0,Λ0),

p(Λx,η)W(Λν0,Λ0)n=1NN(xnμ,Λ1)|Λ|ν0D12e12Tr(Λ01Λ)n=1N|Λ|12e12(xnμ)Λ(xnμ)|Λ|ν0+ND12exp{12Tr([Λ01+n=1N(xnμ)(xnμ)]Λ)}

We recognize this as yet another Wishart distribution,

p(Λx,η)W(ΛνN,ΛN),

where

νN=ν0+NΛN=[Λ01+n=1N(xnμ)(xnμ)]1
# Sample data from a Gaussian with identity covariance (and precision)
torch.manual_seed(305)
N = 10
x = MultivariateNormal(torch.zeros(2), torch.eye(2)).sample((N,))
# Set a weak prior
nu_0 = torch.tensor(1.0)
Lmbda_0 = torch.eye(2)

# Compute the posterior distribution over the precision under a Wishart prior, 
# assuming the mean is known to be zero
nu_N = nu_0 + N 
Lmbda_N = torch.inverse(torch.inverse(Lmbda_0) + x.T @ x)
posterior = Wishart(nu_N, Lmbda_N)

# Plot posterior samples of the *covariance* (i.e. inverse precision)
precision_samples = posterior.sample((10,))
covariance_samples = torch.inverse(precision_samples)

ax = plt.axes(aspect=1)
for Sigma in covariance_samples:
    plot_cov(Sigma, ax=ax, color='r', alpha=0.5)
plt.plot(x[:, 0], x[:, 1], 'ko', markersize=6)

plt.xlim(-4, 4)
plt.ylim(-4, 4)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
Text(0, 0.5, '$x_2$')
../_images/264d16c9da2b7070ed50a16ae5940e5f082b785dd6b35bd79952700a5d6c1a51.png

Bayesian Inference with Unknown Mean and Precision#

What if both the mean and the precision are unknown? Then a normal-Wishart prior is conjugate with the multivariate normal likelihood. We say

μ,ΛNW(μ0,κ0,ν0,Λ0)

if

ΛW(ν0,Λ0)μΛN(μ0,(κ0Λ)1)

Under a multivariate Gaussian likelihood,

p(Xμ,Λ,η)=n=1NN(xnμ,Λ1).

Then, the posterior on the parameters is another normal-Wishart distribution with parameters

νN=ν0+NκN=κ0+NΛn=[Λ01+κ0μ0μ0+n=1NxnxnκNμNμN]1μN=1κN(κ0μ0+n=1Nxn)

Next, we’ll compute the posterior of μ and Λ and visualize it.

# Set a weak prior
nu_0 = torch.tensor(1.0)
Lmbda_0 = torch.eye(2)
kappa_0 = torch.tensor(1.0)
mu_0 = torch.zeros((2,))


# Compute the posterior distribution over the mean and precision under a 
# normal-Wishart prior with the hyperparameters above.
nu_N = nu_0 + N 
kappa_N = kappa_0 + N
mu_N = 1/kappa_N * (kappa_0 * mu_0 + x.sum(axis=0))
Lmbda_N = torch.inverse(torch.inverse(Lmbda_0) 
                        + kappa_0 * torch.outer(mu_0, mu_0) 
                        + x.T @ x
                        - kappa_N * torch.outer(mu_N, mu_N))

# Sample the posterior
posterior_Lambda = Wishart(nu_N, Lmbda_N)
Lambda_samples = posterior.sample((100,))
Sigma_samples = torch.inverse(Lambda_samples)

posterior_mu = MultivariateNormal(mu_N, 
                                  precision_matrix=Lambda_samples * kappa_N)
mu_samples = posterior_mu.sample()
# Plot posterior samples of the Gaussian parameters
ax = plt.axes(aspect=1)
for i, (mu, Sigma) in enumerate(zip(mu_samples[:10], Sigma_samples[:10])):
    plot_cov(Sigma, ax=ax, color='r', alpha=0.5, 
             label='$\Sigma$ samples' if i == 0 else None)

    plt.plot(mu[0], mu[1], 'ro', markersize=3, mec='k', 
             label='$\mu$ samples' if i == 0 else None)

plt.plot(x[:, 0], x[:, 1], 'ko', markersize=6, label='data')

plt.xlim(-4, 4)
plt.ylim(-4, 4)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.legend()
<matplotlib.legend.Legend at 0x7fea840677c0>
../_images/bb0554cce8882814c81f70cda1745f2b082b3bb6a921e0ff870a0443b3799c64.png

Recap#

This notebook introduced the MultivariateNormal and Wishart distributions. We gave examples of how you can visualize 2 dimensional covariance matrices (and inverse precision matrices) as ellipses in R2. This is a helpful way of visualizing prior distributions, like the Wishart distribution.

A couple of notes:

  • PyTorch does not currently have an inverse Wishart distribution object. In theory, you could implement one as a TransformedDistribution, but you would need to also write a MatrixInverseTransform object. None of this is particularly hard, but it was more than we chose to do for this demo.

  • In practice, many practitioners suggest using the LKJCholesky distribution as a prior on covariance matrices (technically, their square roots). It doesn’t have as simple a closed form update as the Wishart or inverse Wishart, but it is amenable to other inference techniques that we’ll discuss later in the course.