30. Generative Models - VAE (Variational Autoencoder)
Previous: Generative Adversarial Networks (GAN) | Next: Variational Autoencoder (VAE)
30. Generative Models - VAE (Variational Autoencoder)¶
Learning Objectives¶
- Theoretical foundation of VAE (Variational Inference)
- Understanding Latent Space and probabilistic generation
- Deriving ELBO loss function
- Reparameterization Trick
- Beta-VAE and Disentanglement
- PyTorch implementation and visualization
1. VAE Theory¶
Autoencoder vs VAE¶
Autoencoder:
Input β Encoder β Latent vector z (deterministic) β Decoder β Reconstruction
VAE:
Input β Encoder β Mean mu, Variance sigma β Sample z ~ N(mu, sigma) β Decoder β Reconstruction
β
Latent space is continuous and follows normal distribution
β Can generate new images
Why Probabilistic?¶
Problems with regular Autoencoder:
- Latent space is discontinuous
- Strange output when inputting z not in training data
- Difficult to use as generative model
VAE solution:
- Regularize latent space to normal distribution
- Continuous latent space
- Can sample from arbitrary z ~ N(0, I) for generation
Graphical Model¶
Generative Process:
z ~ p(z) = N(0, I) # Prior distribution
x ~ p_theta(x|z) # Decoder
Inference:
q_phi(z|x) β p(z|x) # Encoder approximates posterior
Goal:
Maximize log p(x) (data likelihood)
β Maximize ELBO (Evidence Lower Bound)
2. ELBO Loss Function¶
Derivation¶
log p(x) = log β« p(x, z) dz
= log β« p(x|z) p(z) dz
= log β« q(z|x) * [p(x|z) p(z) / q(z|x)] dz
β₯ β« q(z|x) log[p(x|z) p(z) / q(z|x)] dz (Jensen's inequality)
= E_q[log p(x|z)] - KL(q(z|x) || p(z))
= ELBO (Evidence Lower Bound)
Meaning of Two Terms¶
# ELBO = Reconstruction - KL Divergence
# 1. Reconstruction Term: E_q[log p(x|z)]
# - How well decoder reconstructs x from z
# - Reconstruction loss (MSE or BCE)
# 2. KL Divergence: KL(q(z|x) || p(z))
# - How close encoded distribution is to prior N(0,I)
# - Latent space regularization
Loss Function Implementation¶
def vae_loss(x, x_recon, mu, log_var):
"""VAE μμ€ ν¨μ (βββ)
Args:
x: μλ³Έ μ΄λ―Έμ§ (batch, ...)
x_recon: μ¬κ΅¬μ± μ΄λ―Έμ§ (batch, ...)
mu: νκ· (batch, latent_dim)
log_var: λ‘κ·Έ λΆμ° (batch, latent_dim)
Returns:
total_loss, recon_loss, kl_loss
"""
# μ¬κ΅¬μ± μμ€ (μ΄μ§ μ΄λ―Έμ§: BCE, μ°μ μ΄λ―Έμ§: MSE)
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
# λλ MSE
# recon_loss = F.mse_loss(x_recon, x, reduction='sum')
# KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
total_loss = recon_loss + kl_loss
return total_loss, recon_loss, kl_loss
3. Reparameterization Trick¶
Problem¶
Sampling from z ~ q(z|x) = N(mu, sigma)
Problem: Sampling is not differentiable β Cannot backpropagate
Solution: Reparameterization¶
def reparameterize(mu, log_var):
"""Reparameterization Trick (βββ)
z = mu + sigma * epsilon
epsilon ~ N(0, I)
This way, randomness is in epsilon,
and we can differentiate with respect to mu, sigma
"""
std = torch.exp(0.5 * log_var) # sigma = exp(0.5 * log(sigma^2))
eps = torch.randn_like(std) # epsilon ~ N(0, I)
z = mu + std * eps
return z
Visual Understanding¶
[Not differentiable]
mu, sigma β Sampling β z β Decoder
[Differentiable - Reparameterization]
mu βββββββββββββββ
β
βΌ
sigma βββββββββΆ (mu + sigma * eps) βββΆ z βββΆ Decoder
β²
eps ~ N(0, I) ββββββββββ (treated as constant)
4. VAE Model Implementation¶
Encoder¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAEEncoder(nn.Module):
"""VAE Encoder (βββ)
μ΄λ―Έμ§ β mu, log_var
"""
def __init__(self, in_channels=1, latent_dim=20):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels, 32, 3, stride=2, padding=1), # 28 β 14
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14 β 7
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
def forward(self, x):
h = self.conv_layers(x)
mu = self.fc_mu(h)
log_var = self.fc_logvar(h)
return mu, log_var
Decoder¶
class VAEDecoder(nn.Module):
"""VAE Decoder (βββ)
z β μ΄λ―Έμ§
"""
def __init__(self, latent_dim=20, out_channels=1):
super().__init__()
self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
self.deconv_layers = nn.Sequential(
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 7 β 14
nn.ReLU(),
nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1), # 14 β 28
nn.Sigmoid() # [0, 1]
)
def forward(self, z):
h = self.fc(z)
h = h.view(-1, 64, 7, 7)
x_recon = self.deconv_layers(h)
return x_recon
Complete VAE¶
class VAE(nn.Module):
"""Variational Autoencoder (βββ)"""
def __init__(self, in_channels=1, latent_dim=20):
super().__init__()
self.encoder = VAEEncoder(in_channels, latent_dim)
self.decoder = VAEDecoder(latent_dim, in_channels)
self.latent_dim = latent_dim
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
def forward(self, x):
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decoder(z)
return x_recon, mu, log_var
def generate(self, num_samples, device):
"""μλ‘μ΄ μν μμ±"""
z = torch.randn(num_samples, self.latent_dim, device=device)
samples = self.decoder(z)
return samples
def reconstruct(self, x):
"""μ΄λ―Έμ§ μ¬κ΅¬μ±"""
with torch.no_grad():
x_recon, _, _ = self.forward(x)
return x_recon
5. Training Loop¶
def train_vae(model, dataloader, epochs=50, lr=1e-3):
"""VAE νμ΅ (βββ)"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
total_loss = 0
total_recon = 0
total_kl = 0
for batch_idx, (data, _) in enumerate(dataloader):
data = data.to(device)
optimizer.zero_grad()
# Forward
x_recon, mu, log_var = model(data)
# Loss
loss, recon_loss, kl_loss = vae_loss(data, x_recon, mu, log_var)
# Normalize by batch size
loss = loss / data.size(0)
# Backward
loss.backward()
optimizer.step()
total_loss += loss.item()
total_recon += recon_loss.item() / data.size(0)
total_kl += kl_loss.item() / data.size(0)
avg_loss = total_loss / len(dataloader)
avg_recon = total_recon / len(dataloader)
avg_kl = total_kl / len(dataloader)
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}")
return model
6. Beta-VAE¶
Idea¶
ELBO = Reconstruction - beta * KL
beta > 1: Greater weight on KL term
β Latent space more regularized
β Disentangled representations
β Each latent dimension captures independent feature
beta = 1: Regular VAE
beta < 1: Focus on reconstruction
Implementation¶
def beta_vae_loss(x, x_recon, mu, log_var, beta=4.0):
"""Beta-VAE μμ€ ν¨μ (βββ)"""
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Beta κ°μ€μΉ
total_loss = recon_loss + beta * kl_loss
return total_loss, recon_loss, kl_loss
Disentanglement Example¶
Beta-VAE trained on MNIST (beta=4):
z[0]: Digit tilt
z[1]: Line thickness
z[2]: Digit type
...
Independently controlling each dimension changes only that feature
7. Latent Space Visualization¶
2D Latent Space¶
def visualize_latent_space(model, dataloader, device):
"""μ μ¬ κ³΅κ° μκ°ν (ββ)"""
model.eval()
latents = []
labels = []
with torch.no_grad():
for data, label in dataloader:
data = data.to(device)
mu, _ = model.encoder(data)
latents.append(mu.cpu())
labels.append(label)
latents = torch.cat(latents, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()
plt.figure(figsize=(10, 8))
scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6)
plt.colorbar(scatter)
plt.xlabel('z[0]')
plt.ylabel('z[1]')
plt.title('VAE Latent Space')
plt.savefig('latent_space.png')
plt.close()
Latent Space Exploration¶
def explore_latent_dimension(model, dim_idx, range_vals, fixed_z, device):
"""νΉμ μ μ¬ μ°¨μ νμ (ββ)"""
model.eval()
images = []
with torch.no_grad():
for val in range_vals:
z = fixed_z.clone()
z[0, dim_idx] = val
img = model.decoder(z.to(device))
images.append(img.cpu())
return torch.cat(images, dim=0)
Manifold Generation¶
def generate_manifold(model, n=20, latent_dim=2, device='cpu'):
"""2D μ μ¬ κ³΅κ°μ manifold μμ± (βββ)"""
model.eval()
# 그리λ μμ± (-3, 3) λ²μ
grid_x = torch.linspace(-3, 3, n)
grid_y = torch.linspace(-3, 3, n)
figure = np.zeros((28 * n, 28 * n))
with torch.no_grad():
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z = torch.zeros(1, latent_dim)
z[0, 0] = xi
z[0, 1] = yi
x_decoded = model.decoder(z.to(device))
digit = x_decoded[0, 0].cpu().numpy()
figure[i * 28:(i + 1) * 28,
j * 28:(j + 1) * 28] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.axis('off')
plt.savefig('vae_manifold.png')
plt.close()
8. VAE vs GAN Comparison¶
| Characteristic | VAE | GAN |
|---|---|---|
| Training method | Likelihood maximization | Adversarial training |
| Loss function | ELBO (explicit) | Min-max (implicit) |
| Training stability | Stable | Unstable |
| Image quality | Tends to be blurry | Sharp |
| Latent space | Structured | Hard to interpret |
| Mode Coverage | Good | Mode Collapse possible |
| Density estimation | Possible | Not possible |
Advantages and Disadvantages¶
VAE Advantages:
- Explicit density model
- Stable training
- Meaningful latent space
- Can both reconstruct and generate
VAE Disadvantages:
- Blurry images due to reconstruction loss
- KL term limits latent space expressiveness
GAN Advantages:
- Sharp high-quality images
- Implicit density β More flexible
GAN Disadvantages:
- Unstable training
- Mode Collapse
- Difficult to evaluate
9. Advanced VAE Variants¶
Conditional VAE (CVAE)¶
class CVAE(nn.Module):
"""Conditional VAE (βββ)
쑰건(μ: ν΄λμ€ λ μ΄λΈ)μ μ£Όμ΄ νΉμ νμ
μμ±
"""
def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
super().__init__()
self.num_classes = num_classes
# 쑰건μ one-hotμΌλ‘ concat
self.encoder = CVAEEncoder(in_channels, latent_dim, num_classes)
self.decoder = CVAEDecoder(latent_dim, in_channels, num_classes)
self.latent_dim = latent_dim
def forward(self, x, label):
# One-hot encoding
y = F.one_hot(label, self.num_classes).float()
mu, log_var = self.encoder(x, y)
z = self.reparameterize(mu, log_var)
x_recon = self.decoder(z, y)
return x_recon, mu, log_var
def generate(self, label, num_samples, device):
"""νΉμ ν΄λμ€ μμ±"""
z = torch.randn(num_samples, self.latent_dim, device=device)
y = F.one_hot(label, self.num_classes).float().to(device)
y = y.expand(num_samples, -1)
return self.decoder(z, y)
VQ-VAE (Vector Quantized VAE)¶
# VQ-VAEλ μ°μ μ μ¬ κ³΅κ° λμ μ΄μ° μ½λλΆ μ¬μ©
# κ³ νμ§ μ΄λ―Έμ§/μ€λμ€ μμ±μ ν¨κ³Όμ
class VectorQuantizer(nn.Module):
"""VQ-VAEμ λ²‘ν° μμν (ββββ)"""
def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# μ½λλΆ
self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
def forward(self, z):
# z: (batch, channels, H, W)
z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
# κ°μ₯ κ°κΉμ΄ μ½λλΆ λ²‘ν° μ°ΎκΈ°
distances = torch.cdist(z_flat, self.embeddings.weight)
indices = torch.argmin(distances, dim=1)
z_q = self.embeddings(indices).view(z.shape[0], z.shape[2], z.shape[3], -1)
z_q = z_q.permute(0, 3, 1, 2)
# μμ€: μ½λλΆ νμ΅ + commitment loss
loss = F.mse_loss(z_q.detach(), z) + self.commitment_cost * F.mse_loss(z_q, z.detach())
# Straight-through estimator
z_q = z + (z_q - z).detach()
return z_q, loss, indices
10. Complete MNIST VAE Example¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# μ€μ
latent_dim = 20
batch_size = 128
epochs = 30
lr = 1e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# λ°μ΄ν°
transform = transforms.ToTensor()
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# λͺ¨λΈ
model = VAE(in_channels=1, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# νμ΅
for epoch in range(epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
x_recon, mu, log_var = model(data)
# Loss
recon_loss = F.binary_cross_entropy(x_recon, data, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
loss = (recon_loss + kl_loss) / data.size(0)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f"Epoch {epoch+1}: Loss = {train_loss / len(train_loader):.4f}")
# μμ±
model.eval()
with torch.no_grad():
samples = model.generate(16, device)
# μ μ₯ λλ μκ°ν...
print("VAE νμ΅ μλ£!")
Summary¶
Key Concepts¶
- VAE: Generative model with probabilistic latent space
- ELBO: Reconstruction + KL Divergence
- Reparameterization: z = mu + sigma * epsilon
- Beta-VAE: Control KL weight for disentanglement
- Latent space: Continuous, structured
Core Code¶
# Encoder μΆλ ₯
mu, log_var = encoder(x)
# Reparameterization
std = torch.exp(0.5 * log_var)
z = mu + std * torch.randn_like(std)
# Decoder
x_recon = decoder(z)
# Loss
recon = F.binary_cross_entropy(x_recon, x)
kl = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())
loss = recon + kl
Use Case Scenarios¶
| Purpose | Recommended Method |
|---|---|
| Data generation | VAE or GAN |
| Latent space analysis | VAE (especially Beta-VAE) |
| High-quality images | GAN or VQ-VAE |
| Conditional generation | CVAE |
| Compression/reconstruction | VAE |
Next Steps¶
Learn about the latest generative model, Diffusion models, in 32_Diffusion_Models.md.