Simple Spike Sorting#

neuropixels

With that background, we come to our first neural data analysis problem: spike sorting. The black vertical stripes in Fig. 6d (reproduced above) are extracellular action potentials (EAPs) measured across adjacent recording channels on the Neuropixel probe, which arise from a spike on a nearby neuron.

The spike sorting problem is to identify the spikes in the multi-channel voltage recording and assign those spikes to individual neurons based on the spike waveform and the channels that were activated.

A simple probabilistic model#

I know we just picked on spherical cows, but there really is value in starting with simplified models. To get started on spike sorting, let’s consider a zoomed out view of a Neuropixels recording, like that shown in Fig. 6d. Specifically, let’s imagine downsampling the 30 kHz time series to 500 Hz. Then we can represent the multi-channel voltage recording as a matrix \(\mathbf{X} \in \mathbb{R}^{N \times T}\) where:

  • \(N\) is the number of channels

  • \(T\) is the number of 2 ms time bins

  • \(x_{n,t}\) is the average voltage on channel \(n\) during time bin \(t\).

At this temporal resolution, each spike is typically contained within a single time bin.

Assumptions#

To model this data, we will make a few assumptions:

  1. Assume there are \(K\) neurons in the vicinity of the probe. When the \(k\)-th neuron spikes, its EAP produces a signature waveform on the channels. We model the waveform as a vector, \(\mathbf{w}_k = (w_{k,1}, \ldots, w_{k,N}) \in \mathbb{R}^N\), where \(w_{k,n}\) represents the average magnitude of the EAP produced on channel \(n\) each time neuron \(k\) spikes.

  2. Let \(\mathbf{a}_k = (a_{k,1}, \ldots, a_{k,T}) \in \mathbb{R}_+^T\) denote the time series of spike amplitudes for neuron \(k\). Since neurons spike only a few times a second, the amplitude is typically zero, but when the neuron does spike it has a positive amplitude.

  3. If two neurons fire at the same, their waveforms add together in the measured voltage.

  4. The voltage recordings are noisy, so in addition to the spike waveforms we also have independent Gaussian noise \(\epsilon_{n,t} \in \mathcal{N}(0, \sigma^2)\) for each channel \(n\) and time bin \(t\).

Warning

Note that the amplitudes are non-negative real numbers (\(a_{k,t} \in \mathbb{R}_+\)). We do not allow spikes with negative amplitude.

The Gaussian distribution#

The Gaussian Distribution

We denote a Gaussian (aka normal) random variable \(x \in \mathbb{R}\) by,

\[ x \sim \mathcal{N}(\mu, \sigma^2), \]

where \(\mu = \mathbb{E}[x]\) is the mean and \(\sigma^2 = \mathbb{V}[x]\) is the variance. The Gaussian pdf is,

\[ \mathcal{N}(x; \mu, \sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}} \exp \left\{ -\frac{1}{2 \sigma^2}(x - \mu)^2\right\}. \]

The Gaussian distribution has many important properties. For example,linear transformations of \(x\) are also Gaussian:

\[ x \sim \mathcal{N}(\mu, \sigma^2) \Rightarrow ax + b \sim \mathcal{N}(a \mu + b, a^2 \sigma^2). \]

We will cover more nice properties of the Gaussian distribution as the course goes on.

With these assumptions, we model the measured voltage as a sum of spike waveforms, weighted by the amplitudes, with Gaussian noise:

\[ x_{n,t} \sim \mathcal{N} \left( \sum_{k=1}^K w_{k,n} a_{k,t}, \sigma^2 \right). \]

Matrix factorization perspective#

Here’s another way to view the model. Combine the waveforms and amplitudes into matrices

