Phase I · Week 2 · Day 10 of 70 · 2.5 hours
"If the operation you need doesn't exist in ATen, you don't file a feature request — you write a kernel."
| Previous | Next | Week | Phase | Curriculum |
|---|---|---|---|---|
| Day 9: Memory Management | Day 11: torch.profiler & Trace Analysis | Week 2: PyTorch Internals | Phase I: Foundations | Curriculum Home |
Fusing operations into a single CUDA kernel eliminates memory round-trips to global
memory. A naive bias + GELU(x) requires two kernel launches and one intermediate
tensor write/read. A fused kernel does both in registers — saving ~2× bandwidth and
1 kernel launch overhead. Custom C++ extensions are how production systems (FlashAttention,
xFormers, Triton-compiled ops) deliver their speedups. Knowing this pathway is essential
for any ML compiler engineer.
PyTorch offers two paths for custom C++/CUDA code:
┌─────────────────────────────────────────────────┐
│ Python User Code │
│ my_op = my_extension.fused_bias_gelu │
├─────────────┬───────────────────────────────────┤
│ Ahead-of- │ JIT Compilation │
│ Time Build │ torch.utils.cpp_extension.load() │
│ setup.py / │ ← compiles on first call, │
│ setuptools │ caches in ~/.cache/torch_ext/ │
├─────────────┴───────────────────────────────────┤
│ torch.utils.cpp_extension │
│ CppExtension / CUDAExtension │
├─────────────────────────────────────────────────┤
│ pybind11 bindings (auto-included by PyTorch) │
├─────────────────────────────────────────────────┤
│ Your C++ / CUDA source files │
│ (.cpp for CPU, .cu for CUDA) │
└─────────────────────────────────────────────────┘
load()The fastest way to iterate on a custom op:
from torch.utils.cpp_extension import load
# This compiles on first call, caches for subsequent runs
my_extension = load(
name='my_extension',
sources=['my_op.cpp', 'my_op_cuda.cu'],
extra_cuda_cflags=['-O3', '--use_fast_math'],
verbose=True,
)
# Use it like any Python function
y = my_extension.fused_bias_gelu(x, bias)
Behind the scenes, load():
1. Invokes nvcc for .cu files and g++ for .cpp files
2. Links against libtorch and libc10
3. Creates a shared object (.so) via pybind11
4. Caches in ~/.cache/torch_extensions/
GELU (Gaussian Error Linear Unit):
$$\text{GELU}(x) = x \cdot \Phi(x) \approx 0.5 \, x \left(1 + \tanh\!\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715\, x^3\right)\right]\right)$$
We want to compute $\text{GELU}(x + \text{bias})$ in a single kernel.
fused_bias_gelu_cuda.cu)#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
// Device function: GELU approximation (tanh variant)
__device__ __forceinline__ float gelu_fwd(float x) {
const float kSqrt2OverPi = 0.7978845608f; // sqrt(2/pi)
const float kCoeff = 0.044715f;
float cube = x * x * x;
float inner = kSqrt2OverPi * (x + kCoeff * cube);
return 0.5f * x * (1.0f + tanhf(inner));
}
// Device function: GELU gradient
__device__ __forceinline__ float gelu_bwd(float x) {
const float kSqrt2OverPi = 0.7978845608f;
const float kCoeff = 0.044715f;
float cube = x * x * x;
float inner = kSqrt2OverPi * (x + kCoeff * cube);
float tanh_inner = tanhf(inner);
float sech2 = 1.0f - tanh_inner * tanh_inner;
float inner_deriv = kSqrt2OverPi * (1.0f + 3.0f * kCoeff * x * x);
return 0.5f * (1.0f + tanh_inner) + 0.5f * x * sech2 * inner_deriv;
}
// Forward kernel: y = GELU(x + bias)
__global__ void fused_bias_gelu_fwd_kernel(
const float* __restrict__ input,
const float* __restrict__ bias,
float* __restrict__ output,
int rows, int cols
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < rows * cols) {
int col = idx % cols;
float val = input[idx] + bias[col];
output[idx] = gelu_fwd(val);
}
}
// Backward kernel
__global__ void fused_bias_gelu_bwd_kernel(
const float* __restrict__ grad_out,
const float* __restrict__ input,
const float* __restrict__ bias,
float* __restrict__ grad_input,
int rows, int cols
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < rows * cols) {
int col = idx % cols;
float val = input[idx] + bias[col];
grad_input[idx] = grad_out[idx] * gelu_bwd(val);
}
}
// C++ wrapper functions (called from Python)
torch::Tensor fused_bias_gelu_fwd_cuda(
torch::Tensor input,
torch::Tensor bias
) {
TORCH_CHECK(input.is_cuda(), "input must be on CUDA");
TORCH_CHECK(bias.is_cuda(), "bias must be on CUDA");
TORCH_CHECK(input.size(-1) == bias.size(0), "bias size mismatch");
auto output = torch::empty_like(input);
int rows = input.numel() / input.size(-1);
int cols = input.size(-1);
int total = rows * cols;
const int threads = 256;
const int blocks = (total + threads - 1) / threads;
fused_bias_gelu_fwd_kernel<<<blocks, threads>>>(
input.data_ptr<float>(),
bias.data_ptr<float>(),
output.data_ptr<float>(),
rows, cols
);
return output;
}
torch::Tensor fused_bias_gelu_bwd_cuda(
torch::Tensor grad_out,
torch::Tensor input,
torch::Tensor bias
) {
auto grad_input = torch::empty_like(input);
int rows = input.numel() / input.size(-1);
int cols = input.size(-1);
int total = rows * cols;
const int threads = 256;
const int blocks = (total + threads - 1) / threads;
fused_bias_gelu_bwd_kernel<<<blocks, threads>>>(
grad_out.data_ptr<float>(),
input.data_ptr<float>(),
bias.data_ptr<float>(),
grad_input.data_ptr<float>(),
rows, cols
);
return grad_input;
}
fused_bias_gelu.cpp)#include <torch/extension.h>
// Forward declarations of CUDA functions
torch::Tensor fused_bias_gelu_fwd_cuda(torch::Tensor input,
torch::Tensor bias);
torch::Tensor fused_bias_gelu_bwd_cuda(torch::Tensor grad_out,
torch::Tensor input,
torch::Tensor bias);
// Input validation + dispatch
torch::Tensor fused_bias_gelu_fwd(torch::Tensor input,
torch::Tensor bias) {
TORCH_CHECK(input.device().is_cuda(),
"fused_bias_gelu only supports CUDA tensors");
return fused_bias_gelu_fwd_cuda(input, bias);
}
torch::Tensor fused_bias_gelu_bwd(torch::Tensor grad_out,
torch::Tensor input,
torch::Tensor bias) {
return fused_bias_gelu_bwd_cuda(grad_out, input, bias);
}
// pybind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_bias_gelu_fwd,
"Fused bias + GELU forward (CUDA)");
m.def("backward", &fused_bias_gelu_bwd,
"Fused bias + GELU backward (CUDA)");
}
To make the custom op work with loss.backward(), wrap it in a
torch.autograd.Function:
import torch
from torch.utils.cpp_extension import load
# JIT compile
fused_op = load(
name='fused_bias_gelu',
sources=['fused_bias_gelu.cpp', 'fused_bias_gelu_cuda.cu'],
extra_cuda_cflags=['-O3'],
)
class FusedBiasGELU(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bias):
# Save for backward
ctx.save_for_backward(input, bias)
return fused_op.forward(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
grad_input = fused_op.backward(grad_output.contiguous(),
input, bias)
# grad_bias = sum of grad_input along batch dims
grad_bias = grad_input.sum(dim=tuple(range(grad_input.dim()-1)))
return grad_input, grad_bias
# Convenience wrapper
def fused_bias_gelu(input, bias):
return FusedBiasGELU.apply(input, bias)
Always validate your backward implementation:
from torch.autograd import gradcheck
input = torch.randn(4, 8, dtype=torch.float64,
device='cuda', requires_grad=True)
bias = torch.randn(8, dtype=torch.float64,
device='cuda', requires_grad=True)
# gradcheck uses finite differences to verify analytical gradients
assert gradcheck(FusedBiasGELU.apply, (input, bias), eps=1e-6)
print("Gradient check PASSED")
For proper integration with torch.compile and other PyTorch subsystems,
register your op with the dispatcher using torch.library:
import torch
from torch import Tensor
# Define the operator schema
torch.library.define(
"myops::fused_bias_gelu",
"(Tensor input, Tensor bias) -> Tensor"
)
# Register CUDA implementation
@torch.library.impl("myops::fused_bias_gelu", "cuda")
def fused_bias_gelu_cuda(input: Tensor, bias: Tensor) -> Tensor:
return FusedBiasGELU.apply(input, bias)
# Register a "fake" implementation for torch.compile tracing
@torch.library.impl_abstract("myops::fused_bias_gelu")
def fused_bias_gelu_abstract(input: Tensor, bias: Tensor) -> Tensor:
# Return a tensor with the correct shape/dtype but no data
return torch.empty_like(input)
# Now usable as:
y = torch.ops.myops.fused_bias_gelu(x, bias)
For distribution, use setup.py:
# setup.py
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='fused_bias_gelu',
ext_modules=[
CUDAExtension(
name='fused_bias_gelu',
sources=[
'fused_bias_gelu.cpp',
'fused_bias_gelu_cuda.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': ['-O3', '--use_fast_math',
'-gencode=arch=compute_80,code=sm_80', # A100
'-gencode=arch=compute_89,code=sm_89'], # RTX 4090
},
),
],
cmdclass={'build_ext': BuildExtension},
)
Build and install:
pip install -e .
# OR for development iteration:
python setup.py build_ext --inplace
import torch
import torch.nn.functional as F
import time
def unfused_bias_gelu(x, bias):
return F.gelu(x + bias, approximate='tanh')
# Compare against unfused PyTorch implementation
x = torch.randn(512, 4096, device='cuda')
bias = torch.randn(4096, device='cuda')
# Warmup
for _ in range(10):
unfused_bias_gelu(x, bias)
# fused_bias_gelu(x, bias) # your custom op
torch.cuda.synchronize()
# Benchmark unfused
t0 = time.time()
for _ in range(1000):
y_ref = unfused_bias_gelu(x, bias)
torch.cuda.synchronize()
t_unfused = time.time() - t0
print(f"Unfused: {t_unfused*1000:.1f} ms (1000 iters)")
# Compare with your fused version — expect 1.3-2x speedup
# TASK: Also measure peak memory for both versions
// Extend the CUDA kernel to support fp16 using __half types.
// Hints:
// - Use AT_DISPATCH_FLOATING_TYPES_AND_HALF macro
// - For __half arithmetic, use __hadd(), __hmul()
// - Or cast to float for computation, cast back to __half for storage
// Template the kernel:
template <typename scalar_t>
__global__ void fused_bias_gelu_fwd_kernel(
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ bias,
scalar_t* __restrict__ output,
int rows, int cols
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < rows * cols) {
int col = idx % cols;
float val = static_cast<float>(input[idx])
+ static_cast<float>(bias[col]);
output[idx] = static_cast<scalar_t>(gelu_fwd(val));
}
}
// Dispatch:
// AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "fused_gelu", [&] {
// fused_bias_gelu_fwd_kernel<scalar_t><<<blocks, threads>>>(...);
// });
// The basic kernel loads one float per thread. Optimize by loading
// 4 floats at once using float4:
__global__ void fused_bias_gelu_fwd_vec4(
const float4* __restrict__ input,
const float* __restrict__ bias,
float4* __restrict__ output,
int rows, int cols_div4
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < rows * cols_div4) {
int col4 = (idx % cols_div4) * 4;
float4 in = input[idx];
float4 out;
out.x = gelu_fwd(in.x + bias[col4 + 0]);
out.y = gelu_fwd(in.y + bias[col4 + 1]);
out.z = gelu_fwd(in.z + bias[col4 + 2]);
out.w = gelu_fwd(in.w + bias[col4 + 3]);
output[idx] = out;
}
}
// TASK: Integrate this into the extension and benchmark.
// Expected speedup: 1.5-2x over scalar version due to memory coalescing.
torch.utils.cpp_extension.load() provides zero-setup JIT compilation for
rapid C++/CUDA kernel developmenttorch.autograd.Function bridges custom CUDA kernels with PyTorch's autograd
engine; always verify with gradchecktorch.library registers ops with the dispatcher for compatibility with
torch.compile, vmap, and other transformsAT_DISPATCH_FLOATING_TYPES_AND_HALF handles dtype dispatch across fp32/fp64/fp16
without duplicating kernel codeDay 11: torch.profiler & Trace Analysis — We'll learn to profile PyTorch programs
with torch.profiler, visualize execution traces in Chrome/Perfetto, identify GPU
idle time, and measure CPU-GPU overlap to find the real bottleneck.