30. Generative Models - VAE (Variational Autoencoder)

Previous: Generative Adversarial Networks (GAN) | Next: Variational Autoencoder (VAE)


30. Generative Models - VAE (Variational Autoencoder)

Learning Objectives

  • Theoretical foundation of VAE (Variational Inference)
  • Understanding Latent Space and probabilistic generation
  • Deriving ELBO loss function
  • Reparameterization Trick
  • Beta-VAE and Disentanglement
  • PyTorch implementation and visualization

1. VAE Theory

Autoencoder vs VAE

Autoencoder:
    Input β†’ Encoder β†’ Latent vector z (deterministic) β†’ Decoder β†’ Reconstruction

VAE:
    Input β†’ Encoder β†’ Mean mu, Variance sigma β†’ Sample z ~ N(mu, sigma) β†’ Decoder β†’ Reconstruction
                         ↓
                    Latent space is continuous and follows normal distribution
                    β†’ Can generate new images

Why Probabilistic?

Problems with regular Autoencoder:
- Latent space is discontinuous
- Strange output when inputting z not in training data
- Difficult to use as generative model

VAE solution:
- Regularize latent space to normal distribution
- Continuous latent space
- Can sample from arbitrary z ~ N(0, I) for generation

Graphical Model

Generative Process:
    z ~ p(z) = N(0, I)           # Prior distribution
    x ~ p_theta(x|z)             # Decoder

Inference:
    q_phi(z|x) β‰ˆ p(z|x)          # Encoder approximates posterior

Goal:
    Maximize log p(x) (data likelihood)
    β†’ Maximize ELBO (Evidence Lower Bound)

2. ELBO Loss Function

Derivation

log p(x) = log ∫ p(x, z) dz

         = log ∫ p(x|z) p(z) dz

         = log ∫ q(z|x) * [p(x|z) p(z) / q(z|x)] dz

         β‰₯ ∫ q(z|x) log[p(x|z) p(z) / q(z|x)] dz    (Jensen's inequality)

         = E_q[log p(x|z)] - KL(q(z|x) || p(z))

         = ELBO (Evidence Lower Bound)

Meaning of Two Terms

# ELBO = Reconstruction - KL Divergence

# 1. Reconstruction Term: E_q[log p(x|z)]
#    - How well decoder reconstructs x from z
#    - Reconstruction loss (MSE or BCE)

# 2. KL Divergence: KL(q(z|x) || p(z))
#    - How close encoded distribution is to prior N(0,I)
#    - Latent space regularization

Loss Function Implementation

def vae_loss(x, x_recon, mu, log_var):
    """VAE 손싀 ν•¨μˆ˜ (⭐⭐⭐)

    Args:
        x: 원본 이미지 (batch, ...)
        x_recon: μž¬κ΅¬μ„± 이미지 (batch, ...)
        mu: 평균 (batch, latent_dim)
        log_var: 둜그 λΆ„μ‚° (batch, latent_dim)

    Returns:
        total_loss, recon_loss, kl_loss
    """
    # μž¬κ΅¬μ„± 손싀 (이진 이미지: BCE, 연속 이미지: MSE)
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    # λ˜λŠ” MSE
    # recon_loss = F.mse_loss(x_recon, x, reduction='sum')

    # KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    total_loss = recon_loss + kl_loss

    return total_loss, recon_loss, kl_loss

3. Reparameterization Trick

Problem

Sampling from z ~ q(z|x) = N(mu, sigma)

Problem: Sampling is not differentiable β†’ Cannot backpropagate

Solution: Reparameterization

def reparameterize(mu, log_var):
    """Reparameterization Trick (⭐⭐⭐)

    z = mu + sigma * epsilon
    epsilon ~ N(0, I)

    This way, randomness is in epsilon,
    and we can differentiate with respect to mu, sigma
    """
    std = torch.exp(0.5 * log_var)  # sigma = exp(0.5 * log(sigma^2))
    eps = torch.randn_like(std)     # epsilon ~ N(0, I)
    z = mu + std * eps
    return z

Visual Understanding

[Not differentiable]
mu, sigma β†’ Sampling β†’ z β†’ Decoder

[Differentiable - Reparameterization]
mu ──────────────┐
                 β”‚
                 β–Ό
sigma ────────▢ (mu + sigma * eps) ──▢ z ──▢ Decoder
                       β–²
eps ~ N(0, I) β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ (treated as constant)

4. VAE Model Implementation

Encoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAEEncoder(nn.Module):
    """VAE Encoder (⭐⭐⭐)

    이미지 β†’ mu, log_var
    """
    def __init__(self, in_channels=1, latent_dim=20):
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1),  # 28 β†’ 14
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),           # 14 β†’ 7
            nn.ReLU(),
            nn.Flatten()
        )

        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

    def forward(self, x):
        h = self.conv_layers(x)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var

Decoder

