Discrete diffusion model

From HandWiki

In machine learning, discrete diffusion models are a class of diffusion models, which themselves are a class of latent variable generative models. Each discrete diffusion model consists of two major components: the forward jump diffusion process, and the reverse jump diffusion process. The goal of diffusion modeling is, given a given dataset and a forward process, to learn a model for the reverse process, such that the reverse process can generate new elements that are distributed similarly as the original dataset. A trained discrete diffusion model can be sampled in many ways, which trades off computational efficiency and sample quality. In general, higher quality data can be obtained, but at the price of higher computational cost.

In standard diffusion modeling, the diffusion process takes place over a state space that is continuous space of n, but over a discrete set S. A discrete set is simply a set where one cannot speak of "infinitesimally close" points. Points can be more or less separated from each other, but the separation is always a finite number. This in particular means the standard framework of continuous diffusion does not apply, since it uses gaussian noise, which is continuous. Nevertheless, an analogous theory can be produced.

Discrete diffusion is usually used for language modeling.[1][2] In practice, the state space S is not only discrete, but finite, so this is what we will assume from now on.

Continuous time Markov process

In the case of continuous state space, during the forward discrete diffusion process, at each step tt+dt, we mix in an infinitesimal amount of gaussian noise dxt=12β(t)xtdt+β(t)dWt. This changes the probability density function, by first a convolution with the density of a gaussian, followed by a scaling.

