resnet_lowlevel.py

Download
python 567 lines 17.3 KB
  1"""
  2PyTorch Low-Level ResNet ๊ตฌํ˜„
  3
  4nn.Conv2d, nn.BatchNorm2d ๋Œ€์‹  F.conv2d, ์ˆ˜๋™ BN ์‚ฌ์šฉ
  5BasicBlock๊ณผ Bottleneck ๋ชจ๋‘ ๊ตฌํ˜„
  6"""
  7
  8import torch
  9import torch.nn.functional as F
 10import math
 11from typing import Tuple, List, Dict, Optional, Literal
 12
 13
 14class BatchNorm2dManual:
 15    """์ˆ˜๋™ Batch Normalization"""
 16
 17    def __init__(self, num_features: int, device: torch.device):
 18        self.gamma = torch.ones(num_features, requires_grad=True, device=device)
 19        self.beta = torch.zeros(num_features, requires_grad=True, device=device)
 20        self.running_mean = torch.zeros(num_features, device=device)
 21        self.running_var = torch.ones(num_features, device=device)
 22        self.momentum = 0.1
 23        self.eps = 1e-5
 24
 25    def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
 26        if training:
 27            mean = x.mean(dim=(0, 2, 3), keepdim=True)
 28            var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
 29
 30            with torch.no_grad():
 31                self.running_mean = (
 32                    (1 - self.momentum) * self.running_mean +
 33                    self.momentum * mean.squeeze()
 34                )
 35                self.running_var = (
 36                    (1 - self.momentum) * self.running_var +
 37                    self.momentum * var.squeeze()
 38                )
 39        else:
 40            mean = self.running_mean.view(1, -1, 1, 1)
 41            var = self.running_var.view(1, -1, 1, 1)
 42
 43        x_norm = (x - mean) / torch.sqrt(var + self.eps)
 44        gamma = self.gamma.view(1, -1, 1, 1)
 45        beta = self.beta.view(1, -1, 1, 1)
 46
 47        return gamma * x_norm + beta
 48
 49    def parameters(self) -> List[torch.Tensor]:
 50        return [self.gamma, self.beta]
 51
 52
 53class ConvBN:
 54    """Conv + BatchNorm ์กฐํ•ฉ"""
 55
 56    def __init__(
 57        self,
 58        in_channels: int,
 59        out_channels: int,
 60        kernel_size: int,
 61        stride: int = 1,
 62        padding: int = 0,
 63        device: torch.device = None
 64    ):
 65        self.stride = stride
 66        self.padding = padding
 67
 68        # Kaiming ์ดˆ๊ธฐํ™”
 69        fan_in = in_channels * kernel_size * kernel_size
 70        std = math.sqrt(2.0 / fan_in)
 71
 72        self.weight = torch.randn(
 73            out_channels, in_channels, kernel_size, kernel_size,
 74            requires_grad=True, device=device
 75        ) * std
 76
 77        # Conv์— bias ์—†์Œ (BN์ด ์žˆ์œผ๋ฏ€๋กœ)
 78        self.bn = BatchNorm2dManual(out_channels, device)
 79
 80    def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
 81        x = F.conv2d(x, self.weight, None, self.stride, self.padding)
 82        x = self.bn(x, training)
 83        return x
 84
 85    def parameters(self) -> List[torch.Tensor]:
 86        return [self.weight] + self.bn.parameters()
 87
 88
 89class BasicBlock:
 90    """
 91    ResNet BasicBlock (ResNet-18, 34์šฉ)
 92
 93    ๊ตฌ์กฐ: Conv3ร—3 โ†’ BN โ†’ ReLU โ†’ Conv3ร—3 โ†’ BN โ†’ (+shortcut) โ†’ ReLU
 94    """
 95    expansion = 1
 96
 97    def __init__(
 98        self,
 99        in_channels: int,
100        out_channels: int,
101        stride: int = 1,
102        device: torch.device = None
103    ):
104        self.conv1 = ConvBN(in_channels, out_channels, 3, stride, 1, device)
105        self.conv2 = ConvBN(out_channels, out_channels, 3, 1, 1, device)
106
107        # Shortcut (์ฐจ์›์ด ๋‹ค๋ฅผ ๋•Œ๋งŒ)
108        self.shortcut = None
109        if stride != 1 or in_channels != out_channels:
110            self.shortcut = ConvBN(in_channels, out_channels, 1, stride, 0, device)
111
112    def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
113        identity = x
114
115        out = self.conv1(x, training)
116        out = F.relu(out)
117        out = self.conv2(out, training)
118
119        if self.shortcut is not None:
120            identity = self.shortcut(x, training)
121
122        out = out + identity  # Skip connection!
123        out = F.relu(out)
124
125        return out
126
127    def parameters(self) -> List[torch.Tensor]:
128        params = self.conv1.parameters() + self.conv2.parameters()
129        if self.shortcut is not None:
130            params += self.shortcut.parameters()
131        return params
132
133
134class Bottleneck:
135    """
136    ResNet Bottleneck (ResNet-50, 101, 152์šฉ)
137
138    ๊ตฌ์กฐ: Conv1ร—1 โ†’ BN โ†’ ReLU โ†’ Conv3ร—3 โ†’ BN โ†’ ReLU โ†’ Conv1ร—1 โ†’ BN โ†’ (+shortcut) โ†’ ReLU
139    """
140    expansion = 4
141
142    def __init__(
143        self,
144        in_channels: int,
145        out_channels: int,
146        stride: int = 1,
147        device: torch.device = None
148    ):
149        # Bottleneck: ์ฑ„๋„ ์ถ•์†Œ โ†’ 3ร—3 โ†’ ์ฑ„๋„ ๋ณต์›
150        self.conv1 = ConvBN(in_channels, out_channels, 1, 1, 0, device)  # ์ถ•์†Œ
151        self.conv2 = ConvBN(out_channels, out_channels, 3, stride, 1, device)  # ์ฃผ์š” ์—ฐ์‚ฐ
152        self.conv3 = ConvBN(out_channels, out_channels * self.expansion, 1, 1, 0, device)  # ๋ณต์›
153
154        # Shortcut
155        self.shortcut = None
156        if stride != 1 or in_channels != out_channels * self.expansion:
157            self.shortcut = ConvBN(
158                in_channels, out_channels * self.expansion, 1, stride, 0, device
159            )
160
161    def __call__(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
162        identity = x
163
164        out = self.conv1(x, training)
165        out = F.relu(out)
166
167        out = self.conv2(out, training)
168        out = F.relu(out)
169
170        out = self.conv3(out, training)
171
172        if self.shortcut is not None:
173            identity = self.shortcut(x, training)
174
175        out = out + identity  # Skip connection!
176        out = F.relu(out)
177
178        return out
179
180    def parameters(self) -> List[torch.Tensor]:
181        params = self.conv1.parameters() + self.conv2.parameters() + self.conv3.parameters()
182        if self.shortcut is not None:
183            params += self.shortcut.parameters()
184        return params
185
186
187class ResNetLowLevel:
188    """
189    ResNet Low-Level ๊ตฌํ˜„
190
191    nn.Module ๋ฏธ์‚ฌ์šฉ, ์ˆ˜๋™ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ด€๋ฆฌ
192    """
193
194    CONFIGS = {
195        'resnet18': (BasicBlock, [2, 2, 2, 2]),
196        'resnet34': (BasicBlock, [3, 4, 6, 3]),
197        'resnet50': (Bottleneck, [3, 4, 6, 3]),
198        'resnet101': (Bottleneck, [3, 4, 23, 3]),
199        'resnet152': (Bottleneck, [3, 8, 36, 3]),
200    }
201
202    def __init__(
203        self,
204        config_name: str = 'resnet50',
205        num_classes: int = 1000,
206        input_channels: int = 3
207    ):
208        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
209
210        block_class, num_blocks = self.CONFIGS[config_name]
211        self.expansion = block_class.expansion
212
213        # Stem: Conv7ร—7 + BN + ReLU + MaxPool
214        fan_in = input_channels * 7 * 7
215        std = math.sqrt(2.0 / fan_in)
216        self.conv1_weight = torch.randn(
217            64, input_channels, 7, 7,
218            requires_grad=True, device=self.device
219        ) * std
220        self.bn1 = BatchNorm2dManual(64, self.device)
221
222        # Residual Layers
223        self.in_channels = 64
224        self.layer1 = self._make_layer(block_class, 64, num_blocks[0], stride=1)
225        self.layer2 = self._make_layer(block_class, 128, num_blocks[1], stride=2)
226        self.layer3 = self._make_layer(block_class, 256, num_blocks[2], stride=2)
227        self.layer4 = self._make_layer(block_class, 512, num_blocks[3], stride=2)
228
229        # Classifier
230        fc_in = 512 * self.expansion
231        std = math.sqrt(2.0 / (fc_in + num_classes))
232        self.fc_weight = torch.randn(
233            num_classes, fc_in,
234            requires_grad=True, device=self.device
235        ) * std
236        self.fc_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
237
238    def _make_layer(
239        self,
240        block_class,
241        out_channels: int,
242        num_blocks: int,
243        stride: int
244    ) -> List:
245        """๋ ˆ์ด์–ด (์—ฌ๋Ÿฌ ๋ธ”๋ก) ์ƒ์„ฑ"""
246        blocks = []
247
248        # ์ฒซ ๋ฒˆ์งธ ๋ธ”๋ก: stride ์ ์šฉ, ์ฐจ์› ๋ณ€๊ฒฝ
249        blocks.append(block_class(
250            self.in_channels, out_channels, stride, self.device
251        ))
252        self.in_channels = out_channels * self.expansion
253
254        # ๋‚˜๋จธ์ง€ ๋ธ”๋ก: stride=1
255        for _ in range(1, num_blocks):
256            blocks.append(block_class(
257                self.in_channels, out_channels, 1, self.device
258            ))
259
260        return blocks
261
262    def forward(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
263        """
264        Forward pass
265
266        Args:
267            x: (N, C, H, W) ์ž…๋ ฅ ์ด๋ฏธ์ง€
268            training: ํ•™์Šต ๋ชจ๋“œ
269
270        Returns:
271            logits: (N, num_classes)
272        """
273        # Stem
274        x = F.conv2d(x, self.conv1_weight, None, stride=2, padding=3)
275        x = self.bn1(x, training)
276        x = F.relu(x)
277        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
278
279        # Residual Layers
280        for block in self.layer1:
281            x = block(x, training)
282        for block in self.layer2:
283            x = block(x, training)
284        for block in self.layer3:
285            x = block(x, training)
286        for block in self.layer4:
287            x = block(x, training)
288
289        # Global Average Pooling
290        x = F.adaptive_avg_pool2d(x, (1, 1))
291        x = x.view(x.size(0), -1)
292
293        # Classifier
294        x = torch.matmul(x, self.fc_weight.t()) + self.fc_bias
295
296        return x
297
298    def parameters(self) -> List[torch.Tensor]:
299        """ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ ๋ฐ˜ํ™˜"""
300        params = [self.conv1_weight] + self.bn1.parameters()
301
302        for layer in [self.layer1, self.layer2, self.layer3, self.layer4]:
303            for block in layer:
304                params += block.parameters()
305
306        params += [self.fc_weight, self.fc_bias]
307        return params
308
309    def zero_grad(self):
310        """Gradient ์ดˆ๊ธฐํ™”"""
311        for param in self.parameters():
312            if param.grad is not None:
313                param.grad.zero_()
314
315    def to(self, device):
316        """Device ์ด๋™"""
317        self.device = device
318
319        # Stem
320        self.conv1_weight = self.conv1_weight.to(device)
321        self.bn1.gamma = self.bn1.gamma.to(device)
322        self.bn1.beta = self.bn1.beta.to(device)
323        self.bn1.running_mean = self.bn1.running_mean.to(device)
324        self.bn1.running_var = self.bn1.running_var.to(device)
325
326        # FC
327        self.fc_weight = self.fc_weight.to(device)
328        self.fc_bias = self.fc_bias.to(device)
329
330        return self
331
332    def count_parameters(self) -> int:
333        """ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ณ„์‚ฐ"""
334        return sum(p.numel() for p in self.parameters())
335
336
337class ResNetSmall(ResNetLowLevel):
338    """CIFAR-10์šฉ ์ž‘์€ ResNet"""
339
340    def __init__(
341        self,
342        config_name: str = 'resnet18',
343        num_classes: int = 10,
344        input_channels: int = 3
345    ):
346        # ๋ถ€๋ชจ ์ดˆ๊ธฐํ™” ๊ฑด๋„ˆ๋›ฐ๊ณ  ์ง์ ‘ ๊ตฌ์„ฑ
347        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
348
349        block_class, num_blocks = self.CONFIGS[config_name]
350        self.expansion = block_class.expansion
351
352        # Stem: 3ร—3 Conv (7ร—7 ๋Œ€์‹ )
353        fan_in = input_channels * 3 * 3
354        std = math.sqrt(2.0 / fan_in)
355        self.conv1_weight = torch.randn(
356            64, input_channels, 3, 3,
357            requires_grad=True, device=self.device
358        ) * std
359        self.bn1 = BatchNorm2dManual(64, self.device)
360
361        # MaxPool ์—†์Œ (32ร—32 ์ž…๋ ฅ์ด๋ฏ€๋กœ)
362
363        # Residual Layers
364        self.in_channels = 64
365        self.layer1 = self._make_layer(block_class, 64, num_blocks[0], stride=1)
366        self.layer2 = self._make_layer(block_class, 128, num_blocks[1], stride=2)
367        self.layer3 = self._make_layer(block_class, 256, num_blocks[2], stride=2)
368        self.layer4 = self._make_layer(block_class, 512, num_blocks[3], stride=2)
369
370        # Classifier
371        fc_in = 512 * self.expansion
372        std = math.sqrt(2.0 / (fc_in + num_classes))
373        self.fc_weight = torch.randn(
374            num_classes, fc_in,
375            requires_grad=True, device=self.device
376        ) * std
377        self.fc_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
378
379    def forward(self, x: torch.Tensor, training: bool = True) -> torch.Tensor:
380        # Stem (MaxPool ์—†์Œ)
381        x = F.conv2d(x, self.conv1_weight, None, stride=1, padding=1)
382        x = self.bn1(x, training)
383        x = F.relu(x)
384
385        # Residual Layers
386        for block in self.layer1:
387            x = block(x, training)
388        for block in self.layer2:
389            x = block(x, training)
390        for block in self.layer3:
391            x = block(x, training)
392        for block in self.layer4:
393            x = block(x, training)
394
395        # Global Average Pooling
396        x = F.adaptive_avg_pool2d(x, (1, 1))
397        x = x.view(x.size(0), -1)
398
399        # Classifier
400        x = torch.matmul(x, self.fc_weight.t()) + self.fc_bias
401
402        return x
403
404
405def sgd_step_with_momentum(
406    params: List[torch.Tensor],
407    velocities: List[torch.Tensor],
408    lr: float,
409    momentum: float = 0.9,
410    weight_decay: float = 1e-4
411):
412    """Momentum SGD with Weight Decay"""
413    with torch.no_grad():
414        for param, velocity in zip(params, velocities):
415            if param.grad is not None:
416                param.grad.add_(param, alpha=weight_decay)
417                velocity.mul_(momentum).add_(param.grad)
418                param.sub_(velocity, alpha=lr)
419
420
421def train_epoch(
422    model: ResNetLowLevel,
423    dataloader,
424    lr: float,
425    velocities: List[torch.Tensor]
426) -> Tuple[float, float]:
427    """ํ•œ ์—ํญ ํ•™์Šต"""
428    total_loss = 0.0
429    total_correct = 0
430    total_samples = 0
431
432    for images, labels in dataloader:
433        images = images.to(model.device)
434        labels = labels.to(model.device)
435
436        logits = model.forward(images, training=True)
437        loss = F.cross_entropy(logits, labels)
438
439        model.zero_grad()
440        loss.backward()
441
442        sgd_step_with_momentum(model.parameters(), velocities, lr)
443
444        total_loss += loss.item() * images.size(0)
445        predictions = logits.argmax(dim=1)
446        total_correct += (predictions == labels).sum().item()
447        total_samples += images.size(0)
448
449    return total_loss / total_samples, total_correct / total_samples
450
451
452@torch.no_grad()
453def evaluate(model: ResNetLowLevel, dataloader) -> Tuple[float, float]:
454    """ํ‰๊ฐ€"""
455    total_loss = 0.0
456    total_correct = 0
457    total_samples = 0
458
459    for images, labels in dataloader:
460        images = images.to(model.device)
461        labels = labels.to(model.device)
462
463        logits = model.forward(images, training=False)
464        loss = F.cross_entropy(logits, labels)
465
466        total_loss += loss.item() * images.size(0)
467        predictions = logits.argmax(dim=1)
468        total_correct += (predictions == labels).sum().item()
469        total_samples += images.size(0)
470
471    return total_loss / total_samples, total_correct / total_samples
472
473
474def visualize_gradient_flow(model: ResNetLowLevel):
475    """๊ฐ ๋ ˆ์ด์–ด์˜ gradient ํฌ๊ธฐ ์‹œ๊ฐํ™”"""
476    import matplotlib.pyplot as plt
477
478    gradients = []
479    names = []
480
481    for i, block in enumerate(model.layer1 + model.layer2 + model.layer3 + model.layer4):
482        for j, param in enumerate(block.parameters()):
483            if param.grad is not None:
484                gradients.append(param.grad.abs().mean().item())
485                names.append(f"block{i}_param{j}")
486
487    plt.figure(figsize=(12, 4))
488    plt.bar(range(len(gradients)), gradients)
489    plt.xlabel('Layer')
490    plt.ylabel('Mean |Gradient|')
491    plt.title('Gradient Flow through ResNet')
492    plt.tight_layout()
493    plt.savefig('gradient_flow.png')
494    print("Saved gradient_flow.png")
495
496
497def main():
498    """CIFAR-10์œผ๋กœ ResNet ํ•™์Šต ๋ฐ๋ชจ"""
499    from torchvision import datasets, transforms
500    from torch.utils.data import DataLoader
501
502    print("=== ResNet Low-Level Training (CIFAR-10) ===\n")
503
504    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
505    print(f"Device: {device}")
506
507    # ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ
508    transform_train = transforms.Compose([
509        transforms.RandomCrop(32, padding=4),
510        transforms.RandomHorizontalFlip(),
511        transforms.ToTensor(),
512        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
513    ])
514
515    transform_test = transforms.Compose([
516        transforms.ToTensor(),
517        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
518    ])
519
520    train_dataset = datasets.CIFAR10(
521        root='./data', train=True, download=True, transform=transform_train
522    )
523    test_dataset = datasets.CIFAR10(
524        root='./data', train=False, download=True, transform=transform_test
525    )
526
527    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
528    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)
529
530    print(f"Train samples: {len(train_dataset)}")
531    print(f"Test samples: {len(test_dataset)}\n")
532
533    # ๋ชจ๋ธ
534    model = ResNetSmall(config_name='resnet18', num_classes=10)
535    model.to(device)
536
537    print(f"ResNet-18 for CIFAR-10")
538    print(f"Total parameters: {model.count_parameters():,}\n")
539
540    # Velocity ์ดˆ๊ธฐํ™”
541    velocities = [torch.zeros_like(p) for p in model.parameters()]
542
543    # ํ•™์Šต
544    epochs = 100
545    lr = 0.1
546
547    for epoch in range(epochs):
548        # Learning rate schedule
549        if epoch in [30, 60, 80]:
550            lr *= 0.1
551            print(f"LR โ†’ {lr}")
552
553        train_loss, train_acc = train_epoch(model, train_loader, lr, velocities)
554
555        if (epoch + 1) % 10 == 0:
556            test_loss, test_acc = evaluate(model, test_loader)
557            print(f"Epoch {epoch+1}/{epochs}")
558            print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
559            print(f"  Test  - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}\n")
560
561    final_loss, final_acc = evaluate(model, test_loader)
562    print(f"Final Test Accuracy: {final_acc:.4f}")
563
564
565if __name__ == "__main__":
566    main()