TL;DR
Standard attention materializes the full NΓN attention matrix in GPU HBM, requiring O(NΒ²) memory. FlashAttention reorders the computation using tiling: it computes attention in blocks that fit in SRAM (fast cache), never writing the full NΓN matrix to HBM. This achieves exact (not approximate) attention with 2-4Γ speedup and O(N) memory.
1. Background: Why Standard Attention Is Slow
Modern GPUs have two levels of memory with very different bandwidths. SRAM (on-chip cache) is fast but tiny (βΌ20 MB). HBM (high-bandwidth memory, i.e., the main GPU RAM) is large (βΌ40 GB on A100) but comparatively slow.
| Memory type | Bandwidth | Size (A100) |
|---|---|---|
| SRAM (on-chip) | ~19 TB/s | ~20 MB |
| HBM (off-chip) | ~2 TB/s | ~40 GB |
Standard attention performs three round-trips through HBM for each sequence. Every intermediate result β the raw score matrix S, the softmax probability matrix P β is written to HBM and read back:
For sequence length N = 4096, the attention matrix S has 4096Β² = 16,777,216 elements. At FP16, that is 32 MB just for S alone β already larger than SRAM. The bottleneck is not arithmetic throughput but memory bandwidth.
2. The Tiling Approach
FlashAttention tiles Q, K, V into blocks that each fit in SRAM, computes attention entirely within SRAM for each pair of blocks, and accumulates into the output without ever materializing the full NΓN matrix in HBM.
3. The Online Softmax Trick
Softmax normally requires a two-pass algorithm: first compute the maximum value across all scores (for numerical stability), then compute the exponentials and sum. This requires seeing all N values before producing any output β incompatible with tiling.
The online softmax trick maintains a running maximum m and a running sum β as blocks arrive, and rescales the accumulated output O each time the maximum estimate is updated:
Why this is exact: Each update is a mathematically equivalent rescaling of the previous partial result. When all blocks have been processed, O contains exactly the same value as standard attention β not an approximation. The trick only reorganizes the order of arithmetic operations.
4. Concrete Example: Tiling with N=4, Block Size=2
Query sequence: [q1, q2, q3, q4], Key sequence: [k1, k2, k3, k4]
5. Backward Pass: Recomputation
For backpropagation, standard attention needs to store the NΓN attention matrix P to compute gradients. FlashAttention instead stores only the output O and the per-row softmax statistics (m, β) β O(N) total β and recomputes the attention tiles on the fly during the backward pass.
Trade-off: Recomputation requires additional FLOPs in the backward pass (roughly 2Γ the FLOPs of the forward pass), but this is cheaper than the HBM bandwidth cost of storing and loading the NΓN matrix. On modern hardware, attention is memory-bandwidth-bound, not compute-bound.
6. Memory Complexity
| What is stored | Standard attention | FlashAttention |
|---|---|---|
| Score matrix S | O(NΒ²) | never stored |
| Probability matrix P | O(NΒ²) | never stored |
| Output O | O(Nd) | O(Nd) |
| Softmax statistics (m, β) | implicit in P | O(N) |
| Total | O(NΒ²) | O(N) |
7. Results
| Method | Seq len 2K | Seq len 4K | Seq len 8K | Memory (2K) |
|---|---|---|---|---|
| Standard | 1.0Γ | 1.0Γ | OOM | O(NΒ²) |
| FlashAttention | 3.1Γ | 3.8Γ | 3.5Γ | O(N) |
- BERT: 15% end-to-end training speedup, 3Γ memory reduction
- GPT-2: 3Γ faster attention operation with identical perplexity
- Long context: Enables 64K context length training, compared to the ~2K practical limit of standard attention on the same hardware
8. Limitations
- CUDA-specific: The original FlashAttention is a hand-written CUDA kernel. It is not automatically available in all deep learning frameworks without explicit integration.
- Hardware-dependent block sizes: Optimal block sizes depend on the specific GPU's SRAM capacity, requiring tuning per hardware target.
- GPU utilization: FlashAttention-1 did not fully utilize GPU compute units. FlashAttention-2 and FlashAttention-3 addressed this with further parallelism and warp-level optimizations.
9. Connections to Other Work
FlashAttention is a drop-in replacement for the O(NΒ²) scaled dot-product attention introduced in this paper. It computes exactly the same mathematical operation, just with a dramatically better IO pattern.
Frequently combined in practice: LoRA reduces the number of trainable parameters while FlashAttention reduces the memory and time cost of each forward pass. Together they enable fine-tuning large models on limited hardware.
Masked diffusion language models like LLaDA rely on FlashAttention for efficient training at scale, since they perform many forward passes per training step.