19_vit_implementation.py

Download
python 578 lines 16.7 KB
  1"""
  219. Vision Transformer (ViT) Implementation
  3
  4A minimal but complete implementation of Vision Transformer:
  5- Patch Embedding
  6- Position Embedding
  7- Multi-Head Self-Attention
  8- Transformer Encoder
  9- Classification Head
 10"""
 11
 12import torch
 13import torch.nn as nn
 14import torch.nn.functional as F
 15from torch.utils.data import DataLoader
 16from torchvision import datasets, transforms
 17import matplotlib.pyplot as plt
 18import numpy as np
 19import math
 20
 21print("=" * 60)
 22print("Vision Transformer (ViT) Implementation")
 23print("=" * 60)
 24
 25device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 26print(f"Using device: {device}")
 27
 28
 29# ============================================
 30# 1. Patch Embedding
 31# ============================================
 32print("\n[1] Patch Embedding")
 33print("-" * 40)
 34
 35
 36class PatchEmbedding(nn.Module):
 37    """Convert image to patches and embed them"""
 38    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
 39        super().__init__()
 40        self.img_size = img_size
 41        self.patch_size = patch_size
 42        self.num_patches = (img_size // patch_size) ** 2
 43
 44        # Use Conv2d for efficient patch extraction and embedding
 45        self.projection = nn.Conv2d(
 46            in_channels, embed_dim,
 47            kernel_size=patch_size, stride=patch_size
 48        )
 49
 50    def forward(self, x):
 51        # x: (B, C, H, W)
 52        x = self.projection(x)  # (B, embed_dim, H/P, W/P)
 53        x = x.flatten(2)        # (B, embed_dim, num_patches)
 54        x = x.transpose(1, 2)   # (B, num_patches, embed_dim)
 55        return x
 56
 57
 58# Test patch embedding
 59patch_embed = PatchEmbedding(img_size=224, patch_size=16, embed_dim=768)
 60test_img = torch.randn(2, 3, 224, 224)
 61patches = patch_embed(test_img)
 62print(f"Input image: {test_img.shape}")
 63print(f"Patch embeddings: {patches.shape}")
 64print(f"Number of patches: {patch_embed.num_patches}")
 65
 66
 67# ============================================
 68# 2. Multi-Head Self-Attention
 69# ============================================
 70print("\n[2] Multi-Head Self-Attention")
 71print("-" * 40)
 72
 73
 74class MultiHeadAttention(nn.Module):
 75    """Multi-Head Self-Attention"""
 76    def __init__(self, embed_dim, num_heads, dropout=0.0):
 77        super().__init__()
 78        self.embed_dim = embed_dim
 79        self.num_heads = num_heads
 80        self.head_dim = embed_dim // num_heads
 81        self.scale = self.head_dim ** -0.5
 82
 83        # QKV projection in one matrix for efficiency
 84        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
 85        self.proj = nn.Linear(embed_dim, embed_dim)
 86        self.dropout = nn.Dropout(dropout)
 87
 88    def forward(self, x):
 89        B, N, C = x.shape
 90
 91        # QKV: (B, N, 3*embed_dim) -> (B, N, 3, heads, head_dim)
 92        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
 93        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
 94        q, k, v = qkv[0], qkv[1], qkv[2]
 95
 96        # Attention: (B, heads, N, head_dim) @ (B, heads, head_dim, N) = (B, heads, N, N)
 97        attn = (q @ k.transpose(-2, -1)) * self.scale
 98        attn = attn.softmax(dim=-1)
 99        attn = self.dropout(attn)
100
101        # Output: (B, heads, N, N) @ (B, heads, N, head_dim) = (B, heads, N, head_dim)
102        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
103        x = self.proj(x)
104        return x
105
106
107# Test attention
108mha = MultiHeadAttention(embed_dim=768, num_heads=12)
109attn_out = mha(patches)
110print(f"Attention output: {attn_out.shape}")
111
112
113# ============================================
114# 3. Transformer Block
115# ============================================
116print("\n[3] Transformer Block")
117print("-" * 40)
118
119
120class MLP(nn.Module):
121    """MLP Block with GELU activation"""
122    def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0):
123        super().__init__()
124        hidden_dim = int(embed_dim * mlp_ratio)
125        self.fc1 = nn.Linear(embed_dim, hidden_dim)
126        self.fc2 = nn.Linear(hidden_dim, embed_dim)
127        self.dropout = nn.Dropout(dropout)
128
129    def forward(self, x):
130        x = self.fc1(x)
131        x = F.gelu(x)
132        x = self.dropout(x)
133        x = self.fc2(x)
134        x = self.dropout(x)
135        return x
136
137
138class TransformerBlock(nn.Module):
139    """Transformer Encoder Block"""
140    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
141        super().__init__()
142        self.norm1 = nn.LayerNorm(embed_dim)
143        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
144        self.norm2 = nn.LayerNorm(embed_dim)
145        self.mlp = MLP(embed_dim, mlp_ratio, dropout)
146
147    def forward(self, x):
148        # Pre-norm architecture
149        x = x + self.attn(self.norm1(x))
150        x = x + self.mlp(self.norm2(x))
151        return x
152
153
154# Test transformer block
155block = TransformerBlock(embed_dim=768, num_heads=12)
156block_out = block(patches)
157print(f"Transformer block output: {block_out.shape}")
158
159
160# ============================================
161# 4. Vision Transformer (Full Model)
162# ============================================
163print("\n[4] Vision Transformer Model")
164print("-" * 40)
165
166
167class VisionTransformer(nn.Module):
168    """Vision Transformer (ViT)"""
169    def __init__(
170        self,
171        img_size=224,
172        patch_size=16,
173        in_channels=3,
174        num_classes=1000,
175        embed_dim=768,
176        depth=12,
177        num_heads=12,
178        mlp_ratio=4.0,
179        dropout=0.0
180    ):
181        super().__init__()
182        self.num_patches = (img_size // patch_size) ** 2
183
184        # Patch Embedding
185        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
186
187        # CLS Token
188        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)
189
190        # Position Embedding
191        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)
192
193        self.dropout = nn.Dropout(dropout)
194
195        # Transformer Encoder
196        self.blocks = nn.ModuleList([
197            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
198            for _ in range(depth)
199        ])
200
201        self.norm = nn.LayerNorm(embed_dim)
202
203        # Classification Head
204        self.head = nn.Linear(embed_dim, num_classes)
205
206        self._init_weights()
207
208    def _init_weights(self):
209        # Initialize weights
210        for m in self.modules():
211            if isinstance(m, nn.Linear):
212                nn.init.trunc_normal_(m.weight, std=0.02)
213                if m.bias is not None:
214                    nn.init.zeros_(m.bias)
215            elif isinstance(m, nn.LayerNorm):
216                nn.init.ones_(m.weight)
217                nn.init.zeros_(m.bias)
218
219    def forward(self, x, return_features=False):
220        B = x.shape[0]
221
222        # Patch Embedding
223        x = self.patch_embed(x)  # (B, N, D)
224
225        # Add CLS Token
226        cls_tokens = self.cls_token.expand(B, -1, -1)
227        x = torch.cat([cls_tokens, x], dim=1)  # (B, N+1, D)
228
229        # Add Position Embedding
230        x = x + self.pos_embed
231        x = self.dropout(x)
232
233        # Transformer Blocks
234        for block in self.blocks:
235            x = block(x)
236
237        x = self.norm(x)
238
239        if return_features:
240            return x
241
242        # CLS Token for classification
243        cls_output = x[:, 0]
244        return self.head(cls_output)
245
246
247# Create different ViT variants
248def vit_tiny(num_classes=1000):
249    return VisionTransformer(
250        embed_dim=192, depth=12, num_heads=3, num_classes=num_classes
251    )
252
253
254def vit_small(num_classes=1000):
255    return VisionTransformer(
256        embed_dim=384, depth=12, num_heads=6, num_classes=num_classes
257    )
258
259
260def vit_base(num_classes=1000):
261    return VisionTransformer(
262        embed_dim=768, depth=12, num_heads=12, num_classes=num_classes
263    )
264
265
266# Test ViT
267model = vit_tiny(num_classes=10)
268test_output = model(test_img)
269print(f"ViT-Tiny input: {test_img.shape}")
270print(f"ViT-Tiny output: {test_output.shape}")
271print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
272
273
274# ============================================
275# 5. CIFAR-10 Training
276# ============================================
277print("\n[5] Training on CIFAR-10")
278print("-" * 40)
279
280
281# Custom ViT for CIFAR-10 (32x32 images)
282class ViTForCIFAR(nn.Module):
283    """ViT adapted for CIFAR-10 (32x32 images)"""
284    def __init__(self, num_classes=10):
285        super().__init__()
286        # Smaller patch size for 32x32 images
287        self.vit = VisionTransformer(
288            img_size=32,
289            patch_size=4,  # 32/4 = 8x8 = 64 patches
290            in_channels=3,
291            num_classes=num_classes,
292            embed_dim=256,
293            depth=6,
294            num_heads=8,
295            mlp_ratio=2.0,
296            dropout=0.1
297        )
298
299    def forward(self, x):
300        return self.vit(x)
301
302
303def train_vit_cifar10(epochs=10, batch_size=128, lr=1e-3):
304    """Train ViT on CIFAR-10"""
305    # Data
306    transform_train = transforms.Compose([
307        transforms.RandomCrop(32, padding=4),
308        transforms.RandomHorizontalFlip(),
309        transforms.ToTensor(),
310        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
311    ])
312
313    transform_test = transforms.Compose([
314        transforms.ToTensor(),
315        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
316    ])
317
318    train_data = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)
319    test_data = datasets.CIFAR10('data', train=False, transform=transform_test)
320
321    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
322    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=2)
323
324    # Model
325    model = ViTForCIFAR(num_classes=10).to(device)
326    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
327
328    # Optimizer with warmup
329    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
330    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
331    criterion = nn.CrossEntropyLoss()
332
333    # Training
334    train_losses = []
335    test_accs = []
336
337    for epoch in range(epochs):
338        # Train
339        model.train()
340        total_loss = 0
341
342        for images, labels in train_loader:
343            images, labels = images.to(device), labels.to(device)
344
345            optimizer.zero_grad()
346            outputs = model(images)
347            loss = criterion(outputs, labels)
348            loss.backward()
349
350            # Gradient clipping
351            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
352            optimizer.step()
353
354            total_loss += loss.item()
355
356        train_losses.append(total_loss / len(train_loader))
357
358        # Evaluate
359        model.eval()
360        correct = 0
361        total = 0
362
363        with torch.no_grad():
364            for images, labels in test_loader:
365                images, labels = images.to(device), labels.to(device)
366                outputs = model(images)
367                _, predicted = outputs.max(1)
368                total += labels.size(0)
369                correct += predicted.eq(labels).sum().item()
370
371        accuracy = 100. * correct / total
372        test_accs.append(accuracy)
373
374        print(f"Epoch {epoch+1}/{epochs}: Loss={train_losses[-1]:.4f}, Acc={accuracy:.2f}%")
375
376        scheduler.step()
377
378    return model, train_losses, test_accs
379
380
381# Train (reduced epochs for demo)
382print("\nStarting training...")
383model, losses, accs = train_vit_cifar10(epochs=5)
384
385
386# ============================================
387# 6. Visualizations
388# ============================================
389print("\n[6] Visualizations")
390print("-" * 40)
391
392# Plot training curves
393fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
394
395ax1.plot(losses)
396ax1.set_xlabel('Epoch')
397ax1.set_ylabel('Loss')
398ax1.set_title('Training Loss')
399
400ax2.plot(accs)
401ax2.set_xlabel('Epoch')
402ax2.set_ylabel('Accuracy (%)')
403ax2.set_title('Test Accuracy')
404
405plt.tight_layout()
406plt.savefig('vit_training.png', dpi=150)
407plt.close()
408print("Training curves saved to vit_training.png")
409
410
411# Visualize position embeddings
412def visualize_position_embeddings(model, filename='vit_pos_embed.png'):
413    """Visualize learned position embeddings"""
414    pos_embed = model.vit.pos_embed.detach().cpu()
415    # Remove CLS token
416    pos_embed = pos_embed[0, 1:]  # (N, D)
417
418    # Compute similarity
419    pos_embed_norm = pos_embed / pos_embed.norm(dim=1, keepdim=True)
420    similarity = pos_embed_norm @ pos_embed_norm.T
421
422    # Get grid size
423    num_patches = pos_embed.shape[0]
424    grid_size = int(num_patches ** 0.5)
425
426    # Visualize similarity for corner and center patches
427    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
428
429    patch_indices = [
430        (0, "Top-Left"),
431        (grid_size - 1, "Top-Right"),
432        (num_patches // 2 - grid_size // 2, "Center"),
433        (num_patches - grid_size, "Bottom-Left"),
434        (num_patches - 1, "Bottom-Right"),
435        (grid_size // 2, "Top-Center")
436    ]
437
438    for idx, (patch_idx, name) in enumerate(patch_indices):
439        ax = axes[idx // 3, idx % 3]
440        sim = similarity[patch_idx].reshape(grid_size, grid_size)
441        im = ax.imshow(sim.numpy(), cmap='viridis')
442        ax.set_title(f'{name} (patch {patch_idx})')
443        ax.axis('off')
444        plt.colorbar(im, ax=ax, fraction=0.046)
445
446    plt.suptitle('Position Embedding Similarity')
447    plt.tight_layout()
448    plt.savefig(filename, dpi=150)
449    plt.close()
450    print(f"Position embeddings saved to {filename}")
451
452
453visualize_position_embeddings(model)
454
455
456# ============================================
457# 7. Attention Visualization
458# ============================================
459print("\n[7] Attention Visualization")
460print("-" * 40)
461
462
463class ViTWithAttention(nn.Module):
464    """ViT that returns attention weights"""
465    def __init__(self, vit_model):
466        super().__init__()
467        self.vit = vit_model
468
469    def forward(self, x):
470        B = x.shape[0]
471
472        # Patch embedding
473        x = self.vit.vit.patch_embed(x)
474        cls_tokens = self.vit.vit.cls_token.expand(B, -1, -1)
475        x = torch.cat([cls_tokens, x], dim=1)
476        x = x + self.vit.vit.pos_embed
477
478        # Get attention from first block
479        attn_weights = []
480
481        for block in self.vit.vit.blocks:
482            # Extract attention weights manually
483            norm_x = block.norm1(x)
484            B, N, C = norm_x.shape
485            qkv = block.attn.qkv(norm_x).reshape(B, N, 3, block.attn.num_heads, block.attn.head_dim)
486            qkv = qkv.permute(2, 0, 3, 1, 4)
487            q, k, v = qkv[0], qkv[1], qkv[2]
488
489            attn = (q @ k.transpose(-2, -1)) * block.attn.scale
490            attn = attn.softmax(dim=-1)
491            attn_weights.append(attn)
492
493            x = x + block.attn(block.norm1(x))
494            x = x + block.mlp(block.norm2(x))
495
496        return attn_weights
497
498
499# Get attention for a sample image
500model.eval()
501test_data = datasets.CIFAR10('data', train=False, transform=transforms.Compose([
502    transforms.ToTensor(),
503    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
504]))
505
506sample_img, label = test_data[0]
507sample_img = sample_img.unsqueeze(0).to(device)
508
509vit_attn = ViTWithAttention(model)
510with torch.no_grad():
511    attentions = vit_attn(sample_img)
512
513# Visualize CLS token attention from last layer
514attn_last = attentions[-1][0]  # (heads, N, N)
515cls_attn = attn_last[:, 0, 1:]  # Attention from CLS to patches
516
517# Average over heads
518cls_attn_avg = cls_attn.mean(0).cpu()
519grid_size = int(cls_attn_avg.shape[0] ** 0.5)
520cls_attn_map = cls_attn_avg.reshape(grid_size, grid_size)
521
522# Plot
523fig, axes = plt.subplots(1, 2, figsize=(10, 4))
524
525# Original image
526orig_img = sample_img[0].cpu().permute(1, 2, 0)
527orig_img = (orig_img - orig_img.min()) / (orig_img.max() - orig_img.min())
528axes[0].imshow(orig_img)
529axes[0].set_title(f'Original (Label: {test_data.classes[label]})')
530axes[0].axis('off')
531
532# Attention map
533axes[1].imshow(cls_attn_map.numpy(), cmap='hot')
534axes[1].set_title('CLS Token Attention')
535axes[1].axis('off')
536
537plt.tight_layout()
538plt.savefig('vit_attention.png', dpi=150)
539plt.close()
540print("Attention visualization saved to vit_attention.png")
541
542
543# ============================================
544# Summary
545# ============================================
546print("\n" + "=" * 60)
547print("Vision Transformer Summary")
548print("=" * 60)
549
550summary = """
551Key Components:
5521. Patch Embedding: Image -> Patches -> Linear projection
5532. CLS Token: Learnable token for classification
5543. Position Embedding: Learnable position information
5554. Transformer Blocks: Multi-Head Attention + MLP
5565. Classification Head: Linear layer on CLS output
557
558ViT Variants:
559- ViT-Tiny: 192 dim, 3 heads, 12 layers (~5M params)
560- ViT-Small: 384 dim, 6 heads, 12 layers (~22M params)
561- ViT-Base: 768 dim, 12 heads, 12 layers (~86M params)
562- ViT-Large: 1024 dim, 16 heads, 24 layers (~307M params)
563
564Training Tips:
5651. Use AdamW optimizer with weight decay
5662. Learning rate warmup + cosine decay
5673. Strong data augmentation (RandAugment, Mixup)
5684. Gradient clipping
5695. Pre-training on large datasets helps significantly
570
571Output Files:
572- vit_training.png: Training loss and accuracy curves
573- vit_pos_embed.png: Position embedding similarity
574- vit_attention.png: CLS token attention visualization
575"""
576print(summary)
577print("=" * 60)