\[\begin{split} \begin{align*} \mathbf{W} &= \begin{bmatrix} | & & | \\ \mathbf{w}_1 & \ldots & \mathbf{w}_K \\ | & & | \end{bmatrix} \in \mathbb{R}^{N \times K}, & \mathbf{A} &= \begin{bmatrix} | & & | \\ \mathbf{a}_1 & \ldots & \mathbf{a}_K \\ | & & | \end{bmatrix} \in \mathbb{R}^{T \times K}. \end{align*} \end{split}\]

and let \(\mathbf{E} \in \mathbb{R}^{N \times T}\) denote the noise matrix with entries \(\epsilon_{n,t}\).

Our model is that \(\mathbf{X} = \mathbf{W} \mathbf{A}^\top + \mathbf{E}\). That is, the matrix of voltage measurements is the outer product of the waveforms and the amplitudes.

This is called a matrix factorization model, since the data matrix is modeled as the outer product of two factors (plus noise). In particular, since the amplitudes are constrained to be non-negative and the waveforms are unconstrained, this is a semi-non-negative matrix factorization model [Ding et al., 2008].

Accounting for scale invariance#

Notice that the model is invariant to rescaling. We could multiply the amplitudes \(\mathbf{a}_k\) by any positive constant \(c > 0\) and they would still be non-negative. As long as we multiply the corresponding waveforms by \(c^{-1}\), the product \(\mathbf{W} \mathbf{A}^\top\) remains unchanged.

We can remove this degree of freedom by constraining the weights to be unit norm; i.e. enforce \(\|\mathbf{w}_k\| = 1\) for all \(k\). One way to do this is by giving the waveforms a uniform prior on the unit sphere,

\[ \mathbf{w}_k \sim \mathrm{Unif}(\mathbb{S}_{N-1}). \]

Notation

We denote the unit hypersphere in \(N\) dimensions by

\[ \mathbb{S}_{N-1} = \left\{ \mathbf{u}: \mathbf{u} \in \mathbb{R}^N \text{ and } \|\mathbf{u}\|_2 = 1 \right\} \]

It is a (\(N-1\))-dimensional manifold embedded in \(\mathbb{R}^N\).

The uniform distribution

We denote a random variable \(\mathbf{x} \in \mathbb{X}\) that follows the uniform distribution by

\[ \mathbf{x} \sim \mathrm{Unif}(\mathbb{X}) \]

where \(\mathbb{X}\) is the support. It has density

\[ \mathrm{Unif}(\mathbf{x}; \mathbb{X}) = \frac{1}{|\mathbb{X}|} \cdot \mathbb{I}[\mathbf{x} \in \mathbb{X}] \]

where \(|\mathbb{X}|\) is the volume of \(\mathbb{X}\).

Prior on amplitudes#

To complete the model, we place an exponential prior on amplitudes,

\[ a_{k,t} \sim \mathrm{Exp}(\lambda) \]

where \(\lambda\) is the inverse-scale (aka rate) parameter.

The Exponential distribution

We denote an exponential random variable \(x \in \mathbb{R}_+\) by,

\[ x \sim \mathrm{Exp}(\lambda) \]

where \(\lambda\) is the inverse scale or rate parameter. It has mean \(\mathbb{E}[x] = \lambda^{-1}\) and variance \(\mathbb{V}[x] = \lambda^{-2}\). Its pdf is,

\[ \mathrm{Exp}(x; \lambda) = \lambda e^{-\lambda x}. \]

Exercise

Compare the gamma pdf from the last chapter to the exponential pdf above. Show that the exponential distribution is a special case of the gamma distribution.

We treat the waveforms and the noise variance as parameters; i.e. we don’t put priors on them.

The joint probability#

Finally, we can write the joint probability,

