Phase IV · Week 8 · Day 52 of 70 · 2.5 hours
"The teacher whispers probabilities across classes; the student hears structure the labels alone could never reveal."
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 51: Weight Compression & Pruning | Day 53: TensorRT Optimization | Week 8: Model Formats & Runtimes | Phase IV: Inference & Deployment | ML Compilers |
Hard labels say "this is a cat." Soft labels say "this is 90% cat, 5% lynx, 3% tiger, 2% dog." Those secondary probabilities — the dark knowledge — encode similarity structure the teacher learned over millions of examples. Knowledge distillation transfers this rich information from a large, accurate teacher model to a small, deployable student model. DistilBERT (66M params) matches 97% of BERT-base's accuracy (110M params) while being 60% smaller and 2× faster. Progressive distillation can compress a 1000-step diffusion model to 4 steps. Combined with pruning and quantization, distillation is the third pillar of model compression — and often the most powerful one because it doesn't just compress weights, it compresses knowledge.
Knowledge Distillation — Architecture Overview
══════════════════════════════════════════════════════════════
Input x ─────────────────────┬──────────────────────────┐
│ │
▼ ▼
┌──────────────────┐ ┌──────────────────┐
│ Teacher Model │ │ Student Model │
│ (Large, frozen) │ │ (Small, trainable)
│ ResNet-152 │ │ ResNet-18 │
│ BERT-large │ │ TinyBERT │
└────────┬─────────┘ └────────┬─────────┘
│ │
▼ ▼
┌────────────────┐ ┌────────────────┐
│ Soft logits │ │ Soft logits │
│ z_t / T │ │ z_s / T │
└────────┬───────┘ └────────┬───────┘
│ │
└──────────┬───────────────┘
│
▼
┌──────────────────┐
│ L_distill = │
│ KL(σ(z_t/T) ‖ │
│ σ(z_s/T)) │
│ × T² │
└──────────────────┘
│
▼
┌──────────────────┐ Hard label y
│ L_total = │◀────────────┐
│ α·L_distill + │ │
│ (1-α)·L_CE(s,y) │ ┌────────┴────┐
└──────────────────┘ │ L_CE = │
│ CE(σ(z_s), y)│
└──────────────┘
The foundational technique from Hinton, Vinyals & Dean (2015):
Softmax with temperature $T$ controls the entropy of the output distribution:
$$q_i = \frac{\exp(z_i / T)}{\sum_j \exp(z_j / T)}$$
Temperature Effect on Softmax Output
══════════════════════════════════════
Class: Cat Lynx Tiger Dog Fish
Logits: 5.2 2.1 1.8 0.3 -1.5
T=1 (hard): 0.91 0.04 0.03 0.01 0.00
T=5 (soft): 0.38 0.18 0.16 0.10 0.06
T=10 (softer):0.27 0.17 0.16 0.13 0.09
──▲── ──▲──
Still dominant Now visible
$$\mathcal{L}_{\text{distill}} = T^2 \cdot \text{KL}\left(\sigma\left(\frac{z_t}{T}\right) \;\bigg\|\; \sigma\left(\frac{z_s}{T}\right)\right)$$
The $T^2$ factor compensates for the reduced gradient magnitudes at high temperature:
$$\frac{\partial \mathcal{L}_{\text{KL}}}{\partial z_s} \propto \frac{1}{T^2} \quad \Rightarrow \quad T^2 \cdot \mathcal{L}_{\text{KL}} \text{ restores gradient scale}$$
$$\mathcal{L}_{\text{total}} = \alpha \cdot \mathcal{L}_{\text{distill}}(z_t, z_s; T) + (1 - \alpha) \cdot \mathcal{L}_{\text{CE}}(z_s, y)$$
Typical hyperparameters: $T \in [4, 20]$, $\alpha \in [0.5, 0.9]$.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.T = temperature
self.alpha = alpha
self.ce_loss = nn.CrossEntropyLoss()
def forward(self, student_logits, teacher_logits, labels):
# Soft distillation loss (KL divergence on softened distributions)
soft_teacher = F.softmax(teacher_logits / self.T, dim=-1)
soft_student = F.log_softmax(student_logits / self.T, dim=-1)
distill_loss = F.kl_div(
soft_student, soft_teacher,
reduction="batchmean"
) * (self.T ** 2)
# Hard label loss (standard cross-entropy)
ce_loss = self.ce_loss(student_logits, labels)
return self.alpha * distill_loss + (1 - self.alpha) * ce_loss
# Training loop
teacher = load_pretrained_teacher().eval()
student = create_small_student()
criterion = DistillationLoss(temperature=4.0, alpha=0.7)
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for images, labels in train_loader:
with torch.no_grad():
teacher_logits = teacher(images)
student_logits = student(images)
loss = criterion(student_logits, teacher_logits, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Instead of (or in addition to) matching output logits, match intermediate representations:
Feature-Based Distillation
══════════════════════════════════════════════════════════
Teacher Student
┌────────────────┐ ┌────────────────┐
│ Layer 1 │ ────hint────────▶ │ Layer 1 │
│ [B,256,56,56] │ projection │ [B,64,56,56] │
├────────────────┤ ├────────────────┤
│ Layer 2 │ ────hint────────▶ │ Layer 2 │
│ [B,512,28,28] │ projection │ [B,128,28,28] │
├────────────────┤ ├────────────────┤
│ Layer 3 │ ────hint────────▶ │ Layer 3 │
│ [B,1024,14,14]│ projection │ [B,256,14,14] │
├────────────────┤ ├────────────────┤
│ Output │ ────KL div───────▶│ Output │
└────────────────┘ └────────────────┘
Projection: 1×1 conv to match teacher dimensions
Loss: L_total = L_logit + β·Σ L_feature(f_t^l, proj(f_s^l))
class FeatureDistillationLoss(nn.Module):
def __init__(self, teacher_channels, student_channels):
super().__init__()
# 1x1 conv projections to match teacher dimensions
self.projections = nn.ModuleList([
nn.Conv2d(sc, tc, kernel_size=1)
for sc, tc in zip(student_channels, teacher_channels)
])
def forward(self, student_features, teacher_features):
loss = 0.0
for proj, sf, tf in zip(
self.projections, student_features, teacher_features
):
projected = proj(sf)
# Normalize features before comparing
sf_norm = F.normalize(projected.flatten(2), dim=-1)
tf_norm = F.normalize(tf.flatten(2), dim=-1)
loss += F.mse_loss(sf_norm, tf_norm)
return loss
For transformer models, transfer the attention patterns themselves:
$$\mathcal{L}_{\text{attn}} = \sum_{l=1}^{L} \left\| \frac{A_t^l}{\|A_t^l\|_F} - \frac{A_s^l}{\|A_s^l\|_F} \right\|_F^2$$
where $A^l \in \mathbb{R}^{H \times N \times N}$ is the attention map at layer $l$, with $H$ heads and $N$ tokens.
def attention_transfer_loss(teacher_attentions, student_attentions):
"""
teacher_attentions: list of [B, H_t, N, N] attention maps
student_attentions: list of [B, H_s, N, N] attention maps
"""
loss = 0.0
# Map student layers to teacher layers (e.g., every other)
layer_mapping = list(zip(
range(0, len(teacher_attentions), 2), # Teacher: layers 0, 2, 4, ...
range(len(student_attentions)) # Student: layers 0, 1, 2, ...
))
for t_idx, s_idx in layer_mapping:
# Average across heads
t_attn = teacher_attentions[t_idx].mean(dim=1) # [B, N, N]
s_attn = student_attentions[s_idx].mean(dim=1)
# Frobenius-normalize
t_norm = t_attn / (t_attn.norm(dim=(-2, -1), keepdim=True) + 1e-8)
s_norm = s_attn / (s_attn.norm(dim=(-2, -1), keepdim=True) + 1e-8)
loss += F.mse_loss(s_norm, t_norm)
return loss / len(layer_mapping)
The most successful distillation in NLP — a blueprint for efficient deployment:
DistilBERT Architecture Comparison
══════════════════════════════════════════════════════════
BERT-base DistilBERT
┌──────────────┐ ┌──────────────┐
│ 12 layers │ │ 6 layers │ ← 50% layers
│ 768 hidden │ │ 768 hidden │ ← Same width
│ 12 heads │ │ 12 heads │ ← Same heads
│ 110M params │ │ 66M params │ ← 40% smaller
├──────────────┤ ├──────────────┤
│ GLUE: 79.6 │ │ GLUE: 77.0 │ ← 97% accuracy
│ Latency: 1x │ │ Latency: 0.5x│ ← 2× faster
└──────────────┘ └──────────────┘
Distillation recipe:
┌────────────────────────────────────────────┐
│ 1. Initialize student from every-other │
│ teacher layer (layers 0,2,4,6,8,10) │
│ 2. Triple loss: │
│ - L_ce: masked LM cross-entropy │
│ - L_mlm: soft distillation (T=8) │
│ - L_cos: cosine embedding loss │
│ 3. Train on same data as BERT │
│ 4. No NSP task (removed for simplicity) │
└────────────────────────────────────────────┘
The triple loss combines: (1) cross-entropy on hard MLM labels ($\alpha_{ce} = 0.5$), (2) KL divergence on softened logits at $T=8$ ($\alpha_{mlm} = 0.33$), and (3) cosine embedding loss between hidden states ($\alpha_{cos} = 0.17$).
Distillation applied to iterative sampling — halving the number of denoising steps repeatedly:
Progressive Distillation for Diffusion
══════════════════════════════════════════════════════════
Round 1: Teacher uses 1024 steps → Student learns 512 steps
Round 2: Student becomes teacher → New student learns 256 steps
Round 3: 256-step teacher → 128-step student
Round 4: 128 → 64 → 32 → 16 → 8 → 4 steps
┌─────────┐ distill ┌─────────┐ distill ┌─────────┐
│ 1024 │ ──────────▶ │ 512 │ ──────────▶ │ 256 │ ──▶ ...
│ steps │ merge │ steps │ merge │ steps │
│ (teacher)│ 2 → 1 │(student)│ 2 → 1 │(student)│
└─────────┘ └─────────┘ └─────────┘
Each round: student predicts the result of 2 teacher steps in 1 step
v-prediction parameterization:
v_θ(x_t, t) = α_t · ε − σ_t · x_0
Student loss: ‖v_student(x_t, t) − v_teacher_2steps(x_t, t)‖²
This technique (Salimans & Ho, 2022) enabled 4-step generation from Stable Diffusion with minimal quality loss.
# Teacher = Student (same architecture, different training)
# Step 1: Train model normally
# Step 2: Use the trained model as teacher for a fresh copy
# Does this improve over standard training? Why?
Day 53: TensorRT Optimization — You've compressed and distilled models. Now it's time to make them fly. TensorRT fuses layers, calibrates precision, plans memory, and profiles kernels to extract every last TFLOP from NVIDIA GPUs. You'll learn the full pipeline from ONNX parsing through INT8 calibration to TensorRT-LLM for transformer inference.