Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Recurrent Neural Networks

Let’s turn our attention back to models for sequential data for the next few lectures. With all the excitement around large language modeling, this is a timely topic.

Autoregressive Models

Consider a sequence of observations x1:T=(x1,,xT)\mbx_{1:T} = (\mbx_1, \ldots, \mbx_T) with each xtRD\mbx_t \in \reals^D. We can always factor a joint distribution over observation into a product of conditionals using the chain rule,

p(x1:T)=p(x1)t=2Tp(xtx1:t1).\begin{align*} p(\mbx_{1:T}) &= p(\mbx_1) \prod_{t=2}^T p(\mbx_t \mid \mbx_{1:t-1}). \end{align*}

This is called an autoregressive model since the conditional of xt\mbx_t depends only on previous observations x1,,xt1\mbx_{1}, \ldots, \mbx_{t-1}.

Autoregressive models are well-suited to sequential modeling since they make forward generation or forecasting easy. As long as we have access to the conditionals, we can sample forward indefinitely.

The question is, how should we parameterize these conditional distributions? It looks challenging since each one takes in a variable-length history of previous observations.

Recurrent Neural Networks

Recurrent Neural Networks (RNNs) are autoregressive models in which the conditional distributions are functions of a finite-dimensional hidden state, htRK\mbh_t \in \reals^K,

p(xtx1:t1)=p(xtg(ht;θ)).\begin{align*} p(\mbx_t \mid \mbx_{1:t-1}) &= p(\mbx_t \mid g(\mbh_t; \mbtheta)). \end{align*}

The hidden state is updated with each new observation as,

ht+1=f(ht,xt;θ).\begin{align*} \mbh_{t+1} &= f(\mbh_t, \mbx_t; \mbtheta). \end{align*}

Vanilla RNNs

The standard, “vanilla” RNN consists of a linear-nonlinear state update. For example,

f(ht,xt;θ)=tanh(Wht+Bxt),\begin{align*} f(\mbh_t, \mbx_t; \mbtheta) &= \tanh \left(\mbW \mbh_t + \mbB \mbx_t \right), \end{align*}

where

The “read-out” of a vanilla RNN is typically a simple linear or generalized linear model, depending on the type of observations. For example,

g(ht,xt;θ)=Cht+d,\begin{align*} g(\mbh_t, \mbx_t; \mbtheta) &= \mbC \mbh_t + \mbd, \end{align*}

where

Let θ=(W,B,C,d)\mbtheta = (\mbW, \mbB, \mbC, \mbd) denote the set of parameters.

Theoretical Neuroscience

From a machine learning perspective, RNNs are useful function approximators for sequential data. From a neuroscience perspective, they have a long history as theoretical models of neural computation.

In such models, the hidden state htRK\mbh_t \in \reals^K corresponds to the relative firing rates of KK neurons. With a hyperbolic tangent nonlinearity, ht(1,1)K\mbh_t \in (-1, 1)^K, and negative rates don’t make sense. Instead, we think of ht\mbh_t as an additive offset to a baseline firing rate.

The dynamics weights correspond to synaptic connections, with positive weights as excitatory synapses and negative weights as inhibitory. When a presynaptic neuron spikes, it induces an electrical current in postsynaptic neurons. Under this interpretation, the activation Wht\mbW \mbh_t corresponds to the input current generated by inputs from other neurons.

As a neuron receives input current, its voltage steadily increases until it reaches a threshold, at which point the voltage spikes and the neuron fires an action potential. These spikes induce currents on downstream neurons, as described above. After a cell fires, there is a short refractory period before the neuron can spike again. Thus, there is an upper bound on firing rates, which the hyperbolic tangent is meant to capture.

Backpropagation Through Time

Artificial RNNs are trained using stochastic gradient descent (SGD) to minimize the negative log likelihood,

L(θ)=t=1Tlogp(xtx1:t1)=t=1Tlogp(xtg(ht;θ))=t=1Tlogp(xtg(f(ht1,xt1;θ);θ))=t=1Tlogp(xtg(f(f(h1,x1;θ),xt1;θ);θ)).\begin{align*} \cL(\mbtheta) &= - \sum_{t=1}^T \log p(\mbx_t \mid \mbx_{1:t-1}) \\ &= - \sum_{t=1}^T \log p(\mbx_t \mid g(\mbh_t; \mbtheta)) \\ &= - \sum_{t=1}^T \log p(\mbx_t \mid g(f(\mbh_{t-1}, \mbx_{t-1}; \mbtheta); \mbtheta)) \\ &= - \sum_{t=1}^T \log p(\mbx_t \mid g(f(\cdots f(\mbh_{1}, \mbx_{1}; \mbtheta) \cdots, \mbx_{t-1}; \mbtheta); \mbtheta)). \end{align*}

