31. Variational Autoencoder (VAE)
Previous: Generative Models - VAE | Next: Diffusion Models
31. Variational Autoencoder (VAE)¶
Overview¶
Variational Autoencoder (VAE) is a foundational generative model architecture that learns latent representations of data and can generate new samples. "Auto-Encoding Variational Bayes" (Kingma & Welling, 2013)
Mathematical Background¶
1. Generative Model Goal¶
Goal: model p(x)
- x: observed data (images, etc.)
- z: latent variable
Generation process:
z ~ p(z) # Prior (usually N(0, I))
x ~ p(x|z) # Decoder/Generator
Problem: p(x) = โซ p(x|z)p(z)dz is intractable
2. Variational Inference¶
Posterior p(z|x) is also intractable
โ Learn approximate distribution q(z|x) (Encoder)
ELBO (Evidence Lower BOund):
log p(x) โฅ E_q[log p(x|z)] - KL(q(z|x) || p(z))
โโโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโโโโโโโ
Reconstruction Regularization
Loss (Prior matching)
Objective to maximize:
L(ฮธ, ฯ; x) = E_q_ฯ(z|x)[log p_ฮธ(x|z)] - KL(q_ฯ(z|x) || p(z))
3. Reparameterization Trick¶
Problem: sampling z ~ q(z|x) = N(ฮผ, ฯยฒ) is not differentiable
Solution: Reparameterization
ฮต ~ N(0, I)
z = ฮผ + ฯ โ ฮต
Now gradient can backpropagate through ฮผ, ฯ!
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Encoder โ
โ x โ [ฮผ, log ฯยฒ] โ
โ โ
โ Reparameterization โ
โ ฮต ~ N(0, I) โ
โ z = ฮผ + ฯ โ ฮต โ
โ โ
โ Decoder โ
โ z โ xฬ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
4. Loss Function¶
L = L_recon + ฮฒ * L_KL
Reconstruction Loss (images):
- Binary: BCE(x, xฬ) = -ฮฃ[xยทlog(xฬ) + (1-x)ยทlog(1-xฬ)]
- Continuous: MSE(x, xฬ) = ||x - xฬ||ยฒ
KL Divergence (Gaussian prior):
KL(N(ฮผ, ฯยฒ) || N(0, 1)) = -ยฝ ฮฃ(1 + log ฯยฒ - ฮผยฒ - ฯยฒ)
ฮฒ-VAE:
ฮฒ > 1: stronger disentanglement
ฮฒ < 1: better reconstruction
VAE Architecture¶
Standard VAE (MNIST)¶
Encoder:
Input (28ร28ร1)
โ
Conv2d(1โ32, k=3, s=2, p=1) โ (14ร14ร32)
โ ReLU
Conv2d(32โ64, k=3, s=2, p=1) โ (7ร7ร64)
โ ReLU
Flatten โ (7ร7ร64 = 3136)
โ
Linear(3136โ256)
โ ReLU
โโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโ
โ Linear(256โz) โ Linear(256โz) โ
โ ฮผ โ log ฯยฒ โ
โโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโ
Reparameterization:
z = ฮผ + ฯ โ ฮต, ฮต ~ N(0, I)
Decoder:
z (latent_dim)
โ
Linear(zโ256)
โ ReLU
Linear(256โ3136)
โ ReLU
Reshape โ (7ร7ร64)
โ
ConvT2d(64โ32, k=3, s=2, p=1, op=1) โ (14ร14ร32)
โ ReLU
ConvT2d(32โ1, k=3, s=2, p=1, op=1) โ (28ร28ร1)
โ Sigmoid
Output (28ร28ร1)
File Structure¶
11_VAE/
โโโ README.md
โโโ numpy/
โ โโโ vae_numpy.py # NumPy VAE (forward only)
โโโ pytorch_lowlevel/
โ โโโ vae_lowlevel.py # PyTorch Low-Level VAE
โโโ paper/
โ โโโ vae_paper.py # Paper reproduction
โโโ exercises/
โโโ 01_latent_space.md # Latent space visualization
โโโ 02_interpolation.md # Latent space interpolation
Core Concepts¶
1. Latent Space¶
Good latent space characteristics:
1. Continuity: nearby points produce similar outputs
2. Completeness: all points generate meaningful outputs
3. (Disentanglement): each dimension controls independent features
VAE vs AE:
- AE: point embeddings โ discontinuous, has empty spaces
- VAE: distribution embeddings โ continuous, can sample
2. VAE Variants¶
ฮฒ-VAE (ฮฒ > 1):
- Stronger KL regularization
- Better disentanglement
- Worse reconstruction
Conditional VAE (CVAE):
- Add condition c: q(z|x, c), p(x|z, c)
- Enables conditional generation
VQ-VAE:
- Discrete codebook instead of continuous latent space
- Used in DALL-E, AudioLM, etc.
3. Training Stability¶
KL Annealing:
- Initial: ฮฒ=0 (focus on reconstruction)
- Gradually ฮฒโ1 (add regularization)
Free Bits:
- Ensure minimum KL (prevent posterior collapse)
- L_KL = max(KL, ฮป)
Implementation Levels¶
Level 2: PyTorch Low-Level (pytorch_lowlevel/)¶
- Directly use F.conv2d, F.linear
- Implement reparameterization trick
- Implement ELBO loss function
Level 3: Paper Implementation (paper/)¶
- Implement ฮฒ-VAE
- Implement CVAE (Conditional)
- Latent space visualization
Learning Checklist¶
- [ ] Understand ELBO derivation process
- [ ] Understand reparameterization trick
- [ ] Calculate KL divergence
- [ ] Understand role of ฮฒ
- [ ] Visualize latent space
- [ ] Implement Conditional VAE
References¶
- Kingma & Welling (2013). "Auto-Encoding Variational Bayes"
- Higgins et al. (2017). "ฮฒ-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework"
- ../Deep_Learning/16_VAE.md