30. 생성 λͺ¨λΈ - VAE (Variational Autoencoder)

이전: 생성적 μ λŒ€ 신경망(GAN) | λ‹€μŒ: Variational Autoencoder (VAE)


30. 생성 λͺ¨λΈ - VAE (Variational Autoencoder)

ν•™μŠ΅ λͺ©ν‘œ

  • VAE의 이둠적 기반 (Variational Inference)
  • Latent Space와 ν™•λ₯ μ  생성 이해
  • ELBO 손싀 ν•¨μˆ˜ μœ λ„
  • Reparameterization Trick
  • Beta-VAE와 Disentanglement
  • PyTorch κ΅¬ν˜„ 및 μ‹œκ°ν™”

1. VAE 이둠

Autoencoder vs VAE

Autoencoder:
    μž…λ ₯ β†’ Encoder β†’ 잠재 벑터 z (결정둠적) β†’ Decoder β†’ μž¬κ΅¬μ„±

VAE:
    μž…λ ₯ β†’ Encoder β†’ 평균 mu, λΆ„μ‚° sigma β†’ μƒ˜ν”Œλ§ z ~ N(mu, sigma) β†’ Decoder β†’ μž¬κ΅¬μ„±
                         ↓
                    잠재 곡간이 연속적이고 μ •κ·œ 뢄포λ₯Ό 따름
                    β†’ μƒˆλ‘œμš΄ 이미지 생성 κ°€λŠ₯

μ™œ ν™•λ₯ μ μΈκ°€?

일반 Autoencoder의 문제:
- 잠재 곡간이 λΆˆμ—°μ†μ 
- ν•™μŠ΅ 데이터에 μ—†λŠ” z μž…λ ₯ μ‹œ μ΄μƒν•œ 좜λ ₯
- 생성 λͺ¨λΈλ‘œ μ‚¬μš©ν•˜κΈ° 어렀움

VAE의 ν•΄κ²°:
- 잠재 곡간을 μ •κ·œ λΆ„ν¬λ‘œ μ •κ·œν™”
- 연속적인 잠재 곡간
- μž„μ˜μ˜ z ~ N(0, I)μ—μ„œ μƒ˜ν”Œλ§ν•˜μ—¬ 생성 κ°€λŠ₯

κ·Έλž˜ν”Όμ»¬ λͺ¨λΈ

생성 κ³Όμ • (Generative Process):
    z ~ p(z) = N(0, I)           # 사전 뢄포
    x ~ p_theta(x|z)             # 디코더

μΆ”λ‘  κ³Όμ • (Inference):
    q_phi(z|x) β‰ˆ p(z|x)          # 인코더가 사후 뢄포 근사

λͺ©ν‘œ:
    log p(x) μ΅œλŒ€ν™” (λ°μ΄ν„°μ˜ μš°λ„)
    β†’ ELBO (Evidence Lower Bound) μ΅œλŒ€ν™”

2. ELBO 손싀 ν•¨μˆ˜

μœ λ„

