Phase II — Attention, Transformers & Scaling | Week 2 | 2.5 hours "The most important development in neural machine translation was allowing the decoder to look back." — Dzmitry Bahdanau, 2014
Recall from Day 7 the seq2seq model: an encoder reads the source sentence and compresses it into a single fixed-length vector $c$, which the decoder must use to generate the entire target sentence.
Source: x₁ ─→ h₁ ─→ h₂ ─→ h₃ ─→ h₄ ──→ [c] ← entire sentence in one vector!
│
Target: [c] ─→ s₁ ─→ s₂ ─→ s₃
y₁ y₂ y₃
This is the information bottleneck. No matter how long the source sentence is — 5 words or 50 — everything must be squeezed through a vector of dimension $d$. Performance degrades dramatically on long sentences (Cho et al., 2014).
The question: Can the decoder look at all encoder hidden states, not just the last one?
Bahdanau, Cho, and Bengio proposed: let the decoder attend to different parts of the source for each target word.
Key insight: Attention is a soft dictionary lookup.
For each decoder step $i$, compute a context vector $c_i$ as a weighted sum of encoder states:
$$c_i = \sum_{j=1}^{T} \alpha_{ij} \, h_j$$
where $\alpha_{ij}$ are attention weights (soft alignment):
$$\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T} \exp(e_{ik})}$$
The energy function (alignment model) is a learnable feedforward network:
$$e_{ij} = v^T \tanh(W_s \, s_{i-1} + W_h \, h_j)$$
Encoder hidden states
h₁ h₂ h₃ h₄
│ │ │ │
▼ ▼ ▼ ▼
s_{i-1} ───→ [energy] [energy] [energy] [energy]
│ │ │ │
▼ ▼ ▼ ▼
e_{i1} e_{i2} e_{i3} e_{i4}
│ │ │ │
└─────┴─────┴─────┘
│
softmax
│
┌─────┴─────┐
▼ ▼ ▼ ▼
α_{i1} α_{i2} α_{i3} α_{i4}
│ │ │ │
▼ ▼ ▼ ▼
α·h₁ + α·h₂ + α·h₃ + α·h₄ = c_i
The energy function uses addition inside the tanh:
$$e_{ij} = v^T \tanh(\underbrace{W_s \, s_{i-1}}_{\text{query term}} + \underbrace{W_h \, h_j}_{\text{key term}})$$
This is why Bahdanau attention is classified as additive attention, in contrast to the dot-product attention we'll see in Day 11.
Parameters to learn: - $W_s \in \mathbb{R}^{d_a \times d_s}$ — projects decoder state - $W_h \in \mathbb{R}^{d_a \times d_h}$ — projects encoder state - $v \in \mathbb{R}^{d_a}$ — weight vector that produces a scalar energy
In traditional machine translation, alignment is a hard mapping between source and target words. Attention replaces this with a soft, differentiable version:
| Hard alignment | Soft attention |
|---|---|
| Each target word → exactly one source word | Each target word → weighted combination |
| Discrete, non-differentiable | Continuous, differentiable |
| Must be computed separately | Learned end-to-end |
The attention weight $\alpha_{ij}$ tells us: "When generating target word $i$, how much should we look at source word $j$?"
With attention, the decoder GRU update becomes:
$$s_i = \text{GRU}(s_{i-1}, [y_{i-1}; c_i])$$
where $[y_{i-1}; c_i]$ is the concatenation of the previous target embedding and the context vector. The context changes at every step — the decoder sees a different view of the source for each word it generates.
import torch
import torch.nn as nn
import torch.nn.functional as F
class BahdanauAttention(nn.Module):
"""Additive (Bahdanau) attention mechanism."""
def __init__(self, encoder_dim: int, decoder_dim: int, attention_dim: int):
super().__init__()
self.W_h = nn.Linear(encoder_dim, attention_dim, bias=False)
self.W_s = nn.Linear(decoder_dim, attention_dim, bias=False)
self.v = nn.Linear(attention_dim, 1, bias=False)
def forward(
self,
decoder_state: torch.Tensor, # (batch, decoder_dim)
encoder_outputs: torch.Tensor, # (batch, src_len, encoder_dim)
) -> tuple[torch.Tensor, torch.Tensor]:
# Project encoder outputs: (batch, src_len, attn_dim)
keys = self.W_h(encoder_outputs)
# Project decoder state: (batch, 1, attn_dim)
query = self.W_s(decoder_state).unsqueeze(1)
# Energy: (batch, src_len, 1) → squeeze → (batch, src_len)
energy = self.v(torch.tanh(keys + query)).squeeze(-1)
# Attention weights: (batch, src_len)
alpha = F.softmax(energy, dim=-1)
# Context vector: (batch, encoder_dim)
context = torch.bmm(alpha.unsqueeze(1), encoder_outputs).squeeze(1)
return context, alpha
class Encoder(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, src: torch.Tensor):
# src: (batch, src_len)
embedded = self.embedding(src) # (batch, src_len, embed_dim)
outputs, hidden = self.rnn(embedded) # outputs: (batch, src_len, 2*hidden)
# Combine forward/backward final states for decoder init
hidden = torch.tanh(self.fc(
torch.cat([hidden[-2], hidden[-1]], dim=-1)
))
return outputs, hidden
class AttentionDecoder(nn.Module):
def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, encoder_dim: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.attention = BahdanauAttention(encoder_dim, hidden_dim, hidden_dim)
self.rnn = nn.GRU(embed_dim + encoder_dim, hidden_dim, batch_first=True)
self.fc_out = nn.Linear(hidden_dim + encoder_dim + embed_dim, vocab_size)
def forward(self, token, decoder_state, encoder_outputs):
# token: (batch,) — single time step
embedded = self.embedding(token) # (batch, embed_dim)
context, alpha = self.attention(decoder_state, encoder_outputs)
rnn_input = torch.cat([embedded, context], dim=-1).unsqueeze(1)
output, hidden = self.rnn(rnn_input, decoder_state.unsqueeze(0))
prediction = self.fc_out(
torch.cat([output.squeeze(1), context, embedded], dim=-1)
)
return prediction, hidden.squeeze(0), alpha
import matplotlib.pyplot as plt
import numpy as np
def plot_attention(attention_weights: np.ndarray, src_tokens: list[str], tgt_tokens: list[str]):
"""Plot attention heatmap showing soft alignment between source and target."""
fig, ax = plt.subplots(figsize=(8, 6))
cax = ax.matshow(attention_weights, cmap='viridis')
fig.colorbar(cax)
ax.set_xticks(range(len(src_tokens)))
ax.set_xticklabels(src_tokens, rotation=45, ha='left')
ax.set_yticks(range(len(tgt_tokens)))
ax.set_yticklabels(tgt_tokens)
ax.set_xlabel('Source')
ax.set_ylabel('Target')
ax.set_title('Bahdanau Attention Weights')
plt.tight_layout()
plt.savefig('attention_heatmap.png', dpi=150)
plt.show()
# Example: after training, collect attention weights during inference
# attention_matrix[i, j] = α_{ij} = how much target word i attends to source word j
# Expect near-diagonal for monotonic languages, crossed patterns for reordering
Given encoder states $h_1 = [1, 0]$, $h_2 = [0, 1]$, $h_3 = [1, 1]$ and decoder state $s_0 = [0.5, 0.5]$:
Train two seq2seq models on Multi30k EN→DE: 1. Vanilla seq2seq (Day 7 style) with final-state-only context 2. Seq2seq with Bahdanau attention
Compare: - BLEU scores on test set - Training curves (loss vs epoch) - Performance on short (< 10 tokens) vs long (> 20 tokens) sentences
Using your trained attention model: 1. Generate translations for 5 sentences of varying length 2. Plot attention heatmaps for each 3. Find examples of: monotonic alignment, reordering, many-to-one alignment 4. Does the model learn word order differences between EN and DE?
Attention solves a compression problem. Fixed-context seq2seq tries to compress the entire source into one vector — a lossy compression that loses information for long sequences. Attention is adaptive decompression: at each step, the decoder selects which parts of the source to decompress. This is the beginning of the transformer's power: selective access to information beats wholesale compression.