TL;DR
Standard LLM training predicts one token at a time. Multi-Token Prediction (MTP) adds k independent output heads to the same transformer trunk, each trained to predict one future token (t+1 through t+k) simultaneously. The joint loss forces the model to learn higher-level, forward-looking representations. At inference, the k heads double as a built-in speculative decoder, yielding up to 3Γ throughput gains. Models trained with MTP consistently outperform next-token-only baselines on coding and algorithmic reasoning benchmarks, and reach the same perplexity with less training data.
1. The One-Token-At-A-Time Bottleneck
Every modern autoregressive language model is trained with the same objective: given a sequence of tokens xβ, xβ, β¦, x_T, maximise the log-likelihood of each token given all preceding tokens. The training loss is the sum of cross-entropy terms over every position:
This is elegant and powerful, but it imposes a strict constraint: every forward pass teaches the model about exactly one future token. The representation learned by the transformer at position t only needs to carry enough information to predict x_{t+1} β nothing further. This local, myopic training signal may limit how much long-range structure the model internalises.
Auto-regressive generation also has a well-known throughput problem: tokens are produced strictly one at a time. Even on the fastest GPUs, the bottleneck is the sequential dependency chain β you cannot start computing x_{t+2} until x_{t+1} is known. Multi-Token Prediction addresses both issues at once.
2. Multi-Token Prediction Architecture
The core idea is deceptively simple: keep the transformer body exactly as-is, but attach k independent output heads instead of one. Each head n (for n = 1, 2, β¦, k) is responsible for predicting the token that appears n steps into the future, given only the tokens seen up to position t.
The shared trunk is the expensive part β the full transformer with all its layers. Each output head is a lightweight linear projection from the hidden dimension d to the vocabulary size V, plus a bias. Because the trunk is shared, the overhead of adding k heads is modest.
3. k Independent Heads
Each head n has its own weight matrix W_n and bias b_n. Given the hidden state h_t produced by the shared trunk at position t, head n computes a distribution over the vocabulary:
The heads are independent: they do not share parameters with each other (beyond the trunk), and they are trained with separate gradient signals. Crucially, the gradient from head n flows back through the shared trunk β so the trunk receives k training signals simultaneously from every token position.
Parameter overhead
Each head adds a weight matrix of shape (V Γ d) plus a bias of shape V. With k heads, the additional parameters are k Γ (V Β· d + V). For typical values (d = 4096, V = 32000, k = 4), this is roughly 4 Γ 131M = 524M extra parameters β significant in absolute terms, but small compared to the tens of billions of parameters in the trunk.
4. Training Loss
The training objective is a simple sum of k cross-entropy losses, one per head. At each position t, head n is supervised by the ground-truth token x_{t+n}. Summing over all positions and all heads:
Notice that the conditioning context is always the same prefix x_{1:t} for all k heads at position t. Head 1 predicts the very next token (same supervision as standard NTP). Head 2 predicts two steps ahead. And so on. This means the gradient from every head flows into the same hidden state h_t β giving the trunk far richer learning signal.
5. Speculative Decoding Connection
One of the most practically valuable aspects of MTP is that the k prediction heads can be reused at inference time as a speculative decoder β at zero extra training cost.
Speculative decoding works by running a cheap draft model to propose multiple tokens in one shot, then verifying all proposals in a single forward pass of the full model. Since verification is embarrassingly parallel (all proposals are scored simultaneously), this avoids the serial bottleneck of standard generation while keeping output quality identical.
MTP as a draft model
- Draft: run the shared trunk once at position t. Heads 1β¦k each output a proposed token xΜ_{t+1}, β¦, xΜ_{t+k}.
- Verify: run a single forward pass on the concatenated sequence x_{1:t}, xΜ_{t+1}, β¦, xΜ_{t+k} to obtain the model's true probabilities for each position.
- Accept/reject: accept each proposed token greedily or via the speculative sampling rule below.
The speculative sampling acceptance rule guarantees that the accepted tokens are distributed exactly as if they had been sampled from the full model's distribution β so output quality is provably unchanged:
When the draft distribution closely matches the verifier (i.e., the MTP heads are accurate), nearly all proposals are accepted and the net throughput approaches kΓ that of standard generation. In practice, Meta's experiments show roughly 3Γ speedup for code generation tasks where the heads are highly accurate.
6. Why Planning Ahead Helps
The intuitive story is compelling: if a model must simultaneously predict the next word and the word five steps ahead, it cannot afford to commit to greedy local choices. It must learn representations that are globally consistent with upcoming structure. This is analogous to how a chess player plans multiple moves ahead instead of optimising each move in isolation.
More formally: the gradient that the shared trunk receives from head n contains information about the local distribution at t+n, not just t+1. When many such gradients are aggregated, the trunk is incentivised to compress longer-range dependencies into h_t. The result is a more expressive, temporally richer hidden state.
7. Results: Coding, Reasoning, and Sample Efficiency
Gloeckle et al. train a series of models (370M, 1.3B, 3B, 7B, and 13B parameters) on code and text corpora, comparing standard NTP against MTP with k = 4. Key findings:
| Benchmark | NTP (baseline) | MTP k=4 | Gain |
|---|---|---|---|
| HumanEval (code) | ~39% | ~46% | +7 pp |
| MBPP (code) | ~52% | ~58% | +6 pp |
| Algorithmic reasoning (byte ops) | baseline | significantly higher | ++ |
| Natural language perplexity | baseline | slightly better | + |
| Inference throughput (code) | 1\u00d7 | ~3\u00d7 | 3\u00d7 |
Approximate numbers from the paper; exact values vary by model size and task.
Sample efficiency is a particularly striking result. When training on a fixed compute budget, MTP models reach a given perplexity threshold with substantially fewer tokens. Alternatively, with the same number of training tokens, MTP models achieve lower perplexity. This suggests the richer gradient signal leads to more efficient use of each training example.