Decoding Models#
If encoding models approximate the conditional probability of neural activity \(y\) given covariates \(x\), then decoding models do the reverse: they predict covariates given neural activity. For example, we could decode arm movements from neural activity in motor cortex to drive a prosthetic limb. Or we could decode an animal’s position by measuring the activity of place cells in its hippocampus. Or we could infer what a person is seeing based on measurements of neural activity in visual cortex. All of these are examples of decoding an external signal given neural measurements.
At the end of the day, it’s all just regression. With enough data, we can train a predictive model to output \(p(x \mid y)\). However, sometimes we can be a bit more clever. With a good encoding model and a prior distribution over covariates, we can build Bayesian decoders to estimate the posterior distribution of covariates \(x\) given neural activity \(y\). If we don’t have all the ingredients for a Bayesian decoder — if we have a good prior but a poor encoding model — we can build structured prediction models that capture key constraints on the outputs. This chapter will discuss both of these approaches.
Bayesian decoders#
Let \(\mathbf{y}_t \in \mathbb{R}^N\) denote the neural measurement at time \(t\), and let \(\mathbf{x}_t \in \mathbb{R}^D\) be the covariate of interest. For example, \(\mathbf{y}_t \in \mathbb{N}_0^N\) could be a vector of spike counts and \(\mathbf{x}_t \in \mathbb{R}^2\) the position of an animal’s hand.
Suppose we have a good estimate of the prior distribution on hand positions \(p(\mathbf{x}_t)\). Likewise, suppose we have fit an encoding model \(p(\mathbf{y}_t \mid \mathbf{x}_t)\) to predict neural activity as a function of hand position. Then we can use Bayes’ rule to infer the posterior distribuition over positions given neural activity,
The encoder is the likelihood and the decoder is the posterior.
Simple model#
Consider a Gaussian prior,
Assume time bins are independent; i.e., \(p(\mathbf{x}_{1:T}) = \prod_{t=1}^T p(\mathbf{x}_t)\).
Now assume a linear Gaussian encoding model,
with
emission matrix \(\mathbf{C} \in \mathbb{R}^{N \times D}\)
emission bias \(\mathbf{d} \in \mathbb{R}^{N}\)
emission covariance \(\mathbf{R} \in \mathbb{R}_{++}^{N \times N}\) (a positive definite matrix)
There are a million things wrong with this model. To name a few,
It assumes hand positions are independent across time
It assumes spike counts are independent across time given the covariates
It assumes a continuous (Gaussian) model for spike counts, which are discrete random variables
It assumes the noise covariance is static, whereas the variance of spike counts tends to grow with the mean
Still, it is mathematically convenient to work with. Let’s derive the posterior distribution,
where
The posterior is proportional to an exponentiated quadratic, so it too is a Gaussian distribution! \(\mathbf{J}_t\) and \(\mathbf{h}_t\) are its natural paramters. (Refer back to the chapter on GLMs and exponential family distributions!) To convert the natural paramters back into mean parameters, we need to complete the square,
Information Form#
The “Information Form” of the Gaussian distribution
The multivariate normal (aka Gaussian) distribution is typically written in terms of its mean \(\boldsymbol{\mu}\) and covariance \(\boldsymbol{\Sigma}\). However, for Bayesian inference, it is often easier to work with its information form,
where \(\mathbf{J} = \boldsymbol{\Sigma}^{-1}\) is the precision matrix and \(\mathbf{h} = \boldsymbol{\Sigma}^{-1} \boldsymbol{\mu}\) is the precision-weighted mean. The mapping between standard parameters and information parameters is bijective. The inverse is \(\boldsymbol{\Sigma} = \mathbf{J}^{-1}\) and \(\boldsymbol{\mu} = \mathbf{J}^{-1} \mathbf{h}\).
The precision matrix (really, its determinant) is a measure of information content. Multivariate normal distributions with with small covariance have high precision — there is little about \(\mathbf{x}\).
The information form is nice because conditioning is easy. Conditioning amounts to multiplying two Gaussian densities together, and in information form that simply means adding their linear (\(\mathbf{h}\)) and quadratic (\(\mathbf{J}\)) coefficients. By contrast, marginalizaztion is easy in the standard form. To obtain the marginal distribution of a subset of coordinates of a multivariate normal random vector, we simply extract the corresponding blocks of the mean vector and the covariance matrix.
The information form is very closely related to the natural exponential family form of the Gaussian distribution, whose sufficient statistics are \(\mathbf{x}\) and \(\mathbf{x} \mathbf{x}^\top\).
Linear dynamical system prior#
The prior was arguably the “weakest link” of the simple model above. Hand positions are obviously not independent across time. An easy way to improve upon the simple model is to incorporate temporal dependencies into the prior,
This is a vector autoregressive model with first-order dependencies (\(\mathbf{x}_t\) only depends on \(\mathbf{x}_{t-1}\)). Sometimes it is abbreviated as a VAR(1) model. It is also called a linear dynamical system.
The model is parameterized by a,
dynamics matrix \(\mathbf{A} \in \mathbb{R}^{D \times D}\)
dynamics covariance \(\mathbf{Q} \in \mathbb{R}_{++}^{D \times D}\) (a positive definite covariance matrix)
We could optionally include a dynamics bias term, like the emission bias above.
Computing the posterior#
The key point is that prior is still a linear Gaussian model. As such, the posterior distribution is as well:
where \(\mathrm{vec}(\mathbf{x}_{1:T}) = (\mathbf{x}_1^\top, \ldots, \mathbf{x}_T^\top)^\top \in \mathbb{R}^{TD}\) are the vectorized hand positions.
The mean paramters are more easily derived from the natural parameters,
where
The diagonal blocks of the precision matrix are \(\mathbf{J}_{tt} = \mathbf{Q}^{-1} + \mathbf{A}^\top \mathbf{Q}^{-1} \mathbf{A} + \mathbf{C}^\top \mathbf{R}^{-1} \mathbf{C}\) (except for \(\mathbf{J}_{11}\) and \(\mathbf{J}_{TT}\)).
The lower diagonal blocks are \(\mathbf{J}_{t,t-1} = - \mathbf{Q}^{-1} \mathbf{A}\)
The precision-weighted mean blocks are \(\mathbf{h}_{t} = \mathbf{C}^\top \mathbf{R}^{-1} (\mathbf{y}_t - \mathbf{d})\)
To get the mean parameters, set \(\boldsymbol{\Sigma} = \mathbf{J}^{-1}\) and \(\boldsymbol{\mu} = \mathbf{J}^{-1} \mathbf{h}\).
Solving for the mean parameters
We’ve shown that the posterior is Gaussian and derived its natural parameters. Converting them into mean parameters requires solving for \(\mathbf{J}^{-1}\). Naively, this would take \(\mathcal{O}(T^3 D^3)\) time — cubic in the length of the time series! However, note that the precision matrix is block tridiagonal. This is a special case of a sparse banded matrix, and it can be solved in only linear time, \(\mathcal{O}(T D^3)\), using the sparse matrix solvers. In particular, solving this sparse system of equations is equivalent to a Kalman smoother — a canonical inference algorithm for linear dynamical systems.
Beyond Gaussian encoders#
When working with spike counts \(\mathbf{y}_t \in \mathbb{N}_0^N\), the Gaussian encoding model is less than ideal. A more reasonable model would have a discrete distribution over counts, like the Poisson GLM,
where \(f: \mathbb{R} \mapsto \mathbb{R}_+\) is a rectifying nonlinearity.
The posterior distribution under this model is no longer Gaussian, but it’s common to approximate it as one. For example, the Laplace approximation is a Gaussian approxiamtion centered on the posterior mode (i.e., the MAP estimate),
where
For GLM encoders like the one above, the negative log joint is convex and the MAP estimate, \(\boldsymbol{\mu}\), can be found efficiently. Likewise, the Hessian, \(\nabla^2 \mathcal{L}(\mathbf{x}_{1:T})\), is block tridiagonal (just like the precision was before), so the relevant blocks of \(\boldsymbol{\Sigma}\) can be found efficiently as well.
Exercise
Show that the Laplace approximation is equal to the Gaussian posterior from above when the likelihood is the linear Gaussian model from before.
Structured prediction#
If we’re going to make a Gaussian approximation anyway, why not relax the encoding model even further? Let’s keep the prior on \(\mathbf{x}_{1:T}\) since it adds reasonable temporal dependencies, but allow for more general dependencies on the data as follows,
where \(\phi(\mathbf{x}_t ; \mathbf{y}_{1:T})\) is a potential function. Since the prior is a linear Gaussian model, if we constrain the potentials to be quadratic functions of \(\mathbf{x}_t\) then the posterior will still be Gaussian. Specifically, assume,
where \(\mathbf{J}_t(\mathbf{y}_{1:T}): \mathbb{R}^{T \times N} \mapsto \mathbb{R}_{++}^{D \times D}\) outputs a positive definite precision matrix and \(\mathbf{h}_t(\mathbf{y}_{1:T}): \mathbb{R}^{T \times N} \mapsto \mathbb{R}^D\) outputs a precision-weighted mean.
The key idea is that the posterior will still be a Gaussian distribution with a block tridiagonal precision matrix, just like above. Therefore, our efficient algorithms like Kalman smoothing will still apply, even in this more general setting.
This model is called a conditional random field (CRF), and it is used for structured prediction. The catch is that we need to learn the parameters of the \(\mathbf{J}_t\) and \(\mathbf{h}_t\) functions. In practice, these could be implemented by convolutional neural networks that share parameters across time. We can estimate the parameters by stochastic gradient descent on the negative log probability.
Conclusion#
Encoding and decoding are two sides of the same coin.
We can treat decoding as a simple regression problem, but sometimes we can leverage prior information about \(\mathbf{x}\) and an encoder \(p(\mathbf{y} \mid \mathbf{x})\).
Bayes’ rule tells us how to combine the prior and likelihood to derive the posterior distribution. That is how we construct Bayesian decoders.
However, the posterior rarely has a simple closed form, so we need to make approximations.
Structured decoders take this idea one step further, allowing us to learn features of the data that can be combined with the structured prior distribution to obtain flexible decoders.