log p(x) = log ∫ p(x, z) dz

         = log ∫ p(x|z) p(z) dz

         = log ∫ q(z|x) * [p(x|z) p(z) / q(z|x)] dz

         β‰₯ ∫ q(z|x) log[p(x|z) p(z) / q(z|x)] dz    (Jensen's inequality)

         = E_q[log p(x|z)] - KL(q(z|x) || p(z))

         = ELBO (Evidence Lower Bound)

두 ν•­μ˜ 의미

# ELBO = Reconstruction - KL Divergence

# 1. Reconstruction Term: E_q[log p(x|z)]
#    - 디코더가 zλ‘œλΆ€ν„° xλ₯Ό μ–Όλ§ˆλ‚˜ 잘 λ³΅μ›ν•˜λŠ”κ°€
#    - μž¬κ΅¬μ„± 손싀 (MSE λ˜λŠ” BCE)

# 2. KL Divergence: KL(q(z|x) || p(z))
#    - μΈμ½”λ”©λœ 뢄포가 사전 뢄포(N(0,I))와 μ–Όλ§ˆλ‚˜ κ°€κΉŒμš΄κ°€
#    - 잠재 곡간 μ •κ·œν™”

손싀 ν•¨μˆ˜ κ΅¬ν˜„

def vae_loss(x, x_recon, mu, log_var):
    """VAE 손싀 ν•¨μˆ˜ (⭐⭐⭐)

    Args:
        x: 원본 이미지 (batch, ...)
        x_recon: μž¬κ΅¬μ„± 이미지 (batch, ...)
        mu: 평균 (batch, latent_dim)
        log_var: 둜그 λΆ„μ‚° (batch, latent_dim)

    Returns:
        total_loss, recon_loss, kl_loss
    """
    # μž¬κ΅¬μ„± 손싀 (이진 이미지: BCE, 연속 이미지: MSE)
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    # λ˜λŠ” MSE
    # recon_loss = F.mse_loss(x_recon, x, reduction='sum')

    # KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    total_loss = recon_loss + kl_loss

    return total_loss, recon_loss, kl_loss

3. Reparameterization Trick

문제점

z ~ q(z|x) = N(mu, sigma)μ—μ„œ μƒ˜ν”Œλ§

문제: μƒ˜ν”Œλ§μ€ λ―ΈλΆ„ λΆˆκ°€ β†’ μ—­μ „νŒŒ λΆˆκ°€

ν•΄κ²°: Reparameterization

def reparameterize(mu, log_var):
    """Reparameterization Trick (⭐⭐⭐)

    z = mu + sigma * epsilon
    epsilon ~ N(0, I)

    μ΄λ ‡κ²Œ ν•˜λ©΄ λžœλ€μ„±μ΄ epsilon에 있고,
    mu, sigma에 λŒ€ν•΄ λ―ΈλΆ„ κ°€λŠ₯
    """
    std = torch.exp(0.5 * log_var)  # sigma = exp(0.5 * log(sigma^2))
    eps = torch.randn_like(std)     # epsilon ~ N(0, I)
    z = mu + std * eps
    return z

μ‹œκ°μ  이해

[λ―ΈλΆ„ λΆˆκ°€]
mu, sigma β†’ μƒ˜ν”Œλ§ β†’ z β†’ Decoder

[λ―ΈλΆ„ κ°€λŠ₯ - Reparameterization]
mu ──────────────┐
                 β”‚
                 β–Ό
sigma ────────▢ (mu + sigma * eps) ──▢ z ──▢ Decoder
                       β–²
eps ~ N(0, I) β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ (μƒμˆ˜ μ·¨κΈ‰)

4. VAE λͺ¨λΈ κ΅¬ν˜„

Encoder

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAEEncoder(nn.Module):
    """VAE Encoder (⭐⭐⭐)

    이미지 β†’ mu, log_var
    """
    def __init__(self, in_channels=1, latent_dim=20):
        super().__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1),  # 28 β†’ 14
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),           # 14 β†’ 7
            nn.ReLU(),
            nn.Flatten()
        )

        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

    def forward(self, x):
        h = self.conv_layers(x)
        mu = self.fc_mu(h)
        log_var = self.fc_logvar(h)
        return mu, log_var

Decoder

class VAEDecoder(nn.Module):
    """VAE Decoder (⭐⭐⭐)

    z β†’ 이미지
    """
    def __init__(self, latent_dim=20, out_channels=1):
        super().__init__()

        self.fc = nn.Linear(latent_dim, 64 * 7 * 7)

        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 7 β†’ 14
            nn.ReLU(),
            nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1),  # 14 β†’ 28
            nn.Sigmoid()  # [0, 1]
        )

    def forward(self, z):
        h = self.fc(z)
        h = h.view(-1, 64, 7, 7)
        x_recon = self.deconv_layers(h)
        return x_recon

전체 VAE

