31. Variational Autoencoder (VAE)

이전: 생성 λͺ¨λΈ - VAE | λ‹€μŒ: ν™•μ‚° λͺ¨λΈ


31. Variational Autoencoder (VAE)

κ°œμš”

Variational Autoencoder (VAE)λŠ” 생성 λͺ¨λΈμ˜ κΈ°μ΄ˆκ°€ λ˜λŠ” μ•„ν‚€ν…μ²˜λ‘œ, λ°μ΄ν„°μ˜ 잠재 ν‘œν˜„(latent representation)을 ν•™μŠ΅ν•˜κ³  μƒˆλ‘œμš΄ μƒ˜ν”Œμ„ 생성할 수 μžˆμŠ΅λ‹ˆλ‹€. "Auto-Encoding Variational Bayes" (Kingma & Welling, 2013)


μˆ˜ν•™μ  λ°°κ²½

1. 생성 λͺ¨λΈ λͺ©ν‘œ

λͺ©ν‘œ: p(x) λͺ¨λΈλ§
- x: κ΄€μΈ‘ 데이터 (이미지 λ“±)
- z: 잠재 λ³€μˆ˜ (latent variable)

생성 κ³Όμ •:
z ~ p(z)         # Prior (보톡 N(0, I))
x ~ p(x|z)       # Decoder/Generator

문제: p(x) = ∫ p(x|z)p(z)dz λŠ” 계산 λΆˆκ°€λŠ₯ (intractable)

2. Variational Inference

사후 뢄포 p(z|x)도 계산 λΆˆκ°€λŠ₯
β†’ 근사 뢄포 q(z|x)λ₯Ό ν•™μŠ΅ (Encoder)

ELBO (Evidence Lower BOund):
log p(x) β‰₯ E_q[log p(x|z)] - KL(q(z|x) || p(z))
         ────────────────   ─────────────────────
         Reconstruction     Regularization
         Loss               (Prior matching)

μ΅œλŒ€ν™”ν•  λͺ©ν‘œ:
L(ΞΈ, Ο†; x) = E_q_Ο†(z|x)[log p_ΞΈ(x|z)] - KL(q_Ο†(z|x) || p(z))

3. Reparameterization Trick

문제: z ~ q(z|x) = N(ΞΌ, σ²) μ—μ„œ μƒ˜ν”Œλ§μ€ λ―ΈλΆ„ λΆˆκ°€

ν•΄κ²°: Reparameterization
Ξ΅ ~ N(0, I)
z = ΞΌ + Οƒ βŠ™ Ξ΅

이제 κ·Έλž˜λ””μ–ΈνŠΈκ°€ ΞΌ, Οƒλ₯Ό 톡해 μ—­μ „νŒŒ κ°€λŠ₯!

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Encoder                                β”‚
β”‚  x β†’ [ΞΌ, log σ²]                        β”‚
β”‚                                         β”‚
β”‚  Reparameterization                     β”‚
β”‚  Ξ΅ ~ N(0, I)                           β”‚
β”‚  z = ΞΌ + Οƒ βŠ™ Ξ΅                         β”‚
β”‚                                         β”‚
β”‚  Decoder                                β”‚
β”‚  z β†’ xΜ‚                                  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

4. 손싀 ν•¨μˆ˜

L = L_recon + Ξ² * L_KL

Reconstruction Loss (이미지):
- Binary: BCE(x, xΜ‚) = -Ξ£[xΒ·log(xΜ‚) + (1-x)Β·log(1-xΜ‚)]
- Continuous: MSE(x, xΜ‚) = ||x - xΜ‚||Β²

KL Divergence (Gaussian prior):
KL(N(ΞΌ, σ²) || N(0, 1)) = -Β½ Ξ£(1 + log σ² - ΞΌΒ² - σ²)

Ξ²-VAE:
Ξ² > 1: 더 κ°•ν•œ disentanglement
Ξ² < 1: 더 λ‚˜μ€ reconstruction

VAE μ•„ν‚€ν…μ²˜

ν‘œμ€€ VAE (MNIST)

Encoder:
Input (28Γ—28Γ—1)
    ↓
Conv2d(1β†’32, k=3, s=2, p=1)  β†’ (14Γ—14Γ—32)
    ↓ ReLU
Conv2d(32β†’64, k=3, s=2, p=1) β†’ (7Γ—7Γ—64)
    ↓ ReLU
Flatten β†’ (7Γ—7Γ—64 = 3136)
    ↓