class VAEDecoder(nn.Module):
    """VAE Decoder (⭐⭐⭐)

    z β†’ 이미지
    """
    def __init__(self, latent_dim=20, out_channels=1):
        super().__init__()

        self.fc = nn.Linear(latent_dim, 64 * 7 * 7)

        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 7 β†’ 14
            nn.ReLU(),
            nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1),  # 14 β†’ 28
            nn.Sigmoid()  # [0, 1]
        )

    def forward(self, z):
        h = self.fc(z)
        h = h.view(-1, 64, 7, 7)
        x_recon = self.deconv_layers(h)
        return x_recon

Complete VAE

class VAE(nn.Module):
    """Variational Autoencoder (⭐⭐⭐)"""
    def __init__(self, in_channels=1, latent_dim=20):
        super().__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim)
        self.decoder = VAEDecoder(latent_dim, in_channels)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

    def generate(self, num_samples, device):
        """μƒˆλ‘œμš΄ μƒ˜ν”Œ 생성"""
        z = torch.randn(num_samples, self.latent_dim, device=device)
        samples = self.decoder(z)
        return samples

    def reconstruct(self, x):
        """이미지 μž¬κ΅¬μ„±"""
        with torch.no_grad():
            x_recon, _, _ = self.forward(x)
        return x_recon

5. Training Loop

