Phase II · Week 4 · Day 24 of 70 · 2.5 hours
"Flash Attention's key insight isn't algorithmic cleverness — it's respecting the memory hierarchy. By never materializing the full N×N attention matrix, it turns a memory-bound problem into a compute-bound one."
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 23: Triton Matrix Multiplication | Day 25: torch.compile Internals | Week 4: Triton & Kernel Engineering | Phase II: Compiler Fundamentals | ML Compilers |
Standard attention computes the full $N \times N$ score matrix, writes it to HBM, reads it back for softmax, writes it again, then reads it for the final matmul with $V$. For sequence length $N = 8192$ and batch $\times$ heads $= 96$, that's 24 GB of intermediate storage — just for one attention layer. Flash Attention eliminates this entirely by fusing Q·K^T, softmax, and the multiplication by V into a single kernel that tiles over the sequence dimension. The result: 2–4× wall-clock speedup and memory reduction from $O(N^2)$ to $O(N)$. This is the most impactful GPU kernel of the transformer era.
Given $Q, K, V \in \mathbb{R}^{N \times d}$ (for one head):
$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d}}\right) V$$
# Standard attention — PyTorch
def standard_attention(Q, K, V):
d = Q.shape[-1]
S = Q @ K.T / math.sqrt(d) # (N, N) — materialized in HBM!
P = torch.softmax(S, dim=-1) # (N, N) — another HBM round-trip
O = P @ V # (N, d)
return O
Standard Attention — 4 HBM round trips:
HBM SRAM (on-chip)
┌────────────┐ ┌──────────┐
│ Q (N×d) │──read──────► │ │
│ K (N×d) │──read──────► │ compute │
│ │ │ QK^T │
│ S (N×N) │◄──write───── │ │ ① Write S to HBM
│ │ └──────────┘
│ S (N×N) │──read──────► ┌──────────┐
│ │ │ softmax │
│ P (N×N) │◄──write───── │ │ ② Write P to HBM
│ │ └──────────┘
│ P (N×N) │──read──────► ┌──────────┐
│ V (N×d) │──read──────► │ P @ V │
│ O (N×d) │◄──write───── │ │ ③ Write O to HBM
└────────────┘ └──────────┘
HBM traffic: O(N²d + N²) — dominated by N×N matrices
Memory: O(N²) for S and P
For $N = 4096, d = 64$: $S$ and $P$ are $4096^2 \times 4 = 64$ MB each — per head.
Standard softmax requires two passes: one for $\max$, one for $\sum e^{x_i - \max}$. Flash Attention uses the online softmax algorithm that updates the running maximum and sum as new blocks arrive:
Given partial results from blocks $1 \ldots j-1$ and a new block $j$:
$$m_j = \max(m_{j-1}, \max(\mathbf{s}_j))$$
$$\ell_j = e^{m_{j-1} - m_j} \cdot \ell_{j-1} + \sum_i e^{s_{j,i} - m_j}$$
$$O_j = \frac{e^{m_{j-1} - m_j} \cdot \ell_{j-1} \cdot O_{j-1} + e^{\mathbf{s}_j - m_j} \cdot V_j}{\ell_j}$$
When the last block is processed, $O$ contains the exact attention output — no approximation.
FLASH-ATTENTION-FORWARD(Q, K, V):
Input: Q, K, V ∈ ℝ^{N×d}, block sizes Br, Bc
Output: O ∈ ℝ^{N×d}
for i = 0 to ⌈N/Br⌉ - 1: # Outer loop: Q blocks
Qi = Q[i·Br : (i+1)·Br] # Load Q block (Br × d) from HBM
Oi = zeros(Br, d) # Initialize output block
mi = -∞ · ones(Br) # Running row-wise max
li = zeros(Br) # Running row-wise sum
for j = 0 to ⌈N/Bc⌉ - 1: # Inner loop: K,V blocks
Kj = K[j·Bc : (j+1)·Bc] # Load K block (Bc × d) from HBM
Vj = V[j·Bc : (j+1)·Bc] # Load V block (Bc × d) from HBM
Sij = Qi @ Kj^T / √d # Compute scores (Br × Bc) in SRAM
# Online softmax update
mi_new = max(mi, rowmax(Sij))
Pij = exp(Sij - mi_new) # Unnormalized attention (Br × Bc)
li_new = exp(mi - mi_new) * li + rowsum(Pij)
# Rescale previous output and accumulate
Oi = (exp(mi - mi_new) * li / li_new)[:,None] * Oi
+ (1 / li_new)[:,None] * (Pij @ Vj)
mi = mi_new
li = li_new
O[i·Br : (i+1)·Br] = Oi # Write output block to HBM
return O
Flash Attention — 1 HBM round trip per Q block:
HBM SRAM (on-chip, ~20 MB on A100)
┌────────────┐ ┌──────────────────────────┐
│ Q block │──read──────► │ Q_i (Br × d) │
│ K block j │──read──────► │ K_j (Bc × d) │
│ V block j │──read──────► │ V_j (Bc × d) │
│ │ │ │
│ │ (no write!) │ S_ij = Q_i @ K_j^T SRAM │
│ │ (no write!) │ P_ij = softmax(S_ij) SRAM│
│ │ │ O_i += P_ij @ V_j SRAM │
│ │ │ m_i, l_i (running stats) │
│ │ └──────────────────────────┘
│ O block │◄──write───── (only final output)
└────────────┘
HBM traffic: O(N²d² / M) where M = SRAM size
Memory: O(N) — only Q, K, V, O stored (no N×N matrices)
import triton
import triton.language as tl
import torch
import math
@triton.jit
def flash_attention_fwd_kernel(
Q_ptr, K_ptr, V_ptr, O_ptr,
# Scaling factor
sm_scale,
# Sequence length
N,
# Head dimension
d: tl.constexpr,
# Strides
stride_qn, stride_qd,
stride_kn, stride_kd,
stride_vn, stride_vd,
stride_on, stride_od,
# Block sizes
BLOCK_M: tl.constexpr, # Q block size (Br)
BLOCK_N: tl.constexpr, # K/V block size (Bc)
IS_CAUSAL: tl.constexpr,
):
# Which Q block does this program handle?
pid_m = tl.program_id(0)
# Q block row range
q_range = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
q_mask = q_range < N
# Load Q block: (BLOCK_M, d)
q_ptrs = Q_ptr + q_range[:, None] * stride_qn + tl.arange(0, d)[None, :] * stride_qd
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0)
# Initialize accumulators
m_i = tl.full((BLOCK_M,), value=-1e9, dtype=tl.float32) # row-wise max
l_i = tl.zeros((BLOCK_M,), dtype=tl.float32) # row-wise sum
o_i = tl.zeros((BLOCK_M, d), dtype=tl.float32) # output accumulator
# Determine K/V iteration range
if IS_CAUSAL:
kv_bound = min((pid_m + 1) * BLOCK_M, N)
else:
kv_bound = N
# Inner loop: iterate over K/V blocks
for start_n in range(0, kv_bound, BLOCK_N):
kv_range = start_n + tl.arange(0, BLOCK_N)
kv_mask = kv_range < N
# Load K block: (BLOCK_N, d)
k_ptrs = K_ptr + kv_range[:, None] * stride_kn + tl.arange(0, d)[None, :] * stride_kd
k = tl.load(k_ptrs, mask=kv_mask[:, None], other=0.0)
# Compute QK^T: (BLOCK_M, BLOCK_N)
s = tl.dot(q, tl.trans(k)) * sm_scale
# Apply causal mask if needed
if IS_CAUSAL:
causal_mask = q_range[:, None] >= kv_range[None, :]
s = tl.where(causal_mask, s, -1e9)
# --- Online softmax ---
# New row-wise max
m_ij = tl.max(s, axis=1) # (BLOCK_M,)
m_new = tl.maximum(m_i, m_ij) # (BLOCK_M,)
# Correction factor for previous accumulations
alpha = tl.exp(m_i - m_new) # (BLOCK_M,)
# Exponentiate current block scores
p = tl.exp(s - m_new[:, None]) # (BLOCK_M, BLOCK_N)
# Update running sum
l_new = alpha * l_i + tl.sum(p, axis=1) # (BLOCK_M,)
# Load V block: (BLOCK_N, d)
v_ptrs = V_ptr + kv_range[:, None] * stride_vn + tl.arange(0, d)[None, :] * stride_vd
v = tl.load(v_ptrs, mask=kv_mask[:, None], other=0.0)
# Rescale previous output and accumulate
o_i = alpha[:, None] * o_i + tl.dot(p.to(v.dtype), v)
# Update state
m_i = m_new
l_i = l_new
# Final normalization
o_i = o_i / l_i[:, None]
# Store output block
o_ptrs = O_ptr + q_range[:, None] * stride_on + tl.arange(0, d)[None, :] * stride_od
tl.store(o_ptrs, o_i.to(O_ptr.dtype.element_ty), mask=q_mask[:, None])
def flash_attention_triton(
Q: torch.Tensor, # (N, d) or (B, H, N, d)
K: torch.Tensor,
V: torch.Tensor,
causal: bool = False,
) -> torch.Tensor:
"""Flash Attention forward pass using Triton."""
# Handle batched input — flatten batch and heads
orig_shape = Q.shape
if Q.ndim == 4:
B, H, N, d = Q.shape
Q = Q.reshape(B * H, N, d)
K = K.reshape(B * H, N, d)
V = V.reshape(B * H, N, d)
else:
B_H, N, d = Q.shape
O = torch.empty_like(Q)
sm_scale = 1.0 / math.sqrt(d)
BLOCK_M = 64
BLOCK_N = 64
# One program per Q block, per batch*head
grid = (triton.cdiv(N, BLOCK_M), Q.shape[0])
# Launch for each batch*head element
for bh in range(Q.shape[0]):
flash_attention_fwd_kernel[(triton.cdiv(N, BLOCK_M),)](
Q[bh], K[bh], V[bh], O[bh],
sm_scale, N, d,
Q.stride(-2), Q.stride(-1),
K.stride(-2), K.stride(-1),
V.stride(-2), V.stride(-1),
O.stride(-2), O.stride(-1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
IS_CAUSAL=causal,
)
return O.reshape(orig_shape)
# Correctness check
N, d = 1024, 64
Q = torch.randn(N, d, device='cuda', dtype=torch.float16)
K = torch.randn(N, d, device='cuda', dtype=torch.float16)
V = torch.randn(N, d, device='cuda', dtype=torch.float16)
out_flash = flash_attention_triton(Q.unsqueeze(0), K.unsqueeze(0), V.unsqueeze(0))
out_ref = torch.nn.functional.scaled_dot_product_attention(
Q.unsqueeze(0).unsqueeze(0), K.unsqueeze(0).unsqueeze(0), V.unsqueeze(0).unsqueeze(0)
)
print(f"Max error: {(out_flash.squeeze() - out_ref.squeeze()).abs().max():.6f}")
| Algorithm | HBM Memory | HBM I/O |
|---|---|---|
| Standard attention | $O(N^2)$ for $S, P$ | $O(N^2 d + N^2)$ |
| Flash Attention | $O(N)$ for $Q, K, V, O$ only | $O(N^2 d^2 / M)$ |
Where $M$ is the SRAM size. Since $d \ll N$ and $M$ is large enough to hold several blocks, Flash Attention is IO-optimal — it achieves the minimum possible HBM accesses for the attention computation.
Both algorithms perform the same FLOPs: $O(N^2 d)$. Flash Attention does more FLOP per byte — it's compute-bound on modern GPUs, which is exactly where we want to be.
Sequence Standard Flash Speedup
Length Attention Attention
──────────────────────────────────────────────
512 0.12 ms 0.08 ms 1.5×
1024 0.41 ms 0.18 ms 2.3×
2048 1.52 ms 0.52 ms 2.9×
4096 5.89 ms 1.68 ms 3.5×
8192 23.1 ms 5.92 ms 3.9×
16384 OOM 21.3 ms ∞
──────────────────────────────────────────────
For autoregressive models (GPT, LLaMA), position $i$ can only attend to positions $\leq i$:
Causal attention mask (N=8):
K₀ K₁ K₂ K₃ K₄ K₅ K₆ K₇
Q₀ [✓ ✗ ✗ ✗ ✗ ✗ ✗ ✗]
Q₁ [✓ ✓ ✗ ✗ ✗ ✗ ✗ ✗]
Q₂ [✓ ✓ ✓ ✗ ✗ ✗ ✗ ✗]
Q₃ [✓ ✓ ✓ ✓ ✗ ✗ ✗ ✗]
Q₄ [✓ ✓ ✓ ✓ ✓ ✗ ✗ ✗]
Q₅ [✓ ✓ ✓ ✓ ✓ ✓ ✗ ✗]
Q₆ [✓ ✓ ✓ ✓ ✓ ✓ ✓ ✗]
Q₇ [✓ ✓ ✓ ✓ ✓ ✓ ✓ ✓]
In Flash Attention, causal masking provides a free speedup: for Q block $i$, we only iterate over K/V blocks $j \leq i$. This skips ~50% of the computation:
Tiled causal structure (BLOCK_M = BLOCK_N = 2):
K₀₁ K₂₃ K₄₅ K₆₇
Q₀₁ [COMP skip skip skip] ← only 1 K/V block
Q₂₃ [COMP COMP skip skip] ← 2 K/V blocks
Q₄₅ [COMP COMP COMP skip] ← 3 K/V blocks
Q₆₇ [COMP COMP COMP COMP] ← 4 K/V blocks
COMP = computed, skip = skipped entirely
Within boundary blocks: element-wise causal_mask applied
Implement the online softmax algorithm in pure PyTorch (no Triton) and verify it matches torch.softmax:
def online_softmax(x: torch.Tensor, block_size: int = 64) -> torch.Tensor:
"""Compute softmax using online algorithm, processing block_size elements at a time."""
N = x.shape[-1]
m = torch.full((x.shape[0],), -1e9, device=x.device)
l = torch.zeros(x.shape[0], device=x.device)
for start in range(0, N, block_size):
end = min(start + block_size, N)
x_block = x[:, start:end]
m_block = x_block.max(dim=-1).values
m_new = torch.maximum(m, m_block)
# TODO: update l, compute final softmax
# Hint: track exp(x_i - m_new) contributions
# TODO: reconstruct full softmax output
Compare your Triton Flash Attention against torch.nn.functional.scaled_dot_product_attention and a naive PyTorch implementation for $N \in [256, 512, 1024, 2048, 4096]$. Plot throughput (GFLOPS) and peak memory.
If you haven't already, add causal masking to the kernel. Verify that:
1. Causal output matches scaled_dot_product_attention(..., is_causal=True)
2. Causal is faster than non-causal (should skip ~50% of inner loop iterations)
3. Future positions have zero influence on the output
Day 25 steps back from kernel writing to understand the torch.compile pipeline that generates kernels automatically. We'll trace how @torch.compile captures Python bytecode (Dynamo), splits forward and backward graphs (AOTAutograd), and generates Triton kernels (Inductor). Understanding this pipeline connects the Triton skills you've built this week to the compiler infrastructure that ships with PyTorch.