28. ์ƒ์„ฑ ๋ชจ๋ธ - GAN (Generative Adversarial Networks)

์ด์ „: TensorBoard ์‹œ๊ฐํ™” | ๋‹ค์Œ: ์ƒ์„ฑ์  ์ ๋Œ€ ์‹ ๊ฒฝ๋ง(GAN)


28. ์ƒ์„ฑ ๋ชจ๋ธ - GAN (Generative Adversarial Networks)

ํ•™์Šต ๋ชฉํ‘œ

  • GAN์˜ ๊ธฐ๋ณธ ์›๋ฆฌ์™€ ์ ๋Œ€์  ํ•™์Šต ์ดํ•ด
  • Generator์™€ Discriminator ๊ตฌ์กฐ ์„ค๊ณ„
  • ๋‹ค์–‘ํ•œ ์†์‹ค ํ•จ์ˆ˜ (Adversarial, Wasserstein, WGAN-GP)
  • DCGAN ์•„ํ‚คํ…์ฒ˜ ๊ตฌํ˜„
  • ํ•™์Šต ์•ˆ์ •ํ™” ๊ธฐ๋ฒ• ์ ์šฉ
  • StyleGAN ๊ฐœ๋… ์ดํ•ด

1. GAN ๊ธฐ์ดˆ ์ด๋ก 

๊ฐœ๋…

GAN = ๋‘ ์‹ ๊ฒฝ๋ง์˜ ๊ฒฝ์Ÿ์  ํ•™์Šต

Generator (G): ๊ฐ€์งœ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
    ๋žœ๋ค ๋…ธ์ด์ฆˆ z โ†’ ๊ฐ€์งœ ์ด๋ฏธ์ง€

Discriminator (D): ์ง„์งœ/๊ฐ€์งœ ํŒ๋ณ„
    ์ด๋ฏธ์ง€ โ†’ ์ง„์งœ(1) / ๊ฐ€์งœ(0)

๋ชฉํ‘œ: G๊ฐ€ D๋ฅผ ์†์ผ ์ˆ˜ ์žˆ์„ ๋งŒํผ ์ข‹์€ ๊ฐ€์งœ ์ƒ์„ฑ

์ ๋Œ€์  ํ•™์Šต (Adversarial Training)

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”     โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚  Noise z   โ”‚โ”€โ”€โ”€โ”€โ–ถโ”‚ Generator  โ”‚โ”€โ”€โ”€โ”ฌโ”€โ”€โ–ถ ๊ฐ€์งœ ์ด๋ฏธ์ง€
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜     โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜   โ”‚
                                    โ”‚
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                      โ–ผ
โ”‚ ์ง„์งœ ์ด๋ฏธ์ง€โ”‚โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ–ถ Discriminator โ”€โ”€โ–ถ ์ง„์งœ/๊ฐ€์งœ
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

D ํ•™์Šต: ์ง„์งœ=1, ๊ฐ€์งœ=0 ํŒ๋ณ„ ์ •ํ™•๋„ ์ตœ๋Œ€ํ™”
G ํ•™์Šต: D๊ฐ€ ๊ฐ€์งœ๋ฅผ ์ง„์งœ๋กœ ํŒ๋ณ„ํ•˜๋„๋ก ์œ ๋„

Min-Max ๊ฒŒ์ž„

# GAN ๋ชฉ์  ํ•จ์ˆ˜ (min-max game)
# min_G max_D V(D, G) = E[log D(x)] + E[log(1 - D(G(z)))]

# D์˜ ๋ชฉํ‘œ: V(D, G) ์ตœ๋Œ€ํ™”
#   - D(x) โ†’ 1 (์ง„์งœ๋ฅผ ์ง„์งœ๋กœ)
#   - D(G(z)) โ†’ 0 (๊ฐ€์งœ๋ฅผ ๊ฐ€์งœ๋กœ)

# G์˜ ๋ชฉํ‘œ: V(D, G) ์ตœ์†Œํ™”
#   - D(G(z)) โ†’ 1 (D๊ฐ€ ๊ฐ€์งœ๋ฅผ ์ง„์งœ๋กœ ํŒ๋‹จํ•˜๊ฒŒ)

2. ๊ธฐ๋ณธ GAN ๊ตฌํ˜„

Generator

import torch
import torch.nn as nn

