Phase IV — Vision: ViT, 3D, Video | Week 7 | 2.5 hours "Mask 75% of patches, reconstruct pixels. Simple, scalable, and it works spectacularly." — He et al., 2022
BERT masks 15% of tokens and predicts them. MAE masks 75% of image patches and reconstructs pixels. Why so aggressive? Because images have much higher spatial redundancy than text — a missing patch can often be inferred from its neighbors.
Original image: After 75% masking:
┌──┬──┬──┬──┐ ┌──┬──┬──┬──┐
│P1│P2│P3│P4│ │ │P2│ │ │
├──┼──┼──┼──┤ ├──┼──┼──┼──┤
│P5│P6│P7│P8│ mask 75% │ │ │P7│ │
├──┼──┼──┼──┤ ──────────► ├──┼──┼──┼──┤
│P9│10│11│12│ │P9│ │ │12│
├──┼──┼──┼──┤ ├──┼──┼──┼──┤
│13│14│15│16│ │ │ │ │ │
└──┴──┴──┴──┘ └──┴──┴──┴──┘
Encoder: processes ONLY visible patches (25%)
Decoder: reconstructs ALL patches from encoded visible + mask tokens
┌─────────────────────────────────────────────────────────────┐
│ MAE Pipeline │
│ │
│ 1. Patch + embed all patches │
│ 2. Random mask 75% → keep 25% visible │
│ │
│ 3. Encoder (large ViT): │
│ - Processes ONLY visible patches (25%) │
│ - Full self-attention but on reduced sequence │
│ - This is where 75% compute savings come from! │
│ │
│ 4. Insert mask tokens at masked positions │
│ - Mask token = shared learnable vector │
│ - Add positional embeddings to all tokens │
│ │
│ 5. Decoder (small transformer): │
│ - Processes full set (visible encoded + mask tokens) │
│ - Only used during pretraining → can be lightweight │
│ │
│ 6. Loss: MSE on masked patches only │
│ $\mathcal{L} = \frac{1}{|M|} \sum_{i \in M} \| │
│ \hat{x}_i - x_i \|^2$ │
│ (optionally: per-patch normalized pixels) │
└─────────────────────────────────────────────────────────────┘
| Masking ratio | Training speed | Quality |
|---|---|---|
| 15% (BERT-like) | Slow (process 85% of patches) | Lower — too easy, neighbor copying |
| 50% | Medium | Good |
| 75% | Fast (only 25% through encoder) | Best — forces semantic understanding |
| 90% | Very fast | Degrades — too hard |
At 75%, the model can't just copy neighbors — it must understand what the object is to fill in distant patches. This forces learning of high-level semantics.
The same principle underlies both: - BERT: Mask words → model learns language semantics by predicting them - MAE: Mask patches → model learns visual semantics by reconstructing them - Key insight: What you can predict from partial information reveals what you understand
| Property | MAE | DINO |
|---|---|---|
| Task | Pixel reconstruction | Feature matching |
| Supervision | Self (reconstruct masked) | Self (student-teacher) |
| What's learned | Low-level + high-level | Primarily high-level |
| Best for | Fine-tuning → classification | Zero-shot / k-NN transfer |
| Training speed | 3× faster (75% mask) | Slower (EMA, multi-crop) |
import torch
import torch.nn as nn
from einops import rearrange, repeat
class MAE(nn.Module):
"""Masked Autoencoder for self-supervised ViT pretraining."""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dim=768,
encoder_depth=12,
encoder_heads=12,
decoder_embed_dim=512,
decoder_depth=8,
decoder_heads=16,
mask_ratio=0.75,
):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
n_patches = (img_size // patch_size) ** 2
patch_dim = in_channels * patch_size ** 2
# --- Encoder ---
self.patch_embed = nn.Linear(patch_dim, embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
self.encoder_blocks = nn.ModuleList([
TransformerBlock(embed_dim, encoder_heads)
for _ in range(encoder_depth)
])
self.encoder_norm = nn.LayerNorm(embed_dim)
# --- Decoder ---
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, n_patches + 1, decoder_embed_dim)
)
self.decoder_blocks = nn.ModuleList([
TransformerBlock(decoder_embed_dim, decoder_heads)
for _ in range(decoder_depth)
])
self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_dim)
# Initialize
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.mask_token, std=0.02)
def patchify(self, imgs):
"""Convert images to patch sequences."""
p = self.patch_size
x = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
return x
def random_masking(self, x, mask_ratio):
"""Random masking: keep (1-mask_ratio) patches, mask the rest."""
B, N, D = x.shape
keep = int(N * (1 - mask_ratio))
# Random permutation per sample
noise = torch.rand(B, N, device=x.device)
ids_shuffle = noise.argsort(dim=1)
ids_restore = ids_shuffle.argsort(dim=1)
# Keep first `keep` indices
ids_keep = ids_shuffle[:, :keep]
x_visible = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
# Binary mask: 0 = keep, 1 = masked
mask = torch.ones(B, N, device=x.device)
mask[:, :keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_visible, mask, ids_restore
def forward_encoder(self, x):
# Patchify and embed
patches = self.patchify(x)
x = self.patch_embed(patches)
x = x + self.pos_embed[:, 1:] # no CLS position yet
# Mask
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# Prepend CLS
cls = self.cls_token + self.pos_embed[:, :1]
cls = cls.expand(x.shape[0], -1, -1)
x = torch.cat([cls, x], dim=1)
# Encode (only visible patches!)
for block in self.encoder_blocks:
x = block(x)
x = self.encoder_norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# Project to decoder dimension
x = self.decoder_embed(x)
# Insert mask tokens
B = x.shape[0]
mask_tokens = self.mask_token.repeat(
B, ids_restore.shape[1] + 1 - x.shape[1], 1
)
# Unshuffle: put visible and mask tokens back in original order
x_no_cls = torch.cat([x[:, 1:], mask_tokens], dim=1)
x_no_cls = torch.gather(
x_no_cls, dim=1,
index=ids_restore.unsqueeze(-1).expand(-1, -1, x.shape[2])
)
x = torch.cat([x[:, :1], x_no_cls], dim=1) # re-add CLS
# Add decoder positional embeddings
x = x + self.decoder_pos_embed
# Decode
for block in self.decoder_blocks:
x = block(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
return x[:, 1:] # remove CLS
def forward(self, imgs):
latent, mask, ids_restore = self.forward_encoder(imgs)
pred = self.forward_decoder(latent, ids_restore)
target = self.patchify(imgs)
# MSE loss on masked patches only
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # per-patch MSE
loss = (loss * mask).sum() / mask.sum() # masked patches only
return loss, pred, mask
Masking ratio experiment: Train a small MAE (embed_dim=256, depth=4) on CIFAR-10 with mask ratios of 25%, 50%, 75%, and 90%. Compare reconstruction quality and downstream linear probe accuracy.
Reconstruction visualization: Train MAE for 50 epochs on a small dataset. Visualize: original → masked → reconstructed for 8 images. How well does the model fill in missing patches?
MAE vs DINO downstream: Compare MAE and DINOv2 features on the same downstream task (k-NN on CIFAR-10 or linear probe). Which produces better features for transfer? Why?
MAE completes the self-supervised vision picture. You've now seen two paradigms: matching-based (DINO) and reconstruction-based (MAE). Both produce features without labels. Next week: extending vision to 3D, depth, and video.