Phase III · Week 5 · Day 33 of 70 · 2.5 hours
"A schedule is a series of promises to the compiler: 'I guarantee this transformation preserves correctness — now make it fast.'"
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 32: Tensor Expression (TE) | Day 34: TVM Runtime & Deployment | Week 5: TVM Foundations | Phase III: Apache TVM Deep Dive | ML Compilers |
Yesterday you wrote TE declarations — the what. Today you learn TIR (Tensor IR) — the how. TIR is TVM's low-level, loop-explicit intermediate representation. When you call tvm.lower(), your TE compute + schedule transforms into TIR: explicit loop nests, buffer allocations, and memory scope annotations. Understanding TIR is essential for three reasons: (1) it's what you debug when performance is wrong, (2) advanced schedule primitives like compute_at, cache_read, and tensorize are best understood by reading their TIR output, and (3) TVM's new tvm.script lets you write TIR directly, bypassing TE entirely for maximum control.
TE (te.compute, te.placeholder) ← declarative, what
│
│ tvm.lower()
▼
TIR (PrimFunc) ← imperative, how
│
│ tvm.build()
▼
Target Code (LLVM IR / CUDA C / ...)
TIR is an imperative IR — it has explicit loops, conditionals, buffer loads/stores, and memory allocation. Compare this with Relay (graph-level, functional) and TE (declarative, mathematical).
| Node | Purpose | Example |
|---|---|---|
BufferStore |
Write to a buffer | C[i, j] = value |
BufferLoad |
Read from a buffer | A[i, k] |
For |
Loop construct | for i in range(0, M) |
Block |
Scope with bindings | Wraps computation + assertions |
Allocate |
Buffer allocation | Stack/shared/global memory |
IfThenElse |
Conditional | Guard for boundary checks |
SeqStmt |
Sequence of stmts | Multiple ops in order |
A PrimFunc is TIR's fundamental unit — one kernel that can be compiled to a target. It declares buffers, their shapes, dtypes, and memory scopes, then contains the loop body.
import tvm
from tvm import te, tir
import numpy as np
M, N, K = 128, 128, 128
A = te.placeholder((M, K), name="A")
B = te.placeholder((K, N), name="B")
k = te.reduce_axis((0, K), name="k")
C = te.compute((M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="C")
# Create schedule and lower to TIR
s = te.create_schedule(C.op)
mod = tvm.lower(s, [A, B, C], simple_mode=True)
print(mod)
Output (simplified):
@T.prim_func
def main(A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")):
for i in range(128):
for j in range(128):
C[i, j] = T.float32(0)
for k in range(128):
C[i, j] = C[i, j] + A[i, k] * B[k, j]
PrimFunc
├── params: [A, B, C] # Buffer parameters
├── buffer_map: {var → Buffer} # Shape, dtype, strides, scope
├── body: SeqStmt # The computation
│ └── For(i, 0, 128)
│ └── For(j, 0, 128)
│ ├── BufferStore(C, 0.0, [i, j]) # init
│ └── For(k, 0, 128)
│ └── BufferStore(C, ..., [i, j]) # accumulate
└── attrs: {"target": ..., "global_symbol": "main"}
compute_at — Fusing Producer into Consumercompute_at is the most powerful primitive. It moves a producer's computation inside the consumer's loop, improving locality.
# Without compute_at: compute B fully, then compute C
A = te.placeholder((M, N), name="A")
B = te.compute((M, N), lambda i, j: A[i, j] + 1, name="B")
C = te.compute((M, N), lambda i, j: B[i, j] * 2, name="C")
s = te.create_schedule(C.op)
print("=== Before compute_at ===")
print(tvm.lower(s, [A, C], simple_mode=True))
# for i: for j: B[i,j] = A[i,j] + 1
# for i: for j: C[i,j] = B[i,j] * 2 ← two separate passes
s = te.create_schedule(C.op)
s[B].compute_at(s[C], C.op.axis[1]) # fuse B into C's j-loop
print("=== After compute_at(C, j) ===")
print(tvm.lower(s, [A, C], simple_mode=True))
# for i: for j: B[0] = A[i,j] + 1; C[i,j] = B[0] * 2 ← fused!
Memory impact: Without fusion, B requires $M \times N$ storage. After compute_at, B shrinks to a single scalar — computed and consumed immediately.
Before compute_at: After compute_at(C, j):
┌──────────────────┐ ┌──────────────────┐
│ for i in M: │ │ for i in M: │
│ for j in N: │ │ for j in N: │
│ B[i,j] = ... │ │ b = A[i,j]+1 │
│ │ │ C[i,j] = b*2 │
│ for i in M: │ └──────────────────┘
│ for j in N: │ 1 pass, scalar B
│ C[i,j] = ... │
└──────────────────┘
2 passes, full B
cache_read / cache_write — Explicit Staging BuffersThese insert an explicit copy into a faster memory scope (shared memory on GPU, L1 scratchpad on accelerators).
s = te.create_schedule(C.op)
# Stage A into shared memory before the matmul inner loop
AA = s.cache_read(A, "shared", [C])
# Stage C writes through local registers
CC = s.cache_write(C, "local")
# Typical GPU schedule pattern:
bx, tx = s[C].split(C.op.axis[0], factor=32)
by, ty = s[C].split(C.op.axis[1], factor=32)
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[CC].compute_at(s[C], tx)
s[AA].compute_at(s[C], bx)
storage_align — Avoiding Bank ConflictsGPU shared memory is banked (typically 32 banks, 4 bytes each). When threads in a warp access the same bank, they serialize. storage_align pads the allocation:
# Pad shared memory by 1 element per row to avoid bank conflicts
s[AA].storage_align(AA.op.axis[0], factor=32, offset=1)
$$\text{Effective stride} = \lceil N / \text{factor} \rceil \times \text{factor} + \text{offset}$$
Without padding: threads 0 and 8 hit bank 0. With offset=1: stride becomes 33, all bank indices shift.
tensorize — Mapping to Hardware Intrinsicstensorize replaces a loop nest with a hardware-specific intrinsic (e.g., NVIDIA Tensor Core wmma, ARM dot product, Intel VNNI).
# Define a 16×16×16 matrix multiply intrinsic (Tensor Core)
def intrin_wmma_gemm():
a = te.placeholder((16, 16), name="a", dtype="float16")
b = te.placeholder((16, 16), name="b", dtype="float16")
k = te.reduce_axis((0, 16), name="k")
c = te.compute((16, 16),
lambda i, j: te.sum(a[i, k].astype("float32") *
b[k, j].astype("float32"), axis=k))
# ... buffer/intrinsic definition omitted for brevity
return te.decl_tensor_intrin(c.op, ...)
# Use it in a schedule
s = te.create_schedule(C.op)
io, ii = s[C].split(C.op.axis[0], factor=16)
jo, ji = s[C].split(C.op.axis[1], factor=16)
ko, ki = s[C].split(k, factor=16)
s[C].tensorize(ii, intrin_wmma_gemm()) # replace 16×16×16 block with wmma
tvm.lower()tvm.lower() is your primary debugging tool. It shows the TIR after schedule transformations but before target code generation.
M, N = 256, 256
A = te.placeholder((M, N), name="A")
B = te.compute((M, N), lambda i, j: A[i, j] * 2.0, name="B")
s = te.create_schedule(B.op)
# Step 1: tile
io, ii = s[B].split(B.op.axis[0], factor=32)
jo, ji = s[B].split(B.op.axis[1], factor=8)
s[B].reorder(io, jo, ii, ji)
# Step 2: vectorize innermost, parallelize outermost
s[B].vectorize(ji)
s[B].parallel(io)
print(tvm.lower(s, [A, B], simple_mode=True))
Expected TIR output:
@T.prim_func
def main(A: T.Buffer((256, 256), "float32"),
B: T.Buffer((256, 256), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
for i_outer in T.parallel(8): # parallel over 256/32=8
for j_outer in range(32): # 256/8=32
for i_inner in range(32):
B[i_outer*32 + i_inner,
j_outer*8 : j_outer*8 + 8] = (
A[i_outer*32 + i_inner,
j_outer*8 : j_outer*8 + 8] * T.float32(2)
) # vectorized (8-wide)
| Symptom | TIR Check | Fix |
|---|---|---|
| Slow kernel | Inner loop not vectorized | Add s[X].vectorize(inner) |
| GPU occupancy low | Missing bind() calls |
Bind to blockIdx / threadIdx |
| Memory blowup | Large temporary buffers | Use compute_at to shrink lifetimes |
| Bank conflicts | Shared memory stride = power of 2 | storage_align(offset=1) |
| Correctness bug | Reduction init missing | Check BufferStore for init + accumulate |
TVM's schedule system descends from Halide. Understanding the lineage clarifies design choices:
| Concept | Halide | TVM TE |
|---|---|---|
| Compute declaration | Func f; f(x,y) = ... |
te.compute((M,N), lambda i,j: ...) |
| Split | f.split(x, xo, xi, 32) |
xo, xi = s[f].split(x, factor=32) |
| Reorder | f.reorder(yo, xo, yi, xi) |
s[f].reorder(yo, xo, yi, xi) |
| Vectorize | f.vectorize(xi) |
s[f].vectorize(xi) |
| Parallelize | f.parallel(xo) |
s[f].parallel(xo) |
| Fuse into consumer | g.compute_at(f, yo) |
s[g].compute_at(s[f], yo) |
| GPU binding | f.gpu_blocks(xo).gpu_threads(xi) |
s[f].bind(xo, thread_axis("blockIdx.x")) |
| Hardware intrinsics | Not built-in | s[f].tensorize(xi, intrin) |
| Auto-tuning | Halide auto-scheduler | AutoTVM / MetaSchedule |
Key difference: Halide is primarily for image processing pipelines (stencils, multi-stage filters). TVM extends the model for neural network operators — reductions, tensor contractions, and hardware-specific intrinsics like Tensor Cores.
tvm.scriptFor advanced cases, bypass TE and write TIR directly:
from tvm.script import tir as T
@T.prim_func
def matmul_tiled(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, (128, 128), "float32")
B = T.match_buffer(b, (128, 128), "float32")
C = T.match_buffer(c, (128, 128), "float32")
for io, jo in T.grid(4, 4): # outer tiles: 128/32 = 4
for ii, ji in T.grid(32, 32): # inner tile: 32×32
with T.block("C"):
vi = T.axis.spatial(128, io * 32 + ii)
vj = T.axis.spatial(128, jo * 32 + ji)
T.reads(A[vi, 0:128], B[0:128, vj])
T.writes(C[vi, vj])
C[vi, vj] = T.float32(0)
for k in range(128):
C[vi, vj] += A[vi, k] * B[k, vj]
# This is valid TIR — can be built directly
mod = tvm.IRModule({"main": matmul_tiled})
print(mod.script())
When to use tvm.script over TE: - Irregular access patterns (sparse, ragged) - Custom reduction strategies - When you need explicit memory management that TE can't express - Contributing new operators to TVM itself
Write a 2D convolution in TE, apply three different schedules, and compare the tvm.lower() output:
H, W, KH, KW = 64, 64, 3, 3
data = te.placeholder((H, W), name="data")
kernel = te.placeholder((KH, KW), name="kernel")
rh = te.reduce_axis((0, KH), name="rh")
rw = te.reduce_axis((0, KW), name="rw")
conv = te.compute((H - KH + 1, W - KW + 1),
lambda i, j: te.sum(data[i + rh, j + rw] * kernel[rh, rw],
axis=[rh, rw]),
name="conv")
# Schedule A: naive (no transforms)
# Schedule B: tile 8×8, vectorize inner
# Schedule C: tile 8×8, vectorize inner, unroll reduction
# Compare TIR output for each — how do the loop nests differ?
Create a three-stage pipeline: ReLU → BatchNorm → Scale. Fuse all stages using compute_at so the entire pipeline runs in a single pass with no intermediate buffers.
Write a GPU-targeted matmul schedule that:
1. Tiles to 32×32 thread blocks
2. Uses cache_read to stage A tiles into shared memory
3. Uses cache_write to accumulate in local registers
4. Applies storage_align to avoid bank conflicts
5. Lowers with tvm.lower() and inspect the memory scopes in TIR
tvm.lower() and the input to target code generationcompute_at fuses producer-consumer stages, eliminating intermediate buffers and improving localitycache_read/cache_write explicitly manage data movement between memory hierarchies (global → shared → local)storage_align prevents GPU shared memory bank conflicts by adding paddingtensorize maps loop nests to hardware intrinsics (Tensor Cores, VNNI, NEON dot)tvm.lower(simple_mode=True) is your primary debugging tool — always inspect TIR before buildingtvm.scriptDay 34: TVM Runtime & Deployment — You've lowered TIR into optimized loops. Now you'll learn how those loops become executable code: the TVM runtime architecture (Module, PackedFunc, NDArray, DLPack), compilation targets (LLVM, CUDA, ARM), cross-compilation for edge devices, and saving/loading compiled modules for production deployment.