\[\begin{split} \begin{align*} p(\mathbf{X}, \mathbf{W}, \mathbf{A}) &= p(\mathbf{X} \mid \mathbf{W}, \mathbf{A}) \, p(\mathbf{W}) \, p(\mathbf{A}) \\ &= \left[ \prod_{n=1}^N \prod_{t=1}^T \mathcal{N} \left(x_{n,t} \mid \sum_{k=1}^K w_{k,n} a_{k,t}, \sigma^2 \right) \right] \\ &\qquad \times \left[ \prod_{k=1}^K \mathrm{Unif}(\mathbf{w}_k; \mathbb{S}_{N-1}) \right] \times \left[ \prod_{k=1}^K \prod_{t=1}^T \mathrm{Exp}(a_{k,t}; \lambda) \right]. \end{align*} \end{split}\]

(As before, we suppressed the dependence on the parameters; i.e., \(\sigma^2\) and \(\lambda\).)

Fitting the model#

Like last time, we will “fit” the model by performing maximum a posteriori (MAP) estimation with coordinate ascent. Specifically, our algorithm will be:

  • repeat until convergence:

    • for \(k=1,\ldots,K\):

      • Set \(\mathbf{w}_k = \text{arg max} \; p(\mathbf{X}, \mathbf{W}, \mathbf{A})\) holding all else fixed

      • Set \(\mathbf{a}_k = \text{arg max} \; p(\mathbf{X}, \mathbf{W}, \mathbf{A})\) holding all else fixed

Updating the waveforms#

First, consider optimizing the waveforms. Maximizing the joint probability wrt \(\mathbf{w}_k\) is equivalent to maximizing the log joint probability, which is

\[\begin{split} \begin{align*} \log p(\mathbf{X}, \mathbf{W}, \mathbf{A}) &= \sum_{n=1}^N \sum_{t=1}^T \log \mathcal{N}\left(x_{n,t} \mid \sum_{k=1}^K w_{k,n} a_{k,t}, \sigma^2 \right) + c \\ &= -\frac{1}{2\sigma^2} \sum_{n=1}^N \sum_{t=1}^T \left(x_{n,t} - \sum_{k=1}^K w_{k,n} a_{k,t}\right)^2 + c' \\ &= -\frac{1}{2\sigma^2} \sum_{n=1}^N \sum_{t=1}^T \left(r_{n,t} - w_{k,n} a_{k,t} \right)^2 + c' \\ \end{align*} \end{split}\]

where

\[ r_{n,t} = x_{n,t} - \sum_{j \neq k} w_{j,n} a_{j,t} \]

is the residual, and \(c\) and \(c'\) are constants wrt \(\mathbf{w}_k\).

The solution will become more clear if we write it in vector form. Let \(\mathbf{r}_t = (r_{1,t}, \ldots, r_{n,t})\) denote the vector of residuals. Then we can get rid of the sum over channels and write the log joint probability as,

\[\begin{split} \begin{align*} \log p(\mathbf{X}, \mathbf{W}, \mathbf{A}) &= -\frac{1}{2\sigma^2} \sum_{t=1}^T (\mathbf{r}_{t} - \mathbf{w}_{k} a_{k,t})^\top (\mathbf{r}_{t} - \mathbf{w}_{k} a_{k,t}) + c' \\ &= \sum_{t=1}^T \mathcal{N}(\mathbf{r}_t; \mathbf{w}_k a_{k,t}, \sigma^2 \mathbf{I}) + c' \\ \end{align*} \end{split}\]

where we have taken this opportunity to introduce the multivariate normal distribution.

The multivariate normal distribution

The multivariate normal distribution is the multi-dimensional generalization of the scalar Gaussian/normal distribution introduced above. We denote a multivariate normal random variable \(\mathbf{x} \in \mathbb{R}^D\) by,

\[ \mathbf{x} \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma}). \]

Its density is,

\[ \mathcal{N}(\mathbf{x}; \boldsymbol{\mu}, \boldsymbol{\Sigma}) = (2 \pi)^{-\frac{D}{2}} |\boldsymbol{\Sigma}|^{-\frac{1}{2}} \exp \left\{ -\frac{1}{2} (\mathbf{x} - \boldsymbol{\mu})^\top \boldsymbol{\Sigma}^{-1} (\mathbf{x} - \boldsymbol{\mu}) \right\}. \]

