lenet_lowlevel.py

Download
python 358 lines 10.0 KB
  1"""
  2PyTorch Low-Level LeNet-5 κ΅¬ν˜„
  3
  4nn.Conv2d, nn.Linear λŒ€μ‹  F.conv2d, torch.matmul μ‚¬μš©
  5νŒŒλΌλ―Έν„°λ₯Ό μˆ˜λ™μœΌλ‘œ 관리
  6"""
  7
  8import torch
  9import torch.nn.functional as F
 10import math
 11from typing import Tuple, List
 12
 13
 14class LeNetLowLevel:
 15    """
 16    LeNet-5 Low-Level κ΅¬ν˜„
 17
 18    nn.Module λ―Έμ‚¬μš©, F.conv2d λ“± κΈ°λ³Έ μ—°μ‚°λ§Œ μ‚¬μš©
 19    """
 20
 21    def __init__(self, num_classes: int = 10):
 22        # Device
 23        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 24
 25        # Conv1: 1 β†’ 6 channels, 5x5 kernel
 26        self.conv1_weight = self._init_conv_weight(1, 6, 5)
 27        self.conv1_bias = torch.zeros(6, requires_grad=True, device=self.device)
 28
 29        # Conv2: 6 β†’ 16 channels, 5x5 kernel
 30        self.conv2_weight = self._init_conv_weight(6, 16, 5)
 31        self.conv2_bias = torch.zeros(16, requires_grad=True, device=self.device)
 32
 33        # Conv3: 16 β†’ 120 channels, 5x5 kernel
 34        self.conv3_weight = self._init_conv_weight(16, 120, 5)
 35        self.conv3_bias = torch.zeros(120, requires_grad=True, device=self.device)
 36
 37        # FC1: 120 β†’ 84
 38        self.fc1_weight = self._init_linear_weight(120, 84)
 39        self.fc1_bias = torch.zeros(84, requires_grad=True, device=self.device)
 40
 41        # FC2: 84 β†’ num_classes
 42        self.fc2_weight = self._init_linear_weight(84, num_classes)
 43        self.fc2_bias = torch.zeros(num_classes, requires_grad=True, device=self.device)
 44
 45    def _init_conv_weight(
 46        self,
 47        in_channels: int,
 48        out_channels: int,
 49        kernel_size: int
 50    ) -> torch.Tensor:
 51        """Kaiming μ΄ˆκΈ°ν™”"""
 52        fan_in = in_channels * kernel_size * kernel_size
 53        std = math.sqrt(2.0 / fan_in)
 54        weight = torch.randn(
 55            out_channels, in_channels, kernel_size, kernel_size,
 56            requires_grad=True, device=self.device
 57        ) * std
 58        return weight
 59
 60    def _init_linear_weight(
 61        self,
 62        in_features: int,
 63        out_features: int
 64    ) -> torch.Tensor:
 65        """Xavier μ΄ˆκΈ°ν™”"""
 66        std = math.sqrt(2.0 / (in_features + out_features))
 67        weight = torch.randn(
 68            out_features, in_features,
 69            requires_grad=True, device=self.device
 70        ) * std
 71        return weight
 72
 73    def forward(self, x: torch.Tensor) -> torch.Tensor:
 74        """
 75        Forward pass
 76
 77        Args:
 78            x: (N, 1, 32, 32) μž…λ ₯ 이미지
 79
 80        Returns:
 81            logits: (N, num_classes)
 82        """
 83        # Layer 1: Conv β†’ ReLU β†’ AvgPool
 84        # (N, 1, 32, 32) β†’ (N, 6, 28, 28) β†’ (N, 6, 14, 14)
 85        x = F.conv2d(x, self.conv1_weight, self.conv1_bias, stride=1, padding=0)
 86        x = F.relu(x)
 87        x = F.avg_pool2d(x, kernel_size=2, stride=2)
 88
 89        # Layer 2: Conv β†’ ReLU β†’ AvgPool
 90        # (N, 6, 14, 14) β†’ (N, 16, 10, 10) β†’ (N, 16, 5, 5)
 91        x = F.conv2d(x, self.conv2_weight, self.conv2_bias, stride=1, padding=0)
 92        x = F.relu(x)
 93        x = F.avg_pool2d(x, kernel_size=2, stride=2)
 94
 95        # Layer 3: Conv β†’ ReLU
 96        # (N, 16, 5, 5) β†’ (N, 120, 1, 1)
 97        x = F.conv2d(x, self.conv3_weight, self.conv3_bias, stride=1, padding=0)
 98        x = F.relu(x)
 99