class VAE(nn.Module):
    """Variational Autoencoder (⭐⭐⭐)"""
    def __init__(self, in_channels=1, latent_dim=20):
        super().__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim)
        self.decoder = VAEDecoder(latent_dim, in_channels)
        self.latent_dim = latent_dim

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + std * eps

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var

    def generate(self, num_samples, device):
        """μƒˆλ‘œμš΄ μƒ˜ν”Œ 생성"""
        z = torch.randn(num_samples, self.latent_dim, device=device)
        samples = self.decoder(z)
        return samples

    def reconstruct(self, x):
        """이미지 μž¬κ΅¬μ„±"""
        with torch.no_grad():
            x_recon, _, _ = self.forward(x)
        return x_recon

5. ν•™μŠ΅ 루프

def train_vae(model, dataloader, epochs=50, lr=1e-3):
    """VAE ν•™μŠ΅ (⭐⭐⭐)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_recon = 0
        total_kl = 0

        for batch_idx, (data, _) in enumerate(dataloader):
            data = data.to(device)

            optimizer.zero_grad()

            # Forward
            x_recon, mu, log_var = model(data)

            # Loss
            loss, recon_loss, kl_loss = vae_loss(data, x_recon, mu, log_var)

            # Normalize by batch size
            loss = loss / data.size(0)

            # Backward
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_recon += recon_loss.item() / data.size(0)
            total_kl += kl_loss.item() / data.size(0)

        avg_loss = total_loss / len(dataloader)
        avg_recon = total_recon / len(dataloader)
        avg_kl = total_kl / len(dataloader)

        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}")

    return model

6. Beta-VAE

아이디어

ELBO = Reconstruction - beta * KL

beta > 1: KL 항에 더 큰 κ°€μ€‘μΉ˜
    β†’ 잠재 곡간이 더 μ •κ·œν™”λ¨
    β†’ Disentangled representations
    β†’ 각 잠재 차원이 독립적인 νŠΉμ§• 포착

beta = 1: 일반 VAE
beta < 1: μž¬κ΅¬μ„±μ— 집쀑

κ΅¬ν˜„

def beta_vae_loss(x, x_recon, mu, log_var, beta=4.0):
    """Beta-VAE 손싀 ν•¨μˆ˜ (⭐⭐⭐)"""
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    # Beta κ°€μ€‘μΉ˜
    total_loss = recon_loss + beta * kl_loss

    return total_loss, recon_loss, kl_loss

Disentanglement μ˜ˆμ‹œ

MNISTμ—μ„œ ν•™μŠ΅λœ Beta-VAE (beta=4):
    z[0]: 숫자의 기울기
    z[1]: μ„  λ‘κ»˜
    z[2]: 숫자 μ’…λ₯˜
    ...

각 차원을 λ…λ¦½μ μœΌλ‘œ μ‘°μ ˆν•˜λ©΄ ν•΄λ‹Ή νŠΉμ§•λ§Œ λ³€ν™”

7. Latent Space μ‹œκ°ν™”

2D Latent Space

def visualize_latent_space(model, dataloader, device):
    """잠재 곡간 μ‹œκ°ν™” (⭐⭐)"""
    model.eval()

    latents = []
    labels = []

    with torch.no_grad():
        for data, label in dataloader:
            data = data.to(device)
            mu, _ = model.encoder(data)
            latents.append(mu.cpu())
            labels.append(label)

    latents = torch.cat(latents, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6)
    plt.colorbar(scatter)
    plt.xlabel('z[0]')
    plt.ylabel('z[1]')
    plt.title('VAE Latent Space')
    plt.savefig('latent_space.png')
    plt.close()

Latent Space 탐색

def explore_latent_dimension(model, dim_idx, range_vals, fixed_z, device):
    """νŠΉμ • 잠재 차원 탐색 (⭐⭐)"""
    model.eval()
    images = []

    with torch.no_grad():
        for val in range_vals:
            z = fixed_z.clone()
            z[0, dim_idx] = val
            img = model.decoder(z.to(device))
            images.append(img.cpu())

    return torch.cat(images, dim=0)

Manifold 생성

def generate_manifold(model, n=20, latent_dim=2, device='cpu'):
    """2D 잠재 κ³΅κ°„μ˜ manifold 생성 (⭐⭐⭐)"""
    model.eval()

    # κ·Έλ¦¬λ“œ 생성 (-3, 3) λ²”μœ„
    grid_x = torch.linspace(-3, 3, n)
    grid_y = torch.linspace(-3, 3, n)

    figure = np.zeros((28 * n, 28 * n))

    with torch.no_grad():
        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z = torch.zeros(1, latent_dim)
                z[0, 0] = xi
                z[0, 1] = yi

                x_decoded = model.decoder(z.to(device))
                digit = x_decoded[0, 0].cpu().numpy()

                figure[i * 28:(i + 1) * 28,
                       j * 28:(j + 1) * 28] = digit

    plt.figure(figsize=(10, 10))
    plt.imshow(figure, cmap='gray')
    plt.axis('off')
    plt.savefig('vae_manifold.png')
    plt.close()

8. VAE vs GAN 비ꡐ

νŠΉμ„± VAE GAN
ν•™μŠ΅ 방식 μš°λ„ μ΅œλŒ€ν™” μ λŒ€μ  ν•™μŠ΅
손싀 ν•¨μˆ˜ ELBO (λͺ…μ‹œμ ) Min-max (μ•”μ‹œμ )
ν•™μŠ΅ μ•ˆμ •μ„± μ•ˆμ •μ  λΆˆμ•ˆμ •
이미지 ν’ˆμ§ˆ νλ¦Ών•œ κ²½ν–₯ μ„ λͺ…함
잠재 곡간 ꡬ쑰화됨 해석 어렀움
Mode Coverage μ’‹μŒ Mode Collapse κ°€λŠ₯
밀도 μΆ”μ • κ°€λŠ₯ λΆˆκ°€

μž₯단점

VAE μž₯점:
- λͺ…μ‹œμ  밀도 λͺ¨λΈ
- μ•ˆμ •μ  ν•™μŠ΅
- 의미 μžˆλŠ” 잠재 곡간
- μž¬κ΅¬μ„± + 생성 λͺ¨λ‘ κ°€λŠ₯

VAE 단점:
- μž¬κ΅¬μ„± μ†μ‹€λ‘œ μΈν•œ νλ¦Ών•œ 이미지
- KL 항이 잠재 곡간 ν‘œν˜„λ ₯ μ œν•œ

GAN μž₯점:
- μ„ λͺ…ν•œ κ³ ν’ˆμ§ˆ 이미지
- μ•”μ‹œμ  밀도 β†’ 더 μœ μ—°

GAN 단점:
- ν•™μŠ΅ λΆˆμ•ˆμ •
- Mode Collapse
- 평가 어렀움

9. κ³ κΈ‰ VAE λ³€ν˜•

Conditional VAE (CVAE)

class CVAE(nn.Module):
    """Conditional VAE (⭐⭐⭐)

    쑰건(예: 클래슀 λ ˆμ΄λΈ”)을 μ£Όμ–΄ νŠΉμ • νƒ€μž… 생성
    """
    def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
        super().__init__()
        self.num_classes = num_classes

        # 쑰건을 one-hot으둜 concat
        self.encoder = CVAEEncoder(in_channels, latent_dim, num_classes)
        self.decoder = CVAEDecoder(latent_dim, in_channels, num_classes)
        self.latent_dim = latent_dim

    def forward(self, x, label):
        # One-hot encoding
        y = F.one_hot(label, self.num_classes).float()

        mu, log_var = self.encoder(x, y)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z, y)

        return x_recon, mu, log_var

    def generate(self, label, num_samples, device):
        """νŠΉμ • 클래슀 생성"""
        z = torch.randn(num_samples, self.latent_dim, device=device)
        y = F.one_hot(label, self.num_classes).float().to(device)
        y = y.expand(num_samples, -1)
        return self.decoder(z, y)

VQ-VAE (Vector Quantized VAE)

# VQ-VAEλŠ” 연속 잠재 곡간 λŒ€μ‹  이산 μ½”λ“œλΆ μ‚¬μš©
# κ³ ν’ˆμ§ˆ 이미지/μ˜€λ””μ˜€ 생성에 효과적

class VectorQuantizer(nn.Module):
    """VQ-VAE의 벑터 μ–‘μžν™” (⭐⭐⭐⭐)"""
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # μ½”λ“œλΆ
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def forward(self, z):
        # z: (batch, channels, H, W)
        z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        # κ°€μž₯ κ°€κΉŒμš΄ μ½”λ“œλΆ 벑터 μ°ΎκΈ°
        distances = torch.cdist(z_flat, self.embeddings.weight)
        indices = torch.argmin(distances, dim=1)
        z_q = self.embeddings(indices).view(z.shape[0], z.shape[2], z.shape[3], -1)
        z_q = z_q.permute(0, 3, 1, 2)

        # 손싀: μ½”λ“œλΆ ν•™μŠ΅ + commitment loss
        loss = F.mse_loss(z_q.detach(), z) + self.commitment_cost * F.mse_loss(z_q, z.detach())

        # Straight-through estimator
        z_q = z + (z_q - z).detach()

        return z_q, loss, indices

10. MNIST VAE μ™„μ „ 예제

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# μ„€μ •
latent_dim = 20
batch_size = 128
epochs = 30
lr = 1e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 데이터
transform = transforms.ToTensor()
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# λͺ¨λΈ
model = VAE(in_channels=1, latent_dim=latent_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# ν•™μŠ΅
for epoch in range(epochs):
    model.train()
    train_loss = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)

        optimizer.zero_grad()
        x_recon, mu, log_var = model(data)

        # Loss
        recon_loss = F.binary_cross_entropy(x_recon, data, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        loss = (recon_loss + kl_loss) / data.size(0)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch+1}: Loss = {train_loss / len(train_loader):.4f}")

# 생성
model.eval()
with torch.no_grad():
    samples = model.generate(16, device)
    # μ €μž₯ λ˜λŠ” μ‹œκ°ν™”...

print("VAE ν•™μŠ΅ μ™„λ£Œ!")

정리

핡심 κ°œλ…

  1. VAE: ν™•λ₯ μ  잠재 곡간을 κ°€μ§„ 생성 λͺ¨λΈ
  2. ELBO: Reconstruction + KL Divergence
  3. Reparameterization: z = mu + sigma * epsilon
  4. Beta-VAE: KL κ°€μ€‘μΉ˜ 쑰절둜 disentanglement
  5. 잠재 곡간: 연속적, ꡬ쑰화됨

핡심 μ½”λ“œ

# Encoder 좜λ ₯
mu, log_var = encoder(x)

# Reparameterization
std = torch.exp(0.5 * log_var)
z = mu + std * torch.randn_like(std)

# Decoder
x_recon = decoder(z)

# Loss
recon = F.binary_cross_entropy(x_recon, x)
kl = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())
loss = recon + kl

μ‚¬μš© μ‹œλ‚˜λ¦¬μ˜€

λͺ©μ  μΆ”μ²œ 방법
데이터 생성 VAE λ˜λŠ” GAN
잠재 곡간 뢄석 VAE (특히 Beta-VAE)
κ³ ν’ˆμ§ˆ 이미지 GAN λ˜λŠ” VQ-VAE
쑰건뢀 생성 CVAE
μ••μΆ•/μž¬κ΅¬μ„± VAE

λ‹€μŒ 단계

32_Diffusion_Models.mdμ—μ„œ μ΅œμ‹  생성 λͺ¨λΈμΈ Diffusion λͺ¨λΈμ„ ν•™μŠ΅ν•©λ‹ˆλ‹€.

to navigate between lessons