Now, you would simply use automatic differentiation to compute the necessary gradients to minimize this loss, but we can gain some insight by working them out manually. With some algebra, we can show that the Jacobian of the loss with respect to the dynamics weights (other parameters are similar) is,

L(θ)W=t=1TL(θ)hthtW\begin{align*} \frac{\partial \cL(\mbtheta)}{\partial \mbW} &= \sum_{t=1}^T \frac{\partial \cL(\mbtheta)}{\partial \mbh_t} \frac{ \partial \mbh_t}{\partial \mbW} \end{align*}

We need the Jacobian of the loss with respect to the hidden states. These can be computed recursively by backpropagation through time (BPTT),

L(θ)ht=L(θ)ht+1ht+1htlogp(xtg(ht;θ))ht\begin{align*} \frac{\partial \cL(\mbtheta)}{\partial \mbh_t} &= \frac{\partial \cL(\mbtheta)}{\partial \mbh_{t+1}} \frac{\partial \mbh_{t+1}}{\partial \mbh_t} - \frac{\partial \log p(\mbx_t \mid g(\mbh_t; \mbtheta))}{\partial \mbh_t} \end{align*}

For a vanilla RNN, the Jacobian of the next state with respect to the current state is,

ht+1ht=diag(1ht+12)W\begin{align*} \frac{\partial \mbh_{t+1}}{\partial \mbh_t} &= \diag \left(1 - \mbh_{t+1}^2 \right) \mbW \end{align*}

since  ⁣d ⁣datanh(a)=1tanh(a)2\frac{\dif}{\dif a}\tanh(a) = 1 - \tanh(a)^2.

Vanishing Gradients

When the Jacobians At\mbA_t have small eigenvalues (1\ll 1), we run into problems of vanishing gradients. Consider the case of a linear RNN in which the tanh\tanh is replaced with identity: then the Jacobians are At=W\mbA_t = \mbW^\top for all time steps. If the eigenvalues of W\mbW are much less than one, the gradients will decay to zero exponentially quickly, absent strong inputs bt\mbb_t.

Vanishing gradients are especially problematic when xt\mbx_t depends on much earlier observations, xs\mbx_s for sts \ll t. In that case, the hidden state must propagate information about xs\mbx_s for many time steps during the forward pass, and likewise, the gradient must pass information backward many timesteps during the backward pass. If the weights have small eigenvalues, those gradients will decay and the learning signal will fail to propagate.

Gated RNNs

One way to combat the vanishing gradient problem is by modifying the RNN architecture. Architectures like long short-term memory (LSTM) networks achieve this via gated units.

An LSTM has internal (aka “cell”) states ctRK\mbc_t \in \reals^K and hidden states htRK\mbh_t \in \reals^K. The internal states follow conditionally linear dynamics,

ct=Ftct1+bt\begin{align*} \mbc_{t} &= \mbF_t \mbc_{t-1} + \mbb_t \end{align*}

where

Ft=diag(ft,1,,ft,K)ft,k=σ(W(f)ht1+B(f)xt1).\begin{align*} \mbF_t &= \diag(f_{t,1}, \ldots, f_{t,K}) \\ f_{t,k} &= \sigma(\mbW^{(f)} \mbh_{t-1} + \mbB^{(f)} \mbx_{t-1}). \end{align*}

The bounded entries ft,k[0,1]f_{t,k} \in [0,1] ensure stability. When ft,k1f_{t,k} \approx 1, the state is propagated, and when ft,k0f_{t,k} \approx 0, the state is forgotten. Thus, ft=(ft,1,,ft,K)[0,1]K\mbf_t = (f_{t,1}, \ldots, f_{t,K})^\top \in [0,1]^K are called the forget gates, and they are parameterized by the matrices W(f)\mbW^{(f)} and B(f)\mbB^{(f)}.

The affine term is determined by,

