MDLM: Simple and Effective Masked Diffusion Language Models

Sahoo et al. Β· NeurIPS 2024 Β· arXiv 2406.07524

TL;DR

MDLM derives a clean, principled training objective for masked diffusion language models from first principles β€” starting from a continuous-time ELBO. It shows that a simple absorbing-state diffusion (tokens β†’ [MASK]) with the right loss weighting achieves strong perplexity results, providing the theoretical foundation that LLaDA later scales up.

β—†Architecture Overview
Problem: No principled training for discrete diffusion
Prior work (D3PM) had complex objectives with many terms
Key question
Insight: Continuous-time limit simplifies everything
Take the ELBO in continuous time β†’ clean closed-form loss
Derivation
Method: Absorbing-state masked diffusion + ELBO loss
Forward: token β†’ [MASK]
Absorbing state = masking
Loss: weighted CE on masks
Weight = 1/t for proper ELBO
Result: Strong perplexity, simple implementation
Theoretical foundation for LLaDA
~100 lines of core code

1. Background: Why We Need Better Theory

Discrete diffusion models (like D3PM) existed before MDLM, but they had issues:

  • Complex ELBO: D3PM's loss has many terms that are hard to balance
  • Auxiliary losses: needed extra loss terms to work well in practice
  • Gap from continuous: continuous diffusion (images) had cleaner theory β€” can we match it for discrete?

MDLM answers: yes. By taking the continuous-time limit of the discrete ELBO, we get an elegant, simple loss.

2. The Forward Process: Absorbing Diffusion

MDLM uses an absorbing-state forward process β€” each token independently transitions to [MASK] (the absorbing state) at rate Ξ²(t):

Forward transition probability
q(xt=[M]∣x0)=1βˆ’eβˆ’βˆ«0tΞ²(s) dsβ‰œ1βˆ’Ξ±tq(x_t = \texttt{[M]} \mid x_0) = 1 - e^{-\int_0^t \beta(s)\,ds} \triangleq 1 - \alpha_t
Ξ²(t)\beta(t)Noise rate at time t β€” controls how fast tokens get maskedΞ±t\alpha_tSurvival probability: probability that token is still unmasked at time teβˆ’int0tbeta(s),dse^{-\\int_0^t \\beta(s)\\,ds}Exponential decay β€” same math as radioactive decay!

This is more general than LLaDA's "mask with probability t". By choosing different Ξ²(t), we can have non-linear masking schedules. When Ξ²(t) = 1/(1-t), we recover LLaDA's linear schedule where Ξ±_t = 1-t.

3. The Continuous-Time ELBO

The key contribution: MDLM derives a clean ELBO in continuous time. Starting from the standard variational bound:

Continuous-time ELBO
log⁑p(x0)β‰₯Eq[βˆ’βˆ«01Ξ²(t)1βˆ’Ξ±tβˆ‘i:xti=[M]log⁑pΞΈ(x0i∣xt)⏟Lt dt]\log p(x_0) \geq \mathbb{E}_{q}\left[-\int_0^1 \underbrace{\frac{\beta(t)}{1-\alpha_t} \sum_{i: x_t^i = \texttt{[M]}} \log p_\theta(x_0^i \mid x_t)}_{L_t}\, dt\right]
log⁑p(x0)\log p(x_0)Log-likelihood of the data β€” what we want to maximizeβ‰₯\geqEvidence Lower BOund β€” maximizing the right side pushes up log p(xβ‚€)∫01⋯ dt\int_0^1 \cdots\, dtIntegrate over all timesteps β€” continuous version of summing over T discrete stepsΞ²(t)1βˆ’Ξ±t\frac{\beta(t)}{1-\alpha_t}Weighting factor β€” derived from the math, not hand-tuned. This is what makes it "principled"βˆ‘i:xti=[M]\sum_{i: x_t^i = \texttt{[M]}}Sum over masked positions onlylog⁑pΞΈ(x0imidxt)\log p_\theta(x_0^i \\mid x_t)Cross-entropy: how well does the model predict each masked token?

Practical Training Loss

In practice, we can't compute the integral β€” so we sample t uniformly and get:

MDLM training loss (Monte Carlo estimate)
LMDLM=βˆ’Et∼U(0,1)[Ξ²(t)1βˆ’Ξ±tβˆ‘i:xti=[M]log⁑pΞΈ(x0i∣xt)]\mathcal{L}_{\text{MDLM}} = -\mathbb{E}_{t \sim \mathcal{U}(0,1)}\left[\frac{\beta(t)}{1-\alpha_t} \sum_{i: x_t^i = \texttt{[M]}} \log p_\theta(x_0^i \mid x_t)\right]

Connection to LLaDA: With the linear schedule Ξ²(t)=1/(1-t), we get Ξ±_t=1-t and Ξ²(t)/(1-Ξ±_t) = 1/t/(t) = 1/(tΒ·L) when normalized. This is exactly LLaDA's loss! MDLM provides the theoretical justification for LLaDA's training objective.

βš™Interactive: ELBO Decomposition Across Timestepsβ€” Hover to see per-step contribution
ELBO = -βˆ‘ Lβ‚œ β€” hover over each bar to see the per-timestep loss
t=1
t=2
t=3
t=4
t=5
t=6
t=7
t=8
Total ELBO loss: 1.790

4. Noise Schedule Design

MDLM explores different noise schedules Ξ²(t) and finds that the choice matters significantly:

Log-linear schedule (best performing)
Ξ±t=1βˆ’t1+(e10βˆ’1)β‹…tβ‡’slowΒ start,Β fastΒ end\alpha_t = \frac{1 - t}{1 + (e^{10} - 1) \cdot t} \quad \Rightarrow \quad \text{slow start, fast end}

5. Experiments & Results

ModelTypePPL (text8)PPL (OpenWebText)
D3PM (absorbing)Discrete Diff.1.45β€”
SEDDScore Entropy1.3932.1
MDLMMasked Diff.1.3631.2
GPT-2 (small)ARβ€”29.1

Key result: MDLM outperforms all prior discrete diffusion models and nearly matches GPT-2 on OpenWebText. The gap to AR models is small enough to suggest that with more scale (which LLaDA later demonstrates), diffusion LMs can be fully competitive.

6. Limitations & Future Work

  • Scale: Only tested up to ~110M parameters. LLaDA later proves it works at 8B.
  • Generation quality: Perplexity is good but unconditional text samples can be incoherent β€” needs SFT/RLHF.
  • Schedule sensitivity: Performance depends on noise schedule choice β€” not fully understood why log-linear works best.

7. Connections to Other Work

LLaDA

Scales up MDLM's framework to 8B parameters. Uses MDLM's training objective (with linear schedule). Proves that masked diffusion works at LLM scale.

Fast DLLM

Optimizes MDLM's inference speed using the ELBO decomposition that MDLM derives. The per-step L_t values come directly from MDLM's theory.

Block Diffusion

Uses MDLM's within-block diffusion framework combined with AR between blocks.

D3PM (Austin et al. 2021)

The predecessor. D3PM introduced discrete diffusion with transition matrices. MDLM simplifies D3PM's approach by taking the continuous-time limit and focusing on the absorbing state.

8. Additional Resources