FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Dao et al. Β· NeurIPS 2022 Β· arXiv 2205.14135

TL;DR

Standard attention materializes the full NΓ—N attention matrix in GPU HBM, requiring O(NΒ²) memory. FlashAttention reorders the computation using tiling: it computes attention in blocks that fit in SRAM (fast cache), never writing the full NΓ—N matrix to HBM. This achieves exact (not approximate) attention with 2-4Γ— speedup and O(N) memory.

β—†FlashAttention Method Overview
Problem: Standard attention writes NΓ—N matrix to HBM β€” bottleneck is memory bandwidth
Motivates
Key insight: GPU SRAM (fast) is tiny; GPU HBM (slow) is large. Minimize HBM reads/writes.
Solution strategy
Tiling: Split Q, K, V into blocks that fit in SRAM
Block size B_r Γ— B_c
Compute partial softmax in SRAM
Core mechanism
Online softmax: Rescale running max/sum as new blocks arrive
Enables exact softmax
Backward: Recompute attention from tiles (no store of NΓ—N grad)
Outcome
Result: 3Γ— faster attention, 5-20Γ— memory reduction, longer context
Exact attention (not approximate)
O(N) memory instead of O(NΒ²)
IO-aware algorithm design

1. Background: Why Standard Attention Is Slow

Modern GPUs have two levels of memory with very different bandwidths. SRAM (on-chip cache) is fast but tiny (∼20 MB). HBM (high-bandwidth memory, i.e., the main GPU RAM) is large (∼40 GB on A100) but comparatively slow.

Memory typeBandwidthSize (A100)
SRAM (on-chip)~19 TB/s~20 MB
HBM (off-chip)~2 TB/s~40 GB

Standard attention performs three round-trips through HBM for each sequence. Every intermediate result β€” the raw score matrix S, the softmax probability matrix P β€” is written to HBM and read back:

Step 1: Load Q, K from HBM β†’ compute S = QKα΅€ β†’ write S to HBM
Step 2: Load S from HBM β†’ compute P = softmax(S) β†’ write P to HBM
Step 3: Load P, V from HBM β†’ compute O = PV β†’ write O to HBM
Total HBM reads/writes: O(NΒ²) β€” this is the bottleneck, not compute!

For sequence length N = 4096, the attention matrix S has 4096Β² = 16,777,216 elements. At FP16, that is 32 MB just for S alone β€” already larger than SRAM. The bottleneck is not arithmetic throughput but memory bandwidth.

2. The Tiling Approach

FlashAttention tiles Q, K, V into blocks that each fit in SRAM, computes attention entirely within SRAM for each pair of blocks, and accumulates into the output without ever materializing the full NΓ—N matrix in HBM.

Block tile decomposition
Q=[Q1,…,QTr],K=[K1,…,KTc],V=[V1,…,VTc]Q = [Q_1, \ldots, Q_{T_r}], \quad K = [K_1, \ldots, K_{T_c}], \quad V = [V_1, \ldots, V_{T_c}]
# Load Q block into SRAM once
for i in 1..T_r:
Load Q_i from HBM to SRAM
for j in 1..T_c:
Load K_j, V_j from HBM to SRAM
Compute S_ij = Q_i K_jα΅€ (in SRAM)
Update running max m and sum β„“ (online softmax)
Accumulate O_i += rescaled(softmax(S_ij) Β· V_j) (in SRAM)
Write O_i to HBM (once per Q block)
HBM reads/writes: O(N) total β€” S and P are never materialized in HBM!

3. The Online Softmax Trick

Softmax normally requires a two-pass algorithm: first compute the maximum value across all scores (for numerical stability), then compute the exponentials and sum. This requires seeing all N values before producing any output β€” incompatible with tiling.

The online softmax trick maintains a running maximum m and a running sum β„“ as blocks arrive, and rescales the accumulated output O each time the maximum estimate is updated:

