vit_lowlevel.py

Download
python 418 lines 11.3 KB
  1"""
  2PyTorch Low-Level Vision Transformer (ViT) ๊ตฌํ˜„
  3
  4nn.TransformerEncoder ๋ฏธ์‚ฌ์šฉ
  5ํŒจ์น˜ ์ž„๋ฒ ๋”ฉ๋ถ€ํ„ฐ ์ง์ ‘ ๊ตฌํ˜„
  6"""
  7
  8import torch
  9import torch.nn as nn
 10import torch.nn.functional as F
 11import math
 12from typing import Optional, Tuple
 13from dataclasses import dataclass
 14
 15
 16@dataclass
 17class ViTConfig:
 18    """ViT ์„ค์ •"""
 19    image_size: int = 224
 20    patch_size: int = 16
 21    in_channels: int = 3
 22    num_classes: int = 1000
 23    hidden_size: int = 768
 24    num_layers: int = 12
 25    num_heads: int = 12
 26    mlp_ratio: float = 4.0
 27    dropout: float = 0.0
 28    attention_dropout: float = 0.0
 29
 30
 31class PatchEmbedding(nn.Module):
 32    """
 33    ์ด๋ฏธ์ง€ โ†’ ํŒจ์น˜ โ†’ ์ž„๋ฒ ๋”ฉ
 34
 35    (B, C, H, W) โ†’ (B, N, D)
 36    """
 37
 38    def __init__(
 39        self,
 40        image_size: int = 224,
 41        patch_size: int = 16,
 42        in_channels: int = 3,
 43        hidden_size: int = 768
 44    ):
 45        super().__init__()
 46        self.image_size = image_size
 47        self.patch_size = patch_size
 48        self.num_patches = (image_size // patch_size) ** 2
 49
 50        # Linear projection (Conv2d๋กœ ํšจ์œจ์  ๊ตฌํ˜„)
 51        # kernel_size = stride = patch_size โ†’ ๊ฒน์น˜์ง€ ์•Š๋Š” ํŒจ์น˜
 52        self.projection = nn.Conv2d(
 53            in_channels, hidden_size,
 54            kernel_size=patch_size, stride=patch_size
 55        )
 56
 57    def forward(self, x: torch.Tensor) -> torch.Tensor:
 58        """
 59        Args:
 60            x: (B, C, H, W)
 61
 62        Returns:
 63            patches: (B, N, D) where N = num_patches
 64        """
 65        # (B, C, H, W) โ†’ (B, D, H/P, W/P)
 66        x = self.projection(x)
 67
 68        # (B, D, H', W') โ†’ (B, D, N) โ†’ (B, N, D)
 69        x = x.flatten(2).transpose(1, 2)
 70
 71        return x
 72
 73
 74class MultiHeadAttention(nn.Module):
 75    """Multi-Head Self-Attention"""
 76
 77    def __init__(
 78        self,
 79        hidden_size: int,
 80        num_heads: int,
 81        dropout: float = 0.0
 82    ):
 83        super().__init__()
 84        assert hidden_size % num_heads == 0
 85
 86        self.num_heads = num_heads
 87        self.head_dim = hidden_size // num_heads
 88        self.scale = self.head_dim ** -0.5
 89
 90        # QKV๋ฅผ ํ•˜๋‚˜์˜ projection์œผ๋กœ
 91        self.qkv = nn.Linear(hidden_size, hidden_size * 3)
 92        self.attn_dropout = nn.Dropout(dropout)
 93        self.proj = nn.Linear(hidden_size, hidden_size)
 94        self.proj_dropout = nn.Dropout(dropout)
 95
 96    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 97        """
 98        Args:
 99            x: (B, N, D)
100
101        Returns:
102            output: (B, N, D)
103            attention: (B, H, N, N)
104        """
105        B, N, D = x.shape
106
107        # QKV ๊ณ„์‚ฐ: (B, N, 3D)
108        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
109        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, N, head_dim)
110        q, k, v = qkv[0], qkv[1], qkv[2]
111
112        # Attention scores: (B, H, N, N)
113        attn = (q @ k.transpose(-2, -1)) * self.scale
114        attn = F.softmax(attn, dim=-1)
115        attn = self.attn_dropout(attn)
116
117        # Apply attention: (B, H, N, head_dim)
118        x = attn @ v
119        x = x.transpose(1, 2).reshape(B, N, D)  # (B, N, D)
120
121        # Output projection
122        x = self.proj(x)
123        x = self.proj_dropout(x)
124
125        return x, attn
126
127
128class MLP(nn.Module):
129    """Feed-Forward Network (2-layer MLP with GELU)"""
130
131    def __init__(
132        self,
133        hidden_size: int,
134        mlp_ratio: float = 4.0,
135        dropout: float = 0.0
136    ):
137        super().__init__()
138        mlp_hidden = int(hidden_size * mlp_ratio)
139
140        self.fc1 = nn.Linear(hidden_size, mlp_hidden)
141        self.fc2 = nn.Linear(mlp_hidden, hidden_size)
142        self.dropout = nn.Dropout(dropout)
143
144    def forward(self, x: torch.Tensor) -> torch.Tensor:
145        x = self.fc1(x)
146        x = F.gelu(x)
147        x = self.dropout(x)
148        x = self.fc2(x)
149        x = self.dropout(x)
150        return x
151
152
153class TransformerBlock(nn.Module):
154    """ViT Transformer Block (Pre-LN)"""
155
156    def __init__(
157        self,
158        hidden_size: int,
159        num_heads: int,
160        mlp_ratio: float = 4.0,
161        dropout: float = 0.0,
162        attention_dropout: float = 0.0
163    ):
164        super().__init__()
165        self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
166        self.attn = MultiHeadAttention(hidden_size, num_heads, attention_dropout)
167        self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
168        self.mlp = MLP(hidden_size, mlp_ratio, dropout)
169
170    def forward(
171        self,
172        x: torch.Tensor,
173        return_attention: bool = False
174    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
175        # Pre-LN + Attention + Residual
176        attn_out, attn_weights = self.attn(self.norm1(x))
177        x = x + attn_out
178
179        # Pre-LN + MLP + Residual
180        x = x + self.mlp(self.norm2(x))
181
182        if return_attention:
183            return x, attn_weights
184        return x, None
185
186
187class VisionTransformer(nn.Module):
188    """Vision Transformer (ViT)"""
189
190    def __init__(self, config: ViTConfig):
191        super().__init__()
192        self.config = config
193
194        # Patch embedding
195        self.patch_embed = PatchEmbedding(
196            config.image_size, config.patch_size,
197            config.in_channels, config.hidden_size
198        )
199        num_patches = self.patch_embed.num_patches
200
201        # [CLS] token
202        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
203
204        # Position embedding (learnable)
205        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
206
207        # Dropout
208        self.pos_dropout = nn.Dropout(config.dropout)
209
210        # Transformer blocks
211        self.blocks = nn.ModuleList([
212            TransformerBlock(
213                config.hidden_size, config.num_heads,
214                config.mlp_ratio, config.dropout, config.attention_dropout
215            )
216            for _ in range(config.num_layers)
217        ])
218
219        # Final norm
220        self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
221
222        # Classification head
223        self.head = nn.Linear(config.hidden_size, config.num_classes)
224
225        # Initialize weights
226        self._init_weights()
227
228    def _init_weights(self):
229        # Position embedding: truncated normal
230        nn.init.trunc_normal_(self.pos_embed, std=0.02)
231        nn.init.trunc_normal_(self.cls_token, std=0.02)
232
233        # Linear layers
234        for m in self.modules():
235            if isinstance(m, nn.Linear):
236                nn.init.trunc_normal_(m.weight, std=0.02)
237                if m.bias is not None:
238                    nn.init.zeros_(m.bias)
239            elif isinstance(m, nn.LayerNorm):
240                nn.init.ones_(m.weight)
241                nn.init.zeros_(m.bias)
242
243    def forward_features(
244        self,
245        x: torch.Tensor,
246        return_all_tokens: bool = False,
247        return_attention: bool = False
248    ):
249        """ํŠน์ง• ์ถ”์ถœ (๋ถ„๋ฅ˜ ํ—ค๋“œ ์ „)"""
250        B = x.shape[0]
251
252        # Patch embedding: (B, N, D)
253        x = self.patch_embed(x)
254
255        # [CLS] token ์ถ”๊ฐ€: (B, N+1, D)
256        cls_tokens = self.cls_token.expand(B, -1, -1)
257        x = torch.cat([cls_tokens, x], dim=1)
258
259        # Position embedding
260        x = x + self.pos_embed
261        x = self.pos_dropout(x)
262
263        # Transformer blocks
264        attentions = []
265        for block in self.blocks:
266            x, attn = block(x, return_attention=return_attention)
267            if return_attention:
268                attentions.append(attn)
269
270        # Final norm
271        x = self.norm(x)
272
273        if return_all_tokens:
274            return x, attentions
275
276        # [CLS] token๋งŒ ๋ฐ˜ํ™˜
277        return x[:, 0], attentions
278
279    def forward(self, x: torch.Tensor) -> torch.Tensor:
280        """๋ถ„๋ฅ˜"""
281        features, _ = self.forward_features(x)
282        return self.head(features)
283
284
285class ViTForImageClassification(nn.Module):
286    """ViT with flexible head"""
287
288    def __init__(self, config: ViTConfig):
289        super().__init__()
290        self.vit = VisionTransformer(config)
291
292    def forward(
293        self,
294        pixel_values: torch.Tensor,
295        labels: Optional[torch.Tensor] = None
296    ):
297        logits = self.vit(pixel_values)
298
299        loss = None
300        if labels is not None:
301            loss = F.cross_entropy(logits, labels)
302
303        return {
304            'logits': logits,
305            'loss': loss
306        }
307
308
309# Attention ์‹œ๊ฐํ™”
310def visualize_attention(
311    model: VisionTransformer,
312    image: torch.Tensor,
313    layer_idx: int = -1,
314    head_idx: int = 0
315):
316    """Attention map ์‹œ๊ฐํ™”"""
317    import matplotlib.pyplot as plt
318
319    model.eval()
320    with torch.no_grad():
321        _, attentions = model.forward_features(image, return_attention=True)
322
323    # ํŠน์ • ๋ ˆ์ด์–ด์˜ attention
324    attn = attentions[layer_idx]  # (B, H, N, N)
325    attn = attn[0, head_idx]      # (N, N)
326
327    # [CLS] token์˜ attention (๋‹ค๋ฅธ ํŒจ์น˜์— ๋Œ€ํ•œ)
328    cls_attn = attn[0, 1:]  # (N-1,)
329
330    # 2D๋กœ reshape
331    num_patches = int(cls_attn.shape[0] ** 0.5)
332    cls_attn = cls_attn.reshape(num_patches, num_patches)
333
334    # ์‹œ๊ฐํ™”
335    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
336
337    # ์›๋ณธ ์ด๋ฏธ์ง€
338    img = image[0].permute(1, 2, 0).cpu()
339    img = (img - img.min()) / (img.max() - img.min())
340    axes[0].imshow(img)
341    axes[0].set_title("Original Image")
342    axes[0].axis('off')
343
344    # Attention map
345    axes[1].imshow(cls_attn.cpu(), cmap='viridis')
346    axes[1].set_title(f"[CLS] Attention (Layer {layer_idx}, Head {head_idx})")
347    axes[1].axis('off')
348
349    plt.tight_layout()
350    plt.savefig('vit_attention.png')
351    print("Saved vit_attention.png")
352
353
354# ๋‹ค์–‘ํ•œ ํฌ๊ธฐ์˜ ViT ์„ค์ •
355def vit_tiny():
356    return ViTConfig(hidden_size=192, num_layers=12, num_heads=3)
357
358def vit_small():
359    return ViTConfig(hidden_size=384, num_layers=12, num_heads=6)
360
361def vit_base():
362    return ViTConfig(hidden_size=768, num_layers=12, num_heads=12)
363
364def vit_large():
365    return ViTConfig(hidden_size=1024, num_layers=24, num_heads=16)
366
367
368# ํ…Œ์ŠคํŠธ
369if __name__ == "__main__":
370    print("=== Vision Transformer Low-Level Implementation ===\n")
371
372    # ViT-Base ์„ค์ •
373    config = vit_base()
374    print(f"Config: {config}\n")
375
376    # ๋ชจ๋ธ ์ƒ์„ฑ
377    model = VisionTransformer(config)
378
379    # ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
380    total_params = sum(p.numel() for p in model.parameters())
381    print(f"Total parameters: {total_params:,}")
382    print(f"Expected ~86M for ViT-Base/16\n")
383
384    # ํ…Œ์ŠคํŠธ ์ž…๋ ฅ
385    batch_size = 2
386    x = torch.randn(batch_size, 3, 224, 224)
387
388    # Forward
389    logits = model(x)
390    print(f"Input shape: {x.shape}")
391    print(f"Output shape: {logits.shape}")
392
393    # Features with attention
394    features, attentions = model.forward_features(x, return_attention=True)
395    print(f"\nFeatures shape: {features.shape}")
396    print(f"Number of attention maps: {len(attentions)}")
397    print(f"Attention shape: {attentions[0].shape}")
398
399    # Patch embedding ํ…Œ์ŠคํŠธ
400    print("\n=== Patch Embedding Test ===")
401    patch_embed = PatchEmbedding(224, 16, 3, 768)
402    patches = patch_embed(x)
403    print(f"Image: {x.shape}")
404    print(f"Patches: {patches.shape}")
405    print(f"Number of patches: {patches.shape[1]}")
406    print(f"Expected: (224/16)ยฒ = {(224//16)**2}")
407
408    # ๋‹ค์–‘ํ•œ ํฌ๊ธฐ ํ…Œ์ŠคํŠธ
409    print("\n=== Different ViT Sizes ===")
410    for name, config_fn in [('Tiny', vit_tiny), ('Small', vit_small),
411                             ('Base', vit_base), ('Large', vit_large)]:
412        cfg = config_fn()
413        model = VisionTransformer(cfg)
414        params = sum(p.numel() for p in model.parameters())
415        print(f"ViT-{name}: {params/1e6:.1f}M params")
416
417    print("\nAll tests passed!")