class Generator(nn.Module):
    """๊ฐ„๋‹จํ•œ Generator (โญโญ)"""
    def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()  # ์ถœ๋ ฅ: [-1, 1]
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(img.size(0), *self.img_shape)

Discriminator

class Discriminator(nn.Module):
    """๊ฐ„๋‹จํ•œ Discriminator (โญโญ)"""
    def __init__(self, img_shape=(1, 28, 28)):
        super().__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()  # ์ถœ๋ ฅ: [0, 1] ํ™•๋ฅ 
        )

    def forward(self, img):
        return self.model(img)

ํ•™์Šต ๋ฃจํ”„

def train_gan(generator, discriminator, dataloader, epochs=100, latent_dim=100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    generator.to(device)
    discriminator.to(device)

    criterion = nn.BCELoss()

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            batch_size = real_imgs.size(0)
            real_imgs = real_imgs.to(device)

            # ๋ ˆ์ด๋ธ”
            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # ==================
            # Discriminator ํ•™์Šต
            # ==================
            optimizer_D.zero_grad()

            # ์ง„์งœ ์ด๋ฏธ์ง€
            real_output = discriminator(real_imgs)
            d_loss_real = criterion(real_output, real_labels)

            # ๊ฐ€์งœ ์ด๋ฏธ์ง€
            z = torch.randn(batch_size, latent_dim, device=device)
            fake_imgs = generator(z)
            fake_output = discriminator(fake_imgs.detach())
            d_loss_fake = criterion(fake_output, fake_labels)

            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_D.step()

            # ==================
            # Generator ํ•™์Šต
            # ==================
            optimizer_G.zero_grad()

            # D๊ฐ€ ๊ฐ€์งœ๋ฅผ ์ง„์งœ๋กœ ํŒ๋‹จํ•˜๋„๋ก
            fake_output = discriminator(fake_imgs)
            g_loss = criterion(fake_output, real_labels)  # ์ง„์งœ ๋ ˆ์ด๋ธ” ์‚ฌ์šฉ

            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{epochs}] D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

3. ์†์‹ค ํ•จ์ˆ˜

Vanilla GAN Loss (BCE)

# Binary Cross Entropy
criterion = nn.BCELoss()

# D loss
d_loss = criterion(D(real), 1) + criterion(D(G(z)), 0)

# G loss (non-saturating)
g_loss = criterion(D(G(z)), 1)  # -log(D(G(z)))

# G loss (original, saturating)
# g_loss = -criterion(D(G(z)), 0)  # log(1 - D(G(z))) - ์ž˜ ์•ˆ ์”€

Wasserstein Loss (WGAN)

def wasserstein_loss(y_pred, y_true):
    """Wasserstein distance (Earth Mover's Distance)"""
    return torch.mean(y_pred * y_true)

# D (Critic) loss - ์ตœ๋Œ€ํ™”
d_loss = torch.mean(D(real)) - torch.mean(D(G(z)))
# โ†’ ์ตœ์†Œํ™”ํ•˜๋ ค๋ฉด ๋ถ€ํ˜ธ ๋ฐ˜์ „: -D(real) + D(G(z))

# G loss - ์ตœ์†Œํ™”
g_loss = -torch.mean(D(G(z)))

# Weight Clipping (WGAN)
for p in discriminator.parameters():
    p.data.clamp_(-0.01, 0.01)

WGAN-GP (Gradient Penalty)

def gradient_penalty(discriminator, real_imgs, fake_imgs, device):
    """Gradient Penalty for WGAN-GP (โญโญโญ)"""
    batch_size = real_imgs.size(0)

    # ๋žœ๋ค interpolation
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolated = alpha * real_imgs + (1 - alpha) * fake_imgs
    interpolated.requires_grad_(True)

    # Discriminator ์ถœ๋ ฅ
    d_interpolated = discriminator(interpolated)

    # Gradient ๊ณ„์‚ฐ
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]

    # Gradient norm
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)

    # Penalty: (||grad|| - 1)^2
    penalty = ((gradient_norm - 1) ** 2).mean()

    return penalty

# WGAN-GP Loss
lambda_gp = 10
gp = gradient_penalty(discriminator, real_imgs, fake_imgs, device)
d_loss = -torch.mean(D(real)) + torch.mean(D(G(z))) + lambda_gp * gp

Hinge Loss

# Discriminator loss
d_loss_real = torch.mean(torch.relu(1.0 - D(real)))
d_loss_fake = torch.mean(torch.relu(1.0 + D(G(z))))
d_loss = d_loss_real + d_loss_fake

