30. μμ± λͺ¨λΈ - VAE (Variational Autoencoder)
μ΄μ : μμ±μ μ λ μ κ²½λ§(GAN) | λ€μ: Variational Autoencoder (VAE)
30. μμ± λͺ¨λΈ - VAE (Variational Autoencoder)¶
νμ΅ λͺ©ν¶
- VAEμ μ΄λ‘ μ κΈ°λ° (Variational Inference)
- Latent Spaceμ νλ₯ μ μμ± μ΄ν΄
- ELBO μμ€ ν¨μ μ λ
- Reparameterization Trick
- Beta-VAEμ Disentanglement
- PyTorch ꡬν λ° μκ°ν
1. VAE μ΄λ‘ ¶
Autoencoder vs VAE¶
Autoencoder:
μ
λ ₯ β Encoder β μ μ¬ λ²‘ν° z (κ²°μ λ‘ μ ) β Decoder β μ¬κ΅¬μ±
VAE:
μ
λ ₯ β Encoder β νκ· mu, λΆμ° sigma β μνλ§ z ~ N(mu, sigma) β Decoder β μ¬κ΅¬μ±
β
μ μ¬ κ³΅κ°μ΄ μ°μμ μ΄κ³ μ κ· λΆν¬λ₯Ό λ°λ¦
β μλ‘μ΄ μ΄λ―Έμ§ μμ± κ°λ₯
μ νλ₯ μ μΈκ°?¶
μΌλ° Autoencoderμ λ¬Έμ :
- μ μ¬ κ³΅κ°μ΄ λΆμ°μμ
- νμ΅ λ°μ΄ν°μ μλ z μ
λ ₯ μ μ΄μν μΆλ ₯
- μμ± λͺ¨λΈλ‘ μ¬μ©νκΈ° μ΄λ €μ
VAEμ ν΄κ²°:
- μ μ¬ κ³΅κ°μ μ κ· λΆν¬λ‘ μ κ·ν
- μ°μμ μΈ μ μ¬ κ³΅κ°
- μμμ z ~ N(0, I)μμ μνλ§νμ¬ μμ± κ°λ₯
κ·ΈλνΌμ»¬ λͺ¨λΈ¶
μμ± κ³Όμ (Generative Process):
z ~ p(z) = N(0, I) # μ¬μ λΆν¬
x ~ p_theta(x|z) # λμ½λ
μΆλ‘ κ³Όμ (Inference):
q_phi(z|x) β p(z|x) # μΈμ½λκ° μ¬ν λΆν¬ κ·Όμ¬
λͺ©ν:
log p(x) μ΅λν (λ°μ΄ν°μ μ°λ)
β ELBO (Evidence Lower Bound) μ΅λν
2. ELBO μμ€ ν¨μ¶
μ λ¶
log p(x) = log β« p(x, z) dz
= log β« p(x|z) p(z) dz
= log β« q(z|x) * [p(x|z) p(z) / q(z|x)] dz
β₯ β« q(z|x) log[p(x|z) p(z) / q(z|x)] dz (Jensen's inequality)
= E_q[log p(x|z)] - KL(q(z|x) || p(z))
= ELBO (Evidence Lower Bound)
λ νμ μλ―Έ¶
# ELBO = Reconstruction - KL Divergence
# 1. Reconstruction Term: E_q[log p(x|z)]
# - λμ½λκ° zλ‘λΆν° xλ₯Ό μΌλ§λ μ 볡μνλκ°
# - μ¬κ΅¬μ± μμ€ (MSE λλ BCE)
# 2. KL Divergence: KL(q(z|x) || p(z))
# - μΈμ½λ©λ λΆν¬κ° μ¬μ λΆν¬(N(0,I))μ μΌλ§λ κ°κΉμ΄κ°
# - μ μ¬ κ³΅κ° μ κ·ν
μμ€ ν¨μ ꡬν¶
def vae_loss(x, x_recon, mu, log_var):
"""VAE μμ€ ν¨μ (βββ)
Args:
x: μλ³Έ μ΄λ―Έμ§ (batch, ...)
x_recon: μ¬κ΅¬μ± μ΄λ―Έμ§ (batch, ...)
mu: νκ· (batch, latent_dim)
log_var: λ‘κ·Έ λΆμ° (batch, latent_dim)
Returns:
total_loss, recon_loss, kl_loss
"""
# μ¬κ΅¬μ± μμ€ (μ΄μ§ μ΄λ―Έμ§: BCE, μ°μ μ΄λ―Έμ§: MSE)
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
# λλ MSE
# recon_loss = F.mse_loss(x_recon, x, reduction='sum')
# KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
total_loss = recon_loss + kl_loss
return total_loss, recon_loss, kl_loss
3. Reparameterization Trick¶
λ¬Έμ μ ¶
z ~ q(z|x) = N(mu, sigma)μμ μνλ§
λ¬Έμ : μνλ§μ λ―ΈλΆ λΆκ° β μμ ν λΆκ°
ν΄κ²°: Reparameterization¶
def reparameterize(mu, log_var):
"""Reparameterization Trick (βββ)
z = mu + sigma * epsilon
epsilon ~ N(0, I)
μ΄λ κ² νλ©΄ λλ€μ±μ΄ epsilonμ μκ³ ,
mu, sigmaμ λν΄ λ―ΈλΆ κ°λ₯
"""
std = torch.exp(0.5 * log_var) # sigma = exp(0.5 * log(sigma^2))
eps = torch.randn_like(std) # epsilon ~ N(0, I)
z = mu + std * eps
return z
μκ°μ μ΄ν΄¶
[λ―ΈλΆ λΆκ°]
mu, sigma β μνλ§ β z β Decoder
[λ―ΈλΆ κ°λ₯ - Reparameterization]
mu βββββββββββββββ
β
βΌ
sigma βββββββββΆ (mu + sigma * eps) βββΆ z βββΆ Decoder
β²
eps ~ N(0, I) ββββββββββ (μμ μ·¨κΈ)
4. VAE λͺ¨λΈ ꡬν¶
Encoder¶
import torch
import torch.nn as nn
import torch.nn.functional as F
class VAEEncoder(nn.Module):
"""VAE Encoder (βββ)
μ΄λ―Έμ§ β mu, log_var
"""
def __init__(self, in_channels=1, latent_dim=20):
super().__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(in_channels, 32, 3, stride=2, padding=1), # 28 β 14
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14 β 7
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
def forward(self, x):
h = self.conv_layers(x)
mu = self.fc_mu(h)
log_var = self.fc_logvar(h)
return mu, log_var
Decoder¶
class VAEDecoder(nn.Module):
"""VAE Decoder (βββ)
z β μ΄λ―Έμ§
"""
def __init__(self, latent_dim=20, out_channels=1):
super().__init__()
self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
self.deconv_layers = nn.Sequential(
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 7 β 14
nn.ReLU(),
nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1), # 14 β 28
nn.Sigmoid() # [0, 1]
)
def forward(self, z):
h = self.fc(z)
h = h.view(-1, 64, 7, 7)
x_recon = self.deconv_layers(h)
return x_recon
μ 체 VAE¶
class VAE(nn.Module):
"""Variational Autoencoder (βββ)"""
def __init__(self, in_channels=1, latent_dim=20):
super().__init__()
self.encoder = VAEEncoder(in_channels, latent_dim)
self.decoder = VAEDecoder(latent_dim, in_channels)
self.latent_dim = latent_dim
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + std * eps
def forward(self, x):
mu, log_var = self.encoder(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decoder(z)
return x_recon, mu, log_var
def generate(self, num_samples, device):
"""μλ‘μ΄ μν μμ±"""
z = torch.randn(num_samples, self.latent_dim, device=device)
samples = self.decoder(z)
return samples
def reconstruct(self, x):
"""μ΄λ―Έμ§ μ¬κ΅¬μ±"""
with torch.no_grad():
x_recon, _, _ = self.forward(x)
return x_recon
5. νμ΅ λ£¨ν¶
def train_vae(model, dataloader, epochs=50, lr=1e-3):
"""VAE νμ΅ (βββ)"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
total_loss = 0
total_recon = 0
total_kl = 0
for batch_idx, (data, _) in enumerate(dataloader):
data = data.to(device)
optimizer.zero_grad()
# Forward
x_recon, mu, log_var = model(data)
# Loss
loss, recon_loss, kl_loss = vae_loss(data, x_recon, mu, log_var)
# Normalize by batch size
loss = loss / data.size(0)
# Backward
loss.backward()
optimizer.step()
total_loss += loss.item()
total_recon += recon_loss.item() / data.size(0)
total_kl += kl_loss.item() / data.size(0)
avg_loss = total_loss / len(dataloader)
avg_recon = total_recon / len(dataloader)
avg_kl = total_kl / len(dataloader)
print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}")
return model
6. Beta-VAE¶
μμ΄λμ΄¶
ELBO = Reconstruction - beta * KL
beta > 1: KL νμ λ ν° κ°μ€μΉ
β μ μ¬ κ³΅κ°μ΄ λ μ κ·νλ¨
β Disentangled representations
β κ° μ μ¬ μ°¨μμ΄ λ
립μ μΈ νΉμ§ ν¬μ°©
beta = 1: μΌλ° VAE
beta < 1: μ¬κ΅¬μ±μ μ§μ€
ꡬν¶
def beta_vae_loss(x, x_recon, mu, log_var, beta=4.0):
"""Beta-VAE μμ€ ν¨μ (βββ)"""
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
# Beta κ°μ€μΉ
total_loss = recon_loss + beta * kl_loss
return total_loss, recon_loss, kl_loss
Disentanglement μμ¶
MNISTμμ νμ΅λ Beta-VAE (beta=4):
z[0]: μ«μμ κΈ°μΈκΈ°
z[1]: μ λκ»
z[2]: μ«μ μ’
λ₯
...
κ° μ°¨μμ λ
립μ μΌλ‘ μ‘°μ νλ©΄ ν΄λΉ νΉμ§λ§ λ³ν
7. Latent Space μκ°ν¶
2D Latent Space¶
def visualize_latent_space(model, dataloader, device):
"""μ μ¬ κ³΅κ° μκ°ν (ββ)"""
model.eval()
latents = []
labels = []
with torch.no_grad():
for data, label in dataloader:
data = data.to(device)
mu, _ = model.encoder(data)
latents.append(mu.cpu())
labels.append(label)
latents = torch.cat(latents, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()
plt.figure(figsize=(10, 8))
scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6)
plt.colorbar(scatter)
plt.xlabel('z[0]')
plt.ylabel('z[1]')
plt.title('VAE Latent Space')
plt.savefig('latent_space.png')
plt.close()
Latent Space νμ¶
def explore_latent_dimension(model, dim_idx, range_vals, fixed_z, device):
"""νΉμ μ μ¬ μ°¨μ νμ (ββ)"""
model.eval()
images = []
with torch.no_grad():
for val in range_vals:
z = fixed_z.clone()
z[0, dim_idx] = val
img = model.decoder(z.to(device))
images.append(img.cpu())
return torch.cat(images, dim=0)
Manifold μμ±¶
def generate_manifold(model, n=20, latent_dim=2, device='cpu'):
"""2D μ μ¬ κ³΅κ°μ manifold μμ± (βββ)"""
model.eval()
# 그리λ μμ± (-3, 3) λ²μ
grid_x = torch.linspace(-3, 3, n)
grid_y = torch.linspace(-3, 3, n)
figure = np.zeros((28 * n, 28 * n))
with torch.no_grad():
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z = torch.zeros(1, latent_dim)
z[0, 0] = xi
z[0, 1] = yi
x_decoded = model.decoder(z.to(device))
digit = x_decoded[0, 0].cpu().numpy()
figure[i * 28:(i + 1) * 28,
j * 28:(j + 1) * 28] = digit
plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='gray')
plt.axis('off')
plt.savefig('vae_manifold.png')
plt.close()
8. VAE vs GAN λΉκ΅¶
| νΉμ± | VAE | GAN |
|---|---|---|
| νμ΅ λ°©μ | μ°λ μ΅λν | μ λμ νμ΅ |
| μμ€ ν¨μ | ELBO (λͺ μμ ) | Min-max (μμμ ) |
| νμ΅ μμ μ± | μμ μ | λΆμμ |
| μ΄λ―Έμ§ νμ§ | νλ¦Ών κ²½ν₯ | μ λͺ ν¨ |
| μ μ¬ κ³΅κ° | ꡬ쑰νλ¨ | ν΄μ μ΄λ €μ |
| Mode Coverage | μ’μ | Mode Collapse κ°λ₯ |
| λ°λ μΆμ | κ°λ₯ | λΆκ° |
μ₯λ¨μ ¶
VAE μ₯μ :
- λͺ
μμ λ°λ λͺ¨λΈ
- μμ μ νμ΅
- μλ―Έ μλ μ μ¬ κ³΅κ°
- μ¬κ΅¬μ± + μμ± λͺ¨λ κ°λ₯
VAE λ¨μ :
- μ¬κ΅¬μ± μμ€λ‘ μΈν νλ¦Ών μ΄λ―Έμ§
- KL νμ΄ μ μ¬ κ³΅κ° ννλ ₯ μ ν
GAN μ₯μ :
- μ λͺ
ν κ³ νμ§ μ΄λ―Έμ§
- μμμ λ°λ β λ μ μ°
GAN λ¨μ :
- νμ΅ λΆμμ
- Mode Collapse
- νκ° μ΄λ €μ
9. κ³ κΈ VAE λ³ν¶
Conditional VAE (CVAE)¶
class CVAE(nn.Module):
"""Conditional VAE (βββ)
쑰건(μ: ν΄λμ€ λ μ΄λΈ)μ μ£Όμ΄ νΉμ νμ
μμ±
"""
def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
super().__init__()
self.num_classes = num_classes
# 쑰건μ one-hotμΌλ‘ concat
self.encoder = CVAEEncoder(in_channels, latent_dim, num_classes)
self.decoder = CVAEDecoder(latent_dim, in_channels, num_classes)
self.latent_dim = latent_dim
def forward(self, x, label):
# One-hot encoding
y = F.one_hot(label, self.num_classes).float()
mu, log_var = self.encoder(x, y)
z = self.reparameterize(mu, log_var)
x_recon = self.decoder(z, y)
return x_recon, mu, log_var
def generate(self, label, num_samples, device):
"""νΉμ ν΄λμ€ μμ±"""
z = torch.randn(num_samples, self.latent_dim, device=device)
y = F.one_hot(label, self.num_classes).float().to(device)
y = y.expand(num_samples, -1)
return self.decoder(z, y)
VQ-VAE (Vector Quantized VAE)¶
# VQ-VAEλ μ°μ μ μ¬ κ³΅κ° λμ μ΄μ° μ½λλΆ μ¬μ©
# κ³ νμ§ μ΄λ―Έμ§/μ€λμ€ μμ±μ ν¨κ³Όμ
class VectorQuantizer(nn.Module):
"""VQ-VAEμ λ²‘ν° μμν (ββββ)"""
def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
super().__init__()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.commitment_cost = commitment_cost
# μ½λλΆ
self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)
def forward(self, z):
# z: (batch, channels, H, W)
z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)
# κ°μ₯ κ°κΉμ΄ μ½λλΆ λ²‘ν° μ°ΎκΈ°
distances = torch.cdist(z_flat, self.embeddings.weight)
indices = torch.argmin(distances, dim=1)
z_q = self.embeddings(indices).view(z.shape[0], z.shape[2], z.shape[3], -1)
z_q = z_q.permute(0, 3, 1, 2)
# μμ€: μ½λλΆ νμ΅ + commitment loss
loss = F.mse_loss(z_q.detach(), z) + self.commitment_cost * F.mse_loss(z_q, z.detach())
# Straight-through estimator
z_q = z + (z_q - z).detach()
return z_q, loss, indices
10. MNIST VAE μμ μμ ¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# μ€μ
latent_dim = 20
batch_size = 128
epochs = 30
lr = 1e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# λ°μ΄ν°
transform = transforms.ToTensor()
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# λͺ¨λΈ
model = VAE(in_channels=1, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# νμ΅
for epoch in range(epochs):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
x_recon, mu, log_var = model(data)
# Loss
recon_loss = F.binary_cross_entropy(x_recon, data, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
loss = (recon_loss + kl_loss) / data.size(0)
loss.backward()
train_loss += loss.item()
optimizer.step()
print(f"Epoch {epoch+1}: Loss = {train_loss / len(train_loader):.4f}")
# μμ±
model.eval()
with torch.no_grad():
samples = model.generate(16, device)
# μ μ₯ λλ μκ°ν...
print("VAE νμ΅ μλ£!")
μ 리¶
ν΅μ¬ κ°λ ¶
- VAE: νλ₯ μ μ μ¬ κ³΅κ°μ κ°μ§ μμ± λͺ¨λΈ
- ELBO: Reconstruction + KL Divergence
- Reparameterization: z = mu + sigma * epsilon
- Beta-VAE: KL κ°μ€μΉ μ‘°μ λ‘ disentanglement
- μ μ¬ κ³΅κ°: μ°μμ , ꡬ쑰νλ¨
ν΅μ¬ μ½λ¶
# Encoder μΆλ ₯
mu, log_var = encoder(x)
# Reparameterization
std = torch.exp(0.5 * log_var)
z = mu + std * torch.randn_like(std)
# Decoder
x_recon = decoder(z)
# Loss
recon = F.binary_cross_entropy(x_recon, x)
kl = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())
loss = recon + kl
μ¬μ© μλ리쀶
| λͺ©μ | μΆμ² λ°©λ² |
|---|---|
| λ°μ΄ν° μμ± | VAE λλ GAN |
| μ μ¬ κ³΅κ° λΆμ | VAE (νΉν Beta-VAE) |
| κ³ νμ§ μ΄λ―Έμ§ | GAN λλ VQ-VAE |
| μ‘°κ±΄λΆ μμ± | CVAE |
| μμΆ/μ¬κ΅¬μ± | VAE |
λ€μ λ¨κ³¶
32_Diffusion_Models.mdμμ μ΅μ μμ± λͺ¨λΈμΈ Diffusion λͺ¨λΈμ νμ΅ν©λλ€.