22. Vision Transformer (ViT)

이전: λΉ„μ „ 트랜슀포머 | λ‹€μŒ: ν•™μŠ΅ μ΅œμ ν™”


22. Vision Transformer (ViT)

κ°œμš”

Vision Transformer (ViT)λŠ” Transformer μ•„ν‚€ν…μ²˜λ₯Ό 이미지 λΆ„λ₯˜μ— μ μš©ν•œ λͺ¨λΈμž…λ‹ˆλ‹€. 이미지λ₯Ό 패치둜 λΆ„ν• ν•˜κ³ , 각 패치λ₯Ό ν† ν°μ²˜λŸΌ μ²˜λ¦¬ν•©λ‹ˆλ‹€. "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)


μˆ˜ν•™μ  λ°°κ²½

1. 이미지 νŒ¨μΉ˜ν™”

μž…λ ₯ 이미지: x ∈ R^(H Γ— W Γ— C)
패치 크기: P Γ— P

패치 μ‹œν€€μŠ€:
x_p ∈ R^(N Γ— PΒ² Γ— C)  where N = (H Γ— W) / PΒ²

μ˜ˆμ‹œ:
- 이미지: 224 Γ— 224 Γ— 3
- 패치: 16 Γ— 16
- N = (224 Γ— 224) / (16 Γ— 16) = 196 패치
- 각 패치: 16 Γ— 16 Γ— 3 = 768 차원

2. 패치 μž„λ² λ”©

Linear Projection:
z_0 = [x_class; x_p¹E; x_p²E; ...; x_pⁿE] + E_pos

μ—¬κΈ°μ„œ:
- x_class: ν•™μŠ΅ κ°€λŠ₯ν•œ [CLS] 토큰
- E ∈ R^(PΒ²C Γ— D): 패치 μž„λ² λ”© ν–‰λ ¬
- E_pos ∈ R^((N+1) Γ— D): μœ„μΉ˜ μž„λ² λ”©

z_0 ∈ R^((N+1) Γ— D): 초기 μž„λ² λ”© μ‹œν€€μŠ€

3. Transformer Encoder

Encoder block (L layers):

z'_l = MSA(LN(z_{l-1})) + z_{l-1}
z_l = MLP(LN(z'_l)) + z'_l

μ΅œμ’… 좜λ ₯:
y = LN(z_L⁰)  # [CLS] ν† ν°λ§Œ μ‚¬μš©

μ—¬κΈ°μ„œ z_Lβ°λŠ” L번째 λ ˆμ΄μ–΄μ˜ [CLS] 토큰

ViT μ•„ν‚€ν…μ²˜ λ³€ν˜•

ViT-Base (B/16):
- Hidden size: 768
- Layers: 12
- Attention heads: 12
- MLP size: 3072
- Patch size: 16
- Parameters: 86M

ViT-Large (L/16):
- Hidden size: 1024
- Layers: 24
- Attention heads: 16
- MLP size: 4096
- Patch size: 16
- Parameters: 307M

ViT-Huge (H/14):
- Hidden size: 1280
- Layers: 32
- Attention heads: 16
- MLP size: 5120
- Patch size: 14
- Parameters: 632M

파일 ꡬ쑰

10_ViT/
β”œβ”€β”€ README.md
β”œβ”€β”€ pytorch_lowlevel/
β”‚   └── vit_lowlevel.py         # ViT 직접 κ΅¬ν˜„
β”œβ”€β”€ paper/
β”‚   └── vit_paper.py            # λ…Όλ¬Έ μž¬ν˜„
└── exercises/
    β”œβ”€β”€ 01_patch_embedding.md   # 패치 μž„λ² λ”© μ‹œκ°ν™”
    └── 02_attention_maps.md    # Attention μ‹œκ°ν™”

핡심 κ°œλ…

1. CNN vs ViT

CNN:
- 지역적 수용 μ˜μ—­ (local receptive field)
- Inductive bias: locality, translation equivariance
- μž‘μ€ 데이터셋에 유리

ViT:
- μ „μ—­ 수용 μ˜μ—­ (global from start)
- μ΅œμ†Œν•œμ˜ inductive bias
- λŒ€κ·œλͺ¨ 데이터셋에 유리 (JFT-300M)
- μž‘μ€ 데이터: pre-training ν•„μš”

2. Position Embedding

1D Learnable (ViT κΈ°λ³Έ):
- N+1개의 ν•™μŠ΅ κ°€λŠ₯ν•œ 벑터
- μˆœμ„œ 정보 ν•™μŠ΅

2D Positional (λ³€ν˜•):
- (row, col) 별도 μž„λ² λ”©
- 이미지 ꡬ쑰 반영

Sinusoidal:
- κ³ μ •λœ 삼각 ν•¨μˆ˜
- μ™Έμ‚½ κ°€λŠ₯μ„±

3. [CLS] Token vs Global Average Pooling

[CLS] Token:
- 첫 번째 μœ„μΉ˜μ— μΆ”κ°€
- 전체 이미지 ν‘œν˜„ μ§‘μ•½
- BERT μŠ€νƒ€μΌ

Global Average Pooling:
- λͺ¨λ“  패치 평균
- CNN μŠ€νƒ€μΌ
- λΉ„μŠ·ν•œ μ„±λŠ₯

κ΅¬ν˜„ 레벨

Level 2: PyTorch Low-Level (pytorch_lowlevel/)

  • F.linear, F.layer_norm μ‚¬μš©
  • nn.TransformerEncoder λ―Έμ‚¬μš©
  • νŒ¨μΉ˜ν™” 직접 κ΅¬ν˜„

Level 3: Paper Implementation (paper/)

  • λ…Όλ¬Έ μ •ν™•ν•œ 사양
  • JFT/ImageNet pre-training
  • Fine-tuning μ½”λ“œ

Level 4: Code Analysis (별도)

  • timm 라이브러리 뢄석
  • HuggingFace ViT 뢄석

ν•™μŠ΅ 체크리슀트

  • [ ] 패치 μž„λ² λ”© μˆ˜μ‹ 이해
  • [ ] μœ„μΉ˜ μž„λ² λ”© μ—­ν• 
  • [ ] [CLS] 토큰 μ—­ν• 
  • [ ] CNN λŒ€λΉ„ μž₯단점
  • [ ] Attention map μ‹œκ°ν™”
  • [ ] Fine-tuning μ „λž΅

참고 자료

to navigate between lessons