19. Vision Transformer (ViT)

19. Vision Transformer (ViT)

ํ•™์Šต ๋ชฉํ‘œ

  • Vision Transformer ์•„ํ‚คํ…์ฒ˜ ์ดํ•ด
  • Patch Embedding ์›๋ฆฌ
  • CLS ํ† ํฐ๊ณผ Position Embedding
  • ViT ๋ณ€ํ˜• ๋ชจ๋ธ๋“ค (DeiT, Swin Transformer)
  • PyTorch ๊ตฌํ˜„ ๋ฐ ํ™œ์šฉ

1. Vision Transformer ๊ฐœ์š”

ํ•ต์‹ฌ ์•„์ด๋””์–ด

๊ธฐ์กด CNN: ์ง€์—ญ์  ํŠน์ง• โ†’ ์ „์—ญ ํŠน์ง• (๊ณ„์ธต์ )
ViT: ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜ ์‹œํ€€์Šค๋กœ ๋ณ€ํ™˜ โ†’ Transformer๋กœ ์ฒ˜๋ฆฌ

์ด๋ฏธ์ง€ (224ร—224) โ†’ 16ร—16 ํŒจ์น˜ 196๊ฐœ โ†’ Transformer Encoder

์™œ Transformer๋ฅผ Vision์—?

1. Self-Attention์˜ ์žฅ์ 
   - ์žฅ๊ฑฐ๋ฆฌ ์˜์กด์„ฑ ํฌ์ฐฉ
   - ์ „์—ญ์  ์ปจํ…์ŠคํŠธ ๊ณ ๋ ค

2. ํ™•์žฅ์„ฑ
   - ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์…‹์—์„œ CNN ๋Šฅ๊ฐ€
   - ์Šค์ผ€์ผ๋ง์ด ์šฉ์ด

3. ์•„ํ‚คํ…์ฒ˜ ํ†ตํ•ฉ
   - Vision + Language ํ†ตํ•ฉ ๊ฐ€๋Šฅ
   - ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ํ•™์Šต์— ์œ ๋ฆฌ

2. ViT ์•„ํ‚คํ…์ฒ˜

์ „์ฒด ๊ตฌ์กฐ

์ž…๋ ฅ ์ด๋ฏธ์ง€ (224ร—224ร—3)
        โ†“
[Patch Embedding] โ†’ 196๊ฐœ ํŒจ์น˜ ๋ฒกํ„ฐ (๊ฐ 768์ฐจ์›)
        โ†“
[CLS Token ์ถ”๊ฐ€] โ†’ 197๊ฐœ ํ† ํฐ
        โ†“
[Position Embedding ์ถ”๊ฐ€]
        โ†“
[Transformer Encoder ร— L layers]
        โ†“
[CLS Token ์ถœ๋ ฅ ์ถ”์ถœ]
        โ†“
[MLP Head] โ†’ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ

์ˆ˜์‹ ์ •๋ฆฌ

# ์ž…๋ ฅ
x โˆˆ R^(Hร—Wร—C)  # ์˜ˆ: 224ร—224ร—3

# ํŒจ์น˜ ๋ถ„ํ• 
P = patch_size  # ์˜ˆ: 16
N = (H/P) ร— (W/P)  # ํŒจ์น˜ ๊ฐœ์ˆ˜: 196

# Patch Embedding
x_p โˆˆ R^(Nร—(PยฒยทC))  # 196ร—768 (16ร—16ร—3 = 768)
z_0 = [x_class; x_pยทE] + E_pos  # E: ํˆฌ์˜ ํ–‰๋ ฌ

# Transformer
z_l = MSA(LN(z_{l-1})) + z_{l-1}  # Multi-Head Self-Attention
z_l = MLP(LN(z_l)) + z_l         # Feed Forward

# ์ถœ๋ ฅ
y = LN(z_L^0)  # CLS ํ† ํฐ์˜ ์ตœ์ข… ํ‘œํ˜„

3. Patch Embedding

๊ฐœ๋…

# ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜๋กœ ๋ถ„ํ• 
# (B, 3, 224, 224) โ†’ (B, 196, 768)

# ๋ฐฉ๋ฒ• 1: reshape
patches = image.reshape(B, N, P*P*C)  # ์ง์ ‘ ์žฌ๊ตฌ์„ฑ