When \(\boldsymbol{\Sigma} = \sigma^2 \mathbf{I}\), we call it a spherical Gaussian distribution. Then its density reduces to a product of scalar Gaussian densities, implying that the entries of \(\mathbf{x}\) are independent Gaussian random variables.

\[ \begin{align*} \mathcal{N}(\mathbf{x}; \boldsymbol{\mu}, \sigma^2 \mathbf{I}) &= \prod_{d=1}^D \mathcal{N}(x_d \mid \mu_d, \sigma^2). \end{align*} \]

Exercise

Prove the equality between the spherical Gaussian densities and the product of scalar Gaussian densities.

Finishing the optimization#

To complete the optimization wrt \(\mathbf{w}_k\), let’s further expand the multivariate density and drop terms that don’t depend on the variable of interest,

\[ \begin{align*} \log p(\mathbf{X}, \mathbf{W}, \mathbf{A}) &= \frac{1}{\sigma^2} \sum_{t=1}^T \left(\mathbf{r}_{t}^\top \mathbf{w}_{k} a_{k,t} - \frac{a_{k,t}^2}{2} \mathbf{w}_k^\top \mathbf{w}_k \right) + c'' \end{align*} \]

Remember this is a constrained optimization problem: \(\mathbf{w}_k\) must be a normalized vector. Note that \(\mathbf{w}_k^\top \mathbf{w}_k = 1\) for all \(\mathbf{w}_k \in \mathbb{S}_{N-1}\), so the second term in parentheses is actually a constant for all valid \(\mathbf{w}_k\).

Thus, maximizing with respect to \(\mathbf{w}_k\) amounts to solving the following constrained optimization problem,

\[\begin{split} \begin{align*} \mathbf{w}_k^\star &= \text{arg}\, \max_{\mathbf{w}_{k} \in \mathbb{S}_{N-1}} \left( \sum_{t=1}^T a_{k,t} \mathbf{r}_{t} \right)^\top \mathbf{w}_{k} \\ &= \text{arg}\, \max_{\mathbf{w}_{k} \in \mathbb{S}_{N-1}} \left\langle \sum_{t=1}^T a_{k,t} \mathbf{r}_t, \, \mathbf{w}_{k} \right \rangle \\ &= \text{arg}\, \max_{\mathbf{w}_{k} \in \mathbb{S}_{N-1}} \left\langle \mathbf{R} \mathbf{a}_k, \, \mathbf{w}_{k} \right \rangle \end{align*} \end{split}\]

where \(\mathbf{R} \in \mathbb{R}^{N \times T}\) is the matrix of residuals with columns \([\mathbf{r}_1, \ldots, \mathbf{r}_T]\).

Maximing this linear objective subject to a unit norm constraint has a well known solution — make \(\mathbf{w}_k\) parallel to the other vector in the inner product:

\[ \mathbf{w}_k^\star \propto \mathbf{R} \mathbf{a}_k. \]

Updating the amplitudes#

Now let’s derive updates for the amplitudes. As a function of \(a_{k,t}\), the log joint probability is,

\[ \begin{align*} \log p(\mathbf{X}, \mathbf{W}, \mathbf{A}) &= \frac{\mathbf{r}_{t}^\top \mathbf{w}_{k} a_{k,t}}{\sigma^2} - \frac{a_{k,t}^2}{2 \sigma^2} - \lambda a_{k,t} + \log \mathbb{I}[a_{k,t} \geq 0] + c \end{align*} \]

Maximizing wrt \(a_{k,t}\) is simply maximizing a quadratic objective over the non-negative reals,

Quadratic optimzation with non-negativity constraints

To solve the constrained optimization problem, \(\max_{x \geq 0} f(x)\) with

\[ f(x) = -\frac{\alpha}{2} x^2 + \beta x + \gamma, \]

