Phase IV · Week 9 · Day 57 of 70 · 2.5 hours
"Serving one LLM request is easy. Serving a thousand per second without going bankrupt is the hard part."
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 56: Mini-Project — Optimization Pipeline | Day 58: KV Cache Optimization | Week 9: LLM Serving Systems | Phase IV: Inference & Deployment | ML Compilers |
LLMs broke every assumption traditional inference systems were built on. A ResNet forward pass is one shot — fixed input, fixed compute, fixed memory. An LLM generates tokens one at a time, each depending on every token before it. This autoregressive loop turns inference from a compute-bound matrix multiply into a memory-bandwidth-bound serial dependency chain. The KV cache that stores past attention state can consume more GPU memory than the model weights themselves. Understanding these unique challenges is prerequisite to every optimization in the next two weeks.
Unlike image models that produce output in a single forward pass, LLMs generate text token by token:
Autoregressive Generation Loop
═══════════════════════════════════════════════════════════════
Input: "The capital of France is"
Step 1: Forward(all tokens) ──→ logits ──→ sample ──→ "Paris"
Step 2: Forward("Paris") ──→ logits ──→ sample ──→ "."
Step 3: Forward(".") ──→ logits ──→ sample ──→ "<eos>"
┌─────────────────────────────────────────────────────────┐
│ Each step MUST wait for the previous token. │
│ You cannot parallelize across output tokens. │
│ N output tokens = N sequential forward passes. │
└─────────────────────────────────────────────────────────┘
This creates a serial dependency that fundamentally limits throughput. Generating 512 tokens requires 512 sequential model invocations — no amount of hardware parallelism can break this chain.
Every LLM request has two distinct phases with entirely different computational profiles:
Request Lifecycle
═══════════════════════════════════════════════════════════════
┌─────────────────────┐ ┌─────────────────────────────┐
│ PREFILL │ │ DECODE │
│ │ │ │
│ Process all input │ │ Generate tokens one by one │
│ tokens in parallel │────▶│ using cached KV states │
│ │ │ │
│ Compute-bound │ │ Memory-bandwidth-bound │
│ High GPU util. │ │ Low GPU util. (~5-15%) │
│ O(n²) attention │ │ O(n) attention per step │
│ One pass │ │ T passes for T tokens │
└─────────────────────┘ └─────────────────────────────┘
Example: LLaMA-2 7B on A100 (80GB)
───────────────────────────────────
Prefill 2048 tokens: ~45 ms (compute-bound, high FLOPS)
Decode 1 token: ~12 ms (memory-bound, reading weights)
Generate 512 tokens: ~6.1 s (512 × 12ms, serial)
All input tokens are processed in parallel. The attention computation is a large matrix multiply — $O(n^2 \cdot d)$ where $n$ is sequence length and $d$ is head dimension. This phase looks like traditional inference: compute-bound, high GPU utilization.
Each new token requires reading the entire model weights from GPU memory but only performs a matrix-vector multiply (batch size = 1 token). The arithmetic intensity drops dramatically:
$$\text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes Accessed}}$$
$$\text{Prefill: } \frac{2 \cdot n \cdot d^2}{d^2 \cdot 2} = n \quad \text{(high, compute-bound)}$$
$$\text{Decode: } \frac{2 \cdot 1 \cdot d^2}{d^2 \cdot 2} = 1 \quad \text{(low, memory-bound)}$$
With an A100's compute-to-bandwidth ratio of ~312 TFLOPS / 2 TB/s ≈ 156, any operation with arithmetic intensity < 156 is memory-bound. Decode is deeply memory-bound.
During prefill, each layer computes Key and Value projections for all input tokens. During decode, we must reuse these cached K/V tensors to avoid recomputation. This KV cache grows with every token generated:
KV Cache Memory Per Request
═══════════════════════════════════════════════════════════════
Memory = 2 × n_layers × n_heads × seq_len × head_dim × bytes_per_param
LLaMA-2 7B (FP16):
──────────────────
n_layers = 32
n_heads = 32
head_dim = 128
bytes = 2 (FP16)
Per token: 2 × 32 × 32 × 1 × 128 × 2 = 524,288 bytes = 0.5 MB
For 4096 token sequence:
2 × 32 × 32 × 4096 × 128 × 2 = 2,147,483,648 bytes ≈ 2 GB
┌─────────────────────────────────────────────────────────┐
│ Model weights (7B × 2 bytes): 14 GB │
│ KV cache (1 request, 4K ctx): 2 GB │
│ KV cache (32 concurrent requests): 64 GB ← PROBLEM │
│ A100 total memory: 80 GB │
└─────────────────────────────────────────────────────────┘
For larger models, the situation worsens:
| Model | Weights (FP16) | KV Cache/request (4K) | Max Concurrent (A100 80GB) |
|---|---|---|---|
| LLaMA-2 7B | 14 GB | 2 GB | ~33 |
| LLaMA-2 13B | 26 GB | 3.25 GB | ~16 |
| LLaMA-2 70B | 140 GB | 20 GB | Needs multi-GPU |
| GPT-3 175B | 350 GB | 50 GB | Needs 8+ GPUs |
Traditional inference batches requests trivially — stack inputs into a larger tensor. LLM batching is fundamentally harder:
The LLM Batching Problem
═══════════════════════════════════════════════════════════════
Static Batching (Naïve):
────────────────────────
Request A: [################............] (16 tokens in, done at step 12)
Request B: [################################] (8 tokens in, done at step 30)
Request C: [####################..........] (20 tokens in, done at step 20)
▲ All requests padded to longest. A and C waste GPU cycles waiting.
▲ New requests must wait for entire batch to finish.
Continuous Batching (Better):
─────────────────────────────
Step 1: [A₁, B₁, C₁] ← all active
Step 12: [ _, B₁₂, C₁₂] ← A finishes, slot freed
Step 13: [D₁, B₁₃, C₁₃] ← D joins immediately!
Step 20: [D₈, B₂₀, _] ← C finishes
Step 21: [D₉, B₂₁, E₁] ← E joins
▲ No wasted compute. Requests enter/leave independently.
The challenge with continuous batching: - Requests have different prefill lengths (new requests need full prefill, existing need single-token decode) - KV cache must be managed per-request (allocate on arrival, free on completion) - The scheduler must balance prefill latency (time-to-first-token) against decode throughput
LLM serving cost is dominated by GPU time. Let's compute the cost of serving LLaMA-2 70B:
# Cost analysis for LLM serving
def serving_cost_analysis():
"""Compute cost per million tokens for LLM serving."""
# Hardware assumptions
a100_cost_per_hour = 3.00 # $/hr (cloud spot pricing)
gpus_needed = 4 # 70B needs ~4x A100-80GB
# Performance (measured, LLaMA-2 70B, FP16, batch=1)
prefill_tokens_per_sec = 2000 # tokens/sec during prefill
decode_tokens_per_sec = 35 # tokens/sec during decode
# Typical request
input_tokens = 512
output_tokens = 256
# Time per request
prefill_time = input_tokens / prefill_tokens_per_sec # 0.256 sec
decode_time = output_tokens / decode_tokens_per_sec # 7.3 sec
total_time = prefill_time + decode_time # 7.56 sec
# Cost per request (no batching)
cost_per_sec = (a100_cost_per_hour * gpus_needed) / 3600 # $0.0033/sec
cost_per_request = cost_per_sec * total_time # $0.025
# Cost per million output tokens
cost_per_M_tokens = (cost_per_request / output_tokens) * 1e6
print(f"Prefill time: {prefill_time:.3f} sec")
print(f"Decode time: {decode_time:.3f} sec")
print(f"Cost per request: ${cost_per_request:.4f}")
print(f"Cost per 1M tokens: ${cost_per_M_tokens:.2f}")
print(f"Requests/hour: {3600/total_time:.0f}")
# With batching (batch=32)
print("\n--- With batch=32 ---")
batch_size = 32
# Decode throughput scales ~linearly until memory-bound
batched_decode_tps = decode_tokens_per_sec * batch_size * 0.7 # 70% efficiency
effective_time_per_request = output_tokens / (batched_decode_tps / batch_size)
batched_cost = cost_per_sec * effective_time_per_request
batched_cost_per_M = (batched_cost / output_tokens) * 1e6
print(f"Batched cost/req: ${batched_cost:.4f}")
print(f"Batched cost/1M tok: ${batched_cost_per_M:.2f}")
print(f"Speedup: {cost_per_M_tokens/batched_cost_per_M:.1f}x cheaper")
serving_cost_analysis()
Key insight: Batching is the single most important optimization for cost. A batch of 32 can make serving 10-20× cheaper per token by amortizing the weight-reading cost across requests.
The roofline model reveals why decode is so inefficient:
Roofline Model for LLM Inference
═══════════════════════════════════════════════════════════════
Performance (TFLOPS)
▲
│ ┌───────────────────
│ │ Peak Compute
312│- - - - - - - - - - - - - - - - -│ 312 TFLOPS (A100)
│ ╱ │
│ ╱ │
│ ╱ │
│ ╱ │
│ ╱ │
│ Memory ╱ │
│ Bandwidth ╱ │
│ Roof ╱ │
│ ╱ │
│ ★ ╱ │
│ Decode ╱ │
│ ╱ ★ Prefill │
│ ╱ (batch=2048) │
│ ╱ │
└───────┴────────────────────────────▶ Arithmetic Intensity
1 156 (FLOPs/Byte)
▲ ▲
Decode Ridge Point
(AI ≈ 1) (compute=memory)
Decode utilizes only 1/156 = 0.6% of the GPU's compute capability!
The arithmetic for a single decode step in a transformer layer:
$$\text{FLOPs per decode step} = 2 \times d_{\text{model}}^2 \times n_{\text{layers}} \times 4$$
$$\text{Bytes read} = d_{\text{model}}^2 \times n_{\text{layers}} \times 4 \times 2 \text{ (FP16)}$$
$$\text{AI} = \frac{2 \times d^2 \times L \times 4}{d^2 \times L \times 4 \times 2} = 1$$
For LLaMA-2 7B ($d=4096, L=32$): FLOPs = $2 \times 4096^2 \times 32 \times 4 \approx 8.6$ GFLOP, but we read 14 GB of weights — giving us only 0.6 FLOP/byte. The GPU sits nearly idle while waiting for memory.
import torch
import torch.nn as nn
import time
class SimplifiedTransformerBlock(nn.Module):
"""Minimal transformer block for profiling prefill vs decode."""
def __init__(self, d_model=4096, n_heads=32):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.o_proj = nn.Linear(d_model, d_model, bias=False)
self.ffn_up = nn.Linear(d_model, d_model * 4, bias=False)
self.ffn_down = nn.Linear(d_model * 4, d_model, bias=False)
def forward(self, x, kv_cache=None):
B, S, D = x.shape
q = self.q_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
if kv_cache is not None:
k = torch.cat([kv_cache[0], k], dim=2)
v = torch.cat([kv_cache[1], v], dim=2)
new_cache = (k.detach(), v.detach())
attn = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, S, D)
out = self.o_proj(out)
ffn_out = self.ffn_down(torch.relu(self.ffn_up(out)))
return out + ffn_out, new_cache
def profile_phases(device='cuda'):
"""Profile prefill vs decode to show the fundamental asymmetry."""
model = SimplifiedTransformerBlock(d_model=4096, n_heads=32).to(device).half()
# Prefill: 512 tokens at once
x_prefill = torch.randn(1, 512, 4096, device=device, dtype=torch.float16)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(10):
_, kv = model(x_prefill)
torch.cuda.synchronize()
prefill_ms = (time.perf_counter() - start) / 10 * 1000
# Decode: 1 token at a time
x_decode = torch.randn(1, 1, 4096, device=device, dtype=torch.float16)
torch.cuda.synchronize()
start = time.perf_counter()
cache = kv
for _ in range(100):
_, cache = model(x_decode, kv_cache=cache)
torch.cuda.synchronize()
decode_ms = (time.perf_counter() - start) / 100 * 1000
print(f"Prefill 512 tokens: {prefill_ms:.2f} ms "
f"({512/prefill_ms*1000:.0f} tok/s)")
print(f"Decode 1 token: {decode_ms:.2f} ms "
f"({1/decode_ms*1000:.0f} tok/s)")
print(f"Decode is {prefill_ms/decode_ms/512:.1%} efficient vs prefill per token")
# Uncomment on GPU:
# profile_phases()
Now that you understand why LLM inference is fundamentally memory-bound and why the KV cache dominates memory, tomorrow we'll attack the cache directly. Day 58: KV Cache Optimization covers Multi-Query Attention (MQA), Grouped-Query Attention (GQA), sliding window attention, and cache compression — techniques that can reduce KV cache memory by 4-8× without meaningful quality loss.