1"""
2PyTorch Low-Level ResNet ๊ตฌํ
3
4nn.Conv2d, nn.BatchNorm2d ๋์ F.conv2d, ์๋ BN ์ฌ์ฉ
5BasicBlock๊ณผ Bottleneck ๋ชจ๋ ๊ตฌํ
6"""
7
8import torch
9import torch.nn.functional as F
10import math
11from typing import Tuple, List, Dict, Optional, Literal
12
13
14class BatchNorm2dManual:
15 """์๋ Batch Normalization"""
16
17 def __init__(self, num_features: int, device: torch.device):
18 self.gamma = torch.ones(num_features, requires_grad=True, device=device)
19 self.beta = torch.zeros(num_features, requires_grad=True, device=device)
20 self.running_mean = torch.zeros(num_features, device=device)
21 self.running_var = torch.ones(num_features, device=device)
22 self.momentum = 0.1
23 self.eps = 1e-5
24
25 def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
26 if training:
27 mean = x.mean(dim=(0, 2, 3), keepdim=True)
28 var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
29
30 with torch.no_grad():
31 self.running_mean = (
32 (1 - self.momentum) * self.running_mean +
33 self.momentum * mean.squeeze()
34 )
35 self.running_var = (
36 (1 - self.momentum) * self.running_var +
37 self.momentum * var.squeeze()
38 )
39 else:
40 mean = self.running_mean.view(1, -1, 1, 1)
41 var = self.running_var.view(1, -1, 1, 1)
42
43 x_norm = (x - mean) / torch.sqrt(var + self.eps)
44 gamma = self.gamma.view(1, -1, 1, 1)
45 beta = self.beta.view(1, -1, 1, 1)
46
47 return gamma * x_norm + beta
48
49 def parameters(self) -> List[torch.Tensor]:
50 return [self.gamma, self.beta]
51
52
53class ConvBN:
54 """Conv + BatchNorm ์กฐํฉ"""
55
56 def __init__(
57 self,
58 in_channels: int,
59 out_channels: int,
60 kernel_size: int,
61 stride: int = 1,
62 padding: int = 0,
63 device: torch.device = None
64 ):
65 self.stride = stride
66 self.padding = padding
67
68 # Kaiming ์ด๊ธฐํ
69 fan_in = in_channels * kernel_size * kernel_size
70 std = math.sqrt(2.0 / fan_in)
71
72 self.weight = torch.randn(
73 out_channels, in_channels, kernel_size, kernel_size,
74 requires_grad=True, device=device
75 ) * std
76
77 # Conv์ bias ์์ (BN์ด ์์ผ๋ฏ๋ก)
78 self.bn = BatchNorm2dManual(out_channels, device)
79
80 def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
81 x = F.conv2d(x, self.weight, None, self.stride, self.padding)
82 x = self.bn(x, training)
83 return x
84
85 def parameters(self) -> List[torch.Tensor]:
86 return [self.weight] + self.bn.parameters()
87
88
89class BasicBlock:
90 """
91 ResNet BasicBlock (ResNet-18, 34์ฉ)
92
93 ๊ตฌ์กฐ: Conv3ร3 โ BN โ ReLU โ Conv3ร3 โ BN โ (+shortcut) โ ReLU
94 """
95 expansion = 1
96
97 def __init__(
98 self,
99 in_channels: int,
100 out_channels: int,
101 stride: int = 1,
102 device: torch.device = None
103 ):
104 self.conv1 = ConvBN(in_channels, out_channels, 3, stride, 1, device)
105 self.conv2 = ConvBN(out_channels, out_channels, 3, 1, 1, device)
106
107 # Shortcut (์ฐจ์์ด ๋ค๋ฅผ ๋๋ง)
108 self.shortcut = None
109 if stride != 1 or in_channels != out_channels:
110 self.shortcut = ConvBN(in_channels, out_channels, 1, stride, 0, device)
111
112 def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
113 identity = x
114
115 out = self.conv1(x, training)
116 out = F.relu(out)
117 out = self.conv2(out, training)
118
119 if self.shortcut is not None:
120 identity = self.shortcut(x, training)
121
122 out = out + identity # Skip connection!
123 out = F.relu(out)
124
125 return out
126
127 def parameters(self) -> List[torch.Tensor]:
128 params = self.conv1.parameters() + self.conv2.parameters()
129 if self.shortcut is not None:
130 params += self.shortcut.parameters()
131 return params
132
133
134class Bottleneck:
135 """
136 ResNet Bottleneck (ResNet-50, 101, 152์ฉ)
137
138 ๊ตฌ์กฐ: Conv1ร1 โ BN โ ReLU โ Conv3ร3 โ BN โ ReLU โ Conv1ร1 โ BN โ (+shortcut) โ ReLU
139 """
140 expansion = 4
141
142 def __init__(
143 self,
144 in_channels: int,
145 out_channels: int,
146 stride: int = 1,
147 device: torch.device = None
148 ):
149 # Bottleneck: ์ฑ๋ ์ถ์ โ 3ร3 โ ์ฑ๋ ๋ณต์
150 self.conv1 = ConvBN(in_channels, out_channels, 1, 1, 0, device) # ์ถ์
151 self.conv2 = ConvBN(out_channels, out_channels, 3, stride, 1, device) # ์ฃผ์ ์ฐ์ฐ
152 self.conv3 = ConvBN(out_channels, out_channels * self.expansion, 1, 1, 0, device) # ๋ณต์
153
154 # Shortcut
155 self.shortcut = None
156 if stride != 1 or in_channels != out_channels * self.expansion:
157 self.shortcut = ConvBN(
158 in_channels, out_channels * self.expansion, 1, stride, 0, device
159 )
160
161 def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
162 identity = x
163
164 out = self.conv1(x, training)
165 out = F.relu(out)
166
167 out = self.conv2(out, training)
168 out = F.relu(out)
169
170 out = self.conv3(out, training)
171
172 if self.shortcut is not None:
173 identity = self.shortcut(x, training)
174
175 out = out + identity # Skip connection!
176 out = F.relu(out)
177
178 return out
179
180 def parameters(self) -> List[torch.Tensor]:
181 params = self.conv1.parameters() + self.conv2.parameters() + self.conv3.parameters()
182 if self.shortcut is not None:
183 params += self.shortcut.parameters()
184 return params
185
186
187class ResNetLowLevel:
188 """
189 ResNet Low-Level ๊ตฌํ
190
191 nn.Module ๋ฏธ์ฌ์ฉ, ์๋ ํ๋ผ๋ฏธํฐ ๊ด๋ฆฌ
192 """
193
194 CONFIGS = {
195 'resnet18': (BasicBlock, [2, 2, 2, 2]),
196 'resnet34': (BasicBlock, [3, 4, 6, 3]),
197 'resnet50': (Bottleneck, [3, 4, 6, 3]),
198 'resnet101': (Bottleneck, [3, 4, 23, 3]),
199 'resnet152': (Bottleneck, [3, 8, 36, 3]),
200 }
201
202 def __init__(
203 self,
204 config_name: str = 'resnet50',
205 num_classes: int = 1000,
206 input_channels: int = 3
207 ):
208 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
209
210 block_class, num_blocks = self.CONFIGS[config_name]
211 self.expansion = block_class.expansion
212
213 # Stem: Conv7ร7 + BN + ReLU + MaxPool
214 fan_in = input_channels * 7 * 7
215 std = math.sqrt(2.0 / fan_in)
216 self.conv1_weight = torch.randn(
217 64, input_channels, 7, 7,
218 requires_grad=True, device=self.device
219 ) * std
220 self.bn1 = BatchNorm2dManual(64, self.device)
221
222 # Residual Layers
223 self.in_channels = 64
224 self.layer1 = self._make_layer(block_class, 64, num_blocks[0], stride=1)
225 self.layer2 = self._make_layer(block_class, 128, num_blocks[1], stride=2)
226 self.layer3 = self._make_layer(block_class, 256, num_blocks[2], stride=2)
227 self.layer4 = self._make_layer(block_class, 512, num_blocks[3], stride=2)
228
229 # Classifier
230 fc_in = 512 * self.expansion
231 std = math.sqrt(2.0 / (fc_in + num_classes))
232 self.fc_weight = torch.randn(
233 num_classes, fc_in,
234 requires_grad=True, device=self.device
235 ) * std
236 self.fc_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
237
238 def _make_layer(
239 self,
240 block_class,
241 out_channels: int,
242 num_blocks: int,
243 stride: int
244 ) -> List:
245 """๋ ์ด์ด (์ฌ๋ฌ ๋ธ๋ก) ์์ฑ"""
246 blocks = []
247
248 # ์ฒซ ๋ฒ์งธ ๋ธ๋ก: stride ์ ์ฉ, ์ฐจ์ ๋ณ๊ฒฝ
249 blocks.append(block_class(
250 self.in_channels, out_channels, stride, self.device
251 ))
252 self.in_channels = out_channels * self.expansion
253
254 # ๋๋จธ์ง ๋ธ๋ก: stride=1
255 for _ in range(1, num_blocks):
256 blocks.append(block_class(
257 self.in_channels, out_channels, 1, self.device
258 ))
259
260 return blocks
261
262 def forward(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
263 """
264 Forward pass
265
266 Args:
267 x: (N, C, H, W) ์
๋ ฅ ์ด๋ฏธ์ง
268 training: ํ์ต ๋ชจ๋
269
270 Returns:
271 logits: (N, num_classes)
272 """
273 # Stem
274 x = F.conv2d(x, self.conv1_weight, None, stride=2, padding=3)
275 x = self.bn1(x, training)
276 x = F.relu(x)
277 x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
278
279 # Residual Layers
280 for block in self.layer1:
281 x = block(x, training)
282 for block in self.layer2:
283 x = block(x, training)
284 for block in self.layer3:
285 x = block(x, training)
286 for block in self.layer4:
287 x = block(x, training)
288
289 # Global Average Pooling
290 x = F.adaptive_avg_pool2d(x, (1, 1))
291 x = x.view(x.size(0), -1)
292
293 # Classifier
294 x = torch.matmul(x, self.fc_weight.t()) + self.fc_bias
295
296 return x
297
298 def parameters(self) -> List[torch.Tensor]:
299 """ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ ๋ฐํ"""
300 params = [self.conv1_weight] + self.bn1.parameters()
301
302 for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
303 for block in layer:
304 params += block.parameters()
305
306 params += [self.fc_weight, self.fc_bias]
307 return params
308
309 def zero_grad(self):
310 """Gradient ์ด๊ธฐํ"""
311 for param in self.parameters():
312 if param.grad is not None:
313 param.grad.zero_()
314
315 def to(self, device):
316 """Device ์ด๋"""
317 self.device = device
318
319 # Stem
320 self.conv1_weight = self.conv1_weight.to(device)
321 self.bn1.gamma = self.bn1.gamma.to(device)
322 self.bn1.beta = self.bn1.beta.to(device)
323 self.bn1.running_mean = self.bn1.running_mean.to(device)
324 self.bn1.running_var = self.bn1.running_var.to(device)
325
326 # FC
327 self.fc_weight = self.fc_weight.to(device)
328 self.fc_bias = self.fc_bias.to(device)
329
330 return self
331
332 def count_parameters(self) -> int:
333 """ํ๋ผ๋ฏธํฐ ์ ๊ณ์ฐ"""
334 return sum(p.numel() for p in self.parameters())
335
336
337class ResNetSmall(ResNetLowLevel):
338 """CIFAR-10์ฉ ์์ ResNet"""
339
340 def __init__(
341 self,
342 config_name: str = 'resnet18',
343 num_classes: int = 10,
344 input_channels: int = 3
345 ):
346 # ๋ถ๋ชจ ์ด๊ธฐํ ๊ฑด๋๋ฐ๊ณ ์ง์ ๊ตฌ์ฑ
347 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
348
349 block_class, num_blocks = self.CONFIGS[config_name]
350 self.expansion = block_class.expansion
351
352 # Stem: 3ร3 Conv (7ร7 ๋์ )
353 fan_in = input_channels * 3 * 3
354 std = math.sqrt(2.0 / fan_in)
355 self.conv1_weight = torch.randn(
356 64, input_channels, 3, 3,
357 requires_grad=True, device=self.device
358 ) * std
359 self.bn1 = BatchNorm2dManual(64, self.device)
360
361 # MaxPool ์์ (32ร32 ์
๋ ฅ์ด๋ฏ๋ก)
362
363 # Residual Layers
364 self.in_channels = 64
365 self.layer1 = self._make_layer(block_class, 64, num_blocks[0], stride=1)
366 self.layer2 = self._make_layer(block_class, 128, num_blocks[1], stride=2)
367 self.layer3 = self._make_layer(block_class, 256, num_blocks[2], stride=2)
368 self.layer4 = self._make_layer(block_class, 512, num_blocks[3], stride=2)
369
370 # Classifier
371 fc_in = 512 * self.expansion
372 std = math.sqrt(2.0 / (fc_in + num_classes))
373 self.fc_weight = torch.randn(
374 num_classes, fc_in,
375 requires_grad=True, device=self.device
376 ) * std
377 self.fc_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
378
379 def forward(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
380 # Stem (MaxPool ์์)
381 x = F.conv2d(x, self.conv1_weight, None, stride=1, padding=1)
382 x = self.bn1(x, training)
383 x = F.relu(x)
384
385 # Residual Layers
386 for block in self.layer1:
387 x = block(x, training)
388 for block in self.layer2:
389 x = block(x, training)
390 for block in self.layer3:
391 x = block(x, training)
392 for block in self.layer4:
393 x = block(x, training)
394
395 # Global Average Pooling
396 x = F.adaptive_avg_pool2d(x, (1, 1))
397 x = x.view(x.size(0), -1)
398
399 # Classifier
400 x = torch.matmul(x, self.fc_weight.t()) + self.fc_bias
401
402 return x
403
404
405def sgd_step_with_momentum(
406 params: List[torch.Tensor],
407 velocities: List[torch.Tensor],
408 lr: float,
409 momentum: float = 0.9,
410 weight_decay: float = 1e-4
411):
412 """Momentum SGD with Weight Decay"""
413 with torch.no_grad():
414 for param, velocity in zip(params, velocities):
415 if param.grad is not None:
416 param.grad.add_(param, alpha=weight_decay)
417 velocity.mul_(momentum).add_(param.grad)
418 param.sub_(velocity, alpha=lr)
419
420
421def train_epoch(
422 model: ResNetLowLevel,
423 dataloader,
424 lr: float,
425 velocities: List[torch.Tensor]
426) -> Tuple[float, float]:
427 """ํ ์ํญ ํ์ต"""
428 total_loss = 0.0
429 total_correct = 0
430 total_samples = 0
431
432 for images, labels in dataloader:
433 images = images.to(model.device)
434 labels = labels.to(model.device)
435
436 logits = model.forward(images, training=True)
437 loss = F.cross_entropy(logits, labels)
438
439 model.zero_grad()
440 loss.backward()
441
442 sgd_step_with_momentum(model.parameters(), velocities, lr)
443
444 total_loss += loss.item() * images.size(0)
445 predictions = logits.argmax(dim=1)
446 total_correct += (predictions == labels).sum().item()
447 total_samples += images.size(0)
448
449 return total_loss / total_samples, total_correct / total_samples
450
451
452@torch.no_grad()
453def evaluate(model: ResNetLowLevel, dataloader) -> Tuple[float, float]:
454 """ํ๊ฐ"""
455 total_loss = 0.0
456 total_correct = 0
457 total_samples = 0
458
459 for images, labels in dataloader:
460 images = images.to(model.device)
461 labels = labels.to(model.device)
462
463 logits = model.forward(images, training=False)
464 loss = F.cross_entropy(logits, labels)
465
466 total_loss += loss.item() * images.size(0)
467 predictions = logits.argmax(dim=1)
468 total_correct += (predictions == labels).sum().item()
469 total_samples += images.size(0)
470
471 return total_loss / total_samples, total_correct / total_samples
472
473
474def visualize_gradient_flow(model: ResNetLowLevel):
475 """๊ฐ ๋ ์ด์ด์ gradient ํฌ๊ธฐ ์๊ฐํ"""
476 import matplotlib.pyplot as plt
477
478 gradients = []
479 names = []
480
481 for i, block in enumerate(model.layer1 + model.layer2 + model.layer3 + model.layer4):
482 for j, param in enumerate(block.parameters()):
483 if param.grad is not None:
484 gradients.append(param.grad.abs().mean().item())
485 names.append(f"block{i}_param{j}")
486
487 plt.figure(figsize=(12, 4))
488 plt.bar(range(len(gradients)), gradients)
489 plt.xlabel('Layer')
490 plt.ylabel('Mean |Gradient|')
491 plt.title('Gradient Flow through ResNet')
492 plt.tight_layout()
493 plt.savefig('gradient_flow.png')
494 print("Saved gradient_flow.png")
495
496
497def main():
498 """CIFAR-10์ผ๋ก ResNet ํ์ต ๋ฐ๋ชจ"""
499 from torchvision import datasets, transforms
500 from torch.utils.data import DataLoader
501
502 print("=== ResNet Low-Level Training (CIFAR-10) ===\n")
503
504 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
505 print(f"Device: {device}")
506
507 # ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
508 transform_train = transforms.Compose([
509 transforms.RandomCrop(32, padding=4),
510 transforms.RandomHorizontalFlip(),
511 transforms.ToTensor(),
512 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
513 ])
514
515 transform_test = transforms.Compose([
516 transforms.ToTensor(),
517 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
518 ])
519
520 train_dataset = datasets.CIFAR10(
521 root='./data', train=True, download=True, transform=transform_train
522 )
523 test_dataset = datasets.CIFAR10(
524 root='./data', train=False, download=True, transform=transform_test
525 )
526
527 train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
528 test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)
529
530 print(f"Train samples: {len(train_dataset)}")
531 print(f"Test samples: {len(test_dataset)}\n")
532
533 # ๋ชจ๋ธ
534 model = ResNetSmall(config_name='resnet18', num_classes=10)
535 model.to(device)
536
537 print(f"ResNet-18 for CIFAR-10")
538 print(f"Total parameters: {model.count_parameters():,}\n")
539
540 # Velocity ์ด๊ธฐํ
541 velocities = [torch.zeros_like(p) for p in model.parameters()]
542
543 # ํ์ต
544 epochs = 100
545 lr = 0.1
546
547 for epoch in range(epochs):
548 # Learning rate schedule
549 if epoch in [30, 60, 80]:
550 lr *= 0.1
551 print(f"LR โ {lr}")
552
553 train_loss, train_acc = train_epoch(model, train_loader, lr, velocities)
554
555 if (epoch + 1) % 10 == 0:
556 test_loss, test_acc = evaluate(model, test_loader)
557 print(f"Epoch {epoch+1}/{epochs}")
558 print(f" Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
559 print(f" Test - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}\n")
560
561 final_loss, final_acc = evaluate(model, test_loader)
562 print(f"Final Test Accuracy: {final_acc:.4f}")
563
564
565if __name__ == "__main__":
566 main()