21. Vision Transformer (ViT)
21. Vision Transformer (ViT)¶
Previous: GPT | Next: Vision Transformer (ViT) Implementation
Learning Objectives¶
- Understanding Vision Transformer architecture
- Patch Embedding principles
- CLS token and Position Embedding
- ViT variants (DeiT, Swin Transformer)
- PyTorch implementation and applications
1. Vision Transformer Overview¶
Core Idea¶
Traditional CNN: local features ā global features (hierarchical)
ViT: Convert image to patch sequence ā process with Transformer
Image (224Ć224) ā 196 patches of 16Ć16 ā Transformer Encoder
Why Transformer for Vision?¶
1. Self-Attention Advantages
- Captures long-range dependencies
- Considers global context
2. Scalability
- Surpasses CNN on large datasets
- Easy to scale
3. Architecture Unification
- Can unify Vision + Language
- Favorable for multimodal learning
2. ViT Architecture¶
Overall Structure¶
Input Image (224Ć224Ć3)
ā
[Patch Embedding] ā 196 patch vectors (each 768-dim)
ā
[Add CLS Token] ā 197 tokens
ā
[Add Position Embedding]
ā
[Transformer Encoder Ć L layers]
ā
[Extract CLS Token output]
ā
[MLP Head] ā Classification result
Formula Summary¶
# Input
x ā R^(HĆWĆC) # e.g., 224Ć224Ć3
# Patch splitting
P = patch_size # e.g., 16
N = (H/P) Ć (W/P) # number of patches: 196
# Patch Embedding
x_p ā R^(NĆ(P²·C)) # 196Ć768 (16Ć16Ć3 = 768)
z_0 = [x_class; x_pĀ·E] + E_pos # E: projection matrix
# Transformer
z_l = MSA(LN(z_{l-1})) + z_{l-1} # Multi-Head Self-Attention
z_l = MLP(LN(z_l)) + z_l # Feed Forward
# Output
y = LN(z_L^0) # Final representation of CLS token
3. Patch Embedding¶
Concept¶
# Split image into patches
# (B, 3, 224, 224) ā (B, 196, 768)
# Method 1: reshape
patches = image.reshape(B, N, P*P*C) # Direct reconstruction
# Method 2: Conv2d (more efficient)
# stride=kernel_size for non-overlapping patches
conv = nn.Conv2d(3, 768, kernel_size=16, stride=16)
patches = conv(image) # (B, 768, 14, 14)
patches = patches.flatten(2).transpose(1, 2) # (B, 196, 768)
PyTorch Implementation¶
class PatchEmbedding(nn.Module):
"""Patch Embedding Layer (āā)"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Extract patches + embed with Conv2d
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W)
x = self.projection(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2) # (B, embed_dim, num_patches)
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
return x
4. CLS Token and Position Embedding¶
CLS Token¶
# Concept borrowed from BERT
# Special token that learns representation of entire image
class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# Broadcast to batch
cls_tokens = class_token.expand(batch_size, -1, -1) # (B, 1, D)
# Concatenate before patch embeddings
x = torch.cat([cls_tokens, patch_embeddings], dim=1) # (B, N+1, D)
Position Embedding¶
# Provide patch position information (Transformer has no position info)
class PositionEmbedding(nn.Module):
"""Learnable Position Embedding (āā)"""
def __init__(self, num_patches, embed_dim):
super().__init__()
# +1 for CLS token
self.pos_embedding = nn.Parameter(
torch.randn(1, num_patches + 1, embed_dim)
)
def forward(self, x):
return x + self.pos_embedding
Position Embedding Visualization¶
def visualize_position_embedding(pos_embed, img_size=224, patch_size=16):
"""Visualize position embedding similarity (āā)"""
# pos_embed: (1, N+1, D)
# Exclude CLS token
pos_embed = pos_embed[0, 1:] # (N, D)
# Similarity matrix
similarity = torch.mm(pos_embed, pos_embed.T) # (N, N)
# Similarity with specific patch
num_patches = (img_size // patch_size)
center_idx = num_patches * (num_patches // 2) + (num_patches // 2)
center_sim = similarity[center_idx].reshape(num_patches, num_patches)
return center_sim # Similarity map with center patch
5. Complete Vision Transformer Implementation¶
Basic ViT¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""Multi-Head Self-Attention (āāā)"""
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# QKV computation
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Attention
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
attn = self.dropout(attn)
# Output
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class MLP(nn.Module):
"""MLP Block (āā)"""
def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0):
super().__init__()
hidden_dim = int(embed_dim * mlp_ratio)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = F.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
"""Transformer Encoder Block (āāā)"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, mlp_ratio, dropout)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
"""Vision Transformer (ViT) (āāāā)"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.0
):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
# Patch Embedding
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
# CLS Token
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# Position Embedding
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches + 1, embed_dim)
)
self.dropout = nn.Dropout(dropout)
# Transformer Blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# Classification Head
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B = x.shape[0]
# Patch Embedding
x = self.patch_embed(x) # (B, N, D)
# Add CLS Token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, N+1, D)
# Add Position Embedding
x = x + self.pos_embed
x = self.dropout(x)
# Transformer Blocks
for block in self.blocks:
x = block(x)
x = self.norm(x)
# Extract and classify CLS token
cls_output = x[:, 0]
return self.head(cls_output)
ViT Model Variants¶
# ViT-Base (ViT-B/16)
vit_base = VisionTransformer(
img_size=224, patch_size=16,
embed_dim=768, depth=12, num_heads=12
)
# ViT-Large (ViT-L/16)
vit_large = VisionTransformer(
img_size=224, patch_size=16,
embed_dim=1024, depth=24, num_heads=16
)
# ViT-Huge (ViT-H/14)
vit_huge = VisionTransformer(
img_size=224, patch_size=14,
embed_dim=1280, depth=32, num_heads=16
)
6. DeiT (Data-efficient Image Transformer)¶
Key Improvements¶
Problem: ViT requires large-scale data (JFT-300M etc)
Solution: Knowledge distillation + strong data augmentation for ImageNet-only training
1. Distillation Token: Learn knowledge from CNN teacher
2. Strong Data Augmentation
3. Regularization (Stochastic Depth, Dropout)
Distillation Token¶
class DeiT(nn.Module):
"""Data-efficient Image Transformer (āāāā)"""
def __init__(self, img_size=224, patch_size=16, num_classes=1000,
embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
# CLS Token + Distillation Token
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# Position Embedding (+2 for CLS and DIST)
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches + 2, embed_dim)
)
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# Two heads
self.head = nn.Linear(embed_dim, num_classes)
self.head_dist = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
dist_tokens = self.dist_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
x = x + self.pos_embed
for block in self.blocks:
x = block(x)
x = self.norm(x)
# Use both CLS and DIST tokens
cls_output = self.head(x[:, 0])
dist_output = self.head_dist(x[:, 1])
if self.training:
return cls_output, dist_output
else:
# Average during inference
return (cls_output + dist_output) / 2
DeiT Training¶
def train_deit_with_distillation(student, teacher, dataloader, epochs=100):
"""DeiT Knowledge Distillation Training (āāā)"""
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-3)
criterion_ce = nn.CrossEntropyLoss()
criterion_dist = nn.CrossEntropyLoss()
teacher.eval()
for epoch in range(epochs):
for images, labels in dataloader:
# Teacher prediction (soft labels)
with torch.no_grad():
teacher_output = teacher(images)
# Student predictions
cls_output, dist_output = student(images)
# Losses
loss_cls = criterion_ce(cls_output, labels)
loss_dist = criterion_dist(dist_output, teacher_output.argmax(dim=1))
loss = 0.5 * loss_cls + 0.5 * loss_dist
optimizer.zero_grad()
loss.backward()
optimizer.step()
7. Swin Transformer¶
Key Idea¶
Problem: ViT's O(n²) complexity ā difficult to process high-resolution images
Solution: Hierarchical structure + Shifted Window Attention
Features:
1. Window Attention: attention only within local windows
2. Shifted Windows: information exchange between windows
3. Hierarchical structure: progressive feature map resolution reduction
Window Attention¶
def window_partition(x, window_size):
"""Partition image into windows (āāā)"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size,
W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""Merge windows back to image (āāā)"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size,
window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window-based Multi-Head Self-Attention (āāāā)"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# Relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
)
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# Create relative position index
self._create_relative_position_index()
def _create_relative_position_index(self):
coords = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid([coords, coords], indexing='ij'))
coords_flatten = coords.flatten(1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size - 1
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
# Add relative position bias
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
attn = attn + mask
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
Shifted Window¶
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block with (Shifted) Window Attention (āāāā)"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0):
super().__init__()
self.dim = dim
self.window_size = window_size
self.shift_size = shift_size
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim)
def forward(self, x, H, W):
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# Cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# Window partition
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# Window attention
attn_windows = self.attn(x_windows)
# Window reverse
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
# Reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, L, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
8. Using Pretrained Models¶
Using torchvision¶
from torchvision.models import vit_b_16, vit_l_16, swin_t, swin_s
# ViT-B/16 (pretrained)
model = vit_b_16(weights='IMAGENET1K_V1')
# Use as feature extractor
model.heads = nn.Identity()
features = model(image) # (B, 768)
# Fine-tuning
model = vit_b_16(weights='IMAGENET1K_V1')
model.heads = nn.Linear(768, num_classes)
# Differential learning rates
params = [
{'params': model.encoder.parameters(), 'lr': 1e-5}, # backbone
{'params': model.heads.parameters(), 'lr': 1e-3} # head
]
optimizer = torch.optim.AdamW(params)
Using timm Library¶
import timm
# List available ViT models
vit_models = timm.list_models('vit*', pretrained=True)
print(f"Available ViT models: {len(vit_models)}")
# Load model
model = timm.create_model('vit_base_patch16_224', pretrained=True)
# Custom classification head
model = timm.create_model(
'vit_base_patch16_224',
pretrained=True,
num_classes=10 # Automatically replace head
)
# DeiT model
deit_model = timm.create_model('deit_base_patch16_224', pretrained=True)
# Swin Transformer
swin_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
9. Practical Fine-tuning¶
CIFAR-10 Fine-tuning¶
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
def finetune_vit_cifar10(epochs=10):
"""ViT CIFAR-10 Fine-tuning (āāā)"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data preprocessing (resize to ViT input size)
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# Dataset
train_data = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10('data', train=False, transform=transform_test)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4)
# Model
model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)
# Optimizer (differential learning rates)
backbone_params = [p for n, p in model.named_parameters() if 'head' not in n]
head_params = [p for n, p in model.named_parameters() if 'head' in n]
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': 1e-5},
{'params': head_params, 'lr': 1e-3}
], weight_decay=0.01)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# Training
for epoch in range(epochs):
model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100. * correct / total
print(f'Epoch {epoch+1}/{epochs}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy:.2f}%')
scheduler.step()
return model
10. ViT vs CNN Comparison¶
Characteristics Comparison¶
| Feature | CNN | ViT |
|---|---|---|
| Inductive bias | Locality, equivariance | None |
| Data requirement | Less | More |
| Computational complexity | O(n) | O(n²) |
| Long-range dependencies | Difficult | Easy |
| Interpretability | Filter visualization | Attention visualization |
Usage Guidelines¶
Prefer CNN:
- Small datasets
- Limited computational resources
- Real-time inference required
Prefer ViT:
- Large datasets or pretrained models available
- Tasks requiring global context
- Planning multimodal learning
Summary¶
Key Concepts¶
- Patch Embedding: Convert image to patch sequence
- CLS Token: Learn global image representation
- Position Embedding: Provide patch position information
- DeiT: Data-efficient training (knowledge distillation)
- Swin: Window-based efficient attention
Model Selection Guide¶
General classification: ViT-B/16 or DeiT
High resolution: Swin Transformer
Limited resources: ViT-Small, DeiT-Tiny
Best performance: ViT-Large, Swin-Large
PyTorch Practical Tips¶
# 1. Recommend using timm
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
# 2. Differential learning rates essential
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': 1e-5},
{'params': head_params, 'lr': 1e-3}
])
# 3. Pay attention to input size (224, 384, etc)
# 4. Use strong data augmentation
References¶
- ViT Original: https://arxiv.org/abs/2010.11929
- DeiT: https://arxiv.org/abs/2012.12877
- Swin Transformer: https://arxiv.org/abs/2103.14030
- timm Library: https://github.com/huggingface/pytorch-image-models