Phase II — Attention, Transformers & Scaling | Week 4 | 2.5 hours "The model learns a distribution. Sampling is the art of drawing from it well."
An autoregressive language model gives you $p(x_t | x_{
The question is how to sample. Different strategies produce dramatically different outputs from the same model.
Model output at each step:
Token : "the" "a" "an" "one" "this" ...
p(token): 0.35 0.20 0.15 0.10 0.08 ...
│
├── Greedy: always pick "the" (highest prob)
├── Top-k (k=3): sample from {"the", "a", "an"}
├── Top-p (p=0.7): sample from {"the", "a", "an"} (cumsum ≥ 0.7)
└── Temperature: reshape the distribution, then sample
The simplest strategy: always pick the highest-probability token.
$$x_t = \arg\max_{v} p(v | x_{ Pros: Deterministic, fast, never picks a "bad" token. Cons: Produces repetitive, boring, degenerate text. The globally optimal sequence is almost never composed of locally optimal choices. This is the classic greedy algorithm failure. Before sampling, divide logits by temperature $T$: $$p_T(v | x_{ where $z_v$ are the raw logits from the model. Mathematical proof that $T \to 0$ gives greedy: As $T \to 0$, $\exp(z_v / T)$ grows exponentially faster for the largest $z_v$. The softmax concentrates all mass on $\arg\max z_v$. Mathematical proof that $T \to \infty$ gives uniform: As $T \to \infty$, $z_v / T \to 0$ for all $v$, so $\exp(z_v / T) \to 1$ for all $v$. The softmax becomes $1/|V|$ — uniform. Restrict sampling to the $k$ most probable tokens. Redistribute probability mass among them. $$p_{\text{top-k}}(v) = \begin{cases} \frac{p(v)}{\sum_{v' \in \text{top-k}} p(v')} & \text{if } v \in \text{top-k} \\ 0 & \text{otherwise} \end{cases}$$ Problem: Fixed $k$ is suboptimal. Sometimes the distribution is peaked (only 1-2 good tokens) and sometimes it's flat (many reasonable tokens). A fixed $k$ either cuts good options or keeps bad ones. Dynamically select the smallest set of tokens whose cumulative probability exceeds $p$: $$\text{top-p}(p) = \min \left\{ k : \sum_{i=1}^{k} p(v_{(i)} | x_{ where $v_{(i)}$ is the $i$-th most probable token. Why top-p > top-k: It adapts to the shape of the distribution. For confident predictions, nucleus is small. For uncertain predictions, nucleus is large. A common pathology: models repeat themselves. Repetition penalty modifies logits for tokens that already appeared: $$z'_v = \begin{cases} z_v / \alpha & \text{if } v \in \text{generated tokens and } z_v > 0 \\ z_v \cdot \alpha & \text{if } v \in \text{generated tokens and } z_v < 0 \end{cases}$$ where $\alpha > 1$ is the penalty factor (typical: 1.1–1.3). Sometimes you need the output to conform to a specific format — JSON, SQL, code. Constrained decoding masks out tokens that would violate the grammar at each step. Libraries: Demonstrate empirically:
1. $T \to 0$: Generate 5 completions with $T = 0.01$. Show they are all identical (greedy).
2. $T \to \infty$: Generate with $T = 100$. Show it's nearly random gibberish.
3. Find the Goldilocks zone: Try $T = 0.3, 0.5, 0.7, 0.9, 1.0, 1.2$. Which produces the most coherent yet creative text? For the same prompt, generate 10 samples each with:
- Top-k = 5, 10, 20, 50
- Top-p = 0.5, 0.7, 0.9, 0.95 Measure:
1. Diversity: How different are the 10 samples from each other?
2. Quality: Subjective ranking of coherence
3. Nucleus size: For top-p, how many tokens are in the nucleus on average? Make your model output valid JSON for these tasks:
1. Generate a JSON object with keys Sampling strategies connect directly to robot action generation. In diffusion-based VLAs (like π₀), the denoising process is analogous to temperature-controlled sampling — more denoising steps = more "greedy," fewer steps = more exploration. Understanding these tradeoffs in the language domain prepares you for understanding them in action space.Greedy output: "The cat sat on the mat. The cat sat on the mat. The cat..."
← degenerates into repetition loops
27.3 Temperature Scaling
Temperature
Effect
Distribution
$T \to 0$
Approaches greedy (argmax)
Delta function on top token
$T = 1$
Model's learned distribution
As trained
$T > 1$
Flatter distribution
More random, more creative
$T \to \infty$
Uniform random
Every token equally likely
27.4 Top-k Sampling (Fan et al., 2018)
27.5 Top-p (Nucleus) Sampling (Holtzman et al., 2020)
Example with p = 0.9:
Token : "the" "a" "an" "one" "this" "some" ...
Prob : 0.35 0.20 0.15 0.10 0.08 0.05 ...
Cumsum : 0.35 0.55 0.70 0.80 0.88 0.93 ← stop here!
↑
top-p set = first 6 tokens
27.6 Repetition Penalty
27.7 Structured Generation (Constrained Decoding)
Grammar-guided decoding for JSON:
State: expecting key after "{"
Valid next tokens: '"' (quote to start key)
Invalid tokens: everything else → mask to -∞
State: expecting value after ":"
Valid next tokens: '"', digit, '[', '{', 'true', 'false', 'null'
Invalid tokens: mask to -∞
outlines, guidance, lm-format-enforcer.
Implementation (60 min)
Complete Sampling Implementation
import torch
import torch.nn.functional as F
from typing import Optional
def sample_next_token(
logits: torch.Tensor, # (vocab_size,)
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
generated_ids: Optional[torch.Tensor] = None,
) -> int:
"""Sample the next token using various strategies.
Args:
logits: Raw logits from the model for the next position
temperature: Scaling factor (0 = greedy, 1 = normal, >1 = creative)
top_k: If set, only sample from top-k tokens
top_p: If set, only sample from nucleus (cumulative prob >= p)
repetition_penalty: Penalize already-generated tokens (>1.0)
generated_ids: Previously generated token IDs for rep penalty
Returns:
Sampled token ID
"""
# Step 1: Apply repetition penalty
if repetition_penalty != 1.0 and generated_ids is not None:
for token_id in set(generated_ids.tolist()):
if logits[token_id] > 0:
logits[token_id] /= repetition_penalty
else:
logits[token_id] *= repetition_penalty
# Step 2: Apply temperature
if temperature == 0.0:
# Greedy decoding
return logits.argmax().item()
logits = logits / temperature
# Step 3: Apply top-k filtering
if top_k is not None and top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k)[0][-1]
logits[indices_to_remove] = float('-inf')
# Step 4: Apply top-p (nucleus) filtering
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift right to keep first token above threshold
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = float('-inf')
# Step 5: Sample from the filtered distribution
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1).item()
@torch.no_grad()
def generate(
model,
prompt_ids: torch.Tensor, # (seq_len,)
max_new_tokens: int = 200,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
) -> torch.Tensor:
"""Generate text autoregressively."""
model.eval()
generated = prompt_ids.clone()
for _ in range(max_new_tokens):
# Get logits for the last position
# Crop context if needed (model has max context length)
context = generated[-model.block_size:] if len(generated) > model.block_size else generated
logits = model(context.unsqueeze(0)) # (1, seq_len, vocab_size)
logits = logits[0, -1, :] # (vocab_size,) — last position
# Sample next token
next_id = sample_next_token(
logits.clone(),
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
generated_ids=generated,
)
generated = torch.cat([generated, torch.tensor([next_id])])
return generated
Comparing Sampling Strategies
def compare_sampling_strategies(model, tokenizer, prompt: str):
"""Generate with multiple strategies and compare outputs."""
prompt_ids = torch.tensor(tokenizer.encode(prompt))
strategies = [
{"name": "Greedy", "temperature": 0.0},
{"name": "T=0.5", "temperature": 0.5},
{"name": "T=1.0", "temperature": 1.0},
{"name": "T=1.5", "temperature": 1.5},
{"name": "Top-k=10", "temperature": 1.0, "top_k": 10},
{"name": "Top-k=50", "temperature": 1.0, "top_k": 50},
{"name": "Top-p=0.9", "temperature": 1.0, "top_p": 0.9},
{"name": "Top-p=0.95", "temperature": 1.0, "top_p": 0.95},
{"name": "Rep penalty", "temperature": 1.0, "top_p": 0.9,
"repetition_penalty": 1.2},
]
print(f"Prompt: '{prompt}'\n")
print("=" * 80)
for s in strategies:
name = s.pop("name")
ids = generate(model, prompt_ids, max_new_tokens=100, **s)
text = tokenizer.decode(ids.tolist())
print(f"\n[{name}]")
print(text[:300])
print("-" * 40)
Constrained JSON Decoding
import json
import re
class JSONConstrainedDecoder:
"""Simple grammar-guided decoder for JSON output."""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.json_so_far = ""
def get_valid_token_mask(self, logits: torch.Tensor) -> torch.Tensor:
"""Return a mask of valid next tokens given current JSON state."""
mask = torch.zeros_like(logits, dtype=torch.bool)
vocab_size = logits.size(-1)
for token_id in range(vocab_size):
token_str = self.tokenizer.decode([token_id])
candidate = self.json_so_far + token_str
# Check if this token could lead to valid JSON
if self._could_be_valid_json_prefix(candidate):
mask[token_id] = True
return mask
def _could_be_valid_json_prefix(self, s: str) -> bool:
"""Check if string could be a prefix of valid JSON."""
s = s.strip()
if not s:
return True
# Try parsing — if it works, it's valid JSON
try:
json.loads(s)
return True
except json.JSONDecodeError as e:
# If error is at the end, it might be a valid prefix
# (incomplete but could become valid with more tokens)
if e.pos is not None and e.pos >= len(s) - 1:
return True
# Allow if we're still inside a string or number
if "Unterminated" in str(e) or "Expecting" in str(e):
return True
return False
def decode_step(self, logits: torch.Tensor) -> int:
"""Apply JSON constraints and sample."""
mask = self.get_valid_token_mask(logits)
# Mask invalid tokens
logits[~mask] = float('-inf')
# Sample from valid tokens only
token_id = sample_next_token(logits, temperature=0.8)
# Update state
self.json_so_far += self.tokenizer.decode([token_id])
return token_id
Exercise (45 min)
E27.1 — Temperature Extremes
E27.2 — Top-p vs Top-k Analysis
E27.3 — Constrained Decoding Challenge
"name", "age", "city"
2. Generate a JSON array of 3 numbers
3. Discuss: What are the limitations of token-level grammar constraints?
Key Takeaways
Connection to the Thread
Further Reading