Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Parallelizing Nonlinear RNNs

Linear RNNs and structured SSMs are fast to train because their recurrences compose as affine functions — a property that admits an O(logT)O(\log T)-depth parallel scan. Nonlinear RNNs, like the vanilla RNN and the GRU, do not enjoy this property: each state depends nonlinearly on the previous one, so evaluation seems inherently sequential. This chapter first introduces the parallel scan algorithm, then shows how iterative linearization extends it to nonlinear recurrences — reducing each iteration to a linear dynamical system (LDS) solve that can be parallelized with the same primitive.

The Parallel Scan

Consider a linear RNN with time-varying dynamics,

xt=Atxt1+bt,t=1,,T,\mbx_t = \mbA_t \mbx_{t-1} + \mbb_t, \qquad t = 1, \ldots, T,

where xtRD\mbx_t \in \reals^D, AtRD×D\mbA_t \in \reals^{D \times D}, and btRD\mbb_t \in \reals^D. Naively, computing the full trajectory x1:T\mbx_{1:T} requires TT sequential steps, since each xt\mbx_t depends on xt1\mbx_{t-1}. The parallel scan Blelloch, 1990, also called the associative scan, reduces this to O(logT)O(\log T) depth by exploiting the structure of affine composition.

Associativity and Closure

Each step of the recurrence is an affine map ft(x)=Atx+btf_t(\mbx) = \mbA_t \mbx + \mbb_t. Two consecutive steps compose as another affine map:

fjfi(x)=Aj(Aix+bi)+bj=(AjAi)x+(Ajbi+bj).f_j \circ f_i(\mbx) = \mbA_j(\mbA_i \mbx + \mbb_i) + \mbb_j = (\mbA_j \mbA_i)\,\mbx + (\mbA_j \mbb_i + \mbb_j).

Representing each step as a pair (At,bt)(\mbA_t, \mbb_t), the composition rule is:

(Aj,bj)(Ai,bi)=(AjAi,  Ajbi+bj).(\mbA_j, \mbb_j) \otimes (\mbA_i, \mbb_i) = (\mbA_j \mbA_i,\; \mbA_j \mbb_i + \mbb_j).

This operation is associative(fkfj)fi=fk(fjfi)(f_k \circ f_j) \circ f_i = f_k \circ (f_j \circ f_i) — and closed — the composition of two affine maps is again an affine map. These two properties are exactly what the parallel scan requires.

Divide-and-Conquer

Given TT affine maps, we want all cumulative compositions ft:1=ftf1f_{t:1} = f_t \circ \cdots \circ f_1 for t=1,,Tt = 1, \ldots, T. The parallel scan computes these in two phases:

Together both phases run in O(logT)O(\log T) depth with O(T)O(T) processors, yielding the full trajectory x1:T\mbx_{1:T} in O(T)O(T) total work — the same as sequential evaluation, but parallelized across the time dimension.

Sequential Evaluation as Root-Finding

Let xtRD\mbx_t \in \reals^D denote the hidden state at time tt, and let ft+1f_{t+1} denote the (possibly nonlinear, possibly input-dependent) transition function. Sequential evaluation of the recurrence,

xt+1=ft+1(xt),t=0,,T1,\mbx_{t+1} = f_{t+1}(\mbx_t), \qquad t = 0, \ldots, T-1,

requires O(T)O(T) serial steps because xt+1\mbx_{t+1} directly depends on xt\mbx_t.

We can reframe the problem: the correct trajectory x1:T\mbx_{1:T}^* is the unique solution to the system of equations,

xt+1ft+1(xt)=0,t{0,,T1}.\mbx_{t+1} - f_{t+1}(\mbx_t) = 0, \qquad \forall t \in \{0, \ldots, T-1\}.

This is a system of TT coupled nonlinear equations in TDTD unknowns. Fixed-point iteration is a natural approach: start from an initial guess x1:T(0)\mbx_{1:T}^{(0)} (e.g., all zeros) and iteratively refine it by solving a simpler surrogate problem.

A Unifying Framework

The key insight of Gonzalez et al. (2026) is that four prominent fixed-point methods — Newton, quasi-Newton, Picard, and Jacobi — all reduce to iterative evaluation of a linear dynamical system. Specifically, each iteration takes the common form,

xt+1(i+1)=ft+1 ⁣(xt(i))+A~t+1 ⁣(xt(i+1)xt(i)),\mbx_{t+1}^{(i+1)} = f_{t+1}\!\left(\mbx_t^{(i)}\right) + \widetilde{A}_{t+1}\!\left(\mbx_t^{(i+1)} - \mbx_t^{(i)}\right),