and \(\alpha > 0\), note that the objective is concave and the unconstrained solution is at \(x^\star = \beta / \alpha\). If \(x^\star < 0\), then \(f(x)\) must be decreasing at \(x=0\) so that the constrained optimum is at,

\[ x^\star = \max \left\{0, \beta/\alpha \right\} \]

and the solution is,

\[\begin{split} f(x^\star) = \begin{cases} \gamma + \frac{\beta^2}{2 \alpha} & \text{if } \beta \geq 0 \\ \gamma & \text{if } \beta < 0 \end{cases} \end{split}\]

From the box above, we have,

\[ a_{k,t}^\star = \max \left\{0, \, \sigma^2 \left(\frac{\mathbf{r}_{t}^\top \mathbf{w}_{k}}{\sigma^2} - \lambda \right)\right\} = \max\left\{0, \, \mathbf{r}_{t}^\top \mathbf{w}_{k} - \lambda \sigma^2 \right\} \]

The first term, \(\mathbf{r}_{t}^\top \mathbf{w}_{k}\), is the projection of the residual onto the waveform for neuron \(k\). The product of hyperparameters \(\lambda\) and \(\sigma^2\) defines the threshold that projection must exceed in order to designate a spike in amplitude.

In vector form, this simplifies to

\[ \mathbf{a}_k^\star = \max \{0, \, \mathbf{R}^\top \mathbf{w}_k - \lambda \sigma^2 \} \]

Shrinkage

Note that even if the projection \(\mathbf{r}_{t}^\top \mathbf{w}_{k}\) exceeds the threshold, the optimal amplitude is still “shrunk” by a factor of \(\lambda \sigma^2\). This is a common feature of \(\ell_1\) optimization problems; i.e., optimization problems with regularization on the \(\ell_1\) norm of the solution. The exponential prior on amplitudes induces such an optimization problem, and indeed there is a close correspondence between MAP estimation in Bayesian models with exponential priors and maximum likelihood estimation with \(\ell_1\) regularization, like the LASSO problem [Hastie et al., 2009].

The final algorithm#

Now that we have derived both updates, we can revise our final algorithm slightly:

  • repeat until convergence:

    • for \(k=1,\ldots,K\):

      • Compute the residual \(\mathbf{R} = \mathbf{X} - \sum_{j \neq k} \mathbf{w}_j \mathbf{a}_j^\top\)

      • Set \(\mathbf{w}_k \propto \mathbf{R} \mathbf{a}_k\)

      • Set \(\mathbf{a}_k = \max \{0, \, \mathbf{R}^\top \mathbf{w}_k - \lambda \sigma^2 \}\)

Note

You don’t have to recompute the residual from scratch each iteration. Just do a rank one update after each iteration.

Conclusion#

This chapter introduced the spike sorting problem for electrophysiological (“ephys”) recordings with modern tools like Neuropixels. The same strategy applies to other ephys methods, like dense multielectrode arrays.

In this first pass, we framed the problem as one of semi-non-negative matrix factorization, or semi-NMF. We took a Bayesian approach, with priors on the waveforms and amplitudes. We derived a coordinate ascent algorithm for MAP estimation and found that the updates have pleasingly simple form.

However, we made an undesirable assumption in deriving this algorithm: we started by downsampling our voltage traces to the point that a spike fit within a single time bin. What is the point of recording at 30kHz if we are going to average into 2ms bins!? In the next chapter, we will relax this assumption and show how to do spike sorting with a convolutional matrix factorization model.

Further Reading#

Modeling via matrix factorization is one of those great ideas that has been discovered over and over again. In the probabilistic machine learning literature, some nice references include:

Many models are really matrix factorization in disguise. For example, latent Dirichlet allocation [Blei et al., 2003] is quite similar to NMF.

I don’t know of a reference for our specific model with norm constraints and exponential priors for spike sorting. However, it falls nicely within this broader landscape of matrix factorization methods, and it leads nicely to the model people actually use, as we’ll see next.