Discrete Denoising Diffusion Models#
Denoising diffusion models (DDMs) are currently the state-of-the-art approach for image generation, but can they be used for generating discrete data like language and protein sequences? We derived the basic principles for DDMs with continuous-valued data last week. Here, we will show how these concepts extend to discrete data, and how concepts like continuous-time diffusion and the score function in the reverse diffusion SDE extend to the discrete setting.
Discrete-Time, Discrete-State DDMs#
We will start in discrete time and develop a DDM for discrete-valued data
Setup#
Let \(x_0 \in \cX\) denote a data point.
Let \(|\cX| = S < \infty\) is the vocabulary size.
Let \(q_0(x_0)\) denote the data distribution
Noising process#
Use a Markov chain that gradually converts \(x_0\) to \(x_T \sim q_T(x_T)\), which is pure “noise,”
For example, \(q_T(x_T) = \mathrm{Unif}_{\cX}(x_T)\) could be achieved by,
Masking diffusion#
One of the most effective noising processes is the masking diffusion which introduces a \(\mathsf{MASK}\) token that is an absorbing state of the Markov process,
Reverse process#
The reverse of the noising process is also a Markov chain! (We can derive this from the graphical model.) It factors as,
The reverse transition probabilities can be obtained via Bayes’ rule,
Alternatively, we can express the reverse transition probabilities in terms of the denoising distributions \(q_{0|t+1}(x_0 \mid x_{t+1})\) as follows,
Explanation
In the last line we used the fact that,
which follows from the chain rule.
Approximating the Reverse Process#
Problem: We know everything in the reverse transition probability but the denoising distribution \(q_{0|t_1}(x_0 \mid x_{t+1})\).
Solution: Learn it! Parameterize the reverse transition probability as,
where \(p_{0|t+1}(x_0 \mid x_{t+1}; \theta)\) is a learned, approximate denoising distribution
We can then sample from the approximate reverse process one step at a time from \(T\) down to \(0\),
The Evidence Lower Bound#
We will estimate the model parameters \(\theta\) by maximizing the ELBO, which is a sum over data points of,
where
is the ELBO for the last term in the sum.
Note that we are again using Rao-Blackwellization to write the ELBO in terms of expectations over fewer random variables for each term in the sum.
Important
We choose \(q\) such that the marginal distribution \(q_{t|0}(x_t)\) and interpolating distribution \(q(x_t \mid x_{t+1}, x_0)\) are available in closed form!
Continuous-Time Markov Chains#
We ended our discussion of continuous-state DDMs by noting that in the continuous-time limit the noising and reverse processes are SDEs. For discrete-state DDMs, the continuous-time limit involves Continuous-Time Markov Chains (CTMCs), which are closely related to Poisson processes!
Properties of CTMCs#
A CTMC is a stochastic process \(\{x_t : t\in[0,T]\}\) taking values on a finte state space \(\cX\) such that:
Sample paths \(t \mapsto x_t\) are right-continuous and have finitely many jumps
The Markov propert holds:
\[\begin{align*} \Pr(x_{t+\Delta t} = j \mid \{x_s: s\leq t\}) &= \Pr(x_{t+\Delta t}=j \mid x_t) \end{align*}\]
Transition Probabilities#
CTMCs are uniquely characterized by the transition distributions \(q_{t|s}(x_t=j \mid x_s=i)\) for \(t \geq s\).
These must satisfy the Chapman-Kolmogorov Equations
for \(s \leq u \leq t\).
Rate matrices#
Equivalently, we can identify a CTMC by its rate matrices
Intuitively, \(R_s(i \to j)\) is the amount of probability flow from \(i\) to \(j\) at time \(s\).
Properties of rate matrices:
\(R_s(i \to j) \geq 0\) for \(i \neq j\). (Probability flow must be outward.)
\(\sum_j R_s(i \to j) = 0\) (Probability is conserved.)
Define \(R_s(i) = \sum_{j\neq i} R_s(i \to j)\) to be the total outward probability flow.
A homogenous CTMC has a fixed rate matrix \(R_s \equiv R\) for all times \(s \in [0,T]\).
Gillespie’s Algorithm#
Consider a homogenous CTMC with rate matrix \(R\). We can simulate a draw from the CTMC using Gillespie’s Algorithm
Initialize \(x_0 \sim \pi_0\) and set \(t_0 = 0\), \(i=0\).
While \(t_i < T\)
Draw the waiting time \(\Delta_i \sim \mathrm{Exp}(R(x_i))\)
If \(t_i + \Delta_i > T\), return \(\{(t_j, x_j)\}_{j=0}^i\)
Else, set \(t_{i+1} \leftarrow t_i + \Delta_i\) and draw the next state
\[\begin{align*} x_{i+1} \sim \mathrm{Cat}\left(\left[\frac{R(x_i \to x_j)}{R(x_i)} \right]_{j \neq i} \right) \end{align*}\]
Does this look familiar?
Connection to Poisson processes#
We can cast a CTMC as a marked Poisson process with times \(t_i\) and marks \(x_i \in \cX\).
The process follows a conditional intensity function that depends on the history \(\cH_t\). In particular, the history contains the current state \(x_t\) (since the state path is right continuous),
Gillespie’s algorithm is using the Poisson superposition and Poisson thinning properties we discussed last time! To sample the waiting time, we are using the fact that \(\lambda(t \mid \cH_t) = \sum_x \lambda(t, x \mid \cH_t)\) is a Poisson process on the time of the next event. Once we sample the time, we use Poisson thinning to sample the next state (i.e. mark).
Rao and Teh [RT13] used this construction to develop a very clever Gibbs sampling algorithm for CTMCs!
CTMCs in Continuous-Time Discrete DDMs#
The reversal of a CTMC is another CTMC, and [CBDB+22] showed how to parameterize the reverse process of a discrete-state, continuous-time DDM in terms of the backward rates.
It turns out the backward rate is,
where the density ratio \(\frac{q_t(x_t=j)}{q_t(x_t=i)}\) can be seen as the analog of the score function for a discrete distribution.
Sampling the backward process tricky because the reverse rate is inhomogeneous, and Gillespies algorithm for inhomogeneous processes requires integrating rate matrices. Instead, [CBDB+22] propose to use a technique called tau-leaping to approximately sample the backward process. Then they use corrector steps to try to correct for the errors in the approximation. In recent work, we show how to develop more informative correctors for discrete diffusion with masking processes [ZSML24].
Conclusion#
Discrete DDMs are a nice way to wrap up this course! They combine old and new: Poisson processes and CTMCs, as well as modern deep generative models.
These models have recently been used for langugage modeling, and sampling from them can be much more faster than from an autoregressive model like a Transformer since many words can be generated in parallel. See Inception AI!