vgg_lowlevel.py

Download
python 475 lines 14.5 KB
  1"""
  2PyTorch Low-Level VGG 구현
  3
  4nn.Conv2d, nn.Linear 대신 F.conv2d, torch.matmul 사용
  5파라미터를 수동으로 관리하며 블록 단위로 구성
  6"""
  7
  8import torch
  9import torch.nn.functional as F
 10import math
 11from typing import Tuple, List, Dict, Optional
 12
 13
 14# VGG 설정: 숫자 = 출력 채널, 'M' = MaxPool
 15VGG_CONFIGS = {
 16    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
 17    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
 18    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
 19    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
 20}
 21
 22
 23class VGGLowLevel:
 24    """
 25    VGG Low-Level 구현
 26
 27    nn.Module 미사용, F.conv2d 등 기본 연산만 사용
 28    """
 29
 30    def __init__(
 31        self,
 32        config_name: str = 'VGG16',
 33        num_classes: int = 1000,
 34        input_channels: int = 3,
 35        use_bn: bool = False
 36    ):
 37        """
 38        Args:
 39            config_name: VGG 변형 ('VGG11', 'VGG13', 'VGG16', 'VGG19')
 40            num_classes: 출력 클래스 수
 41            input_channels: 입력 채널 수 (RGB=3)
 42            use_bn: Batch Normalization 사용 여부
 43        """
 44        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 45        self.config = VGG_CONFIGS[config_name]
 46        self.use_bn = use_bn
 47
 48        # Feature extractor 파라미터
 49        self.conv_params = []
 50        self.bn_params = [] if use_bn else None
 51        self._build_features(input_channels)
 52
 53        # Classifier 파라미터
 54        self._build_classifier(num_classes)
 55
 56    def _init_conv_weight(
 57        self,
 58        in_channels: int,
 59        out_channels: int,
 60        kernel_size: int = 3
 61    ) -> Tuple[torch.Tensor, torch.Tensor]:
 62        """Kaiming 초기화로 Conv 가중치 생성"""
 63        fan_in = in_channels * kernel_size * kernel_size
 64        std = math.sqrt(2.0 / fan_in)
 65
 66        weight = torch.randn(
 67            out_channels, in_channels, kernel_size, kernel_size,
 68            requires_grad=True, device=self.device
 69        ) * std
 70        bias = torch.zeros(out_channels, requires_grad=True, device=self.device)
 71
 72        return weight, bias
 73
 74    def _init_bn_params(self, num_features: int) -> Dict[str, torch.Tensor]:
 75        """BatchNorm 파라미터 초기화"""
 76        return {
 77            'gamma': torch.ones(num_features, requires_grad=True, device=self.device),
 78            'beta': torch.zeros(num_features, requires_grad=True, device=self.device),
 79            'running_mean': torch.zeros(num_features, device=self.device),
 80            'running_var': torch.ones(num_features, device=self.device),
 81        }
 82
 83    def _init_linear_weight(
 84        self,
 85        in_features: int,
 86        out_features: int
 87    ) -> Tuple[torch.Tensor, torch.Tensor]:
 88        """Xavier 초기화로 Linear 가중치 생성"""
 89        std = math.sqrt(2.0 / (in_features + out_features))
 90
 91        weight = torch.randn(
 92            out_features, in_features,
 93            requires_grad=True, device=self.device
 94        ) * std
 95        bias = torch.zeros(out_features, requires_grad=True, device=self.device)
 96
 97        return weight, bias
 98
 99    def _build_features(self, input_channels: int):
