LLaDA: Large Language Diffusion with mAsking

Nie et al. Β· 2025 Β· arXiv 2502.09992

TL;DR

LLaDA treats language modeling as a masked diffusion process: the forward process randomly masks tokens, and the reverse process learns to predict all masked tokens simultaneously. Unlike AR models that generate left-to-right, LLaDA can fill in tokens in any order. Scaled to 8B parameters, it matches LLaMA3 on many benchmarks.

β—†Architecture Overview
Problem: AR LMs have inherent limitations
Unidirectional, sequential, error accumulation
Insight
Core Idea: Masked Diffusion for Language
Masking = natural "noise" for discrete tokens
Forward Process
Progressively mask tokens with probability t
Reverse Process
Predict all [MASK] positions in parallel
Design choice
Architecture: Standard Transformer (like LLaMA)
Bidirectional attention (no causal mask) + RoPE + RMSNorm + SwiGLU
How to learn
Training: Variable-rate MLM
Cross-entropy on masked positions, t ~ Uniform(0,1)
How to generate
Generation: Iterative Unmasking
T steps: unmask highest-confidence tokens first
8B params, competitive with LLaMA3
Bidirectional context
Parallel decoding potential

1. Background & Motivation

Autoregressive (AR) LLMs like GPT generate tokens one-by-one, left to right. This is the dominant paradigm, but it has inherent limitations:

  • Unidirectional: can only condition on left context β€” can't "look ahead"
  • Sequential decoding: generating N tokens requires N forward passes, can't parallelize
  • Error accumulation: mistakes in early tokens propagate to all later tokens
  • Fixed generation order: always leftβ†’right, even when later tokens are more certain

Continuous diffusion models (like DALL-E for images) have shown that iterative denoising can be a powerful generative paradigm. But text is discrete β€” you can't add Gaussian noise to words. LLaDA's key insight: masking IS the natural noise for discrete tokens.

2. Forward Process: Progressive Masking

Given a clean sequence, the forward process independently masks each token with probability t:

Forward transition kernel
q(xti∣x0i)={x0iwithΒ probabilityΒ 1βˆ’t[MASK]withΒ probabilityΒ tq(x_t^i \mid x_0^i) = \begin{cases} x_0^i & \text{with probability } 1 - t \\ \texttt{[MASK]} & \text{with probability } t \end{cases}
x0ix_0^iThe original (clean) token at position ixtix_t^iThe token at position i at noise level t (either original or [MASK])t∈[0,1]t \in [0, 1]Noise level / timestep. t=0 means clean, t=1 means fully maskedq(β‹…βˆ£β‹…)q(\cdot \mid \cdot)The forward transition distribution β€” how we add noise

Key properties of this forward process:

  • At t=0, the sequence is fully clean
  • At t=1, every token is [MASK]
  • Masking is independent across positions β€” each token decides on its own
  • Expected number of masked tokens at time t: t Γ— L
Joint forward distribution (all positions independent)
q(xt∣x0)=∏i=1Lq(xti∣x0i)q(x_t \mid x_0) = \prod_{i=1}^{L} q(x_t^i \mid x_0^i)

Take the sentence "The cat sat on the mat" (L=6 tokens). Each token independently decides whether to mask:

t = 0.0: mask prob = 0% β†’ The cat sat on the mat
Expected masked: 0Γ—6 = 0 tokens
t = 0.25: mask prob = 25% β†’ The cat sat [M] the mat
Expected: 0.25Γ—6 = 1.5 tokens (got 1 here β€” it's stochastic!)
t = 0.5: mask prob = 50% β†’ The [M] sat [M] [M] mat
Expected: 0.5Γ—6 = 3 tokens
t = 0.75: mask prob = 75% β†’ [M] [M] sat [M] [M] [M]
Expected: 0.75Γ—6 = 4.5 tokens
t = 1.0: mask prob = 100% β†’ [M] [M] [M] [M] [M] [M]
Expected: 1.0Γ—6 = 6 tokens (deterministic)
βš™Interactive: Forward & Reverse Masking Processβ€” Step through the masking/unmasking of tokens
Click to toggle direction
The
cat
sat
on
the
mat
t = 0 / 4Β Β·Β 0 / 6 tokens masked

3. Reverse Process: Learning to Unmask

The reverse process is a neural network that takes a partially masked sequence and predicts the original token at every masked position:

Reverse process β€” predict clean tokens from noisy input
pθ(x0∣xt)=∏i=1Lpθ(x0i∣xt)p_\theta(x_0 \mid x_t) = \prod_{i=1}^{L} p_\theta(x_0^i \mid x_t)
pΞΈp_\thetaNeural network (Transformer) with parameters ΞΈxtx_tThe noisy (partially masked) input sequencex0x_0The clean target sequence we want to predict∏i=1L\prod_{i=1}^{L}Independent prediction at each position β€” this is what enables parallelism

Key insight: The model predicts ALL masked tokens independently in a single forward pass. This means we don't need to generate tokens one by one like in AR models.

4. Training Objective

LLaDA training objective (ELBO-derived)
L(ΞΈ)=βˆ’Et∼U(0,1)Ext∼q(xt∣x0)[1tβ‹…Lβˆ‘i:xti=[M]log⁑pΞΈ(x0i∣xt)]\mathcal{L}(\theta) = -\mathbb{E}_{t \sim \mathcal{U}(0,1)} \mathbb{E}_{x_t \sim q(x_t \mid x_0)} \left[ \frac{1}{t \cdot L} \sum_{i: x_t^i = \texttt{[M]}} \log p_\theta(x_0^i \mid x_t) \right]
L(ΞΈ)\mathcal{L}(\theta)The loss function we minimizet∼U(0,1)t \sim \mathcal{U}(0,1)Sample a random timestep uniformly from [0,1]. This means we train at ALL noise levelsxt∼q(xt∣x0)x_t \sim q(x_t \mid x_0)Apply forward process: mask each token with probability tβˆ‘i:xti=[M]\sum_{i: x_t^i = \texttt{[M]}}Sum only over masked positions β€” don't compute loss on unmasked tokenslog⁑pΞΈ(x0i∣xt)\log p_\theta(x_0^i \mid x_t)Log probability of predicting the correct original token1tβ‹…L\frac{1}{t \cdot L}Normalization: expected number of masked tokens is tΓ—L

Sentence: "The cat sat on the mat" (L=6)

Step 1: Sample t = 0.5
Step 2: Mask with prob 0.5 β†’ "The [M] sat [M] [M] mat"
Masked positions: {2, 4, 5} (cat, on, the)
Step 3: Model predicts at masked positions:
pos 2: p(cat|x_t) = 0.7 β†’ log(0.7) = -0.357
pos 4: p(on|x_t) = 0.9 β†’ log(0.9) = -0.105
pos 5: p(the|x_t) = 0.6 β†’ log(0.6) = -0.511
Step 4: Loss = -1/(0.5Γ—6) Γ— (-0.357 + -0.105 + -0.511)
= -1/3 Γ— (-0.973) = 0.324

Connection to BERT: LLaDA's training is essentially BERT with a variable masking rate (t~Uniform) instead of BERT's fixed 15%. By training across ALL masking rates, the model learns to handle any degree of partial information β€” from nearly complete to fully masked.

5. Generation: Iterative Unmasking

At inference, LLaDA starts from a fully masked sequence and iteratively unmasks tokens over T steps:

Generation algorithm
ForΒ s=T,Tβˆ’1,…,1:xtsβˆ’1=Unmask(xts,pΞΈ,ns)\text{For } s = T, T{-}1, \ldots, 1: \quad x_{t_{s-1}} = \text{Unmask}(x_{t_s}, p_\theta, n_s)
TTTotal number of denoising steps (e.g. 10, 50, 100)tst_sNoise level at step s: t_T=1, t_0=0, uniformly spacednsn_sNumber of tokens to unmask at this stepUnmask()1) Model predicts all masked tokens; 2) Keep top-n_s by confidence; 3) Re-mask the rest
Tokens to unmask per step
ns=⌊(tsβˆ’tsβˆ’1)β‹…LβŒ‹n_s = \left\lfloor (t_s - t_{s-1}) \cdot L \right\rfloor
βš™Interactive: Iterative Denoising Processβ€” Watch tokens get revealed step by step with confidence scores
Reverse process: iteratively denoise from [MASK] β†’ text
[M]
[M]
[M]
[M]
[M]
[M]
Step 0 / 3
Masked Low confidence High confidence

