05_cnn_basic.py

Download
python 423 lines 11.5 KB
  1"""
  205. CNN ๊ธฐ์ดˆ - PyTorch ๋ฒ„์ „
  3
  4ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง(CNN)์„ PyTorch๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  5MNIST์™€ CIFAR-10 ๋ถ„๋ฅ˜๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  6"""
  7
  8import torch
  9import torch.nn as nn
 10import torch.nn.functional as F
 11from torch.utils.data import DataLoader
 12from torchvision import datasets, transforms
 13import matplotlib.pyplot as plt
 14import numpy as np
 15
 16print("=" * 60)
 17print("PyTorch CNN ๊ธฐ์ดˆ")
 18print("=" * 60)
 19
 20device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 21print(f"์‚ฌ์šฉ ์žฅ์น˜: {device}")
 22
 23
 24# ============================================
 25# 1. ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ ์ดํ•ด
 26# ============================================
 27print("\n[1] ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ ์ดํ•ด")
 28print("-" * 40)
 29
 30# Conv2d ๊ธฐ๋ณธ
 31conv = nn.Conv2d(
 32    in_channels=1,    # ์ž…๋ ฅ ์ฑ„๋„
 33    out_channels=3,   # ํ•„ํ„ฐ ๊ฐœ์ˆ˜ (์ถœ๋ ฅ ์ฑ„๋„)
 34    kernel_size=3,    # ํ•„ํ„ฐ ํฌ๊ธฐ
 35    stride=1,         # ์ด๋™ ๊ฐ„๊ฒฉ
 36    padding=1         # ํŒจ๋”ฉ
 37)
 38
 39print(f"Conv2d ํŒŒ๋ผ๋ฏธํ„ฐ:")
 40print(f"  weight shape: {conv.weight.shape}")  # (out, in, H, W)
 41print(f"  bias shape: {conv.bias.shape}")       # (out,)
 42
 43# ์ž…๋ ฅ/์ถœ๋ ฅ ํ™•์ธ
 44x = torch.randn(1, 1, 8, 8)  # (batch, channel, H, W)
 45out = conv(x)
 46print(f"\n์ž…๋ ฅ: {x.shape} โ†’ ์ถœ๋ ฅ: {out.shape}")
 47
 48
 49# ์ถœ๋ ฅ ํฌ๊ธฐ ๊ณ„์‚ฐ
 50def calc_output_size(input_size, kernel_size, stride=1, padding=0):
 51    return (input_size - kernel_size + 2 * padding) // stride + 1
 52
 53print("\n์ถœ๋ ฅ ํฌ๊ธฐ ๊ณต์‹: (์ž…๋ ฅ - ์ปค๋„ + 2ร—ํŒจ๋”ฉ) / ์ŠคํŠธ๋ผ์ด๋“œ + 1")
 54for k, s, p in [(3, 1, 0), (3, 1, 1), (3, 2, 0), (5, 1, 2)]:
 55    out_size = calc_output_size(32, k, s, p)
 56    print(f"  ์ž…๋ ฅ=32, kernel={k}, stride={s}, pad={p} โ†’ ์ถœ๋ ฅ={out_size}")
 57
 58
 59# ============================================
 60# 2. ํ’€๋ง ์—ฐ์‚ฐ
 61# ============================================
 62print("\n[2] ํ’€๋ง ์—ฐ์‚ฐ")
 63print("-" * 40)
 64
 65# MaxPool2d
 66pool = nn.MaxPool2d(kernel_size=2, stride=2)
 67
 68x = torch.tensor([[[[1, 2, 3, 4],
 69                    [5, 6, 7, 8],
 70                    [9, 10, 11, 12],
 71                    [13, 14, 15, 16]]]], dtype=torch.float32)
 72
 73print(f"์ž…๋ ฅ:\n{x.squeeze()}")
 74print(f"\nMaxPool2d(2,2) ์ถœ๋ ฅ:\n{pool(x).squeeze()}")
 75
 76# AvgPool2d
 77avg_pool = nn.AvgPool2d(2, 2)
 78print(f"\nAvgPool2d(2,2) ์ถœ๋ ฅ:\n{avg_pool(x).squeeze()}")
 79
 80
 81# ============================================
 82# 3. MNIST CNN
 83# ============================================
 84print("\n[3] MNIST CNN")
 85print("-" * 40)
 86
 87class MNISTNet(nn.Module):
 88    """MNIST์šฉ ๊ฐ„๋‹จํ•œ CNN"""
 89    def __init__(self):
 90        super().__init__()
 91        # Conv ๋ธ”๋ก 1: 1โ†’32 ์ฑ„๋„, 28โ†’14
 92        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
 93        self.bn1 = nn.BatchNorm2d(32)
 94        self.pool1 = nn.MaxPool2d(2, 2)
 95
 96        # Conv ๋ธ”๋ก 2: 32โ†’64 ์ฑ„๋„, 14โ†’7
 97        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
 98        self.bn2 = nn.BatchNorm2d(64)
 99        self.pool2 = nn.MaxPool2d(2, 2)