def train_vae(model, dataloader, epochs=50, lr=1e-3):
    """VAE ν•™μŠ΅ (⭐⭐⭐)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_recon = 0
        total_kl = 0

        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.to(device)

            optimizer.zero_grad()

            # Forward
            x_recon, mu, log_var = model(data)

            # Loss
            loss, recon_loss, kl_loss = vae_loss(data, x_recon, mu, log_var)

            # Normalize by batch size
            loss = loss / data.size(0)

            # Backward
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_recon += recon_loss.item() / data.size(0)
            total_kl += kl_loss.item() / data.size(0)

        avg_loss = total_loss / len(dataloader)
        avg_recon = total_recon / len(dataloader)
        avg_kl = total_kl / len(dataloader)

        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}")

    return model

6. Beta-VAE

Idea

ELBO = Reconstruction - beta * KL

beta > 1: Greater weight on KL term
    β†’ Latent space more regularized
    β†’ Disentangled representations
    β†’ Each latent dimension captures independent feature

beta = 1: Regular VAE
beta < 1: Focus on reconstruction

Implementation

def beta_vae_loss(x, x_recon, mu, log_var, beta=4.0):
    """Beta-VAE 손싀 ν•¨μˆ˜ (⭐⭐⭐)"""
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    # Beta κ°€μ€‘μΉ˜
    total_loss = recon_loss + beta * kl_loss

    return total_loss, recon_loss, kl_loss

Disentanglement Example

Beta-VAE trained on MNIST (beta=4):
    z[0]: Digit tilt
    z[1]: Line thickness
    z[2]: Digit type
    ...

Independently controlling each dimension changes only that feature

7. Latent Space Visualization

2D Latent Space

def visualize_latent_space(model, dataloader, device):
    """잠재 곡간 μ‹œκ°ν™” (⭐⭐)"""
    model.eval()

    latents = []
    labels = []

    with torch.no_grad():
        for data, label in dataloader:
            data = data.to(device)
            mu, _ = model.encoder(data)
            latents.append(mu.cpu())
            labels.append(label)

    latents = torch.cat(latents, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar(scatter)
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.title('VAE Latent Space')
    plt.savefig('latent_space.png')
    plt.close()

Latent Space Exploration

def explore_latent_dimension(model, dim_idx, range_vals, fixed_z, device):
    """νŠΉμ • 잠재 차원 탐색 (⭐⭐)"""
    model.eval()
    images = []

    with torch.no_grad():
        for val in range_vals:
            z = fixed_z.clone()
            z[0, dim_idx] = val
            img = model.decoder(z.to(device))
            images.append(img.cpu())

    return torch.cat(images, dim=0)

Manifold Generation

def generate_manifold(model, n=20, latent_dim=2, device='cpu'):
    """2D 잠재 κ³΅κ°„μ˜ manifold 생성 (⭐⭐⭐)"""
    model.eval()

    # κ·Έλ¦¬λ“œ 생성 (-3, 3) λ²”μœ„
    grid_x = torch.linspace(-3, 3, n)
    grid_y = torch.linspace(-3, 3, n)

    figure = np.zeros((28 * n, 28 * n))

    with torch.no_grad():
        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z = torch.zeros(1, latent_dim)
                z[0, 0] = xi
                z[0, 1] = yi

                x_decoded = model.decoder(z.to(device))
                digit = x_decoded[0, 0].cpu().numpy()

                figure[i * 28:(i + 1) * 28,
                       j * 28:(j + 1) * 28] = digit

    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='gray')
    plt.axis('off')
    plt.savefig('vae_manifold.png')
    plt.close()

8. VAE vs GAN Comparison

Characteristic VAE GAN
Training method Likelihood maximization Adversarial training
Loss function ELBO (explicit) Min-max (implicit)
Training stability Stable Unstable
Image quality Tends to be blurry Sharp
Latent space Structured Hard to interpret
Mode Coverage Good Mode Collapse possible
Density estimation Possible Not possible

Advantages and Disadvantages

VAE Advantages:
- Explicit density model
- Stable training
- Meaningful latent space
- Can both reconstruct and generate

VAE Disadvantages:
- Blurry images due to reconstruction loss
- KL term limits latent space expressiveness

GAN Advantages:
- Sharp high-quality images
- Implicit density β†’ More flexible

GAN Disadvantages:
- Unstable training
- Mode Collapse
- Difficult to evaluate

9. Advanced VAE Variants

Conditional VAE (CVAE)

class CVAE(nn.Module):
    """Conditional VAE (⭐⭐⭐)

    쑰건(예: 클래슀 λ ˆμ΄λΈ”)을 μ£Όμ–΄ νŠΉμ • νƒ€μž… 생성
    """
    def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
        super().__init__()
        self.num_classes = num_classes

        # 쑰건을 one-hot으둜 concat
        self.encoder = CVAEEncoder(in_channels, latent_dim, num_classes)
        self.decoder = CVAEDecoder(latent_dim, in_channels, num_classes)
        self.latent_dim = latent_dim

    def forward(self, x, label):
        # One-hot encoding
        y = F.one_hot(label, self.num_classes).float()

        mu, log_var = self.encoder(x, y)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z, y)

        return x_recon, mu, log_var

    def generate(self, label, num_samples, device):
        """νŠΉμ • 클래슀 생성"""
        z = torch.randn(num_samples, self.latent_dim, device=device)
        y = F.one_hot(label, self.num_classes).float().to(device)
        y = y.expand(num_samples, -1)
        return self.decoder(z, y)

VQ-VAE (Vector Quantized VAE)

# VQ-VAEλŠ” 연속 잠재 곡간 λŒ€μ‹  이산 μ½”λ“œλΆ μ‚¬μš©
# κ³ ν’ˆμ§ˆ 이미지/μ˜€λ””μ˜€ 생성에 효과적

class VectorQuantizer(nn.Module):
    """VQ-VAE의 벑터 μ–‘μžν™” (⭐⭐⭐⭐)"""
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # μ½”λ“œλΆ
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def forward(self, z):
        # z: (batch, channels, H, W)
        z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        # κ°€μž₯ κ°€κΉŒμš΄ μ½”λ“œλΆ 벑터 μ°ΎκΈ°
        distances = torch.cdist(z_flat, self.embeddings.weight)
        indices = torch.argmin(distances, dim=1)
        z_q = self.embeddings(indices).view(z.shape[0], z.shape[2], z.shape[3], -1)
        z_q = z_q.permute(0, 3, 1, 2)

        # 손싀: μ½”λ“œλΆ ν•™μŠ΅ + commitment loss
        loss = F.mse_loss(z_q.detach(), z) + self.commitment_cost * F.mse_loss(z_q, z.detach())

        # Straight-through estimator
        z_q = z + (z_q - z).detach()

        return z_q, loss, indices

10. Complete MNIST VAE Example

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# μ„€μ •
latent_dim = 20
batch_size = 128
epochs = 30
lr = 1e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 데이터
transform = transforms.ToTensor()
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# λͺ¨λΈ
model = VAE(in_channels=1, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# ν•™μŠ΅
for epoch in range(epochs):
    model.train()
    train_loss = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)

        optimizer.zero_grad()
        x_recon, mu, log_var = model(data)

        # Loss
        recon_loss = F.binary_cross_entropy(x_recon, data, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = (recon_loss + kl_loss) / data.size(0)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch+1}: Loss = {train_loss / len(train_loader):.4f}")

# 생성
model.eval()
with torch.no_grad():
    samples = model.generate(16, device)
    # μ €μž₯ λ˜λŠ” μ‹œκ°ν™”...

print("VAE ν•™μŠ΅ μ™„λ£Œ!")

Summary

Key Concepts

  1. VAE: Generative model with probabilistic latent space
  2. ELBO: Reconstruction + KL Divergence
  3. Reparameterization: z = mu + sigma * epsilon
  4. Beta-VAE: Control KL weight for disentanglement
  5. Latent space: Continuous, structured

Core Code

# Encoder 좜λ ₯
mu, log_var = encoder(x)

# Reparameterization
std = torch.exp(0.5 * log_var)
z = mu + std * torch.randn_like(std)

# Decoder
x_recon = decoder(z)

# Loss
recon = F.binary_cross_entropy(x_recon, x)
kl = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())
loss = recon + kl

Use Case Scenarios

Purpose Recommended Method
Data generation VAE or GAN
Latent space analysis VAE (especially Beta-VAE)
High-quality images GAN or VQ-VAE
Conditional generation CVAE
Compression/reconstruction VAE

Next Steps

Learn about the latest generative model, Diffusion models, in 32_Diffusion_Models.md.

to navigate between lessons