# ๋ฐฉ๋ฒ• 2: Conv2d (๋” ํšจ์œจ์ )
# stride=kernel_size๋กœ ๊ฒน์น˜์ง€ ์•Š๋Š” ํŒจ์น˜ ์ถ”์ถœ
conv = nn.Conv2d(3, 768, kernel_size=16, stride=16)
patches = conv(image)  # (B, 768, 14, 14)
patches = patches.flatten(2).transpose(1, 2)  # (B, 196, 768)

PyTorch ๊ตฌํ˜„

class PatchEmbedding(nn.Module):
    """Patch Embedding Layer (โญโญ)"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        # Conv2d๋กœ ํŒจ์น˜ ์ถ”์ถœ + ์ž„๋ฒ ๋”ฉ
        self.projection = nn.Conv2d(
            in_channels, embed_dim,
            kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.projection(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)        # (B, embed_dim, num_patches)
        x = x.transpose(1, 2)   # (B, num_patches, embed_dim)
        return x

4. CLS Token๊ณผ Position Embedding

CLS Token

# BERT์—์„œ ์ฐจ์šฉํ•œ ๊ฐœ๋…
# ์ „์ฒด ์ด๋ฏธ์ง€์˜ ํ‘œํ˜„์„ ํ•™์Šตํ•˜๋Š” ํŠน๋ณ„ ํ† ํฐ

class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# ๋ฐฐ์น˜์— ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ
cls_tokens = class_token.expand(batch_size, -1, -1)  # (B, 1, D)
# ํŒจ์น˜ ์ž„๋ฒ ๋”ฉ ์•ž์— ์—ฐ๊ฒฐ
x = torch.cat([cls_tokens, patch_embeddings], dim=1)  # (B, N+1, D)

Position Embedding

# ํŒจ์น˜์˜ ์œ„์น˜ ์ •๋ณด ์ œ๊ณต (Transformer๋Š” ์œ„์น˜ ์ •๋ณด ์—†์Œ)

class PositionEmbedding(nn.Module):
    """Learnable Position Embedding (โญโญ)"""
    def __init__(self, num_patches, embed_dim):
        super().__init__()
        # +1 for CLS token
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_patches + 1, embed_dim)
        )

    def forward(self, x):
        return x + self.pos_embedding

์œ„์น˜ ์ž„๋ฒ ๋”ฉ ์‹œ๊ฐํ™”

def visualize_position_embedding(pos_embed, img_size=224, patch_size=16):
    """์œ„์น˜ ์ž„๋ฒ ๋”ฉ ์œ ์‚ฌ๋„ ์‹œ๊ฐํ™” (โญโญ)"""
    # pos_embed: (1, N+1, D)
    # CLS ํ† ํฐ ์ œ์™ธ
    pos_embed = pos_embed[0, 1:]  # (N, D)

    # ์œ ์‚ฌ๋„ ํ–‰๋ ฌ
    similarity = torch.mm(pos_embed, pos_embed.T)  # (N, N)

    # ํŠน์ • ํŒจ์น˜์™€์˜ ์œ ์‚ฌ๋„
    num_patches = (img_size // patch_size)
    center_idx = num_patches * (num_patches // 2) + (num_patches // 2)
    center_sim = similarity[center_idx].reshape(num_patches, num_patches)

    return center_sim  # ์ค‘์•™ ํŒจ์น˜์™€์˜ ์œ ์‚ฌ๋„ ๋งต

5. Vision Transformer ์ „์ฒด ๊ตฌํ˜„

๊ธฐ๋ณธ ViT

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention (โญโญโญ)"""
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        # QKV ๊ณ„์‚ฐ
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Attention
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        # Output
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x


class MLP(nn.Module):
    """MLP Block (โญโญ)"""
    def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """Transformer Encoder Block (โญโญโญ)"""
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class VisionTransformer(nn.Module):
    """Vision Transformer (ViT) (โญโญโญโญ)"""
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        dropout=0.0
    ):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2

        # Patch Embedding
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )

        # CLS Token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Position Embedding
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.num_patches + 1, embed_dim)
        )

        self.dropout = nn.Dropout(dropout)

        # Transformer Blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # Classification Head
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]

        # Patch Embedding
        x = self.patch_embed(x)  # (B, N, D)

        # Add CLS Token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, N+1, D)

        # Add Position Embedding
        x = x + self.pos_embed
        x = self.dropout(x)

        # Transformer Blocks
        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        # CLS Token๋งŒ ์ถ”์ถœํ•˜์—ฌ ๋ถ„๋ฅ˜
        cls_output = x[:, 0]
        return self.head(cls_output)

