07_transfer_learning.py

Download
python 391 lines 11.8 KB
  1"""
  207. ์ „์ดํ•™์Šต (Transfer Learning)
  3
  4์‚ฌ์ „ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ํ™œ์šฉํ•œ ์ „์ดํ•™์Šต์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10from torch.utils.data import DataLoader, TensorDataset
 11import numpy as np
 12
 13print("=" * 60)
 14print("PyTorch ์ „์ดํ•™์Šต (Transfer Learning)")
 15print("=" * 60)
 16
 17device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 18print(f"์‚ฌ์šฉ ์žฅ์น˜: {device}")
 19
 20
 21# ============================================
 22# 1. ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ ๋กœ๋“œ
 23# ============================================
 24print("\n[1] ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ ๋กœ๋“œ")
 25print("-" * 40)
 26
 27try:
 28    import torchvision.models as models
 29
 30    # ๋‹ค์–‘ํ•œ ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ
 31    print("์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ:")
 32    pretrained_models = {
 33        'ResNet-18': lambda: models.resnet18(weights='IMAGENET1K_V1'),
 34        'ResNet-50': lambda: models.resnet50(weights='IMAGENET1K_V2'),
 35        'EfficientNet-B0': lambda: models.efficientnet_b0(weights='IMAGENET1K_V1'),
 36        'MobileNet-V2': lambda: models.mobilenet_v2(weights='IMAGENET1K_V1'),
 37    }
 38
 39    for name, loader in pretrained_models.items():
 40        model = loader()
 41        params = sum(p.numel() for p in model.parameters())
 42        print(f"  {name}: {params:,} ํŒŒ๋ผ๋ฏธํ„ฐ")
 43
 44    TORCHVISION_AVAILABLE = True
 45except ImportError:
 46    print("torchvision์ด ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๋ฐ๋ชจ ๋ชจ๋“œ๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
 47    TORCHVISION_AVAILABLE = False
 48
 49
 50# ============================================
 51# 2. ํŠน์„ฑ ์ถ”์ถœ (Feature Extraction)
 52# ============================================
 53print("\n[2] ํŠน์„ฑ ์ถ”์ถœ (Feature Extraction)")
 54print("-" * 40)
 55
 56if TORCHVISION_AVAILABLE:
 57    # ResNet-18 ๋กœ๋“œ
 58    model = models.resnet18(weights='IMAGENET1K_V1')
 59
 60    # ์›๋ž˜ ๋ถ„๋ฅ˜๊ธฐ ํ™•์ธ
 61    print(f"์›๋ž˜ FC ์ธต: {model.fc}")
 62
 63    # ๋ชจ๋“  ๊ฐ€์ค‘์น˜ ๊ณ ์ •
 64    for param in model.parameters():
 65        param.requires_grad = False
 66
 67    # ๋งˆ์ง€๋ง‰ ์ธต ๊ต์ฒด
 68    num_features = model.fc.in_features
 69    model.fc = nn.Sequential(
 70        nn.Dropout(0.5),
 71        nn.Linear(num_features, 10)  # 10 ํด๋ž˜์Šค
 72    )
 73
 74    print(f"์ƒˆ FC ์ธต: {model.fc}")
 75
 76    # ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ ํ™•์ธ
 77    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
 78    total = sum(p.numel() for p in model.parameters())
 79    print(f"ํ•™์Šต ๊ฐ€๋Šฅ ํŒŒ๋ผ๋ฏธํ„ฐ: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
 80
 81
 82# ============================================
 83# 3. ๋ฏธ์„ธ ์กฐ์ • (Fine-tuning)
 84# ============================================
 85print("\n[3] ๋ฏธ์„ธ ์กฐ์ • (Fine-tuning)")
 86print("-" * 40)
 87
 88if TORCHVISION_AVAILABLE:
 89    # ์ƒˆ๋กœ์šด ๋ชจ๋ธ ๋กœ๋“œ
 90    model = models.resnet18(weights='IMAGENET1K_V1')
 91
 92    # ๋งˆ์ง€๋ง‰ ์ธต ๊ต์ฒด
 93    model.fc = nn.Linear(model.fc.in_features, 10)
 94
 95    # ์ „์ฒด ํ•™์Šต ๊ฐ€๋Šฅ (๊ธฐ๋ณธ)
 96    print("์ „์ฒด ๋ฏธ์„ธ ์กฐ์ •:")
 97    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
 98    print(f"  ํ•™์Šต ๊ฐ€๋Šฅ ํŒŒ๋ผ๋ฏธํ„ฐ: {trainable:,}")
 99
100
101# ============================================
102# 4. ์ ์ง„์  ํ•ด๋™ (Gradual Unfreezing)
103# ============================================
104print("\n[4] ์ ์ง„์  ํ•ด๋™ (Gradual Unfreezing)")
105print("-" * 40)
106
107if TORCHVISION_AVAILABLE:
108    model = models.resnet18(weights='IMAGENET1K_V1')
109
110    # 1๋‹จ๊ณ„: ๋ชจ๋“  ์ธต ๊ณ ์ •
111    for param in model.parameters():
112        param.requires_grad = False
113
114    # ๋งˆ์ง€๋ง‰ ์ธต๋งŒ ํ•™์Šต ๊ฐ€๋Šฅ
115    model.fc = nn.Linear(model.fc.in_features, 10)
116
117    def count_trainable(model):
118        return sum(p.numel() for p in model.parameters() if p.requires_grad)
119
120    print("์ ์ง„์  ํ•ด๋™ ๊ณผ์ •:")
121    print(f"  1๋‹จ๊ณ„ (FC๋งŒ): {count_trainable(model):,} ํŒŒ๋ผ๋ฏธํ„ฐ")
122
123    # 2๋‹จ๊ณ„: layer4 ํ•ด๋™
124    for param in model.layer4.parameters():
125        param.requires_grad = True
126    print(f"  2๋‹จ๊ณ„ (FC + layer4): {count_trainable(model):,} ํŒŒ๋ผ๋ฏธํ„ฐ")
127
128    # 3๋‹จ๊ณ„: layer3 ํ•ด๋™
129    for param in model.layer3.parameters():
130        param.requires_grad = True
131    print(f"  3๋‹จ๊ณ„ (FC + layer4 + layer3): {count_trainable(model):,} ํŒŒ๋ผ๋ฏธํ„ฐ")
132
133    # 4๋‹จ๊ณ„: ์ „์ฒด ํ•ด๋™
134    for param in model.parameters():
135        param.requires_grad = True
136    print(f"  4๋‹จ๊ณ„ (์ „์ฒด): {count_trainable(model):,} ํŒŒ๋ผ๋ฏธํ„ฐ")
137
138
139# ============================================
140# 5. ์ฐจ๋“ฑ ํ•™์Šต๋ฅ  (Discriminative Learning Rates)
141# ============================================
142print("\n[5] ์ฐจ๋“ฑ ํ•™์Šต๋ฅ ")
143print("-" * 40)
144
145if TORCHVISION_AVAILABLE:
146    model = models.resnet18(weights='IMAGENET1K_V1')
147    model.fc = nn.Linear(model.fc.in_features, 10)
148
149    # ์ธต๋ณ„ ๋‹ค๋ฅธ ํ•™์Šต๋ฅ 
150    optimizer = torch.optim.Adam([
151        {'params': model.conv1.parameters(), 'lr': 1e-5},
152        {'params': model.layer1.parameters(), 'lr': 2e-5},
153        {'params': model.layer2.parameters(), 'lr': 5e-5},
154        {'params': model.layer3.parameters(), 'lr': 1e-4},
155        {'params': model.layer4.parameters(), 'lr': 2e-4},
156        {'params': model.fc.parameters(), 'lr': 1e-3},
157    ])
158
159    print("์ธต๋ณ„ ํ•™์Šต๋ฅ :")
160    for i, group in enumerate(optimizer.param_groups):
161        print(f"  ๊ทธ๋ฃน {i}: lr = {group['lr']}")
162
163
164# ============================================
165# 6. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ (ImageNet ์ •๊ทœํ™”)
166# ============================================
167print("\n[6] ImageNet ์ •๊ทœํ™”")
168print("-" * 40)
169
170try:
171    from torchvision import transforms
172
173    # ImageNet ์ •๊ทœํ™” ๊ฐ’
174    imagenet_mean = [0.485, 0.456, 0.406]
175    imagenet_std = [0.229, 0.224, 0.225]
176
177    train_transform = transforms.Compose([
178        transforms.RandomResizedCrop(224),
179        transforms.RandomHorizontalFlip(),
180        transforms.ToTensor(),
181        transforms.Normalize(imagenet_mean, imagenet_std)
182    ])
183
184    val_transform = transforms.Compose([
185        transforms.Resize(256),
186        transforms.CenterCrop(224),
187        transforms.ToTensor(),
188        transforms.Normalize(imagenet_mean, imagenet_std)
189    ])
190
191    print(f"ImageNet Mean: {imagenet_mean}")
192    print(f"ImageNet Std: {imagenet_std}")
193    print("ํ›ˆ๋ จ ๋ณ€ํ™˜: RandomResizedCrop, Flip, Normalize")
194    print("๊ฒ€์ฆ ๋ณ€ํ™˜: Resize, CenterCrop, Normalize")
195except:
196    print("transforms ๋กœ๋“œ ์‹คํŒจ")
197
198
199# ============================================
200# 7. ์ „์ดํ•™์Šต ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ
201# ============================================
202print("\n[7] ์ „์ดํ•™์Šต ์ „์ฒด ํŒŒ์ดํ”„๋ผ์ธ")
203print("-" * 40)
204
205class TransferLearningPipeline:
206    """์ „์ดํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ"""
207
208    def __init__(self, backbone='resnet18', num_classes=10, strategy='finetune'):
209        self.strategy = strategy
210
211        if TORCHVISION_AVAILABLE:
212            # ๋ฐฑ๋ณธ ๋กœ๋“œ
213            if backbone == 'resnet18':
214                self.model = models.resnet18(weights='IMAGENET1K_V1')
215                in_features = self.model.fc.in_features
216                self.model.fc = nn.Linear(in_features, num_classes)
217            elif backbone == 'resnet50':
218                self.model = models.resnet50(weights='IMAGENET1K_V2')
219                in_features = self.model.fc.in_features
220                self.model.fc = nn.Linear(in_features, num_classes)
221            else:
222                raise ValueError(f"Unknown backbone: {backbone}")
223
224            # ์ „๋žต์— ๋”ฐ๋ฅธ ๊ฐ€์ค‘์น˜ ๊ณ ์ •
225            if strategy == 'feature_extract':
226                self._freeze_backbone()
227            elif strategy == 'finetune':
228                pass  # ์ „์ฒด ํ•™์Šต ๊ฐ€๋Šฅ
229            elif strategy == 'gradual':
230                self._freeze_backbone()
231        else:
232            # ๋ฐ๋ชจ์šฉ ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ
233            self.model = nn.Sequential(
234                nn.Conv2d(3, 64, 3, padding=1),
235                nn.ReLU(),
236                nn.AdaptiveAvgPool2d(1),
237                nn.Flatten(),
238                nn.Linear(64, num_classes)
239            )
240
241    def _freeze_backbone(self):
242        """FC ์ œ์™ธ ๋ชจ๋“  ์ธต ๊ณ ์ •"""
243        for name, param in self.model.named_parameters():
244            if 'fc' not in name:
245                param.requires_grad = False
246
247    def unfreeze_layer(self, layer_name):
248        """ํŠน์ • ์ธต ํ•ด๋™"""
249        layer = getattr(self.model, layer_name, None)
250        if layer:
251            for param in layer.parameters():
252                param.requires_grad = True
253
254    def get_optimizer(self, lr=1e-4):
255        """์ตœ์ ํ™”๊ธฐ ์ƒ์„ฑ"""
256        if self.strategy == 'feature_extract':
257            # ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ
258            params = filter(lambda p: p.requires_grad, self.model.parameters())
259            return torch.optim.Adam(params, lr=lr)
260        else:
261            return torch.optim.Adam(self.model.parameters(), lr=lr)
262
263    def summary(self):
264        """๋ชจ๋ธ ์š”์•ฝ"""
265        trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
266        total = sum(p.numel() for p in self.model.parameters())
267        print(f"์ „๋žต: {self.strategy}")
268        print(f"ํ•™์Šต ๊ฐ€๋Šฅ: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
269
270# ํ…Œ์ŠคํŠธ
271print("\n์ „๋žต๋ณ„ ๋น„๊ต:")
272for strategy in ['feature_extract', 'finetune']:
273    print(f"\n{strategy}:")
274    pipeline = TransferLearningPipeline('resnet18', 10, strategy)
275    pipeline.summary()
276
277
278# ============================================
279# 8. ๋”๋ฏธ ๋ฐ์ดํ„ฐ๋กœ ํ•™์Šต ์˜ˆ์‹œ
280# ============================================
281print("\n[8] ํ•™์Šต ์˜ˆ์‹œ (๋”๋ฏธ ๋ฐ์ดํ„ฐ)")
282print("-" * 40)
283
284# ๋”๋ฏธ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
285X_train = torch.randn(100, 3, 224, 224)
286y_train = torch.randint(0, 10, (100,))
287
288train_dataset = TensorDataset(X_train, y_train)
289train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
290
291# ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •
292pipeline = TransferLearningPipeline('resnet18', 10, 'feature_extract')
293model = pipeline.model.to(device)
294optimizer = pipeline.get_optimizer(lr=1e-3)
295criterion = nn.CrossEntropyLoss()
296
297# ๊ฐ„๋‹จํ•œ ํ•™์Šต
298model.train()
299for epoch in range(2):
300    epoch_loss = 0
301    for X_batch, y_batch in train_loader:
302        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
303
304        outputs = model(X_batch)
305        loss = criterion(outputs, y_batch)
306
307        optimizer.zero_grad()
308        loss.backward()
309        optimizer.step()
310
311        epoch_loss += loss.item()
312
313    print(f"Epoch {epoch+1}: Loss = {epoch_loss/len(train_loader):.4f}")
314
315
316# ============================================
317# 9. ์ „์ดํ•™์Šต ์ฒดํฌ๋ฆฌ์ŠคํŠธ
318# ============================================
319print("\n[9] ์ „์ดํ•™์Šต ์ฒดํฌ๋ฆฌ์ŠคํŠธ")
320print("-" * 40)
321
322checklist = """
323โœ“ ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ ์„ ํƒ
324  - ์ž‘์—…๊ณผ ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐ๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ
325  - ImageNet ๋ชจ๋ธ์ด ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ ์ข‹์Œ
326
327โœ“ ์ „์ฒ˜๋ฆฌ
328  - ImageNet ์ •๊ทœํ™” ์‚ฌ์šฉ
329  - ๋ชจ๋ธ ์ž…๋ ฅ ํฌ๊ธฐ ๋งž์ถ”๊ธฐ (๋ณดํ†ต 224ร—224)
330
331โœ“ ์ „๋žต ์„ ํƒ
332  - ๋ฐ์ดํ„ฐ ์ ์Œ: ํŠน์„ฑ ์ถ”์ถœ (FC๋งŒ ํ•™์Šต)
333  - ๋ฐ์ดํ„ฐ ์ถฉ๋ถ„: ๋ฏธ์„ธ ์กฐ์ • (์ „์ฒด ํ•™์Šต)
334  - ์ค‘๊ฐ„: ์ ์ง„์  ํ•ด๋™
335
336โœ“ ํ•™์Šต๋ฅ 
337  - ํŠน์„ฑ ์ถ”์ถœ: 1e-3 ~ 1e-2
338  - ๋ฏธ์„ธ ์กฐ์ •: 1e-5 ~ 1e-4
339  - ์ฐจ๋“ฑ ํ•™์Šต๋ฅ  ๊ณ ๋ ค
340
341โœ“ ์ •๊ทœํ™”
342  - Dropout, Weight Decay
343  - ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•
344  - ์กฐ๊ธฐ ์ข…๋ฃŒ
345
346โœ“ ๋ชจ๋“œ ์ „ํ™˜
347  - ํ›ˆ๋ จ: model.train()
348  - ํ‰๊ฐ€: model.eval()
349"""
350print(checklist)
351
352
353# ============================================
354# ์ •๋ฆฌ
355# ============================================
356print("\n" + "=" * 60)
357print("์ „์ดํ•™์Šต ์ •๋ฆฌ")
358print("=" * 60)
359
360summary = """
361์ „์ดํ•™์Šต ์ „๋žต:
362
3631. ํŠน์„ฑ ์ถ”์ถœ (Feature Extraction)
364   - ์‚ฌ์ „ ํ•™์Šต ๊ฐ€์ค‘์น˜ ๊ณ ์ •
365   - ๋งˆ์ง€๋ง‰ ์ธต๋งŒ ํ•™์Šต
366   - ๋ฐ์ดํ„ฐ ์ ์„ ๋•Œ ์ ํ•ฉ
367
3682. ๋ฏธ์„ธ ์กฐ์ • (Fine-tuning)
369   - ์ „์ฒด ๋„คํŠธ์›Œํฌ ํ•™์Šต
370   - ๋‚ฎ์€ ํ•™์Šต๋ฅ  ์‚ฌ์šฉ
371   - ๋ฐ์ดํ„ฐ ์ถฉ๋ถ„ํ•  ๋•Œ
372
3733. ์ ์ง„์  ํ•ด๋™ (Gradual Unfreezing)
374   - ํ›„๋ฐ˜ ์ธต๋ถ€ํ„ฐ ์ˆœ์ฐจ์  ํ•ด๋™
375   - ๊ท ํ˜• ์žกํžŒ ์ ‘๊ทผ
376
377ํ•ต์‹ฌ ์ฝ”๋“œ:
378    # ๊ฐ€์ค‘์น˜ ๊ณ ์ •
379    for param in model.parameters():
380        param.requires_grad = False
381
382    # ๋งˆ์ง€๋ง‰ ์ธต ๊ต์ฒด
383    model.fc = nn.Linear(in_features, num_classes)
384
385    # ImageNet ์ •๊ทœํ™”
386    transforms.Normalize([0.485, 0.456, 0.406],
387                        [0.229, 0.224, 0.225])
388"""
389print(summary)
390print("=" * 60)