ํ์ต ๋ชฉํ
- Vision Transformer ์ํคํ
์ฒ ์ดํด
- Patch Embedding ์๋ฆฌ
- CLS ํ ํฐ๊ณผ Position Embedding
- ViT ๋ณํ ๋ชจ๋ธ๋ค (DeiT, Swin Transformer)
- PyTorch ๊ตฌํ ๋ฐ ํ์ฉ
ํต์ฌ ์์ด๋์ด
๊ธฐ์กด CNN: ์ง์ญ์ ํน์ง โ ์ ์ญ ํน์ง (๊ณ์ธต์ )
ViT: ์ด๋ฏธ์ง๋ฅผ ํจ์น ์ํ์ค๋ก ๋ณํ โ Transformer๋ก ์ฒ๋ฆฌ
์ด๋ฏธ์ง (224ร224) โ 16ร16 ํจ์น 196๊ฐ โ Transformer Encoder
1. Self-Attention์ ์ฅ์
- ์ฅ๊ฑฐ๋ฆฌ ์์กด์ฑ ํฌ์ฐฉ
- ์ ์ญ์ ์ปจํ
์คํธ ๊ณ ๋ ค
2. ํ์ฅ์ฑ
- ๋๊ท๋ชจ ๋ฐ์ดํฐ์
์์ CNN ๋ฅ๊ฐ
- ์ค์ผ์ผ๋ง์ด ์ฉ์ด
3. ์ํคํ
์ฒ ํตํฉ
- Vision + Language ํตํฉ ๊ฐ๋ฅ
- ๋ฉํฐ๋ชจ๋ฌ ํ์ต์ ์ ๋ฆฌ
2. ViT ์ํคํ
์ฒ
์ ์ฒด ๊ตฌ์กฐ
์
๋ ฅ ์ด๋ฏธ์ง (224ร224ร3)
โ
[Patch Embedding] โ 196๊ฐ ํจ์น ๋ฒกํฐ (๊ฐ 768์ฐจ์)
โ
[CLS Token ์ถ๊ฐ] โ 197๊ฐ ํ ํฐ
โ
[Position Embedding ์ถ๊ฐ]
โ
[Transformer Encoder ร L layers]
โ
[CLS Token ์ถ๋ ฅ ์ถ์ถ]
โ
[MLP Head] โ ๋ถ๋ฅ ๊ฒฐ๊ณผ
์์ ์ ๋ฆฌ
# ์
๋ ฅ
x โ R^(HรWรC) # ์: 224ร224ร3
# ํจ์น ๋ถํ
P = patch_size # ์: 16
N = (H/P) ร (W/P) # ํจ์น ๊ฐ์: 196
# Patch Embedding
x_p โ R^(Nร(PยฒยทC)) # 196ร768 (16ร16ร3 = 768)
z_0 = [x_class; x_pยทE] + E_pos # E: ํฌ์ ํ๋ ฌ
# Transformer
z_l = MSA(LN(z_{l-1})) + z_{l-1} # Multi-Head Self-Attention
z_l = MLP(LN(z_l)) + z_l # Feed Forward
# ์ถ๋ ฅ
y = LN(z_L^0) # CLS ํ ํฐ์ ์ต์ข
ํํ
3. Patch Embedding
๊ฐ๋
# ์ด๋ฏธ์ง๋ฅผ ํจ์น๋ก ๋ถํ
# (B, 3, 224, 224) โ (B, 196, 768)
# ๋ฐฉ๋ฒ 1: reshape
patches = image.reshape(B, N, P*P*C) # ์ง์ ์ฌ๊ตฌ์ฑ
# ๋ฐฉ๋ฒ 2: Conv2d (๋ ํจ์จ์ )
# stride=kernel_size๋ก ๊ฒน์น์ง ์๋ ํจ์น ์ถ์ถ
conv = nn.Conv2d(3, 768, kernel_size=16, stride=16)
patches = conv(image) # (B, 768, 14, 14)
patches = patches.flatten(2).transpose(1, 2) # (B, 196, 768)
PyTorch ๊ตฌํ
class PatchEmbedding(nn.Module):
"""Patch Embedding Layer (โญโญ)"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# Conv2d๋ก ํจ์น ์ถ์ถ + ์๋ฒ ๋ฉ
self.projection = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
# x: (B, C, H, W)
x = self.projection(x) # (B, embed_dim, H/P, W/P)
x = x.flatten(2) # (B, embed_dim, num_patches)
x = x.transpose(1, 2) # (B, num_patches, embed_dim)
return x
4. CLS Token๊ณผ Position Embedding
CLS Token
# BERT์์ ์ฐจ์ฉํ ๊ฐ๋
# ์ ์ฒด ์ด๋ฏธ์ง์ ํํ์ ํ์ตํ๋ ํน๋ณ ํ ํฐ
class_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# ๋ฐฐ์น์ ๋ธ๋ก๋์บ์คํธ
cls_tokens = class_token.expand(batch_size, -1, -1) # (B, 1, D)
# ํจ์น ์๋ฒ ๋ฉ ์์ ์ฐ๊ฒฐ
x = torch.cat([cls_tokens, patch_embeddings], dim=1) # (B, N+1, D)
Position Embedding
# ํจ์น์ ์์น ์ ๋ณด ์ ๊ณต (Transformer๋ ์์น ์ ๋ณด ์์)
class PositionEmbedding(nn.Module):
"""Learnable Position Embedding (โญโญ)"""
def __init__(self, num_patches, embed_dim):
super().__init__()
# +1 for CLS token
self.pos_embedding = nn.Parameter(
torch.randn(1, num_patches + 1, embed_dim)
)
def forward(self, x):
return x + self.pos_embedding
์์น ์๋ฒ ๋ฉ ์๊ฐํ
def visualize_position_embedding(pos_embed, img_size=224, patch_size=16):
"""์์น ์๋ฒ ๋ฉ ์ ์ฌ๋ ์๊ฐํ (โญโญ)"""
# pos_embed: (1, N+1, D)
# CLS ํ ํฐ ์ ์ธ
pos_embed = pos_embed[0, 1:] # (N, D)
# ์ ์ฌ๋ ํ๋ ฌ
similarity = torch.mm(pos_embed, pos_embed.T) # (N, N)
# ํน์ ํจ์น์์ ์ ์ฌ๋
num_patches = (img_size // patch_size)
center_idx = num_patches * (num_patches // 2) + (num_patches // 2)
center_sim = similarity[center_idx].reshape(num_patches, num_patches)
return center_sim # ์ค์ ํจ์น์์ ์ ์ฌ๋ ๋งต
๊ธฐ๋ณธ ViT
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
"""Multi-Head Self-Attention (โญโญโญ)"""
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, N, C = x.shape
# QKV ๊ณ์ฐ
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# Attention
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
attn = self.dropout(attn)
# Output
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class MLP(nn.Module):
"""MLP Block (โญโญ)"""
def __init__(self, embed_dim, mlp_ratio=4.0, dropout=0.0):
super().__init__()
hidden_dim = int(embed_dim * mlp_ratio)
self.fc1 = nn.Linear(embed_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.fc1(x)
x = F.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class TransformerBlock(nn.Module):
"""Transformer Encoder Block (โญโญโญ)"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, mlp_ratio, dropout)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class VisionTransformer(nn.Module):
"""Vision Transformer (ViT) (โญโญโญโญ)"""
def __init__(
self,
img_size=224,
patch_size=16,
in_channels=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
dropout=0.0
):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
# Patch Embedding
self.patch_embed = PatchEmbedding(
img_size, patch_size, in_channels, embed_dim
)
# CLS Token
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# Position Embedding
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches + 1, embed_dim)
)
self.dropout = nn.Dropout(dropout)
# Transformer Blocks
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# Classification Head
self.head = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B = x.shape[0]
# Patch Embedding
x = self.patch_embed(x) # (B, N, D)
# Add CLS Token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1) # (B, N+1, D)
# Add Position Embedding
x = x + self.pos_embed
x = self.dropout(x)
# Transformer Blocks
for block in self.blocks:
x = block(x)
x = self.norm(x)
# CLS Token๋ง ์ถ์ถํ์ฌ ๋ถ๋ฅ
cls_output = x[:, 0]
return self.head(cls_output)
ViT ๋ชจ๋ธ ๋ณํ
# ViT-Base (ViT-B/16)
vit_base = VisionTransformer(
img_size=224, patch_size=16,
embed_dim=768, depth=12, num_heads=12
)
# ViT-Large (ViT-L/16)
vit_large = VisionTransformer(
img_size=224, patch_size=16,
embed_dim=1024, depth=24, num_heads=16
)
# ViT-Huge (ViT-H/14)
vit_huge = VisionTransformer(
img_size=224, patch_size=14,
embed_dim=1280, depth=32, num_heads=16
)
ํต์ฌ ๊ฐ์ ์
๋ฌธ์ : ViT๋ ๋๊ท๋ชจ ๋ฐ์ดํฐ ํ์ (JFT-300M ๋ฑ)
ํด๊ฒฐ: ์ง์ ์ฆ๋ฅ + ๊ฐ๋ ฅํ ๋ฐ์ดํฐ ์ฆ๊ฐ์ผ๋ก ImageNet๋ง์ผ๋ก ํ์ต
1. Distillation Token: CNN ๊ต์ฌ ๋ชจ๋ธ์ ์ง์ ํ์ต
2. ๊ฐ๋ ฅํ Data Augmentation
3. Regularization (Stochastic Depth, Dropout)
Distillation Token
class DeiT(nn.Module):
"""Data-efficient Image Transformer (โญโญโญโญ)"""
def __init__(self, img_size=224, patch_size=16, num_classes=1000,
embed_dim=768, depth=12, num_heads=12):
super().__init__()
self.num_patches = (img_size // patch_size) ** 2
self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
# CLS Token + Distillation Token
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.randn(1, 1, embed_dim))
# Position Embedding (+2 for CLS and DIST)
self.pos_embed = nn.Parameter(
torch.randn(1, self.num_patches + 2, embed_dim)
)
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# ๋ ๊ฐ์ Head
self.head = nn.Linear(embed_dim, num_classes)
self.head_dist = nn.Linear(embed_dim, num_classes)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
dist_tokens = self.dist_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, dist_tokens, x], dim=1)
x = x + self.pos_embed
for block in self.blocks:
x = block(x)
x = self.norm(x)
# CLS์ DIST ํ ํฐ ๋ชจ๋ ์ฌ์ฉ
cls_output = self.head(x[:, 0])
dist_output = self.head_dist(x[:, 1])
if self.training:
return cls_output, dist_output
else:
# ์ถ๋ก ์ ํ๊ท
return (cls_output + dist_output) / 2
DeiT ํ์ต
def train_deit_with_distillation(student, teacher, dataloader, epochs=100):
"""DeiT ์ง์ ์ฆ๋ฅ ํ์ต (โญโญโญ)"""
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-3)
criterion_ce = nn.CrossEntropyLoss()
criterion_dist = nn.CrossEntropyLoss()
teacher.eval()
for epoch in range(epochs):
for images, labels in dataloader:
# Teacher prediction (soft labels)
with torch.no_grad():
teacher_output = teacher(images)
# Student predictions
cls_output, dist_output = student(images)
# Losses
loss_cls = criterion_ce(cls_output, labels)
loss_dist = criterion_dist(dist_output, teacher_output.argmax(dim=1))
loss = 0.5 * loss_cls + 0.5 * loss_dist
optimizer.zero_grad()
loss.backward()
optimizer.step()
ํต์ฌ ์์ด๋์ด
๋ฌธ์ : ViT์ O(nยฒ) ๋ณต์ก๋ โ ๊ณ ํด์๋ ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์ด๋ ค์
ํด๊ฒฐ: ๊ณ์ธต์ ๊ตฌ์กฐ + Shifted Window Attention
ํน์ง:
1. Window Attention: ์ง์ญ ์๋์ฐ ๋ด์์๋ง attention
2. Shifted Windows: ์๋์ฐ ๊ฐ ์ ๋ณด ๊ตํ
3. ๊ณ์ธต์ ๊ตฌ์กฐ: ํน์ง ๋งต ํด์๋ ์ ์ง์ ๊ฐ์
Window Attention
def window_partition(x, window_size):
"""์ด๋ฏธ์ง๋ฅผ ์๋์ฐ๋ก ๋ถํ (โญโญโญ)"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size,
W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""์๋์ฐ๋ฅผ ๋ค์ ์ด๋ฏธ์ง๋ก ํฉ์นจ (โญโญโญ)"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size,
window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Window-based Multi-Head Self-Attention (โญโญโญโญ)"""
def __init__(self, dim, window_size, num_heads):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
# Relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
)
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# ์๋ ์์น ์ธ๋ฑ์ค ์์ฑ
self._create_relative_position_index()
def _create_relative_position_index(self):
coords = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid([coords, coords], indexing='ij'))
coords_flatten = coords.flatten(1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size - 1
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
# Add relative position bias
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(N, N, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
attn = attn + mask
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
Shifted Window
class SwinTransformerBlock(nn.Module):
"""Swin Transformer Block with (Shifted) Window Attention (โญโญโญโญ)"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0):
super().__init__()
self.dim = dim
self.window_size = window_size
self.shift_size = shift_size
self.norm1 = nn.LayerNorm(dim)
self.attn = WindowAttention(dim, window_size, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = MLP(dim)
def forward(self, x, H, W):
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# Cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# Window partition
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
# Window attention
attn_windows = self.attn(x_windows)
# Window reverse
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W)
# Reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, L, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
8. ์ฌ์ ํ์ต ๋ชจ๋ธ ํ์ฉ
torchvision ์ฌ์ฉ
from torchvision.models import vit_b_16, vit_l_16, swin_t, swin_s
# ViT-B/16 (pretrained)
model = vit_b_16(weights='IMAGENET1K_V1')
# ํน์ง ์ถ์ถ๊ธฐ๋ก ์ฌ์ฉ
model.heads = nn.Identity()
features = model(image) # (B, 768)
# Fine-tuning
model = vit_b_16(weights='IMAGENET1K_V1')
model.heads = nn.Linear(768, num_classes)
# ํ์ต๋ฅ ์ฐจ๋ฑ ์ ์ฉ
params = [
{'params': model.encoder.parameters(), 'lr': 1e-5}, # backbone
{'params': model.heads.parameters(), 'lr': 1e-3} # head
]
optimizer = torch.optim.AdamW(params)
timm ๋ผ์ด๋ธ๋ฌ๋ฆฌ
import timm
# ์ฌ์ฉ ๊ฐ๋ฅํ ViT ๋ชจ๋ธ ๋ชฉ๋ก
vit_models = timm.list_models('vit*', pretrained=True)
print(f"Available ViT models: {len(vit_models)}")
# ๋ชจ๋ธ ๋ก๋
model = timm.create_model('vit_base_patch16_224', pretrained=True)
# ์ปค์คํ
๋ถ๋ฅ ํค๋
model = timm.create_model(
'vit_base_patch16_224',
pretrained=True,
num_classes=10 # ์๋์ผ๋ก head ๊ต์ฒด
)
# DeiT ๋ชจ๋ธ
deit_model = timm.create_model('deit_base_patch16_224', pretrained=True)
# Swin Transformer
swin_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
9. ์ค์ Fine-tuning
CIFAR-10 Fine-tuning
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
def finetune_vit_cifar10(epochs=10):
"""ViT CIFAR-10 Fine-tuning (โญโญโญ)"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ (ViT ์
๋ ฅ ํฌ๊ธฐ์ ๋ง๊ฒ)
transform_train = transforms.Compose([
transforms.Resize(224),
transforms.RandomCrop(224, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
# ๋ฐ์ดํฐ์
train_data = datasets.CIFAR10('data', train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10('data', train=False, transform=transform_test)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False, num_workers=4)
# ๋ชจ๋ธ
model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=10)
model = model.to(device)
# ์ตํฐ๋ง์ด์ (์ฐจ๋ฑ ํ์ต๋ฅ )
backbone_params = [p for n, p in model.named_parameters() if 'head' not in n]
head_params = [p for n, p in model.named_parameters() if 'head' in n]
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': 1e-5},
{'params': head_params, 'lr': 1e-3}
], weight_decay=0.01)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
# ํ์ต
for epoch in range(epochs):
model.train()
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
# ํ๊ฐ
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100. * correct / total
print(f'Epoch {epoch+1}/{epochs}: Loss={total_loss/len(train_loader):.4f}, Acc={accuracy:.2f}%')
scheduler.step()
return model
10. ViT vs CNN ๋น๊ต
ํน์ฑ ๋น๊ต
| ํน์ฑ |
CNN |
ViT |
| ๊ท๋ฉ์ ํธํฅ |
์ง์ญ์ฑ, ๋ฑ๋ณ์ฑ |
์์ |
| ๋ฐ์ดํฐ ์๊ตฌ๋ |
์ ์ |
๋ง์ |
| ๊ณ์ฐ ๋ณต์ก๋ |
O(n) |
O(nยฒ) |
| ์ฅ๊ฑฐ๋ฆฌ ์์กด์ฑ |
์ด๋ ค์ |
์ฉ์ด |
| ํด์ ๊ฐ๋ฅ์ฑ |
ํํฐ ์๊ฐํ |
Attention ์๊ฐํ |
์ฌ์ฉ ๊ฐ์ด๋๋ผ์ธ
CNN ์ ํธ:
- ์๊ท๋ชจ ๋ฐ์ดํฐ์
- ์ ํ๋ ๊ณ์ฐ ๋ฆฌ์์ค
- ์ค์๊ฐ ์ถ๋ก ํ์
ViT ์ ํธ:
- ๋๊ท๋ชจ ๋ฐ์ดํฐ์
๋๋ ์ฌ์ ํ์ต ๋ชจ๋ธ ํ์ฉ
- ์ ์ญ ์ปจํ
์คํธ๊ฐ ์ค์ํ ํ์คํฌ
- ๋ฉํฐ๋ชจ๋ฌ ํ์ต ๊ณํ
์ ๋ฆฌ
ํต์ฌ ๊ฐ๋
- Patch Embedding: ์ด๋ฏธ์ง๋ฅผ ํจ์น ์ํ์ค๋ก ๋ณํ
- CLS Token: ์ ์ฒด ์ด๋ฏธ์ง ํํ ํ์ต
- Position Embedding: ํจ์น ์์น ์ ๋ณด ์ ๊ณต
- DeiT: ๋ฐ์ดํฐ ํจ์จ์ ํ์ต (์ง์ ์ฆ๋ฅ)
- Swin: ์๋์ฐ ๊ธฐ๋ฐ ํจ์จ์ attention
๋ชจ๋ธ ์ ํ ๊ฐ์ด๋
์ผ๋ฐ ๋ถ๋ฅ: ViT-B/16 ๋๋ DeiT
๊ณ ํด์๋: Swin Transformer
์ ํ๋ ์์: ViT-Small, DeiT-Tiny
์ต๊ณ ์ฑ๋ฅ: ViT-Large, Swin-Large
PyTorch ์ค์ ํ
# 1. timm ์ฌ์ฉ ๊ถ์ฅ
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
# 2. ์ฐจ๋ฑ ํ์ต๋ฅ ํ์
optimizer = torch.optim.AdamW([
{'params': backbone_params, 'lr': 1e-5},
{'params': head_params, 'lr': 1e-3}
])
# 3. ์
๋ ฅ ํฌ๊ธฐ ์ฃผ์ (224, 384, ๋ฑ)
# 4. ๊ฐ๋ ฅํ ๋ฐ์ดํฐ ์ฆ๊ฐ ์ฌ์ฉ
์ฐธ๊ณ ์๋ฃ
- ViT ์๋ณธ: https://arxiv.org/abs/2010.11929
- DeiT: https://arxiv.org/abs/2012.12877
- Swin Transformer: https://arxiv.org/abs/2103.14030
- timm ๋ผ์ด๋ธ๋ฌ๋ฆฌ: https://github.com/huggingface/pytorch-image-models