# Generator loss
g_loss = -torch.mean(D(G(z)))

4. DCGAN ์•„ํ‚คํ…์ฒ˜

ํ•ต์‹ฌ ์›์น™

1. Pooling ์ œ๊ฑฐ โ†’ Strided Conv (D), Transposed Conv (G)
2. BatchNorm ์‚ฌ์šฉ (G์˜ ์ถœ๋ ฅ์ธต, D์˜ ์ž…๋ ฅ์ธต ์ œ์™ธ)
3. G์—์„œ ReLU (์ถœ๋ ฅ์ธต์€ Tanh)
4. D์—์„œ LeakyReLU

DCGAN Generator

class DCGANGenerator(nn.Module):
    """DCGAN Generator (โญโญโญ)

    z (100,) โ†’ (1024, 4, 4) โ†’ (512, 8, 8) โ†’ (256, 16, 16) โ†’ (128, 32, 32) โ†’ (3, 64, 64)
    """
    def __init__(self, latent_dim=100, ngf=64, nc=3):
        super().__init__()

        self.main = nn.Sequential(
            # ์ž…๋ ฅ: z (latent_dim,)
            # ์ถœ๋ ฅ: (ngf*8, 4, 4)
            nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            # (ngf*8, 4, 4) โ†’ (ngf*4, 8, 8)
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            # (ngf*4, 8, 8) โ†’ (ngf*2, 16, 16)
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            # (ngf*2, 16, 16) โ†’ (ngf, 32, 32)
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            # (ngf, 32, 32) โ†’ (nc, 64, 64)
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, z):
        # z: (batch, latent_dim) โ†’ (batch, latent_dim, 1, 1)
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.main(z)

DCGAN Discriminator

class DCGANDiscriminator(nn.Module):
    """DCGAN Discriminator (โญโญโญ)

    (3, 64, 64) โ†’ (64, 32, 32) โ†’ (128, 16, 16) โ†’ (256, 8, 8) โ†’ (512, 4, 4) โ†’ (1,)
    """
    def __init__(self, nc=3, ndf=64):
        super().__init__()

        self.main = nn.Sequential(
            # (nc, 64, 64) โ†’ (ndf, 32, 32)
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf, 32, 32) โ†’ (ndf*2, 16, 16)
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf*2, 16, 16) โ†’ (ndf*4, 8, 8)
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf*4, 8, 8) โ†’ (ndf*8, 4, 4)
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            # (ndf*8, 4, 4) โ†’ (1, 1, 1)
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight, 1.0, 0.02)
                nn.init.constant_(m.bias, 0)

    def forward(self, img):
        return self.main(img).view(-1, 1)

5. ํ•™์Šต ์•ˆ์ •ํ™” ๊ธฐ๋ฒ•

Spectral Normalization

from torch.nn.utils import spectral_norm

class SNDiscriminator(nn.Module):
    """Spectral Normalization ์ ์šฉ Discriminator (โญโญโญ)"""
    def __init__(self, nc=3, ndf=64):
        super().__init__()

        self.main = nn.Sequential(
            spectral_norm(nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True),

            spectral_norm(nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False))
        )

    def forward(self, img):
        return self.main(img).view(-1, 1)

Label Smoothing

# One-sided label smoothing
real_labels = torch.ones(batch_size, 1, device=device) * 0.9  # 1.0 โ†’ 0.9
fake_labels = torch.zeros(batch_size, 1, device=device)

# ๋˜๋Š” noisy labels
real_labels = 0.7 + 0.3 * torch.rand(batch_size, 1, device=device)  # [0.7, 1.0]

Two Time-Scale Update Rule (TTUR)

# D๋Š” ๋” ๋†’์€ ํ•™์Šต๋ฅ 
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0004, betas=(0.0, 0.9))
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0001, betas=(0.0, 0.9))

Progressive Growing

# ์ž‘์€ ํ•ด์ƒ๋„์—์„œ ์‹œ์ž‘ํ•ด์„œ ์ ์ง„์ ์œผ๋กœ ํ‚ค์›€
# ProGAN (Progressive GAN)

resolutions = [4, 8, 16, 32, 64, 128, 256, 512, 1024]