100
101        # FC ๋ธ”๋ก
102        self.fc1 = nn.Linear(64 * 7 * 7, 128)
103        self.dropout = nn.Dropout(0.5)
104        self.fc2 = nn.Linear(128, 10)
105
106    def forward(self, x):
107        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
108        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
109        x = x.view(-1, 64 * 7 * 7)
110        x = F.relu(self.fc1(x))
111        x = self.dropout(x)
112        x = self.fc2(x)
113        return x
114
115model = MNISTNet()
116print(model)
117
118# ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ณ„์‚ฐ
119total = sum(p.numel() for p in model.parameters())
120print(f"\n์ด ํŒŒ๋ผ๋ฏธํ„ฐ: {total:,}")
121
122
123# ============================================
124# 4. MNIST ํ•™์Šต
125# ============================================
126print("\n[4] MNIST ํ•™์Šต")
127print("-" * 40)
128
129# ๋ฐ์ดํ„ฐ ๋กœ๋“œ
130transform = transforms.Compose([
131    transforms.ToTensor(),
132    transforms.Normalize((0.1307,), (0.3081,))
133])
134
135try:
136    train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
137    test_data = datasets.MNIST('data', train=False, transform=transform)
138
139    train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
140    test_loader = DataLoader(test_data, batch_size=1000)
141
142    print(f"ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ: {len(train_data)} ์ƒ˜ํ”Œ")
143    print(f"ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ: {len(test_data)} ์ƒ˜ํ”Œ")
144
145    # ๋ชจ๋ธ, ์†์‹ค, ์˜ตํ‹ฐ๋งˆ์ด์ €
146    model = MNISTNet().to(device)
147    criterion = nn.CrossEntropyLoss()
148    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
149
150    # ํ•™์Šต
151    epochs = 3
152    train_losses = []
153
154    for epoch in range(epochs):
155        model.train()
156        epoch_loss = 0
157        correct = 0
158        total = 0
159
160        for images, labels in train_loader:
161            images, labels = images.to(device), labels.to(device)
162
163            outputs = model(images)
164            loss = criterion(outputs, labels)
165
166            optimizer.zero_grad()
167            loss.backward()
168            optimizer.step()
169
170            epoch_loss += loss.item()
171            _, predicted = outputs.max(1)
172            total += labels.size(0)
173            correct += predicted.eq(labels).sum().item()
174
175        acc = 100. * correct / total
176        avg_loss = epoch_loss / len(train_loader)
177        train_losses.append(avg_loss)
178        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={acc:.2f}%")
179
180    # ํ…Œ์ŠคํŠธ
181    model.eval()
182    correct = 0
183    total = 0
184    with torch.no_grad():
185        for images, labels in test_loader:
186            images, labels = images.to(device), labels.to(device)
187            outputs = model(images)
188            _, predicted = outputs.max(1)
189            total += labels.size(0)
190            correct += predicted.eq(labels).sum().item()
191
192    print(f"\nํ…Œ์ŠคํŠธ ์ •ํ™•๋„: {100. * correct / total:.2f}%")
193
194except Exception as e:
195    print(f"MNIST ๋กœ๋“œ ์‹คํŒจ (์˜คํ”„๋ผ์ธ?): {e}")
196    print("๋ฐ๋ชจ ๋ชจ๋“œ๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.")
197
198    # ๋”๋ฏธ ๋ฐ์ดํ„ฐ๋กœ ํ…Œ์ŠคํŠธ
199    x_dummy = torch.randn(4, 1, 28, 28)
200    model = MNISTNet()
201    out = model(x_dummy)
202    print(f"๋”๋ฏธ ์ž…๋ ฅ: {x_dummy.shape} โ†’ ์ถœ๋ ฅ: {out.shape}")
203
204
205# ============================================
206# 5. ํŠน์ง• ๋งต ์‹œ๊ฐํ™”
207# ============================================
208print("\n[5] ํŠน์ง• ๋งต ์‹œ๊ฐํ™”")
209print("-" * 40)
210
211def visualize_feature_maps(model, image, layer_name='conv1'):
212    """ํŠน์ง• ๋งต ์‹œ๊ฐํ™”"""
213    model.eval()
214
215    # ํ›…์œผ๋กœ ์ค‘๊ฐ„ ์ถœ๋ ฅ ์บก์ฒ˜
216    activations = {}
217    def hook_fn(module, input, output):
218        activations['output'] = output.detach()
219
220    hook = getattr(model, layer_name).register_forward_hook(hook_fn)
221
222    with torch.no_grad():
223        model(image)
224
225    hook.remove()
226    feature_maps = activations['output']
227
228    # ์‹œ๊ฐํ™”
229    n_maps = min(16, feature_maps.shape[1])
230    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
231
232    for i, ax in enumerate(axes.flat):
233        if i < n_maps:
234            ax.imshow(feature_maps[0, i].cpu().numpy(), cmap='viridis')
235        ax.axis('off')
236
237    plt.suptitle(f'{layer_name} Feature Maps')
238    plt.tight_layout()
239    plt.savefig('cnn_feature_maps.png', dpi=100)
240    plt.close()
241    print(f"ํŠน์ง• ๋งต ์ €์žฅ: cnn_feature_maps.png")
242
243# ์‹œ๊ฐํ™” (ํ•™์Šต๋œ ๋ชจ๋ธ์ด ์žˆ๋Š” ๊ฒฝ์šฐ)
244try:
245    sample_image = train_data[0][0].unsqueeze(0).to(device)
246    visualize_feature_maps(model, sample_image, 'conv1')
247except:
248    print("์‹œ๊ฐํ™” ์Šคํ‚ต (๋ฐ์ดํ„ฐ ์—†์Œ)")
249
250
251# ============================================
252# 6. ํ•„ํ„ฐ ์‹œ๊ฐํ™”
253# ============================================
254print("\n[6] ํ•„ํ„ฐ ์‹œ๊ฐํ™”")
255print("-" * 40)
256
257def visualize_filters(model, layer_name='conv1'):
258    """Conv ํ•„ํ„ฐ ์‹œ๊ฐํ™”"""
259    filters = getattr(model, layer_name).weight.detach().cpu()
260
261    # ์ฒซ 16๊ฐœ ํ•„ํ„ฐ
262    n_filters = min(16, filters.shape[0])
263    fig, axes = plt.subplots(4, 4, figsize=(8, 8))
264
265    for i, ax in enumerate(axes.flat):
266        if i < n_filters:
267            # ์ฒซ ๋ฒˆ์งธ ์ž…๋ ฅ ์ฑ„๋„์˜ ํ•„ํ„ฐ
268            ax.imshow(filters[i, 0].numpy(), cmap='gray')
269        ax.axis('off')
270
271    plt.suptitle(f'{layer_name} Filters')
272    plt.tight_layout()
273    plt.savefig('cnn_filters.png', dpi=100)
274    plt.close()
275    print(f"ํ•„ํ„ฐ ์ €์žฅ: cnn_filters.png")
276
277try:
278    visualize_filters(model, 'conv1')
279except:
280    print("ํ•„ํ„ฐ ์‹œ๊ฐํ™” ์Šคํ‚ต")
281
282
283# ============================================
284# 7. CIFAR-10 CNN
285# ============================================
286print("\n[7] CIFAR-10 CNN")
287print("-" * 40)
288
289class CIFAR10Net(nn.Module):
290    """CIFAR-10์šฉ CNN"""
291    def __init__(self):
292        super().__init__()
293        self.features = nn.Sequential(
294            # ๋ธ”๋ก 1: 3โ†’64, 32โ†’16
295            nn.Conv2d(3, 64, 3, padding=1),
296            nn.BatchNorm2d(64),
297            nn.ReLU(),
298            nn.Conv2d(64, 64, 3, padding=1),
299            nn.BatchNorm2d(64),
300            nn.ReLU(),
301            nn.MaxPool2d(2, 2),
302            nn.Dropout2d(0.25),
303
304            # ๋ธ”๋ก 2: 64โ†’128, 16โ†’8
305            nn.Conv2d(64, 128, 3, padding=1),
306            nn.BatchNorm2d(128),
307            nn.ReLU(),
308            nn.Conv2d(128, 128, 3, padding=1),
309            nn.BatchNorm2d(128),
310            nn.ReLU(),
311            nn.MaxPool2d(2, 2),
312            nn.Dropout2d(0.25),
313
314            # ๋ธ”๋ก 3: 128โ†’256, 8โ†’4
315            nn.Conv2d(128, 256, 3, padding=1),
316            nn.BatchNorm2d(256),
317            nn.ReLU(),
318            nn.Conv2d(256, 256, 3, padding=1),
319            nn.BatchNorm2d(256),
320            nn.ReLU(),
321            nn.MaxPool2d(2, 2),
322            nn.Dropout2d(0.25),
323        )
324
325        self.classifier = nn.Sequential(
326            nn.Flatten(),
327            nn.Linear(256 * 4 * 4, 512),
328            nn.ReLU(),
329            nn.Dropout(0.5),
330            nn.Linear(512, 10),
331        )
332
333    def forward(self, x):
334        x = self.features(x)
335        x = self.classifier(x)
336        return x
337
338cifar_model = CIFAR10Net()
339print(cifar_model)
340
341# ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
342total = sum(p.numel() for p in cifar_model.parameters())
343print(f"\n์ด ํŒŒ๋ผ๋ฏธํ„ฐ: {total:,}")
344
345# ํ…Œ์ŠคํŠธ
346x_test = torch.randn(2, 3, 32, 32)
347out = cifar_model(x_test)
348print(f"์ž…๋ ฅ: {x_test.shape} โ†’ ์ถœ๋ ฅ: {out.shape}")
349
350
351# ============================================
352# 8. ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•
353# ============================================
354print("\n[8] ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•")
355print("-" * 40)
356
357train_transform = transforms.Compose([
358    transforms.RandomCrop(32, padding=4),
359    transforms.RandomHorizontalFlip(),
360    transforms.ColorJitter(brightness=0.2, contrast=0.2),
361    transforms.ToTensor(),
362    transforms.Normalize((0.4914, 0.4822, 0.4465),
363                        (0.2470, 0.2435, 0.2616))
364])
365
366test_transform = transforms.Compose([
367    transforms.ToTensor(),
368    transforms.Normalize((0.4914, 0.4822, 0.4465),
369                        (0.2470, 0.2435, 0.2616))
370])
371
372print("ํ›ˆ๋ จ ๋ณ€ํ™˜: RandomCrop, Flip, ColorJitter, Normalize")
373print("ํ…Œ์ŠคํŠธ ๋ณ€ํ™˜: ToTensor, Normalize")
374
375
376# ============================================
377# 9. ๋ชจ๋ธ ์ €์žฅ/๋กœ๋“œ
378# ============================================
379print("\n[9] ๋ชจ๋ธ ์ €์žฅ/๋กœ๋“œ")
380print("-" * 40)
381
382# ์ €์žฅ
383torch.save(cifar_model.state_dict(), 'cifar_cnn.pth')
384print("๋ชจ๋ธ ์ €์žฅ: cifar_cnn.pth")
385
386# ๋กœ๋“œ
387loaded_model = CIFAR10Net()
388loaded_model.load_state_dict(torch.load('cifar_cnn.pth', weights_only=True))
389loaded_model.eval()
390print("๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ")
391
392
393# ============================================
394# ์ •๋ฆฌ
395# ============================================
396print("\n" + "=" * 60)
397print("CNN ๊ธฐ์ดˆ ์ •๋ฆฌ")
398print("=" * 60)
399
400summary = """
401CNN ๊ตฌ์„ฑ์š”์†Œ:
4021. Conv2d: ์ง€์—ญ ํŒจํ„ด ์ถ”์ถœ
4032. BatchNorm2d: ํ•™์Šต ์•ˆ์ •ํ™”
4043. ReLU: ๋น„์„ ํ˜•์„ฑ
4054. MaxPool2d: ๊ณต๊ฐ„ ์ถ•์†Œ
4065. Dropout2d: ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€
4076. Flatten + Linear: ๋ถ„๋ฅ˜
408
409์ถœ๋ ฅ ํฌ๊ธฐ ๊ณต์‹:
410    output = (input - kernel + 2*padding) / stride + 1
411
412์ผ๋ฐ˜์ ์ธ ํŒจํ„ด:
413    Conv โ†’ BN โ†’ ReLU โ†’ Pool (๋ฐ˜๋ณต) โ†’ Flatten โ†’ FC
414
415๊ถŒ์žฅ ์„ค์ •:
416- kernel_size=3, padding=1 (same padding)
417- ์ฑ„๋„ ์ฆ๊ฐ€: 64 โ†’ 128 โ†’ 256
418- Pool๋กœ ๊ณต๊ฐ„ ์ถ•์†Œ
419- FC ์•ž์— Dropout
420"""
421print(summary)
422print("=" * 60)