13_cifar10_project.py

Download
python 478 lines 13.6 KB
  1"""
  213. μ‹€μ „ 이미지 λΆ„λ₯˜ ν”„λ‘œμ νŠΈ (CIFAR-10)
  3
  4CIFAR-10 λΆ„λ₯˜λ₯Ό μœ„ν•œ 전체 ν•™μŠ΅ νŒŒμ΄ν”„λΌμΈμ„ κ΅¬ν˜„ν•©λ‹ˆλ‹€.
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10from torch.utils.data import DataLoader
 11import numpy as np
 12import matplotlib.pyplot as plt
 13import time
 14
 15print("=" * 60)
 16print("CIFAR-10 이미지 λΆ„λ₯˜ ν”„λ‘œμ νŠΈ")
 17print("=" * 60)
 18
 19device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 20print(f"μ‚¬μš© μž₯치: {device}")
 21
 22
 23# ============================================
 24# 1. 데이터 μ€€λΉ„
 25# ============================================
 26print("\n[1] 데이터 μ€€λΉ„")
 27print("-" * 40)
 28
 29try:
 30    from torchvision import datasets, transforms
 31
 32    # CIFAR-10 μ •κ·œν™” κ°’
 33    mean = (0.4914, 0.4822, 0.4465)
 34    std = (0.2470, 0.2435, 0.2616)
 35
 36    # ν›ˆλ ¨ λ³€ν™˜ (데이터 증강)
 37    train_transform = transforms.Compose([
 38        transforms.RandomCrop(32, padding=4),
 39        transforms.RandomHorizontalFlip(),
 40        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
 41        transforms.ToTensor(),
 42        transforms.Normalize(mean, std)
 43    ])
 44
 45    # ν…ŒμŠ€νŠΈ λ³€ν™˜
 46    test_transform = transforms.Compose([
 47        transforms.ToTensor(),
 48        transforms.Normalize(mean, std)
 49    ])
 50
 51    # 데이터셋 λ‘œλ“œ
 52    train_data = datasets.CIFAR10('data', train=True, download=True,
 53                                   transform=train_transform)
 54    test_data = datasets.CIFAR10('data', train=False,
 55                                  transform=test_transform)
 56
 57    train_loader = DataLoader(train_data, batch_size=128, shuffle=True,
 58                              num_workers=2, pin_memory=True)
 59    test_loader = DataLoader(test_data, batch_size=256)
 60
 61    classes = ('airplane', 'automobile', 'bird', 'cat', 'deer',
 62               'dog', 'frog', 'horse', 'ship', 'truck')
 63
 64    print(f"ν›ˆλ ¨ 데이터: {len(train_data)}")
 65    print(f"ν…ŒμŠ€νŠΈ 데이터: {len(test_data)}")
 66    print(f"클래슀: {classes}")
 67
 68    DATA_AVAILABLE = True
 69
 70except Exception as e:
 71    print(f"데이터 λ‘œλ“œ μ‹€νŒ¨: {e}")
 72    print("더미 λ°μ΄ν„°λ‘œ μ§„ν–‰ν•©λ‹ˆλ‹€.")
 73    DATA_AVAILABLE = False
 74
 75
 76# ============================================
 77# 2. λͺ¨λΈ μ •μ˜
 78# ============================================
 79print("\n[2] λͺ¨λΈ μ •μ˜")
 80print("-" * 40)
 81
 82class CIFAR10CNN(nn.Module):
 83    """CIFAR-10용 CNN"""
 84    def __init__(self, num_classes=10):
 85        super().__init__()
 86        self.features = nn.Sequential(
 87            # Block 1: 32 β†’ 16
 88            nn.Conv2d(3, 64, 3, padding=1),
 89            nn.BatchNorm2d(64),
 90            nn.ReLU(inplace=True),
 91            nn.Conv2d(64, 64, 3, padding=1),
 92            nn.BatchNorm2d(64),
 93            nn.ReLU(inplace=True),
 94            nn.MaxPool2d(2, 2),
 95            nn.Dropout2d(0.25),
 96
 97            # Block 2: 16 β†’ 8
 98            nn.Conv2d(64, 128, 3, padding=1),
 99            nn.BatchNorm2d(128),
100            nn.ReLU(inplace=True),
101            nn.Conv2d(128, 128, 3, padding=1),
102            nn.BatchNorm2d(128),
103            nn.ReLU(inplace=True),
104            nn.MaxPool2d(2, 2),
105            nn.Dropout2d(0.25),
106
107            # Block 3: 8 β†’ 4
108            nn.Conv2d(128, 256, 3, padding=1),
109            nn.BatchNorm2d(256),
110            nn.ReLU(inplace=True),
111            nn.Conv2d(256, 256, 3, padding=1),
112            nn.BatchNorm2d(256),
113            nn.ReLU(inplace=True),
114            nn.MaxPool2d(2, 2),
115            nn.Dropout2d(0.25),
116        )
117
118        self.classifier = nn.Sequential(
119            nn.Flatten(),
120            nn.Linear(256 * 4 * 4, 512),
121            nn.ReLU(inplace=True),
122            nn.Dropout(0.5),
123            nn.Linear(512, num_classes)
124        )
125
126    def forward(self, x):
127        x = self.features(x)
128        x = self.classifier(x)
129        return x
130
131class ResBlock(nn.Module):
132    """Residual Block"""
133    def __init__(self, in_ch, out_ch, stride=1):
134        super().__init__()
135        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
136        self.bn1 = nn.BatchNorm2d(out_ch)
137        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
138        self.bn2 = nn.BatchNorm2d(out_ch)
139
140        self.shortcut = nn.Sequential()
141        if stride != 1 or in_ch != out_ch:
142            self.shortcut = nn.Sequential(
143                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
144                nn.BatchNorm2d(out_ch)
145            )
146
147    def forward(self, x):
148        out = F.relu(self.bn1(self.conv1(x)))
149        out = self.bn2(self.conv2(out))
150        out += self.shortcut(x)
151        return F.relu(out)
152
153class ResNetCIFAR(nn.Module):
154    """CIFAR용 ResNet"""
155    def __init__(self, num_classes=10):
156        super().__init__()
157        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
158        self.bn1 = nn.BatchNorm2d(64)
159
160        self.layer1 = self._make_layer(64, 64, 2, stride=1)
161        self.layer2 = self._make_layer(64, 128, 2, stride=2)
162        self.layer3 = self._make_layer(128, 256, 2, stride=2)
163
164        self.avgpool = nn.AdaptiveAvgPool2d(1)
165        self.fc = nn.Linear(256, num_classes)
166
167    def _make_layer(self, in_ch, out_ch, num_blocks, stride):
168        layers = [ResBlock(in_ch, out_ch, stride)]
169        for _ in range(1, num_blocks):
170            layers.append(ResBlock(out_ch, out_ch))
171        return nn.Sequential(*layers)
172
173    def forward(self, x):
174        x = F.relu(self.bn1(self.conv1(x)))
175        x = self.layer1(x)
176        x = self.layer2(x)
177        x = self.layer3(x)
178        x = self.avgpool(x)
179        x = x.view(x.size(0), -1)
180        x = self.fc(x)
181        return x
182
183# λͺ¨λΈ 생성
184model = CIFAR10CNN().to(device)
185print(f"CIFAR10CNN νŒŒλΌλ―Έν„°: {sum(p.numel() for p in model.parameters()):,}")
186
187resnet = ResNetCIFAR().to(device)
188print(f"ResNetCIFAR νŒŒλΌλ―Έν„°: {sum(p.numel() for p in resnet.parameters()):,}")
189
190
191# ============================================
192# 3. Mixup 데이터 증강
193# ============================================
194print("\n[3] Mixup 데이터 증강")
195print("-" * 40)
196
197def mixup_data(x, y, alpha=0.2):
198    """Mixup 데이터 증강"""
199    if alpha > 0:
200        lam = np.random.beta(alpha, alpha)
201    else:
202        lam = 1
203
204    batch_size = x.size(0)
205    index = torch.randperm(batch_size).to(x.device)
206
207    mixed_x = lam * x + (1 - lam) * x[index]
208    y_a, y_b = y, y[index]
209    return mixed_x, y_a, y_b, lam
210
211def mixup_criterion(criterion, pred, y_a, y_b, lam):
212    """Mixup 손싀"""
213    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
214
215# ν…ŒμŠ€νŠΈ
216x = torch.randn(4, 3, 32, 32)
217y = torch.tensor([0, 1, 2, 3])
218mixed_x, y_a, y_b, lam = mixup_data(x, y, alpha=0.2)
219print(f"Mixup lambda: {lam:.4f}")
220
221
222# ============================================
223# 4. ν•™μŠ΅ ν•¨μˆ˜
224# ============================================
225print("\n[4] ν•™μŠ΅ ν•¨μˆ˜")
226print("-" * 40)
227
228def train_epoch(model, loader, optimizer, criterion, use_mixup=False):
229    model.train()
230    total_loss = 0
231    correct = 0
232    total = 0
233
234    for data, target in loader:
235        data, target = data.to(device), target.to(device)
236
237        if use_mixup:
238            data, target_a, target_b, lam = mixup_data(data, target)
239
240        optimizer.zero_grad()
241        output = model(data)
242
243        if use_mixup:
244            loss = mixup_criterion(criterion, output, target_a, target_b, lam)
245        else:
246            loss = criterion(output, target)
247
248        loss.backward()
249        optimizer.step()
250
251        total_loss += loss.item()
252
253        if not use_mixup:
254            pred = output.argmax(dim=1)
255            correct += (pred == target).sum().item()
256            total += target.size(0)
257
258    avg_loss = total_loss / len(loader)
259    accuracy = 100. * correct / total if total > 0 else 0
260    return avg_loss, accuracy
261
262def evaluate(model, loader, criterion):
263    model.eval()
264    total_loss = 0
265    correct = 0
266    total = 0
267
268    with torch.no_grad():
269        for data, target in loader:
270            data, target = data.to(device), target.to(device)
271            output = model(data)
272            loss = criterion(output, target)
273
274            total_loss += loss.item()
275            pred = output.argmax(dim=1)
276            correct += (pred == target).sum().item()
277            total += target.size(0)
278
279    avg_loss = total_loss / len(loader)
280    accuracy = 100. * correct / total
281    return avg_loss, accuracy
282
283
284# ============================================
285# 5. 전체 ν•™μŠ΅ νŒŒμ΄ν”„λΌμΈ
286# ============================================
287print("\n[5] ν•™μŠ΅ μ‹€ν–‰")
288print("-" * 40)
289
290def train_model(model, train_loader, test_loader, epochs=10, use_mixup=False):
291    criterion = nn.CrossEntropyLoss()
292    optimizer = torch.optim.SGD(
293        model.parameters(),
294        lr=0.1,
295        momentum=0.9,
296        weight_decay=5e-4
297    )
298    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
299
300    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
301    best_acc = 0
302
303    for epoch in range(epochs):
304        start_time = time.time()
305
306        train_loss, train_acc = train_epoch(
307            model, train_loader, optimizer, criterion, use_mixup
308        )
309        test_loss, test_acc = evaluate(model, test_loader, criterion)
310
311        scheduler.step()
312
313        history['train_loss'].append(train_loss)
314        history['train_acc'].append(train_acc)
315        history['test_loss'].append(test_loss)
316        history['test_acc'].append(test_acc)
317
318        elapsed = time.time() - start_time
319
320        if test_acc > best_acc:
321            best_acc = test_acc
322
323        print(f"Epoch {epoch+1:3d}: Train Acc={train_acc:5.2f}%, "
324              f"Test Acc={test_acc:5.2f}%, Time={elapsed:.1f}s")
325
326    print(f"\n졜고 ν…ŒμŠ€νŠΈ 정확도: {best_acc:.2f}%")
327    return history
328
329if DATA_AVAILABLE:
330    # 짧은 ν•™μŠ΅ (데λͺ¨)
331    model = CIFAR10CNN().to(device)
332    history = train_model(model, train_loader, test_loader, epochs=5)
333else:
334    print("데이터 μ—†μŒ - ν•™μŠ΅ μŠ€ν‚΅")
335    history = None
336
337
338# ============================================
339# 6. κ²°κ³Ό μ‹œκ°ν™”
340# ============================================
341print("\n[6] κ²°κ³Ό μ‹œκ°ν™”")
342print("-" * 40)
343
344if history:
345    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
346
347    axes[0].plot(history['train_loss'], label='Train')
348    axes[0].plot(history['test_loss'], label='Test')
349    axes[0].set_xlabel('Epoch')
350    axes[0].set_ylabel('Loss')
351    axes[0].set_title('Loss')
352    axes[0].legend()
353    axes[0].grid(True, alpha=0.3)
354
355    axes[1].plot(history['train_acc'], label='Train')
356    axes[1].plot(history['test_acc'], label='Test')
357    axes[1].set_xlabel('Epoch')
358    axes[1].set_ylabel('Accuracy (%)')
359    axes[1].set_title('Accuracy')
360    axes[1].legend()
361    axes[1].grid(True, alpha=0.3)
362
363    plt.tight_layout()
364    plt.savefig('cifar10_training.png', dpi=100)
365    plt.close()
366    print("κ·Έλž˜ν”„ μ €μž₯: cifar10_training.png")
367
368
369# ============================================
370# 7. ν΄λž˜μŠ€λ³„ 정확도
371# ============================================
372print("\n[7] ν΄λž˜μŠ€λ³„ 뢄석")
373print("-" * 40)
374
375if DATA_AVAILABLE:
376    def per_class_accuracy(model, loader, classes):
377        model.eval()
378        class_correct = [0] * len(classes)
379        class_total = [0] * len(classes)
380
381        with torch.no_grad():
382            for data, target in loader:
383                data, target = data.to(device), target.to(device)
384                output = model(data)
385                pred = output.argmax(dim=1)
386
387                for i in range(len(target)):
388                    label = target[i].item()
389                    class_total[label] += 1
390                    if pred[i] == label:
391                        class_correct[label] += 1
392
393        print("ν΄λž˜μŠ€λ³„ 정확도:")
394        for i, cls in enumerate(classes):
395            if class_total[i] > 0:
396                acc = 100 * class_correct[i] / class_total[i]
397                print(f"  {cls:12s}: {acc:5.2f}%")
398
399    per_class_accuracy(model, test_loader, classes)
400
401
402# ============================================
403# 8. 예츑 μ‹œκ°ν™”
404# ============================================
405print("\n[8] 예츑 μ‹œκ°ν™”")
406print("-" * 40)
407
408if DATA_AVAILABLE:
409    def visualize_predictions(model, loader, classes, n=8):
410        model.eval()
411        data, target = next(iter(loader))
412        data, target = data[:n].to(device), target[:n]
413
414        with torch.no_grad():
415            output = model(data)
416            pred = output.argmax(dim=1)
417
418        fig, axes = plt.subplots(2, 4, figsize=(12, 6))
419        for i, ax in enumerate(axes.flat):
420            if i < n:
421                img = data[i].cpu().numpy().transpose(1, 2, 0)
422                # μ—­μ •κ·œν™”
423                img = img * np.array(std) + np.array(mean)
424                img = np.clip(img, 0, 1)
425
426                ax.imshow(img)
427                color = 'green' if pred[i] == target[i] else 'red'
428                ax.set_title(f"Pred: {classes[pred[i]]}\nTrue: {classes[target[i]]}",
429                            color=color)
430                ax.axis('off')
431
432        plt.tight_layout()
433        plt.savefig('cifar10_predictions.png', dpi=100)
434        plt.close()
435        print("예츑 μ‹œκ°ν™” μ €μž₯: cifar10_predictions.png")
436
437    visualize_predictions(model, test_loader, classes)
438
439
440# ============================================
441# 정리
442# ============================================
443print("\n" + "=" * 60)
444print("CIFAR-10 ν”„λ‘œμ νŠΈ 정리")
445print("=" * 60)
446
447summary = """
448μ£Όμš” 기법:
449
4501. 데이터 증강
451   - RandomCrop, HorizontalFlip
452   - ColorJitter
453   - Mixup/CutMix
454
4552. λͺ¨λΈ ꡬ쑰
456   - Conv-BN-ReLU 블둝
457   - Dropout2d, Dropout
458   - ResNet 블둝 (Skip Connection)
459
4603. ν•™μŠ΅ μ„€μ •
461   - SGD + Momentum + Weight Decay
462   - Cosine Annealing LR
463   - Label Smoothing
464
465μ˜ˆμƒ 정확도:
466   - κΈ°λ³Έ CNN: 75-80%
467   - + 데이터 증강: 80-85%
468   - + Mixup: 85-88%
469   - ResNet + μ „μ΄ν•™μŠ΅: 90%+
470
471λ‹€μŒ 단계:
472   - 더 κΉŠμ€ λͺ¨λΈ (ResNet-50)
473   - AutoAugment
474   - Knowledge Distillation
475"""
476print(summary)
477print("=" * 60)