Phase I · Week 2 · Day 12 of 70 · 2.5 hours
"The tension between developer experience and compiler optimization is the central drama of modern ML frameworks."
| ← Previous | Next → | 📅 Week | �phase Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 11: Torch Profiler & Trace | Day 13: Operator Fusion | Week 2: PyTorch Internals | Phase I: Foundations | Curriculum Home |
Every ML framework faces the same fundamental tradeoff: eager execution gives you a debugger, print statements, and Python control flow — but the runtime sees only one operation at a time. Graph mode surrenders that flexibility so a compiler can see the full computation, enabling operator fusion, memory planning, and kernel selection that can deliver 2–5× speedups. Understanding how PyTorch navigates this tension — from TorchScript to torch.compile — is prerequisite knowledge for every topic in this curriculum.
In eager mode, every PyTorch operation immediately executes through the dispatcher — a multi-level dispatch table that routes based on device, dtype, autograd state, and more.
import torch
x = torch.randn(4, 4, device='cuda')
y = torch.randn(4, 4, device='cuda')
# Each line dispatches independently — the runtime has no lookahead
z = x + y # dispatch → CUDA kernel: elementwise_add
w = z * 2.0 # dispatch → CUDA kernel: elementwise_mul
out = w.relu() # dispatch → CUDA kernel: relu
What happens per operation:
Python call: x + y
│
▼
┌─────────────────────────┐
│ PyTorch Dispatcher │
│ ┌───────────────────┐ │
│ │ Autograd dispatch │──── record grad_fn if requires_grad
│ │ Device dispatch │──── route to CUDA/CPU/MPS backend
│ │ Dtype dispatch │──── select float32/float16 kernel
│ │ Layout dispatch │──── dense vs sparse vs strided
│ └───────────────────┘ │
└─────────────────────────┘
│
▼
CUDA kernel launch
(cudaLaunchKernel)
│
▼
Result tensor returned to Python
| Property | Behavior |
|---|---|
| Execution | Op-by-op, immediate |
| Shapes | Fully dynamic — can change every call |
| Control flow | Full Python if/else/for/while |
| Debugging | pdb, print(), breakpoints work |
| Performance | No cross-op optimization, kernel launch overhead |
| Memory | No planning — allocator reacts to each request |
Because the runtime sees only one op at a time, it cannot:
x + y and * 2.0 launch separate kernels, each reading/writing global memorycaching_allocatorThe overhead per op is typically 5–15 μs of Python + dispatcher cost, which dominates for small tensors.
A computation graph captures the full sequence of operations before executing any of them. This gives the compiler a global view:
Eager view (one op at a time): Graph view (full program):
x + y → launch kernel ┌───┐ ┌───┐
* 2.0 → launch kernel │ x │ │ y │
relu() → launch kernel └─┬─┘ └─┬─┘
│ │
3 kernels, 3 memory round-trips └──┬───┘
│ add
▼
┌─────┐
│ z │
└──┬──┘
│ mul(2.0)
▼
┌─────┐
│ w │
└──┬──┘
│ relu
▼
┌─────┐
│ out │
└─────┘
Compiler can fuse into 1 kernel!
| Optimization | Description | Typical Speedup |
|---|---|---|
| Operator fusion | Merge element-wise ops into one kernel | 2–4× for chains |
| Memory planning | Pre-allocate and reuse buffers | 10–30% memory reduction |
| Shape specialization | Generate kernels for exact shapes | 5–15% compute savings |
| Layout optimization | Choose channels-last when beneficial | 10–20% for convolutions |
| Dead code elimination | Remove unused computations | Varies |
| Constant folding | Evaluate constant expressions at compile time | Varies |
torch.jit)TorchScript was PyTorch's first production graph mode, offering two capture mechanisms:
torch.jit.trace — Record ExecutionRuns the model once with example inputs and records every operation:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(784, 256)
def forward(self, x):
return self.linear(x).relu()
model = SimpleModel().eval()
example_input = torch.randn(1, 784)
# Trace: run once, record the graph
traced = torch.jit.trace(model, example_input)
print(traced.graph) # IR representation
Limitations of tracing: - ❌ Data-dependent control flow is baked in (only one branch recorded) - ❌ Dynamic shapes become fixed constants - ❌ In-place dictionary/list mutations may be lost
# DANGER: control flow is NOT captured correctly by trace
class ConditionalModel(nn.Module):
def forward(self, x):
if x.sum() > 0: # This branch is baked at trace time!
return x * 2
else:
return x * -1
traced = torch.jit.trace(ConditionalModel(), torch.ones(4))
# traced ALWAYS returns x * 2, even for negative inputs!
torch.jit.script — Parse Python ASTParses the Python source code and compiles it to TorchScript IR:
@torch.jit.script
def conditional_fn(x: torch.Tensor) -> torch.Tensor:
if x.sum() > 0: # Control flow IS captured
return x * 2
else:
return x * -1
# Works correctly for all inputs
print(conditional_fn(torch.ones(4))) # tensor([2., 2., 2., 2.])
print(conditional_fn(-torch.ones(4))) # tensor([1., 1., 1., 1.])
Limitations of scripting:
- ❌ Only a subset of Python is supported (no **kwargs, limited containers)
- ❌ Type annotations required everywhere
- ❌ Many third-party libraries cannot be scripted
- ❌ Error messages are often cryptic
torch.fx (introduced PyTorch 1.8) takes a different approach: symbolic tracing through Python. Instead of running real tensors, it feeds Proxy objects that record the call graph:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(128, 64)
self.bn = torch.nn.BatchNorm1d(64)
def forward(self, x):
x = self.linear(x)
x = self.bn(x)
return x.relu()
# FX symbolic trace
traced: torch.fx.GraphModule = torch.fx.symbolic_trace(MyModule())
print(traced.graph)
Output (FX IR):
graph():
%x : [num_users=1] = placeholder[target=x]
%linear : [num_users=1] = call_module[target=linear](args = (%x,), kwargs = {})
%bn : [num_users=1] = call_module[target=bn](args = (%linear,), kwargs = {})
%relu : [num_users=1] = call_method[target=relu](args = (%bn,), kwargs = {})
return relu
The real power of FX is programmatic graph transformation:
# Example: replace all ReLU with GELU
def replace_relu_with_gelu(gm: torch.fx.GraphModule):
for node in gm.graph.nodes:
if node.op == 'call_method' and node.target == 'relu':
node.target = 'gelu' # Swap activation
gm.graph.lint()
gm.recompile()
return gm
transformed = replace_relu_with_gelu(traced)
torch.compile and torch.export (PyTorch 2.x)PyTorch 2.0 introduced TorchDynamo — a CPython bytecode interceptor that captures graphs with minimal user friction.
torch.compile — JIT Compilationimport torch
def my_function(x, y):
z = x + y
w = z * 2.0
return w.relu()
# One line — that's it
compiled_fn = torch.compile(my_function)
# First call: TorchDynamo captures graph → TorchInductor generates code
x = torch.randn(1024, 1024, device='cuda')
y = torch.randn(1024, 1024, device='cuda')
result = compiled_fn(x, y) # Triggers compilation
How TorchDynamo works:
Python bytecode
│
▼
┌──────────────────┐
│ TorchDynamo │ ← Intercepts at CPython frame level
│ (bytecode trans- │
│ formation) │
└────────┬─────────┘
│ FX Graph
▼
┌──────────────────┐
│ AOTAutograd │ ← Generates forward + backward graphs
└────────┬─────────┘
│ Aten IR
▼
┌──────────────────┐
│ TorchInductor │ ← Backend compiler
│ (or other │
│ backend) │
└────────┬─────────┘
│ Triton / C++ code
▼
Optimized kernel
When Dynamo encounters something it can't trace, it inserts a graph break — splitting the graph and falling back to eager:
@torch.compile
def has_graph_break(x):
y = x + 1 # ← Graph 1
print(f"Shape: {y.shape}") # ← Graph break! (print is side effect)
z = y * 2 # ← Graph 2
return z
# Use TORCH_LOGS to see graph breaks:
# TORCH_LOGS="graph_breaks" python script.py
torch.export — Ahead-of-Time Graph Capture# torch.export produces a clean, serializable graph
from torch.export import export
class Model(torch.nn.Module):
def forward(self, x):
return x.relu() + 1.0
m = Model()
# Specify dynamic dimensions explicitly
batch = torch.export.Dim("batch", min=1, max=256)
exported = export(m, (torch.randn(4, 8),), dynamic_shapes={"x": {0: batch}})
print(exported.graph_module.graph)
┌────────────────┬──────────┬────────────┬───────────┬──────────────┬─────────────┐
│ │ Eager │ jit.trace │jit.script │ torch.fx │torch.compile│
├────────────────┼──────────┼────────────┼───────────┼──────────────┼─────────────┤
│ Capture method │ None │ Execute │ Parse AST │Symbolic trace│ Bytecode │
│ Control flow │ Full │ ✗ Baked │ ✓ Subset │ ✗ Limited │ ✓ Guards │
│ Dynamic shapes │ Full │ ✗ Fixed │ ✓ Limited │ ✗ Fixed │ ✓ Guards │
│ Python compat │ 100% │ ~80% │ ~60% │ ~75% │ ~95% │
│ Debug ease │ ★★★★★ │ ★★☆☆☆ │ ★★☆☆☆ │ ★★★☆☆ │ ★★★★☆ │
│ Optimization │ None │ Medium │ Medium │ High │ Highest │
│ Serializable │ No │ Yes │ Yes │ Yes │ via export │
│ Status (2025) │ Default │ Legacy │ Legacy │ Used in PT2 │ Preferred │
└────────────────┴──────────┴────────────┴───────────┴──────────────┴─────────────┘
import torch
import time
def benchmark(fn, x, warmup=10, iters=100):
for _ in range(warmup):
fn(x)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
fn(x)
torch.cuda.synchronize()
return (time.perf_counter() - t0) / iters * 1000 # ms
def chain_ops(x):
x = x + 1
x = x * 2
x = x.relu()
x = x - 0.5
x = x.sigmoid()
return x
x = torch.randn(2048, 2048, device='cuda')
eager_ms = benchmark(chain_ops, x)
compiled_fn = torch.compile(chain_ops)
compiled_ms = benchmark(compiled_fn, x)
print(f"Eager: {eager_ms:.2f} ms | Compiled: {compiled_ms:.2f} ms | "
f"Speedup: {eager_ms/compiled_ms:.1f}×")
# Run with graph break logging enabled
TORCH_LOGS="graph_breaks" python -c "
import torch
@torch.compile
def leaky(x):
y = x.sin()
print('debug') # Graph break!
return y.cos()
leaky(torch.randn(4))
"
Trace the same model with torch.fx.symbolic_trace() and torch.jit.trace(). Compare the IR outputs. Which is more readable? Which preserves module structure?
jit.trace/jit.script) is legacy — avoid for new projectstorch.compile (TorchDynamo + TorchInductor) is the modern answer: ~95% Python compatibility with graph-level optimizationtorch.export is for ahead-of-time capture when you need a serializable, deployable graphNow that we can capture computation graphs, what do we do with them? Day 13 dives into operator fusion — the single most impactful optimization a compiler can perform. We'll measure memory bandwidth savings, see Triton code generated by torch.compile, and understand why a chain of 5 element-wise ops can run faster than a single matmul.