bt=gtitgt=σ(W(g)ht1+B(g)xt1)it=σ(W(i)ht1+B(i)xt1)\begin{align*} \mbb_t &= \mbg_t \odot \mbi_t \\ \mbg_t &= \sigma(\mbW^{(g)} \mbh_{t-1} + \mbB^{(g)} \mbx_{t-1}) \\ \mbi_t &= \sigma(\mbW^{(i)} \mbh_{t-1} + \mbB^{(i)} \mbx_{t-1}) \end{align*}

The vector gt[0,1]K\mbg_t \in [0,1]^K plays the role of an input gate, and the input s themselves are given by it[0,1]K\mbi_t \in [0,1]^K.

Finally, the hidden states ht\mbh_t are gated functions of the internal state passed through a nonlinearity,

ht=ottanh(ct)ot=σ(W(o)ht1+B(o)xt1)\begin{align*} \mbh_t &= \mbo_t \odot \tanh(\mbc_t) \\ \mbo_t &= \sigma(\mbW^{(o)} \mbh_{t-1} + \mbB^{(o)} \mbx_{t-1}) \end{align*}

where ot[0,1]K\mbo_t \in [0,1]^K are the output gates. As in a vanilla RNN, the final prediction depends on a (generalized) linear function of the hidden state, g(ht;θ)g(\mbh_t; \mbtheta).

We can think of an LSTM as an RNN that operates on an extended state (ct,ht)R+K×[1,1]K(\mbc_t, \mbh_t) \in \reals_+^K \times [-1, 1]^K. The forget gates let the eigenvalues of Ft\mbF_t to be close to one, allowing cell states to be propagated for long periods of time on the forward pass, and gradients to be backpropagated without vanishing on the backward pass.

There are many variants of gated RNNs. Besides the LSTM, the most commonly used in the gated recurrent unit (GRU), which has a slightly simplified architecture. See Goodfellow et al. (2016) (ch. 10) for more detail.

Other Variations and Uses of RNNs

We motivated RNNs from an autoregressive modeling persepective, but they are useful in other sequential data settings as well. For example, suppose we want to predict the sentiment of a review yRy \in \reals given a variable-length sequence of input words x1:T\mbx_{1:T}. We can use an RNN to summarize the input sequence in terms of a hidden state for prediction,

p(yx1:T)=p(yg(ht;θ)).\begin{align*} p(y \mid \mbx_{1:T}) &= p(y \mid g(\mbh_t; \mbtheta)). \end{align*}

Sequence to Sequence Models

