06_cnn_architectures.py

Download
python 494 lines 14.5 KB
  1"""
  206. CNN ์‹ฌํ™” - ์œ ๋ช… ์•„ํ‚คํ…์ฒ˜
  3
  4VGG, ResNet, EfficientNet ๋“ฑ ์œ ๋ช… ์•„ํ‚คํ…์ฒ˜๋ฅผ PyTorch๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10
 11print("=" * 60)
 12print("PyTorch CNN ์‹ฌํ™” - ์œ ๋ช… ์•„ํ‚คํ…์ฒ˜")
 13print("=" * 60)
 14
 15
 16# ============================================
 17# 1. VGG ๋ธ”๋ก ๋ฐ ๋ชจ๋ธ
 18# ============================================
 19print("\n[1] VGG16 ๊ตฌํ˜„")
 20print("-" * 40)
 21
 22def make_vgg_block(in_channels, out_channels, num_convs):
 23    """VGG ๋ธ”๋ก ์ƒ์„ฑ"""
 24    layers = []
 25    for i in range(num_convs):
 26        layers.append(nn.Conv2d(
 27            in_channels if i == 0 else out_channels,
 28            out_channels, kernel_size=3, padding=1
 29        ))
 30        layers.append(nn.ReLU(inplace=True))
 31    layers.append(nn.MaxPool2d(2, 2))
 32    return nn.Sequential(*layers)
 33
 34
 35class VGG16(nn.Module):
 36    """VGG16 ๊ตฌํ˜„"""
 37    def __init__(self, num_classes=1000):
 38        super().__init__()
 39        # ํŠน์ง• ์ถ”์ถœ๋ถ€
 40        self.features = nn.Sequential(
 41            make_vgg_block(3, 64, 2),    # 224โ†’112
 42            make_vgg_block(64, 128, 2),  # 112โ†’56
 43            make_vgg_block(128, 256, 3), # 56โ†’28
 44            make_vgg_block(256, 512, 3), # 28โ†’14
 45            make_vgg_block(512, 512, 3), # 14โ†’7
 46        )
 47
 48        # ๋ถ„๋ฅ˜๊ธฐ
 49        self.classifier = nn.Sequential(
 50            nn.Linear(512 * 7 * 7, 4096),
 51            nn.ReLU(inplace=True),
 52            nn.Dropout(0.5),
 53            nn.Linear(4096, 4096),
 54            nn.ReLU(inplace=True),
 55            nn.Dropout(0.5),
 56            nn.Linear(4096, num_classes),
 57        )
 58
 59    def forward(self, x):
 60        x = self.features(x)
 61        x = x.view(x.size(0), -1)
 62        x = self.classifier(x)
 63        return x
 64
 65vgg = VGG16(num_classes=10)
 66print(f"VGG16 ํŒŒ๋ผ๋ฏธํ„ฐ: {sum(p.numel() for p in vgg.parameters()):,}")
 67
 68# ํ…Œ์ŠคํŠธ
 69x = torch.randn(1, 3, 224, 224)
 70out = vgg(x)
 71print(f"์ž…๋ ฅ: {x.shape} โ†’ ์ถœ๋ ฅ: {out.shape}")
 72
 73
 74# ============================================
 75# 2. ResNet Basic Block
 76# ============================================
 77print("\n[2] ResNet ๊ตฌํ˜„")
 78print("-" * 40)
 79
 80class BasicBlock(nn.Module):
 81    """ResNet Basic Block (ResNet-18, 34์šฉ)"""
 82    expansion = 1
 83
 84    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
 85        super().__init__()
 86        self.conv1 = nn.Conv2d(in_channels, out_channels, 3,
 87                               stride=stride, padding=1, bias=False)
 88        self.bn1 = nn.BatchNorm2d(out_channels)
 89        self.conv2 = nn.Conv2d(out_channels, out_channels, 3,
 90                               padding=1, bias=False)
 91        self.bn2 = nn.BatchNorm2d(out_channels)
 92        self.downsample = downsample
 93
 94    def forward(self, x):
 95        identity = x
 96
 97        out = F.relu(self.bn1(self.conv1(x)))
 98        out = self.bn2(self.conv2(out))
 99
100        if self.downsample is not None:
101            identity = self.downsample(x)
102
103        out += identity  # Skip connection!
104        out = F.relu(out)
105        return out
106
107
108class Bottleneck(nn.Module):
109    """ResNet Bottleneck Block (ResNet-50, 101, 152์šฉ)"""
110    expansion = 4
111
112    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
113        super().__init__()
114        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
115        self.bn1 = nn.BatchNorm2d(out_channels)
116        self.conv2 = nn.Conv2d(out_channels, out_channels, 3,
117                               stride=stride, padding=1, bias=False)
118        self.bn2 = nn.BatchNorm2d(out_channels)
119        self.conv3 = nn.Conv2d(out_channels, out_channels * 4, 1, bias=False)
120        self.bn3 = nn.BatchNorm2d(out_channels * 4)
121        self.downsample = downsample
122
123    def forward(self, x):
124        identity = x
125
126        out = F.relu(self.bn1(self.conv1(x)))
127        out = F.relu(self.bn2(self.conv2(out)))
128        out = self.bn3(self.conv3(out))
129
130        if self.downsample is not None:
131            identity = self.downsample(x)
132
133        out += identity
134        out = F.relu(out)
135        return out
136
137
138class ResNet(nn.Module):
139    """ResNet ๊ตฌํ˜„"""
140    def __init__(self, block, layers, num_classes=1000):
141        super().__init__()
142        self.in_channels = 64
143
144        # ์ดˆ๊ธฐ ์ธต
145        self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
146        self.bn1 = nn.BatchNorm2d(64)
147        self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
148
149        # ResNet ์ธต
150        self.layer1 = self._make_layer(block, 64, layers[0])
151        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
152        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
153        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
154
155        # ๋ถ„๋ฅ˜๊ธฐ
156        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
157        self.fc = nn.Linear(512 * block.expansion, num_classes)
158
159    def _make_layer(self, block, out_channels, num_blocks, stride=1):
160        downsample = None
161        if stride != 1 or self.in_channels != out_channels * block.expansion:
162            downsample = nn.Sequential(
163                nn.Conv2d(self.in_channels, out_channels * block.expansion,
164                         1, stride=stride, bias=False),
165                nn.BatchNorm2d(out_channels * block.expansion)
166            )
167
168        layers = [block(self.in_channels, out_channels, stride, downsample)]
169        self.in_channels = out_channels * block.expansion
170
171        for _ in range(1, num_blocks):
172            layers.append(block(self.in_channels, out_channels))
173
174        return nn.Sequential(*layers)
175
176    def forward(self, x):
177        x = F.relu(self.bn1(self.conv1(x)))
178        x = self.maxpool(x)
179
180        x = self.layer1(x)
181        x = self.layer2(x)
182        x = self.layer3(x)
183        x = self.layer4(x)
184
185        x = self.avgpool(x)
186        x = x.view(x.size(0), -1)
187        x = self.fc(x)
188        return x
189
190
191def resnet18(num_classes=1000):
192    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
193
194def resnet34(num_classes=1000):
195    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
196
197def resnet50(num_classes=1000):
198    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
199
200# ํ…Œ์ŠคํŠธ
201resnet = resnet18(num_classes=10)
202print(f"ResNet-18 ํŒŒ๋ผ๋ฏธํ„ฐ: {sum(p.numel() for p in resnet.parameters()):,}")
203
204x = torch.randn(1, 3, 224, 224)
205out = resnet(x)
206print(f"์ž…๋ ฅ: {x.shape} โ†’ ์ถœ๋ ฅ: {out.shape}")
207
208
209# ============================================
210# 3. SE Block (Squeeze-and-Excitation)
211# ============================================
212print("\n[3] SE Block")
213print("-" * 40)
214
215class SEBlock(nn.Module):
216    """Squeeze-and-Excitation Block"""
217    def __init__(self, channels, reduction=16):
218        super().__init__()
219        self.squeeze = nn.AdaptiveAvgPool2d(1)
220        self.excitation = nn.Sequential(
221            nn.Linear(channels, channels // reduction, bias=False),
222            nn.ReLU(inplace=True),
223            nn.Linear(channels // reduction, channels, bias=False),
224            nn.Sigmoid()
225        )
226
227    def forward(self, x):
228        b, c, _, _ = x.size()
229        # Squeeze
230        y = self.squeeze(x).view(b, c)
231        # Excitation
232        y = self.excitation(y).view(b, c, 1, 1)
233        # Scale
234        return x * y.expand_as(x)
235
236# ํ…Œ์ŠคํŠธ
237se = SEBlock(64)
238x = torch.randn(2, 64, 32, 32)
239out = se(x)
240print(f"SE Block: {x.shape} โ†’ {out.shape}")
241
242
243# ============================================
244# 4. MBConv (EfficientNet ๋ธ”๋ก)
245# ============================================
246print("\n[4] MBConv Block (EfficientNet)")
247print("-" * 40)
248
249class MBConv(nn.Module):
250    """Mobile Inverted Bottleneck Convolution"""
251    def __init__(self, in_channels, out_channels, expand_ratio=6,
252                 stride=1, se_ratio=0.25):
253        super().__init__()
254        hidden_dim = in_channels * expand_ratio
255        self.use_skip = stride == 1 and in_channels == out_channels
256
257        layers = []
258
259        # Expand
260        if expand_ratio != 1:
261            layers.extend([
262                nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
263                nn.BatchNorm2d(hidden_dim),
264                nn.SiLU(inplace=True)
265            ])
266
267        # Depthwise
268        layers.extend([
269            nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride,
270                     padding=1, groups=hidden_dim, bias=False),
271            nn.BatchNorm2d(hidden_dim),
272            nn.SiLU(inplace=True)
273        ])
274
275        self.conv = nn.Sequential(*layers)
276
277        # SE
278        se_channels = max(1, int(in_channels * se_ratio))
279        self.se = nn.Sequential(
280            nn.AdaptiveAvgPool2d(1),
281            nn.Conv2d(hidden_dim, se_channels, 1),
282            nn.SiLU(inplace=True),
283            nn.Conv2d(se_channels, hidden_dim, 1),
284            nn.Sigmoid()
285        )
286
287        # Project
288        self.project = nn.Sequential(
289            nn.Conv2d(hidden_dim, out_channels, 1, bias=False),
290            nn.BatchNorm2d(out_channels)
291        )
292
293    def forward(self, x):
294        identity = x
295        out = self.conv(x)
296        out = out * self.se(out)
297        out = self.project(out)
298
299        if self.use_skip:
300            out = out + identity
301        return out
302
303# ํ…Œ์ŠคํŠธ
304mbconv = MBConv(32, 32, expand_ratio=6)
305x = torch.randn(2, 32, 28, 28)
306out = mbconv(x)
307print(f"MBConv: {x.shape} โ†’ {out.shape}")
308
309
310# ============================================
311# 5. ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ ์‚ฌ์šฉ
312# ============================================
313print("\n[5] torchvision ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ")
314print("-" * 40)
315
316try:
317    import torchvision.models as models
318
319    # ๋‹ค์–‘ํ•œ ์‚ฌ์ „ ํ•™์Šต ๋ชจ๋ธ
320    model_names = ['resnet18', 'resnet50', 'vgg16', 'mobilenet_v2']
321
322    for name in model_names:
323        model = getattr(models, name)(weights=None)  # ๊ฐ€์ค‘์น˜ ์—†์ด ๊ตฌ์กฐ๋งŒ
324        params = sum(p.numel() for p in model.parameters())
325        print(f"{name}: {params:,} ํŒŒ๋ผ๋ฏธํ„ฐ")
326
327    # ์‚ฌ์ „ ํ•™์Šต๋œ ResNet50 ๋กœ๋“œ
328    print("\n์‚ฌ์ „ ํ•™์Šต๋œ ResNet50 ๋กœ๋“œ:")
329    resnet50_pretrained = models.resnet50(weights='IMAGENET1K_V2')
330    print(f"  ๋งˆ์ง€๋ง‰ ์ธต: {resnet50_pretrained.fc}")
331
332    # ์ „์ด ํ•™์Šต์„ ์œ„ํ•œ ์ˆ˜์ •
333    resnet50_pretrained.fc = nn.Linear(2048, 10)  # 10 ํด๋ž˜์Šค๋กœ ๋ณ€๊ฒฝ
334    print(f"  ์ˆ˜์ •๋œ ๋งˆ์ง€๋ง‰ ์ธต: {resnet50_pretrained.fc}")
335
336except ImportError:
337    print("torchvision์ด ์„ค์น˜๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
338
339
340# ============================================
341# 6. ๋ชจ๋ธ ๋น„๊ต
342# ============================================
343print("\n[6] ๋ชจ๋ธ ๋น„๊ต")
344print("-" * 40)
345
346def count_parameters(model):
347    return sum(p.numel() for p in model.parameters())
348
349def measure_forward_time(model, input_shape, iterations=100):
350    import time
351    model.eval()
352    x = torch.randn(*input_shape)
353    with torch.no_grad():
354        # ์›Œ๋ฐ์—…
355        for _ in range(10):
356            _ = model(x)
357        # ์ธก์ •
358        start = time.time()
359        for _ in range(iterations):
360            _ = model(x)
361        end = time.time()
362    return (end - start) / iterations * 1000  # ms
363
364# ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ๋“ค ๋น„๊ต
365models_to_compare = {
366    'VGG16 (simple)': VGG16(num_classes=10),
367    'ResNet-18': resnet18(num_classes=10),
368    'ResNet-50': resnet50(num_classes=10),
369}
370
371print(f"{'Model':<20} {'Params':>12} {'Time (ms)':>12}")
372print("-" * 46)
373
374for name, model in models_to_compare.items():
375    params = count_parameters(model)
376    try:
377        time_ms = measure_forward_time(model, (1, 3, 224, 224), iterations=10)
378        print(f"{name:<20} {params:>12,} {time_ms:>12.2f}")
379    except:
380        print(f"{name:<20} {params:>12,} {'N/A':>12}")
381
382
383# ============================================
384# 7. ๊ฐ„๋‹จํ•œ ResNet ์‹คํ—˜
385# ============================================
386print("\n[7] Skip Connection ํšจ๊ณผ ์‹คํ—˜")
387print("-" * 40)
388
389class ResBlockWithoutSkip(nn.Module):
390    """Skip Connection ์—†๋Š” ๋ธ”๋ก"""
391    def __init__(self, channels):
392        super().__init__()
393        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
394        self.bn1 = nn.BatchNorm2d(channels)
395        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
396        self.bn2 = nn.BatchNorm2d(channels)
397
398    def forward(self, x):
399        out = F.relu(self.bn1(self.conv1(x)))
400        out = self.bn2(self.conv2(out))
401        return F.relu(out)  # Skip ์—†์Œ!
402
403class ResBlockWithSkip(nn.Module):
404    """Skip Connection ์žˆ๋Š” ๋ธ”๋ก"""
405    def __init__(self, channels):
406        super().__init__()
407        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
408        self.bn1 = nn.BatchNorm2d(channels)
409        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
410        self.bn2 = nn.BatchNorm2d(channels)
411
412    def forward(self, x):
413        identity = x
414        out = F.relu(self.bn1(self.conv1(x)))
415        out = self.bn2(self.conv2(out))
416        return F.relu(out + identity)  # Skip ์žˆ์Œ!
417
418# ๊นŠ์€ ๋„คํŠธ์›Œํฌ ๋น„๊ต
419def make_deep_net(block_class, num_blocks, channels=64):
420    layers = [nn.Conv2d(3, channels, 3, padding=1), nn.ReLU()]
421    for _ in range(num_blocks):
422        layers.append(block_class(channels))
423    layers.append(nn.AdaptiveAvgPool2d(1))
424    layers.append(nn.Flatten())
425    layers.append(nn.Linear(channels, 10))
426    return nn.Sequential(*layers)
427
428# ๊ธฐ์šธ๊ธฐ ํ™•์ธ
429def check_gradient_flow(model, depth):
430    model.train()
431    x = torch.randn(1, 3, 32, 32, requires_grad=True)
432    out = model(x)
433    loss = out.sum()
434    loss.backward()
435
436    # ์ฒซ ๋ฒˆ์งธ Conv ๊ธฐ์šธ๊ธฐ ํ™•์ธ
437    first_conv_grad = None
438    for module in model.modules():
439        if isinstance(module, nn.Conv2d):
440            if module.weight.grad is not None:
441                first_conv_grad = module.weight.grad.abs().mean().item()
442                break
443
444    return first_conv_grad
445
446print("๊ธฐ์šธ๊ธฐ ํ๋ฆ„ ๋น„๊ต (๊นŠ์€ ๋„คํŠธ์›Œํฌ):")
447for depth in [5, 10, 20]:
448    net_no_skip = make_deep_net(ResBlockWithoutSkip, depth)
449    net_with_skip = make_deep_net(ResBlockWithSkip, depth)
450
451    grad_no_skip = check_gradient_flow(net_no_skip, depth)
452    grad_with_skip = check_gradient_flow(net_with_skip, depth)
453
454    print(f"  ๊นŠ์ด {depth:2d}: Skip ์—†์Œ = {grad_no_skip:.6f}, Skip ์žˆ์Œ = {grad_with_skip:.6f}")
455
456
457# ============================================
458# ์ •๋ฆฌ
459# ============================================
460print("\n" + "=" * 60)
461print("CNN ์•„ํ‚คํ…์ฒ˜ ์ •๋ฆฌ")
462print("=" * 60)
463
464summary = """
465์ฃผ์š” ์•„ํ‚คํ…์ฒ˜:
466
4671. VGG (2014)
468   - 3ร—3 Conv๋งŒ ์‚ฌ์šฉ
469   - ๊นŠ์ด = ์„ฑ๋Šฅ (๋‹จ์ˆœํ•˜์ง€๋งŒ ํŒŒ๋ผ๋ฏธํ„ฐ ๋งŽ์Œ)
470
4712. ResNet (2015)
472   - Skip Connection์œผ๋กœ ๊ธฐ์šธ๊ธฐ ์†Œ์‹ค ํ•ด๊ฒฐ
473   - 100+ ์ธต๋„ ํ•™์Šต ๊ฐ€๋Šฅ
474   - ๊ฐ€์žฅ ๋„๋ฆฌ ์‚ฌ์šฉ๋จ
475
4763. EfficientNet (2019)
477   - Compound Scaling
478   - MBConv (Depthwise Separable + SE)
479   - ํšจ์œจ์ ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ ์‚ฌ์šฉ
480
481ํ•ต์‹ฌ ๊ธฐ๋ฒ•:
482- Batch Normalization
483- Skip Connection (Residual)
484- Depthwise Separable Conv
485- Squeeze-and-Excitation
486
487์‹ค์ „ ์„ ํƒ:
488- ๋น ๋ฅธ ์ถ”๋ก : MobileNet, EfficientNet-B0
489- ๋†’์€ ์ •ํ™•๋„: EfficientNet-B4~B7
490- ๊ท ํ˜•: ResNet-50
491"""
492print(summary)
493print("=" * 60)