1"""
2PyTorch Low-Level Vision Transformer (ViT) ๊ตฌํ
3
4nn.TransformerEncoder ๋ฏธ์ฌ์ฉ
5ํจ์น ์๋ฒ ๋ฉ๋ถํฐ ์ง์ ๊ตฌํ
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11import math
12from typing import Optional, Tuple
13from dataclasses import dataclass
14
15
16@dataclass
17class ViTConfig:
18 """ViT ์ค์ """
19 image_size: int = 224
20 patch_size: int = 16
21 in_channels: int = 3
22 num_classes: int = 1000
23 hidden_size: int = 768
24 num_layers: int = 12
25 num_heads: int = 12
26 mlp_ratio: float = 4.0
27 dropout: float = 0.0
28 attention_dropout: float = 0.0
29
30
31class PatchEmbedding(nn.Module):
32 """
33 ์ด๋ฏธ์ง โ ํจ์น โ ์๋ฒ ๋ฉ
34
35 (B, C, H, W) โ (B, N, D)
36 """
37
38 def __init__(
39 self,
40 image_size: int = 224,
41 patch_size: int = 16,
42 in_channels: int = 3,
43 hidden_size: int = 768
44 ):
45 super().__init__()
46 self.image_size = image_size
47 self.patch_size = patch_size
48 self.num_patches = (image_size // patch_size) ** 2
49
50 # Linear projection (Conv2d๋ก ํจ์จ์ ๊ตฌํ)
51 # kernel_size = stride = patch_size โ ๊ฒน์น์ง ์๋ ํจ์น
52 self.projection = nn.Conv2d(
53 in_channels, hidden_size,
54 kernel_size=patch_size, stride=patch_size
55 )
56
57 def forward(self, x: torch.Tensor) -> torch.Tensor:
58 """
59 Args:
60 x: (B, C, H, W)
61
62 Returns:
63 patches: (B, N, D) where N = num_patches
64 """
65 # (B, C, H, W) โ (B, D, H/P, W/P)
66 x = self.projection(x)
67
68 # (B, D, H', W') โ (B, D, N) โ (B, N, D)
69 x = x.flatten(2).transpose(1, 2)
70
71 return x
72
73
74class MultiHeadAttention(nn.Module):
75 """Multi-Head Self-Attention"""
76
77 def __init__(
78 self,
79 hidden_size: int,
80 num_heads: int,
81 dropout: float = 0.0
82 ):
83 super().__init__()
84 assert hidden_size % num_heads == 0
85
86 self.num_heads = num_heads
87 self.head_dim = hidden_size // num_heads
88 self.scale = self.head_dim ** -0.5
89
90 # QKV๋ฅผ ํ๋์ projection์ผ๋ก
91 self.qkv = nn.Linear(hidden_size, hidden_size * 3)
92 self.attn_dropout = nn.Dropout(dropout)
93 self.proj = nn.Linear(hidden_size, hidden_size)
94 self.proj_dropout = nn.Dropout(dropout)
95
96 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
97 """
98 Args:
99 x: (B, N, D)
100
101 Returns:
102 output: (B, N, D)
103 attention: (B, H, N, N)
104 """
105 B, N, D = x.shape
106
107 # QKV ๊ณ์ฐ: (B, N, 3D)
108 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
109 qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, head_dim)
110 q, k, v = qkv[0], qkv[1], qkv[2]
111
112 # Attention scores: (B, H, N, N)
113 attn = (q @ k.transpose(-2, -1)) * self.scale
114 attn = F.softmax(attn, dim=-1)
115 attn = self.attn_dropout(attn)
116
117 # Apply attention: (B, H, N, head_dim)
118 x = attn @ v
119 x = x.transpose(1, 2).reshape(B, N, D) # (B, N, D)
120
121 # Output projection
122 x = self.proj(x)
123 x = self.proj_dropout(x)
124
125 return x, attn
126
127
128class MLP(nn.Module):
129 """Feed-Forward Network (2-layer MLP with GELU)"""
130
131 def __init__(
132 self,
133 hidden_size: int,
134 mlp_ratio: float = 4.0,
135 dropout: float = 0.0
136 ):
137 super().__init__()
138 mlp_hidden = int(hidden_size * mlp_ratio)
139
140 self.fc1 = nn.Linear(hidden_size, mlp_hidden)
141 self.fc2 = nn.Linear(mlp_hidden, hidden_size)
142 self.dropout = nn.Dropout(dropout)
143
144 def forward(self, x: torch.Tensor) -> torch.Tensor:
145 x = self.fc1(x)
146 x = F.gelu(x)
147 x = self.dropout(x)
148 x = self.fc2(x)
149 x = self.dropout(x)
150 return x
151
152
153class TransformerBlock(nn.Module):
154 """ViT Transformer Block (Pre-LN)"""
155
156 def __init__(
157 self,
158 hidden_size: int,
159 num_heads: int,
160 mlp_ratio: float = 4.0,
161 dropout: float = 0.0,
162 attention_dropout: float = 0.0
163 ):
164 super().__init__()
165 self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6)
166 self.attn = MultiHeadAttention(hidden_size, num_heads, attention_dropout)
167 self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6)
168 self.mlp = MLP(hidden_size, mlp_ratio, dropout)
169
170 def forward(
171 self,
172 x: torch.Tensor,
173 return_attention: bool = False
174 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
175 # Pre-LN + Attention + Residual
176 attn_out, attn_weights = self.attn(self.norm1(x))
177 x = x + attn_out
178
179 # Pre-LN + MLP + Residual
180 x = x + self.mlp(self.norm2(x))
181
182 if return_attention:
183 return x, attn_weights
184 return x, None
185
186
187class VisionTransformer(nn.Module):
188 """Vision Transformer (ViT)"""
189
190 def __init__(self, config: ViTConfig):
191 super().__init__()
192 self.config = config
193
194 # Patch embedding
195 self.patch_embed = PatchEmbedding(
196 config.image_size, config.patch_size,
197 config.in_channels, config.hidden_size
198 )
199 num_patches = self.patch_embed.num_patches
200
201 # [CLS] token
202 self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
203
204 # Position embedding (learnable)
205 self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
206
207 # Dropout
208 self.pos_dropout = nn.Dropout(config.dropout)
209
210 # Transformer blocks
211 self.blocks = nn.ModuleList([
212 TransformerBlock(
213 config.hidden_size, config.num_heads,
214 config.mlp_ratio, config.dropout, config.attention_dropout
215 )
216 for _ in range(config.num_layers)
217 ])
218
219 # Final norm
220 self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
221
222 # Classification head
223 self.head = nn.Linear(config.hidden_size, config.num_classes)
224
225 # Initialize weights
226 self._init_weights()
227
228 def _init_weights(self):
229 # Position embedding: truncated normal
230 nn.init.trunc_normal_(self.pos_embed, std=0.02)
231 nn.init.trunc_normal_(self.cls_token, std=0.02)
232
233 # Linear layers
234 for m in self.modules():
235 if isinstance(m, nn.Linear):
236 nn.init.trunc_normal_(m.weight, std=0.02)
237 if m.bias is not None:
238 nn.init.zeros_(m.bias)
239 elif isinstance(m, nn.LayerNorm):
240 nn.init.ones_(m.weight)
241 nn.init.zeros_(m.bias)
242
243 def forward_features(
244 self,
245 x: torch.Tensor,
246 return_all_tokens: bool = False,
247 return_attention: bool = False
248 ):
249 """ํน์ง ์ถ์ถ (๋ถ๋ฅ ํค๋ ์ )"""
250 B = x.shape[0]
251
252 # Patch embedding: (B, N, D)
253 x = self.patch_embed(x)
254
255 # [CLS] token ์ถ๊ฐ: (B, N+1, D)
256 cls_tokens = self.cls_token.expand(B, -1, -1)
257 x = torch.cat([cls_tokens, x], dim=1)
258
259 # Position embedding
260 x = x + self.pos_embed
261 x = self.pos_dropout(x)
262
263 # Transformer blocks
264 attentions = []
265 for block in self.blocks:
266 x, attn = block(x, return_attention=return_attention)
267 if return_attention:
268 attentions.append(attn)
269
270 # Final norm
271 x = self.norm(x)
272
273 if return_all_tokens:
274 return x, attentions
275
276 # [CLS] token๋ง ๋ฐํ
277 return x[:, 0], attentions
278
279 def forward(self, x: torch.Tensor) -> torch.Tensor:
280 """๋ถ๋ฅ"""
281 features, _ = self.forward_features(x)
282 return self.head(features)
283
284
285class ViTForImageClassification(nn.Module):
286 """ViT with flexible head"""
287
288 def __init__(self, config: ViTConfig):
289 super().__init__()
290 self.vit = VisionTransformer(config)
291
292 def forward(
293 self,
294 pixel_values: torch.Tensor,
295 labels: Optional[torch.Tensor] = None
296 ):
297 logits = self.vit(pixel_values)
298
299 loss = None
300 if labels is not None:
301 loss = F.cross_entropy(logits, labels)
302
303 return {
304 'logits': logits,
305 'loss': loss
306 }
307
308
309# Attention ์๊ฐํ
310def visualize_attention(
311 model: VisionTransformer,
312 image: torch.Tensor,
313 layer_idx: int = -1,
314 head_idx: int = 0
315):
316 """Attention map ์๊ฐํ"""
317 import matplotlib.pyplot as plt
318
319 model.eval()
320 with torch.no_grad():
321 _, attentions = model.forward_features(image, return_attention=True)
322
323 # ํน์ ๋ ์ด์ด์ attention
324 attn = attentions[layer_idx] # (B, H, N, N)
325 attn = attn[0, head_idx] # (N, N)
326
327 # [CLS] token์ attention (๋ค๋ฅธ ํจ์น์ ๋ํ)
328 cls_attn = attn[0, 1:] # (N-1,)
329
330 # 2D๋ก reshape
331 num_patches = int(cls_attn.shape[0] ** 0.5)
332 cls_attn = cls_attn.reshape(num_patches, num_patches)
333
334 # ์๊ฐํ
335 fig, axes = plt.subplots(1, 2, figsize=(10, 5))
336
337 # ์๋ณธ ์ด๋ฏธ์ง
338 img = image[0].permute(1, 2, 0).cpu()
339 img = (img - img.min()) / (img.max() - img.min())
340 axes[0].imshow(img)
341 axes[0].set_title("Original Image")
342 axes[0].axis('off')
343
344 # Attention map
345 axes[1].imshow(cls_attn.cpu(), cmap='viridis')
346 axes[1].set_title(f"[CLS] Attention (Layer {layer_idx}, Head {head_idx})")
347 axes[1].axis('off')
348
349 plt.tight_layout()
350 plt.savefig('vit_attention.png')
351 print("Saved vit_attention.png")
352
353
354# ๋ค์ํ ํฌ๊ธฐ์ ViT ์ค์
355def vit_tiny():
356 return ViTConfig(hidden_size=192, num_layers=12, num_heads=3)
357
358def vit_small():
359 return ViTConfig(hidden_size=384, num_layers=12, num_heads=6)
360
361def vit_base():
362 return ViTConfig(hidden_size=768, num_layers=12, num_heads=12)
363
364def vit_large():
365 return ViTConfig(hidden_size=1024, num_layers=24, num_heads=16)
366
367
368# ํ
์คํธ
369if __name__ == "__main__":
370 print("=== Vision Transformer Low-Level Implementation ===\n")
371
372 # ViT-Base ์ค์
373 config = vit_base()
374 print(f"Config: {config}\n")
375
376 # ๋ชจ๋ธ ์์ฑ
377 model = VisionTransformer(config)
378
379 # ํ๋ผ๋ฏธํฐ ์
380 total_params = sum(p.numel() for p in model.parameters())
381 print(f"Total parameters: {total_params:,}")
382 print(f"Expected ~86M for ViT-Base/16\n")
383
384 # ํ
์คํธ ์
๋ ฅ
385 batch_size = 2
386 x = torch.randn(batch_size, 3, 224, 224)
387
388 # Forward
389 logits = model(x)
390 print(f"Input shape: {x.shape}")
391 print(f"Output shape: {logits.shape}")
392
393 # Features with attention
394 features, attentions = model.forward_features(x, return_attention=True)
395 print(f"\nFeatures shape: {features.shape}")
396 print(f"Number of attention maps: {len(attentions)}")
397 print(f"Attention shape: {attentions[0].shape}")
398
399 # Patch embedding ํ
์คํธ
400 print("\n=== Patch Embedding Test ===")
401 patch_embed = PatchEmbedding(224, 16, 3, 768)
402 patches = patch_embed(x)
403 print(f"Image: {x.shape}")
404 print(f"Patches: {patches.shape}")
405 print(f"Number of patches: {patches.shape[1]}")
406 print(f"Expected: (224/16)ยฒ = {(224//16)**2}")
407
408 # ๋ค์ํ ํฌ๊ธฐ ํ
์คํธ
409 print("\n=== Different ViT Sizes ===")
410 for name, config_fn in [('Tiny', vit_tiny), ('Small', vit_small),
411 ('Base', vit_base), ('Large', vit_large)]:
412 cfg = config_fn()
413 model = VisionTransformer(cfg)
414 params = sum(p.numel() for p in model.parameters())
415 print(f"ViT-{name}: {params/1e6:.1f}M params")
416
417 print("\nAll tests passed!")