100        # Flatten: (N, 120, 1, 1) β†’ (N, 120)
101        x = x.view(x.size(0), -1)
102
103        # FC1: (N, 120) β†’ (N, 84)
104        x = torch.matmul(x, self.fc1_weight.t()) + self.fc1_bias
105        x = F.relu(x)
106
107        # FC2: (N, 84) β†’ (N, num_classes)
108        x = torch.matmul(x, self.fc2_weight.t()) + self.fc2_bias
109
110        return x
111
112    def parameters(self) -> List[torch.Tensor]:
113        """ν•™μŠ΅ κ°€λŠ₯ν•œ νŒŒλΌλ―Έν„° λ°˜ν™˜"""
114        return [
115            self.conv1_weight, self.conv1_bias,
116            self.conv2_weight, self.conv2_bias,
117            self.conv3_weight, self.conv3_bias,
118            self.fc1_weight, self.fc1_bias,
119            self.fc2_weight, self.fc2_bias,
120        ]
121
122    def zero_grad(self):
123        """Gradient μ΄ˆκΈ°ν™”"""
124        for param in self.parameters():
125            if param.grad is not None:
126                param.grad.zero_()
127
128    def to(self, device):
129        """Device 이동"""
130        self.device = device
131        for param in self.parameters():
132            param.data = param.data.to(device)
133            if param.grad is not None:
134                param.grad = param.grad.to(device)
135        return self
136
137
138def sgd_step(params: List[torch.Tensor], lr: float):
139    """μˆ˜λ™ SGD μ—…λ°μ΄νŠΈ"""
140    with torch.no_grad():
141        for param in params:
142            if param.grad is not None:
143                param -= lr * param.grad
144
145
146def sgd_step_with_momentum(
147    params: List[torch.Tensor],
148    velocities: List[torch.Tensor],
149    lr: float,
150    momentum: float = 0.9
151):
152    """Momentum SGD"""
153    with torch.no_grad():
154        for param, velocity in zip(params, velocities):
155            if param.grad is not None:
156                velocity.mul_(momentum).add_(param.grad)
157                param -= lr * velocity
158
159
160def train_epoch(
161    model: LeNetLowLevel,
162    dataloader,
163    lr: float = 0.01
164) -> Tuple[float, float]:
165    """ν•œ 에폭 ν•™μŠ΅"""
166    total_loss = 0.0
167    total_correct = 0
168    total_samples = 0
169
170    for images, labels in dataloader:
171        images = images.to(model.device)
172        labels = labels.to(model.device)
173
174        # Forward
175        logits = model.forward(images)
176
177        # Loss (Cross Entropyλ₯Ό 직접 계산)
178        # log_softmax + nll_loss = cross_entropy
179        log_probs = F.log_softmax(logits, dim=1)
180        loss = F.nll_loss(log_probs, labels)
181
182        # Backward
183        model.zero_grad()
184        loss.backward()
185
186        # Update
187        sgd_step(model.parameters(), lr)
188
189        # Metrics
190        total_loss += loss.item() * images.size(0)
191        predictions = logits.argmax(dim=1)
192        total_correct += (predictions == labels).sum().item()
193        total_samples += images.size(0)
194
195    avg_loss = total_loss / total_samples
196    accuracy = total_correct / total_samples
197
198    return avg_loss, accuracy
199
200
201@torch.no_grad()
202def evaluate(
203    model: LeNetLowLevel,
204    dataloader
205) -> Tuple[float, float]:
206    """평가"""
207    total_loss = 0.0
208    total_correct = 0
209    total_samples = 0
210
211    for images, labels in dataloader:
212        images = images.to(model.device)
213        labels = labels.to(model.device)
214
215        # Forward
216        logits = model.forward(images)
217
218        # Loss
219        loss = F.cross_entropy(logits, labels)
220
221        # Metrics
222        total_loss += loss.item() * images.size(0)
223        predictions = logits.argmax(dim=1)
224        total_correct += (predictions == labels).sum().item()
225        total_samples += images.size(0)
226
227    avg_loss = total_loss / total_samples
228    accuracy = total_correct / total_samples
229
230    return avg_loss, accuracy
231
232
233def main():
234    """ν•™μŠ΅ 슀크립트"""
235    from torchvision import datasets, transforms
236    from torch.utils.data import DataLoader
237
238    print("=== LeNet-5 Low-Level Training ===\n")
239
240    # Device
241    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
242    print(f"Device: {device}")
243
244    # 데이터셋 (MNIST β†’ 32x32둜 λ¦¬μ‚¬μ΄μ¦ˆ)
245    transform = transforms.Compose([
246        transforms.Resize((32, 32)),
247        transforms.ToTensor(),
248        transforms.Normalize((0.1307,), (0.3081,))
249    ])
250
251    train_dataset = datasets.MNIST(
252        root='./data', train=True, download=True, transform=transform
253    )
254    test_dataset = datasets.MNIST(
255        root='./data', train=False, download=True, transform=transform
256    )
257
258    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
259    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
260
261    print(f"Train samples: {len(train_dataset)}")
262    print(f"Test samples: {len(test_dataset)}\n")
263
264    # λͺ¨λΈ
265    model = LeNetLowLevel(num_classes=10)
266    model.to(device)
267
268    # νŒŒλΌλ―Έν„° 수 계산
269    total_params = sum(p.numel() for p in model.parameters())
270    print(f"Total parameters: {total_params:,}\n")
271
272    # ν•™μŠ΅
273    epochs = 10
274    lr = 0.01
275
276    for epoch in range(epochs):
277        train_loss, train_acc = train_epoch(model, train_loader, lr)
278        test_loss, test_acc = evaluate(model, test_loader)
279
280        print(f"Epoch {epoch+1}/{epochs}")
281        print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
282        print(f"  Test  - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")
283
284        # Learning rate decay
285        if (epoch + 1) % 5 == 0:
286            lr *= 0.5
287            print(f"  LR β†’ {lr}")
288
289        print()
290
291    print("Training complete!")
292
293    # μ΅œμ’… κ²°κ³Ό
294    final_loss, final_acc = evaluate(model, test_loader)
295    print(f"\nFinal Test Accuracy: {final_acc:.4f}")
296
297
298# Convolution μ—°μ‚° μ‹œκ°ν™”
299def visualize_conv_operation():
300    """Convolution μ—°μ‚° κ³Όμ • μ‹œκ°ν™”"""
301    import matplotlib.pyplot as plt
302
303    # κ°„λ‹¨ν•œ μž…λ ₯
304    input_img = torch.zeros(1, 1, 5, 5)
305    input_img[0, 0, 1:4, 1:4] = 1.0  # 쀑앙에 3x3 μ‚¬κ°ν˜•
306
307    # μ—£μ§€ κ²€μΆœ ν•„ν„°
308    filters = {
309        'Horizontal': torch.tensor([
310            [-1, -1, -1],
311            [ 0,  0,  0],
312            [ 1,  1,  1]
313        ]).float().view(1, 1, 3, 3),
314
315        'Vertical': torch.tensor([
316            [-1, 0, 1],
317            [-1, 0, 1],
318            [-1, 0, 1]
319        ]).float().view(1, 1, 3, 3),
320
321        'Identity': torch.tensor([
322            [0, 0, 0],
323            [0, 1, 0],
324            [0, 0, 0]
325        ]).float().view(1, 1, 3, 3),
326    }
327
328    fig, axes = plt.subplots(2, len(filters) + 1, figsize=(12, 6))
329
330    # μž…λ ₯ 이미지
331    axes[0, 0].imshow(input_img[0, 0], cmap='gray')
332    axes[0, 0].set_title('Input')
333    axes[0, 0].axis('off')
334
335    axes[1, 0].axis('off')
336
337    # 각 ν•„ν„° 적용
338    for i, (name, kernel) in enumerate(filters.items()):
339        output = F.conv2d(input_img, kernel, padding=1)
340
341        # ν•„ν„°
342        axes[0, i+1].imshow(kernel[0, 0], cmap='RdBu', vmin=-1, vmax=1)
343        axes[0, i+1].set_title(f'{name} Filter')
344        axes[0, i+1].axis('off')
345
346        # 좜λ ₯
347        axes[1, i+1].imshow(output[0, 0].detach(), cmap='gray')
348        axes[1, i+1].set_title(f'Output')
349        axes[1, i+1].axis('off')
350
351    plt.tight_layout()
352    plt.savefig('conv_visualization.png', dpi=150)
353    print("Saved conv_visualization.png")
354
355
356if __name__ == "__main__":
357    main()