In the case of discrete state space, the gaussian noise must be replaced by a noise that takes values over a finite set. For example, if the noise is the uniform distribution over S, then the probability distribution at time t+dt satisfiesqt+dt(x)=(1dt)qt(x)+dt(1|S|ySqt(y))More succinctly,tqt(x)=(11|S|)qt(x)+yS,yx1|S|qt(y)In general, we do not need to convolve with a uniformly distributed noise, but with an arbitrary noise process. That is, we use an arbitrary matrix Qt such thattqt(y)=xSQt(y,x)qt(x)where Qt is called the rate matrix. Any matrix may be used as a rate matrix if it has non-negative off-diagonals, and each column sums to 0:Qt(y,x)0yx,ySQt(y,x)=0xA continuous time Markov chain (CTMC) is defined by a continuous function Q that maps any time t[0,T) to a rate matrix Qt. Given the function Q, time-evolution under the CTMC is done as follows: Given state xt at time t, and given an infinitesimal dt, the state at t+dt is xt+dt, such thatPr(xt+dt|xt)={1+Qt(xt+dt,xt)dtif xt+dt=xtQt(xt+dt,xt)dtelseThis implies that the probability distribution function evolves according totqt(y)=xSQt(y,x)qt(x)which is what we previously specified.

Backward process

Similarly to the case of continuous diffusion, in discrete diffusion, there exists a backward diffusion process Q¯t:s(x,t)y:=qt(y)qt(x),Q¯t(y,x):={s(x,t)yQt(x,y)if yxy:yxQ¯t(y,x)if y=xwhere s(x,t)y should be interpreted as the discrete score or concrete score, since, abusing notation a bit, the score function is lnρt(x)=1dx(ρt(x+dx)ρt(x)1).

If we picture the distribution qt as a bunch of point-masses, one per state xS, then the forward diffusion from time t to t+dt is performed by removing Qt(x,y)qt(y)dt from the mass at y and moving it to the mass at x, for each pair xy. Thus, the process is reversed in detail by the CTMC defined by Q¯, since Q¯t(y,x)qt(x)=Qt(x,y)qt(y).

Given Q¯t, if we have a way to sample from qt, then we can sample from qtdt by first sampling xtqt, then sampling xtdt according toPr(xtdt|xt)={1+Q¯t(xtdt,xt)dtif xtdt=xtQ¯t(xtdt,xt)dtelse

Overall plan of score-matching discrete diffusion modeling

Similar to score-matching continuous diffusion, score-matching discrete diffusion is a method to sample an initial distribution.

If we have a certain function sθ that approximates the true score function sθ(x,t)ys(x,t)y, then it allows a corresponding Q¯θ to be defined in the same way.

If we also have a base distribution qbase such that it is easy to sample from, and approximately equal to the true terminal distribution qbaseqT, then we can perform the backward CTMC with Q¯θ and qTθ:=qterminal.

When both approximations are good, the backward CTMC would give q0θq0. This is the idea of score-matching discrete diffusion modeling.

If qdata is sharp, in the sense that for some x,x, we have qdata(x)qdata(x), then the score function would diverge as 1/t at the t0 limit. To avoid this in practice, it is common to use early stopping, which is to stop the backward process at some time δ>0, and sample from qδθ instead of q0θ.

Tractable forward processes

The theory of CTMC works for any continuous choice of rate matrices Q. However, most choices are computationally expensive and cannot be used in practice.

In the case of continuous diffusion, the gaussian noise is used for the simple reason that the sum of any number of gaussians is still a gaussian. This allows one to sample any xtρt by sampling a single x0ρ0, followed by a single gaussian noise z𝒩(0,I), and let xt=α¯tx0+σtz, without needing any xs for any 0<s<t.

Similarly, the choice of rate matrices should also allow us to "skip forward" without needing any intermediate steps.

The uniform noising process is defined bytqt(x)=(11|S|)qt(x)+yS,yx1|S|qt(y)To see how to skip forward, note that the uniform noising process is equivalent to the following process: to evolve from time t to time t+dt, either don't change anything with probability 1dt, or sample a random state uniformly with probability dt. Since sampling uniform random states twice is the same as sampling it once, we see that the only question is whether we have ever sampled a random state. As time goes on, the probability of not sampling a random state decays exponentially. Therefore,qt(y|x0)=1y=x0et+1|S|(1et)Time may be rescaled as desired. For example, the CTMC defined for t[0,1)qt(y|x0)=1y=x0(1t)+1|S|tis produced by the scaled-time uniform noising processtqt(x)=11t(11|S|)qt(x)+11tyS,yx1|S|qt(y)The construction works in general for arbitrary rescaled time, and arbitrary noisy distribution on S. The idea is that if we have a fixed reference noise distribution qnoise, then sampling from it twice is the same as sampling from it once. Therefore, the noising process tqt(x)=σ(t)qt(x)+σ(t)qnoise(x) produces qt(y|x0)=eσ(t)1y=x0+(1eσ(t))qnoise(y). Here, the time-rescaling function σ(t) must be strictly monotonic.

More generally, if Qt=σ(t)Q, then qt(|x0) is the x0-th column of the matrix eσ(t)Q. If σ(t), then qt(|x0)qnoise.

In practice, when |S| is sufficiently large, only two processes are efficient enough for training in practice: the uniform process and the absorbing process:Quniform=1|S|[1|S|1111|S|1111|S|]=1|S|𝟏|S|𝟏|S|TI|S|×|S|,Qabsorb=[1000010000101110]The uniform process is simply the uniform noising process with a rescalable time, converging to the uniform distribution on S.

The absorbing process means that there is a special absorbing state sabsorbing, converging to the point distribution on the absorbing state δsabsorbing. During time [t,t+dt], if the state xtsabsorbing, then it transitions into xt+dt=sabsorbing with probability dt, and stays unchanged with probability 1dt, but if the xt=sabsorbing, then xt+dt=sabsorbing always. In diffusion language modeling, that special state is usually called [MASK], which originated from masked language modeling.

Score matching

A discrete diffusion model is usually a score-matching network. That is, it is a neural network that takes as input x,t,y, and approximately computes the discrete score function:sθ(x,t)ys(x,t)y=qt(y)qt(x)where θ is the weights of the network. Once some good weights are found, the score network can be used to produce the backward diffusion processQ¯tθ(y,x):={sθ(x,t)yQt(x,y)if yxy:yxQ¯tθ(y,x)if y=xand produce samples that are approximately distributed as q0:=qdata. There are different algorithms for training a score-matching network.

The concrete score matching algorithm minimizes the L2 loss by stochastic gradient descentLCSM(θ):=𝔼t[𝔼x0qdata,xtqt|0(|x0)[y:yxt(sθ(xt,t)yqt(y)qt(xt))2]]where the outer expectation 𝔼t means averaging over a randomly sampled time-instance. For example, if we allow t[0,1) in the definition of the forward CTMC process, then a common choice is to sample tUniform([0,1]) during training.

The L2 loss has the problem that sθ should never be negative, but the L2 loss does not prevent sθ from becoming negative.

SEDD

The Score Entropy Discrete Diffusion (SEDD) algorithm[3] minimizes a certain score entropy loss:LSE(θ):=𝔼t[𝔼x0qdata,xtqt|0(|x0)[y:yxtwt,xt,y(sθ(xt,t)yqt(y)qt(xt)lnsθ(xt,t)y+K(qt(y)qt(xt)))]]where K(a):=a(lna1) is just a function, and wt,x,y is an arbitrary array of positive numbers that can be adjusted as hyperparameters of the training algorithm. In the next section, we will see that setting wt,xt,y=Qt(y,xt) is a z

The expression within the brackets is proportional to the Bregman divergence for ln, the negative logarithmic function:sθ(xt,t)yqt(y)qt(xt)lnsθ(xt,t)y+K(qt(y)qt(xt))=qt(y)qt(xt)Dln(sθ(xt,t)y,qt(y)qt(xt))Since ln is used in definition of entropy, this explains why LSE is called the "score entropy loss". Since the loss approaches infinity where sθ approaches zero, the score entropy loss prevents negative values of sθ

Since the Bregman divergence is zero only when the two terms are equal, the score entropy loss is minimized to a value of zero iff the score matching is perfect: sθ(x,t)y=qt(y)qt(x).

There are 2 losses equivalent to SEDD. The implicit score entropy loss is LISE(θ):=𝔼t[𝔼x0qdata,xtqt|0(|x0)[y:yxt(wxt,ysθ(xt,t)ywy,xtsθ(y,t)xt)]]which is equal to LSE(θ)+C, where C=𝔼[K] is independent of θ, and therefore optimization of LSE is equivalent to the optimization of LISE. However, the ISE loss requires evaluating the score network for |S| times per sample of t,x0,xt. This does not scale.

The denoising score entropy loss is LDSE(θ):=𝔼t[𝔼x0qdata,xtqt|0(|x0)[y:yxtwxt,y(sθ(xt,t)yqt|0(y|x0)qt|0(xt|x0)lnsθ(xt,t)y)]]which is equal to LISE. This can be derived by using the identity 𝔼x0|xt[qt|0(y|x0)qt|0(xt|x0)]=qt(y)qt(xt). Since it only evaluates the score network once per sample of t,x0,xt. This does scale.

Variational inference

The score entropy objectives can be cast into the variational inference form. Given a base distribution qbase on S and a backward CTMC defined by Q¯tθ and qbase, let q0θ denote the resulting model distribution over x0. For a fixed data point x0, the diffusion weighted denoising score entropy (DWDSE) loss is defined asLDWDSE(x0):=0T𝔼xtqt|0(|x0)[y:yxtQt(y,xt)(sθ(xt,t)yqt|0(y|x0)qt|0(xt|x0)lnsθ(xt,t)y+K(qt|0(y|x0)qt|0(xt|x0)))]dt,where qt|0(|x0) is the forward CTMC kernel defined by Qt, and K is the function introduced above. It is minimized when qt(y)qt(xt)=sθ(xt,t)y for all y:yxt such that Qt(y,xt)>0. If Qt is a sparse matrix, then the expression for LDWDSE accordingly can be simplified, since most state transitions are impossible.

For the diffusion and forward probabilities defined above,lnq0θ(x0)LDWDSE(x0)+DKL(qT|0(|x0)qbase),where DKL is the Kullback–Leibler divergence. In particular, when qbase=qT|0, minimizing the expectation of LDWDSE over x0qdata minimizes an upper bound on the expected negative log-likelihood 𝔼x0qdata[lnq0θ(x0)].

Adaption to sequence modeling

The most common application of discrete diffusion is for sequence modeling. For these, the discrete state space S usually has a particular structure that can and must be exploited. For example, in language modeling, there are only finitely many different tokens allowed. The set of allowed tokens is called the vocabulary Σ, and its size is the vocabulary size |Σ|, which is always finite. For a given sequence length n, the state space is the space of all length-n sequences of tokens, which is S=Σn of size |S|=|Σ|n.

Forward process

Since the size of the state space grows exponentially with sequence length, it is too large to be directly modeled. For example, if a sequence has 10 tokens, and each token can be chosen from a list of 100 valid tokens, then the full state space has size 10010, which is intractable.

Because of this, the standard method is to consider only tokenwise forward processes, i.e. those that factor into independent forward processes over each token. Tokenwise forward processes do not need each token to undergo the same forward process, though in practice, often all tokens undergo the same forward process.

Let the sequence xt have n tokens. Let i1:n index over the tokens in the sequence, so that xt=xt,1,xt,2,,xt,n. Let the forward process for token i be defined by the rate matrix Qt,i(yi,xi), then the rate matrix for the full sequence Qt(y,x) satisfiesQt(y,x)={i1:nyi:yixiQt,i(yi,xi)if y=xQt,i(yi,xi)if y,x differ at index i0if y,x differ by more than 1 tokenIntuitively, the rate matrices are simply added together, since the probability that two jumps occur during the same infinitesimal slice of time [t,t+dt] is on the order of dt2, which is infinitesimal compared to the probability that one jump occurs, which is on the order of dt.

In language modeling, usually the vocabulary size is on the order of 100,000, which means that an arbitrary matrix is too large to fit into memory, so the only case in common use is where all tokens use the exact same rate matrix Qtok, which is equal to one of the aforementioned cases Quniform,Qabsorb. That is, there exists some function σ(t)>0 such that Qt,i=σ(t)Qtok for all token indices i1:n and all times t[0,T).

Given this set up, the forward process factors tokenwiseqt|0(xt|x0)=exp(σ(t)Q)xt,x0=i1:nexp(σ(t)Qtok)xt,i,x0,iNote that the factorization is conditional on x0. Without the conditioning, it fails, because the initial distribution q0(x) does not factor tokenwise as q0(x)=i1:nq0,i(xi). Thus, the following does not factorize: the backward process q0|t(x0|xt), the marginalized distribution qt(xt), and the score s(xt,t)y.

Score function

In general, the discrete score s(xt,t)y is not tokenwise, i.e. in general, there does not exist some function f such thatqt|0(y)qt|0(xt)=f(qt|0(y1)qt|0(xt,1),,qt|0(yn)qt|0(xt,n))Nevertheless, in this case, variational inference allows a simplification. Specifically, since Qt(y,xt)=0 when y,xt differ at more than 1 token, the summation y:yxt in the definition of LDWDSE need only include the cases where y,xt differ at exactly 1 token. That is, the training process need only minimize the following loss:LDWDSE(θ):=0T𝔼x0qdata,xtqt|0(|x0)[i1:n,x^iΣ,x^ixt,iQt,i(x^i,xt,i)(sθ(xt,i,t,i)x^iqt|0(xtix^i|x0)qt|0(xt|x0)lnsθ(xt,i,t,i)x^i)]dtwhere in the notation, xtix^i means the sequence obtained by replacing the i-th entry of xt by x^i. Consequently, the score-matching model sθ need to output only n(|Σ|1) scores for each 1-token modification, instead of |Σ|n1 for each full-sequence modification. If Qt,i is sparse, then the expression can be simplified further, such as when Qt,i=σ(t)Qtok, since most single-token transitions are impossible.

The theoretical minimum LDWDSE is achieved when sθ(xt,i,t,i)x^i=s(xt,i,t,i)x^i=qt(xtix^i)qt(xt)for all t[0,T),i1:n,xt,x^ixt,i such that p(xt)>0 and Qt,i(x^i,xt,i)>0.

Backward process

Ideally, if the model learns the score exactly, then it defines a backward diffusion process that exactly reverses the forward diffusion process. Its rate matrix isQ¯t(y,x):={y:yxQ¯t(y,x)if y=xs(xi,t)ytQt(xi,yi)if y,x differ at index i0if y,x differ by more than 1 tokenIntuitively, since the forward process only changes 0 or 1 tokens at any moment in time, so does the backward process.

However, if the score is not exactly matched sθ(xt,i,t,i)x^iqt(xtix^i)qt(xt), then it produces a score-matching error.

Furthermore, the backward process in practice cannot be performed in continuous time, but only in discrete time. This produces a time-discretization error. The Gillespie algorithm cannot in general perform the backward process exactly, since for fixed i,x,x^i, the score s(x,t,i)x^i changes as t changes. That is, the backward CTMC cannot be solved exactly as a backward discrete-time Markov chain. This contrasts with the forward process, where for Quniform,Qabsorb, the forward CTMC is exactly solvable as a discrete-time Markov chain. This is similar to how in continuous diffusion, the forward diffusion is exactly computable at discrete time instances, but the backward process requires integration over continuous time.

Using the Gillespie algorithm, or other discrete-time algorithms, produces a Euler method approximation error. It can be improved by using better stochastic integrators and using more integration steps in the backward process.

Similar to how the Gillespie algorithm can be accelerated by tau-leaping, the backward process can be accelerated by changing more than 1 token per discretized time-step.

Given the noising process Qt=σ(t)Q, we have qt|0(xt|x0)=exp(σ(t)Q)xt,x0. By Bayes's theorem, it is inverted by the discrete Tweedie's formula:q0|t(x0|xt)=exp(σ(t)Q)xt,x0ySexp(σ(t)Q)x0,ypt(y)pt(xt)=exp(σ(t)Q)xt,x0ySexp(σ(t)Q)x0,ys(xt,t)yThis gives the Tweedie tau-leaping algorithm, which, for each backward time-interval [t,tΔt], randomly and independently transition each token i1:n. That is, for each token i1:n individually, sample xtΔt,i independently of the other tokens, with the probability of transitioning from xt to x^tΔt equal toexp(σtΔQ)xt,i,xtΔt,iytΣexp(σtΔQ)xtΔt,i,ysθ(xt,i,t,i)ytwhere σtΔt=(σ(t)σ(tΔt)).

The lower accuracy of tau-leaping is due to the approximation by tokenwise independence. In general, given ij,t,x,x^i,x^j, the score s(x,t,j)x^js(xix^i,t,j)x^j. Concretely, suppose that in an exact backward simulation, exactly two tokens xt,i,xt,j changes to x^i,x^j during a backward time-interval [t,tΔt], then it matters whether tokens i or token j changed first, since that affects the rate of change at the other token.

Conditional backward process

The above framework considers unconditional generation. That is, sampling a full sequence x0 approximately from the sample distribution pdata. Certain tasks require generating part of a sequence while holding other parts of a sequence fixed. For example, in prompt engineering or few-shot learning, the first few sentences are fixed, and only the rest of the sequence can be changed.

In general, let IJ be a partition of 1:n, then the problem is to generate x0,I conditional on x0,J, that is, sampling x0,Iq0(|x0,J). By Bayes's theorem,qt(yt,I|yt,J)qt(xt,I|xt,J)=qt(yt)qt(xt)when xt,J=yt,J. Thus, a score-matching model for unconditional sequence generation is also a score-matching model for the conditional case:sθ(xt,i,t,i)x^is(xt,t,i)x^i=qt(xt,Iix^i)qt(xt)=qt(xt,Iix^i|xJ)qt(xt,I|xJ)for any t,i,x^i,xt such that xt,J=xJ. Thus, all previous conditional backward process sampling algorithms still work, simply by fixing xt,J=xJ.

Error analysis

In general, the backward process of a discrete diffusion samples a probability distribution that differs from qdata. This creates the following sources of error:

  • The mixing error, due to beginning the backward diffusion process at qbase rather than qT.
  • The score-matching error, due to using sθ rather than s to define the rate matrix of the backward diffusion process.
  • The time-discretization error, due to using discrete time, not continuous time, when integrating through the backward diffusion process. This is analogous to the error of the Euler algorithm.
  • The independence error, due to using tau-leaping, which erroneously allows more than one token to change at a time in a probabilistically independent way.
  • The early stopping error, due to ending the backward diffusion process at qδθ rather than at q0θ.

Each error can be decreased or eliminated, usually at the price of increased cost of compute.

  • Mixing error can be decreased by running the diffusion time for longer, or eliminated by the analog of the "zero SNR" fix for diffusion in continuous space.[4]
  • Score-matching error can be decreased by better training of the score-matching model.
  • Time-discretization error can be decreased by making smaller time-steps, or essentially eliminated by numerical integration plus the uniformization trick for CTMC.[5]
  • Independence error can be eliminated by only changing one token at a time during backward diffusion.
  • Early stopping can be eliminated if the score does not diverge as t0. If the score does diverge, then the stopping time δ can be decreased, but the time-steps would need to decrease in tandem.

See also

Further reading

References

  1. Gulrajani, Ishaan; Hashimoto, Tatsunori B. (2023-12-15). "Likelihood-Based Diffusion Language Models" (in en). Advances in Neural Information Processing Systems 36: 16693–16715. https://proceedings.neurips.cc/paper_files/paper/2023/hash/35b5c175e139bff5f22a5361270fce87-Abstract-Conference.html. 
  2. Campbell, Andrew; Benton, Joe; De Bortoli, Valentin; Rainforth, Thomas; Deligiannidis, George; Doucet, Arnaud (2022-12-06). "A Continuous Time Framework for Discrete Denoising Models" (in en). Advances in Neural Information Processing Systems 35: 28266–28279. https://proceedings.neurips.cc/paper_files/paper/2022/hash/b5b528767aa35f5b1a60fe0aaeca0563-Abstract-Conference.html. 
  3. Lou, Aaron; Meng, Chenlin; Ermon, Stefano (2024-06-06). "Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution". arXiv:2310.16834 [stat.ML].
  4. Lin, Shanchuan; Liu, Bingchen; Li, Jiashi; Yang, Xiao (2024). "Common Diffusion Noise Schedules and Sample Steps Are Flawed" (in en). IEEE/CVF Winter Conference on Applications of Computer Vision (WACV). pp. 5404–5411. https://openaccess.thecvf.com/content/WACV2024/html/Lin_Common_Diffusion_Noise_Schedules_and_Sample_Steps_Are_Flawed_WACV_2024_paper.html. 
  5. Chen, Hongrui; Ying, Lexing (2024-02-14), Convergence Analysis of Discrete Diffusion Model: Exact Implementation through Uniformization, http://arxiv.org/abs/2402.08095