1"""
213. μ€μ μ΄λ―Έμ§ λΆλ₯ νλ‘μ νΈ (CIFAR-10)
3
4CIFAR-10 λΆλ₯λ₯Ό μν μ 체 νμ΅ νμ΄νλΌμΈμ ꡬνν©λλ€.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.utils.data import DataLoader
11import numpy as np
12import matplotlib.pyplot as plt
13import time
14
15print("=" * 60)
16print("CIFAR-10 μ΄λ―Έμ§ λΆλ₯ νλ‘μ νΈ")
17print("=" * 60)
18
19device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20print(f"μ¬μ© μ₯μΉ: {device}")
21
22
23# ============================================
24# 1. λ°μ΄ν° μ€λΉ
25# ============================================
26print("\n[1] λ°μ΄ν° μ€λΉ")
27print("-" * 40)
28
29try:
30 from torchvision import datasets, transforms
31
32 # CIFAR-10 μ κ·ν κ°
33 mean = (0.4914, 0.4822, 0.4465)
34 std = (0.2470, 0.2435, 0.2616)
35
36 # νλ ¨ λ³ν (λ°μ΄ν° μ¦κ°)
37 train_transform = transforms.Compose([
38 transforms.RandomCrop(32, padding=4),
39 transforms.RandomHorizontalFlip(),
40 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
41 transforms.ToTensor(),
42 transforms.Normalize(mean, std)
43 ])
44
45 # ν
μ€νΈ λ³ν
46 test_transform = transforms.Compose([
47 transforms.ToTensor(),
48 transforms.Normalize(mean, std)
49 ])
50
51 # λ°μ΄ν°μ
λ‘λ
52 train_data = datasets.CIFAR10('data', train=True, download=True,
53 transform=train_transform)
54 test_data = datasets.CIFAR10('data', train=False,
55 transform=test_transform)
56
57 train_loader = DataLoader(train_data, batch_size=128, shuffle=True,
58 num_workers=2, pin_memory=True)
59 test_loader = DataLoader(test_data, batch_size=256)
60
61 classes = ('airplane', 'automobile', 'bird', 'cat', 'deer',
62 'dog', 'frog', 'horse', 'ship', 'truck')
63
64 print(f"νλ ¨ λ°μ΄ν°: {len(train_data)}")
65 print(f"ν
μ€νΈ λ°μ΄ν°: {len(test_data)}")
66 print(f"ν΄λμ€: {classes}")
67
68 DATA_AVAILABLE = True
69
70except Exception as e:
71 print(f"λ°μ΄ν° λ‘λ μ€ν¨: {e}")
72 print("λλ―Έ λ°μ΄ν°λ‘ μ§νν©λλ€.")
73 DATA_AVAILABLE = False
74
75
76# ============================================
77# 2. λͺ¨λΈ μ μ
78# ============================================
79print("\n[2] λͺ¨λΈ μ μ")
80print("-" * 40)
81
82class CIFAR10CNN(nn.Module):
83 """CIFAR-10μ© CNN"""
84 def __init__(self, num_classes=10):
85 super().__init__()
86 self.features = nn.Sequential(
87 # Block 1: 32 β 16
88 nn.Conv2d(3, 64, 3, padding=1),
89 nn.BatchNorm2d(64),
90 nn.ReLU(inplace=True),
91 nn.Conv2d(64, 64, 3, padding=1),
92 nn.BatchNorm2d(64),
93 nn.ReLU(inplace=True),
94 nn.MaxPool2d(2, 2),
95 nn.Dropout2d(0.25),
96
97 # Block 2: 16 β 8
98 nn.Conv2d(64, 128, 3, padding=1),
99 nn.BatchNorm2d(128),
100 nn.ReLU(inplace=True),
101 nn.Conv2d(128, 128, 3, padding=1),
102 nn.BatchNorm2d(128),
103 nn.ReLU(inplace=True),
104 nn.MaxPool2d(2, 2),
105 nn.Dropout2d(0.25),
106
107 # Block 3: 8 β 4
108 nn.Conv2d(128, 256, 3, padding=1),
109 nn.BatchNorm2d(256),
110 nn.ReLU(inplace=True),
111 nn.Conv2d(256, 256, 3, padding=1),
112 nn.BatchNorm2d(256),
113 nn.ReLU(inplace=True),
114 nn.MaxPool2d(2, 2),
115 nn.Dropout2d(0.25),
116 )
117
118 self.classifier = nn.Sequential(
119 nn.Flatten(),
120 nn.Linear(256 * 4 * 4, 512),
121 nn.ReLU(inplace=True),
122 nn.Dropout(0.5),
123 nn.Linear(512, num_classes)
124 )
125
126 def forward(self, x):
127 x = self.features(x)
128 x = self.classifier(x)
129 return x
130
131class ResBlock(nn.Module):
132 """Residual Block"""
133 def __init__(self, in_ch, out_ch, stride=1):
134 super().__init__()
135 self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
136 self.bn1 = nn.BatchNorm2d(out_ch)
137 self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
138 self.bn2 = nn.BatchNorm2d(out_ch)
139
140 self.shortcut = nn.Sequential()
141 if stride != 1 or in_ch != out_ch:
142 self.shortcut = nn.Sequential(
143 nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
144 nn.BatchNorm2d(out_ch)
145 )
146
147 def forward(self, x):
148 out = F.relu(self.bn1(self.conv1(x)))
149 out = self.bn2(self.conv2(out))
150 out += self.shortcut(x)
151 return F.relu(out)
152
153class ResNetCIFAR(nn.Module):
154 """CIFARμ© ResNet"""
155 def __init__(self, num_classes=10):
156 super().__init__()
157 self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
158 self.bn1 = nn.BatchNorm2d(64)
159
160 self.layer1 = self._make_layer(64, 64, 2, stride=1)
161 self.layer2 = self._make_layer(64, 128, 2, stride=2)
162 self.layer3 = self._make_layer(128, 256, 2, stride=2)
163
164 self.avgpool = nn.AdaptiveAvgPool2d(1)
165 self.fc = nn.Linear(256, num_classes)
166
167 def _make_layer(self, in_ch, out_ch, num_blocks, stride):
168 layers = [ResBlock(in_ch, out_ch, stride)]
169 for _ in range(1, num_blocks):
170 layers.append(ResBlock(out_ch, out_ch))
171 return nn.Sequential(*layers)
172
173 def forward(self, x):
174 x = F.relu(self.bn1(self.conv1(x)))
175 x = self.layer1(x)
176 x = self.layer2(x)
177 x = self.layer3(x)
178 x = self.avgpool(x)
179 x = x.view(x.size(0), -1)
180 x = self.fc(x)
181 return x
182
183# λͺ¨λΈ μμ±
184model = CIFAR10CNN().to(device)
185print(f"CIFAR10CNN νλΌλ―Έν°: {sum(p.numel() for p in model.parameters()):,}")
186
187resnet = ResNetCIFAR().to(device)
188print(f"ResNetCIFAR νλΌλ―Έν°: {sum(p.numel() for p in resnet.parameters()):,}")
189
190
191# ============================================
192# 3. Mixup λ°μ΄ν° μ¦κ°
193# ============================================
194print("\n[3] Mixup λ°μ΄ν° μ¦κ°")
195print("-" * 40)
196
197def mixup_data(x, y, alpha=0.2):
198 """Mixup λ°μ΄ν° μ¦κ°"""
199 if alpha > 0:
200 lam = np.random.beta(alpha, alpha)
201 else:
202 lam = 1
203
204 batch_size = x.size(0)
205 index = torch.randperm(batch_size).to(x.device)
206
207 mixed_x = lam * x + (1 - lam) * x[index]
208 y_a, y_b = y, y[index]
209 return mixed_x, y_a, y_b, lam
210
211def mixup_criterion(criterion, pred, y_a, y_b, lam):
212 """Mixup μμ€"""
213 return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
214
215# ν
μ€νΈ
216x = torch.randn(4, 3, 32, 32)
217y = torch.tensor([0, 1, 2, 3])
218mixed_x, y_a, y_b, lam = mixup_data(x, y, alpha=0.2)
219print(f"Mixup lambda: {lam:.4f}")
220
221
222# ============================================
223# 4. νμ΅ ν¨μ
224# ============================================
225print("\n[4] νμ΅ ν¨μ")
226print("-" * 40)
227
228def train_epoch(model, loader, optimizer, criterion, use_mixup=False):
229 model.train()
230 total_loss = 0
231 correct = 0
232 total = 0
233
234 for data, target in loader:
235 data, target = data.to(device), target.to(device)
236
237 if use_mixup:
238 data, target_a, target_b, lam = mixup_data(data, target)
239
240 optimizer.zero_grad()
241 output = model(data)
242
243 if use_mixup:
244 loss = mixup_criterion(criterion, output, target_a, target_b, lam)
245 else:
246 loss = criterion(output, target)
247
248 loss.backward()
249 optimizer.step()
250
251 total_loss += loss.item()
252
253 if not use_mixup:
254 pred = output.argmax(dim=1)
255 correct += (pred == target).sum().item()
256 total += target.size(0)
257
258 avg_loss = total_loss / len(loader)
259 accuracy = 100. * correct / total if total > 0 else 0
260 return avg_loss, accuracy
261
262def evaluate(model, loader, criterion):
263 model.eval()
264 total_loss = 0
265 correct = 0
266 total = 0
267
268 with torch.no_grad():
269 for data, target in loader:
270 data, target = data.to(device), target.to(device)
271 output = model(data)
272 loss = criterion(output, target)
273
274 total_loss += loss.item()
275 pred = output.argmax(dim=1)
276 correct += (pred == target).sum().item()
277 total += target.size(0)
278
279 avg_loss = total_loss / len(loader)
280 accuracy = 100. * correct / total
281 return avg_loss, accuracy
282
283
284# ============================================
285# 5. μ 체 νμ΅ νμ΄νλΌμΈ
286# ============================================
287print("\n[5] νμ΅ μ€ν")
288print("-" * 40)
289
290def train_model(model, train_loader, test_loader, epochs=10, use_mixup=False):
291 criterion = nn.CrossEntropyLoss()
292 optimizer = torch.optim.SGD(
293 model.parameters(),
294 lr=0.1,
295 momentum=0.9,
296 weight_decay=5e-4
297 )
298 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
299
300 history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
301 best_acc = 0
302
303 for epoch in range(epochs):
304 start_time = time.time()
305
306 train_loss, train_acc = train_epoch(
307 model, train_loader, optimizer, criterion, use_mixup
308 )
309 test_loss, test_acc = evaluate(model, test_loader, criterion)
310
311 scheduler.step()
312
313 history['train_loss'].append(train_loss)
314 history['train_acc'].append(train_acc)
315 history['test_loss'].append(test_loss)
316 history['test_acc'].append(test_acc)
317
318 elapsed = time.time() - start_time
319
320 if test_acc > best_acc:
321 best_acc = test_acc
322
323 print(f"Epoch {epoch+1:3d}: Train Acc={train_acc:5.2f}%, "
324 f"Test Acc={test_acc:5.2f}%, Time={elapsed:.1f}s")
325
326 print(f"\nμ΅κ³ ν
μ€νΈ μ νλ: {best_acc:.2f}%")
327 return history
328
329if DATA_AVAILABLE:
330 # μ§§μ νμ΅ (λ°λͺ¨)
331 model = CIFAR10CNN().to(device)
332 history = train_model(model, train_loader, test_loader, epochs=5)
333else:
334 print("λ°μ΄ν° μμ - νμ΅ μ€ν΅")
335 history = None
336
337
338# ============================================
339# 6. κ²°κ³Ό μκ°ν
340# ============================================
341print("\n[6] κ²°κ³Ό μκ°ν")
342print("-" * 40)
343
344if history:
345 fig, axes = plt.subplots(1, 2, figsize=(12, 4))
346
347 axes[0].plot(history['train_loss'], label='Train')
348 axes[0].plot(history['test_loss'], label='Test')
349 axes[0].set_xlabel('Epoch')
350 axes[0].set_ylabel('Loss')
351 axes[0].set_title('Loss')
352 axes[0].legend()
353 axes[0].grid(True, alpha=0.3)
354
355 axes[1].plot(history['train_acc'], label='Train')
356 axes[1].plot(history['test_acc'], label='Test')
357 axes[1].set_xlabel('Epoch')
358 axes[1].set_ylabel('Accuracy (%)')
359 axes[1].set_title('Accuracy')
360 axes[1].legend()
361 axes[1].grid(True, alpha=0.3)
362
363 plt.tight_layout()
364 plt.savefig('cifar10_training.png', dpi=100)
365 plt.close()
366 print("κ·Έλν μ μ₯: cifar10_training.png")
367
368
369# ============================================
370# 7. ν΄λμ€λ³ μ νλ
371# ============================================
372print("\n[7] ν΄λμ€λ³ λΆμ")
373print("-" * 40)
374
375if DATA_AVAILABLE:
376 def per_class_accuracy(model, loader, classes):
377 model.eval()
378 class_correct = [0] * len(classes)
379 class_total = [0] * len(classes)
380
381 with torch.no_grad():
382 for data, target in loader:
383 data, target = data.to(device), target.to(device)
384 output = model(data)
385 pred = output.argmax(dim=1)
386
387 for i in range(len(target)):
388 label = target[i].item()
389 class_total[label] += 1
390 if pred[i] == label:
391 class_correct[label] += 1
392
393 print("ν΄λμ€λ³ μ νλ:")
394 for i, cls in enumerate(classes):
395 if class_total[i] > 0:
396 acc = 100 * class_correct[i] / class_total[i]
397 print(f" {cls:12s}: {acc:5.2f}%")
398
399 per_class_accuracy(model, test_loader, classes)
400
401
402# ============================================
403# 8. μμΈ‘ μκ°ν
404# ============================================
405print("\n[8] μμΈ‘ μκ°ν")
406print("-" * 40)
407
408if DATA_AVAILABLE:
409 def visualize_predictions(model, loader, classes, n=8):
410 model.eval()
411 data, target = next(iter(loader))
412 data, target = data[:n].to(device), target[:n]
413
414 with torch.no_grad():
415 output = model(data)
416 pred = output.argmax(dim=1)
417
418 fig, axes = plt.subplots(2, 4, figsize=(12, 6))
419 for i, ax in enumerate(axes.flat):
420 if i < n:
421 img = data[i].cpu().numpy().transpose(1, 2, 0)
422 # μμ κ·ν
423 img = img * np.array(std) + np.array(mean)
424 img = np.clip(img, 0, 1)
425
426 ax.imshow(img)
427 color = 'green' if pred[i] == target[i] else 'red'
428 ax.set_title(f"Pred: {classes[pred[i]]}\nTrue: {classes[target[i]]}",
429 color=color)
430 ax.axis('off')
431
432 plt.tight_layout()
433 plt.savefig('cifar10_predictions.png', dpi=100)
434 plt.close()
435 print("μμΈ‘ μκ°ν μ μ₯: cifar10_predictions.png")
436
437 visualize_predictions(model, test_loader, classes)
438
439
440# ============================================
441# μ 리
442# ============================================
443print("\n" + "=" * 60)
444print("CIFAR-10 νλ‘μ νΈ μ 리")
445print("=" * 60)
446
447summary = """
448μ£Όμ κΈ°λ²:
449
4501. λ°μ΄ν° μ¦κ°
451 - RandomCrop, HorizontalFlip
452 - ColorJitter
453 - Mixup/CutMix
454
4552. λͺ¨λΈ ꡬ쑰
456 - Conv-BN-ReLU λΈλ‘
457 - Dropout2d, Dropout
458 - ResNet λΈλ‘ (Skip Connection)
459
4603. νμ΅ μ€μ
461 - SGD + Momentum + Weight Decay
462 - Cosine Annealing LR
463 - Label Smoothing
464
465μμ μ νλ:
466 - κΈ°λ³Έ CNN: 75-80%
467 - + λ°μ΄ν° μ¦κ°: 80-85%
468 - + Mixup: 85-88%
469 - ResNet + μ μ΄νμ΅: 90%+
470
471λ€μ λ¨κ³:
472 - λ κΉμ λͺ¨λΈ (ResNet-50)
473 - AutoAugment
474 - Knowledge Distillation
475"""
476print(summary)
477print("=" * 60)