07. ์ „์ดํ•™์Šต (Transfer Learning)

07. ์ „์ดํ•™์Šต (Transfer Learning)

ํ•™์Šต ๋ชฉํ‘œ

  • ์ „์ดํ•™์Šต์˜ ๊ฐœ๋…๊ณผ ์ด์ 
  • ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ ํ™œ์šฉ
  • ๋ฏธ์„ธ ์กฐ์ •(Fine-tuning) ์ „๋žต
  • ์‹ค์ „ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ํ”„๋กœ์ ํŠธ

1. ์ „์ดํ•™์Šต์ด๋ž€?

๊ฐœ๋…

ImageNet์œผ๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ
        โ†“
    ์ €์ˆ˜์ค€ ํŠน์ง• (์—์ง€, ํ…์Šค์ฒ˜) โ†’ ์žฌ์‚ฌ์šฉ
        โ†“
    ๊ณ ์ˆ˜์ค€ ํŠน์ง• โ†’ ์ƒˆ ๋ฐ์ดํ„ฐ์— ๋งž๊ฒŒ ์กฐ์ •
        โ†“
    ์ƒˆ๋กœ์šด ๋ถ„๋ฅ˜ ์ž‘์—…

์ด์ 

  • ์ ์€ ๋ฐ์ดํ„ฐ๋กœ๋„ ๋†’์€ ์„ฑ๋Šฅ
  • ๋น ๋ฅธ ํ•™์Šต
  • ๋” ๋‚˜์€ ์ผ๋ฐ˜ํ™”

2. ์ „์ดํ•™์Šต ์ „๋žต

์ „๋žต 1: ํŠน์„ฑ ์ถ”์ถœ (Feature Extraction)

# ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ์˜ ๊ฐ€์ค‘์น˜ ๊ณ ์ •
for param in model.parameters():
    param.requires_grad = False

# ๋งˆ์ง€๋ง‰ ์ธต๋งŒ ๊ต์ฒด
model.fc = nn.Linear(2048, num_classes)
  • ์‚ฌ์ „ ํ•™์Šต๋œ ํŠน์ง• ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
  • ๋งˆ์ง€๋ง‰ ๋ถ„๋ฅ˜์ธต๋งŒ ํ•™์Šต
  • ๋ฐ์ดํ„ฐ๊ฐ€ ์ ์„ ๋•Œ ์ ํ•ฉ

์ „๋žต 2: ๋ฏธ์„ธ ์กฐ์ • (Fine-tuning)

# ์ „์ฒด ๋˜๋Š” ์ผ๋ถ€ ์ธต ํ•™์Šต
for param in model.parameters():
    param.requires_grad = True

# ๋‚ฎ์€ ํ•™์Šต๋ฅ  ์‚ฌ์šฉ
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
  • ์‚ฌ์ „ ํ•™์Šต ๊ฐ€์ค‘์น˜๋ฅผ ์‹œ์ž‘์ ์œผ๋กœ
  • ์ „์ฒด ๋„คํŠธ์›Œํฌ ๋ฏธ์„ธ ์กฐ์ •
  • ๋ฐ์ดํ„ฐ๊ฐ€ ์ถฉ๋ถ„ํ•  ๋•Œ ์ ํ•ฉ

์ „๋žต 3: ์ ์ง„์  ํ•ด๋™ (Gradual Unfreezing)

# 1๋‹จ๊ณ„: ๋งˆ์ง€๋ง‰ ์ธต๋งŒ
for param in model.parameters():
    param.requires_grad = False
model.fc.requires_grad_(True)
train_for_epochs(5)

# 2๋‹จ๊ณ„: ๋งˆ์ง€๋ง‰ ๋ธ”๋ก๋„
model.layer4.requires_grad_(True)
train_for_epochs(5)

# 3๋‹จ๊ณ„: ์ „์ฒด
model.requires_grad_(True)
train_for_epochs(10)

3. PyTorch ๊ตฌํ˜„

๊ธฐ๋ณธ ์ „์ดํ•™์Šต

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms, datasets

# 1. ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ ๋กœ๋“œ
model = models.resnet50(weights='IMAGENET1K_V2')

# 2. ํŠน์„ฑ ์ถ”์ถœ๊ธฐ๋กœ ์‚ฌ์šฉ (๊ฐ€์ค‘์น˜ ๊ณ ์ •)
for param in model.parameters():
    param.requires_grad = False

# 3. ๋งˆ์ง€๋ง‰ ์ธต ๊ต์ฒด
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, num_classes)
)

๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

# ImageNet ์ •๊ทœํ™” ์‚ฌ์šฉ
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

4. ํ•™์Šต ์ „๋žต

์ฐจ๋“ฑ ํ•™์Šต๋ฅ  (Discriminative Learning Rates)

# ์ธต๋ณ„ ๋‹ค๋ฅธ ํ•™์Šต๋ฅ 
optimizer = torch.optim.Adam([
    {'params': model.layer1.parameters(), 'lr': 1e-5},
    {'params': model.layer2.parameters(), 'lr': 5e-5},
    {'params': model.layer3.parameters(), 'lr': 1e-4},
    {'params': model.layer4.parameters(), 'lr': 5e-4},
    {'params': model.fc.parameters(), 'lr': 1e-3},
])

ํ•™์Šต๋ฅ  ์Šค์ผ€์ค„๋ง

# Warmup + Cosine Decay
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.1  # 10% warmup
)

5. ๋‹ค์–‘ํ•œ ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ

torchvision ๋ชจ๋ธ

# ๋ถ„๋ฅ˜์šฉ
resnet50 = models.resnet50(weights='IMAGENET1K_V2')
efficientnet = models.efficientnet_b0(weights='IMAGENET1K_V1')
vit = models.vit_b_16(weights='IMAGENET1K_V1')

# ๊ฐ์ฒด ๊ฒ€์ถœ์šฉ
fasterrcnn = models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')

# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์šฉ
deeplabv3 = models.segmentation.deeplabv3_resnet50(weights='DEFAULT')

timm ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ

import timm

# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ํ™•์ธ
print(timm.list_models('*efficientnet*'))

# ๋ชจ๋ธ ๋กœ๋“œ
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=10)

6. ์‹ค์ „ ํ”„๋กœ์ ํŠธ: ๊ฝƒ ๋ถ„๋ฅ˜

๋ฐ์ดํ„ฐ ์ค€๋น„

# Flowers102 ๋ฐ์ดํ„ฐ์…‹
from torchvision.datasets import Flowers102

train_data = Flowers102(
    root='data',
    split='train',
    transform=train_transform,
    download=True
)

test_data = Flowers102(
    root='data',
    split='test',
    transform=val_transform
)

๋ชจ๋ธ ๋ฐ ํ•™์Šต

class FlowerClassifier(nn.Module):
    def __init__(self, num_classes=102):
        super().__init__()
        self.backbone = models.efficientnet_b0(weights='IMAGENET1K_V1')

        # ๋งˆ์ง€๋ง‰ ์ธต ๊ต์ฒด
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)

# ํ•™์Šต
model = FlowerClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

7. ์ฃผ์˜์‚ฌํ•ญ

๋ฐ์ดํ„ฐ ํฌ๊ธฐ๋ณ„ ์ „๋žต

๋ฐ์ดํ„ฐ ํฌ๊ธฐ ์ „๋žต ์„ค๋ช…
๋งค์šฐ ์ ์Œ (<1000) ํŠน์„ฑ ์ถ”์ถœ ๋งˆ์ง€๋ง‰ ์ธต๋งŒ ํ•™์Šต
์ ์Œ (1000-10000) ์ ์ง„์  ํ•ด๋™ ํ›„๋ฐ˜ ์ธต๋ถ€ํ„ฐ ํ•ด๋™
๋ณดํ†ต (10000+) ์ „์ฒด ๋ฏธ์„ธ ์กฐ์ • ๋‚ฎ์€ ํ•™์Šต๋ฅ ๋กœ ์ „์ฒด ํ•™์Šต

๋„๋ฉ”์ธ ์œ ์‚ฌ์„ฑ

ImageNet๊ณผ ์œ ์‚ฌ (๋™๋ฌผ, ์‚ฌ๋ฌผ):
    โ†’ ์–•์€ ์ธต๋„ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ ๊ฐ€๋Šฅ

ImageNet๊ณผ ๋‹ค๋ฆ„ (์˜๋ฃŒ, ์œ„์„ฑ):
    โ†’ ๊นŠ์€ ์ธต๊นŒ์ง€ ๋ฏธ์„ธ ์กฐ์ • ํ•„์š”

์ผ๋ฐ˜์ ์ธ ์‹ค์ˆ˜

  1. ImageNet ์ •๊ทœํ™” ๋ˆ„๋ฝ
  2. ๋„ˆ๋ฌด ๋†’์€ ํ•™์Šต๋ฅ 
  3. ํ›ˆ๋ จ/ํ‰๊ฐ€ ๋ชจ๋“œ ์ „ํ™˜ ์žŠ์Œ
  4. ๊ฐ€์ค‘์น˜ ๊ณ ์ • ํ›„ optimizer์— ํฌํ•จ

8. ์„ฑ๋Šฅ ํ–ฅ์ƒ ํŒ

๋ฐ์ดํ„ฐ ์ฆ๊ฐ•

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    normalize
])

Label Smoothing

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

Mixup / CutMix

def mixup(x, y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

์ •๋ฆฌ

ํ•ต์‹ฌ ๊ฐœ๋…

  1. ํŠน์„ฑ ์ถ”์ถœ: ์‚ฌ์ „ ํ•™์Šต ํŠน์ง• ์žฌ์‚ฌ์šฉ
  2. ๋ฏธ์„ธ ์กฐ์ •: ๋‚ฎ์€ ํ•™์Šต๋ฅ ๋กœ ์ „์ฒด ์กฐ์ •
  3. ์ ์ง„์  ํ•ด๋™: ํ›„๋ฐ˜ ์ธต๋ถ€ํ„ฐ ์ˆœ์ฐจ์  ํ•™์Šต

์ฒดํฌ๋ฆฌ์ŠคํŠธ

  • [ ] ImageNet ์ •๊ทœํ™” ์‚ฌ์šฉ
  • [ ] ์ ์ ˆํ•œ ํ•™์Šต๋ฅ  ์„ ํƒ (1e-4 ~ 1e-5)
  • [ ] model.train() / model.eval() ์ „ํ™˜
  • [ ] ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• ์ ์šฉ
  • [ ] ์กฐ๊ธฐ ์ข…๋ฃŒ ์„ค์ •

๋‹ค์Œ ๋‹จ๊ณ„

13_RNN_Basics.md์—์„œ ์ˆœํ™˜ ์‹ ๊ฒฝ๋ง์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

to navigate between lessons