where A~t+1RD×D\widetilde{A}_{t+1} \in \reals^{D \times D} is an approximate Jacobian of the dynamics. This is a linear recursion in the unknown xt+1(i+1)\mbx_{t+1}^{(i+1)}, driven by the bias bt+1=ft+1(xt(i))A~t+1xt(i)\mbb_{t+1} = f_{t+1}(\mbx_t^{(i)}) - \widetilde{A}_{t+1} \mbx_t^{(i)}, which only depends on the previous iterate. Since it is an LDS, it can be evaluated via a parallel scan in O(logT)O(\log T) depth.

The four methods differ only in how they choose A~t+1\widetilde{A}_{t+1}:

MethodA~t+1\widetilde{A}_{t+1}Cost per iterationParallelization
Newtonft+1xt(xt(i))\frac{\partial f_{t+1}}{\partial \mbx_t}(\mbx_t^{(i)}) — full JacobianO(TD3)O(TD^3)Parallel scan (dense)
Quasi-Newtondiag ⁣[ft+1xt(xt(i))]\mathrm{diag}\!\left[\frac{\partial f_{t+1}}{\partial \mbx_t}(\mbx_t^{(i)})\right] — diagonalO(TD)O(TD)Parallel scan (elementwise)
PicardIDI_D — identityO(TD)O(TD)Prefix sum
Jacobi0 — zeroO(TD)O(TD)Embarrassingly parallel

All four methods are guaranteed to converge to the correct trajectory x1:T\mbx_{1:T}^* in at most TT iterations Gonzalez et al., 2026.

Four Root Finding Methods

Newton Iterations

Newton’s method for root-finding linearizes the residual r(x1:T)=[x1f1(x0),,xTfT(xT1)]r(\mbx_{1:T}) = [\mbx_1 - f_1(\mbx_0), \ldots, \mbx_T - f_T(\mbx_{T-1})] using its full Jacobian. Applied to sequential evaluation, each Newton step is the first-order Taylor expansion of the recurrence around the current iterate:

xt+1(i+1)=ft+1 ⁣(xt(i))+ft+1xt ⁣(xt(i)) ⁣(xt(i+1)xt(i)).\mbx_{t+1}^{(i+1)} = f_{t+1}\!\left(\mbx_t^{(i)}\right) + \frac{\partial f_{t+1}}{\partial \mbx_t}\!\left(\mbx_t^{(i)}\right)\!\left(\mbx_t^{(i+1)} - \mbx_t^{(i)}\right).

This is the DEER algorithm of Lim et al. (2024). Because the transition matrix A~t+1\widetilde{A}_{t+1} is the full D×DD \times D Jacobian, each iteration requires O(TD2)O(TD^2) memory and O(TD3)O(TD^3) work for the matrix–matrix multiplications in the parallel scan. For large state dimensions, this is prohibitive.

When ft+1f_{t+1} is a linear function of xt\mbx_t, the Jacobian is exact and Newton converges in a single iteration — recovering the standard parallel scan for LDSs as a special case.

Quasi-Newton Iterations

To reduce the cost of Newton iterations, Gonzalez et al. (2024) replace the full Jacobian with its diagonal:

xt+1(i+1)=ft+1 ⁣(xt(i))+diag ⁣[ft+1xt ⁣(xt(i))] ⁣(xt(i+1)xt(i)).\mbx_{t+1}^{(i+1)} = f_{t+1}\!\left(\mbx_t^{(i)}\right) + \mathrm{diag}\!\left[\frac{\partial f_{t+1}}{\partial \mbx_t}\!\left(\mbx_t^{(i)}\right)\right]\!\left(\mbx_t^{(i+1)} - \mbx_t^{(i)}\right).

This is the ELK (Evaluating Levenberg–Marquardt via Kalman) algorithm. With a diagonal transition matrix, each parallel scan step is an elementwise vector multiplication, reducing cost to O(TD)O(TD). The diagonal of the Jacobian can often be computed in closed form for common RNN architectures (e.g., GRUs), or estimated stochastically using the Hutchinson estimator in a single function call.

Picard Iterations

Picard iterations set A~t+1=ID\widetilde{A}_{t+1} = I_D, approximating the Jacobian by the identity matrix. The update simplifies to a prefix sum:

xt+1(i+1)=x0+s=0tfs+1 ⁣(xs(i))Δ,\mbx_{t+1}^{(i+1)} = \mbx_0 + \sum_{s=0}^{t} f_{s+1}\!\left(\mbx_s^{(i)}\right) \Delta,

where Δ\Delta is the discretization step (for ODE-based dynamics). Picard iterations require only vector additions, making them the cheapest per-iteration method. Shih et al. (2023) used them to parallelize sampling in diffusion models.