Linear(3136β†’256)
    ↓ ReLU
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Linear(256β†’z)  β”‚ Linear(256β†’z)  β”‚
β”‚     ΞΌ          β”‚    log σ²      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Reparameterization:
z = ΞΌ + Οƒ βŠ™ Ξ΅,  Ξ΅ ~ N(0, I)

Decoder:
z (latent_dim)
    ↓
Linear(z→256)
    ↓ ReLU
Linear(256β†’3136)
    ↓ ReLU
Reshape β†’ (7Γ—7Γ—64)
    ↓
ConvT2d(64β†’32, k=3, s=2, p=1, op=1) β†’ (14Γ—14Γ—32)
    ↓ ReLU
ConvT2d(32β†’1, k=3, s=2, p=1, op=1)  β†’ (28Γ—28Γ—1)
    ↓ Sigmoid
Output (28Γ—28Γ—1)

파일 ꡬ쑰

11_VAE/
β”œβ”€β”€ README.md
β”œβ”€β”€ numpy/
β”‚   └── vae_numpy.py          # NumPy VAE (forward만)
β”œβ”€β”€ pytorch_lowlevel/
β”‚   └── vae_lowlevel.py       # PyTorch Low-Level VAE
β”œβ”€β”€ paper/
β”‚   └── vae_paper.py          # λ…Όλ¬Έ μž¬ν˜„
└── exercises/
    β”œβ”€β”€ 01_latent_space.md    # 잠재 곡간 μ‹œκ°ν™”
    └── 02_interpolation.md   # 잠재 곡간 보간

핡심 κ°œλ…

1. Latent Space

쒋은 잠재 κ³΅κ°„μ˜ νŠΉμ„±:
1. Continuity: κ°€κΉŒμš΄ 점듀은 λΉ„μŠ·ν•œ 좜λ ₯
2. Completeness: λͺ¨λ“  점이 μ˜λ―ΈμžˆλŠ” 좜λ ₯ 생성
3. (Disentanglement): 각 차원이 독립적 νŠΉμ„± μ œμ–΄

VAE vs AE:
- AE: 점 μž„λ² λ”© β†’ λΆˆμ—°μ†μ , 빈 곡간 있음
- VAE: 뢄포 μž„λ² λ”© β†’ 연속적, μƒ˜ν”Œλ§ κ°€λŠ₯

2. VAE Variants

Ξ²-VAE (Ξ² > 1):
- 더 κ°•ν•œ KL regularization
- Better disentanglement
- Worse reconstruction

Conditional VAE (CVAE):
- 쑰건 c μΆ”κ°€: q(z|x, c), p(x|z, c)
- 쑰건뢀 생성 κ°€λŠ₯

VQ-VAE:
- 연속 잠재 곡간 λŒ€μ‹  이산 μ½”λ“œλΆ
- DALL-E, AudioLM 등에 μ‚¬μš©

3. ν•™μŠ΅ μ•ˆμ •μ„±

KL Annealing:
- 초기: Ξ²=0 (reconstruction에 집쀑)
- μ μ§„μ μœΌλ‘œ Ξ²β†’1 (μ •κ·œν™” μΆ”κ°€)

Free Bits:
- KL μ΅œμ†Œκ°’ 보μž₯ (posterior collapse λ°©μ§€)
- L_KL = max(KL, Ξ»)

κ΅¬ν˜„ 레벨

Level 2: PyTorch Low-Level (pytorch_lowlevel/)

  • F.conv2d, F.linear 직접 μ‚¬μš©
  • reparameterization trick κ΅¬ν˜„
  • ELBO 손싀 ν•¨μˆ˜ κ΅¬ν˜„

Level 3: Paper Implementation (paper/)

  • Ξ²-VAE κ΅¬ν˜„
  • CVAE (Conditional) κ΅¬ν˜„
  • 잠재 곡간 μ‹œκ°ν™”

ν•™μŠ΅ 체크리슀트

  • [ ] ELBO μœ λ„ κ³Όμ • 이해
  • [ ] Reparameterization trick 이해
  • [ ] KL divergence 계산
  • [ ] β의 μ—­ν•  이해
  • [ ] 잠재 곡간 μ‹œκ°ν™”
  • [ ] Conditional VAE κ΅¬ν˜„

참고 자료

  • Kingma & Welling (2013). "Auto-Encoding Variational Bayes"
  • Higgins et al. (2017). "Ξ²-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework"
  • 30_Generative_Models_VAE.md
to navigate between lessons