Phase III · Week 5 · Day 30 of 70 · 2.5 hours
"Relay is a functional IR — every transformation is a function that takes a graph and returns a new graph. No mutation, no side effects, no debugging nightmares."
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 29: TVM Architecture Overview | Day 31: Relay Optimization Passes | Week 5: TVM Foundations | Phase III: Apache TVM Deep Dive | ML Compilers |
Every model that enters TVM — whether from PyTorch, TensorFlow, or ONNX — is first converted into Relay IR. Relay is TVM's graph-level intermediate representation, and it's where high-level optimizations like operator fusion and constant folding happen. Unlike PyTorch's FX Graph (a flat DAG of operations), Relay is a functional language with a proper type system, let-bindings, closures, and algebraic data types. Understanding Relay is essential because every optimization pass in TVM's graph layer operates on Relay programs.
Relay is not just a "graph format" — it's a full programming language with:
┌────────────────────────────────────────────────────┐
│ Relay IR │
│ │
│ • First-class functions (closures, higher-order) │
│ • Let-bindings (SSA-like, no mutation) │
│ • Pattern matching on ADTs │
│ • Rich type system (Tensor, Tuple, Function, Ref) │
│ • Call-by-value evaluation │
│ • Differentiable (supports AD natively) │
└────────────────────────────────────────────────────┘
| Feature | FX Graph (PyTorch) | HLO (XLA) | Relay (TVM) | ONNX |
|---|---|---|---|---|
| Paradigm | Imperative DAG | Functional ops | Functional language | Serialization format |
| Type system | Minimal | Shaped arrays | Rich (generics, ADTs) | Tensor types |
| Control flow | Graph breaks | While/Cond ops | If/Match/Recursion | If/Loop |
| Functions | None (flat graph) | HloComputation | First-class closures | Subgraphs |
| Let-bindings | No (node list) | No | Yes (SSA-like) | No |
| Differentiable | Via AOTAutograd | Via HLO AD | Built-in AD pass | No |
# Imperative (FX-style): mutations are possible, order matters
graph.nodes[3].args = (new_input,) # Mutating in place — dangerous!
# Functional (Relay-style): no mutation, always produce new IR
new_expr = relay.Call(relay.op.get("nn.relu"), [old_expr])
# old_expr is unchanged — safe for parallel passes
The functional approach means: - Passes can't corrupt shared state - Dead code is automatically collected - Transformation correctness is easier to verify
Relay has a static type system that catches shape mismatches at compile time, not at runtime.
import tvm
from tvm import relay
# TensorType: shape + dtype
t1 = relay.TensorType((1, 3, 224, 224), "float32")
print(t1) # Tensor[(1, 3, 224, 224), float32]
# Shape can include symbolic dimensions
n = tvm.tir.Var("n", "int32")
t_dynamic = relay.TensorType((n, 3, 224, 224), "float32")
print(t_dynamic) # Tensor[(n, 3, 224, 224), float32]
FuncType(arg_types, ret_type, type_params, type_constraints)
Example: a conv2d operation
fn(Tensor[(N, C_in, H, W), f32], ← input
Tensor[(C_out, C_in, kH, kW), f32] ← weight
) -> Tensor[(N, C_out, H', W'), f32] ← output
# Tuples group multiple values
tt = relay.TupleType([
relay.TensorType((1, 64, 56, 56), "float32"),
relay.TensorType((64,), "float32"),
relay.TensorType((64,), "float32"),
])
# Used by batch_norm which returns (output, running_mean, running_var)
Relay has a Hindley-Milner-style type inference engine:
# Types are inferred automatically — you don't need to annotate
x = relay.var("x", shape=(1, 3, 224, 224), dtype="float32")
w = relay.var("w", shape=(64, 3, 7, 7), dtype="float32")
y = relay.nn.conv2d(x, w, strides=(2, 2), padding=(3, 3))
# After type inference:
mod = tvm.IRModule.from_expr(y)
mod = relay.transform.InferType()(mod)
print(mod["main"].ret_type)
# → Tensor[(1, 64, 112, 112), float32]
The type system enforces correctness:
$$\text{Output height} = \left\lfloor \frac{H + 2p - k}{s} \right\rfloor + 1 = \left\lfloor \frac{224 + 6 - 7}{2} \right\rfloor + 1 = 112$$
import tvm
from tvm import relay
import numpy as np
# Build a simple network: Linear → ReLU → Linear
# y = relu(x @ W1 + b1) @ W2 + b2
x = relay.var("x", shape=(1, 784), dtype="float32")
w1 = relay.var("w1", shape=(784, 256), dtype="float32")
b1 = relay.var("b1", shape=(256,), dtype="float32")
w2 = relay.var("w2", shape=(256, 10), dtype="float32")
b2 = relay.var("b2", shape=(10,), dtype="float32")
# Build computation graph
dense1 = relay.nn.dense(x, w1) # (1, 256)
biased1 = relay.add(dense1, b1) # (1, 256) — broadcast
act = relay.nn.relu(biased1) # (1, 256)
dense2 = relay.nn.dense(act, w2) # (1, 10)
out = relay.add(dense2, b2) # (1, 10)
# Create a Relay function
func = relay.Function([x, w1, b1, w2, b2], out)
mod = tvm.IRModule.from_expr(func)
# Print the IR
print(mod)
Output:
def @main(%x: Tensor[(1, 784), float32],
%w1: Tensor[(784, 256), float32],
%b1: Tensor[(256), float32],
%w2: Tensor[(256, 10), float32],
%b2: Tensor[(10), float32]) {
%0 = nn.dense(%x, %w1, units=256);
%1 = add(%0, %b1);
%2 = nn.relu(%1);
%3 = nn.dense(%2, %w2, units=10);
add(%3, %b2)
}
# Let-bindings give names to intermediate values (SSA form)
x = relay.var("x", shape=(1, 784), dtype="float32")
w = relay.var("w", shape=(784, 256), dtype="float32")
# Explicit let-binding
v0 = relay.var("v0")
v1 = relay.var("v1")
body = relay.Let(v0, relay.nn.dense(x, w),
relay.Let(v1, relay.nn.relu(v0),
v1))
# Equivalent to:
# let v0 = nn.dense(x, w) in
# let v1 = nn.relu(v0) in
# v1
Without let-bindings (expression tree):
add(nn.relu(nn.dense(x, w1)), b1)
└── if nn.dense appears in two places, it's DUPLICATED
With let-bindings (SSA-like):
let %0 = nn.dense(x, w1) in
let %1 = nn.relu(%0) in
add(%1, b1)
└── %0 is computed once, referenced by name
TVM includes importers for all major frameworks. They all produce Relay IR.
import onnx
from tvm import relay
onnx_model = onnx.load("model.onnx")
mod, params = relay.frontend.from_onnx(onnx_model, shape={"input": (1, 3, 224, 224)})
# mod: tvm.IRModule containing the Relay function
# params: dict of parameter name → numpy array
import torch
from tvm import relay
model = torchvision.models.resnet18(weights="IMAGENET1K_V1").eval()
scripted = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
# Convert TorchScript → Relay
input_infos = [("input", (1, 3, 224, 224))]
mod, params = relay.frontend.from_pytorch(scripted, input_infos)
from tvm import relay
# From SavedModel
mod, params = relay.frontend.from_tensorflow(
graph_def, # tf.GraphDef protobuf
layout="NHWC", # TF default layout
shape={"input": (1, 224, 224, 3)}
)
Framework Model
│
▼
┌─────────────────────────────────────────┐
│ Frontend Importer │
│ │
│ 1. Parse framework-specific format │
│ 2. Map ops: torch.nn.Conv2d │
│ → relay.nn.conv2d │
│ 3. Handle layout: NHWC ↔ NCHW │
│ 4. Extract weights as numpy arrays │
│ 5. Build Relay function + params dict │
└─────────────────────────────────────────┘
│
▼
(IRModule, params: Dict[str, np.ndarray])
| PyTorch | ONNX | Relay |
|---|---|---|
torch.nn.Conv2d |
Conv |
relay.nn.conv2d |
torch.nn.BatchNorm2d |
BatchNormalization |
relay.nn.batch_norm |
torch.nn.ReLU |
Relu |
relay.nn.relu |
torch.matmul |
MatMul |
relay.nn.dense or relay.nn.batch_matmul |
torch.nn.Softmax |
Softmax |
relay.nn.softmax |
torch.cat |
Concat |
relay.concatenate |
# Pretty-print the entire module
print(mod.astext(show_meta_data=False))
# Print just the main function signature
func = mod["main"]
print(f"Inputs: {[p.name_hint for p in func.params]}")
print(f"Return: {func.ret_type}")
from tvm.contrib import relay_viz
# Generate a graphviz DOT visualization
viz = relay_viz.RelayVisualizer(mod)
viz.render("relay_graph") # Creates relay_graph.pdf
# Or as SVG for notebooks
viz.render("relay_graph", format="svg")
def analyze_relay(mod):
"""Count ops, parameters, and compute cost."""
op_counts = {}
param_count = 0
def visit(node):
nonlocal param_count
if isinstance(node, relay.Call) and hasattr(node.op, "name"):
name = node.op.name
op_counts[name] = op_counts.get(name, 0) + 1
relay.analysis.post_order_visit(mod["main"], visit)
# Count parameters
for name, param in params.items():
param_count += np.prod(param.shape)
print("Operation counts:")
for op, count in sorted(op_counts.items(), key=lambda x: -x[1]):
print(f" {op}: {count}")
print(f"\nTotal parameters: {param_count:,}")
analyze_relay(mod)
# nn.conv2d: 20
# nn.batch_norm: 20
# nn.relu: 17
# add: 1
# nn.dense: 1
# nn.global_avg_pool2d: 1
# nn.max_pool2d: 1
# nn.softmax: 0
# Total parameters: 11,689,512
relay.Expr (base class)
├── relay.Constant ← literal tensor value
├── relay.Var ← named variable
├── relay.Call ← function application (op invocation)
├── relay.Let ← let-binding
├── relay.If ← conditional
├── relay.Tuple ← tuple construction
├── relay.TupleGetItem ← tuple element access
├── relay.Function ← function definition
├── relay.GlobalVar ← reference to a global function
└── relay.RefCreate ← mutable reference (rarely used)
relay.RefRead
relay.RefWrite
# Post-order visitor (most common pattern)
class OpCounter(relay.ExprVisitor):
def __init__(self):
super().__init__()
self.ops = {}
def visit_call(self, call):
if isinstance(call.op, tvm.ir.Op):
name = call.op.name
self.ops[name] = self.ops.get(name, 0) + 1
# MUST call super to continue traversal
super().visit_call(call)
counter = OpCounter()
counter.visit(mod["main"])
print(counter.ops)
# Mutating visitor (creates new IR)
class ReLUReplacer(relay.ExprMutator):
"""Replace all nn.relu with nn.leaky_relu."""
def visit_call(self, call):
call = super().visit_call(call) # recurse first
if isinstance(call.op, tvm.ir.Op) and call.op.name == "nn.relu":
return relay.nn.leaky_relu(call.args[0], alpha=0.01)
return call
replacer = ReLUReplacer()
new_func = replacer.visit(mod["main"])
import tvm
from tvm import relay
import numpy as np
def build_simple_cnn():
"""Build: Conv2d(3→16) → ReLU → MaxPool → Flatten → Dense(10)."""
x = relay.var("x", shape=(1, 3, 32, 32), dtype="float32")
# Conv2d: 3 input channels → 16 output channels, 3×3 kernel
w_conv = relay.var("w_conv", shape=(16, 3, 3, 3), dtype="float32")
conv = relay.nn.conv2d(x, w_conv, kernel_size=(3, 3), padding=(1, 1))
act = relay.nn.relu(conv)
pool = relay.nn.max_pool2d(act, pool_size=(2, 2), strides=(2, 2))
# Flatten: (1, 16, 16, 16) → (1, 4096)
flat = relay.nn.batch_flatten(pool)
# Dense: 4096 → 10
w_dense = relay.var("w_dense", shape=(10, 4096), dtype="float32")
b_dense = relay.var("b_dense", shape=(10,), dtype="float32")
dense = relay.nn.dense(flat, w_dense)
out = relay.add(dense, b_dense)
func = relay.Function([x, w_conv, w_dense, b_dense], out)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
return mod
mod = build_simple_cnn()
print(mod)
# Verify: check the inferred output shape is (1, 10)
# Import the same model from ONNX and PyTorch, compare Relay IR
import torchvision, torch, onnx
model = torchvision.models.mobilenet_v2(weights="IMAGENET1K_V1").eval()
# Path A: PyTorch → Relay
scripted = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
mod_pt, _ = relay.frontend.from_pytorch(scripted, [("input", (1, 3, 224, 224))])
# Path B: PyTorch → ONNX → Relay
torch.onnx.export(model, torch.randn(1, 3, 224, 224), "/tmp/mbv2.onnx",
input_names=["input"])
mod_onnx, _ = relay.frontend.from_onnx(onnx.load("/tmp/mbv2.onnx"),
{"input": (1, 3, 224, 224)})
# Count ops in each — are they identical?
# (They may differ slightly due to different op decompositions)
class ShapeAnalyzer(relay.ExprVisitor):
"""Track tensor shapes at every Call node."""
def __init__(self):
super().__init__()
self.shapes = []
def visit_call(self, call):
super().visit_call(call)
if call.checked_type and isinstance(call.checked_type, relay.TensorType):
shape = [int(x) for x in call.checked_type.shape]
op_name = call.op.name if isinstance(call.op, tvm.ir.Op) else "fused"
self.shapes.append((op_name, shape))
# Run on type-inferred module
mod_typed = relay.transform.InferType()(mod)
analyzer = ShapeAnalyzer()
analyzer.visit(mod_typed["main"])
for op, shape in analyzer.shapes:
print(f" {op:30s} → {shape}")
ExprVisitor and ExprMutator are the building blocks for writing custom analysis and transformation passes over Relay IRDay 31: Relay Optimization Passes — Now that you can read and construct Relay programs, we'll learn the pass infrastructure that transforms them. You'll see how FuseOps merges operators for memory efficiency, how FoldConstant pre-computes static expressions, and how to write your own custom Relay passes.