100        """Feature extractor (Conv layers) 구축"""
101        in_channels = input_channels
102
103        for v in self.config:
104            if v == 'M':
105                # MaxPool은 파라미터 없음
106                self.conv_params.append('M')
107                if self.use_bn:
108                    self.bn_params.append(None)
109            else:
110                out_channels = v
111                weight, bias = self._init_conv_weight(in_channels, out_channels, 3)
112                self.conv_params.append({'weight': weight, 'bias': bias})
113
114                if self.use_bn:
115                    bn = self._init_bn_params(out_channels)
116                    self.bn_params.append(bn)
117
118                in_channels = out_channels
119
120    def _build_classifier(self, num_classes: int):
121        """Classifier (FC layers) 구축"""
122        # 7×7×512 = 25088 (224×224 입력 기준)
123        # CIFAR-10 (32×32) 사용시 1×1×512 = 512
124
125        # FC1: 25088 → 4096
126        self.fc1_weight, self.fc1_bias = self._init_linear_weight(512 * 7 * 7, 4096)
127
128        # FC2: 4096 → 4096
129        self.fc2_weight, self.fc2_bias = self._init_linear_weight(4096, 4096)
130
131        # FC3: 4096 → num_classes
132        self.fc3_weight, self.fc3_bias = self._init_linear_weight(4096, num_classes)
133
134    def _batch_norm(
135        self,
136        x: torch.Tensor,
137        bn_params: Dict[str, torch.Tensor],
138        training: bool = True,
139        momentum: float = 0.1,
140        eps: float = 1e-5
141    ) -> torch.Tensor:
142        """수동 Batch Normalization"""
143        if training:
144            # 현재 배치의 mean, var 계산
145            mean = x.mean(dim=(0, 2, 3), keepdim=True)
146            var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
147
148            # Running statistics 업데이트
149            with torch.no_grad():
150                bn_params['running_mean'] = (
151                    (1 - momentum) * bn_params['running_mean'] +
152                    momentum * mean.squeeze()
153                )
154                bn_params['running_var'] = (
155                    (1 - momentum) * bn_params['running_var'] +
156                    momentum * var.squeeze()
157                )
158        else:
159            mean = bn_params['running_mean'].view(1, -1, 1, 1)
160            var = bn_params['running_var'].view(1, -1, 1, 1)
161
162        # Normalize
163        x_norm = (x - mean) / torch.sqrt(var + eps)
164
165        # Scale and shift
166        gamma = bn_params['gamma'].view(1, -1, 1, 1)
167        beta = bn_params['beta'].view(1, -1, 1, 1)
168
169        return gamma * x_norm + beta
170
171    def forward(
172        self,
173        x: torch.Tensor,
174        training: bool = True
175    ) -> torch.Tensor:
176        """
177        Forward pass
178
179        Args:
180            x: (N, C, H, W) 입력 이미지
181            training: 학습 모드 (BN, Dropout에 영향)
182
183        Returns:
184            logits: (N, num_classes)
185        """
186        # Feature extraction
187        for i, params in enumerate(self.conv_params):
188            if params == 'M':
189                x = F.max_pool2d(x, kernel_size=2, stride=2)
190            else:
191                x = F.conv2d(x, params['weight'], params['bias'],
192                            stride=1, padding=1)
193
194                if self.use_bn and self.bn_params[i] is not None:
195                    x = self._batch_norm(x, self.bn_params[i], training)
196
197                x = F.relu(x)
198
199        # Flatten
200        x = x.view(x.size(0), -1)
201
202        # Classifier
203        # FC1
204        x = torch.matmul(x, self.fc1_weight.t()) + self.fc1_bias
205        x = F.relu(x)
206        if training:
207            x = F.dropout(x, p=0.5, training=True)
208
209        # FC2
210        x = torch.matmul(x, self.fc2_weight.t()) + self.fc2_bias
211        x = F.relu(x)
212        if training:
213            x = F.dropout(x, p=0.5, training=True)
214
215        # FC3
216        x = torch.matmul(x, self.fc3_weight.t()) + self.fc3_bias
217
218        return x
219
220    def parameters(self) -> List[torch.Tensor]:
221        """학습 가능한 파라미터 반환"""
222        params = []
223
224        # Conv 파라미터
225        for p in self.conv_params:
226            if p != 'M':
227                params.extend([p['weight'], p['bias']])
228
229        # BN 파라미터
230        if self.use_bn:
231            for bn in self.bn_params:
232                if bn is not None:
233                    params.extend([bn['gamma'], bn['beta']])
234
235        # FC 파라미터
236        params.extend([
237            self.fc1_weight, self.fc1_bias,
238            self.fc2_weight, self.fc2_bias,
239            self.fc3_weight, self.fc3_bias,
240        ])
241
242        return params
243
244    def zero_grad(self):
245        """Gradient 초기화"""
246        for param in self.parameters():
247            if param.grad is not None:
248                param.grad.zero_()
249
250    def to(self, device):
251        """Device 이동"""
252        self.device = device
253
254        # Conv 파라미터
255        for p in self.conv_params:
256            if p != 'M':
257                p['weight'] = p['weight'].to(device)
258                p['bias'] = p['bias'].to(device)
259
260        # BN 파라미터
261        if self.use_bn:
262            for bn in self.bn_params:
263                if bn is not None:
264                    for key in bn:
265                        bn[key] = bn[key].to(device)
266
267        # FC 파라미터
268        for attr in ['fc1_weight', 'fc1_bias', 'fc2_weight',
269                     'fc2_bias', 'fc3_weight', 'fc3_bias']:
270            tensor = getattr(self, attr)
271            setattr(self, attr, tensor.to(device))
272
273        return self
274
275    def count_parameters(self) -> int:
276        """파라미터 수 계산"""
277        return sum(p.numel() for p in self.parameters())
278
279
280class VGGSmall(VGGLowLevel):
281    """
282    CIFAR-10용 작은 VGG
283
284    입력: 32×32 → 출력 feature map: 1×1×512
285    """
286
287    def _build_classifier(self, num_classes: int):
288        """작은 입력에 맞는 Classifier"""
289        # 32×32 입력 → 5번 풀링 → 1×1×512
290        self.fc1_weight, self.fc1_bias = self._init_linear_weight(512, 512)
291        self.fc2_weight, self.fc2_bias = self._init_linear_weight(512, 512)
292        self.fc3_weight, self.fc3_bias = self._init_linear_weight(512, num_classes)
293
294
295def sgd_step_with_momentum(
296    params: List[torch.Tensor],
297    velocities: List[torch.Tensor],
298    lr: float,
299    momentum: float = 0.9,
300    weight_decay: float = 5e-4
301):
302    """Momentum SGD with Weight Decay"""
303    with torch.no_grad():
304        for param, velocity in zip(params, velocities):
305            if param.grad is not None:
306                # Weight decay
307                param.grad.add_(param, alpha=weight_decay)
308
309                # Momentum update
310                velocity.mul_(momentum).add_(param.grad)
311                param.sub_(velocity, alpha=lr)
312
313
314def train_epoch(
315    model: VGGLowLevel,
316    dataloader,
317    lr: float = 0.01,
318    momentum: float = 0.9,
319    weight_decay: float = 5e-4
320) -> Tuple[float, float]:
321    """한 에폭 학습"""
322    # Velocity 초기화 (첫 에폭)
323    if not hasattr(train_epoch, 'velocities') or len(train_epoch.velocities) != len(model.parameters()):
324        train_epoch.velocities = [torch.zeros_like(p) for p in model.parameters()]
325
326    total_loss = 0.0
327    total_correct = 0
328    total_samples = 0
329
330    for images, labels in dataloader:
331        images = images.to(model.device)
332        labels = labels.to(model.device)
333
334        # Forward
335        logits = model.forward(images, training=True)
336
337        # Loss
338        loss = F.cross_entropy(logits, labels)
339
340        # Backward
341        model.zero_grad()
342        loss.backward()
343
344        # Update
345        sgd_step_with_momentum(
346            model.parameters(),
347            train_epoch.velocities,
348            lr, momentum, weight_decay
349        )
350
351        # Metrics
352        total_loss += loss.item() * images.size(0)
353        predictions = logits.argmax(dim=1)
354        total_correct += (predictions == labels).sum().item()
355        total_samples += images.size(0)
356
357    return total_loss / total_samples, total_correct / total_samples
358
359
360@torch.no_grad()
361def evaluate(model: VGGLowLevel, dataloader) -> Tuple[float, float]:
362    """평가"""
363    total_loss = 0.0
364    total_correct = 0
365    total_samples = 0
366
367    for images, labels in dataloader:
368        images = images.to(model.device)
369        labels = labels.to(model.device)
370
371        logits = model.forward(images, training=False)
372        loss = F.cross_entropy(logits, labels)
373
374        total_loss += loss.item() * images.size(0)
375        predictions = logits.argmax(dim=1)
376        total_correct += (predictions == labels).sum().item()
377        total_samples += images.size(0)
378
379    return total_loss / total_samples, total_correct / total_samples
380
381
382def visualize_features(model: VGGLowLevel, image: torch.Tensor) -> List[torch.Tensor]:
383    """
384    각 블록의 feature map 추출
385
386    Returns:
387        List of feature maps after each conv block (before pooling)
388    """
389    features = []
390    x = image.to(model.device)
391
392    for i, params in enumerate(model.conv_params):
393        if params == 'M':
394            features.append(x.clone())  # Pool 전 저장
395            x = F.max_pool2d(x, kernel_size=2, stride=2)
396        else:
397            x = F.conv2d(x, params['weight'], params['bias'],
398                        stride=1, padding=1)
399            x = F.relu(x)
400
401    features.append(x)  # 마지막 블록
402    return features
403
404
405def main():
406    """CIFAR-10으로 VGG 학습 데모"""
407    from torchvision import datasets, transforms
408    from torch.utils.data import DataLoader
409
410    print("=== VGG Low-Level Training (CIFAR-10) ===\n")
411
412    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
413    print(f"Device: {device}")
414
415    # 데이터 전처리
416    transform_train = transforms.Compose([
417        transforms.RandomCrop(32, padding=4),
418        transforms.RandomHorizontalFlip(),
419        transforms.ToTensor(),
420        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
421    ])
422
423    transform_test = transforms.Compose([
424        transforms.ToTensor(),
425        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
426    ])
427
428    # CIFAR-10 데이터셋
429    train_dataset = datasets.CIFAR10(
430        root='./data', train=True, download=True, transform=transform_train
431    )
432    test_dataset = datasets.CIFAR10(
433        root='./data', train=False, download=True, transform=transform_test
434    )
435
436    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
437    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)
438
439    print(f"Train samples: {len(train_dataset)}")
440    print(f"Test samples: {len(test_dataset)}\n")
441
442    # 모델 (CIFAR용 작은 VGG)
443    model = VGGSmall(config_name='VGG16', num_classes=10, use_bn=True)
444    model.to(device)
445
446    print(f"VGG16-BN for CIFAR-10")
447    print(f"Total parameters: {model.count_parameters():,}\n")
448
449    # 학습
450    epochs = 100
451    lr = 0.1
452
453    for epoch in range(epochs):
454        # Learning rate schedule
455        if epoch in [30, 60, 80]:
456            lr *= 0.1
457            print(f"LR → {lr}")
458
459        train_loss, train_acc = train_epoch(model, train_loader, lr)
460
461        # 10 에폭마다 평가
462        if (epoch + 1) % 10 == 0:
463            test_loss, test_acc = evaluate(model, test_loader)
464            print(f"Epoch {epoch+1}/{epochs}")
465            print(f"  Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
466            print(f"  Test  - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}\n")
467
468    # 최종 평가
469    final_loss, final_acc = evaluate(model, test_loader)
470    print(f"Final Test Accuracy: {final_acc:.4f}")
471
472
473if __name__ == "__main__":
474    main()