Phase II — Attention, Transformers & Scaling | Week 2 | 2.5 hours "The dot product is nature's similarity measure — but without scaling, it breaks softmax."
Yesterday's Bahdanau attention uses a learned feedforward network to compute alignment scores. This works but is slow — for each (query, key) pair, we run a neural network forward pass.
Luong et al. (2015) proposed a simpler alternative: just take the dot product between query and key vectors.
| Attention type | Score function | Complexity |
|---|---|---|
| Additive (Bahdanau) | $e = v^T \tanh(W_s s + W_h h)$ | $O(d_a)$ per pair, sequential |
| Dot-product (Luong) | $e = s^T h$ | $O(d)$ per pair, parallelizable |
| Scaled dot-product | $e = \frac{s^T h}{\sqrt{d_k}}$ | $O(d)$ per pair, parallelizable |
The dot product is computationally cheaper and can be computed in parallel via matrix multiplication.
Vaswani et al. (2017) formalized attention using three projections:
Given input sequence $X \in \mathbb{R}^{n \times d}$:
$$Q = X W_Q, \quad K = X W_K, \quad V = X W_V$$
where $W_Q, W_K \in \mathbb{R}^{d \times d_k}$ and $W_V \in \mathbb{R}^{d \times d_v}$.
$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$
Step by step:
1. Compute scores: S = Q K^T shape: (n, n)
2. Scale: S = S / √d_k prevents softmax saturation
3. Softmax: A = softmax(S) row-wise, each row sums to 1
4. Weighted sum: O = A V shape: (n, d_v)
In matrix form, this is beautifully parallelizable:
Q: (n × d_k) ┐
├──→ QK^T: (n × n) ──→ softmax ──→ A: (n × n) ──→ AV: (n × d_v)
K: (n × d_k) ┘ ↑
V: (n × d_v) ────────────────────────────────────────────────────────┘
This is the critical insight that makes dot-product attention work in practice.
Problem: For large $d_k$, the dot products $q \cdot k$ grow in magnitude. If $q$ and $k$ have components drawn from $\mathcal{N}(0, 1)$, then:
$$\mathbb{E}[q \cdot k] = 0, \quad \text{Var}[q \cdot k] = d_k$$
So the standard deviation of dot products scales as $\sqrt{d_k}$. For $d_k = 64$, typical dot products have magnitude $\sim 8$.
What happens without scaling: Large values push softmax into regions where gradients are extremely small:
$$\text{softmax}([10, 0, 0]) = [0.9999, 0.00005, 0.00005]$$
The gradient $\frac{\partial}{\partial x_i} \text{softmax}(x)_i = p_i(1 - p_i)$ vanishes when $p_i \approx 1$ or $p_i \approx 0$.
Scaling by $\sqrt{d_k}$ restores the variance to $\sim 1$:
$$\text{Var}\!\left[\frac{q \cdot k}{\sqrt{d_k}}\right] = \frac{d_k}{d_k} = 1$$
This keeps softmax in a regime with healthy gradients.
The elegance of scaled dot-product attention is that it's just three matrix multiplications:
$$\underbrace{Q K^T}_{\text{all pairwise similarities}} \rightarrow \underbrace{\text{softmax}}_{\text{normalize}} \rightarrow \underbrace{\times V}_{\text{weighted retrieval}}$$
This means every position attends to every other position simultaneously — no recurrence, no sequential bottleneck. An $n$-token sequence requires $O(n^2 \cdot d)$ compute but has $O(1)$ sequential depth.
| Property | RNN | Attention |
|---|---|---|
| Sequential operations | $O(n)$ | $O(1)$ |
| Total compute | $O(n \cdot d^2)$ | $O(n^2 \cdot d)$ |
| Max path length | $O(n)$ | $O(1)$ |
For $n < d$ (typical in practice), attention wins on everything.
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor, # (batch, n_q, d_k)
key: torch.Tensor, # (batch, n_k, d_k)
value: torch.Tensor, # (batch, n_k, d_v)
mask: torch.Tensor | None = None, # (batch, n_q, n_k) or broadcastable
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute scaled dot-product attention.
Returns:
output: (batch, n_q, d_v) — weighted value vectors
weights: (batch, n_q, n_k) — attention weights
"""
d_k = query.size(-1)
# Step 1: Compute raw scores via matrix multiplication
# (batch, n_q, d_k) × (batch, d_k, n_k) → (batch, n_q, n_k)
scores = torch.bmm(query, key.transpose(-2, -1))
# Step 2: Scale by √d_k
scores = scores / math.sqrt(d_k)
# Step 3: Apply mask (e.g., causal mask for autoregressive decoding)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 4: Softmax over keys dimension
weights = F.softmax(scores, dim=-1)
# Step 5: Weighted sum of values
output = torch.bmm(weights, value)
return output, weights
def test_attention_shapes():
"""Verify output shapes match expectations."""
batch, n_q, n_k, d_k, d_v = 2, 5, 7, 64, 32
Q = torch.randn(batch, n_q, d_k)
K = torch.randn(batch, n_k, d_k)
V = torch.randn(batch, n_k, d_v)
output, weights = scaled_dot_product_attention(Q, K, V)
assert output.shape == (batch, n_q, d_v), f"Expected ({batch}, {n_q}, {d_v}), got {output.shape}"
assert weights.shape == (batch, n_q, n_k), f"Expected ({batch}, {n_q}, {n_k}), got {weights.shape}"
# Attention weights should sum to 1 along key dimension
weight_sums = weights.sum(dim=-1)
assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), \
f"Weights don't sum to 1: {weight_sums}"
print("✓ All shape tests passed")
def test_scaling_effect():
"""Demonstrate why scaling matters."""
d_k = 512
Q = torch.randn(1, 1, d_k)
K = torch.randn(1, 10, d_k)
V = torch.randn(1, 10, d_k)
# Without scaling
scores_unscaled = torch.bmm(Q, K.transpose(-2, -1))
weights_unscaled = F.softmax(scores_unscaled, dim=-1)
# With scaling
scores_scaled = scores_unscaled / math.sqrt(d_k)
weights_scaled = F.softmax(scores_scaled, dim=-1)
print(f"d_k = {d_k}")
print(f"Score std (unscaled): {scores_unscaled.std():.2f}")
print(f"Score std (scaled): {scores_scaled.std():.2f}")
print(f"Max weight (unscaled): {weights_unscaled.max():.4f}")
print(f"Max weight (scaled): {weights_scaled.max():.4f}")
print(f"Entropy (unscaled): {-(weights_unscaled * weights_unscaled.log()).sum():.4f}")
print(f"Entropy (scaled): {-(weights_scaled * weights_scaled.log()).sum():.4f}")
test_attention_shapes()
test_scaling_effect()
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""Create lower-triangular mask for autoregressive attention.
Position i can only attend to positions ≤ i.
"""
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask # 1 = attend, 0 = block
# Example: 4-token sequence
mask = create_causal_mask(4)
print(mask)
# tensor([[1, 0, 0, 0], ← token 0 sees only itself
# [1, 1, 0, 0], ← token 1 sees 0,1
# [1, 1, 1, 0], ← token 2 sees 0,1,2
# [1, 1, 1, 1]]) ← token 3 sees everything
Q = K = V = torch.randn(1, 4, 64)
output, weights = scaled_dot_product_attention(Q, K, V, mask=mask.unsqueeze(0))
print(f"Causal attention weights:\n{weights.squeeze()}")
# Upper triangle should be zero
Given a 3-token sequence with $d_k = 2$:
$$Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix}, \quad K = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0.5 & 0.5 \end{bmatrix}, \quad V = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 0.5 & 0.5 \end{bmatrix}$$
Write code that: 1. Fixes $Q$ and $K$ and varies $d_k$ from 1 to 1024 2. For each $d_k$, computes attention weights with and without scaling 3. Plots the entropy of the attention distribution vs $d_k$ 4. Shows that without scaling, entropy collapses (attention becomes one-hot) as $d_k$ grows
Using the same encoder-decoder architecture: 1. Replace Bahdanau attention with scaled dot-product attention 2. Train both on the same data 3. Compare: training speed (wall time per epoch), final BLEU, attention heatmaps 4. Where do they agree? Where do they differ?
The dot product $q \cdot k = \|q\| \|k\| \cos\theta$ measures similarity in the embedding space. Attention weights are a similarity-normalized information retrieval. The model learns to place queries and keys in a space where geometrically close = semantically relevant. This is the same principle behind compression: identify redundancy (similarity), then use it to reconstruct (retrieve values). The transformer will build on this to create an entire architecture out of nothing but learned similarity lookups.