6. Architecture Details

LLaDA uses a standard Transformer architecture (same as LLaMA), with one crucial difference:

ComponentLLaMA (AR)LLaDA (Diffusion)
AttentionCausal (left-to-right)Bidirectional (full)
Positional EncodingRoPERoPE
NormalizationRMSNormRMSNorm
ActivationSwiGLUSwiGLU
Timestep conditioningN/AImplicit (via masking ratio)

Why no causal mask? AR models use causal masking because they generate left-to-right and shouldn't see future tokens. LLaDA generates all positions simultaneously, so every token needs to see every other token (including the [MASK] tokens) to make coordinated predictions. This is the same as BERT-style bidirectional attention.

7. Supervised Fine-tuning (SFT)

For instruction following, LLaDA only masks the response tokens (not the prompt) during fine-tuning:

SFT: mask response only
LSFT=βˆ’Et[1tβ‹…Lrβˆ‘i∈response1[xti=M]β‹…log⁑pΞΈ(x0i∣xprompt,xt,response)]\mathcal{L}_{\text{SFT}} = -\mathbb{E}_t \left[ \frac{1}{t \cdot L_r} \sum_{i \in \text{response}} \mathbf{1}[x_t^i = \texttt{M}] \cdot \log p_\theta(x_0^i \mid x_{\text{prompt}}, x_{t,\text{response}}) \right]

8. Experiments & Results

LLaDA was trained at two scales: 1.1B and 8B parameters.

ModelTypeParamsMMLUARC-CHellaSwag
LLaMA3AR8B65.353.782.1
LLaDADiffusion8B67.055.679.8
GPT-2AR1.5B32.433.371.3
LLaDADiffusion1.1B38.938.262.1

Key takeaway: A diffusion LM can be competitive with AR models at scale. LLaDA 8B slightly outperforms LLaMA3 8B on MMLU and ARC-C β€” a strong signal that AR is not the only path to powerful LLMs.

9. Limitations & Future Work

  • Inference speed: Multiple denoising steps needed vs. AR's single pass per token. Mitigated by Fast DLLM's adaptive schedule.
  • No KV-cache: Bidirectional attention means standard KV-cache doesn't work. Block Diffusion addresses this with block-level caching.
  • RLHF unexplored: How to do reinforcement learning from human feedback with diffusion LMs is an open question.
  • Long-form generation: Performance on very long sequences (>4K tokens) not yet studied at scale.

10. Connections to Other Work

MDLM

Provides the rigorous continuous-time ELBO theory that underpins LLaDA's training objective. MDLM's loss is essentially the same as LLaDA's, derived from first principles.

Fast DLLM

Addresses LLaDA's main weakness β€” slow multi-step generation β€” with adaptive denoising schedules and importance sampling. Reduces steps by 3-10x.

Block Diffusion

Combines AR and diffusion at the block level. Can be seen as a generalization of LLaDA where B=L is one extreme (full diffusion = LLaDA) and B=1 is the other (pure AR).

11. Additional Resources