Phase III — LLMs: Training & Alignment | Week 6 | 2.5 hours "Context is all you need — but how much context can you actually use?" — An LLM scaling researcher
Transformers have a fundamental limitation: self-attention is $O(n^2)$ in sequence length.
Context length evolution:
GPT-2 (2019): 1,024 tokens
GPT-3 (2020): 2,048 tokens
GPT-3.5 (2023): 4,096 → 16,384 tokens
GPT-4 (2023): 8,192 → 128,000 tokens
Claude 3 (2024): 200,000 tokens
Gemini 1.5 (2024): 1,000,000+ tokens
Memory for KV cache at 128K context (7B model, FP16):
Layers: 32, Heads: 32, Head dim: 128
KV per layer: 2 × 128K × 128 × 2 bytes = 64 MB
Total: 32 × 64 MB = 2 GB (just for one request's KV cache!)
Most modern LLMs use RoPE (Rotary Position Encoding):
$$ \text{RoPE}(x_m, m) = \begin{pmatrix} x_m^{(1)} \cos(m\theta_1) - x_m^{(2)} \sin(m\theta_1) \\ x_m^{(1)} \sin(m\theta_1) + x_m^{(2)} \cos(m\theta_1) \\ \vdots \end{pmatrix} $$
where $\theta_i = 10000^{-2i/d}$ and $m$ is the position index.
Key property: The attention score between positions $m$ and $n$ depends only on $m - n$ (relative position), but is computed by rotating the query and key vectors.
To extend a model trained on 4K context to 128K+ without retraining:
Position interpolation (simple):
$$ \theta'_i = \theta_i / s, \quad s = \frac{L_{\text{target}}}{L_{\text{trained}}} $$
Problem: compresses all frequencies equally, losing fine-grained position information.
YaRN (Yet another RoPE extensioN):
$$ \theta'_i = \begin{cases} \theta_i & \text{if } \lambda_i < 1 \quad \text{(high frequency: keep)} \\ \theta_i / s & \text{if } \lambda_i > 1 \quad \text{(low frequency: interpolate)} \\ (1-\gamma)\frac{\theta_i}{s} + \gamma\theta_i & \text{otherwise (smooth blend)} \end{cases} $$
Intuition:
High-frequency RoPE dimensions → encode nearby positions (keep intact)
Low-frequency RoPE dimensions → encode distant positions (interpolate)
Like a rubber band: stretch only the low-frequency part
→ preserves local attention while extending global reach
For very long contexts (1M+ tokens), Ring Attention distributes the sequence across multiple devices:
Standard attention: Ring attention:
┌─────────────────┐ Device 1 Device 2 Device 3
│ Full Q·K^T·V │ ┌──────┐ ┌──────┐ ┌──────┐
│ (all in one GPU)│ │ Q₁K₁ │→→→│ Q₁K₂ │→→→│ Q₁K₃ │
│ │ │ V₁ │ │ V₂ │ │ V₃ │
│ Memory: O(n²) │ └──┬───┘ └──┬───┘ └──┬───┘
└─────────────────┘ │ │ │
←←← Ring communication ←←←
Each device holds 1/N of the sequence
K,V blocks rotate around the ring
Each device computes partial attention
→ Memory per device: O(n²/N)
→ Enables 1M+ context on multi-GPU
Chain-of-Thought (CoT) prompting unlocks multi-step reasoning:
Standard prompting:
Q: "Roger has 5 balls. He buys 2 cans of 3 balls each. How many balls?"
A: "11" ← Often wrong without reasoning
Chain-of-thought:
Q: "Roger has 5 balls. He buys 2 cans of 3 balls each. How many balls?"
A: "Roger starts with 5 balls. He buys 2 cans × 3 balls = 6 balls.
Total: 5 + 6 = 11 balls." ← Correct with reasoning trace
OpenAI's o1 introduced test-time compute scaling: instead of making the model bigger, give it more time to think.
Traditional scaling: Test-time compute scaling:
More parameters → better More thinking tokens → better
Fixed compute per token Variable compute per problem
Cost: training time Cost: inference time
Limit: GPU memory Limit: context length + latency
Key mechanisms:
$$ \text{Performance} \propto \log(\text{test-time compute}) \quad \text{(on reasoning tasks)} $$
Standard model: Prompt → Answer (fast, sometimes wrong)
o1-style model: Prompt → [Think][Think][Think]... → Answer (slow, more accurate)
Trade-off:
Easy questions: standard model is efficient (thinking wastes tokens)
Hard questions: o1-style is better (thinking prevents errors)
→ Optimal strategy: route easy questions to fast model,
hard questions to reasoning model
"""
Day 39 Implementation: RoPE scaling and chain-of-thought reasoning.
"""
import math
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# ============================================================
# Part 1: RoPE Mathematics
# ============================================================
def compute_rope_frequencies(
dim: int, base: float = 10000.0, seq_len: int = 4096,
) -> torch.Tensor:
"""Compute RoPE rotation frequencies."""
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(seq_len).float()
# Outer product: [seq_len, dim/2]
angles = torch.outer(positions, freqs)
return angles
def apply_rope(x: torch.Tensor, angles: torch.Tensor) -> torch.Tensor:
"""Apply RoPE to input tensor x of shape [seq_len, dim]."""
d = x.shape[-1]
x1, x2 = x[..., :d//2], x[..., d//2:]
cos = torch.cos(angles)
sin = torch.sin(angles)
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
def position_interpolation(
dim: int, base: float, trained_len: int, target_len: int,
) -> torch.Tensor:
"""Simple position interpolation for RoPE extension."""
scale = target_len / trained_len
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
freqs_scaled = freqs / scale # interpolate all frequencies
positions = torch.arange(target_len).float()
return torch.outer(positions, freqs_scaled)
def yarn_frequencies(
dim: int,
base: float,
trained_len: int,
target_len: int,
beta_fast: float = 32.0,
beta_slow: float = 1.0,
) -> torch.Tensor:
"""YaRN frequency scaling for RoPE extension."""
scale = target_len / trained_len
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# Compute wavelengths
wavelengths = 2 * math.pi / freqs
# Ramp function: smooth interpolation between keeping and scaling
low = max(math.floor(beta_fast * dim / (2 * math.pi * (trained_len - 1))), 1)
high = min(math.ceil(beta_slow * dim / (2 * math.pi * (trained_len - 1))),
dim // 2 - 1)
yarn_freqs = torch.zeros_like(freqs)
for i in range(len(freqs)):
if i < low:
# High frequency: keep original
yarn_freqs[i] = freqs[i]
elif i > high:
# Low frequency: interpolate
yarn_freqs[i] = freqs[i] / scale
else:
# Blend
gamma = (i - low) / (high - low)
yarn_freqs[i] = (1 - gamma) * freqs[i] + gamma * (freqs[i] / scale)
positions = torch.arange(target_len).float()
return torch.outer(positions, yarn_freqs)
# ============================================================
# Part 2: Attention Distance Analysis
# ============================================================
def analyze_attention_decay(dim: int = 64, max_dist: int = 1000):
"""Show how RoPE makes attention decay with distance."""
freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
# Create query and key at position 0
q = torch.randn(dim)
k = torch.randn(dim)
distances = list(range(0, max_dist, 10))
dot_products = []
for dist in distances:
angles_q = torch.zeros(dim // 2)
angles_k = torch.tensor([dist]).float() * freqs
q_rot = apply_rope(q.unsqueeze(0), angles_q.unsqueeze(0)).squeeze()
k_rot = apply_rope(k.unsqueeze(0), angles_k.unsqueeze(0)).squeeze()
dot = (q_rot * k_rot).sum().item()
dot_products.append(dot)
return distances, dot_products
# ============================================================
# Part 3: Chain-of-Thought Prompting
# ============================================================
def create_cot_prompt(question: str, examples: list[dict]) -> str:
"""Create a chain-of-thought prompt."""
prompt = "Solve each problem step by step.\n\n"
for ex in examples:
prompt += f"Q: {ex['question']}\n"
prompt += f"A: Let's think step by step.\n{ex['reasoning']}\n"
prompt += f"The answer is {ex['answer']}.\n\n"
prompt += f"Q: {question}\nA: Let's think step by step.\n"
return prompt
COT_EXAMPLES = [
{
"question": "A robot picks 3 items per minute. It works for 5 minutes, "
"rests for 2, then works for 3 more. How many items?",
"reasoning": "First work period: 3 items/min × 5 min = 15 items.\n"
"Rest period: 0 items.\n"
"Second work period: 3 items/min × 3 min = 9 items.\n"
"Total: 15 + 9 = 24 items.",
"answer": "24 items",
},
{
"question": "A warehouse has 4 aisles. Each aisle has 10 shelves. "
"Each shelf holds 6 bins. How many bins total?",
"reasoning": "Bins per aisle: 10 shelves × 6 bins = 60 bins.\n"
"Total bins: 4 aisles × 60 bins = 240 bins.",
"answer": "240 bins",
},
]
if __name__ == "__main__":
# RoPE visualization
print("=" * 60)
print("RoPE Frequency Analysis")
print("=" * 60)
dim = 64
angles_original = compute_rope_frequencies(dim, seq_len=4096)
angles_pi = position_interpolation(dim, 10000.0, 4096, 32768)
angles_yarn = yarn_frequencies(dim, 10000.0, 4096, 32768)
print(f"Original RoPE: {angles_original.shape} (4K context)")
print(f"PI-extended: {angles_pi.shape} (32K context)")
print(f"YaRN-extended: {angles_yarn.shape} (32K context)")
# Frequency comparison
freqs_orig = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
print(f"\nHighest freq: {freqs_orig[0]:.6f} (wavelength: {2*math.pi/freqs_orig[0]:.0f} positions)")
print(f"Lowest freq: {freqs_orig[-1]:.6f} (wavelength: {2*math.pi/freqs_orig[-1]:.0f} positions)")
# Chain-of-thought demo
print("\n" + "=" * 60)
print("Chain-of-Thought Prompt")
print("=" * 60)
test_question = ("A robot fleet has 8 robots. Each charges for 20 minutes "
"every 3 hours. In a 12-hour shift, how much total "
"charging time across all robots?")
cot = create_cot_prompt(test_question, COT_EXAMPLES)
print(cot[:500])
print(f"...\n[Prompt length: {len(cot)} chars]")
Simulate o1-style reasoning: 1. For a set of math problems, generate answers with different "thinking budgets" (50, 100, 200, 500 tokens of reasoning) 2. Plot accuracy vs. thinking tokens 3. Identify the point of diminishing returns — when does more thinking stop helping?
Long context is critical for robotics: a robot needs to remember its entire mission (thousands of time steps) to make coherent decisions. Ring attention techniques map to distributed robot compute, and chain-of-thought reasoning is how robots will plan complex multi-step warehouse operations — decomposing "sort all packages by destination" into individual pick-place steps.