ViT ๋ชจ๋ธ ๋ณ€ํ˜•

# ViT-Base (ViT-B/16)
vit_base = VisionTransformer(
    img_size=224, patch_size=16,
    embed_dim=768, depth=12, num_heads=12
)

# ViT-Large (ViT-L/16)
vit_large = VisionTransformer(
    img_size=224, patch_size=16,
    embed_dim=1024, depth=24, num_heads=16
)

# ViT-Huge (ViT-H/14)
vit_huge = VisionTransformer(
    img_size=224, patch_size=14,
    embed_dim=1280, depth=32, num_heads=16
)

6. DeiT (Data-efficient Image Transformer)

ํ•ต์‹ฌ ๊ฐœ์„ ์ 

๋ฌธ์ œ: ViT๋Š” ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ ํ•„์š” (JFT-300M ๋“ฑ)
ํ•ด๊ฒฐ: ์ง€์‹ ์ฆ๋ฅ˜ + ๊ฐ•๋ ฅํ•œ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•์œผ๋กœ ImageNet๋งŒ์œผ๋กœ ํ•™์Šต

1. Distillation Token: CNN ๊ต์‚ฌ ๋ชจ๋ธ์˜ ์ง€์‹ ํ•™์Šต
2. ๊ฐ•๋ ฅํ•œ Data Augmentation
3. Regularization (Stochastic Depth, Dropout)

Distillation Token

class DeiT(nn.Module):
    """Data-efficient Image Transformer (โญโญโญโญ)"""
    def __init__(self, img_size=224, patch_size=16, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2

        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)

        # CLS Token + Distillation Token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Position Embedding (+2 for CLS and DIST)
        self.pos_embed = nn.Parameter(
            torch.randn(1, self.num_patches + 2, embed_dim)
        )

        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        # ๋‘ ๊ฐœ์˜ Head
        self.head = nn.Linear(embed_dim, num_classes)
        self.head_dist = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, dist_tokens, x], dim=1)

        x = x + self.pos_embed

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        # CLS์™€ DIST ํ† ํฐ ๋ชจ๋‘ ์‚ฌ์šฉ
        cls_output = self.head(x[:, 0])
        dist_output = self.head_dist(x[:, 1])

        if self.training:
            return cls_output, dist_output
        else:
            # ์ถ”๋ก  ์‹œ ํ‰๊ท 
            return (cls_output + dist_output) / 2

DeiT ํ•™์Šต

def train_deit_with_distillation(student, teacher, dataloader, epochs=100):
    """DeiT ์ง€์‹ ์ฆ๋ฅ˜ ํ•™์Šต (โญโญโญ)"""
    optimizer = torch.optim.AdamW(student.parameters(), lr=1e-3)
    criterion_ce = nn.CrossEntropyLoss()
    criterion_dist = nn.CrossEntropyLoss()

    teacher.eval()

    for epoch in range(epochs):
        for images, labels in dataloader:
            # Teacher prediction (soft labels)
            with torch.no_grad():
                teacher_output = teacher(images)

            # Student predictions
            cls_output, dist_output = student(images)

            # Losses
            loss_cls = criterion_ce(cls_output, labels)
            loss_dist = criterion_dist(dist_output, teacher_output.argmax(dim=1))

            loss = 0.5 * loss_cls + 0.5 * loss_dist

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

7. Swin Transformer

ํ•ต์‹ฌ ์•„์ด๋””์–ด

๋ฌธ์ œ: ViT์˜ O(nยฒ) ๋ณต์žก๋„ โ†’ ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์–ด๋ ค์›€
ํ•ด๊ฒฐ: ๊ณ„์ธต์  ๊ตฌ์กฐ + Shifted Window Attention

ํŠน์ง•:
1. Window Attention: ์ง€์—ญ ์œˆ๋„์šฐ ๋‚ด์—์„œ๋งŒ attention
2. Shifted Windows: ์œˆ๋„์šฐ ๊ฐ„ ์ •๋ณด ๊ตํ™˜
3. ๊ณ„์ธต์  ๊ตฌ์กฐ: ํŠน์ง• ๋งต ํ•ด์ƒ๋„ ์ ์ง„์  ๊ฐ์†Œ

Window Attention