Online max update
mnew=max⁑(mold,β€…β€Šmax⁑(sblock))m_\text{new} = \max(m_\text{old},\; \max(s_\text{block}))
Online sum rescale
β„“new=emoldβˆ’mnewβ‹…β„“old+βˆ‘jesblock,jβˆ’mnew\ell_\text{new} = e^{m_\text{old} - m_\text{new}} \cdot \ell_\text{old} + \sum_j e^{s_{\text{block},j} - m_\text{new}}
Online output accumulation
Onew=emoldβˆ’mnewβ‹…β„“oldβ‹…Oold+esblockβˆ’mnewβ‹…Vblockβ„“newO_\text{new} = \frac{e^{m_\text{old} - m_\text{new}} \cdot \ell_\text{old} \cdot O_\text{old} + e^{s_\text{block} - m_\text{new}} \cdot V_\text{block}}{\ell_\text{new}}

Why this is exact: Each update is a mathematically equivalent rescaling of the previous partial result. When all blocks have been processed, O contains exactly the same value as standard attention β€” not an approximation. The trick only reorganizes the order of arithmetic operations.

4. Concrete Example: Tiling with N=4, Block Size=2

Query sequence: [q1, q2, q3, q4], Key sequence: [k1, k2, k3, k4]

Block 1 (q1,q2 attend to k1,k2):
Load Q[1:2], K[1:2], V[1:2] into SRAM
Compute S_11 = Q[1:2] Β· K[1:2]α΅€
Set m1 = max(S_11), β„“_1 = sum of exp(S_11 - m1)
O[1:2] = exp(S_11 - m1) Β· V[1:2] / β„“_1
Block 2 (q1,q2 attend to k3,k4):
Load K[3:4], V[3:4] into SRAM (Q[1:2] stays)
Compute S_12 = Q[1:2] Β· K[3:4]α΅€
m2 = max(m1, max(S_12)), rescale β„“_1, merge into β„“_2
Rescale O[1:2] by exp(m1 - m2), add new contribution
Final O[1:2] = exact softmax attention over all 4 keys!
Never wrote S or P to HBM. Only read each K, V block once.

5. Backward Pass: Recomputation

For backpropagation, standard attention needs to store the NΓ—N attention matrix P to compute gradients. FlashAttention instead stores only the output O and the per-row softmax statistics (m, β„“) β€” O(N) total β€” and recomputes the attention tiles on the fly during the backward pass.

Trade-off: Recomputation requires additional FLOPs in the backward pass (roughly 2Γ— the FLOPs of the forward pass), but this is cheaper than the HBM bandwidth cost of storing and loading the NΓ—N matrix. On modern hardware, attention is memory-bandwidth-bound, not compute-bound.

6. Memory Complexity

What is storedStandard attentionFlashAttention
Score matrix SO(NΒ²)never stored
Probability matrix PO(NΒ²)never stored
Output OO(Nd)O(Nd)
Softmax statistics (m, β„“)implicit in PO(N)
TotalO(NΒ²)O(N)

7. Results

MethodSeq len 2KSeq len 4KSeq len 8KMemory (2K)
Standard1.0Γ—1.0Γ—OOMO(NΒ²)
FlashAttention3.1Γ—3.8Γ—3.5Γ—O(N)
  • BERT: 15% end-to-end training speedup, 3Γ— memory reduction
  • GPT-2: 3Γ— faster attention operation with identical perplexity
  • Long context: Enables 64K context length training, compared to the ~2K practical limit of standard attention on the same hardware

8. Limitations

  • CUDA-specific: The original FlashAttention is a hand-written CUDA kernel. It is not automatically available in all deep learning frameworks without explicit integration.
  • Hardware-dependent block sizes: Optimal block sizes depend on the specific GPU's SRAM capacity, requiring tuning per hardware target.
  • GPU utilization: FlashAttention-1 did not fully utilize GPU compute units. FlashAttention-2 and FlashAttention-3 addressed this with further parallelism and warp-level optimizations.

9. Connections to Other Work

Attention Is All You Need

FlashAttention is a drop-in replacement for the O(NΒ²) scaled dot-product attention introduced in this paper. It computes exactly the same mathematical operation, just with a dramatically better IO pattern.

LoRA

Frequently combined in practice: LoRA reduces the number of trainable parameters while FlashAttention reduces the memory and time cost of each forward pass. Together they enable fine-tuning large models on limited hardware.

LLaDA

Masked diffusion language models like LLaDA rely on FlashAttention for efficient training at scale, since they perform many forward passes per training step.

10. Additional Resources