1"""
2PyTorch Low-Level LeNet-5 ꡬν
3
4nn.Conv2d, nn.Linear λμ F.conv2d, torch.matmul μ¬μ©
5νλΌλ―Έν°λ₯Ό μλμΌλ‘ κ΄λ¦¬
6"""
7
8import torch
9import torch.nn.functional as F
10import math
11from typing import Tuple, List
12
13
14class LeNetLowLevel:
15 """
16 LeNet-5 Low-Level ꡬν
17
18 nn.Module λ―Έμ¬μ©, F.conv2d λ± κΈ°λ³Έ μ°μ°λ§ μ¬μ©
19 """
20
21 def __init__(self, num_classes: int = 10):
22 # Device
23 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
25 # Conv1: 1 β 6 channels, 5x5 kernel
26 self.conv1_weight = self._init_conv_weight(1, 6, 5)
27 self.conv1_bias = torch.zeros(6, requires_grad=True, device=self.device)
28
29 # Conv2: 6 β 16 channels, 5x5 kernel
30 self.conv2_weight = self._init_conv_weight(6, 16, 5)
31 self.conv2_bias = torch.zeros(16, requires_grad=True, device=self.device)
32
33 # Conv3: 16 β 120 channels, 5x5 kernel
34 self.conv3_weight = self._init_conv_weight(16, 120, 5)
35 self.conv3_bias = torch.zeros(120, requires_grad=True, device=self.device)
36
37 # FC1: 120 β 84
38 self.fc1_weight = self._init_linear_weight(120, 84)
39 self.fc1_bias = torch.zeros(84, requires_grad=True, device=self.device)
40
41 # FC2: 84 β num_classes
42 self.fc2_weight = self._init_linear_weight(84, num_classes)
43 self.fc2_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
44
45 def _init_conv_weight(
46 self,
47 in_channels: int,
48 out_channels: int,
49 kernel_size: int
50 ) -> torch.Tensor:
51 """Kaiming μ΄κΈ°ν"""
52 fan_in = in_channels * kernel_size * kernel_size
53 std = math.sqrt(2.0 / fan_in)
54 weight = torch.randn(
55 out_channels, in_channels, kernel_size, kernel_size,
56 requires_grad=True, device=self.device
57 ) * std
58 return weight
59
60 def _init_linear_weight(
61 self,
62 in_features: int,
63 out_features: int
64 ) -> torch.Tensor:
65 """Xavier μ΄κΈ°ν"""
66 std = math.sqrt(2.0 / (in_features + out_features))
67 weight = torch.randn(
68 out_features, in_features,
69 requires_grad=True, device=self.device
70 ) * std
71 return weight
72
73 def forward(self, x: torch.Tensor) -> torch.Tensor:
74 """
75 Forward pass
76
77 Args:
78 x: (N, 1, 32, 32) μ
λ ₯ μ΄λ―Έμ§
79
80 Returns:
81 logits: (N, num_classes)
82 """
83 # Layer 1: Conv β ReLU β AvgPool
84 # (N, 1, 32, 32) β (N, 6, 28, 28) β (N, 6, 14, 14)
85 x = F.conv2d(x, self.conv1_weight, self.conv1_bias, stride=1, padding=0)
86 x = F.relu(x)
87 x = F.avg_pool2d(x, kernel_size=2, stride=2)
88
89 # Layer 2: Conv β ReLU β AvgPool
90 # (N, 6, 14, 14) β (N, 16, 10, 10) β (N, 16, 5, 5)
91 x = F.conv2d(x, self.conv2_weight, self.conv2_bias, stride=1, padding=0)
92 x = F.relu(x)
93 x = F.avg_pool2d(x, kernel_size=2, stride=2)
94
95 # Layer 3: Conv β ReLU
96 # (N, 16, 5, 5) β (N, 120, 1, 1)
97 x = F.conv2d(x, self.conv3_weight, self.conv3_bias, stride=1, padding=0)
98 x = F.relu(x)
99
100 # Flatten: (N, 120, 1, 1) β (N, 120)
101 x = x.view(x.size(0), -1)
102
103 # FC1: (N, 120) β (N, 84)
104 x = torch.matmul(x, self.fc1_weight.t()) + self.fc1_bias
105 x = F.relu(x)
106
107 # FC2: (N, 84) β (N, num_classes)
108 x = torch.matmul(x, self.fc2_weight.t()) + self.fc2_bias
109
110 return x
111
112 def parameters(self) -> List[torch.Tensor]:
113 """νμ΅ κ°λ₯ν νλΌλ―Έν° λ°ν"""
114 return [
115 self.conv1_weight, self.conv1_bias,
116 self.conv2_weight, self.conv2_bias,
117 self.conv3_weight, self.conv3_bias,
118 self.fc1_weight, self.fc1_bias,
119 self.fc2_weight, self.fc2_bias,
120 ]
121
122 def zero_grad(self):
123 """Gradient μ΄κΈ°ν"""
124 for param in self.parameters():
125 if param.grad is not None:
126 param.grad.zero_()
127
128 def to(self, device):
129 """Device μ΄λ"""
130 self.device = device
131 for param in self.parameters():
132 param.data = param.data.to(device)
133 if param.grad is not None:
134 param.grad = param.grad.to(device)
135 return self
136
137
138def sgd_step(params: List[torch.Tensor], lr: float):
139 """μλ SGD μ
λ°μ΄νΈ"""
140 with torch.no_grad():
141 for param in params:
142 if param.grad is not None:
143 param -= lr * param.grad
144
145
146def sgd_step_with_momentum(
147 params: List[torch.Tensor],
148 velocities: List[torch.Tensor],
149 lr: float,
150 momentum: float = 0.9
151):
152 """Momentum SGD"""
153 with torch.no_grad():
154 for param, velocity in zip(params, velocities):
155 if param.grad is not None:
156 velocity.mul_(momentum).add_(param.grad)
157 param -= lr * velocity
158
159
160def train_epoch(
161 model: LeNetLowLevel,
162 dataloader,
163 lr: float = 0.01
164) -> Tuple[float, float]:
165 """ν μν νμ΅"""
166 total_loss = 0.0
167 total_correct = 0
168 total_samples = 0
169
170 for images, labels in dataloader:
171 images = images.to(model.device)
172 labels = labels.to(model.device)
173
174 # Forward
175 logits = model.forward(images)
176
177 # Loss (Cross Entropyλ₯Ό μ§μ κ³μ°)
178 # log_softmax + nll_loss = cross_entropy
179 log_probs = F.log_softmax(logits, dim=1)
180 loss = F.nll_loss(log_probs, labels)
181
182 # Backward
183 model.zero_grad()
184 loss.backward()
185
186 # Update
187 sgd_step(model.parameters(), lr)
188
189 # Metrics
190 total_loss += loss.item() * images.size(0)
191 predictions = logits.argmax(dim=1)
192 total_correct += (predictions == labels).sum().item()
193 total_samples += images.size(0)
194
195 avg_loss = total_loss / total_samples
196 accuracy = total_correct / total_samples
197
198 return avg_loss, accuracy
199
200
201@torch.no_grad()
202def evaluate(
203 model: LeNetLowLevel,
204 dataloader
205) -> Tuple[float, float]:
206 """νκ°"""
207 total_loss = 0.0
208 total_correct = 0
209 total_samples = 0
210
211 for images, labels in dataloader:
212 images = images.to(model.device)
213 labels = labels.to(model.device)
214
215 # Forward
216 logits = model.forward(images)
217
218 # Loss
219 loss = F.cross_entropy(logits, labels)
220
221 # Metrics
222 total_loss += loss.item() * images.size(0)
223 predictions = logits.argmax(dim=1)
224 total_correct += (predictions == labels).sum().item()
225 total_samples += images.size(0)
226
227 avg_loss = total_loss / total_samples
228 accuracy = total_correct / total_samples
229
230 return avg_loss, accuracy
231
232
233def main():
234 """νμ΅ μ€ν¬λ¦½νΈ"""
235 from torchvision import datasets, transforms
236 from torch.utils.data import DataLoader
237
238 print("=== LeNet-5 Low-Level Training ===\n")
239
240 # Device
241 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
242 print(f"Device: {device}")
243
244 # λ°μ΄ν°μ
(MNIST β 32x32λ‘ λ¦¬μ¬μ΄μ¦)
245 transform = transforms.Compose([
246 transforms.Resize((32, 32)),
247 transforms.ToTensor(),
248 transforms.Normalize((0.1307,), (0.3081,))
249 ])
250
251 train_dataset = datasets.MNIST(
252 root='./data', train=True, download=True, transform=transform
253 )
254 test_dataset = datasets.MNIST(
255 root='./data', train=False, download=True, transform=transform
256 )
257
258 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
259 test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
260
261 print(f"Train samples: {len(train_dataset)}")
262 print(f"Test samples: {len(test_dataset)}\n")
263
264 # λͺ¨λΈ
265 model = LeNetLowLevel(num_classes=10)
266 model.to(device)
267
268 # νλΌλ―Έν° μ κ³μ°
269 total_params = sum(p.numel() for p in model.parameters())
270 print(f"Total parameters: {total_params:,}\n")
271
272 # νμ΅
273 epochs = 10
274 lr = 0.01
275
276 for epoch in range(epochs):
277 train_loss, train_acc = train_epoch(model, train_loader, lr)
278 test_loss, test_acc = evaluate(model, test_loader)
279
280 print(f"Epoch {epoch+1}/{epochs}")
281 print(f" Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
282 print(f" Test - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")
283
284 # Learning rate decay
285 if (epoch + 1) % 5 == 0:
286 lr *= 0.5
287 print(f" LR β {lr}")
288
289 print()
290
291 print("Training complete!")
292
293 # μ΅μ’
κ²°κ³Ό
294 final_loss, final_acc = evaluate(model, test_loader)
295 print(f"\nFinal Test Accuracy: {final_acc:.4f}")
296
297
298# Convolution μ°μ° μκ°ν
299def visualize_conv_operation():
300 """Convolution μ°μ° κ³Όμ μκ°ν"""
301 import matplotlib.pyplot as plt
302
303 # κ°λ¨ν μ
λ ₯
304 input_img = torch.zeros(1, 1, 5, 5)
305 input_img[0, 0, 1:4, 1:4] = 1.0 # μ€μμ 3x3 μ¬κ°ν
306
307 # μ£μ§ κ²μΆ νν°
308 filters = {
309 'Horizontal': torch.tensor([
310 [-1, -1, -1],
311 [ 0, 0, 0],
312 [ 1, 1, 1]
313 ]).float().view(1, 1, 3, 3),
314
315 'Vertical': torch.tensor([
316 [-1, 0, 1],
317 [-1, 0, 1],
318 [-1, 0, 1]
319 ]).float().view(1, 1, 3, 3),
320
321 'Identity': torch.tensor([
322 [0, 0, 0],
323 [0, 1, 0],
324 [0, 0, 0]
325 ]).float().view(1, 1, 3, 3),
326 }
327
328 fig, axes = plt.subplots(2, len(filters) + 1, figsize=(12, 6))
329
330 # μ
λ ₯ μ΄λ―Έμ§
331 axes[0, 0].imshow(input_img[0, 0], cmap='gray')
332 axes[0, 0].set_title('Input')
333 axes[0, 0].axis('off')
334
335 axes[1, 0].axis('off')
336
337 # κ° νν° μ μ©
338 for i, (name, kernel) in enumerate(filters.items()):
339 output = F.conv2d(input_img, kernel, padding=1)
340
341 # νν°
342 axes[0, i+1].imshow(kernel[0, 0], cmap='RdBu', vmin=-1, vmax=1)
343 axes[0, i+1].set_title(f'{name} Filter')
344 axes[0, i+1].axis('off')
345
346 # μΆλ ₯
347 axes[1, i+1].imshow(output[0, 0].detach(), cmap='gray')
348 axes[1, i+1].set_title(f'Output')
349 axes[1, i+1].axis('off')
350
351 plt.tight_layout()
352 plt.savefig('conv_visualization.png', dpi=150)
353 print("Saved conv_visualization.png")
354
355
356if __name__ == "__main__":
357 main()