1"""
219. Vision Transformer (ViT) Implementation
3
4A minimal but complete implementation of Vision Transformer:
5- Patch Embedding
6- Position Embedding
7- Multi-Head Self-Attention
8- Transformer Encoder
9- Classification Head
10"""
11
12import torch
13import torch.nn as nn
14import torch.nn.functional as F
15from torch.utils.data import DataLoader
16from torchvision import datasets, transforms
17import matplotlib.pyplot as plt
18import numpy as np
19import math
20
21print("=" * 60)
22print("Vision Transformer (ViT) Implementation")
23print("=" * 60)
24
25device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26print(f"Using device: {device}")
27
28
29# ============================================
30# 1. Patch Embedding
31# ============================================
32print("\n[1] Patch Embedding")
33print("-" * 40)
34
35
36class PatchEmbedding(nn.Module):
37 """Convert image to patches and embed them"""
38 def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
39 super().__init__()
40 self.img_size = img_size
41 self.patch_size = patch_size
42 self.num_patches = (img_size // patch_size) ** 2
43
44 # Use Conv2d for efficient patch extraction and embedding
45 self.projection = nn.Conv2d(
46 in_channels, embed_dim,
47 kernel_size=patch_size, stride=patch_size
48 )
49
50 def forward(self, x):
51 # x: (B, C, H, W)
52 x = self.projection(x) # (B, embed_dim, H/P, W/P)
53 x = x.flatten(2) # (B, embed_dim, num_patches)
54 x = x.transpose(1, 2) # (B, num_patches, embed_dim)
55 return x
56
57
58# Test patch embedding
59patch_embed = PatchEmbedding(img_size=224, patch_size=16, embed_dim=768)
60test_img = torch.randn(2, 3, 224, 224)
61patches = patch_embed(test_img)
62print(f"Input image: {test_img.shape}")
63print(f"Patch embeddings: {patches.shape}")
64print(f"Number of patches: {patch_embed.num_patches}")
65
66
67# ============================================
68# 2. Multi-Head Self-Attention
69# ============================================
70print("\n[2] Multi-Head Self-Attention")
71print("-" * 40)
72
73
74class MultiHeadAttention(nn.Module):
75 """Multi-Head Self-Attention"""
76 def __init__(self, embed_dim, num_heads, dropout=0.0):
77 super().__init__()
78 self.embed_dim = embed_dim
79 self.num_heads = num_heads
80 self.head_dim = embed_dim // num_heads
81 self.scale = self.head_dim ** -0.5
82
83 # QKV projection in one matrix for efficiency
84 self.qkv = nn.Linear(embed_dim, embed_dim * 3)
85 self.proj = nn.Linear(embed_dim, embed_dim)
86 self.dropout = nn.Dropout(dropout)
87
88 def forward(self, x):
89 B, N, C = x.shape
90
91 # QKV: (B, N, 3*embed_dim) -> (B, N, 3, heads, head_dim)
92 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
93 qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
94 q, k, v = qkv[0], qkv[1], qkv[2]
95
96 # Attention: (B, heads, N, head_dim) @ (B, heads, head_dim, N) = (B, heads, N, N)
97 attn = (q @ k.transpose(-2, -1)) * self.scale
98 attn = attn.softmax(dim=-1)
99 attn = self.dropout(attn)
100
101 # Output: (B, heads, N, N) @ (B, heads, N, head_dim) = (B, heads, N, head_dim)
102 x = (attn @ v).transpose(1, 2).reshape(B, N, C)
103 x = self.proj(x)
104 return x
105
106
107# Test attention
108mha = MultiHeadAttention(embed_dim=768, num_heads=12)
109attn_out = mha(patches)
110print(f"Attention output: {attn_out.shape}")
111
112
113# ============================================
114# 3. Transformer Block
115# ============================================
116print("\n[3] Transformer Block")
117print("-" * 40)
118
119
120class MLP(nn.Module):
121 """MLP Block with GELU activation"""
122 def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0):
123 super().__init__()
124 hidden_dim = int(embed_dim * mlp_ratio)
125 self.fc1 = nn.Linear(embed_dim, hidden_dim)
126 self.fc2 = nn.Linear(hidden_dim, embed_dim)
127 self.dropout = nn.Dropout(dropout)
128
129 def forward(self, x):
130 x = self.fc1(x)
131 x = F.gelu(x)
132 x = self.dropout(x)
133 x = self.fc2(x)
134 x = self.dropout(x)
135 return x
136
137
138class TransformerBlock(nn.Module):
139 """Transformer Encoder Block"""
140 def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
141 super().__init__()
142 self.norm1 = nn.LayerNorm(embed_dim)
143 self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
144 self.norm2 = nn.LayerNorm(embed_dim)
145 self.mlp = MLP(embed_dim, mlp_ratio, dropout)
146
147 def forward(self, x):
148 # Pre-norm architecture
149 x = x + self.attn(self.norm1(x))
150 x = x + self.mlp(self.norm2(x))
151 return x
152
153
154# Test transformer block
155block = TransformerBlock(embed_dim=768, num_heads=12)
156block_out = block(patches)
157print(f"Transformer block output: {block_out.shape}")
158
159
160# ============================================
161# 4. Vision Transformer (Full Model)
162# ============================================
163print("\n[4] Vision Transformer Model")
164print("-" * 40)
165
166
167class VisionTransformer(nn.Module):
168 """Vision Transformer (ViT)"""
169 def __init__(
170 self,
171 img_size=224,
172 patch_size=16,
173 in_channels=3,
174 num_classes=1000,
175 embed_dim=768,
176 depth=12,
177 num_heads=12,
178 mlp_ratio=4.0,
179 dropout=0.0
180 ):
181 super().__init__()
182 self.num_patches = (img_size // patch_size) ** 2
183
184 # Patch Embedding
185 self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
186
187 # CLS Token
188 self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
189
190 # Position Embedding
191 self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)
192
193 self.dropout = nn.Dropout(dropout)
194
195 # Transformer Encoder
196 self.blocks = nn.ModuleList([
197 TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
198 for _ in range(depth)
199 ])
200
201 self.norm = nn.LayerNorm(embed_dim)
202
203 # Classification Head
204 self.head = nn.Linear(embed_dim, num_classes)
205
206 self._init_weights()
207
208 def _init_weights(self):
209 # Initialize weights
210 for m in self.modules():
211 if isinstance(m, nn.Linear):
212 nn.init.trunc_normal_(m.weight, std=0.02)
213 if m.bias is not None:
214 nn.init.zeros_(m.bias)
215 elif isinstance(m, nn.LayerNorm):
216 nn.init.ones_(m.weight)
217 nn.init.zeros_(m.bias)
218
219 def forward(self, x, return_features=False):
220 B = x.shape[0]
221
222 # Patch Embedding
223 x = self.patch_embed(x) # (B, N, D)
224
225 # Add CLS Token
226 cls_tokens = self.cls_token.expand(B, -1, -1)
227 x = torch.cat([cls_tokens, x], dim=1) # (B, N+1, D)
228
229 # Add Position Embedding
230 x = x + self.pos_embed
231 x = self.dropout(x)
232
233 # Transformer Blocks
234 for block in self.blocks:
235 x = block(x)
236
237 x = self.norm(x)
238
239 if return_features:
240 return x
241
242 # CLS Token for classification
243 cls_output = x[:, 0]
244 return self.head(cls_output)
245
246
247# Create different ViT variants
248def vit_tiny(num_classes=1000):
249 return VisionTransformer(
250 embed_dim=192, depth=12, num_heads=3, num_classes=num_classes
251 )
252
253
254def vit_small(num_classes=1000):
255 return VisionTransformer(
256 embed_dim=384, depth=12, num_heads=6, num_classes=num_classes
257 )
258
259
260def vit_base(num_classes=1000):
261 return VisionTransformer(
262 embed_dim=768, depth=12, num_heads=12, num_classes=num_classes
263 )
264
265
266# Test ViT
267model = vit_tiny(num_classes=10)
268test_output = model(test_img)
269print(f"ViT-Tiny input: {test_img.shape}")
270print(f"ViT-Tiny output: {test_output.shape}")
271print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
272
273
274# ============================================
275# 5. CIFAR-10 Training
276# ============================================
277print("\n[5] Training on CIFAR-10")
278print("-" * 40)
279
280
281# Custom ViT for CIFAR-10 (32x32 images)
282class ViTForCIFAR(nn.Module):
283 """ViT adapted for CIFAR-10 (32x32 images)"""
284 def __init__(self, num_classes=10):
285 super().__init__()
286 # Smaller patch size for 32x32 images
287 self.vit = VisionTransformer(
288 img_size=32,
289 patch_size=4, # 32/4 = 8x8 = 64 patches
290 in_channels=3,
291 num_classes=num_classes,
292 embed_dim=256,
293 depth=6,
294 num_heads=8,
295 mlp_ratio=2.0,
296 dropout=0.1
297 )
298
299 def forward(self, x):
300 return self.vit(x)
301
302
303def train_vit_cifar10(epochs=10, batch_size=128, lr=1e-3):
304 """Train ViT on CIFAR-10"""
305 # Data
306 transform_train = transforms.Compose([
307 transforms.RandomCrop(32, padding=4),
308 transforms.RandomHorizontalFlip(),
309 transforms.ToTensor(),
310 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
311 ])
312
313 transform_test = transforms.Compose([
314 transforms.ToTensor(),
315 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
316 ])
317
318 train_data = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)
319 test_data = datasets.CIFAR10('data', train=False, transform=transform_test)
320
321 train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
322 test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
323
324 # Model
325 model = ViTForCIFAR(num_classes=10).to(device)
326 print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
327
328 # Optimizer with warmup
329 optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
330 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
331 criterion = nn.CrossEntropyLoss()
332
333 # Training
334 train_losses = []
335 test_accs = []
336
337 for epoch in range(epochs):
338 # Train
339 model.train()
340 total_loss = 0
341
342 for images, labels in train_loader:
343 images, labels = images.to(device), labels.to(device)
344
345 optimizer.zero_grad()
346 outputs = model(images)
347 loss = criterion(outputs, labels)
348 loss.backward()
349
350 # Gradient clipping
351 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
352 optimizer.step()
353
354 total_loss += loss.item()
355
356 train_losses.append(total_loss / len(train_loader))
357
358 # Evaluate
359 model.eval()
360 correct = 0
361 total = 0
362
363 with torch.no_grad():
364 for images, labels in test_loader:
365 images, labels = images.to(device), labels.to(device)
366 outputs = model(images)
367 _, predicted = outputs.max(1)
368 total += labels.size(0)
369 correct += predicted.eq(labels).sum().item()
370
371 accuracy = 100. * correct / total
372 test_accs.append(accuracy)
373
374 print(f"Epoch {epoch+1}/{epochs}: Loss={train_losses[-1]:.4f}, Acc={accuracy:.2f}%")
375
376 scheduler.step()
377
378 return model, train_losses, test_accs
379
380
381# Train (reduced epochs for demo)
382print("\nStarting training...")
383model, losses, accs = train_vit_cifar10(epochs=5)
384
385
386# ============================================
387# 6. Visualizations
388# ============================================
389print("\n[6] Visualizations")
390print("-" * 40)
391
392# Plot training curves
393fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
394
395ax1.plot(losses)
396ax1.set_xlabel('Epoch')
397ax1.set_ylabel('Loss')
398ax1.set_title('Training Loss')
399
400ax2.plot(accs)
401ax2.set_xlabel('Epoch')
402ax2.set_ylabel('Accuracy (%)')
403ax2.set_title('Test Accuracy')
404
405plt.tight_layout()
406plt.savefig('vit_training.png', dpi=150)
407plt.close()
408print("Training curves saved to vit_training.png")
409
410
411# Visualize position embeddings
412def visualize_position_embeddings(model, filename='vit_pos_embed.png'):
413 """Visualize learned position embeddings"""
414 pos_embed = model.vit.pos_embed.detach().cpu()
415 # Remove CLS token
416 pos_embed = pos_embed[0, 1:] # (N, D)
417
418 # Compute similarity
419 pos_embed_norm = pos_embed / pos_embed.norm(dim=1, keepdim=True)
420 similarity = pos_embed_norm @ pos_embed_norm.T
421
422 # Get grid size
423 num_patches = pos_embed.shape[0]
424 grid_size = int(num_patches ** 0.5)
425
426 # Visualize similarity for corner and center patches
427 fig, axes = plt.subplots(2, 3, figsize=(12, 8))
428
429 patch_indices = [
430 (0, "Top-Left"),
431 (grid_size - 1, "Top-Right"),
432 (num_patches // 2 - grid_size // 2, "Center"),
433 (num_patches - grid_size, "Bottom-Left"),
434 (num_patches - 1, "Bottom-Right"),
435 (grid_size // 2, "Top-Center")
436 ]
437
438 for idx, (patch_idx, name) in enumerate(patch_indices):
439 ax = axes[idx // 3, idx % 3]
440 sim = similarity[patch_idx].reshape(grid_size, grid_size)
441 im = ax.imshow(sim.numpy(), cmap='viridis')
442 ax.set_title(f'{name} (patch {patch_idx})')
443 ax.axis('off')
444 plt.colorbar(im, ax=ax, fraction=0.046)
445
446 plt.suptitle('Position Embedding Similarity')
447 plt.tight_layout()
448 plt.savefig(filename, dpi=150)
449 plt.close()
450 print(f"Position embeddings saved to {filename}")
451
452
453visualize_position_embeddings(model)
454
455
456# ============================================
457# 7. Attention Visualization
458# ============================================
459print("\n[7] Attention Visualization")
460print("-" * 40)
461
462
463class ViTWithAttention(nn.Module):
464 """ViT that returns attention weights"""
465 def __init__(self, vit_model):
466 super().__init__()
467 self.vit = vit_model
468
469 def forward(self, x):
470 B = x.shape[0]
471
472 # Patch embedding
473 x = self.vit.vit.patch_embed(x)
474 cls_tokens = self.vit.vit.cls_token.expand(B, -1, -1)
475 x = torch.cat([cls_tokens, x], dim=1)
476 x = x + self.vit.vit.pos_embed
477
478 # Get attention from first block
479 attn_weights = []
480
481 for block in self.vit.vit.blocks:
482 # Extract attention weights manually
483 norm_x = block.norm1(x)
484 B, N, C = norm_x.shape
485 qkv = block.attn.qkv(norm_x).reshape(B, N, 3, block.attn.num_heads, block.attn.head_dim)
486 qkv = qkv.permute(2, 0, 3, 1, 4)
487 q, k, v = qkv[0], qkv[1], qkv[2]
488
489 attn = (q @ k.transpose(-2, -1)) * block.attn.scale
490 attn = attn.softmax(dim=-1)
491 attn_weights.append(attn)
492
493 x = x + block.attn(block.norm1(x))
494 x = x + block.mlp(block.norm2(x))
495
496 return attn_weights
497
498
499# Get attention for a sample image
500model.eval()
501test_data = datasets.CIFAR10('data', train=False, transform=transforms.Compose([
502 transforms.ToTensor(),
503 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
504]))
505
506sample_img, label = test_data[0]
507sample_img = sample_img.unsqueeze(0).to(device)
508
509vit_attn = ViTWithAttention(model)
510with torch.no_grad():
511 attentions = vit_attn(sample_img)
512
513# Visualize CLS token attention from last layer
514attn_last = attentions[-1][0] # (heads, N, N)
515cls_attn = attn_last[:, 0, 1:] # Attention from CLS to patches
516
517# Average over heads
518cls_attn_avg = cls_attn.mean(0).cpu()
519grid_size = int(cls_attn_avg.shape[0] ** 0.5)
520cls_attn_map = cls_attn_avg.reshape(grid_size, grid_size)
521
522# Plot
523fig, axes = plt.subplots(1, 2, figsize=(10, 4))
524
525# Original image
526orig_img = sample_img[0].cpu().permute(1, 2, 0)
527orig_img = (orig_img - orig_img.min()) / (orig_img.max() - orig_img.min())
528axes[0].imshow(orig_img)
529axes[0].set_title(f'Original (Label: {test_data.classes[label]})')
530axes[0].axis('off')
531
532# Attention map
533axes[1].imshow(cls_attn_map.numpy(), cmap='hot')
534axes[1].set_title('CLS Token Attention')
535axes[1].axis('off')
536
537plt.tight_layout()
538plt.savefig('vit_attention.png', dpi=150)
539plt.close()
540print("Attention visualization saved to vit_attention.png")
541
542
543# ============================================
544# Summary
545# ============================================
546print("\n" + "=" * 60)
547print("Vision Transformer Summary")
548print("=" * 60)
549
550summary = """
551Key Components:
5521. Patch Embedding: Image -> Patches -> Linear projection
5532. CLS Token: Learnable token for classification
5543. Position Embedding: Learnable position information
5554. Transformer Blocks: Multi-Head Attention + MLP
5565. Classification Head: Linear layer on CLS output
557
558ViT Variants:
559- ViT-Tiny: 192 dim, 3 heads, 12 layers (~5M params)
560- ViT-Small: 384 dim, 6 heads, 12 layers (~22M params)
561- ViT-Base: 768 dim, 12 heads, 12 layers (~86M params)
562- ViT-Large: 1024 dim, 16 heads, 24 layers (~307M params)
563
564Training Tips:
5651. Use AdamW optimizer with weight decay
5662. Learning rate warmup + cosine decay
5673. Strong data augmentation (RandAugment, Mixup)
5684. Gradient clipping
5695. Pre-training on large datasets helps significantly
570
571Output Files:
572- vit_training.png: Training loss and accuracy curves
573- vit_pos_embed.png: Position embedding similarity
574- vit_attention.png: CLS token attention visualization
575"""
576print(summary)
577print("=" * 60)