# ๊ฐ ํ•ด์ƒ๋„์—์„œ ์ผ์ • epoch ํ•™์Šต ํ›„ ๋‹ค์Œ ํ•ด์ƒ๋„๋กœ
# Fade-in: ์ƒˆ ๋ ˆ์ด์–ด๋ฅผ ์ ์ง„์ ์œผ๋กœ ์ถ”๊ฐ€

6. StyleGAN ๊ฐœ์š”

ํ•ต์‹ฌ ์•„์ด๋””์–ด

๊ธฐ์กด GAN: z โ†’ G โ†’ ์ด๋ฏธ์ง€
StyleGAN: z โ†’ Mapping Network โ†’ w โ†’ Synthesis Network โ†’ ์ด๋ฏธ์ง€

Mapping Network: 8์ธต MLP, z๋ฅผ "disentangled" w๋กœ ๋ณ€ํ™˜
AdaIN: w๋ฅผ ์‚ฌ์šฉํ•ด ๊ฐ ๋ ˆ์ด์–ด์˜ ์Šคํƒ€์ผ ์ฃผ์ž…

Mapping Network

class MappingNetwork(nn.Module):
    """StyleGAN Mapping Network (โญโญโญโญ)"""
    def __init__(self, latent_dim=512, w_dim=512, num_layers=8):
        super().__init__()

        layers = []
        for i in range(num_layers):
            layers.append(nn.Linear(latent_dim if i == 0 else w_dim, w_dim))
            layers.append(nn.LeakyReLU(0.2))

        self.mapping = nn.Sequential(*layers)

    def forward(self, z):
        return self.mapping(z)

AdaIN (Adaptive Instance Normalization)

class AdaIN(nn.Module):
    """Adaptive Instance Normalization (โญโญโญโญ)"""
    def __init__(self, num_features, w_dim):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features)
        self.style = nn.Linear(w_dim, num_features * 2)  # scale + bias

    def forward(self, x, w):
        # x: (batch, channels, H, W)
        # w: (batch, w_dim)

        normalized = self.norm(x)

        style = self.style(w)  # (batch, channels*2)
        gamma, beta = style.chunk(2, dim=1)
        gamma = gamma.unsqueeze(-1).unsqueeze(-1)
        beta = beta.unsqueeze(-1).unsqueeze(-1)

        return gamma * normalized + beta

Style Mixing

# ์„œ๋กœ ๋‹ค๋ฅธ z์—์„œ w ์ƒ์„ฑ
z1, z2 = torch.randn(2, latent_dim)
w1, w2 = mapping(z1), mapping(z2)

# ํŠน์ • ๋ ˆ์ด์–ด๊นŒ์ง€๋Š” w1, ๊ทธ ์ดํ›„๋Š” w2 ์‚ฌ์šฉ
# โ†’ ์Šคํƒ€์ผ ํ˜ผํ•ฉ (coarse: w1, fine: w2)

7. ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ฐ ํ‰๊ฐ€

์ƒ˜ํ”Œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ

def generate_samples(generator, latent_dim, num_samples=64):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim, device=device)
        fake_imgs = generator(z)
    return fake_imgs

# ์‹œ๊ฐํ™”
import matplotlib.pyplot as plt
import torchvision.utils as vutils

def show_generated_images(images, nrow=8):
    """์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๊ทธ๋ฆฌ๋“œ ํ‘œ์‹œ"""
    # [-1, 1] โ†’ [0, 1]
    images = (images + 1) / 2
    grid = vutils.make_grid(images.cpu(), nrow=nrow, normalize=False)
    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.savefig('generated_samples.png')
    plt.close()

Interpolation

def interpolate_latent(generator, z1, z2, steps=10):
    """์ž ์žฌ ๊ณต๊ฐ„ ๋ณด๊ฐ„ (โญโญ)"""
    generator.eval()
    images = []

    with torch.no_grad():
        for alpha in torch.linspace(0, 1, steps):
            z = (1 - alpha) * z1 + alpha * z2
            img = generator(z.unsqueeze(0))
            images.append(img)

    return torch.cat(images, dim=0)

# Spherical interpolation (slerp) - ๋” ๋‚˜์€ ๊ฒฐ๊ณผ
def slerp(z1, z2, alpha):
    """๊ตฌ๋ฉด ์„ ํ˜• ๋ณด๊ฐ„"""
    omega = torch.acos((z1 * z2).sum() / (z1.norm() * z2.norm()))
    return (torch.sin((1 - alpha) * omega) / torch.sin(omega)) * z1 + \
           (torch.sin(alpha * omega) / torch.sin(omega)) * z2

