Spike Sorting by Deconvolution#

The last chapter framed spike sorting as a matrix factorization problem — specifically, a semi-nonnegative matrix factorization (semi-NMF) problem. However, that model only makes sense under the simplifying assumption that the raw voltage is downsampled to \(\sim\) 500Hz. Otherwise, spike waveforms would be spread over many time bins.

In this chapter we’ll relax that assumption and develop a more realistic model using convolutional matrix factorization. This model is inspired by Kilosort, a state-of-the-art spike sorting algorithm [Pachitariu et al., 2023].

Convolutional model#

Like before, let \(x_{n,t}\) denote the voltage on channel \(c\) and time sample \(t\), but now consider the data at its native resolution of around 30 kHz (i.e. voltage is sampled every \(\sim 0.03\) ms. At this sampling frequency, a spike waveform typically at least 60-90 time steps.

Let \(\mathbf{W}_k \in \mathbb{R}^{N \times D}\) denote a waveform. In this model, it is a matrix for each neuron, where \(D\) denotes the number of time steps that a spike waveform persists in the voltage recording.

Let \(w_{k,n,d}\) denote the entries of waveform \(\mathbf{W}_k\), and let \(\mathbf{W} = \{\mathbf{W}_k\}_{k=1}^K\) be shorthand for the set of waveforms for all \(K\) neurons. Let \(\mathbf{A} \in \mathbb{R}_+^{T \times K}\) denote the matrix of spike amplitudes, as before. Now, \(a_{k,t} = 1\) denotes the start of a unit-amplitude spike with waveform \(\mathbf{W}_k\) at time \(t\).

The new model’s likelihood is,

\[\begin{split} \begin{align*} p(\mathbf{X} \mid \mathbf{W}, \mathbf{A}) &= \prod_{n=1}^N \prod_{t=1}^T \mathcal{N} \left( x_{n,t} \, \bigg| \, \sum_{k=1}^K \sum_{d=1}^D a_{k, t-d} \, w_{k,n,d}, \sigma^2 \right) \\ &= \prod_{n=1}^N \prod_{t=1}^T \mathcal{N} \left( x_{n,t} \, \bigg| \, \sum_{k=1}^K [\mathbf{a}_{k} \circledast \mathbf{w}_{k,n}]_t, \sigma^2 \right) \\ &= \prod_{t=1}^T \mathcal{N} \left( \mathbf{x}_{t} \, \bigg| \, \sum_{k=1}^K [\mathbf{a}_{k} \circledast \mathbf{W}_{k}]_{:,t}, \sigma^2 \mathbf{I} \right) \end{align*} \end{split}\]

where \(\circledast\) denotes the discrete time convolution.

Convolution and cross-correlation#

The discrete time convolution of a signal \(\mathbf{x}\) with a filter \(\mathbf{f}\) is defined as

\[ [\mathbf{x} \circledast \mathbf{f}]_t = \sum_{d = -\infty}^{\infty} x_{t - d} f_d. \]

Of course, in practice we’re dealing with finite-length vectors \(\mathbf{x} \in \mathbb{R}^T\) and filters \(\mathbf{f} \in \mathbb{R}^D\), so we need to decide how to deal with boundary effects. One possibility is to pad \(\mathbf{x}\) with \(D-1\) zeros; another is to return only the “valid” section of the convolution. Yet another is to assume the signal is periodic, so that the convolution wraps around when the index \(t-d\) is negative. That is called a circular convolution.

The cross-correlation is

\[ [\mathbf{x} \star f]_t = \sum_{d = -\infty}^{\infty} x_{t + d} f_d. \]

Thus, convolution is equivalent to cross-correlation with a reversed filter.

Warning

Unfortunately, the definition of cross-correlation is not unique; our definition consistent with Numpy’s correlate function, but it’s what Wikipedia would call \([f \star x]_t\) instead (note the order is swapped).

To make matters more confusing, the “convolution” operation performed by most neural network libraries is actually a cross-correlation (with Wikipedia’s semantics).

Since cross-correlations (convolutions in machine learning parlance) are such fundamental building blocks of modern neural networks, libraries like PyTorch have flexible APIs for performing a variety of types of convolutions. For example, with torch.nn.functional.conv1d you can cross-correlate a bank of 1D signal \(\mathbf{X} \in \mathbb{R}^{N \times T}\) with a bank of filters \(\mathbf{F} \in \mathbb{R}^{N \times D}\) by varying the number of in_channels and out_channels.

Notation

The notation for convolutions with multiple input/output channels is less standardized. We will let \(\mathbf{Y} = \mathbf{x} \star \mathbf{F}\) denote the cross-correlation of a signal \(\mathbf{x} \in \mathbb{R}^T\) with a bank of filters \(\mathbf{F} \in \mathbb{R}^{N \times D}\), which yields a bank of outputs \(\mathbf{Y} \in \mathbb{R}^{N \times T'}\) (the length depends on the padding strategy).

Scale invariance through the waveform prior#

Just like before, there is a scale invariance between \(\mathbf{a}_k\) and \(\mathbf{W}_k\). Last chapter, we placed constrained the waveform vector \(\mathbf{w}_k \in \mathbb{R}^N\) to have unit Euclidean (\(\ell_2\)) norm. Now that the waveforms are matrices \(\mathbf{W}_k \in \mathbb{R}^{N \times D}\), the natural generalization is to constrain the Frobenius norm of the waveforms,

\[ \|\mathbf{W}_k\|_{\mathrm{F}} = 1. \]

The Frobenius norm and the SVD#

The Frobenius norm can be rewritten in many ways.

  1. It is equal to the Euclidean (\(\ell_2\)) norm of the vectorized matrix,

    \[ \|\mathbf{W}_k\|_{\mathrm{F}} = \|\mathrm{vec}(\mathbf{W}_k)\|_2 \]
  2. It is the norm induced by the Frobenius inner product of a matrix with itself,

    \[ \|\mathbf{W}_k\|_{\mathrm{F}} = \sqrt{\langle \mathbf{W}_k, \mathbf{W}_k \rangle_{\mathrm{F}} } \]

    where

    \[ \langle \mathbf{A}, \mathbf{B} \rangle_{\mathrm{F}} = \mathrm{Tr}(\mathbf{A}^\top \mathbf{B}) \]
  3. It is the Euclidean norm of the vector of singular values of the matrix. Let \(\mathbf{W}_k = \mathbf{U}_k \mathbf{S}_k \mathbf{V}_k^\top\) where \(\mathbf{U}_k\) and \(\mathbf{V}_k\) are semi-orthogonal matrices, and where \(\mathbf{S}_k = \mathrm{diag}(\mathbf{s}_k)\) is the diagonal matrix of singular values, \(\mathbf{s}_k = (s_{k,1}, \ldots, s_{k,R})\). Then,

    \[\begin{split} \begin{align*} \|\mathbf{W}_k\|_{\mathrm{F}} &= \sqrt{\mathrm{Tr}(\mathbf{W}_k^\top \mathbf{W}_k)} \\ &= \sqrt{\mathrm{Tr}(\mathbf{V}_k \mathbf{S}_k \mathbf{U}_k^\top \mathbf{U}_k \mathbf{S}_k \mathbf{V}_k^\top)} \\ &= \sqrt{\mathrm{Tr}(\mathbf{V}_k^\top \mathbf{V}_k \mathbf{S}_k \mathbf{U}_k^\top \mathbf{U}_k \mathbf{S}_k )} \\ &= \sqrt{\mathrm{Tr}(\mathbf{S}_k^2)} \\ &= \|\mathbf{s}_k\|_2. \end{align*} \end{split}\]

Singular Value Decomposition (SVD)

Recall that the (compact) singular value decomposition (SVD) of a real valued matrix \(\mathbf{W} \in \mathbb{R}^{N \times D}\) is a factorization of the form,

\[ \mathbf{W} = \mathbf{U} \mathbf{S} \mathbf{V}^\top \]

where \(\mathbf{U} \in \mathbb{R}^{N \times R}\) and \(\mathbf{V} \in \mathbb{R}^{D \times R}\) with \(R \leq \min\{N, D\}\) are real semi-orthogonal matrices (\(\mathbf{U}^\top \mathbf{U} = \mathbf{I}\) and \(\mathbf{V}^\top \mathbf{V} = \mathbf{I}\)). The diagonal matrix \(\mathbf{S} = \mathrm{diag}(\mathbf{s})\) contains the singular values \(\mathbf{s} = (s_1, \ldots, s_R)\). The number of nonzero singular values \(R\) is the rank of the matrix \(\mathbf{W}\).

Equivalently, the SVD can be written as a sum of outer products,

\[ \mathbf{W} = \sum_{r=1}^R s_r \mathbf{u}_r \mathbf{v}_r^\top. \]

Constraining the rank of of the waveform matrices#

Thinking of the Frobenius norm constraint in terms of a constraint on the singular values leads to a natural extension. Rather than just constraining the norm, constrain the rank of the waveform matrices as well.

There are at least two reasons why this is sensible:

  1. Manually identified spike waveforms are well approximated as outer product of a spatial footprint \(\mathbf{u}_k \in \mathbb{S}_{N-1}\) and a temporal profile \(\mathbf{v}_k \in \mathbb{S}_{D-1}\),

    \[ \mathbf{W}_k \approx \mathbf{u}_k \mathbf{v}_k^\top. \]

    Note that this is a rank \(R=1\) matrix.

  2. Constraining the waveform rank can dramatically reduce the number of free parameters, which is good from a statistical estimation standpoint. For example, if we constrain the wave forms to be rank 1 then the waveforms have only \(\mathcal{O}(N + D)\) free parameters in contrast to \(\mathcal{O}(ND)\) free parameters in the full-rank model.

    Exercise

    I used big-O notation because the norm constraints remove additional degrees of freedom. How many degrees of freedom do the rank-1 and full rank models truly have?

We will constrain the waveforms to be rank \(R\) via a uniform prior

\[\mathbf{W}_k \sim \mathrm{Unif}(\mathbb{S}_R^{N,D})\]

where

\[ \mathbb{S}_R^{N,D} = \left\{\mathbf{W}: \mathbf{W} \in \mathbb{R}^{N \times D}, \mathrm{rank}(\mathbf{W}) = R, \|\mathbf{W}\|_{\mathrm{F}} = 1 \right\} \]

is the set of unit-norm, rank-\(R\) matrices in \(\mathbb{R}^{N \times D}\).

When \(R=1\), these matrices can be expressed as \(\mathbf{W}_k = \mathbf{u}_k \mathbf{v}_k^\top\), where \(\mathbf{u}_k \in \mathbb{S}_{N-1}\) and \(\mathbf{v}_k \in \mathbb{S}_{D-1}\).

We will use the same exponential prior on the amplitudes as in the previous chapter.

Maximum a posteriori (MAP) estimation#

Like before, we will fit the model by using coordinate ascent to maximize the posterior probability, which is proportional to the joint probability. Again, that will entail updating the amplitudes given the waveforms, and then the waveforms given the amplitudes. When updating the parameters for neuron \(k\), the solutions will depend on the residual,

\[ \mathbf{R} = \mathbf{X} - \sum_{j \neq k} \mathbf{a}_j \circledast \mathbf{W}_j. \]

where \(\mathbf{R} \in \mathbb{R}^{N \times T}\) has columns \(\mathbf{r}_{t}\) and entries \(r_{n,t}\).

Warning

Technically, we should refer to the residual matrix as \(\mathbf{R}_k\) since it is the residual when updating that neuron, but the notation gets a bit cumbersome, and it will be clear from context.

Optimizing the amplitudes#

As a function of the waveform \(\mathbf{a}_k\) for neuron \(k\), the log joint probability is,

\[\begin{split} \begin{align*} \log p(\mathbf{X}, \mathbf{A}, \mathbf{W}) &= \sum_{t=1}^T \sum_{n=1}^N \log \mathcal{N}\left(\mathbf{r}_{n,t} \,\bigg|\, [\mathbf{a}_k \circledast \mathbf{w}_{k,n}]_{t}, \sigma^2 \right) + \sum_{t=1}^T \mathrm{Exp}(a_{k,t}; \lambda) \\ &= -\frac{1}{2\sigma^2} \| \mathbf{R} - \mathbf{a}_k \circledast \mathbf{W}_k \|_{\mathrm{F}}^2 - \sum_{t=1}^T \lambda a_{k,t} \\ &= \underbrace{-\frac{1}{2\sigma^2} \| \mathbf{a}_k \circledast \mathbf{W}_k \|_{\mathrm{F}}^2}_{\mathcal{L}_2(\mathbf{a}_k)} + \underbrace{\frac{1}{\sigma^2} \langle \mathbf{R}, \mathbf{a}_k \circledast \mathbf{W}_k \rangle_{\mathrm{F}}}_{\mathcal{L}_1(\mathbf{a}_k)} - \sum_{t=1}^T \lambda a_{k,t}. \end{align*} \end{split}\]

The linear term#

Lets start by unpacking the linear term,

\[\begin{split} \begin{align*} \mathcal{L}_1(\mathbf{a}_k) &= \frac{1}{\sigma^2} \langle \mathbf{R}, \mathbf{a}_k \circledast \mathbf{W}_k \rangle \\ &= \frac{1}{\sigma^2} \sum_{t=1}^T \sum_{n=1}^N r_{n,t} [\mathbf{a}_k \circledast \mathbf{w}_{k,n}]_t \\ &= \frac{1}{\sigma^2} \sum_{t=1}^T \sum_{n=1}^N \sum_{d=1}^D a_{k,t-d} r_{n,t} w_{k,n,d} \\ &= \frac{1}{\sigma^2} \sum_{t=1}^T a_{k,t} \sum_{n=1}^N \sum_{d=1}^D r_{n,t+d} w_{k,n,d} \\ &= \frac{1}{\sigma^2} \sum_{t=1}^T a_{k,t} [\mathbf{R} \star \mathbf{W}_k]_t \end{align*} \end{split}\]

where \(\mathbf{R} \star \mathbf{W}_k\) denotes a 2D cross-correlation, which maps \(\mathbb{R}^{N \times T} \times \mathbb{R}^{N \times D} \mapsto \mathbb{R}^{T}\) (with appropriate padding).

Note

In this particular case the “signal” \(\mathbf{R} \in \mathbb{R}^{N \times T}\) and the “filter” \(\mathbf{W}_k \in \mathbb{R}^{N \times D}\) have the same number of rows. We can implement this 2D cross-correlation using PyTorch’s 1-D convolutions by taking advantage of the in- and out-channels, as we’ll see in lab.

The quadratic term#

Now unpack the quadratic term,

\[\begin{split} \begin{align*} \mathcal{L}_2(\mathbf{a}_k) &= -\frac{1}{2\sigma^2} \| \mathbf{a}_k \circledast \mathbf{W}_k \|_{\mathrm{F}}^2 \\ &= -\frac{1}{2\sigma^2} \sum_{n=1}^N \sum_{t=1}^T [\mathbf{a}_k \circledast \mathbf{w}_{k,n}]_{t}^2 \\ &= -\frac{1}{2\sigma^2} \sum_{n=1}^N \sum_{t=1}^T \left[\sum_{d=1}^D a_{k,t-d} w_{k,n,d} \right]^2 \\ &= -\frac{1}{2\sigma^2} \sum_{n=1}^N \sum_{t=1}^T \left[\sum_{d=1}^D a_{k,t-d}^2 w_{k,n,d}^2 + 2 \sum_{d=1}^D \sum_{d'=1}^{d-1} a_{k,t-d} a_{k,t-d'} w_{k,n,d} w_{k,n,d'} \right] \\ \end{align*} \end{split}\]

The second term has interactions between \(a_{k,t}\) and \(a_{k,t'}\) whenever \(|t-t'|<D\), which makes the coordinate ascent update for the vector \(\mathbf{a}_k\) hard!

However, remember that the waveform width \(D\) is roughly the duration of one spike. Thus, it is highly unlikely for two spikes to occur within \(D\) timesteps of each other. We will assume that the nonzero entries in \(\mathbf{a}_k\) are separated by at least \(D\) timesteps.

Under this assumption, the quadratic term reduces to,

\[\begin{split} \begin{align*} \mathcal{L}_2(\mathbf{a}_k) &= -\frac{1}{2\sigma^2} \sum_{n=1}^N \sum_{t=1}^T \sum_{d=1}^D a_{k,t-d}^2 w_{k,n,d}^2\\ &= -\frac{1}{2\sigma^2} \sum_{n=1}^N \sum_{t=1}^T \sum_{d=1}^D a_{k,t}^2 w_{k,n,d}^2\\ &= -\frac{1}{2\sigma^2} \sum_{t=1}^T a_{k,t}^2 \|\mathbf{W}_k\|_{\mathrm{F}}^2 \\ &= -\frac{1}{2\sigma^2} \sum_{t=1}^T a_{k,t}^2, \end{align*} \end{split}\]

just like in the previous chapter.

Finishing the optimization#

We have once again reduced the coordinate update for the amplitudes to solving a bunch of independent, scalar, quadratic optimization problems subject to non-negativity constraints. For \(a_{k,t}\), the problem reduces to,

\[ \begin{align*} a_{k,t}^\star &= \text{arg} \, \max_{a_{k,t} \in \mathbb{R}_+} \; f(a_{k,t}) = -\frac{\alpha}{2} a_{k,t}^2 + \beta a_{k,t} \end{align*} \]

where

\[\begin{split} \begin{align*} \alpha &= \frac{1}{\sigma^2} \\ \beta &= \frac{[\mathbf{R} \star \mathbf{W}]_t}{\sigma^2} - \lambda. \end{align*} \end{split}\]

The solution is,

\[ a_{k,t}^\star = \max \left\{ 0, \, [\mathbf{R} \star \mathbf{W}]_t - \lambda \sigma^2 \right\}. \]

Warning

Note that this solution does not guarantee that the resulting nonzero amplitudes will be separated by at least \(D\) time steps! In practice, we can enforce this constraint via the following heuristic: after solving for the optimal amplitudes, use the scipy.signal.find_peaks function to keep only a subset of nonzero amplitudes that are separated by a distance of \(D\).

Optimizing the waveforms#

As a function of the waveform \(\mathbf{W}_k\) for neuron \(k\), the log joint probability is,

\[ \begin{align*} \log p(\mathbf{X}, \mathbf{A}, \mathbf{W}) &= \frac{1}{\sigma^2} \langle \mathbf{R}, \mathbf{a}_k \circledast \mathbf{W}_k \rangle_{\mathrm{F}} + c' \end{align*} \]

We can simplify this expression a bit by introducing notation for windows of the residual matrix. Let,

\[\begin{split} \mathbf{R}_{t} = \begin{bmatrix} r_{1,t} & \ldots & r_{1,t+D} \\ \vdots & & \vdots \\ r_{n,t} & \ldots & r_{n,t+D} \end{bmatrix}. \end{split}\]

(In code, this is a slice of the residual matrix R[:,t:t+D].)

Once again assuming that the nonzero amplitudes are separated by at least \(D\) time steps, the log probability above simplifies to,

\[\begin{split} \begin{align*} \log p(\mathbf{X}, \mathbf{A}, \mathbf{W}) &= \frac{1}{2\sigma^2} \sum_{t=1}^T \langle a_{k,t} \mathbf{R}_t, \mathbf{W}_k \rangle + c' \\ &= \frac{1}{2\sigma^2} \left \langle \sum_{t=1}^T a_{k,t} \mathbf{R}_t, \mathbf{W}_k \right \rangle + c' \end{align*} \end{split}\]

This is analogous to the norm-constrained optimization problem for vector waveforms from the previous chapter!

Solving the optimization#

We want to maximize this log joint probability over the space of low-rank, unit-norm matrices \(\mathbb{S}_R^{N,D}\),

\[ \mathbf{W}_k^\star = \text{arg} \, \max_{\mathbf{W}_k \in \mathbb{S}_R^{N,D}} \left \langle \sum_{t=1}^T a_{k,t} \mathbf{R}_t, \mathbf{W}_k \right \rangle \]

Such optimization problems come up frequently with dealing with low-rank approximations.

Recall that when we had vector waveforms in the previous chapter, the solution was to set the waveform proportional to the weighted sum of residuals (the other vector in the inner product). Here, the solution is to set the waveform matrix “proportional to” the weighted sum of residual matrices by taking its SVD and renormaling the singular values.

Let \(\mathbf{U} \mathbf{S} \mathbf{V}^\top\) where \(\mathbf{S} = \mathrm{diag}(\mathbf{s})\) be the SVD of the matrix \(\sum_{t=1}^T a_{k,t} \mathbf{R}_t\). Furthermore, assume the singular values \(\mathbf{s}= (s_1, \ldots, s_{\min \{N,D\}})\) are sorted in descending order. The optimal waveform update is,

\[ \mathbf{W}_k^\star = \sum_{r=1}^R \bar{s}_r \mathbf{u}_r \mathbf{v}_r^\top \]

where

\[ \bar{s}_r = \frac{s_r}{\sqrt{\sum_{r'=1}^R s_{r'}^2}}. \]

More efficient computation#

Recall that a key term in the amplitude updates was the cross-correlation of the residual and the waveforms. We can compute that more efficiently by leveraging the fact that the waveforms are low rank,

\[\begin{split} \begin{align*} [\mathbf{R} \star \mathbf{W}_k]_t &= \sum_{n=1}^N \sum_{d=1}^D r_{n,t+d} w_{k,n,d} \\ &= \sum_{d=1}^D \mathbf{r}_{t+d}^\top \mathbf{w}_{k,:,d} \\ &= \sum_{d=1}^D \mathbf{r}_{t+d}^\top \mathbf{U}_k \mathbf{S}_k \mathbf{v}_{k,:,d} \\ &= \sum_{d=1}^D (\mathbf{U}_k^\top \mathbf{r}_{t+d})^\top [\mathbf{S}_k \mathbf{V}_k^\top]_{:,d} \\ &= [(\mathbf{U}_k^\top \mathbf{R}) \star (\mathbf{S}_k \mathbf{V}_k^\top)]_t \end{align*} \end{split}\]

Note that \(\mathbf{U}_k^\top \mathbf{r}_t\) is a projection of the residual onto the \(R\)-dimensional subspace spanned by the columns of \(\mathbf{U}_k\). This equality shows that we can perform the cross-correlation between residual and waveform in this lower dimensional space instead. In particular, when \(R=1\), it reduces to a 1-dimensional cross-correlation of the projected residual \(\mathbf{u}_k^\top \mathbf{R}\) and the waveform’s temporal profile \(\mathbf{v}_k\). This can yield a huge performance boost!

Preprocessing#

In practice, the raw voltage recordings are lightly preprocessed to create the matrix \(\mathbf{X}\).

  1. Sometimes electrical recordings have artifacts from the environment in which the recording is performed or from nearby electronics. One step toward reducing these artifacts is common average referencing, where we first subtract the mean across time, then subtract the median across channels.

  2. Spikes and the resulting EAPs are only a few milliseconds long. Real voltage recordings also have slower signals like local field potentials (LFPs), which have time scales of 3ms to 500ms. Since we are interested in spikes, we typically bandpass filter each channel \(\mathbf{x}_{n} = (x_{n,1}, \ldots, x_{n,t})\) to focus on content in the [300 Hz, 2000 Hz] frequency range; i.e. signals that vary over 0.5 to 3 ms.

  3. In many recordings, especially those from freely moving animals, the electrode may move slightly over time. This movement results in drift of the spike waveforms. State-of-the-art spike sorting software like Kilosort [Pachitariu et al., 2023] tries to correct for drift in preprocessing.

  4. Since the channels are so closely spaced, noise tends to be correlated across channels. Since the noise is assumed to be conditionally independent in the convolutional semi-NMF model described above, it is common to whiten the data before analysis. After bandpass filtering, the data should be mean zero. Thus, the empirical covariance is \(\hat{\mathbf{C}} = \frac{1}{T} \sum_{t=1}^T \mathbf{x}_t \mathbf{x}_t^\top\). Whiten the data by left-multiplying by the inverse square root of the covariance matrix, \(\mathbf{X} \leftarrow \hat{\mathbf{C}}^{-\frac{1}{2}} \mathbf{X}\).

  5. The MAP estimation problem is nonconvex, and the solution found by coordinate ascent will depend on the initialization procedure. There is no right answer for how to initialize the templates. An approach used by Kilosort (which this chapter is based on) is to initialize with a library of “universal” templates.

Post processing#

The waveforms extracted by convolutional semi-NMF usually still need a bit of post-processing. For example, sometimes the model finds two waveforms for the same neuron. Alternatively, it may assign two neurons to the same waveform if their spike timing is highly correlated.

These types of errors can be addressed with a post-processing step to split or merge clusters. This step is not unique to spike sorting — it’s a common postprocessing step in many unsupervised clustering analyses. For spike sorting, we can bring extra domain knowledge to bear on the problem. For example, we expect the spatial footprints to be localized, and we expect the temporal profiles to have a single downward deflection. Modern libraries incorporate checks like these into the postprocessing steps.

Conclusion#

This chapter extended the previous one by allowing waveforms that extend in time. The generalized model is a form of convolutional semi-NMF. Along the way, we picked up some new skills:

  • Convolution and Cross-Correlation: the building blocks of many machine learning models, and we’ll return to them multiple times in this course.

  • Frobenius norm and inner product: basic tools for dealing with matrix-valued variables

  • Singular value decomposition: a crucial matrix factorization with lots of applications in low-rank approximation.

  • More MAP estimation! by now, you’re quite familiar with framing estimation problems as maximizing the log probability and then deriving coordinate ascent algorithms. Here, the coordinate updates were particularl interesting, as they involved making some simplifying assumptions and optimizing over manifolds of low-rank matrices.

Next time, we’ll consider an analogous problem for working with calcium imaging data. Many of the models and tools we’ve developed will transfer.

Further reading#

The algorithm presented in this chapter is similar to Kilosort [Pachitariu et al., 2016, Steinmetz et al., 2021]. The exact algorithms employed by Kilosort change from version to version, but the convolutional generative model is central. Complete details of Kilosort and the differences from one version to the next can be found in Pachitariu et al. [2023].

Of course, there are other spike sorting algorithms and implementations as well, like YASS [Lee et al., 2020] and MountainSort [Chung et al., 2017]. Each has its own unique aspects, and it is interesting to compare and contrast different methods. For Neuropixels users, Kilosort appears to be the go-to method.

Is spike sorting really necessary though? For some questions of interest, like population decoding or state space analysis, it may not be. For example, Deng et al. [2015] showed improved performance using a “clusterless” decoding approach that uses extracted spike waveforms but does not try to assign them neuron labels. Similarly, Trautmann et al. [2019] showed that you can identify low dimensional states and dynamics from electrophysiological recordings without spike sorting. However, if your scientific objectives involve understanding individual neurons’ coding properties, then spike sorting is a necessary step.