1"""
2PyTorch Low-Level VGG 구현
3
4nn.Conv2d, nn.Linear 대신 F.conv2d, torch.matmul 사용
5파라미터를 수동으로 관리하며 블록 단위로 구성
6"""
7
8import torch
9import torch.nn.functional as F
10import math
11from typing import Tuple, List, Dict, Optional
12
13
14# VGG 설정: 숫자 = 출력 채널, 'M' = MaxPool
15VGG_CONFIGS = {
16 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
17 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
18 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
19 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
20}
21
22
23class VGGLowLevel:
24 """
25 VGG Low-Level 구현
26
27 nn.Module 미사용, F.conv2d 등 기본 연산만 사용
28 """
29
30 def __init__(
31 self,
32 config_name: str = 'VGG16',
33 num_classes: int = 1000,
34 input_channels: int = 3,
35 use_bn: bool = False
36 ):
37 """
38 Args:
39 config_name: VGG 변형 ('VGG11', 'VGG13', 'VGG16', 'VGG19')
40 num_classes: 출력 클래스 수
41 input_channels: 입력 채널 수 (RGB=3)
42 use_bn: Batch Normalization 사용 여부
43 """
44 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45 self.config = VGG_CONFIGS[config_name]
46 self.use_bn = use_bn
47
48 # Feature extractor 파라미터
49 self.conv_params = []
50 self.bn_params = [] if use_bn else None
51 self._build_features(input_channels)
52
53 # Classifier 파라미터
54 self._build_classifier(num_classes)
55
56 def _init_conv_weight(
57 self,
58 in_channels: int,
59 out_channels: int,
60 kernel_size: int = 3
61 ) -> Tuple[torch.Tensor, torch.Tensor]:
62 """Kaiming 초기화로 Conv 가중치 생성"""
63 fan_in = in_channels * kernel_size * kernel_size
64 std = math.sqrt(2.0 / fan_in)
65
66 weight = torch.randn(
67 out_channels, in_channels, kernel_size, kernel_size,
68 requires_grad=True, device=self.device
69 ) * std
70 bias = torch.zeros(out_channels, requires_grad=True, device=self.device)
71
72 return weight, bias
73
74 def _init_bn_params(self, num_features: int) -> Dict[str, torch.Tensor]:
75 """BatchNorm 파라미터 초기화"""
76 return {
77 'gamma': torch.ones(num_features, requires_grad=True, device=self.device),
78 'beta': torch.zeros(num_features, requires_grad=True, device=self.device),
79 'running_mean': torch.zeros(num_features, device=self.device),
80 'running_var': torch.ones(num_features, device=self.device),
81 }
82
83 def _init_linear_weight(
84 self,
85 in_features: int,
86 out_features: int
87 ) -> Tuple[torch.Tensor, torch.Tensor]:
88 """Xavier 초기화로 Linear 가중치 생성"""
89 std = math.sqrt(2.0 / (in_features + out_features))
90
91 weight = torch.randn(
92 out_features, in_features,
93 requires_grad=True, device=self.device
94 ) * std
95 bias = torch.zeros(out_features, requires_grad=True, device=self.device)
96
97 return weight, bias
98
99 def _build_features(self, input_channels: int):
100 """Feature extractor (Conv layers) 구축"""
101 in_channels = input_channels
102
103 for v in self.config:
104 if v == 'M':
105 # MaxPool은 파라미터 없음
106 self.conv_params.append('M')
107 if self.use_bn:
108 self.bn_params.append(None)
109 else:
110 out_channels = v
111 weight, bias = self._init_conv_weight(in_channels, out_channels, 3)
112 self.conv_params.append({'weight': weight, 'bias': bias})
113
114 if self.use_bn:
115 bn = self._init_bn_params(out_channels)
116 self.bn_params.append(bn)
117
118 in_channels = out_channels
119
120 def _build_classifier(self, num_classes: int):
121 """Classifier (FC layers) 구축"""
122 # 7×7×512 = 25088 (224×224 입력 기준)
123 # CIFAR-10 (32×32) 사용시 1×1×512 = 512
124
125 # FC1: 25088 → 4096
126 self.fc1_weight, self.fc1_bias = self._init_linear_weight(512 * 7 * 7, 4096)
127
128 # FC2: 4096 → 4096
129 self.fc2_weight, self.fc2_bias = self._init_linear_weight(4096, 4096)
130
131 # FC3: 4096 → num_classes
132 self.fc3_weight, self.fc3_bias = self._init_linear_weight(4096, num_classes)
133
134 def _batch_norm(
135 self,
136 x: torch.Tensor,
137 bn_params: Dict[str, torch.Tensor],
138 training: bool = True,
139 momentum: float = 0.1,
140 eps: float = 1e-5
141 ) -> torch.Tensor:
142 """수동 Batch Normalization"""
143 if training:
144 # 현재 배치의 mean, var 계산
145 mean = x.mean(dim=(0, 2, 3), keepdim=True)
146 var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
147
148 # Running statistics 업데이트
149 with torch.no_grad():
150 bn_params['running_mean'] = (
151 (1 - momentum) * bn_params['running_mean'] +
152 momentum * mean.squeeze()
153 )
154 bn_params['running_var'] = (
155 (1 - momentum) * bn_params['running_var'] +
156 momentum * var.squeeze()
157 )
158 else:
159 mean = bn_params['running_mean'].view(1, -1, 1, 1)
160 var = bn_params['running_var'].view(1, -1, 1, 1)
161
162 # Normalize
163 x_norm = (x - mean) / torch.sqrt(var + eps)
164
165 # Scale and shift
166 gamma = bn_params['gamma'].view(1, -1, 1, 1)
167 beta = bn_params['beta'].view(1, -1, 1, 1)
168
169 return gamma * x_norm + beta
170
171 def forward(
172 self,
173 x: torch.Tensor,
174 training: bool = True
175 ) -> torch.Tensor:
176 """
177 Forward pass
178
179 Args:
180 x: (N, C, H, W) 입력 이미지
181 training: 학습 모드 (BN, Dropout에 영향)
182
183 Returns:
184 logits: (N, num_classes)
185 """
186 # Feature extraction
187 for i, params in enumerate(self.conv_params):
188 if params == 'M':
189 x = F.max_pool2d(x, kernel_size=2, stride=2)
190 else:
191 x = F.conv2d(x, params['weight'], params['bias'],
192 stride=1, padding=1)
193
194 if self.use_bn and self.bn_params[i] is not None:
195 x = self._batch_norm(x, self.bn_params[i], training)
196
197 x = F.relu(x)
198
199 # Flatten
200 x = x.view(x.size(0), -1)
201
202 # Classifier
203 # FC1
204 x = torch.matmul(x, self.fc1_weight.t()) + self.fc1_bias
205 x = F.relu(x)
206 if training:
207 x = F.dropout(x, p=0.5, training=True)
208
209 # FC2
210 x = torch.matmul(x, self.fc2_weight.t()) + self.fc2_bias
211 x = F.relu(x)
212 if training:
213 x = F.dropout(x, p=0.5, training=True)
214
215 # FC3
216 x = torch.matmul(x, self.fc3_weight.t()) + self.fc3_bias
217
218 return x
219
220 def parameters(self) -> List[torch.Tensor]:
221 """학습 가능한 파라미터 반환"""
222 params = []
223
224 # Conv 파라미터
225 for p in self.conv_params:
226 if p != 'M':
227 params.extend([p['weight'], p['bias']])
228
229 # BN 파라미터
230 if self.use_bn:
231 for bn in self.bn_params:
232 if bn is not None:
233 params.extend([bn['gamma'], bn['beta']])
234
235 # FC 파라미터
236 params.extend([
237 self.fc1_weight, self.fc1_bias,
238 self.fc2_weight, self.fc2_bias,
239 self.fc3_weight, self.fc3_bias,
240 ])
241
242 return params
243
244 def zero_grad(self):
245 """Gradient 초기화"""
246 for param in self.parameters():
247 if param.grad is not None:
248 param.grad.zero_()
249
250 def to(self, device):
251 """Device 이동"""
252 self.device = device
253
254 # Conv 파라미터
255 for p in self.conv_params:
256 if p != 'M':
257 p['weight'] = p['weight'].to(device)
258 p['bias'] = p['bias'].to(device)
259
260 # BN 파라미터
261 if self.use_bn:
262 for bn in self.bn_params:
263 if bn is not None:
264 for key in bn:
265 bn[key] = bn[key].to(device)
266
267 # FC 파라미터
268 for attr in ['fc1_weight', 'fc1_bias', 'fc2_weight',
269 'fc2_bias', 'fc3_weight', 'fc3_bias']:
270 tensor = getattr(self, attr)
271 setattr(self, attr, tensor.to(device))
272
273 return self
274
275 def count_parameters(self) -> int:
276 """파라미터 수 계산"""
277 return sum(p.numel() for p in self.parameters())
278
279
280class VGGSmall(VGGLowLevel):
281 """
282 CIFAR-10용 작은 VGG
283
284 입력: 32×32 → 출력 feature map: 1×1×512
285 """
286
287 def _build_classifier(self, num_classes: int):
288 """작은 입력에 맞는 Classifier"""
289 # 32×32 입력 → 5번 풀링 → 1×1×512
290 self.fc1_weight, self.fc1_bias = self._init_linear_weight(512, 512)
291 self.fc2_weight, self.fc2_bias = self._init_linear_weight(512, 512)
292 self.fc3_weight, self.fc3_bias = self._init_linear_weight(512, num_classes)
293
294
295def sgd_step_with_momentum(
296 params: List[torch.Tensor],
297 velocities: List[torch.Tensor],
298 lr: float,
299 momentum: float = 0.9,
300 weight_decay: float = 5e-4
301):
302 """Momentum SGD with Weight Decay"""
303 with torch.no_grad():
304 for param, velocity in zip(params, velocities):
305 if param.grad is not None:
306 # Weight decay
307 param.grad.add_(param, alpha=weight_decay)
308
309 # Momentum update
310 velocity.mul_(momentum).add_(param.grad)
311 param.sub_(velocity, alpha=lr)
312
313
314def train_epoch(
315 model: VGGLowLevel,
316 dataloader,
317 lr: float = 0.01,
318 momentum: float = 0.9,
319 weight_decay: float = 5e-4
320) -> Tuple[float, float]:
321 """한 에폭 학습"""
322 # Velocity 초기화 (첫 에폭)
323 if not hasattr(train_epoch, 'velocities') or len(train_epoch.velocities) != len(model.parameters()):
324 train_epoch.velocities = [torch.zeros_like(p) for p in model.parameters()]
325
326 total_loss = 0.0
327 total_correct = 0
328 total_samples = 0
329
330 for images, labels in dataloader:
331 images = images.to(model.device)
332 labels = labels.to(model.device)
333
334 # Forward
335 logits = model.forward(images, training=True)
336
337 # Loss
338 loss = F.cross_entropy(logits, labels)
339
340 # Backward
341 model.zero_grad()
342 loss.backward()
343
344 # Update
345 sgd_step_with_momentum(
346 model.parameters(),
347 train_epoch.velocities,
348 lr, momentum, weight_decay
349 )
350
351 # Metrics
352 total_loss += loss.item() * images.size(0)
353 predictions = logits.argmax(dim=1)
354 total_correct += (predictions == labels).sum().item()
355 total_samples += images.size(0)
356
357 return total_loss / total_samples, total_correct / total_samples
358
359
360@torch.no_grad()
361def evaluate(model: VGGLowLevel, dataloader) -> Tuple[float, float]:
362 """평가"""
363 total_loss = 0.0
364 total_correct = 0
365 total_samples = 0
366
367 for images, labels in dataloader:
368 images = images.to(model.device)
369 labels = labels.to(model.device)
370
371 logits = model.forward(images, training=False)
372 loss = F.cross_entropy(logits, labels)
373
374 total_loss += loss.item() * images.size(0)
375 predictions = logits.argmax(dim=1)
376 total_correct += (predictions == labels).sum().item()
377 total_samples += images.size(0)
378
379 return total_loss / total_samples, total_correct / total_samples
380
381
382def visualize_features(model: VGGLowLevel, image: torch.Tensor) -> List[torch.Tensor]:
383 """
384 각 블록의 feature map 추출
385
386 Returns:
387 List of feature maps after each conv block (before pooling)
388 """
389 features = []
390 x = image.to(model.device)
391
392 for i, params in enumerate(model.conv_params):
393 if params == 'M':
394 features.append(x.clone()) # Pool 전 저장
395 x = F.max_pool2d(x, kernel_size=2, stride=2)
396 else:
397 x = F.conv2d(x, params['weight'], params['bias'],
398 stride=1, padding=1)
399 x = F.relu(x)
400
401 features.append(x) # 마지막 블록
402 return features
403
404
405def main():
406 """CIFAR-10으로 VGG 학습 데모"""
407 from torchvision import datasets, transforms
408 from torch.utils.data import DataLoader
409
410 print("=== VGG Low-Level Training (CIFAR-10) ===\n")
411
412 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
413 print(f"Device: {device}")
414
415 # 데이터 전처리
416 transform_train = transforms.Compose([
417 transforms.RandomCrop(32, padding=4),
418 transforms.RandomHorizontalFlip(),
419 transforms.ToTensor(),
420 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
421 ])
422
423 transform_test = transforms.Compose([
424 transforms.ToTensor(),
425 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
426 ])
427
428 # CIFAR-10 데이터셋
429 train_dataset = datasets.CIFAR10(
430 root='./data', train=True, download=True, transform=transform_train
431 )
432 test_dataset = datasets.CIFAR10(
433 root='./data', train=False, download=True, transform=transform_test
434 )
435
436 train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
437 test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)
438
439 print(f"Train samples: {len(train_dataset)}")
440 print(f"Test samples: {len(test_dataset)}\n")
441
442 # 모델 (CIFAR용 작은 VGG)
443 model = VGGSmall(config_name='VGG16', num_classes=10, use_bn=True)
444 model.to(device)
445
446 print(f"VGG16-BN for CIFAR-10")
447 print(f"Total parameters: {model.count_parameters():,}\n")
448
449 # 학습
450 epochs = 100
451 lr = 0.1
452
453 for epoch in range(epochs):
454 # Learning rate schedule
455 if epoch in [30, 60, 80]:
456 lr *= 0.1
457 print(f"LR → {lr}")
458
459 train_loss, train_acc = train_epoch(model, train_loader, lr)
460
461 # 10 에폭마다 평가
462 if (epoch + 1) % 10 == 0:
463 test_loss, test_acc = evaluate(model, test_loader)
464 print(f"Epoch {epoch+1}/{epochs}")
465 print(f" Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
466 print(f" Test - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}\n")
467
468 # 최종 평가
469 final_loss, final_acc = evaluate(model, test_loader)
470 print(f"Final Test Accuracy: {final_acc:.4f}")
471
472
473if __name__ == "__main__":
474 main()