The identity approximation is faithful when the true Jacobian ft+1/xtID\partial f_{t+1}/\partial \mbx_t \approx I_D, which holds for dynamics with small step sizes (e.g., discretized ODEs with fine time steps).

Jacobi Iterations

Jacobi iterations set A~t+1=0\widetilde{A}_{t+1} = 0, giving the simplest possible update:

xt+1(i+1)=ft+1 ⁣(xt(i)).\mbx_{t+1}^{(i+1)} = f_{t+1}\!\left(\mbx_t^{(i)}\right).

Each element of the new trajectory is computed independently — no dependencies between time steps at all. This is embarrassingly parallel: all TT states can be updated simultaneously in a single kernel call. Song et al. (2021) used Jacobi iterations to accelerate feedforward computation in deep networks.

The zero-Jacobian approximation is faithful when consecutive states evolve nearly independently, i.e., when the true Jacobian ft+1/xt0\partial f_{t+1}/\partial \mbx_t \approx 0.

Convergence Analysis

Gonzalez et al. (2026) derive a unified convergence bound for all four methods. Let e(i)=x1:T(i)x1:T\mbe^{(i)} = \mbx_{1:T}^{(i)} - \mbx_{1:T}^* denote the error at iteration ii, and let J~\widetilde{J} and JJ denote the block-bidiagonal approximate and true Jacobians of the residual. Then:

e(i+1)2J~12J~J2e(i)2+O ⁣(e(i)22).\|\mbe^{(i+1)}\|_2 \leq \left\|\widetilde{J}^{-1}\right\|_2 \cdot \left\|\widetilde{J} - J\right\|_2 \cdot \|\mbe^{(i)}\|_2 + O\!\left(\|\mbe^{(i)}\|_2^2\right).

As the error approaches zero, the asymptotic linear convergence rate is:

γ=J~1(J~J)2.\gamma = \left\|\widetilde{J}^{-1}\left(\widetilde{J} - J\right)\right\|_2.

Convergence is fast (γ1\gamma \ll 1) when two conditions hold simultaneously:

  1. Small approximation error: J~J2\|\widetilde{J} - J\|_2 is small, i.e., A~t+1\widetilde{A}_{t+1} is a faithful approximation of the true Jacobian ft+1/xt\partial f_{t+1}/\partial \mbx_t.

  2. Stable LDS: J~12\|\widetilde{J}^{-1}\|_2 is small, i.e., the linearized system with transition matrices A~t+1\widetilde{A}_{t+1} is stable (spectral norms well below one).

The two conditions trade off against each other:

Conclusion

The parallel scan turns any associative, closed operation into an O(logT)O(\log T)-depth computation. For nonlinear RNNs, iterative linearization reduces each fixed-point iteration to an LDS — making the parallel scan applicable even when the original dynamics are nonlinear. All four methods are guaranteed to converge in at most TT iterations; the rate depends on how faithfully A~t+1\widetilde{A}_{t+1} approximates the true Jacobian and on the stability of the induced LDS.

MethodA~t+1\widetilde{A}_{t+1}Cost/iterConverges fast when
NewtonFull JacobianO(TD3)O(TD^3)Jacobian is dense
Quasi-NewtonDiag. JacobianO(TD)O(TD)Diagonal of Jacobian dominates
PicardIDI_DO(TD)O(TD)Dynamics \approx identity shift
Jacobi0O(TD)O(TD)States evolve nearly independently
References
  1. Blelloch, G. E. (1990). Prefix sums and their applications (Techreport CMU-CS-90-190). School of Computer Science, Carnegie Mellon University.
  2. Gonzalez, X., Buchanan, E. K., Lee, H. D., Liu, J. W., Wang, K. A., Zoltowski, D. M., Kozachkov, L., Ré, C., & Linderman, S. W. (2026). A Unifying Framework for Parallelizing Sequential Models with Linear Dynamical Systems. Transactions on Machine Learning Research. https://openreview.net/forum?id=fw6GgAIGur
  3. Lim, Y. H., Zhu, Q., Selfridge, J., & Kasim, M. F. (2024). Parallelizing Non-Linear Sequential Models Over the Sequence Length. International Conference on Learning Representations (ICLR).
  4. Gonzalez, X., Buchanan, E. K., Zoltowski, D. M., & Linderman, S. W. (2024). ELK: Evaluating Levenberg–Marquardt via Kalman. arXiv Preprint.
  5. Shih, A., Belkhale, S., Ermon, S., Sadigh, D., & Anari, N. (2023). Parallel Sampling of Diffusion Models. Advances in Neural Information Processing Systems (NeurIPS).
  6. Song, Y., Meng, C., Liao, R., & Ermon, S. (2021). Accelerating Feedforward Computation via Parallel Nonlinear Equation Solving. International Conference on Machine Learning (ICML).