28_vit_vs_cnn.py

Download
python 474 lines 15.8 KB
  1"""
  2Vision Transformer (ViT) vs CNN Comparison on CIFAR-10
  3
  4This script compares Vision Transformers and Convolutional Neural Networks
  5on image classification. It implements both architectures from scratch and
  6provides detailed performance comparisons.
  7
  8Key Concepts:
  9- Vision Transformer: Patch embedding, positional encoding, transformer blocks
 10- CNN: Convolutional layers, batch normalization, residual connections
 11- Performance metrics: Accuracy, training time, parameter count
 12- Visualization: Attention maps (ViT) vs feature maps (CNN)
 13
 14Requirements:
 15    pip install torch torchvision matplotlib numpy
 16"""
 17
 18import argparse
 19import time
 20from typing import Tuple, List
 21import math
 22
 23import torch
 24import torch.nn as nn
 25import torch.nn.functional as F
 26from torch.utils.data import DataLoader
 27import torchvision
 28import torchvision.transforms as transforms
 29import matplotlib.pyplot as plt
 30import numpy as np
 31
 32
 33class PatchEmbedding(nn.Module):
 34    """Convert image into patches and embed them."""
 35
 36    def __init__(self, img_size: int = 32, patch_size: int = 4,
 37                 in_channels: int = 3, embed_dim: int = 128):
 38        super().__init__()
 39        self.img_size = img_size
 40        self.patch_size = patch_size
 41        self.n_patches = (img_size // patch_size) ** 2
 42
 43        # Linear projection of flattened patches
 44        self.projection = nn.Conv2d(
 45            in_channels, embed_dim,
 46            kernel_size=patch_size, stride=patch_size
 47        )
 48
 49    def forward(self, x: torch.Tensor) -> torch.Tensor:
 50        # x: (batch_size, channels, height, width)
 51        x = self.projection(x)  # (B, embed_dim, n_patches**0.5, n_patches**0.5)
 52        x = x.flatten(2)  # (B, embed_dim, n_patches)
 53        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
 54        return x
 55
 56
 57class MultiHeadAttention(nn.Module):
 58    """Multi-head self-attention mechanism."""
 59
 60    def __init__(self, embed_dim: int = 128, num_heads: int = 4, dropout: float = 0.1):
 61        super().__init__()
 62        self.embed_dim = embed_dim
 63        self.num_heads = num_heads
 64        self.head_dim = embed_dim // num_heads
 65        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
 66
 67        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
 68        self.projection = nn.Linear(embed_dim, embed_dim)
 69        self.dropout = nn.Dropout(dropout)
 70
 71    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 72        batch_size, seq_len, embed_dim = x.shape
 73
 74        # Compute Q, K, V
 75        qkv = self.qkv(x)  # (B, seq_len, 3*embed_dim)
 76        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
 77        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, num_heads, seq_len, head_dim)
 78        q, k, v = qkv[0], qkv[1], qkv[2]
 79
 80        # Scaled dot-product attention
 81        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
 82        attn_weights = F.softmax(scores, dim=-1)
 83        attn_weights = self.dropout(attn_weights)
 84
 85        # Apply attention to values
 86        attn_output = torch.matmul(attn_weights, v)  # (B, num_heads, seq_len, head_dim)
 87        attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
 88        output = self.projection(attn_output)
 89
 90        return output, attn_weights.mean(dim=1)  # Return averaged attention across heads
 91
 92
 93class TransformerBlock(nn.Module):
 94    """Transformer encoder block with self-attention and MLP."""
 95
 96    def __init__(self, embed_dim: int = 128, num_heads: int = 4,
 97                 mlp_ratio: float = 4.0, dropout: float = 0.1):
 98        super().__init__()
 99        self.norm1 = nn.LayerNorm(embed_dim)
