09. Transfer Learning

09. Transfer Learning

Previous: CNN Advanced | Next: CNN (LeNet)


Learning Objectives

  • Understand the concept and benefits of transfer learning
  • Utilize pretrained models
  • Learn fine-tuning strategies
  • Practical image classification project

1. What is Transfer Learning?

Concept

Model trained on ImageNet
        ↓
    Low-level features (edges, textures) → Reuse
        ↓
    High-level features → Adapt to new data
        ↓
    New classification task

Benefits

  • High performance with limited data
  • Faster training
  • Better generalization

2. Transfer Learning Strategies

Strategy 1: Feature Extraction

# Freeze pretrained model weights
for param in model.parameters():
    param.requires_grad = False

# Replace only the last layer
model.fc = nn.Linear(2048, num_classes)
  • Use pretrained features as-is
  • Train only the final classification layer
  • Suitable when data is limited

Strategy 2: Fine-tuning

# Train all or some layers
for param in model.parameters():
    param.requires_grad = True

# Use low learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
  • Start from pretrained weights
  • Fine-tune the entire network
  • Suitable when sufficient data available

Strategy 3: Gradual Unfreezing

# Step 1: Last layer only
for param in model.parameters():
    param.requires_grad = False
model.fc.requires_grad_(True)
train_for_epochs(5)

# Step 2: Last block too
model.layer4.requires_grad_(True)
train_for_epochs(5)

# Step 3: Entire network
model.requires_grad_(True)
train_for_epochs(10)

3. PyTorch Implementation

Basic Transfer Learning

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms, datasets

# 1. Load pretrained model
model = models.resnet50(weights='IMAGENET1K_V2')

# 2. Use as feature extractor (freeze weights)
for param in model.parameters():
    param.requires_grad = False

# 3. Replace last layer
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, num_classes)
)

Data Preprocessing

# Use ImageNet normalization
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

4. Training Strategies

Discriminative Learning Rates

# Different learning rates for each layer
optimizer = torch.optim.Adam([
    {'params': model.layer1.parameters(), 'lr': 1e-5},
    {'params': model.layer2.parameters(), 'lr': 5e-5},
    {'params': model.layer3.parameters(), 'lr': 1e-4},
    {'params': model.layer4.parameters(), 'lr': 5e-4},
    {'params': model.fc.parameters(), 'lr': 1e-3},
])

Learning Rate Scheduling

# Warmup + Cosine Decay
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1  # 10% warmup
)

5. Various Pretrained Models

torchvision Models

# Classification
resnet50 = models.resnet50(weights='IMAGENET1K_V2')
efficientnet = models.efficientnet_b0(weights='IMAGENET1K_V1')
vit = models.vit_b_16(weights='IMAGENET1K_V1')

# Object detection
fasterrcnn = models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')

# Segmentation
deeplabv3 = models.segmentation.deeplabv3_resnet50(weights='DEFAULT')

timm Library

import timm

# Check available models
print(timm.list_models('*efficientnet*'))

# Load model
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)

6. Practical Project: Flower Classification

Data Preparation

# Flowers102 dataset
from torchvision.datasets import Flowers102

train_data = Flowers102(
    root='data',
    split='train',
    transform=train_transform,
    download=True
)

test_data = Flowers102(
    root='data',
    split='test',
    transform=val_transform
)

Model and Training

class FlowerClassifier(nn.Module):
    def __init__(self, num_classes=102):
        super().__init__()
        self.backbone = models.efficientnet_b0(weights='IMAGENET1K_V1')

        # Replace last layer
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# Training
model = FlowerClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

7. Considerations

Strategy by Data Size

Data Size Strategy Description
Very Small (<1000) Feature Extraction Train only last layer
Small (1000-10000) Gradual Unfreezing Unfreeze from later layers
Medium (10000+) Full Fine-tuning Train all with low LR

Domain Similarity

Similar to ImageNet (animals, objects):
    → Can use shallow layers as-is

Different from ImageNet (medical, satellite):
    → Need to fine-tune deeper layers

Common Mistakes

  1. Missing ImageNet normalization
  2. Learning rate too high
  3. Forgetting to switch train/eval mode
  4. Including frozen weights in optimizer

8. Performance Improvement Tips

Data Augmentation

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    normalize
])

Label Smoothing

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

Mixup / CutMix

def mixup(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

Summary

Core Concepts

  1. Feature Extraction: Reuse pretrained features
  2. Fine-tuning: Adjust entire network with low LR
  3. Gradual Unfreezing: Sequential training from later layers

Checklist

  • [ ] Use ImageNet normalization
  • [ ] Choose appropriate learning rate (1e-4 ~ 1e-5)
  • [ ] Switch model.train() / model.eval()
  • [ ] Apply data augmentation
  • [ ] Set up early stopping

Next Steps

In 13_RNN_Basics.md, we'll learn recurrent neural networks.

to navigate between lessons