Phase V · Week 10 · Day 65 of 70 · 2.5 hours
"DDP replicates the model. FSDP shards everything. The difference is 4× memory savings — enough to turn 'impossible on 8 GPUs' into 'fits comfortably'."
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 64: Distributed Training Basics | Day 66: Tensor & Pipeline Parallelism | Week 10: Distributed Training & Capstone | Phase V: Training at Scale | ML Compilers |
Data parallelism is the most widely used distributed training strategy — it's conceptually simple, scales linearly, and introduces no model changes. DDP (DistributedDataParallel) is PyTorch's battle-tested implementation, handling gradient synchronization with bucket-based AllReduce overlapped with backward computation. But DDP replicates the entire model on every GPU, limiting you to models that fit in a single GPU's memory. FSDP (Fully Sharded Data Parallel) — PyTorch's implementation of Microsoft's ZeRO (Zero Redundancy Optimizer) — eliminates this redundancy by sharding parameters, gradients, and optimizer states across GPUs. Understanding DDP internals and FSDP's ZeRO stages is critical for training any model beyond 10B parameters.
Naive Data Parallelism
═══════════════════════════════════════════════════════════════
Step 1: Forward pass (independent per GPU)
Step 2: Backward pass (independent per GPU)
Step 3: AllReduce ALL gradients ← BLOCKING, NO OVERLAP
Step 4: Optimizer step (identical per GPU)
Problem: Step 3 waits for ALL gradients before starting.
No overlap between compute and communication.
DDP: Overlapped Gradient Synchronization
═══════════════════════════════════════════════════════════════
Backward pass computes gradients layer by layer (last → first):
Time ──────────────────────────────────────────────────────►
Backward: [grad_L32] [grad_L31] [grad_L30] ... [grad_L1]
│ │ │
Buckets: ┌─▼───────────▼──┐ ┌────▼────────────┐
│ Bucket 2 │ │ Bucket 1 │ ...
│ (25 MB) │ │ (25 MB) │
└──────┬─────────┘ └──────┬───────────┘
│ │
AllReduce: [AllReduce B2] [AllReduce B1]
▲ ▲
│ │
Starts as soon as Starts when B1 full
bucket B2 is full (overlaps with earlier
backward compute)
Key: Gradients are AllReduced while backward is STILL RUNNING
for earlier layers. Communication is almost fully hidden.
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
def train_ddp(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# Create model and wrap with DDP
model = MyTransformer().cuda(rank)
model = DDP(model, device_ids=[rank])
# DDP constructor:
# - Broadcasts parameters from rank 0 to all ranks
# - Registers backward hooks on each parameter
# - Creates gradient buckets (default 25 MB each)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# DistributedSampler ensures non-overlapping data per GPU
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, sampler=sampler, batch_size=32)
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # Shuffle differently each epoch
for batch in loader:
inputs, labels = batch[0].cuda(rank), batch[1].cuda(rank)
optimizer.zero_grad()
outputs = model(inputs) # Forward (independent)
loss = criterion(outputs, labels)
loss.backward() # Backward + AllReduce (overlapped!)
# After backward(): all ranks have identical averaged gradients
optimizer.step() # Identical update on all ranks
dist.destroy_process_group()
| Parameter | Default | Effect |
|---|---|---|
bucket_cap_mb |
25 | Larger → better bandwidth utilization, worse overlap |
gradient_as_bucket_view |
False | Avoid gradient copy to bucket (saves memory) |
find_unused_parameters |
False | Handle models with conditional branches (slower) |
static_graph |
False | Enable optimizations for models with fixed structure |
DDP replicates everything. For a model with $N$ parameters and $P$ GPUs:
DDP Memory Usage (per GPU) — REDUNDANT
═══════════════════════════════════════════════════════════════
Per GPU Total across P GPUs Waste
──────────────────────────────────────────────────────────
2N (FP16 weights) 2NP (P-1)/P redundant
2N (FP16 grads) 2NP (P-1)/P redundant
4N (FP32 master) 4NP (P-1)/P redundant
4N (Adam m) 4NP (P-1)/P redundant
4N (Adam v) 4NP (P-1)/P redundant
──────────────────────────────────────────────────────────
16N per GPU 16NP total
With 8 GPUs, 7/8 = 87.5% of memory is WASTED on duplicates!
Microsoft's ZeRO (Rajbhandari et al., 2019) eliminates redundancy in three progressive stages:
ZeRO Stages — Progressive Sharding
═══════════════════════════════════════════════════════════════
DDP (no ZeRO): Each GPU holds ALL of: Params + Grads + Optimizer
Memory per GPU: 2N + 2N + 12N = 16N bytes
ZeRO Stage 1: Shard OPTIMIZER states only
────────────────────────────────────────────
GPU 0: [Full Params] [Full Grads] [Optim shard 0]
GPU 1: [Full Params] [Full Grads] [Optim shard 1]
GPU 2: [Full Params] [Full Grads] [Optim shard 2]
GPU 3: [Full Params] [Full Grads] [Optim shard 3]
Memory per GPU: 2N + 2N + 12N/P = 4N + 12N/P
8 GPUs: 4N + 1.5N = 5.5N (2.9× reduction)
ZeRO Stage 2: Shard OPTIMIZER + GRADIENTS
──────────────────────────────────────────
GPU 0: [Full Params] [Grad shard 0] [Optim shard 0]
GPU 1: [Full Params] [Grad shard 1] [Optim shard 1]
Memory per GPU: 2N + 2N/P + 12N/P = 2N + 14N/P
8 GPUs: 2N + 1.75N = 3.75N (4.3× reduction)
ZeRO Stage 3: Shard EVERYTHING (Params + Grads + Optimizer)
────────────────────────────────────────────────────────────
GPU 0: [Param shard 0] [Grad shard 0] [Optim shard 0]
GPU 1: [Param shard 1] [Grad shard 1] [Optim shard 1]
Memory per GPU: 2N/P + 2N/P + 12N/P = 16N/P
8 GPUs: 2N (8× reduction!)
┌──────────────────────────────────────────────────┐
│ 7B model: DDP = 112 GB/GPU → ZeRO-3 = 14 GB/GPU │
│ Fits on a single consumer GPU! │
└──────────────────────────────────────────────────┘
ZeRO Stage 3 trades memory for communication — parameters must be gathered before each forward/backward:
| Stage | Memory per GPU | Communication Volume |
|---|---|---|
| DDP | $16N$ | $2N$ (AllReduce grads) |
| ZeRO-1 | $4N + \frac{12N}{P}$ | $2N$ (AllReduce grads) |
| ZeRO-2 | $2N + \frac{14N}{P}$ | $2N$ (ReduceScatter grads) |
| ZeRO-3 | $\frac{16N}{P}$ | $3N$ (AllGather params fwd + AllGather bwd + ReduceScatter grads) |
FSDP is PyTorch's native implementation of ZeRO-3 with additional optimizations.
FSDP Forward + Backward with Parameter Gathering
═══════════════════════════════════════════════════════════════
State: Each GPU holds 1/P of each FSDP unit's parameters
FORWARD PASS:
─────────────
For each FSDP unit (e.g., transformer layer):
1. AllGather: Collect full parameters from all GPUs ← communication
2. Run forward computation with full parameters ← compute
3. Discard non-local parameter shards (free memory!) ← memory saved
4. Keep only local shard for backward
BACKWARD PASS:
──────────────
For each FSDP unit (reverse order):
1. AllGather: Collect full parameters again ← communication
2. Compute gradients with full parameters ← compute
3. ReduceScatter: Each GPU gets its gradient shard ← communication
4. Discard full parameters and full gradients ← memory saved
Net effect: Only 1/P of params + 1/P of grads + 1/P of optim in memory!
import torch
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial
def train_fsdp(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model = LargeTransformer(num_layers=32, hidden=4096, heads=32)
# Auto-wrap policy: wrap each TransformerBlock as an FSDP unit
wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock},
)
# Mixed precision: compute in BF16, reduce in FP32, params stored as BF16
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32, # Gradient reduction in FP32 for stability
buffer_dtype=torch.bfloat16,
)
model = FSDP(
model,
auto_wrap_policy=wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3
device_id=rank,
use_orig_params=True, # Required for torch.compile compatibility
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch).loss
loss.backward() # ReduceScatter gradients automatically
optimizer.step() # Each GPU updates only its shard
| Strategy | PyTorch Enum | Equivalent | Memory | Communication |
|---|---|---|---|---|
| Full Shard | FULL_SHARD |
ZeRO-3 | Best ($16N/P$) | Highest ($3N$) |
| Shard Grad Op | SHARD_GRAD_OP |
ZeRO-2 | Good ($2N + 14N/P$) | Medium ($2N$) |
| No Shard | NO_SHARD |
DDP | Worst ($16N$) | Lowest ($2N$) |
| Hybrid Shard | HYBRID_SHARD |
ZeRO-3 intra-node, DDP inter-node | Balanced | Balanced |
Hybrid Sharding (HYBRID_SHARD)
═══════════════════════════════════════════════════════════════
Node 0 (8 GPUs, NVLink) Node 1 (8 GPUs, NVLink)
┌─────────────────────────┐ ┌─────────────────────────┐
│ FSDP shard across │ │ FSDP shard across │
│ 8 GPUs (fast NVLink) │ │ 8 GPUs (fast NVLink) │
│ │ │ │
│ GPU0-GPU7 each hold │ │ GPU0-GPU7 each hold │
│ 1/8 of params │ │ 1/8 of params │
└────────┬────────────────┘ └────────┬────────────────┘
│ │
└──── DDP AllReduce (slow IB) ───┘
Only gradients, not params!
Benefit: AllGather uses fast NVLink (900 GB/s)
Only gradient AllReduce goes over slow InfiniBand (50 GB/s)
FSDP supports per-component precision control:
Mixed Precision Data Flow in FSDP
═══════════════════════════════════════════════════════════════
Storage (sharded): BF16 weights ← saves memory
│
▼
AllGather: gather full BF16 params
│
▼
Forward: compute in BF16 ← fast on Tensor Cores
│
▼
Backward: compute grads in BF16 ← fast
│
▼
ReduceScatter: accumulate in FP32 ← numerical stability!
│
▼
Optimizer: update in FP32 ← master weights
│
▼
Cast back to BF16 for storage ← memory savings
| Feature | PyTorch FSDP | DeepSpeed ZeRO |
|---|---|---|
| Framework | Native PyTorch | Separate library |
| ZeRO Stages | 2, 3 (+ hybrid) | 1, 2, 3, 3+ (offload) |
| CPU Offload | Yes (basic) | Yes (advanced, NVMe) |
| torch.compile | Yes (with use_orig_params) |
Limited |
| Activation Checkpointing | Via checkpoint_wrapper |
Built-in |
| Config | Python API | JSON config + Python |
| Community | Growing | Mature, feature-rich |
torch.compile integration needed, simpler setupsTrain a 1B-parameter model (GPT-2 XL) on 4 GPUs with both DDP and FSDP. Measure peak GPU memory for each. Verify that FSDP uses approximately $16N/P$ memory vs DDP's $16N$.
# Memory measurement snippet
torch.cuda.reset_peak_memory_stats()
# ... training step ...
peak_mb = torch.cuda.max_memory_allocated() / 1024**2
print(f"[Rank {rank}] Peak memory: {peak_mb:.0f} MB")
Compare FULL_SHARD, SHARD_GRAD_OP, and NO_SHARD on the same model. Measure memory, throughput (samples/sec), and verify gradient equivalence across strategies.
Write a complete FSDP training loop for a 7B model with mixed precision, gradient clipping, learning rate warmup, and checkpointing. Use HYBRID_SHARD if multi-node, FULL_SHARD if single-node.
use_orig_params=True enables compilation, closing the performance gap with custom solutionsFSDP handles models that are too large for one GPU's memory, but what about models with layers so large that a single layer exceeds GPU memory? Tomorrow we explore tensor parallelism (splitting individual matrix operations across GPUs) and pipeline parallelism (splitting the model depth-wise with micro-batch scheduling), completing the 3D parallelism picture used to train GPT-4 and LLaMA-3.