Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Deep SSMs and Linear Attention

Two tensions define modern sequence modeling. On the training side, recurrent models like RNNs process tokens one at a time, leaving GPU parallelism underutilized on long sequences. On the inference side, Transformers compute all pairwise token interactions, incurring O(N2)O(N^2) time and memory that becomes prohibitive for long contexts. Can we get the best of both worlds — parallel training and efficient inference?

This chapter shows that the answer is yes, for a broad family of models that can be written as linear recurrences. A linear recurrence is simultaneously a convolution (fast parallel training via FFT or parallel scan) and a fixed-size hidden-state update (fast O(1)O(1)-memory inference). Structured SSMs exploit this duality with carefully designed dynamics matrices; linear attention recovers it by approximating the softmax kernel.

Deep State Space Models

Linear Recurrent Neural Networks

A linear RNN evolves a hidden state htRp\mbh_t \in \reals^p via

ht=Aht1+Bxt,yt=Cht+d,\mbh_t = \mbA \mbh_{t-1} + \mbB \mbx_t, \qquad \mby_t = \mbC \mbh_t + \mbd,

with h0=0\mbh_0 = \mathbf{0}. Because linear operators compose, unrolling the recurrence gives

yt=Cs=0t1AsBxts+d=[Kx]t,\mby_t = \mbC \sum_{s=0}^{t-1} \mbA^s \mbB \mbx_{t-s} + \mbd = [\mbK \circledast \mbx]_t,

where the SSM kernel is K=[CB    CAB    CA2B    ]\mbK = [\mbC\mbB \;\; \mbC\mbA\mbB \;\; \mbC\mbA^2\mbB \;\; \cdots].

This recurrence–convolution duality is the key computational property:

When the dynamics At\mbA_t are input-dependent (time-varying), the fixed kernel no longer applies. Training then requires a parallel scan over the sequence — an O(logT)O(\log T)-depth divide-and-conquer algorithm that exploits the associativity of affine function composition. See Parallelizing Nonlinear RNNs for a detailed treatment of the parallel scan and its extensions to nonlinear recurrences.

S4: A Continuous-Time Deep SSM

S4 Gu et al., 2022 is a deep SSM that derives its discrete-time recurrence from a continuous-time dynamical system, giving a principled way to initialize A\mbA so that the model captures long-range dependencies. A linear time-invariant (LTI) system evolves a hidden state h(t)Rph(t) \in \mathbb{R}^p via:

h(t)=Ah(t)+Bx(t),y(t)=Ch(t),h'(t) = A\, h(t) + B\, x(t), \qquad y(t) = C^\top h(t),

Discretizing with step size Δ\Delta (zero-order hold) yields the recurrent form:

ht=Aˉht1+Bˉxt,yt=Cht,h_t = \bar{A}\, h_{t-1} + \bar{B}\, x_t, \qquad y_t = C^\top h_t,

where Aˉ=eΔA\bar{A} = e^{\Delta A} and Bˉ=(AˉI)A1B\bar{B} = (\bar{A} - I) A^{-1} B. Because Aˉ\bar{A} is fixed, the output is a convolution with the SSM kernel Kˉ=(CBˉ,  CAˉBˉ,  CAˉ2Bˉ,)\bar{K} = (C^\top \bar{B},\; C^\top \bar{A} \bar{B},\; C^\top \bar{A}^2 \bar{B}, \ldots):

yt=(Kˉx)t.y_t = (\bar{K} * x)_t.

The remaining challenge is computing Kˉ\bar{K} efficiently. Computing eΔAe^{\Delta A} naively is O(p3)O(p^3). S4 addresses this by initializing AA using the HiPPO framework Gu et al., 2020, which constructs AA as a structured matrix designed to optimally memorize the input history via polynomial projections. The HiPPO-LegS matrix has entries:

Ank={(2n+1)1/2(2k+1)1/2n>k(n+1)n=k0n<kA_{nk} = \begin{cases} -(2n+1)^{1/2}(2k+1)^{1/2} & n > k \\ -(n+1) & n = k \\ 0 & n < k \end{cases}

S4 exploits the fact that HiPPO-LegS is a normal plus low-rank (NPLR) matrix to compute the SSM kernel in O(p+NlogN)O(p + N \log N) time via the fast Fourier transform.

S5: Diagonal SSMs with Parallel Scans

S5S5 Smith et al., 2023 simplifies S4 by diagonalizing the state matrix: Λ=P1AP\Lambda = P^{-1} A P where Λ\Lambda is diagonal. With h~t=P1ht\tilde{h}_t = P^{-1} h_t, the recurrence decouples into pp independent scalar recurrences:

h~t,i=λih~t1,i+B~ixt,yt=Re(Ch~t).\tilde{h}_{t,i} = \lambda_i \tilde{h}_{t-1,i} + \tilde{B}_i x_t, \qquad y_t = \mathrm{Re}(C^\top \tilde{h}_t).

Each mode decays independently at rate λi|\lambda_i|; stability requires λi<1|\lambda_i| < 1. S5 computes all NN hidden states in parallel using an associative parallel scan, reducing training time from O(Np)O(Np) sequential steps to O(plogN)O(p \log N) parallel steps.

Mamba: Selective SSMs

A fundamental limitation of LTI SSMs is that the transition matrices Aˉ\bar{A} and Bˉ\bar{B} are input-independent: every token is processed identically regardless of content. This prevents the model from selectively retaining or discarding information based on context.

Mamba Gu & Dao, 2023 introduces a selection mechanism by making BtB_t, CtC_t, and the discretization step Δt\Delta_t functions of the input xtx_t:

ht=Aˉtht1+Bˉtxt,yt=Ctht,h_t = \bar{A}_t\, h_{t-1} + \bar{B}_t\, x_t, \qquad y_t = C_t^\top h_t,

where Aˉt=eΔtA\bar{A}_t = e^{\Delta_t A}, Bˉt=(AˉtI)A1Bt\bar{B}_t = (\bar{A}_t - I) A^{-1} B_t, and Δt,Bt,Ct\Delta_t, B_t, C_t are computed from xtx_t via small linear projections.

Because parameters now vary with tt, the convolutional view no longer applies — the system is time-varying, and yty_t depends on all of x1,,xtx_1, \ldots, x_t in a non-linear way. Training requires a selective scan: a hardware-aware parallel algorithm that exploits the structure of the recurrence without materializing intermediate states in memory.

The selection mechanism gives Mamba attention-like content-dependent routing while preserving the O(Np)O(N p) cost of a recurrent model.

Linear Attention

The link between deep SSMs and Transformers (see the Transformers chapter) may be closer than it appears. Both Mamba and softmax attention compute content-dependent outputs — Mamba via input-driven gates, attention via pairwise dot products. The key difference is architectural: SSMs maintain a fixed-size hidden state updated recurrently, while attention computes all pairwise interactions simultaneously, with no fixed-size bottleneck but at O(N2)O(N^2) cost. Linearizing the attention kernel bridges the two, recovering a recurrent model from attention. Given queries QRN×dQ \in \mathbb{R}^{N \times d}, keys KRN×dK \in \mathbb{R}^{N \times d}, and values VRN×dV \in \mathbb{R}^{N \times d}, the output of causal (autoregressive) softmax attention at position ii is:

yi=j=1iexp ⁣(qikj/d)vjj=1iexp ⁣(qikj/d).y_i = \frac{\sum_{j=1}^{i} \exp\!\left(q_i^\top k_j / \sqrt{d}\right) v_j}{\sum_{j=1}^{i} \exp\!\left(q_i^\top k_j / \sqrt{d}\right)}.

In matrix form (with the causal mask Mij=I[ji]M_{ij} = \mathbb{I}[j \leq i]):

Y=softmax ⁣(QKd+logM)V,Y = \mathrm{softmax}\!\left(\frac{QK^\top}{\sqrt{d}} + \log M\right) V,

where logM\log M sets future entries to -\infty. Computing this requires materializing the N×NN \times N attention matrix, costing O(N2d)O(N^2 d) time and O(N2)O(N^2) memory.

Katharopoulos et al. (2020) observe that the softmax is a kernel function, exp(qk/d)=κ(q,k)\exp(q^\top k / \sqrt{d}) = \kappa(q, k), and that replacing it with a kernel with explicit feature maps, κ(q,k)=ϕ(q)ϕ(k)\kappa(q, k) = \phi(q)^\top \phi(k), unlocks a dramatic simplification. The causal attention output becomes:

yi=ϕ(qi)jiϕ(kj)vjϕ(qi)jiϕ(kj).y_i = \frac{\phi(q_i)^\top \sum_{j \leq i} \phi(k_j) v_j^\top}{\phi(q_i)^\top \sum_{j \leq i} \phi(k_j)}.

The key is that ϕ(qi)\phi(q_i) acts on the accumulated outer products, which can be built up incrementally as a linear recurrence. Defining the hidden state and normalizer:

St=St1+ϕ(kt)vtRr×d,zt=zt1+ϕ(kt)Rr,S_t = S_{t-1} + \phi(k_t) v_t^\top \in \mathbb{R}^{r \times d}, \qquad z_t = z_{t-1} + \phi(k_t) \in \mathbb{R}^r,

the output at each step is yt=Stϕ(qt)/(ztϕ(qt))y_t = S_t^\top \phi(q_t) / (z_t^\top \phi(q_t)). This is a linear RNN with hidden state StS_tO(Nrd)O(Nrd) time and O(rd)O(rd) memory, with no attention matrix. A simple feature map that keeps values positive is ϕ(x)=elu(x)+1\phi(x) = \mathrm{elu}(x) + 1 elementwise, giving r=dr = d.

Test-Time Regression

Linear attention can be reread as solving a regression problem online. At each step, the model has observed a sequence of key-value pairs (k1,v1),,(kt,vt)(k_1, v_1), \ldots, (k_t, v_t), which form a training set, and must predict a value at a new query qtq_t, the test point. The hidden state StS_t is the current regression solution, updated incrementally as new pairs arrive.

Concretely, consider fitting a linear map W:RdRdW : \reals^d \to \reals^d to the accumulated pairs via ridge regression:

Wt=argminWstWksvs2+λWF2.W_t = \arg\min_W \sum_{s \leq t} \|W k_s - v_s\|^2 + \lambda \|W\|_F^2.

The closed-form solution is Wt=(stvsks)(stksks+λI)1W_t = \left(\sum_{s \leq t} v_s k_s^\top\right) \left(\sum_{s \leq t} k_s k_s^\top + \lambda I\right)^{-1}, and the prediction at qtq_t is yt=Wtqty_t = W_t q_tkernel ridge regression with a linear kernel. The models above are all approximations to this solution:

The remaining question is how to approximate the full ridge regression solution cheaply and online.

DeltaNet: Gradient Descent for Linear Regression

Rather than accumulating all key-value pairs, DeltaNet Yang et al., 2024 maintains a weight matrix StS_t and updates it by taking one step of gradient descent on the per-token least-squares loss t(S)=12Sktvt2\ell_t(S) = \tfrac{1}{2}\|S k_t - v_t\|^2. With step size βt\beta_t, the gradient step gives the delta rule:

St=St1βt(St1ktvt)kt=(Iβtktkt)St1+βtvtkt.S_t = S_{t-1} - \beta_t (S_{t-1} k_t - v_t) k_t^\top = (I - \beta_t k_t k_t^\top)\, S_{t-1} + \beta_t v_t k_t^\top.

This is a rank-1 correction: the old memory’s prediction St1ktS_{t-1} k_t is erased and the new target vtv_t is written in, both weighted by ktk_t. Unlike linear attention (St=St1+ϕ(kt)vtS_t = S_{t-1} + \phi(k_t) v_t^\top, no forgetting), DeltaNet selectively overwrites memory associated with ktk_t. Output is yt=Stqty_t = S_t^\top q_t.

With keys and queries 2\ell_2-normalized and βt\beta_t a learned sigmoid gate, DeltaNet is a gated linear RNN with content-dependent forgetting — more expressive than linear attention, cheaper than softmax attention (O(Nd2)O(Nd^2) vs. O(N2d)O(N^2 d)). Yang et al. (2024) derive an efficient parallel training algorithm based on the WY representation of Householder products.

Test-Time Training

Test-time training (TTT) Yu et al., 2024 takes the regression perspective to its logical conclusion: the hidden state WtW_t is the weights of a small nonlinear model fWt:RdRdf_{W_t} : \reals^d \to \reals^d (e.g., a two-layer MLP), updated online by gradient descent on the reconstruction loss:

t(W)=12fW(kt)vt2.\ell_t(W) = \tfrac{1}{2}\|f_W(k_t) - v_t\|^2.

One gradient step gives Wt=Wt1ηWt(Wt1)W_t = W_{t-1} - \eta \nabla_W \ell_t(W_{t-1}), and output is yt=fWt(qt)y_t = f_{W_t}(q_t). When fWf_W is the linear map fW(x)=Wxf_W(x) = Wx, this reduces exactly to DeltaNet with βt=η\beta_t = \eta — so DeltaNet is TTT with the simplest possible hidden model. For nonlinear fWf_W, TTT can take multiple gradient steps on mini-batches of context tokens, creating a meta-learned inner loop that trades compute for expressivity of the hidden state.

Conclusion

The table below organises the models in this chapter along two axes: whether the dynamics are input-dependent (selective), and whether the hidden state update rule involves forgetting (vs. simple accumulation).

ModelHidden state updateInput-dependent?Forgetting?
Linear RNN / S4 Gu et al., 2022ht=Aˉht1+Bˉxth_t = \bar{A}\, h_{t-1} + \bar{B}\, x_tNoNo (fixed decay)
S5 Smith et al., 2023h~t,i=λih~t1,i+B~ixt\tilde{h}_{t,i} = \lambda_i \tilde{h}_{t-1,i} + \tilde{B}_i x_tNoNo (diagonal decay)
Mamba Gu & Dao, 2023ht=Aˉtht1+Bˉtxth_t = \bar{A}_t h_{t-1} + \bar{B}_t x_tYesYes (gated decay)
Linear attention Katharopoulos et al., 2020St=St1+ϕ(kt)vtS_t = S_{t-1} + \phi(k_t) v_t^\topYesNo (A=IA = I)
DeltaNet Yang et al., 2024St=(Iβtktkt)St1+βtvtktS_t = (I - \beta_t k_t k_t^\top) S_{t-1} + \beta_t v_t k_t^\topYesYes (selective overwrite)
TTT (linear) Yu et al., 2024Wt=Wt1η(Wt1ktvt)ktW_t = W_{t-1} - \eta (W_{t-1}k_t - v_t)k_t^\topYesYes (gradient descent)

All achieve O(N)O(N) inference complexity. The dominant pattern — a linear recurrence admitting parallel training via associative scans or convolutions — unifies classical signal processing, online learning, and modern sequence modeling.

References
  1. Gu, A., Goel, K., & Ré, C. (2022). Efficiently Modeling Long Sequences with Structured State Spaces. International Conference on Learning Representations.
  2. Gu, A., Dao, T., Ermon, S., Rudra, A., & Ré, C. (2020). HiPPO: Recurrent Memory with Optimal Polynomial Projections. Advances in Neural Information Processing Systems, 33, 1474–1487.
  3. Smith, J. T. H., Warrington, A., & Linderman, S. W. (2023). Simplified State Space Layers for Sequence Modeling. International Conference on Learning Representations.
  4. Gu, A., & Dao, T. (2023). Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv Preprint arXiv:2312.00752.
  5. Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020). Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. International Conference on Machine Learning, 5156–5165.
  6. Yang, S., Wang, B., Shen, Y., Peng, H., & Kim, Y. (2024). Parallelizing Linear Transformers with the Delta Rule over Sequence Length. Advances in Neural Information Processing Systems, 37.
  7. Yu, Y., Guo, S., Kautz, J., Alvarez, J. M., Anandkumar, A., Xu, K., & Molchanov, P. (2024). Learning to (Learn at Test Time): RNNs with Expressive Hidden States. Advances in Neural Information Processing Systems, 37.