1"""
2Vision Transformer (ViT) vs CNN Comparison on CIFAR-10
3
4This script compares Vision Transformers and Convolutional Neural Networks
5on image classification. It implements both architectures from scratch and
6provides detailed performance comparisons.
7
8Key Concepts:
9- Vision Transformer: Patch embedding, positional encoding, transformer blocks
10- CNN: Convolutional layers, batch normalization, residual connections
11- Performance metrics: Accuracy, training time, parameter count
12- Visualization: Attention maps (ViT) vs feature maps (CNN)
13
14Requirements:
15 pip install torch torchvision matplotlib numpy
16"""
17
18import argparse
19import time
20from typing import Tuple, List
21import math
22
23import torch
24import torch.nn as nn
25import torch.nn.functional as F
26from torch.utils.data import DataLoader
27import torchvision
28import torchvision.transforms as transforms
29import matplotlib.pyplot as plt
30import numpy as np
31
32
33class PatchEmbedding(nn.Module):
34 """Convert image into patches and embed them."""
35
36 def __init__(self, img_size: int = 32, patch_size: int = 4,
37 in_channels: int = 3, embed_dim: int = 128):
38 super().__init__()
39 self.img_size = img_size
40 self.patch_size = patch_size
41 self.n_patches = (img_size // patch_size) ** 2
42
43 # Linear projection of flattened patches
44 self.projection = nn.Conv2d(
45 in_channels, embed_dim,
46 kernel_size=patch_size, stride=patch_size
47 )
48
49 def forward(self, x: torch.Tensor) -> torch.Tensor:
50 # x: (batch_size, channels, height, width)
51 x = self.projection(x) # (B, embed_dim, n_patches**0.5, n_patches**0.5)
52 x = x.flatten(2) # (B, embed_dim, n_patches)
53 x = x.transpose(1, 2) # (B, n_patches, embed_dim)
54 return x
55
56
57class MultiHeadAttention(nn.Module):
58 """Multi-head self-attention mechanism."""
59
60 def __init__(self, embed_dim: int = 128, num_heads: int = 4, dropout: float = 0.1):
61 super().__init__()
62 self.embed_dim = embed_dim
63 self.num_heads = num_heads
64 self.head_dim = embed_dim // num_heads
65 assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
66
67 self.qkv = nn.Linear(embed_dim, embed_dim * 3)
68 self.projection = nn.Linear(embed_dim, embed_dim)
69 self.dropout = nn.Dropout(dropout)
70
71 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
72 batch_size, seq_len, embed_dim = x.shape
73
74 # Compute Q, K, V
75 qkv = self.qkv(x) # (B, seq_len, 3*embed_dim)
76 qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
77 qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, num_heads, seq_len, head_dim)
78 q, k, v = qkv[0], qkv[1], qkv[2]
79
80 # Scaled dot-product attention
81 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
82 attn_weights = F.softmax(scores, dim=-1)
83 attn_weights = self.dropout(attn_weights)
84
85 # Apply attention to values
86 attn_output = torch.matmul(attn_weights, v) # (B, num_heads, seq_len, head_dim)
87 attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
88 output = self.projection(attn_output)
89
90 return output, attn_weights.mean(dim=1) # Return averaged attention across heads
91
92
93class TransformerBlock(nn.Module):
94 """Transformer encoder block with self-attention and MLP."""
95
96 def __init__(self, embed_dim: int = 128, num_heads: int = 4,
97 mlp_ratio: float = 4.0, dropout: float = 0.1):
98 super().__init__()
99 self.norm1 = nn.LayerNorm(embed_dim)
100 self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
101 self.norm2 = nn.LayerNorm(embed_dim)
102
103 mlp_hidden_dim = int(embed_dim * mlp_ratio)
104 self.mlp = nn.Sequential(
105 nn.Linear(embed_dim, mlp_hidden_dim),
106 nn.GELU(),
107 nn.Dropout(dropout),
108 nn.Linear(mlp_hidden_dim, embed_dim),
109 nn.Dropout(dropout)
110 )
111
112 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
113 # Self-attention with residual connection
114 attn_output, attn_weights = self.attn(self.norm1(x))
115 x = x + attn_output
116
117 # MLP with residual connection
118 x = x + self.mlp(self.norm2(x))
119
120 return x, attn_weights
121
122
123class VisionTransformer(nn.Module):
124 """Simple Vision Transformer for CIFAR-10."""
125
126 def __init__(self, img_size: int = 32, patch_size: int = 4, in_channels: int = 3,
127 num_classes: int = 10, embed_dim: int = 128, depth: int = 6,
128 num_heads: int = 4, mlp_ratio: float = 4.0, dropout: float = 0.1):
129 super().__init__()
130
131 # Patch embedding
132 self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
133
134 # CLS token (learnable)
135 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
136
137 # Positional embedding
138 num_patches = self.patch_embed.n_patches
139 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
140 self.dropout = nn.Dropout(dropout)
141
142 # Transformer blocks
143 self.blocks = nn.ModuleList([
144 TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
145 for _ in range(depth)
146 ])
147
148 # Classification head
149 self.norm = nn.LayerNorm(embed_dim)
150 self.head = nn.Linear(embed_dim, num_classes)
151
152 # Initialize weights
153 nn.init.trunc_normal_(self.cls_token, std=0.02)
154 nn.init.trunc_normal_(self.pos_embed, std=0.02)
155
156 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
157 batch_size = x.shape[0]
158
159 # Patch embedding
160 x = self.patch_embed(x)
161
162 # Add CLS token
163 cls_tokens = self.cls_token.expand(batch_size, -1, -1)
164 x = torch.cat([cls_tokens, x], dim=1)
165
166 # Add positional embedding
167 x = x + self.pos_embed
168 x = self.dropout(x)
169
170 # Transformer blocks
171 attn_weights_list = []
172 for block in self.blocks:
173 x, attn_weights = block(x)
174 attn_weights_list.append(attn_weights)
175
176 # Classification using CLS token
177 x = self.norm(x)
178 cls_token_final = x[:, 0]
179 logits = self.head(cls_token_final)
180
181 return logits, attn_weights_list
182
183
184class SimpleCNN(nn.Module):
185 """Simple CNN baseline for comparison."""
186
187 def __init__(self, num_classes: int = 10):
188 super().__init__()
189
190 # Convolutional blocks
191 self.conv1 = self._make_conv_block(3, 64)
192 self.conv2 = self._make_conv_block(64, 128)
193 self.conv3 = self._make_conv_block(128, 256)
194 self.conv4 = self._make_conv_block(256, 512)
195
196 # Global average pooling and classifier
197 self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
198 self.fc = nn.Linear(512, num_classes)
199
200 def _make_conv_block(self, in_channels: int, out_channels: int) -> nn.Sequential:
201 return nn.Sequential(
202 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
203 nn.BatchNorm2d(out_channels),
204 nn.ReLU(inplace=True),
205 nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
206 nn.BatchNorm2d(out_channels),
207 nn.ReLU(inplace=True),
208 nn.MaxPool2d(kernel_size=2, stride=2)
209 )
210
211 def forward(self, x: torch.Tensor) -> torch.Tensor:
212 x = self.conv1(x)
213 x = self.conv2(x)
214 x = self.conv3(x)
215 x = self.conv4(x)
216
217 x = self.avgpool(x)
218 x = torch.flatten(x, 1)
219 x = self.fc(x)
220
221 return x
222
223
224def count_parameters(model: nn.Module) -> int:
225 """Count trainable parameters in a model."""
226 return sum(p.numel() for p in model.parameters() if p.requires_grad)
227
228
229def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
230 optimizer: torch.optim.Optimizer, device: torch.device) -> Tuple[float, float]:
231 """Train for one epoch."""
232 model.train()
233 running_loss = 0.0
234 correct = 0
235 total = 0
236
237 for inputs, labels in dataloader:
238 inputs, labels = inputs.to(device), labels.to(device)
239
240 optimizer.zero_grad()
241
242 # Forward pass (handle both ViT and CNN outputs)
243 outputs = model(inputs)
244 if isinstance(outputs, tuple):
245 outputs = outputs[0] # ViT returns (logits, attention_weights)
246
247 loss = criterion(outputs, labels)
248
249 # Backward pass
250 loss.backward()
251 optimizer.step()
252
253 # Statistics
254 running_loss += loss.item()
255 _, predicted = outputs.max(1)
256 total += labels.size(0)
257 correct += predicted.eq(labels).sum().item()
258
259 epoch_loss = running_loss / len(dataloader)
260 epoch_acc = 100.0 * correct / total
261
262 return epoch_loss, epoch_acc
263
264
265def evaluate(model: nn.Module, dataloader: DataLoader,
266 criterion: nn.Module, device: torch.device) -> Tuple[float, float]:
267 """Evaluate model on validation/test set."""
268 model.eval()
269 running_loss = 0.0
270 correct = 0
271 total = 0
272
273 with torch.no_grad():
274 for inputs, labels in dataloader:
275 inputs, labels = inputs.to(device), labels.to(device)
276
277 outputs = model(inputs)
278 if isinstance(outputs, tuple):
279 outputs = outputs[0]
280
281 loss = criterion(outputs, labels)
282
283 running_loss += loss.item()
284 _, predicted = outputs.max(1)
285 total += labels.size(0)
286 correct += predicted.eq(labels).sum().item()
287
288 eval_loss = running_loss / len(dataloader)
289 eval_acc = 100.0 * correct / total
290
291 return eval_loss, eval_acc
292
293
294def visualize_attention(model: VisionTransformer, dataloader: DataLoader,
295 device: torch.device, save_path: str = "vit_attention.png"):
296 """Visualize attention maps from ViT."""
297 model.eval()
298
299 # Get a batch of images
300 images, labels = next(iter(dataloader))
301 images = images.to(device)
302
303 with torch.no_grad():
304 _, attn_weights_list = model(images[:4]) # Use first 4 images
305
306 # Visualize attention from last layer
307 attn_weights = attn_weights_list[-1][:4].cpu() # (4, seq_len, seq_len)
308
309 fig, axes = plt.subplots(2, 4, figsize=(16, 8))
310
311 for idx in range(4):
312 # Original image
313 img = images[idx].cpu().permute(1, 2, 0).numpy()
314 img = (img - img.min()) / (img.max() - img.min())
315 axes[0, idx].imshow(img)
316 axes[0, idx].set_title(f"Image {idx + 1}")
317 axes[0, idx].axis("off")
318
319 # Attention map (CLS token attending to patches)
320 attn_map = attn_weights[idx, 0, 1:].reshape(8, 8) # Skip CLS token itself
321 axes[1, idx].imshow(attn_map, cmap="viridis")
322 axes[1, idx].set_title("CLS Token Attention")
323 axes[1, idx].axis("off")
324
325 plt.tight_layout()
326 plt.savefig(save_path, dpi=150)
327 print(f"Attention visualization saved to {save_path}")
328
329
330def plot_comparison(vit_history: dict, cnn_history: dict,
331 save_path: str = "comparison.png"):
332 """Plot training comparison between ViT and CNN."""
333 fig, axes = plt.subplots(1, 2, figsize=(12, 4))
334
335 # Training loss
336 axes[0].plot(vit_history["train_loss"], label="ViT", marker="o")
337 axes[0].plot(cnn_history["train_loss"], label="CNN", marker="s")
338 axes[0].set_xlabel("Epoch")
339 axes[0].set_ylabel("Loss")
340 axes[0].set_title("Training Loss")
341 axes[0].legend()
342 axes[0].grid(True)
343
344 # Validation accuracy
345 axes[1].plot(vit_history["val_acc"], label="ViT", marker="o")
346 axes[1].plot(cnn_history["val_acc"], label="CNN", marker="s")
347 axes[1].set_xlabel("Epoch")
348 axes[1].set_ylabel("Accuracy (%)")
349 axes[1].set_title("Validation Accuracy")
350 axes[1].legend()
351 axes[1].grid(True)
352
353 plt.tight_layout()
354 plt.savefig(save_path, dpi=150)
355 print(f"Comparison plot saved to {save_path}")
356
357
358def main(args):
359 # Set device
360 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
361 print(f"Using device: {device}")
362
363 # Data preparation
364 transform_train = transforms.Compose([
365 transforms.RandomCrop(32, padding=4),
366 transforms.RandomHorizontalFlip(),
367 transforms.ToTensor(),
368 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
369 ])
370
371 transform_test = transforms.Compose([
372 transforms.ToTensor(),
373 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
374 ])
375
376 train_dataset = torchvision.datasets.CIFAR10(
377 root="./data", train=True, download=True, transform=transform_train
378 )
379 test_dataset = torchvision.datasets.CIFAR10(
380 root="./data", train=False, download=True, transform=transform_test
381 )
382
383 train_loader = DataLoader(
384 train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2
385 )
386 test_loader = DataLoader(
387 test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
388 )
389
390 # Initialize models
391 vit_model = VisionTransformer(
392 img_size=32, patch_size=4, num_classes=10,
393 embed_dim=128, depth=6, num_heads=4
394 ).to(device)
395
396 cnn_model = SimpleCNN(num_classes=10).to(device)
397
398 print(f"\nViT Parameters: {count_parameters(vit_model):,}")
399 print(f"CNN Parameters: {count_parameters(cnn_model):,}")
400
401 # Training setup
402 criterion = nn.CrossEntropyLoss()
403 vit_optimizer = torch.optim.AdamW(vit_model.parameters(), lr=args.lr, weight_decay=0.05)
404 cnn_optimizer = torch.optim.AdamW(cnn_model.parameters(), lr=args.lr, weight_decay=0.05)
405
406 # Learning rate schedulers
407 vit_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(vit_optimizer, T_max=args.epochs)
408 cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(cnn_optimizer, T_max=args.epochs)
409
410 # Training histories
411 vit_history = {"train_loss": [], "val_acc": []}
412 cnn_history = {"train_loss": [], "val_acc": []}
413
414 # Training loop
415 print("\n" + "="*50)
416 print("Starting Training")
417 print("="*50)
418
419 for epoch in range(args.epochs):
420 print(f"\nEpoch {epoch + 1}/{args.epochs}")
421
422 # Train ViT
423 vit_start = time.time()
424 vit_loss, vit_train_acc = train_epoch(
425 vit_model, train_loader, criterion, vit_optimizer, device
426 )
427 vit_time = time.time() - vit_start
428 vit_scheduler.step()
429
430 # Train CNN
431 cnn_start = time.time()
432 cnn_loss, cnn_train_acc = train_epoch(
433 cnn_model, train_loader, criterion, cnn_optimizer, device
434 )
435 cnn_time = time.time() - cnn_start
436 cnn_scheduler.step()
437
438 # Evaluate
439 _, vit_val_acc = evaluate(vit_model, test_loader, criterion, device)
440 _, cnn_val_acc = evaluate(cnn_model, test_loader, criterion, device)
441
442 # Store history
443 vit_history["train_loss"].append(vit_loss)
444 vit_history["val_acc"].append(vit_val_acc)
445 cnn_history["train_loss"].append(cnn_loss)
446 cnn_history["val_acc"].append(cnn_val_acc)
447
448 # Print results
449 print(f"ViT - Loss: {vit_loss:.4f}, Train Acc: {vit_train_acc:.2f}%, "
450 f"Val Acc: {vit_val_acc:.2f}%, Time: {vit_time:.2f}s")
451 print(f"CNN - Loss: {cnn_loss:.4f}, Train Acc: {cnn_train_acc:.2f}%, "
452 f"Val Acc: {cnn_val_acc:.2f}%, Time: {cnn_time:.2f}s")
453
454 # Final comparison
455 print("\n" + "="*50)
456 print("Final Results")
457 print("="*50)
458 print(f"ViT - Best Val Acc: {max(vit_history['val_acc']):.2f}%")
459 print(f"CNN - Best Val Acc: {max(cnn_history['val_acc']):.2f}%")
460
461 # Visualizations
462 visualize_attention(vit_model, test_loader, device)
463 plot_comparison(vit_history, cnn_history)
464
465
466if __name__ == "__main__":
467 parser = argparse.ArgumentParser(description="ViT vs CNN Comparison on CIFAR-10")
468 parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
469 parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
470 parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
471
472 args = parser.parse_args()
473 main(args)