31. Variational Autoencoder (VAE)

Previous: Generative Models - VAE | Next: Diffusion Models


31. Variational Autoencoder (VAE)

Overview

Variational Autoencoder (VAE) is a foundational generative model architecture that learns latent representations of data and can generate new samples. "Auto-Encoding Variational Bayes" (Kingma & Welling, 2013)


Mathematical Background

1. Generative Model Goal

Goal: model p(x)
- x: observed data (images, etc.)
- z: latent variable

Generation process:
z ~ p(z)         # Prior (usually N(0, I))
x ~ p(x|z)       # Decoder/Generator

Problem: p(x) = โˆซ p(x|z)p(z)dz is intractable

2. Variational Inference

Posterior p(z|x) is also intractable
โ†’ Learn approximate distribution q(z|x) (Encoder)

ELBO (Evidence Lower BOund):
log p(x) โ‰ฅ E_q[log p(x|z)] - KL(q(z|x) || p(z))
         โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€   โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
         Reconstruction     Regularization
         Loss               (Prior matching)

Objective to maximize:
L(ฮธ, ฯ†; x) = E_q_ฯ†(z|x)[log p_ฮธ(x|z)] - KL(q_ฯ†(z|x) || p(z))

3. Reparameterization Trick

Problem: sampling z ~ q(z|x) = N(ฮผ, ฯƒยฒ) is not differentiable

Solution: Reparameterization
ฮต ~ N(0, I)
z = ฮผ + ฯƒ โŠ™ ฮต

Now gradient can backpropagate through ฮผ, ฯƒ!

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚  Encoder                                โ”‚
โ”‚  x โ†’ [ฮผ, log ฯƒยฒ]                        โ”‚
โ”‚                                         โ”‚
โ”‚  Reparameterization                     โ”‚
โ”‚  ฮต ~ N(0, I)                           โ”‚
โ”‚  z = ฮผ + ฯƒ โŠ™ ฮต                         โ”‚
โ”‚                                         โ”‚
โ”‚  Decoder                                โ”‚
โ”‚  z โ†’ xฬ‚                                  โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

4. Loss Function

L = L_recon + ฮฒ * L_KL

Reconstruction Loss (images):
- Binary: BCE(x, xฬ‚) = -ฮฃ[xยทlog(xฬ‚) + (1-x)ยทlog(1-xฬ‚)]
- Continuous: MSE(x, xฬ‚) = ||x - xฬ‚||ยฒ

KL Divergence (Gaussian prior):
KL(N(ฮผ, ฯƒยฒ) || N(0, 1)) = -ยฝ ฮฃ(1 + log ฯƒยฒ - ฮผยฒ - ฯƒยฒ)

ฮฒ-VAE:
ฮฒ > 1: stronger disentanglement
ฮฒ < 1: better reconstruction

VAE Architecture

Standard VAE (MNIST)

Encoder:
Input (28ร—28ร—1)
    โ†“
Conv2d(1โ†’32, k=3, s=2, p=1)  โ†’ (14ร—14ร—32)
    โ†“ ReLU
Conv2d(32โ†’64, k=3, s=2, p=1) โ†’ (7ร—7ร—64)
    โ†“ ReLU
Flatten โ†’ (7ร—7ร—64 = 3136)
    โ†“
Linear(3136โ†’256)
    โ†“ ReLU
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Linear(256โ†’z)  โ”‚ Linear(256โ†’z)  โ”‚
โ”‚     ฮผ          โ”‚    log ฯƒยฒ      โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

Reparameterization:
z = ฮผ + ฯƒ โŠ™ ฮต,  ฮต ~ N(0, I)

Decoder:
z (latent_dim)
    โ†“
Linear(zโ†’256)
    โ†“ ReLU
Linear(256โ†’3136)
    โ†“ ReLU
Reshape โ†’ (7ร—7ร—64)
    โ†“
ConvT2d(64โ†’32, k=3, s=2, p=1, op=1) โ†’ (14ร—14ร—32)
    โ†“ ReLU
ConvT2d(32โ†’1, k=3, s=2, p=1, op=1)  โ†’ (28ร—28ร—1)
    โ†“ Sigmoid
Output (28ร—28ร—1)

File Structure

11_VAE/
โ”œโ”€โ”€ README.md
โ”œโ”€โ”€ numpy/
โ”‚   โ””โ”€โ”€ vae_numpy.py          # NumPy VAE (forward only)
โ”œโ”€โ”€ pytorch_lowlevel/
โ”‚   โ””โ”€โ”€ vae_lowlevel.py       # PyTorch Low-Level VAE
โ”œโ”€โ”€ paper/
โ”‚   โ””โ”€โ”€ vae_paper.py          # Paper reproduction
โ””โ”€โ”€ exercises/
    โ”œโ”€โ”€ 01_latent_space.md    # Latent space visualization
    โ””โ”€โ”€ 02_interpolation.md   # Latent space interpolation

Core Concepts

1. Latent Space

Good latent space characteristics:
1. Continuity: nearby points produce similar outputs
2. Completeness: all points generate meaningful outputs
3. (Disentanglement): each dimension controls independent features

VAE vs AE:
- AE: point embeddings โ†’ discontinuous, has empty spaces
- VAE: distribution embeddings โ†’ continuous, can sample

2. VAE Variants

ฮฒ-VAE (ฮฒ > 1):
- Stronger KL regularization
- Better disentanglement
- Worse reconstruction

Conditional VAE (CVAE):
- Add condition c: q(z|x, c), p(x|z, c)
- Enables conditional generation

VQ-VAE:
- Discrete codebook instead of continuous latent space
- Used in DALL-E, AudioLM, etc.

3. Training Stability

KL Annealing:
- Initial: ฮฒ=0 (focus on reconstruction)
- Gradually ฮฒโ†’1 (add regularization)

Free Bits:
- Ensure minimum KL (prevent posterior collapse)
- L_KL = max(KL, ฮป)

Implementation Levels

Level 2: PyTorch Low-Level (pytorch_lowlevel/)

  • Directly use F.conv2d, F.linear
  • Implement reparameterization trick
  • Implement ELBO loss function

Level 3: Paper Implementation (paper/)

  • Implement ฮฒ-VAE
  • Implement CVAE (Conditional)
  • Latent space visualization

Learning Checklist

  • [ ] Understand ELBO derivation process
  • [ ] Understand reparameterization trick
  • [ ] Calculate KL divergence
  • [ ] Understand role of ฮฒ
  • [ ] Visualize latent space
  • [ ] Implement Conditional VAE

References

  • Kingma & Welling (2013). "Auto-Encoding Variational Bayes"
  • Higgins et al. (2017). "ฮฒ-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework"
  • ../Deep_Learning/16_VAE.md
to navigate between lessons