def window_partition(x, window_size):
    """์ด๋ฏธ์ง€๋ฅผ ์œˆ๋„์šฐ๋กœ ๋ถ„ํ•  (โญโญโญ)"""
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size,
               W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    windows = windows.view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """์œˆ๋„์šฐ๋ฅผ ๋‹ค์‹œ ์ด๋ฏธ์ง€๋กœ ํ•ฉ์นจ (โญโญโญ)"""
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size,
                     window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    """Window-based Multi-Head Self-Attention (โญโญโญโญ)"""
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # Relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

        # ์ƒ๋Œ€ ์œ„์น˜ ์ธ๋ฑ์Šค ์ƒ์„ฑ
        self._create_relative_position_index()

    def _create_relative_position_index(self):
        coords = torch.arange(self.window_size)
        coords = torch.stack(torch.meshgrid([coords, coords], indexing='ij'))
        coords_flatten = coords.flatten(1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size - 1
        relative_coords[:, :, 1] += self.window_size - 1
        relative_coords[:, :, 0] *= 2 * self.window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x, mask=None):
        B_, N, C = x.shape

        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale

        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)
        ].view(N, N, -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            attn = attn + mask

        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

Shifted Window

class SwinTransformerBlock(nn.Module):
    """Swin Transformer Block with (Shifted) Window Attention (โญโญโญโญ)"""
    def __init__(self, dim, num_heads, window_size=7, shift_size=0):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.shift_size = shift_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim)

    def forward(self, x, H, W):
        B, L, C = x.shape
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Window partition
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # Window attention
        attn_windows = self.attn(x_windows)

        # Window reverse
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, L, C)
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))

        return x

8. ์‚ฌ์ „ํ•™์Šต ๋ชจ๋ธ ํ™œ์šฉ

torchvision ์‚ฌ์šฉ

from torchvision.models import vit_b_16, vit_l_16, swin_t, swin_s

# ViT-B/16 (pretrained)
model = vit_b_16(weights='IMAGENET1K_V1')

# ํŠน์ง• ์ถ”์ถœ๊ธฐ๋กœ ์‚ฌ์šฉ
model.heads = nn.Identity()
features = model(image)  # (B, 768)

# Fine-tuning
model = vit_b_16(weights='IMAGENET1K_V1')
model.heads = nn.Linear(768, num_classes)

# ํ•™์Šต๋ฅ  ์ฐจ๋“ฑ ์ ์šฉ
params = [
    {'params': model.encoder.parameters(), 'lr': 1e-5},  # backbone
    {'params': model.heads.parameters(), 'lr': 1e-3}     # head
]
optimizer = torch.optim.AdamW(params)

timm ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ

import timm

# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ViT ๋ชจ๋ธ ๋ชฉ๋ก
vit_models = timm.list_models('vit*', pretrained=True)
print(f"Available ViT models: {len(vit_models)}")

# ๋ชจ๋ธ ๋กœ๋“œ
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# ์ปค์Šคํ…€ ๋ถ„๋ฅ˜ ํ—ค๋“œ
model = timm.create_model(
    'vit_base_patch16_224',
    pretrained=True,
    num_classes=10  # ์ž๋™์œผ๋กœ head ๊ต์ฒด
)

# DeiT ๋ชจ๋ธ
deit_model = timm.create_model('deit_base_patch16_224', pretrained=True)

# Swin Transformer
swin_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)

9. ์‹ค์ „ Fine-tuning

CIFAR-10 Fine-tuning

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm

def finetune_vit_cifar10(epochs=10):
    """ViT CIFAR-10 Fine-tuning (โญโญโญ)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ (ViT ์ž…๋ ฅ ํฌ๊ธฐ์— ๋งž๊ฒŒ)
    transform_train = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(224, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    transform_test = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # ๋ฐ์ดํ„ฐ์…‹
    train_data = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)
    test_data = datasets.CIFAR10('data', train=False, transform=transform_test)

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4)

    # ๋ชจ๋ธ
    model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=10)
    model = model.to(device)

    # ์˜ตํ‹ฐ๋งˆ์ด์ € (์ฐจ๋“ฑ ํ•™์Šต๋ฅ )
    backbone_params = [p for n, p in model.named_parameters() if 'head' not in n]
    head_params = [p for n, p in model.named_parameters() if 'head' in n]

    optimizer = torch.optim.AdamW([
        {'params': backbone_params, 'lr': 1e-5},
        {'params': head_params, 'lr': 1e-3}
    ], weight_decay=0.01)

    criterion = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # ํ•™์Šต
    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # ํ‰๊ฐ€
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

        accuracy = 100. * correct / total
        print(f'Epoch {epoch+1}/{epochs}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy:.2f}%')

        scheduler.step()

    return model

10. ViT vs CNN ๋น„๊ต

ํŠน์„ฑ ๋น„๊ต

ํŠน์„ฑ CNN ViT
๊ท€๋‚ฉ์  ํŽธํ–ฅ ์ง€์—ญ์„ฑ, ๋“ฑ๋ณ€์„ฑ ์—†์Œ
๋ฐ์ดํ„ฐ ์š”๊ตฌ๋Ÿ‰ ์ ์Œ ๋งŽ์Œ
๊ณ„์‚ฐ ๋ณต์žก๋„ O(n) O(nยฒ)
์žฅ๊ฑฐ๋ฆฌ ์˜์กด์„ฑ ์–ด๋ ค์›€ ์šฉ์ด
ํ•ด์„ ๊ฐ€๋Šฅ์„ฑ ํ•„ํ„ฐ ์‹œ๊ฐํ™” Attention ์‹œ๊ฐํ™”

์‚ฌ์šฉ ๊ฐ€์ด๋“œ๋ผ์ธ

CNN ์„ ํ˜ธ:
- ์†Œ๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์…‹
- ์ œํ•œ๋œ ๊ณ„์‚ฐ ๋ฆฌ์†Œ์Šค
- ์‹ค์‹œ๊ฐ„ ์ถ”๋ก  ํ•„์š”

ViT ์„ ํ˜ธ:
- ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ์…‹ ๋˜๋Š” ์‚ฌ์ „ํ•™์Šต ๋ชจ๋ธ ํ™œ์šฉ
- ์ „์—ญ ์ปจํ…์ŠคํŠธ๊ฐ€ ์ค‘์š”ํ•œ ํƒœ์Šคํฌ
- ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ํ•™์Šต ๊ณ„ํš

์ •๋ฆฌ

ํ•ต์‹ฌ ๊ฐœ๋…

  1. Patch Embedding: ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜ ์‹œํ€€์Šค๋กœ ๋ณ€ํ™˜
  2. CLS Token: ์ „์ฒด ์ด๋ฏธ์ง€ ํ‘œํ˜„ ํ•™์Šต
  3. Position Embedding: ํŒจ์น˜ ์œ„์น˜ ์ •๋ณด ์ œ๊ณต
  4. DeiT: ๋ฐ์ดํ„ฐ ํšจ์œจ์  ํ•™์Šต (์ง€์‹ ์ฆ๋ฅ˜)
  5. Swin: ์œˆ๋„์šฐ ๊ธฐ๋ฐ˜ ํšจ์œจ์  attention

๋ชจ๋ธ ์„ ํƒ ๊ฐ€์ด๋“œ

์ผ๋ฐ˜ ๋ถ„๋ฅ˜: ViT-B/16 ๋˜๋Š” DeiT
๊ณ ํ•ด์ƒ๋„: Swin Transformer
์ œํ•œ๋œ ์ž์›: ViT-Small, DeiT-Tiny
์ตœ๊ณ  ์„ฑ๋Šฅ: ViT-Large, Swin-Large

PyTorch ์‹ค์ „ ํŒ

# 1. timm ์‚ฌ์šฉ ๊ถŒ์žฅ
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)

# 2. ์ฐจ๋“ฑ ํ•™์Šต๋ฅ  ํ•„์ˆ˜
optimizer = torch.optim.AdamW([
    {'params': backbone_params, 'lr': 1e-5},
    {'params': head_params, 'lr': 1e-3}
])

# 3. ์ž…๋ ฅ ํฌ๊ธฐ ์ฃผ์˜ (224, 384, ๋“ฑ)

# 4. ๊ฐ•๋ ฅํ•œ ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• ์‚ฌ์šฉ

์ฐธ๊ณ  ์ž๋ฃŒ

  • ViT ์›๋ณธ: https://arxiv.org/abs/2010.11929
  • DeiT: https://arxiv.org/abs/2012.12877
  • Swin Transformer: https://arxiv.org/abs/2103.14030
  • timm ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ: https://github.com/huggingface/pytorch-image-models
to navigate between lessons