Recurrent Neural Networks#
Let’s turn our attention back to models for sequential data for the next few lectures. With all the excitement around large language modeling, this is a timely topic.
Autoregressive Models#
Consider a sequence of observations \(\mbx_{1:T} = (\mbx_1, \ldots, \mbx_T)\) with each \(\mbx_t \in \reals^D\). We can always factor a joint distribution over observation into a product of conditionals using the chain rule,
This is called an autoregressive model since the conditional of \(\mbx_t\) depends only on previous observations \(\mbx_{1}, \ldots, \mbx_{t-1}\).
Autoregressive models are well-suited to sequential modeling since they make forward generation or forecasting easy. As long as we have access to the conditionals, we can sample forward indefinitely.
The question is, how should we parameterize these conditional distributions? It looks challenging since each one takes in a variable-length history of previous observations.
Recurrent Neural Networks#
Recurrent Neural Networks (RNNs) are autoregressive models in which the conditional distributions are functions of a finite-dimensional hidden state, \(\mbh_t \in \reals^K\),
The hidden state is updated with each new observation as,
Defining Property of an RNN
The conditional distribution over the next observation is a function of hidden state. The hidden state is updated recursively, and its size is fixed regardless of the sequence length.
Vanilla RNNs#
The standard, “vanilla” RNN consists of a linear-nonlinear state update. For example,
where
\(\mbW \in \reals^{K \times K}\) are the dynamics weights,
\(\mbB \in \reals^{K \times D}\) are the input weights,
\(\tanh(\cdot)\) is the hyperbolic tangent function.
Hyperbolic Tangent and the Logistic Function
The hyperbolic tangent function equivalent is typically written as,
We can rewrite it as,
Thus, we see that the the hyperbolic tangent is simply a scaled and shifted logistic function.
The “read-out” of a vanilla RNN is typically a simple linear or generalized linear model, depending on the type of observations. For example,
where
\(\mbC \in \reals^{D \times K}\) are the read-out weights and
\(\mbd \in \reals^D\) is the bias.
Let \(\mbtheta = (\mbW, \mbB, \mbC, \mbd)\) denote the set of parameters.
Theoretical Neuroscience#
From a machine learning perspective, RNNs are useful function approximators for sequential data. From a neuroscience perspective, they have a long history as theoretical models of neural computation.
In such models, the hidden state \(\mbh_t \in \reals^K\) corresponds to the relative firing rates of \(K\) neurons. With a hyperbolic tangent nonlinearity, \(\mbh_t \in (-1, 1)^K\), and negative rates don’t make sense. Instead, we think of \(\mbh_t\) as an additive offset to a baseline firing rate.
The dynamics weights correspond to synaptic connections, with positive weights as excitatory synapses and negative weights as inhibitory. When a presynaptic neuron spikes, it induces an electrical current in postsynaptic neurons. Under this interpretation, the activation \(\mbW \mbh_t\) corresponds to the input current generated by inputs from other neurons.
As a neuron receives input current, its voltage steadily increases until it reaches a threshold, at which point the voltage spikes and the neuron fires an action potential. These spikes induce currents on downstream neurons, as described above. After a cell fires, there is a short refractory period before the neuron can spike again. Thus, there is an upper bound on firing rates, which the hyperbolic tangent is meant to capture.
Backpropagation Through Time#
Artificial RNNs are trained using stochastic gradient descent (SGD) to minimize the negative log likelihood,
Now, you would simply use automatic differentiation to compute the necessary gradients to minimize this loss, but we can gain some insight by working them out manually. With some algebra, we can show that the Jacobian of the loss with respect to the dynamics weights (other parameters are similar) is,
We need the Jacobian of the loss with respect to the hidden states. These can be computed recursively by backpropagation through time (BPTT),
For a vanilla RNN, the Jacobian of the next state with respect to the current state is,
since \(\frac{\dif}{\dif a}\tanh(a) = 1 - \tanh(a)^2\).
BPTT is a linear dynamical system
The “state” of the BPTT recursions is the Jacobian, or equivalently its transpose, the gradient \(\mbs_t \triangleq \left(\frac{\partial \cL(\mbtheta)}{\partial \mbh_t}\right)^\top\). This state obeys a linear dynamical system,
where \(\mbA_t = \left(\frac{\partial \mbh_{t+1}}{\partial \mbh_t} \right)^\top\) and \(\mbb_t = - \left(\frac{\partial \log p(\mbx_t \mid g(\mbh_t; \mbtheta))}{\partial \mbh_t} \right)^\top\).
Biological Plausibility
Backpropagation is the most effective way we know of to train artificial RNNs, so it’s reasonable to think that the brain might be using a similar learning algorithm. Unfortunately, it’s not clear how the backpropagation through time algorithm could be implemented by a neural circuit. The multiplication by \(\mbA_t\) in the gradient recursions amounts to passing information backward across synapses, and canonical synaptic models don’t have mechanisms to do this. Recent years have seen a substantial amount of research into biologically plausible mechanisms of backpropagation.
Vanishing Gradients#
When the Jacobians \(\mbA_t\) have small eigenvalues (\(\ll 1\)), we run into problems of vanishing gradients. Consider the case of a linear RNN in which the \(\tanh\) is replaced with identity: then the Jacobians are \(\mbA_t = \mbW^\top\) for all time steps. If the eigenvalues of \(\mbW\) are much less than one, the gradients will decay to zero exponentially quickly, absent strong inputs \(\mbb_t\).
Vanishing gradients are especially problematic when \(\mbx_t\) depends on much earlier observations, \(\mbx_s\) for \(s \ll t\). In that case, the hidden state must propagate information about \(\mbx_s\) for many time steps during the forward pass, and likewise, the gradient must pass information backward many timesteps during the backward pass. If the weights have small eigenvalues, those gradients will decay and the learning signal will fail to propagate.
Gated RNNs#
One way to combat the vanishing gradient problem is by modifying the RNN architecture. Architectures like long short-term memory (LSTM) networks achieve this via gated units.
An LSTM has internal (aka “cell”) states \(\mbc_t \in \reals^K\) and hidden states \(\mbh_t \in \reals^K\). The internal states follow conditionally linear dynamics,
where
The bounded entries \(f_{t,k} \in [0,1]\) ensure stability. When \(f_{t,k} \approx 1\), the state is propagated, and when \(f_{t,k} \approx 0\), the state is forgotten. Thus, \(\mbf_t = (f_{t,1}, \ldots, f_{t,K})^\top \in [0,1]^K\) are called the forget gates, and they are parameterized by the matrices \(\mbW^{(f)}\) and \(\mbB^{(f)}\).
The affine term is determined by,
The vector \(\mbg_t \in [0,1]^K\) plays the role of an input gate, and the input s themselves are given by \(\mbi_t \in [0,1]^K\).
Finally, the hidden states \(\mbh_t\) are gated functions of the internal state passed through a nonlinearity,
where \(\mbo_t \in [0,1]^K\) are the output gates. As in a vanilla RNN, the final prediction depends on a (generalized) linear function of the hidden state, \(g(\mbh_t; \mbtheta)\).
We can think of an LSTM as an RNN that operates on an extended state \((\mbc_t, \mbh_t) \in \reals_+^K \times [-1, 1]^K\). The forget gates let the eigenvalues of \(\mbF_t\) to be close to one, allowing cell states to be propagated for long periods of time on the forward pass, and gradients to be backpropagated without vanishing on the backward pass.
There are many variants of gated RNNs. Besides the LSTM, the most commonly used in the gated recurrent unit (GRU), which has a slightly simplified architecture. See Goodfellow et al. [GBC16] (ch. 10) for more detail.
Other Variations and Uses of RNNs#
We motivated RNNs from an autoregressive modeling persepective, but they are useful in other sequential data settings as well. For example, suppose we want to predict the sentiment of a review \(y \in \reals\) given a variable-length sequence of input words \(\mbx_{1:T}\). We can use an RNN to summarize the input sequence in terms of a hidden state for prediction,
Sequence to Sequence Models#
Sometimes we want to map one sequence \(\mbx_{1:T}\) to another sequence \(\mby_{1:T'}\). The sequences may be of different length; e.g., when we want to translate a sentence from English to French. Again, we can train an encoder RNN to produce a hidden state \(\mbh_T\) that then becomes the initial condition for a decoder RNN that generates the output sequence. Formally,
where \(\mbh_T\) is the output of an RNN that processed \(\mbx_{1:T}\), and \(\mbh'_t\) is the state of an RNN that runs over \(\mby_{1:T}\).
Bidirectional RNNs#
In the example above, one challenge is that the hidden state \(\mbh_T\) obtained by processing \(\mbx_{1:T}\) may not adequately represent early inputs like \(x_1\). For these purposes, you can use a bidirectional RNN that runs one recursion forward \(\mbx_1, \ldots, \mbx_T\) and another backward \(\mbx_T, \ldots, \mbx_1\) to produce two hidden states at each time \(t\). These combined states can then be passed into the decoder.
Deep RNNs#
As with deep neural networks that stack layer upon layer, we can stack RNN upon RNN to construct a deeper model. In such models, the outputs of one layer, \(g(\mbh_t^{(i)}; \mbtheta^{(i)})\) become the inputs to the next layer, \(\mbx_t^{(i+1)}\). Then we can backpropagate gradients through the entire stack to train the model.
HMMs Are RNNs Too!#
It turns out, we’ve already seen an RNN in this class! We presented Hidden Markov Models (HMMs) as latent variable models with hidden states \(\mbz_{1:T}\) and conditionally indepedent observations \(\mbx_{1:T}\), but we can also view them as an autoregressive model in which
where, \(\overline{\mbalpha}_t \in \Delta_{K-1}\) are the normalized forward messages from the forward-backward algorithm. They followed a simple recursion,
with \(\mbl_t \in \reals^K\) is the vector of likelihoods with entries \(l_{t,k} = p(\mbx_t \mid z_t=k)\) and \(\mbP \in \reals^{K \times K}\) is the transition matrix.
Categorical HMMs#
Consider an HMM with categorical emissions,
where \(\mbx_t\) is a one-hot encoding of a variable that takes values in \(\{1,\ldots,V\}\), and \(\mbc_k \in \Delta_{V-1}\) for \(k=1,\ldots,K\) are pmfs. Define the matrix of likelihoods \(\mbC \in \reals^{V \times K}\) to have columns \(\mbc_k\),
The HMM parameters are \(\mbtheta = (\mbP, \mbC)\). (Assume the initial distribution is fixed, for simplicity.)
For a categorical HMM, we can write the likelihoods as \(\mbl_t = \mbC^\top \mbx_t\) so that the the forward recursions simplify to,
Likewise, the autoregressive distributions reduce to,
Framed this way, a categorical HMM can be seen as a simple recurrent neural network!
A Cool Trick for Computing the Gradients#
This formulation suggests that we could estimate the parameters of an HMM by directly maximizing the log likelihood,
With automatic differentiation at our disposal, this sounds like it might be a lot easier!
Let’s pursue this idea a little further. First, we’d prefer to do unconstrained optimization, so let’s parameterize the model in terms of \(\log \mbP\) and \(\log \mbC\). When we need the constrained versions, we will just apply the softmax to obtain simplex vectors:
Softmax is translation invariant
Note that the softmax operation is translation invariant,
for any constant \(a \in \reals\). Thus, we will call our optimization variables \(\log \mbC\) and \(\log \mbP\), but they are not actually the log of matrices with simplex columns or rows; they are unconstrained parameters that become properly normalized via the softmax.
To maximize the likelihood with gradient ascent, we need the Jacobians, \(\frac{\partial \cL(\mbtheta)}{\partial \log \mbP}\) and \(\frac{\partial \cL(\mbtheta)}{\partial \log \mbC}\). For the RNNs above, we computed them using backpropagation through time, but here we can use an even cooler trick.
First, note that the posterior distribution in an HMM can be written as an exponential family,
where the log marginal likelihood \(A(\mbtheta) = \log p(\mbx_{1:T}; \mbtheta)\) is the log normalizer.
Recall that for exponential family distributions, gradients of the log normalizer yield expected sufficient statistics. Thus,
and
The gradients are essentially the posterior marginals and pairwise marginals we computed in the EM algorithm!
In the M-step of EM, we solve for the parameters that satisfy the constrained optimality conditions, whereas in SGD we just take a step in the direction of the gradient. EM tends to converge must faster in practice for this reason.
Linear RNNs#
Consider the special case of linear RNNs,
with read-out \(\mby_t = \mbC \mbh_t + \mbd\).
Since a linear operators compose, we can write the entire input-output map as a linear function,
with the convention that \(\mbh_0 = 0\).
We recognize this as a convolution,
with kernel,
Computationally, representing the RNN as a convolution leads to efficient implementations, since modern deep learning libraries like PyTorch and JAX have highly optimized convolution routines. Likewise, sampling the model is straightforward and efficient using the RNN formulation.
While these simple linear RNNs may seem too simple to capture complex sequential dependencies, it turns out that stacks of linear layers with simple nonlinearities in between — a deep linear state space model [GGR21] — can be highly expressive!
Note
Note that the kernel involves matrix powers \(\mbA^t\), which typically are cubic in the hidden state dimension. Gu et al. [GGR21] derived clever algorithms for efficiently computing the kernel and evaluating the convolution for certain structured classes of dynamics matrices.
Input-dependent dynamics with parallel scan#
One limitation of the convolutional formulation above is that it requires the dynamics matrix \(\mbA\) to be the same at all time-steps. Smith et al. [SWL23] showed an alternative way to evaluate the linear RNN that relaxes this constraint.
Consider two consecutive state updates, now with time-dependent dynamics matrices \(\mbA_t\) and affine terms \(\mbb_t = \mbB_t \mbx_t\),
where \(\mbA_{t-2:t} \triangleq \mbA_t \mbA_{t-1}\) and \(\mbb_{t-2:t} \triangleq \mbA_t \mbb_{t-1} + \mbb_t\).
This is just another affine map, but now it takes \(\mbh_{t-2}\) to \(\mbh_t\). In parallel, we can compute the maps from time \(t\) to \(t+2\), from \(t+2\) to \(t+4\), and so on.
In the next iteration, we can combine these linear maps to obtain \((\mbA_{t-4:t}, \mbb_{t-4:t})\), and so on.
After \(\log_2 T\) iterations, we obtain a map from \(\mbh_0\) to \(\mbh_T\). With \(\cO(\log T)\) more work, we can obtain maps from \(\mbh_0\) to all intermediate times \(t\) as well.
This algorithm is called a parallel scan or binary associative scan, since it is based on a binary associative operator, \(\circ\), of the form,
With \(T\) parallel processors, it requires only \(\cO(\log T)\) time and \(\cO(T)\) memory.
Note
Again, notice that each update requires matrix multiplication, which is typically cubic in the state dimension. Smith et al. [SWL23] proposed to work with complex diagonal matrices instead, which keeps the time and memory costs in check.
Parallelizing Nonlinear RNNs#
The parallel scan relied on the linearity of the RNN dynamics. Can the same be applied to nonlinear RNNs?
It turns out that yes, in many cases we can speed up the evaluation of nonlinear RNNs using a similar trick!
Consider an RNN with nonlinear dynamics function \(f(\mbh_t)\). Form a first-order Taylor approximation arond an initial guess, \(\mbh_t^{(0)}\),
where \(J_f(\mbh_t^{(0)}) = \frac{\partial f}{\partial \mbh_t}\bigg|_{\mbh_t = \mbh_t^{(0)}}\) is the Jacobian of \(f\) evaluated at \(\mbh_t^{(0)}\).
The Taylor approximations yield a linear RNN with time-varying dynamics, which can be evaluated with a parallel scan to obtain a new sequence of latent states, \((\mbh_1^{(1)}, \ldots, \mbh_T^{(1)})\), which can be used as the guess for the next iteration.
Repeating this process until convergence is equivalent to the Gauss-Newton method for minimizing the sum of squares loss function,
This idea was proposed by Lim et al (2023) and called DEER. Gonzalez et al (2024) explained the connection to the Gauss-Newton method and proposed an extension called ELK.
Conclusion#
Recurrent neural networks are foundational models for sequential data. They’re useful machine learning models, and they’re standard models in theoretical neuroscience. While Transformers are the star of modern large language models, RNNs are making a comeback as efficient techniques for capturing long range dependences in sequential data.