Phase III · Week 6 · Day 37 of 70 · 2.5 hours
"The best framework isn't the one with the most features — it's the one where every feature composes with every other."
| ← Previous | Next → | 📅 Week | 🔷 Phase | 📚 Curriculum |
|---|---|---|---|---|
| Day 36: AutoTVM & AutoScheduler (Ansor) | Day 38: BYOC — Bring Your Own Codegen | Week 6: TVM Tuning & Backends | Phase III: Apache TVM Deep Dive | ML Compilers |
Yesterday you saw two generations of TVM auto-tuning: AutoTVM (template-based) and Ansor (template-free). Both work, but they're separate systems with different APIs, cost models, and search strategies. MetaSchedule unifies them into a single framework built on TIR schedule traces — the same primitives you learned in Week 5. This means you can: (1) write custom schedule rules in the same language you already know, (2) replay and mutate schedules deterministically, and (3) share a single database of tuning results across projects. MetaSchedule is the future of TVM tuning, and as of TVM 0.12+ it's the recommended approach.
┌─────────────────────────────────────────────────────────────┐
│ Before MetaSchedule │
│ │
│ AutoTVM Ansor │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Template API │ │ Sketch API │ │
│ │ ConfigSpace │ │ SearchPolicy │ │
│ │ XGBTuner │ │ CostModel │ │
│ │ JSON logs │ │ JSON logs │ │
│ └──────────────┘ └──────────────┘ │
│ │ │ │
│ └─── Different APIs ─────────────┘ │
│ Different search strategies │
│ Incompatible log formats │
│ Neither uses TIR schedule primitives directly │
│ │
├─────────────────────────────────────────────────────────────┤
│ After MetaSchedule │
│ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ MetaSchedule │ │
│ │ ┌────────────┐ ┌───────────┐ ┌────────────────┐ │ │
│ │ │ Schedule │ │ Mutators │ │ Database │ │ │
│ │ │ Rules │ │ │ │ (JSON/SQLite) │ │ │
│ │ └────────────┘ └───────────┘ └────────────────┘ │ │
│ │ Built on TIR Schedule Traces │ │
│ └──────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
| Principle | Description |
|---|---|
| Trace-based | Schedules are recorded as a sequence of TIR primitive calls |
| Composable | Rules and mutators compose freely — mix and match |
| Replay-able | Any schedule can be replayed deterministically from its trace |
| Database-backed | Results stored in JSON or SQLite for cross-session reuse |
| Unified | One API for both template-style and template-free tuning |
A trace records every schedule primitive call as a structured log. This is the key abstraction — it makes schedules inspectable, replayable, and mutable.
import tvm
from tvm import tir
from tvm.script import tir as T
@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle):
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.match_buffer(b, (1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
for i, j, k in T.grid(1024, 1024, 1024):
with T.block("C"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
# Create a schedule from the PrimFunc
sch = tir.Schedule(matmul)
# Apply primitives — each call is recorded in the trace
block = sch.get_block("C")
i, j, k = sch.get_loops(block)
# Tile the i-axis
io, ii = sch.split(i, factors=[32, 32])
# Tile the j-axis
jo, ji = sch.split(j, factors=[32, 32])
# Tile the k-axis
ko, ki = sch.split(k, factors=[64, 16])
# Reorder
sch.reorder(io, jo, ko, ii, ki, ji)
# Print the trace — the recorded sequence of decisions
print(sch.trace)
Instruction 0: GetBlock(name="C")
Instruction 1: GetLoops(block=b0)
Instruction 2: Split(loop=l0, factors=[32, 32]) # i → io, ii
Instruction 3: Split(loop=l1, factors=[32, 32]) # j → jo, ji
Instruction 4: Split(loop=l2, factors=[64, 16]) # k → ko, ki
Instruction 5: Reorder(order=[l3, l5, l7, l4, l8, l6])
The trace is deterministic — replaying these instructions on the same PrimFunc always produces the same schedule. This is what makes mutation-based search possible.
MetaSchedule uses Schedule Rules to define the search space. A rule is a function that takes a block and produces candidate schedule decisions.
from tvm.meta_schedule import schedule_rule
# Common rules shipped with TVM:
rules = [
schedule_rule.AutoInline( # inline trivial computations
into_producer=False,
into_consumer=True,
),
schedule_rule.MultiLevelTiling( # the workhorse: multi-level tiling
structure="SSRSRS", # Spatial-Spatial-Reduce-Spatial-Reduce-Spatial
tile_binds=None, # for GPU: ["blockIdx.x", "threadIdx.x"]
max_innermost_factor=64,
vector_load_lens=None,
reuse_read=None,
reuse_write=None,
),
schedule_rule.ParallelizeVectorizeUnroll(
max_jobs_per_core=16,
max_vectorize_extent=32,
unroll_max_steps=[0, 16, 64, 512],
unroll_explicit=True,
),
]
The "SSRSRS" string encodes the tiling pattern:
S = Spatial loop level
R = Reduction loop level
SSRSRS for a matmul C[i,j] += A[i,k] * B[k,j]:
Level 1 (S): tile i → i0, remaining
Level 2 (S): tile j → j0, remaining
Level 3 (R): tile k → k0, remaining
Level 4 (S): tile i_remaining → i1, i2
Level 5 (R): tile k_remaining → k1, k2
Level 6 (S): tile j_remaining → j1, j2
Result: reorder(i0, j0, k0, i1, k1, j1, i2, k2, j2)
For GPU targets, the tiling structure becomes "SSSRRSRS" with thread/block bindings:
Level 1 (S): blockIdx.x ← coarsest spatial
Level 2 (S): vthread ← virtual thread (bank conflict avoidance)
Level 3 (S): threadIdx.x ← fine-grained spatial
Level 4 (R): reduction outer
Level 5 (R): reduction inner
Level 6 (S): register tile
Level 7 (R): reduction innermost
Level 8 (S): vectorize ← innermost spatial
Instead of generating schedules from scratch each time, MetaSchedule mutates existing good schedules to explore neighboring configurations.
Best schedule so far (trace):
Split(i, [32, 32])
Split(j, [32, 32])
Split(k, [64, 16])
Reorder(io, jo, ko, ii, ki, ji)
│
│ Mutator: MutateTileSize
▼
Candidate 1: Candidate 2:
Split(i, [64, 16]) ← changed Split(i, [32, 32])
Split(j, [32, 32]) Split(j, [16, 64]) ← changed
Split(k, [64, 16]) Split(k, [128, 8]) ← changed
Reorder(...) Reorder(...)
| Mutator | What It Mutates |
|---|---|
MutateTileSize |
Changes tile factors within valid divisor sets |
MutateUnroll |
Changes unroll step from the allowed set |
MutateParallel |
Toggles parallelization on/off for outer loops |
MutateComputeLocation |
Moves compute_at to a different loop level |
MutateAutoInline |
Toggles inline decisions for intermediate stages |
from tvm import meta_schedule as ms
# The evolutionary search loop (conceptual)
population = [generate_initial_schedule(rules) for _ in range(64)]
for generation in range(num_generations):
# 1. Score population with cost model
scores = cost_model.predict(population)
# 2. Select top-k (tournament selection)
elites = select_top_k(population, scores, k=16)
# 3. Mutate elites to create new candidates
children = [mutator.mutate(random.choice(elites)) for _ in range(48)]
# 4. Measure a few on hardware (to retrain cost model)
measured = measure_on_hardware(children[:8])
cost_model.update(measured)
# 5. Next generation
population = elites + children
from tvm import meta_schedule as ms
# Tune a single PrimFunc
database = ms.tune_tir(
mod=tvm.IRModule({"main": matmul}),
target="llvm -mcpu=core-avx2",
max_trials_global=500,
num_trials_per_iter=64,
work_dir="./tune_matmul",
# Uses default rules, mutators, and cost model
)
# Retrieve the best schedule
workload = database.commit_workload(tvm.IRModule({"main": matmul}))
top_record = database.get_top_k(workload, top_k=1)[0]
best_sch = top_record.as_schedule()
print(best_sch.mod.script())
import tvm.relay as relay
# Load model
mod, params = relay.frontend.from_pytorch(traced_model, input_infos)
target = tvm.target.Target("cuda -arch=sm_80")
# Tune all operators via MetaSchedule
database = ms.relay_integration.tune_relay(
mod=mod,
params=params,
target=target,
work_dir="./tune_resnet",
max_trials_global=3000,
strategy="evolutionary",
)
# Compile with tuned schedules
lib = ms.relay_integration.compile_relay(
database=database,
mod=mod,
params=params,
target=target,
)
MetaSchedule stores results in a structured database (JSON lines or SQLite):
# Re-use previous tuning results
database = ms.database.JSONDatabase(
path_workload="./tune_resnet/database_workload.json",
path_tuning_record="./tune_resnet/database_tuning_record.json",
)
# Check what's cached
workloads = database.get_all_tuning_records()
print(f"Database has {len(workloads)} tuned workloads")
# Apply cached schedules to a new compilation
lib = ms.relay_integration.compile_relay(
database=database,
mod=new_mod, # works if operators match
params=new_params,
target=target,
)
from tvm.meta_schedule import schedule_rule as sr
@sr._register_schedule_rule("my_custom_tiling")
class MyCustomTiling(sr.ScheduleRule):
"""Custom tiling that always uses power-of-2 factors."""
def _initialize_with_tune_context(self, context):
self.target = context.target
def apply(self, sch, block):
loops = sch.get_loops(block)
candidates = []
for loop in loops:
extent = sch.get(loop).extent
# Only allow power-of-2 tile sizes
factors = [2**i for i in range(1, 10) if 2**i <= extent]
if factors:
for f in factors:
new_sch = sch.copy()
new_sch.split(loop, factors=[None, f])
candidates.append(new_sch)
return candidates
from tvm.meta_schedule import mutator as mut
import random
@mut._register_mutator("my_tile_mutator")
class MyTileMutator(mut.Mutator):
"""Mutate tile sizes by doubling or halving."""
def _initialize_with_tune_context(self, context):
pass
def apply(self, trace, _):
# Find Split instructions in the trace
for i, inst in enumerate(trace.insts):
if inst.kind.name == "Split":
# Randomly double or halve a factor
factors = list(trace.decisions[inst])
idx = random.randint(0, len(factors) - 1)
if random.random() < 0.5 and factors[idx] > 1:
factors[idx] //= 2
else:
factors[idx] *= 2
# Create mutated trace
new_trace = trace.copy()
new_trace.decisions[inst] = factors
return new_trace
return None # no valid mutation found
Tune a depthwise convolution (N=1, C=128, H=W=56, kernel=3×3) with both MetaSchedule and Ansor. Use 300 trials each. Compare:
1. Best latency achieved
2. Time to reach 90% of best latency
3. Number of hardware measurements required
After tuning a matmul with MetaSchedule, print the best schedule's trace. For each instruction: 1. What primitive does it call? 2. What were the decisions (tile sizes, annotations)? 3. Replay the trace on a fresh schedule to verify determinism.
Tune ResNet-18 for cuda. Save the database. Then load it and compile ResNet-34 — do any workloads hit the cache? Why or why not? (Hint: which operators have identical shapes?)
"SSRSRS" / "SSSRRSRS") is the most important schedule ruleEven the best auto-tuner can't beat a vendor library hand-optimized over years for a specific operator on specific hardware. Day 38 introduces BYOC (Bring Your Own Codegen) — TVM's framework for offloading subgraphs to external libraries like cuDNN, TensorRT, or DNNL, combining TVM's graph-level optimizations with vendor-tuned kernels.