1"""
205. CNN ๊ธฐ์ด - PyTorch ๋ฒ์
3
4ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง(CNN)์ PyTorch๋ก ๊ตฌํํฉ๋๋ค.
5MNIST์ CIFAR-10 ๋ถ๋ฅ๋ฅผ ์ํํฉ๋๋ค.
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.utils.data import DataLoader
12from torchvision import datasets, transforms
13import matplotlib.pyplot as plt
14import numpy as np
15
16print("=" * 60)
17print("PyTorch CNN ๊ธฐ์ด")
18print("=" * 60)
19
20device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21print(f"์ฌ์ฉ ์ฅ์น: {device}")
22
23
24# ============================================
25# 1. ํฉ์ฑ๊ณฑ ์ฐ์ฐ ์ดํด
26# ============================================
27print("\n[1] ํฉ์ฑ๊ณฑ ์ฐ์ฐ ์ดํด")
28print("-" * 40)
29
30# Conv2d ๊ธฐ๋ณธ
31conv = nn.Conv2d(
32 in_channels=1, # ์
๋ ฅ ์ฑ๋
33 out_channels=3, # ํํฐ ๊ฐ์ (์ถ๋ ฅ ์ฑ๋)
34 kernel_size=3, # ํํฐ ํฌ๊ธฐ
35 stride=1, # ์ด๋ ๊ฐ๊ฒฉ
36 padding=1 # ํจ๋ฉ
37)
38
39print(f"Conv2d ํ๋ผ๋ฏธํฐ:")
40print(f" weight shape: {conv.weight.shape}") # (out, in, H, W)
41print(f" bias shape: {conv.bias.shape}") # (out,)
42
43# ์
๋ ฅ/์ถ๋ ฅ ํ์ธ
44x = torch.randn(1, 1, 8, 8) # (batch, channel, H, W)
45out = conv(x)
46print(f"\n์
๋ ฅ: {x.shape} โ ์ถ๋ ฅ: {out.shape}")
47
48
49# ์ถ๋ ฅ ํฌ๊ธฐ ๊ณ์ฐ
50def calc_output_size(input_size, kernel_size, stride=1, padding=0):
51 return (input_size - kernel_size + 2 * padding) // stride + 1
52
53print("\n์ถ๋ ฅ ํฌ๊ธฐ ๊ณต์: (์
๋ ฅ - ์ปค๋ + 2รํจ๋ฉ) / ์คํธ๋ผ์ด๋ + 1")
54for k, s, p in [(3, 1, 0), (3, 1, 1), (3, 2, 0), (5, 1, 2)]:
55 out_size = calc_output_size(32, k, s, p)
56 print(f" ์
๋ ฅ=32, kernel={k}, stride={s}, pad={p} โ ์ถ๋ ฅ={out_size}")
57
58
59# ============================================
60# 2. ํ๋ง ์ฐ์ฐ
61# ============================================
62print("\n[2] ํ๋ง ์ฐ์ฐ")
63print("-" * 40)
64
65# MaxPool2d
66pool = nn.MaxPool2d(kernel_size=2, stride=2)
67
68x = torch.tensor([[[[1, 2, 3, 4],
69 [5, 6, 7, 8],
70 [9, 10, 11, 12],
71 [13, 14, 15, 16]]]], dtype=torch.float32)
72
73print(f"์
๋ ฅ:\n{x.squeeze()}")
74print(f"\nMaxPool2d(2,2) ์ถ๋ ฅ:\n{pool(x).squeeze()}")
75
76# AvgPool2d
77avg_pool = nn.AvgPool2d(2, 2)
78print(f"\nAvgPool2d(2,2) ์ถ๋ ฅ:\n{avg_pool(x).squeeze()}")
79
80
81# ============================================
82# 3. MNIST CNN
83# ============================================
84print("\n[3] MNIST CNN")
85print("-" * 40)
86
87class MNISTNet(nn.Module):
88 """MNIST์ฉ ๊ฐ๋จํ CNN"""
89 def __init__(self):
90 super().__init__()
91 # Conv ๋ธ๋ก 1: 1โ32 ์ฑ๋, 28โ14
92 self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
93 self.bn1 = nn.BatchNorm2d(32)
94 self.pool1 = nn.MaxPool2d(2, 2)
95
96 # Conv ๋ธ๋ก 2: 32โ64 ์ฑ๋, 14โ7
97 self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
98 self.bn2 = nn.BatchNorm2d(64)
99 self.pool2 = nn.MaxPool2d(2, 2)
100
101 # FC ๋ธ๋ก
102 self.fc1 = nn.Linear(64 * 7 * 7, 128)
103 self.dropout = nn.Dropout(0.5)
104 self.fc2 = nn.Linear(128, 10)
105
106 def forward(self, x):
107 x = self.pool1(F.relu(self.bn1(self.conv1(x))))
108 x = self.pool2(F.relu(self.bn2(self.conv2(x))))
109 x = x.view(-1, 64 * 7 * 7)
110 x = F.relu(self.fc1(x))
111 x = self.dropout(x)
112 x = self.fc2(x)
113 return x
114
115model = MNISTNet()
116print(model)
117
118# ํ๋ผ๋ฏธํฐ ์ ๊ณ์ฐ
119total = sum(p.numel() for p in model.parameters())
120print(f"\n์ด ํ๋ผ๋ฏธํฐ: {total:,}")
121
122
123# ============================================
124# 4. MNIST ํ์ต
125# ============================================
126print("\n[4] MNIST ํ์ต")
127print("-" * 40)
128
129# ๋ฐ์ดํฐ ๋ก๋
130transform = transforms.Compose([
131 transforms.ToTensor(),
132 transforms.Normalize((0.1307,), (0.3081,))
133])
134
135try:
136 train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
137 test_data = datasets.MNIST('data', train=False, transform=transform)
138
139 train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
140 test_loader = DataLoader(test_data, batch_size=1000)
141
142 print(f"ํ๋ จ ๋ฐ์ดํฐ: {len(train_data)} ์ํ")
143 print(f"ํ
์คํธ ๋ฐ์ดํฐ: {len(test_data)} ์ํ")
144
145 # ๋ชจ๋ธ, ์์ค, ์ตํฐ๋ง์ด์
146 model = MNISTNet().to(device)
147 criterion = nn.CrossEntropyLoss()
148 optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
149
150 # ํ์ต
151 epochs = 3
152 train_losses = []
153
154 for epoch in range(epochs):
155 model.train()
156 epoch_loss = 0
157 correct = 0
158 total = 0
159
160 for images, labels in train_loader:
161 images, labels = images.to(device), labels.to(device)
162
163 outputs = model(images)
164 loss = criterion(outputs, labels)
165
166 optimizer.zero_grad()
167 loss.backward()
168 optimizer.step()
169
170 epoch_loss += loss.item()
171 _, predicted = outputs.max(1)
172 total += labels.size(0)
173 correct += predicted.eq(labels).sum().item()
174
175 acc = 100. * correct / total
176 avg_loss = epoch_loss / len(train_loader)
177 train_losses.append(avg_loss)
178 print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Acc={acc:.2f}%")
179
180 # ํ
์คํธ
181 model.eval()
182 correct = 0
183 total = 0
184 with torch.no_grad():
185 for images, labels in test_loader:
186 images, labels = images.to(device), labels.to(device)
187 outputs = model(images)
188 _, predicted = outputs.max(1)
189 total += labels.size(0)
190 correct += predicted.eq(labels).sum().item()
191
192 print(f"\nํ
์คํธ ์ ํ๋: {100. * correct / total:.2f}%")
193
194except Exception as e:
195 print(f"MNIST ๋ก๋ ์คํจ (์คํ๋ผ์ธ?): {e}")
196 print("๋ฐ๋ชจ ๋ชจ๋๋ก ์งํํฉ๋๋ค.")
197
198 # ๋๋ฏธ ๋ฐ์ดํฐ๋ก ํ
์คํธ
199 x_dummy = torch.randn(4, 1, 28, 28)
200 model = MNISTNet()
201 out = model(x_dummy)
202 print(f"๋๋ฏธ ์
๋ ฅ: {x_dummy.shape} โ ์ถ๋ ฅ: {out.shape}")
203
204
205# ============================================
206# 5. ํน์ง ๋งต ์๊ฐํ
207# ============================================
208print("\n[5] ํน์ง ๋งต ์๊ฐํ")
209print("-" * 40)
210
211def visualize_feature_maps(model, image, layer_name='conv1'):
212 """ํน์ง ๋งต ์๊ฐํ"""
213 model.eval()
214
215 # ํ
์ผ๋ก ์ค๊ฐ ์ถ๋ ฅ ์บก์ฒ
216 activations = {}
217 def hook_fn(module, input, output):
218 activations['output'] = output.detach()
219
220 hook = getattr(model, layer_name).register_forward_hook(hook_fn)
221
222 with torch.no_grad():
223 model(image)
224
225 hook.remove()
226 feature_maps = activations['output']
227
228 # ์๊ฐํ
229 n_maps = min(16, feature_maps.shape[1])
230 fig, axes = plt.subplots(4, 4, figsize=(8, 8))
231
232 for i, ax in enumerate(axes.flat):
233 if i < n_maps:
234 ax.imshow(feature_maps[0, i].cpu().numpy(), cmap='viridis')
235 ax.axis('off')
236
237 plt.suptitle(f'{layer_name} Feature Maps')
238 plt.tight_layout()
239 plt.savefig('cnn_feature_maps.png', dpi=100)
240 plt.close()
241 print(f"ํน์ง ๋งต ์ ์ฅ: cnn_feature_maps.png")
242
243# ์๊ฐํ (ํ์ต๋ ๋ชจ๋ธ์ด ์๋ ๊ฒฝ์ฐ)
244try:
245 sample_image = train_data[0][0].unsqueeze(0).to(device)
246 visualize_feature_maps(model, sample_image, 'conv1')
247except:
248 print("์๊ฐํ ์คํต (๋ฐ์ดํฐ ์์)")
249
250
251# ============================================
252# 6. ํํฐ ์๊ฐํ
253# ============================================
254print("\n[6] ํํฐ ์๊ฐํ")
255print("-" * 40)
256
257def visualize_filters(model, layer_name='conv1'):
258 """Conv ํํฐ ์๊ฐํ"""
259 filters = getattr(model, layer_name).weight.detach().cpu()
260
261 # ์ฒซ 16๊ฐ ํํฐ
262 n_filters = min(16, filters.shape[0])
263 fig, axes = plt.subplots(4, 4, figsize=(8, 8))
264
265 for i, ax in enumerate(axes.flat):
266 if i < n_filters:
267 # ์ฒซ ๋ฒ์งธ ์
๋ ฅ ์ฑ๋์ ํํฐ
268 ax.imshow(filters[i, 0].numpy(), cmap='gray')
269 ax.axis('off')
270
271 plt.suptitle(f'{layer_name} Filters')
272 plt.tight_layout()
273 plt.savefig('cnn_filters.png', dpi=100)
274 plt.close()
275 print(f"ํํฐ ์ ์ฅ: cnn_filters.png")
276
277try:
278 visualize_filters(model, 'conv1')
279except:
280 print("ํํฐ ์๊ฐํ ์คํต")
281
282
283# ============================================
284# 7. CIFAR-10 CNN
285# ============================================
286print("\n[7] CIFAR-10 CNN")
287print("-" * 40)
288
289class CIFAR10Net(nn.Module):
290 """CIFAR-10์ฉ CNN"""
291 def __init__(self):
292 super().__init__()
293 self.features = nn.Sequential(
294 # ๋ธ๋ก 1: 3โ64, 32โ16
295 nn.Conv2d(3, 64, 3, padding=1),
296 nn.BatchNorm2d(64),
297 nn.ReLU(),
298 nn.Conv2d(64, 64, 3, padding=1),
299 nn.BatchNorm2d(64),
300 nn.ReLU(),
301 nn.MaxPool2d(2, 2),
302 nn.Dropout2d(0.25),
303
304 # ๋ธ๋ก 2: 64โ128, 16โ8
305 nn.Conv2d(64, 128, 3, padding=1),
306 nn.BatchNorm2d(128),
307 nn.ReLU(),
308 nn.Conv2d(128, 128, 3, padding=1),
309 nn.BatchNorm2d(128),
310 nn.ReLU(),
311 nn.MaxPool2d(2, 2),
312 nn.Dropout2d(0.25),
313
314 # ๋ธ๋ก 3: 128โ256, 8โ4
315 nn.Conv2d(128, 256, 3, padding=1),
316 nn.BatchNorm2d(256),
317 nn.ReLU(),
318 nn.Conv2d(256, 256, 3, padding=1),
319 nn.BatchNorm2d(256),
320 nn.ReLU(),
321 nn.MaxPool2d(2, 2),
322 nn.Dropout2d(0.25),
323 )
324
325 self.classifier = nn.Sequential(
326 nn.Flatten(),
327 nn.Linear(256 * 4 * 4, 512),
328 nn.ReLU(),
329 nn.Dropout(0.5),
330 nn.Linear(512, 10),
331 )
332
333 def forward(self, x):
334 x = self.features(x)
335 x = self.classifier(x)
336 return x
337
338cifar_model = CIFAR10Net()
339print(cifar_model)
340
341# ํ๋ผ๋ฏธํฐ ์
342total = sum(p.numel() for p in cifar_model.parameters())
343print(f"\n์ด ํ๋ผ๋ฏธํฐ: {total:,}")
344
345# ํ
์คํธ
346x_test = torch.randn(2, 3, 32, 32)
347out = cifar_model(x_test)
348print(f"์
๋ ฅ: {x_test.shape} โ ์ถ๋ ฅ: {out.shape}")
349
350
351# ============================================
352# 8. ๋ฐ์ดํฐ ์ฆ๊ฐ
353# ============================================
354print("\n[8] ๋ฐ์ดํฐ ์ฆ๊ฐ")
355print("-" * 40)
356
357train_transform = transforms.Compose([
358 transforms.RandomCrop(32, padding=4),
359 transforms.RandomHorizontalFlip(),
360 transforms.ColorJitter(brightness=0.2, contrast=0.2),
361 transforms.ToTensor(),
362 transforms.Normalize((0.4914, 0.4822, 0.4465),
363 (0.2470, 0.2435, 0.2616))
364])
365
366test_transform = transforms.Compose([
367 transforms.ToTensor(),
368 transforms.Normalize((0.4914, 0.4822, 0.4465),
369 (0.2470, 0.2435, 0.2616))
370])
371
372print("ํ๋ จ ๋ณํ: RandomCrop, Flip, ColorJitter, Normalize")
373print("ํ
์คํธ ๋ณํ: ToTensor, Normalize")
374
375
376# ============================================
377# 9. ๋ชจ๋ธ ์ ์ฅ/๋ก๋
378# ============================================
379print("\n[9] ๋ชจ๋ธ ์ ์ฅ/๋ก๋")
380print("-" * 40)
381
382# ์ ์ฅ
383torch.save(cifar_model.state_dict(), 'cifar_cnn.pth')
384print("๋ชจ๋ธ ์ ์ฅ: cifar_cnn.pth")
385
386# ๋ก๋
387loaded_model = CIFAR10Net()
388loaded_model.load_state_dict(torch.load('cifar_cnn.pth', weights_only=True))
389loaded_model.eval()
390print("๋ชจ๋ธ ๋ก๋ ์๋ฃ")
391
392
393# ============================================
394# ์ ๋ฆฌ
395# ============================================
396print("\n" + "=" * 60)
397print("CNN ๊ธฐ์ด ์ ๋ฆฌ")
398print("=" * 60)
399
400summary = """
401CNN ๊ตฌ์ฑ์์:
4021. Conv2d: ์ง์ญ ํจํด ์ถ์ถ
4032. BatchNorm2d: ํ์ต ์์ ํ
4043. ReLU: ๋น์ ํ์ฑ
4054. MaxPool2d: ๊ณต๊ฐ ์ถ์
4065. Dropout2d: ๊ณผ์ ํฉ ๋ฐฉ์ง
4076. Flatten + Linear: ๋ถ๋ฅ
408
409์ถ๋ ฅ ํฌ๊ธฐ ๊ณต์:
410 output = (input - kernel + 2*padding) / stride + 1
411
412์ผ๋ฐ์ ์ธ ํจํด:
413 Conv โ BN โ ReLU โ Pool (๋ฐ๋ณต) โ Flatten โ FC
414
415๊ถ์ฅ ์ค์ :
416- kernel_size=3, padding=1 (same padding)
417- ์ฑ๋ ์ฆ๊ฐ: 64 โ 128 โ 256
418- Pool๋ก ๊ณต๊ฐ ์ถ์
419- FC ์์ Dropout
420"""
421print(summary)
422print("=" * 60)