100        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
101        self.norm2 = nn.LayerNorm(embed_dim)
102
103        mlp_hidden_dim = int(embed_dim * mlp_ratio)
104        self.mlp = nn.Sequential(
105            nn.Linear(embed_dim, mlp_hidden_dim),
106            nn.GELU(),
107            nn.Dropout(dropout),
108            nn.Linear(mlp_hidden_dim, embed_dim),
109            nn.Dropout(dropout)
110        )
111
112    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
113        # Self-attention with residual connection
114        attn_output, attn_weights = self.attn(self.norm1(x))
115        x = x + attn_output
116
117        # MLP with residual connection
118        x = x + self.mlp(self.norm2(x))
119
120        return x, attn_weights
121
122
123class VisionTransformer(nn.Module):
124    """Simple Vision Transformer for CIFAR-10."""
125
126    def __init__(self, img_size: int = 32, patch_size: int = 4, in_channels: int = 3,
127                 num_classes: int = 10, embed_dim: int = 128, depth: int = 6,
128                 num_heads: int = 4, mlp_ratio: float = 4.0, dropout: float = 0.1):
129        super().__init__()
130
131        # Patch embedding
132        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
133
134        # CLS token (learnable)
135        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
136
137        # Positional embedding
138        num_patches = self.patch_embed.n_patches
139        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
140        self.dropout = nn.Dropout(dropout)
141
142        # Transformer blocks
143        self.blocks = nn.ModuleList([
144            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
145            for _ in range(depth)
146        ])
147
148        # Classification head
149        self.norm = nn.LayerNorm(embed_dim)
150        self.head = nn.Linear(embed_dim, num_classes)
151
152        # Initialize weights
153        nn.init.trunc_normal_(self.cls_token, std=0.02)
154        nn.init.trunc_normal_(self.pos_embed, std=0.02)
155
156    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
157        batch_size = x.shape[0]
158
159        # Patch embedding
160        x = self.patch_embed(x)
161
162        # Add CLS token
163        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
164        x = torch.cat([cls_tokens, x], dim=1)
165
166        # Add positional embedding
167        x = x + self.pos_embed
168        x = self.dropout(x)
169
170        # Transformer blocks
171        attn_weights_list = []
172        for block in self.blocks:
173            x, attn_weights = block(x)
174            attn_weights_list.append(attn_weights)
175
176        # Classification using CLS token
177        x = self.norm(x)
178        cls_token_final = x[:, 0]
179        logits = self.head(cls_token_final)
180
181        return logits, attn_weights_list
182
183
184class SimpleCNN(nn.Module):
185    """Simple CNN baseline for comparison."""
186
187    def __init__(self, num_classes: int = 10):
188        super().__init__()
189
190        # Convolutional blocks
191        self.conv1 = self._make_conv_block(3, 64)
192        self.conv2 = self._make_conv_block(64, 128)
193        self.conv3 = self._make_conv_block(128, 256)
194        self.conv4 = self._make_conv_block(256, 512)
195
196        # Global average pooling and classifier
197        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
198        self.fc = nn.Linear(512, num_classes)
199
200    def _make_conv_block(self, in_channels: int, out_channels: int) -> nn.Sequential:
201        return nn.Sequential(
202            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
203            nn.BatchNorm2d(out_channels),
204            nn.ReLU(inplace=True),
205            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
206            nn.BatchNorm2d(out_channels),
207            nn.ReLU(inplace=True),
208            nn.MaxPool2d(kernel_size=2, stride=2)
209        )
210
211    def forward(self, x: torch.Tensor) -> torch.Tensor:
212        x = self.conv1(x)
213        x = self.conv2(x)
214        x = self.conv3(x)
215        x = self.conv4(x)
216
217        x = self.avgpool(x)
218        x = torch.flatten(x, 1)
219        x = self.fc(x)
220
221        return x
222
223
224def count_parameters(model: nn.Module) -> int:
225    """Count trainable parameters in a model."""
226    return sum(p.numel() for p in model.parameters() if p.requires_grad)
227
228
229def train_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
230                optimizer: torch.optim.Optimizer, device: torch.device) -> Tuple[float, float]:
231    """Train for one epoch."""
232    model.train()
233    running_loss = 0.0
234    correct = 0
235    total = 0
236
237    for inputs, labels in dataloader:
238        inputs, labels = inputs.to(device), labels.to(device)
239
240        optimizer.zero_grad()
241
242        # Forward pass (handle both ViT and CNN outputs)
243        outputs = model(inputs)
244        if isinstance(outputs, tuple):
245            outputs = outputs[0]  # ViT returns (logits, attention_weights)
246
247        loss = criterion(outputs, labels)
248
249        # Backward pass
250        loss.backward()
251        optimizer.step()
252
253        # Statistics
254        running_loss += loss.item()
255        _, predicted = outputs.max(1)
256        total += labels.size(0)
257        correct += predicted.eq(labels).sum().item()
258
259    epoch_loss = running_loss / len(dataloader)
260    epoch_acc = 100.0 * correct / total
261
262    return epoch_loss, epoch_acc
263
264
265def evaluate(model: nn.Module, dataloader: DataLoader,
266             criterion: nn.Module, device: torch.device) -> Tuple[float, float]:
267    """Evaluate model on validation/test set."""
268    model.eval()
269    running_loss = 0.0
270    correct = 0
271    total = 0
272
273    with torch.no_grad():
274        for inputs, labels in dataloader:
275            inputs, labels = inputs.to(device), labels.to(device)
276
277            outputs = model(inputs)
278            if isinstance(outputs, tuple):
279                outputs = outputs[0]
280
281            loss = criterion(outputs, labels)
282
283            running_loss += loss.item()
284            _, predicted = outputs.max(1)
285            total += labels.size(0)
286            correct += predicted.eq(labels).sum().item()
287
288    eval_loss = running_loss / len(dataloader)
289    eval_acc = 100.0 * correct / total
290
291    return eval_loss, eval_acc
292
293
294def visualize_attention(model: VisionTransformer, dataloader: DataLoader,
295                       device: torch.device, save_path: str = "vit_attention.png"):
296    """Visualize attention maps from ViT."""
297    model.eval()
298
299    # Get a batch of images
300    images, labels = next(iter(dataloader))
301    images = images.to(device)
302
303    with torch.no_grad():
304        _, attn_weights_list = model(images[:4])  # Use first 4 images
305
306    # Visualize attention from last layer
307    attn_weights = attn_weights_list[-1][:4].cpu()  # (4, seq_len, seq_len)
308
309    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
310
311    for idx in range(4):
312        # Original image
313        img = images[idx].cpu().permute(1, 2, 0).numpy()
314        img = (img - img.min()) / (img.max() - img.min())
315        axes[0, idx].imshow(img)
316        axes[0, idx].set_title(f"Image {idx + 1}")
317        axes[0, idx].axis("off")
318
319        # Attention map (CLS token attending to patches)
320        attn_map = attn_weights[idx, 0, 1:].reshape(8, 8)  # Skip CLS token itself
321        axes[1, idx].imshow(attn_map, cmap="viridis")
322        axes[1, idx].set_title("CLS Token Attention")
323        axes[1, idx].axis("off")
324
325    plt.tight_layout()
326    plt.savefig(save_path, dpi=150)
327    print(f"Attention visualization saved to {save_path}")
328
329
330def plot_comparison(vit_history: dict, cnn_history: dict,
331                   save_path: str = "comparison.png"):
332    """Plot training comparison between ViT and CNN."""
333    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
334
335    # Training loss
336    axes[0].plot(vit_history["train_loss"], label="ViT", marker="o")
337    axes[0].plot(cnn_history["train_loss"], label="CNN", marker="s")
338    axes[0].set_xlabel("Epoch")
339    axes[0].set_ylabel("Loss")
340    axes[0].set_title("Training Loss")
341    axes[0].legend()
342    axes[0].grid(True)
343
344    # Validation accuracy
345    axes[1].plot(vit_history["val_acc"], label="ViT", marker="o")
346    axes[1].plot(cnn_history["val_acc"], label="CNN", marker="s")
347    axes[1].set_xlabel("Epoch")
348    axes[1].set_ylabel("Accuracy (%)")
349    axes[1].set_title("Validation Accuracy")
350    axes[1].legend()
351    axes[1].grid(True)
352
353    plt.tight_layout()
354    plt.savefig(save_path, dpi=150)
355    print(f"Comparison plot saved to {save_path}")
356
357
358def main(args):
359    # Set device
360    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
361    print(f"Using device: {device}")
362
363    # Data preparation
364    transform_train = transforms.Compose([
365        transforms.RandomCrop(32, padding=4),
366        transforms.RandomHorizontalFlip(),
367        transforms.ToTensor(),
368        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
369    ])
370
371    transform_test = transforms.Compose([
372        transforms.ToTensor(),
373        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
374    ])
375
376    train_dataset = torchvision.datasets.CIFAR10(
377        root="./data", train=True, download=True, transform=transform_train
378    )
379    test_dataset = torchvision.datasets.CIFAR10(
380        root="./data", train=False, download=True, transform=transform_test
381    )
382
383    train_loader = DataLoader(
384        train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2
385    )
386    test_loader = DataLoader(
387        test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
388    )
389
390    # Initialize models
391    vit_model = VisionTransformer(
392        img_size=32, patch_size=4, num_classes=10,
393        embed_dim=128, depth=6, num_heads=4
394    ).to(device)
395
396    cnn_model = SimpleCNN(num_classes=10).to(device)
397
398    print(f"\nViT Parameters: {count_parameters(vit_model):,}")
399    print(f"CNN Parameters: {count_parameters(cnn_model):,}")
400
401    # Training setup
402    criterion = nn.CrossEntropyLoss()
403    vit_optimizer = torch.optim.AdamW(vit_model.parameters(), lr=args.lr, weight_decay=0.05)
404    cnn_optimizer = torch.optim.AdamW(cnn_model.parameters(), lr=args.lr, weight_decay=0.05)
405
406    # Learning rate schedulers
407    vit_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(vit_optimizer, T_max=args.epochs)
408    cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(cnn_optimizer, T_max=args.epochs)
409
410    # Training histories
411    vit_history = {"train_loss": [], "val_acc": []}
412    cnn_history = {"train_loss": [], "val_acc": []}
413
414    # Training loop
415    print("\n" + "="*50)
416    print("Starting Training")
417    print("="*50)
418
419    for epoch in range(args.epochs):
420        print(f"\nEpoch {epoch + 1}/{args.epochs}")
421
422        # Train ViT
423        vit_start = time.time()
424        vit_loss, vit_train_acc = train_epoch(
425            vit_model, train_loader, criterion, vit_optimizer, device
426        )
427        vit_time = time.time() - vit_start
428        vit_scheduler.step()
429
430        # Train CNN
431        cnn_start = time.time()
432        cnn_loss, cnn_train_acc = train_epoch(
433            cnn_model, train_loader, criterion, cnn_optimizer, device
434        )
435        cnn_time = time.time() - cnn_start
436        cnn_scheduler.step()
437
438        # Evaluate
439        _, vit_val_acc = evaluate(vit_model, test_loader, criterion, device)
440        _, cnn_val_acc = evaluate(cnn_model, test_loader, criterion, device)
441
442        # Store history
443        vit_history["train_loss"].append(vit_loss)
444        vit_history["val_acc"].append(vit_val_acc)
445        cnn_history["train_loss"].append(cnn_loss)
446        cnn_history["val_acc"].append(cnn_val_acc)
447
448        # Print results
449        print(f"ViT - Loss: {vit_loss:.4f}, Train Acc: {vit_train_acc:.2f}%, "
450              f"Val Acc: {vit_val_acc:.2f}%, Time: {vit_time:.2f}s")
451        print(f"CNN - Loss: {cnn_loss:.4f}, Train Acc: {cnn_train_acc:.2f}%, "
452              f"Val Acc: {cnn_val_acc:.2f}%, Time: {cnn_time:.2f}s")
453
454    # Final comparison
455    print("\n" + "="*50)
456    print("Final Results")
457    print("="*50)
458    print(f"ViT - Best Val Acc: {max(vit_history['val_acc']):.2f}%")
459    print(f"CNN - Best Val Acc: {max(cnn_history['val_acc']):.2f}%")
460
461    # Visualizations
462    visualize_attention(vit_model, test_loader, device)
463    plot_comparison(vit_history, cnn_history)
464
465
466if __name__ == "__main__":
467    parser = argparse.ArgumentParser(description="ViT vs CNN Comparison on CIFAR-10")
468    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
469    parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
470    parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
471
472    args = parser.parse_args()
473    main(args)