Phase IV · Week 9 · Day 60 of 70 · 2.5 hours
"Generate 5 tokens for the cost of 1 verification — if you guess right. And with a good draft model, you guess right 70-90% of the time."
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 59: vLLM & PagedAttention | Day 61: LLM Quantization | Week 9: LLM Serving Systems | Phase IV: Inference & Deployment | ML Compilers |
Autoregressive LLM inference has an unbreakable sequential dependency: you need token $t_n$ before you can compute token $t_{n+1}$. Or do you? Speculative decoding circumvents this by using a small, fast draft model to predict multiple tokens in parallel, then verifying them all in a single forward pass of the large target model. The mathematical guarantee: the output distribution is identical to the target model — zero quality loss, 2-3× speedup. This technique is deployed in production at Google, Meta, and OpenAI, and understanding it is critical because it fundamentally changes the latency-throughput equation for LLM serving.
Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) works in three steps:
Speculative Decoding — Draft & Verify
═══════════════════════════════════════════════════════════════
Target model: M_target (e.g., LLaMA-2 70B) — slow, accurate
Draft model: M_draft (e.g., LLaMA-2 7B) — fast, approximate
Context so far: "The capital of France"
Step 1: DRAFT — Generate K tokens with M_draft (fast)
──────────────────────────────────────────────────────
M_draft("The capital of France") → " is" (q₁)
M_draft("... France is") → " Paris" (q₂)
M_draft("... is Paris") → "," (q₃)
M_draft("... Paris,") → " a" (q₄)
M_draft("... , a") → " city" (q₅)
Draft tokens: [" is", " Paris", ",", " a", " city"]
Draft probs: [q₁, q₂, q₃, q₄, q₅ ]
Step 2: VERIFY — Score all K+1 positions with M_target (one pass!)
──────────────────────────────────────────────────────────────────
M_target("The capital of France is Paris , a city") → [p₁,p₂,p₃,p₄,p₅,p₆]
One forward pass scores ALL draft tokens simultaneously!
Step 3: ACCEPT/REJECT — Modified rejection sampling
────────────────────────────────────────────────────
For each draft token i (left to right):
Accept with probability min(1, p_i(x) / q_i(x))
" is" → p₁(" is")=0.95, q₁(" is")=0.90 → accept (0.95/0.90>1) ✓
" Paris" → p₂("Paris")=0.85, q₂("Paris")=0.80 → accept ✓
"," → p₃(",")=0.70, q₃(",")=0.75 → accept w.p. 0.70/0.75 ✓
" a" → p₄(" a")=0.15, q₄(" a")=0.60 → REJECT (p/q=0.25) ✗
" city" → skipped (previous rejected)
Result: accepted 3 tokens + sample 1 new from adjusted distribution
Total: 4 tokens generated in ~1 target model forward pass!
The acceptance criterion ensures the output follows the exact target distribution $p$, not the draft distribution $q$. This is a mathematical guarantee, not an approximation:
For each draft token $x$ sampled from $q(x)$:
$$\text{Accept with probability } \min\left(1, \frac{p(x)}{q(x)}\right)$$
If rejected at position $i$, sample a correction token from the residual distribution:
$$p'(x) = \frac{\max(0, p(x) - q(x))}{\sum_{x'} \max(0, p(x') - q(x'))}$$
Theorem: The resulting token sequence is distributed exactly as if it were sampled from $p$ autoregressively.
Why Rejection Sampling Preserves the Target Distribution
═══════════════════════════════════════════════════════════════
Probability of accepting token x from draft:
P(accept x) = q(x) · min(1, p(x)/q(x))
= min(q(x), p(x))
Case 1: p(x) ≥ q(x) → P(accept x) = q(x) · 1 = q(x)
Case 2: p(x) < q(x) → P(accept x) = q(x) · p(x)/q(x) = p(x)
So P(accept x) = min(p(x), q(x)) for all x
Total acceptance rate:
α = Σ_x min(p(x), q(x)) ← higher when q ≈ p
On rejection, sample from residual:
p'(x) ∝ max(0, p(x) - q(x))
Combined: P(output x) = min(p(x), q(x)) + (1-α) · p'(x)
= min(p(x), q(x)) + max(0, p(x) - q(x))
= p(x) ✓ Exact target distribution!
The expected number of tokens accepted per speculation round with $K$ draft tokens:
$$E[\text{accepted}] = \sum_{i=1}^{K} \prod_{j=1}^{i} \alpha_j + 1$$
where $\alpha_j$ is the acceptance rate at position $j$. With uniform acceptance rate $\alpha$:
$$E[\text{tokens per round}] = \frac{1 - \alpha^{K+1}}{1 - \alpha}$$
| Acceptance Rate $\alpha$ | K=3 | K=5 | K=7 |
|---|---|---|---|
| 0.5 | 1.9 | 2.0 | 2.0 |
| 0.7 | 2.7 | 3.1 | 3.2 |
| 0.8 | 3.4 | 4.1 | 4.4 |
| 0.9 | 3.8 | 5.2 | 6.1 |
At $\alpha = 0.8$ with $K=5$, you get ~4 tokens per verification — a 4× speedup for decode-bound generation.
import torch
import torch.nn.functional as F
def speculative_decode(
target_model,
draft_model,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
K: int = 5, # draft length per round
temperature: float = 1.0,
):
"""Speculative decoding with rejection sampling.
Args:
target_model: Large model (e.g., 70B) — slow, accurate
draft_model: Small model (e.g., 7B) — fast, approximate
input_ids: [1, seq_len] input token IDs
max_new_tokens: maximum tokens to generate
K: number of draft tokens per speculation round
temperature: sampling temperature
Returns:
Generated token IDs
"""
device = input_ids.device
generated = input_ids.clone()
tokens_generated = 0
while tokens_generated < max_new_tokens:
# ─── Step 1: Draft K tokens ───
draft_tokens = []
draft_probs = []
draft_input = generated.clone()
for _ in range(K):
with torch.no_grad():
logits = draft_model(draft_input).logits[:, -1, :]
probs = F.softmax(logits / temperature, dim=-1)
token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(token)
draft_probs.append(probs)
draft_input = torch.cat([draft_input, token], dim=-1)
draft_tokens = torch.cat(draft_tokens, dim=-1) # [1, K]
# ─── Step 2: Verify with target model (single forward pass!) ───
verify_input = torch.cat([generated, draft_tokens], dim=-1)
with torch.no_grad():
target_logits = target_model(verify_input).logits
# Get target probs at each draft position
# Position i verifies draft token i
n_input = generated.shape[1]
# ─── Step 3: Accept/Reject via rejection sampling ───
accepted = 0
for i in range(K):
target_probs_i = F.softmax(
target_logits[:, n_input + i - 1, :] / temperature, dim=-1
)
draft_prob = draft_probs[i][0, draft_tokens[0, i]]
target_prob = target_probs_i[0, draft_tokens[0, i]]
# Accept with probability min(1, p/q)
accept_prob = min(1.0, (target_prob / draft_prob).item())
if torch.rand(1).item() < accept_prob:
accepted += 1
else:
# Reject: sample from residual distribution
residual = torch.clamp(target_probs_i - draft_probs[i], min=0)
residual = residual / residual.sum()
correction = torch.multinomial(residual, num_samples=1)
# Accept tokens up to i, plus correction token
generated = torch.cat([
generated,
draft_tokens[:, :i],
correction
], dim=-1)
tokens_generated += i + 1
break
else:
# All K draft tokens accepted!
# Sample one more token from target at position K
bonus_probs = F.softmax(
target_logits[:, n_input + K - 1, :] / temperature, dim=-1
)
bonus_token = torch.multinomial(bonus_probs, num_samples=1)
generated = torch.cat([generated, draft_tokens, bonus_token], dim=-1)
tokens_generated += K + 1
return generated
# Usage:
# output = speculative_decode(llama_70b, llama_7b, prompt_ids, K=5)
The acceptance rate depends critically on how well the draft approximates the target:
Draft Model Selection
═══════════════════════════════════════════════════════════════
Strategy 1: Smaller model from same family
──────────────────────────────────────────
Target: LLaMA-2 70B → Draft: LLaMA-2 7B
Acceptance rate: ~70-80%
Pros: Easy setup, shared tokenizer, good alignment
Cons: Draft model still needs GPU memory
Strategy 2: Quantized version of target
────────────────────────────────────────
Target: LLaMA-2 70B FP16 → Draft: LLaMA-2 70B INT4
Acceptance rate: ~85-95% (very high!)
Pros: Best acceptance rate (same architecture)
Cons: Draft is still large (just faster per token)
Strategy 3: Distilled lightweight model
────────────────────────────────────────
Target: LLaMA-2 70B → Draft: Custom 500M model
Acceptance rate: ~60-70%
Pros: Very fast drafting, minimal memory
Cons: Lower acceptance, need training data
Strategy 4: N-gram / retrieval-based
─────────────────────────────────────
Target: any LLM → Draft: suffix tree on training data
Acceptance rate: ~40-60% (domain-dependent)
Pros: Zero GPU cost for drafting
Cons: Low acceptance for creative/novel text
┌───────────────────────────────────────────────────────┐
│ Key tradeoff: faster draft ⇄ lower acceptance rate │
│ Sweet spot: draft should be 5-10× faster than target │
│ with >70% acceptance rate │
└───────────────────────────────────────────────────────┘
Medusa (Cai et al., 2024) eliminates the separate draft model entirely by adding lightweight prediction heads to the target model:
Medusa Architecture
═══════════════════════════════════════════════════════════════
Standard LLM:
┌──────────────────────────┐
│ Transformer layers │
│ (frozen) │──→ LM Head ──→ token t+1
└──────────────────────────┘
Medusa LLM:
┌──────────────────────────┐
│ Transformer layers │──→ LM Head ──→ token t+1 (original)
│ (frozen) │──→ Medusa H₁ ──→ token t+2 (predicted)
└──────────────────────────┘──→ Medusa H₂ ──→ token t+3 (predicted)
──→ Medusa H₃ ──→ token t+4 (predicted)
Each Medusa head: single linear layer + residual
Parameters added: ~0.5% of model size
Training: fine-tune heads on target model's output distribution
Tree-based verification:
────────────────────────
Each head proposes top-k candidates → forms a tree:
t+1: [A, B]
╱ ╲
t+2: [C,D] [E,F]
╱ ╲ ╱ ╲
t+3: [G,H] [I,J] [K,L]
Verify all paths in ONE forward pass using tree attention mask!
Accept the longest valid path → 3-5 tokens per step.
EAGLE (Li et al., 2024) uses the target model's own hidden states to draft:
# EAGLE drafting concept (simplified)
class EAGLEDraftHead(nn.Module):
"""Draft future tokens using current hidden states."""
def __init__(self, hidden_size, vocab_size):
super().__init__()
# Predict next hidden state from current
self.fc = nn.Linear(hidden_size * 2, hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, hidden_state, embedding):
"""
hidden_state: [B, D] from target model's last layer
embedding: [B, D] embedding of last accepted token
"""
# Concatenate hidden state + embedding for richer features
combined = torch.cat([hidden_state, embedding], dim=-1)
next_hidden = self.fc(combined)
logits = self.lm_head(next_hidden)
return logits, next_hidden # next_hidden used for further drafting
SpecInfer (Miao et al., 2024) uses multiple draft models and combines their predictions into a speculation tree:
SpecInfer — Multi-Draft Tree
═══════════════════════════════════════════════════════════════
Draft Model A (7B): "is" → "Paris" → "," → "the"
Draft Model B (1.3B): "is" → "Paris" → "." → "It"
Draft Model C (n-gram): "is" → "the" → "capital"
Merged speculation tree:
"is"
╱ ╲
"Paris" "the"
╱ ╲ ╲
"," "." "capital"
╱ ╲
"the" "It"
Verify entire tree in one forward pass!
→ Accept longest valid branch
→ More diverse candidates = higher acceptance rate
# vLLM speculative decoding configuration
from vllm import LLM, SamplingParams
def setup_speculative_vllm():
"""Configure vLLM with speculative decoding enabled."""
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
tensor_parallel_size=4,
speculative_model="meta-llama/Llama-2-7b-hf", # draft model
num_speculative_tokens=5, # K value
# Alternative: use ngram-based speculation (no draft model needed)
# speculative_model="[ngram]",
# ngram_prompt_lookup_max=4,
)
return llm
# TensorRT-LLM speculative decoding setup
def trt_llm_speculative_config():
"""TensorRT-LLM config for speculative decoding (conceptual)."""
config = {
"target_model": "llama-2-70b",
"draft_model": "llama-2-7b",
"max_draft_len": 5,
"acceptance_threshold": 0.0, # pure rejection sampling
"use_medusa": False,
# Medusa alternative:
# "use_medusa": True,
# "medusa_choices": [[0], [0,0], [1], [0,1], [2]], # tree structure
# "num_medusa_heads": 3,
}
return config
When to Use Speculative Decoding
═══════════════════════════════════════════════════════════════
✓ USE when:
──────────
• Decode latency matters (interactive chat, real-time)
• Target model is large (≥13B) and memory-bound in decode
• Draft model has high acceptance rate (>70%)
• Batch size is small (1-4) — decode is deeply memory-bound
• Content is somewhat predictable (code, structured text)
✗ AVOID when:
─────────────
• Batch size is large (>16) — decode already memory-efficient
• Draft model has low acceptance rate (<50%)
• Target model is small — decode is already fast
• GPU memory is tight — draft model adds memory pressure
• Creative/diverse text — low acceptance hurts throughput
┌────────────────────────────────────────────────────────┐
│ Batch Size Without Spec With Spec Improvement │
│ 1 35 tok/s 105 tok/s 3.0× │
│ 4 120 tok/s 280 tok/s 2.3× │
│ 16 400 tok/s 520 tok/s 1.3× │
│ 64 1200 tok/s 1100 tok/s 0.9× (worse!)│
└────────────────────────────────────────────────────────┘
At large batches, verification overhead > drafting benefit
import torch
import time
def measure_acceptance_rate(target_logits, draft_logits, temperature=1.0):
"""Measure empirical acceptance rate between target and draft distributions."""
target_probs = F.softmax(target_logits / temperature, dim=-1)
draft_probs = F.softmax(draft_logits / temperature, dim=-1)
# Theoretical acceptance rate: sum of min(p, q) for all tokens
acceptance = torch.min(target_probs, draft_probs).sum(dim=-1)
return acceptance.mean().item()
def simulate_speculative_speedup(acceptance_rate, K,
draft_time_ms, verify_time_ms):
"""Simulate expected speedup from speculative decoding."""
alpha = acceptance_rate
# Expected tokens per round
expected_tokens = (1 - alpha**(K+1)) / (1 - alpha)
# Time per round = K * draft_time + 1 * verify_time
time_per_round = K * draft_time_ms + verify_time_ms
# Baseline: 1 token per verify_time
baseline_ms_per_token = verify_time_ms
speculative_ms_per_token = time_per_round / expected_tokens
speedup = baseline_ms_per_token / speculative_ms_per_token
print(f"Acceptance rate: {alpha:.0%}")
print(f"Draft length K: {K}")
print(f"Expected tokens/rnd: {expected_tokens:.1f}")
print(f"Time per round: {time_per_round:.1f} ms")
print(f"Baseline tok/ms: {1/baseline_ms_per_token:.3f}")
print(f"Spec tok/ms: {expected_tokens/time_per_round:.3f}")
print(f"Speedup: {speedup:.2f}×")
return speedup
# Sweep acceptance rate and K
print("=" * 50)
for alpha in [0.6, 0.7, 0.8, 0.9]:
for K in [3, 5, 7]:
print(f"\n--- α={alpha}, K={K} ---")
simulate_speculative_speedup(
acceptance_rate=alpha, K=K,
draft_time_ms=1.5, # 7B draft: ~1.5ms/token
verify_time_ms=12.0 # 70B target: ~12ms/token
)
We've covered how to serve LLMs faster and cheaper through PagedAttention, continuous batching, and speculative decoding. But there's one more critical lever: making the model itself smaller and faster through quantization. Day 61: LLM Quantization dives into GPTQ, AWQ, SqueezeLLM, and the precision ladder from FP16 down to INT4 and beyond — where you lose 0.5% accuracy but gain 4× memory savings.