FID (Frechet Inception Distance)

# FID ๊ณ„์‚ฐ (pytorch-fid ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์‚ฌ์šฉ)
# pip install pytorch-fid

# from pytorch_fid import fid_score
# fid = fid_score.calculate_fid_given_paths(
#     [real_images_path, fake_images_path],
#     batch_size=50,
#     device=device,
#     dims=2048
# )
# ๋‚ฎ์„์ˆ˜๋ก ์ข‹์Œ (0์ด ์™„๋ฒฝ)

8. MNIST GAN ์™„์ „ ์˜ˆ์ œ

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
latent_dim = 100
lr = 0.0002
batch_size = 64
epochs = 50

# ๋ฐ์ดํ„ฐ
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # [-1, 1]
])

mnist = datasets.MNIST('data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

# ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ
G = Generator(latent_dim=latent_dim, img_shape=(1, 28, 28)).to(device)
D = Discriminator(img_shape=(1, 28, 28)).to(device)

# ์˜ตํ‹ฐ๋งˆ์ด์ €
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

criterion = nn.BCELoss()

# ํ•™์Šต
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)

        # ๋ ˆ์ด๋ธ”
        real = torch.ones(batch_size, 1, device=device)
        fake = torch.zeros(batch_size, 1, device=device)

        # D ํ•™์Šต
        optimizer_D.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        fake_imgs = G(z)

        d_loss = criterion(D(real_imgs), real) + criterion(D(fake_imgs.detach()), fake)
        d_loss.backward()
        optimizer_D.step()

        # G ํ•™์Šต
        optimizer_G.zero_grad()
        g_loss = criterion(D(fake_imgs), real)
        g_loss.backward()
        optimizer_G.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: D={d_loss.item():.4f}, G={g_loss.item():.4f}")

        # ์ƒ˜ํ”Œ ์ €์žฅ
        with torch.no_grad():
            sample = G(torch.randn(16, latent_dim, device=device))
            # ์ €์žฅ ์ฝ”๋“œ...

์ •๋ฆฌ

ํ•ต์‹ฌ ๊ฐœ๋…

  1. GAN: Generator์™€ Discriminator์˜ ์ ๋Œ€์  ํ•™์Šต
  2. ์†์‹ค ํ•จ์ˆ˜: BCE, Wasserstein, Gradient Penalty
  3. DCGAN: Transposed Conv + BatchNorm + LeakyReLU
  4. ์•ˆ์ •ํ™”: Spectral Norm, TTUR, Progressive Growing
  5. StyleGAN: Mapping Network + AdaIN์œผ๋กœ ์Šคํƒ€์ผ ์ œ์–ด

GAN ํ•™์Šต ํŒ

1. D์™€ G ๊ท ํ˜• ์œ ์ง€ (D๊ฐ€ ๋„ˆ๋ฌด ๊ฐ•ํ•˜๋ฉด G๊ฐ€ ํ•™์Šต ๋ถˆ๊ฐ€)
2. BatchNorm์€ minibatch ์ „์ฒด์— ์ ์šฉ (์ง„์งœ/๊ฐ€์งœ ๋ถ„๋ฆฌํ•˜์ง€ ์•Š์Œ)
3. Adam beta1=0.5 ์‚ฌ์šฉ (momentum ์ค„์ž„)
4. ํ•™์Šต๋ฅ : ๋ณดํ†ต 0.0001 ~ 0.0002
5. ์ƒ์„ฑ ์ด๋ฏธ์ง€ ์ฃผ๊ธฐ์ ์œผ๋กœ ํ™•์ธ

์†์‹ค ํ•จ์ˆ˜ ๋น„๊ต

์†์‹ค ํ•จ์ˆ˜ ์žฅ์  ๋‹จ์ 
BCE ๊ฐ„๋‹จํ•จ Mode collapse
WGAN ์•ˆ์ •์  ํ•™์Šต Weight clipping
WGAN-GP ๋งค์šฐ ์•ˆ์ •์  ๊ณ„์‚ฐ ๋น„์šฉ
Hinge ๊ฐ„๋‹จํ•˜๊ณ  ํšจ๊ณผ์  -

๋‹ค์Œ ๋‹จ๊ณ„

30_Generative_Models_VAE.md์—์„œ VAE (Variational Autoencoder)๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

to navigate between lessons