1"""
207. ์ ์ดํ์ต (Transfer Learning)
3
4์ฌ์ ํ์ต๋ ๋ชจ๋ธ์ ํ์ฉํ ์ ์ดํ์ต์ ๊ตฌํํฉ๋๋ค.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.utils.data import DataLoader, TensorDataset
11import numpy as np
12
13print("=" * 60)
14print("PyTorch ์ ์ดํ์ต (Transfer Learning)")
15print("=" * 60)
16
17device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18print(f"์ฌ์ฉ ์ฅ์น: {device}")
19
20
21# ============================================
22# 1. ์ฌ์ ํ์ต ๋ชจ๋ธ ๋ก๋
23# ============================================
24print("\n[1] ์ฌ์ ํ์ต ๋ชจ๋ธ ๋ก๋")
25print("-" * 40)
26
27try:
28 import torchvision.models as models
29
30 # ๋ค์ํ ์ฌ์ ํ์ต ๋ชจ๋ธ
31 print("์ฌ์ฉ ๊ฐ๋ฅํ ์ฌ์ ํ์ต ๋ชจ๋ธ:")
32 pretrained_models = {
33 'ResNet-18': lambda: models.resnet18(weights='IMAGENET1K_V1'),
34 'ResNet-50': lambda: models.resnet50(weights='IMAGENET1K_V2'),
35 'EfficientNet-B0': lambda: models.efficientnet_b0(weights='IMAGENET1K_V1'),
36 'MobileNet-V2': lambda: models.mobilenet_v2(weights='IMAGENET1K_V1'),
37 }
38
39 for name, loader in pretrained_models.items():
40 model = loader()
41 params = sum(p.numel() for p in model.parameters())
42 print(f" {name}: {params:,} ํ๋ผ๋ฏธํฐ")
43
44 TORCHVISION_AVAILABLE = True
45except ImportError:
46 print("torchvision์ด ์ค์น๋์ง ์์์ต๋๋ค. ๋ฐ๋ชจ ๋ชจ๋๋ก ์งํํฉ๋๋ค.")
47 TORCHVISION_AVAILABLE = False
48
49
50# ============================================
51# 2. ํน์ฑ ์ถ์ถ (Feature Extraction)
52# ============================================
53print("\n[2] ํน์ฑ ์ถ์ถ (Feature Extraction)")
54print("-" * 40)
55
56if TORCHVISION_AVAILABLE:
57 # ResNet-18 ๋ก๋
58 model = models.resnet18(weights='IMAGENET1K_V1')
59
60 # ์๋ ๋ถ๋ฅ๊ธฐ ํ์ธ
61 print(f"์๋ FC ์ธต: {model.fc}")
62
63 # ๋ชจ๋ ๊ฐ์ค์น ๊ณ ์
64 for param in model.parameters():
65 param.requires_grad = False
66
67 # ๋ง์ง๋ง ์ธต ๊ต์ฒด
68 num_features = model.fc.in_features
69 model.fc = nn.Sequential(
70 nn.Dropout(0.5),
71 nn.Linear(num_features, 10) # 10 ํด๋์ค
72 )
73
74 print(f"์ FC ์ธต: {model.fc}")
75
76 # ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ํ์ธ
77 trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
78 total = sum(p.numel() for p in model.parameters())
79 print(f"ํ์ต ๊ฐ๋ฅ ํ๋ผ๋ฏธํฐ: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
80
81
82# ============================================
83# 3. ๋ฏธ์ธ ์กฐ์ (Fine-tuning)
84# ============================================
85print("\n[3] ๋ฏธ์ธ ์กฐ์ (Fine-tuning)")
86print("-" * 40)
87
88if TORCHVISION_AVAILABLE:
89 # ์๋ก์ด ๋ชจ๋ธ ๋ก๋
90 model = models.resnet18(weights='IMAGENET1K_V1')
91
92 # ๋ง์ง๋ง ์ธต ๊ต์ฒด
93 model.fc = nn.Linear(model.fc.in_features, 10)
94
95 # ์ ์ฒด ํ์ต ๊ฐ๋ฅ (๊ธฐ๋ณธ)
96 print("์ ์ฒด ๋ฏธ์ธ ์กฐ์ :")
97 trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
98 print(f" ํ์ต ๊ฐ๋ฅ ํ๋ผ๋ฏธํฐ: {trainable:,}")
99
100
101# ============================================
102# 4. ์ ์ง์ ํด๋ (Gradual Unfreezing)
103# ============================================
104print("\n[4] ์ ์ง์ ํด๋ (Gradual Unfreezing)")
105print("-" * 40)
106
107if TORCHVISION_AVAILABLE:
108 model = models.resnet18(weights='IMAGENET1K_V1')
109
110 # 1๋จ๊ณ: ๋ชจ๋ ์ธต ๊ณ ์
111 for param in model.parameters():
112 param.requires_grad = False
113
114 # ๋ง์ง๋ง ์ธต๋ง ํ์ต ๊ฐ๋ฅ
115 model.fc = nn.Linear(model.fc.in_features, 10)
116
117 def count_trainable(model):
118 return sum(p.numel() for p in model.parameters() if p.requires_grad)
119
120 print("์ ์ง์ ํด๋ ๊ณผ์ :")
121 print(f" 1๋จ๊ณ (FC๋ง): {count_trainable(model):,} ํ๋ผ๋ฏธํฐ")
122
123 # 2๋จ๊ณ: layer4 ํด๋
124 for param in model.layer4.parameters():
125 param.requires_grad = True
126 print(f" 2๋จ๊ณ (FC + layer4): {count_trainable(model):,} ํ๋ผ๋ฏธํฐ")
127
128 # 3๋จ๊ณ: layer3 ํด๋
129 for param in model.layer3.parameters():
130 param.requires_grad = True
131 print(f" 3๋จ๊ณ (FC + layer4 + layer3): {count_trainable(model):,} ํ๋ผ๋ฏธํฐ")
132
133 # 4๋จ๊ณ: ์ ์ฒด ํด๋
134 for param in model.parameters():
135 param.requires_grad = True
136 print(f" 4๋จ๊ณ (์ ์ฒด): {count_trainable(model):,} ํ๋ผ๋ฏธํฐ")
137
138
139# ============================================
140# 5. ์ฐจ๋ฑ ํ์ต๋ฅ (Discriminative Learning Rates)
141# ============================================
142print("\n[5] ์ฐจ๋ฑ ํ์ต๋ฅ ")
143print("-" * 40)
144
145if TORCHVISION_AVAILABLE:
146 model = models.resnet18(weights='IMAGENET1K_V1')
147 model.fc = nn.Linear(model.fc.in_features, 10)
148
149 # ์ธต๋ณ ๋ค๋ฅธ ํ์ต๋ฅ
150 optimizer = torch.optim.Adam([
151 {'params': model.conv1.parameters(), 'lr': 1e-5},
152 {'params': model.layer1.parameters(), 'lr': 2e-5},
153 {'params': model.layer2.parameters(), 'lr': 5e-5},
154 {'params': model.layer3.parameters(), 'lr': 1e-4},
155 {'params': model.layer4.parameters(), 'lr': 2e-4},
156 {'params': model.fc.parameters(), 'lr': 1e-3},
157 ])
158
159 print("์ธต๋ณ ํ์ต๋ฅ :")
160 for i, group in enumerate(optimizer.param_groups):
161 print(f" ๊ทธ๋ฃน {i}: lr = {group['lr']}")
162
163
164# ============================================
165# 6. ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ (ImageNet ์ ๊ทํ)
166# ============================================
167print("\n[6] ImageNet ์ ๊ทํ")
168print("-" * 40)
169
170try:
171 from torchvision import transforms
172
173 # ImageNet ์ ๊ทํ ๊ฐ
174 imagenet_mean = [0.485, 0.456, 0.406]
175 imagenet_std = [0.229, 0.224, 0.225]
176
177 train_transform = transforms.Compose([
178 transforms.RandomResizedCrop(224),
179 transforms.RandomHorizontalFlip(),
180 transforms.ToTensor(),
181 transforms.Normalize(imagenet_mean, imagenet_std)
182 ])
183
184 val_transform = transforms.Compose([
185 transforms.Resize(256),
186 transforms.CenterCrop(224),
187 transforms.ToTensor(),
188 transforms.Normalize(imagenet_mean, imagenet_std)
189 ])
190
191 print(f"ImageNet Mean: {imagenet_mean}")
192 print(f"ImageNet Std: {imagenet_std}")
193 print("ํ๋ จ ๋ณํ: RandomResizedCrop, Flip, Normalize")
194 print("๊ฒ์ฆ ๋ณํ: Resize, CenterCrop, Normalize")
195except:
196 print("transforms ๋ก๋ ์คํจ")
197
198
199# ============================================
200# 7. ์ ์ดํ์ต ์ ์ฒด ํ์ดํ๋ผ์ธ
201# ============================================
202print("\n[7] ์ ์ดํ์ต ์ ์ฒด ํ์ดํ๋ผ์ธ")
203print("-" * 40)
204
205class TransferLearningPipeline:
206 """์ ์ดํ์ต ํ์ดํ๋ผ์ธ"""
207
208 def __init__(self, backbone='resnet18', num_classes=10, strategy='finetune'):
209 self.strategy = strategy
210
211 if TORCHVISION_AVAILABLE:
212 # ๋ฐฑ๋ณธ ๋ก๋
213 if backbone == 'resnet18':
214 self.model = models.resnet18(weights='IMAGENET1K_V1')
215 in_features = self.model.fc.in_features
216 self.model.fc = nn.Linear(in_features, num_classes)
217 elif backbone == 'resnet50':
218 self.model = models.resnet50(weights='IMAGENET1K_V2')
219 in_features = self.model.fc.in_features
220 self.model.fc = nn.Linear(in_features, num_classes)
221 else:
222 raise ValueError(f"Unknown backbone: {backbone}")
223
224 # ์ ๋ต์ ๋ฐ๋ฅธ ๊ฐ์ค์น ๊ณ ์
225 if strategy == 'feature_extract':
226 self._freeze_backbone()
227 elif strategy == 'finetune':
228 pass # ์ ์ฒด ํ์ต ๊ฐ๋ฅ
229 elif strategy == 'gradual':
230 self._freeze_backbone()
231 else:
232 # ๋ฐ๋ชจ์ฉ ๊ฐ๋จํ ๋ชจ๋ธ
233 self.model = nn.Sequential(
234 nn.Conv2d(3, 64, 3, padding=1),
235 nn.ReLU(),
236 nn.AdaptiveAvgPool2d(1),
237 nn.Flatten(),
238 nn.Linear(64, num_classes)
239 )
240
241 def _freeze_backbone(self):
242 """FC ์ ์ธ ๋ชจ๋ ์ธต ๊ณ ์ """
243 for name, param in self.model.named_parameters():
244 if 'fc' not in name:
245 param.requires_grad = False
246
247 def unfreeze_layer(self, layer_name):
248 """ํน์ ์ธต ํด๋"""
249 layer = getattr(self.model, layer_name, None)
250 if layer:
251 for param in layer.parameters():
252 param.requires_grad = True
253
254 def get_optimizer(self, lr=1e-4):
255 """์ต์ ํ๊ธฐ ์์ฑ"""
256 if self.strategy == 'feature_extract':
257 # ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ๋ง
258 params = filter(lambda p: p.requires_grad, self.model.parameters())
259 return torch.optim.Adam(params, lr=lr)
260 else:
261 return torch.optim.Adam(self.model.parameters(), lr=lr)
262
263 def summary(self):
264 """๋ชจ๋ธ ์์ฝ"""
265 trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
266 total = sum(p.numel() for p in self.model.parameters())
267 print(f"์ ๋ต: {self.strategy}")
268 print(f"ํ์ต ๊ฐ๋ฅ: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")
269
270# ํ
์คํธ
271print("\n์ ๋ต๋ณ ๋น๊ต:")
272for strategy in ['feature_extract', 'finetune']:
273 print(f"\n{strategy}:")
274 pipeline = TransferLearningPipeline('resnet18', 10, strategy)
275 pipeline.summary()
276
277
278# ============================================
279# 8. ๋๋ฏธ ๋ฐ์ดํฐ๋ก ํ์ต ์์
280# ============================================
281print("\n[8] ํ์ต ์์ (๋๋ฏธ ๋ฐ์ดํฐ)")
282print("-" * 40)
283
284# ๋๋ฏธ ๋ฐ์ดํฐ ์์ฑ
285X_train = torch.randn(100, 3, 224, 224)
286y_train = torch.randint(0, 10, (100,))
287
288train_dataset = TensorDataset(X_train, y_train)
289train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
290
291# ํ์ดํ๋ผ์ธ ์ค์
292pipeline = TransferLearningPipeline('resnet18', 10, 'feature_extract')
293model = pipeline.model.to(device)
294optimizer = pipeline.get_optimizer(lr=1e-3)
295criterion = nn.CrossEntropyLoss()
296
297# ๊ฐ๋จํ ํ์ต
298model.train()
299for epoch in range(2):
300 epoch_loss = 0
301 for X_batch, y_batch in train_loader:
302 X_batch, y_batch = X_batch.to(device), y_batch.to(device)
303
304 outputs = model(X_batch)
305 loss = criterion(outputs, y_batch)
306
307 optimizer.zero_grad()
308 loss.backward()
309 optimizer.step()
310
311 epoch_loss += loss.item()
312
313 print(f"Epoch {epoch+1}: Loss = {epoch_loss/len(train_loader):.4f}")
314
315
316# ============================================
317# 9. ์ ์ดํ์ต ์ฒดํฌ๋ฆฌ์คํธ
318# ============================================
319print("\n[9] ์ ์ดํ์ต ์ฒดํฌ๋ฆฌ์คํธ")
320print("-" * 40)
321
322checklist = """
323โ ์ฌ์ ํ์ต ๋ชจ๋ธ ์ ํ
324 - ์์
๊ณผ ์ ์ฌํ ๋ฐ์ดํฐ๋ก ํ์ต๋ ๋ชจ๋ธ
325 - ImageNet ๋ชจ๋ธ์ด ๋๋ถ๋ถ์ ๊ฒฝ์ฐ ์ข์
326
327โ ์ ์ฒ๋ฆฌ
328 - ImageNet ์ ๊ทํ ์ฌ์ฉ
329 - ๋ชจ๋ธ ์
๋ ฅ ํฌ๊ธฐ ๋ง์ถ๊ธฐ (๋ณดํต 224ร224)
330
331โ ์ ๋ต ์ ํ
332 - ๋ฐ์ดํฐ ์ ์: ํน์ฑ ์ถ์ถ (FC๋ง ํ์ต)
333 - ๋ฐ์ดํฐ ์ถฉ๋ถ: ๋ฏธ์ธ ์กฐ์ (์ ์ฒด ํ์ต)
334 - ์ค๊ฐ: ์ ์ง์ ํด๋
335
336โ ํ์ต๋ฅ
337 - ํน์ฑ ์ถ์ถ: 1e-3 ~ 1e-2
338 - ๋ฏธ์ธ ์กฐ์ : 1e-5 ~ 1e-4
339 - ์ฐจ๋ฑ ํ์ต๋ฅ ๊ณ ๋ ค
340
341โ ์ ๊ทํ
342 - Dropout, Weight Decay
343 - ๋ฐ์ดํฐ ์ฆ๊ฐ
344 - ์กฐ๊ธฐ ์ข
๋ฃ
345
346โ ๋ชจ๋ ์ ํ
347 - ํ๋ จ: model.train()
348 - ํ๊ฐ: model.eval()
349"""
350print(checklist)
351
352
353# ============================================
354# ์ ๋ฆฌ
355# ============================================
356print("\n" + "=" * 60)
357print("์ ์ดํ์ต ์ ๋ฆฌ")
358print("=" * 60)
359
360summary = """
361์ ์ดํ์ต ์ ๋ต:
362
3631. ํน์ฑ ์ถ์ถ (Feature Extraction)
364 - ์ฌ์ ํ์ต ๊ฐ์ค์น ๊ณ ์
365 - ๋ง์ง๋ง ์ธต๋ง ํ์ต
366 - ๋ฐ์ดํฐ ์ ์ ๋ ์ ํฉ
367
3682. ๋ฏธ์ธ ์กฐ์ (Fine-tuning)
369 - ์ ์ฒด ๋คํธ์ํฌ ํ์ต
370 - ๋ฎ์ ํ์ต๋ฅ ์ฌ์ฉ
371 - ๋ฐ์ดํฐ ์ถฉ๋ถํ ๋
372
3733. ์ ์ง์ ํด๋ (Gradual Unfreezing)
374 - ํ๋ฐ ์ธต๋ถํฐ ์์ฐจ์ ํด๋
375 - ๊ท ํ ์กํ ์ ๊ทผ
376
377ํต์ฌ ์ฝ๋:
378 # ๊ฐ์ค์น ๊ณ ์
379 for param in model.parameters():
380 param.requires_grad = False
381
382 # ๋ง์ง๋ง ์ธต ๊ต์ฒด
383 model.fc = nn.Linear(in_features, num_classes)
384
385 # ImageNet ์ ๊ทํ
386 transforms.Normalize([0.485, 0.456, 0.406],
387 [0.229, 0.224, 0.225])
388"""
389print(summary)
390print("=" * 60)