Sometimes we want to map one sequence x1:T\mbx_{1:T} to another sequence y1:T\mby_{1:T'}. The sequences may be of different length; e.g., when we want to translate a sentence from English to French. Again, we can train an encoder RNN to produce a hidden state hT\mbh_T that then becomes the initial condition for a decoder RNN that generates the output sequence. Formally,

p(y1:Tx1:T)=t=1Tp(yty1:t1,x1:T)=t=1Tp(ytht,hT)\begin{align*} p(\mby_{1:T'} \mid \mbx_{1:T}) &= \prod_{t=1}^{T'} p(\mby_{t} \mid \mby_{1:t-1}, \mbx_{1:T}) \\ &= \prod_{t=1}^{T'} p(\mby_{t} \mid \mbh'_t, \mbh_T) \end{align*}

where hT\mbh_T is the output of an RNN that processed x1:T\mbx_{1:T}, and ht\mbh'_t is the state of an RNN that runs over y1:T\mby_{1:T}.

Bidirectional RNNs

In the example above, one challenge is that the hidden state hT\mbh_T obtained by processing x1:T\mbx_{1:T} may not adequately represent early inputs like x1x_1. For these purposes, you can use a bidirectional RNN that runs one recursion forward x1,,xT\mbx_1, \ldots, \mbx_T and another backward xT,,x1\mbx_T, \ldots, \mbx_1 to produce two hidden states at each time tt. These combined states can then be passed into the decoder.

Deep RNNs

As with deep neural networks that stack layer upon layer, we can stack RNN upon RNN to construct a deeper model. In such models, the outputs of one layer, g(ht(i);θ(i))g(\mbh_t^{(i)}; \mbtheta^{(i)}) become the inputs to the next layer, xt(i+1)\mbx_t^{(i+1)}. Then we can backpropagate gradients through the entire stack to train the model.

HMMs Are RNNs Too!

It turns out, we’ve already seen an RNN in this class! We presented Hidden Markov Models (HMMs) as latent variable models with hidden states z1:T\mbz_{1:T} and conditionally indepedent observations x1:T\mbx_{1:T}, but we can also view them as an autoregressive model in which

p(xtx1:t1)=k=1Kp(zt=kx1:t1)p(xtzt=k)=ztαt,kp(xtzt=k)\begin{align*} p(\mbx_t \mid \mbx_{1:t-1}) &= \sum_{k=1}^K p(z_t=k \mid \mbx_{1:t-1}) \, p(\mbx_t \mid z_t=k) \\ &= \sum_{z_t} \overline{\alpha}_{t,k} \, p(\mbx_t \mid z_t=k) \end{align*}

where, αtΔK1\overline{\mbalpha}_t \in \Delta_{K-1} are the normalized forward messages from the forward-backward algorithm. They followed a simple recursion,

αt+1=P(αtltαtlt)\begin{align*} \overline{\mbalpha}_{t+1} &= \mbP^\top \left( \frac{\overline{\mbalpha}_t \odot \mbl_t}{\overline{\mbalpha}_t^\top \mbl_t} \right) \end{align*}

with ltRK\mbl_t \in \reals^K is the vector of likelihoods with entries lt,k=p(xtzt=k)l_{t,k} = p(\mbx_t \mid z_t=k) and PRK×K\mbP \in \reals^{K \times K} is the transition matrix.

Categorical HMMs

Consider an HMM with categorical emissions,

p(xtzt)=Cat(xtczt),\begin{align*} p(\mbx_t \mid z_t) &= \mathrm{Cat}(\mbx_t \mid \mbc_{z_t}), \end{align*}

where xt\mbx_t is a one-hot encoding of a variable that takes values in {1,,V}\{1,\ldots,V\}, and ckΔV1\mbc_k \in \Delta_{V-1} for k=1,,Kk=1,\ldots,K are pmfs. Define the matrix of likelihoods CRV×K\mbC \in \reals^{V \times K} to have columns ck\mbc_k,

C=[c1cK].\begin{align*} \mbC &= \begin{bmatrix} | & & | \\ \mbc_1 & \cdots & \mbc_K \\ | & & | \end{bmatrix}. \end{align*}

The HMM parameters are θ=(P,C)\mbtheta = (\mbP, \mbC). (Assume the initial distribution is fixed, for simplicity.)

For a categorical HMM, we can write the likelihoods as lt=Cxt\mbl_t = \mbC^\top \mbx_t so that the the forward recursions simplify to,

αt+1=P(αtCxtαtCxt)=f(αt,xt;θ)\begin{align*} \overline{\mbalpha}_{t+1} &= \mbP^\top \left( \frac{\overline{\mbalpha}_t \odot \mbC^\top \mbx_t}{\overline{\mbalpha}_t^\top \mbC^\top \mbx_t} \right) \\ &= f(\overline{\mbalpha}_t, \mbx_t; \mbtheta) \end{align*}

Likewise, the autoregressive distributions reduce to,

p(xtx1:t1)=Cat(xtCαt)=Cat(xtg(αt;θ)).\begin{align*} p(\mbx_t \mid \mbx_{1:t-1}) &= \mathrm{Cat}(\mbx_t \mid \mbC \overline{\mbalpha}_t) \\ &= \mathrm{Cat}(\mbx_t \mid g(\overline{\mbalpha}_t; \mbtheta)). \end{align*}

Framed this way, a categorical HMM can be seen as a simple recurrent neural network!

A Cool Trick for Computing the Gradients

This formulation suggests that we could estimate the parameters of an HMM by directly maximizing the log likelihood,

L(θ)=logp(x1:T;θ)=t=1Tlogp(xtx1:t1;θ).\begin{align*} \cL(\mbtheta) &= \log p(\mbx_{1:T}; \mbtheta) = \sum_{t=1}^T \log p(\mbx_t \mid \mbx_{1:t-1}; \mbtheta). \end{align*}

With automatic differentiation at our disposal, this sounds like it might be a lot easier!

Let’s pursue this idea a little further. First, we’d prefer to do unconstrained optimization, so let’s parameterize the model in terms of logP\log \mbP and logC\log \mbC. When we need the constrained versions, we will just apply the softmax to obtain simplex vectors:

softmax(logck)=(elogck,1v=1Velogck,v,,elogck,Vv=1Velogck,v)\begin{align*} \mathrm{softmax}(\log \mbc_k) &= \left( \frac{e^{\log c_{k,1}}}{\sum_{v=1}^V e^{\log c_{k,v}}}, \ldots, \frac{e^{\log c_{k,V}}}{\sum_{v=1}^V e^{\log c_{k,v}}} \right)^\top \end{align*}

To maximize the likelihood with gradient ascent, we need the Jacobians, L(θ)logP\frac{\partial \cL(\mbtheta)}{\partial \log \mbP} and L(θ)logC\frac{\partial \cL(\mbtheta)}{\partial \log \mbC}. For the RNNs above, we computed them using backpropagation through time, but here we can use an even cooler trick.

First, note that the posterior distribution in an HMM can be written as an exponential family,

p(z1:Tx1:T;θ)=exp{logp(z1:T,x1:T;θ)logp(x1:T;θ)}=exp{logp(z1;π0)+t=2Tlogp(ztzt1;P)+t=1Tlogp(xtzt;C)logp(x1:T;θ)}=exp{k=1KI[z1=k],logπ0,k+t=2Ti=1Kj=1KI[zt1=izt=j],logPi,j+t=1Tv=1Vk=1KI[xt=vzt=k],logCv,kA(θ)}\begin{align*} p(\mbz_{1:T} \mid \mbx_{1:T}; \mbtheta) &= \exp \left\{ \log p(\mbz_{1:T}, \mbx_{1:T}; \mbtheta) - \log p(\mbx_{1:T}; \mbtheta) \right\} \\ &= \exp \left\{ \log p(z_1; \mbpi_0) + \sum_{t=2}^T \log p(z_t \mid z_{t-1}; \mbP) + \sum_{t=1}^T \log p(\mbx_t \mid z_{t}; \mbC) - \log p(\mbx_{1:T}; \mbtheta) \right\} \\ &= \exp \bigg\{ \sum_{k=1}^K \langle \bbI[z_1 =k], \log \pi_{0,k} \rangle \\ &\hspace{6em} + \sum_{t=2}^T \sum_{i=1}^K \sum_{j=1}^K \langle \bbI[z_{t-1}=i \wedge z_{t} = j], \log P_{i,j} \rangle \\ &\hspace{12em} + \sum_{t=1}^T \sum_{v=1}^V \sum_{k=1}^K \langle \bbI[x_t = v \wedge z_t = k], \log C_{v,k} \rangle \\ &\hspace{18em} - A(\mbtheta) \bigg\} \end{align*}

where the log marginal likelihood A(θ)=logp(x1:T;θ)A(\mbtheta) = \log p(\mbx_{1:T}; \mbtheta) is the log normalizer.

Recall that for exponential family distributions, gradients of the log normalizer yield expected sufficient statistics. Thus,

A(θ)logCv,k=t=1TI[xt=v]Ep(z1:Tx1:T;θ)[I[zt=k]]\begin{align*} \frac{\partial A(\mbtheta)}{\partial \log C_{v,k}} &= \sum_{t=1}^T \bbI[x_{t} = v] \cdot \bbE_{p(\mbz_{1:T} \mid \mbx_{1:T}; \mbtheta)} \left[\bbI[z_{t}=k] \right] \end{align*}

and

A(θ)logPi,j=t=2TEp(z1:Tx1:T;θ)[I[zt1=izt=j]]\begin{align*} \frac{\partial A(\mbtheta)}{\partial \log P_{i,j}} &= \sum_{t=2}^T \bbE_{p(\mbz_{1:T} \mid \mbx_{1:T}; \mbtheta)} \left[ \bbI[z_{t-1}=i \wedge z_{t} = j] \right] \end{align*}

The gradients are essentially the posterior marginals and pairwise marginals we computed in the EM algorithm!

In the M-step of EM, we solve for the parameters that satisfy the constrained optimality conditions, whereas in SGD we just take a step in the direction of the gradient. EM tends to converge must faster in practice for this reason.

Conclusion

Recurrent neural networks are foundational models for sequential data — both as machine learning tools and as theoretical models of neural computation. The key insight is that any autoregressive model can be parameterized through a fixed-dimensional hidden state updated recursively. Gated architectures like LSTMs address the vanishing-gradient problem by allowing cell states to be propagated across many time steps.

References
  1. Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press. https://www.deeplearningbook.org