32. Diffusion Models

์ด์ „: Variational Autoencoder (VAE) | ๋‹ค์Œ: ํ™•์‚ฐ ๋ชจ๋ธ(DDPM)


32. Diffusion Models

ํ•™์Šต ๋ชฉํ‘œ

  • Diffusion Process ์ด๋ก  ์ดํ•ด (Forward/Reverse)
  • DDPM (Denoising Diffusion Probabilistic Models) ์›๋ฆฌ
  • Score-based Generative Models ๊ฐœ๋…
  • U-Net ์•„ํ‚คํ…์ฒ˜ for Diffusion
  • Stable Diffusion ํ•ต์‹ฌ ์›๋ฆฌ
  • Classifier-free Guidance
  • ๊ฐ„๋‹จํ•œ DDPM PyTorch ๊ตฌํ˜„

1. Diffusion Process ๊ฐœ์š”

ํ•ต์‹ฌ ์•„์ด๋””์–ด

๋ฐ์ดํ„ฐ์— ์ ์ง„์ ์œผ๋กœ ๋…ธ์ด์ฆˆ๋ฅผ ์ถ”๊ฐ€ (Forward Process)
    x_0 โ†’ x_1 โ†’ x_2 โ†’ ... โ†’ x_T (์ˆœ์ˆ˜ ๋…ธ์ด์ฆˆ)

๋…ธ์ด์ฆˆ๋ฅผ ์ ์ง„์ ์œผ๋กœ ์ œ๊ฑฐํ•˜์—ฌ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ (Reverse Process)
    x_T โ†’ x_{T-1} โ†’ ... โ†’ x_0 (๊นจ๋—ํ•œ ์ด๋ฏธ์ง€)

ํ•ต์‹ฌ: Reverse Process๋ฅผ ์‹ ๊ฒฝ๋ง์œผ๋กœ ํ•™์Šต

์‹œ๊ฐ์  ์ดํ•ด

Forward (๋…ธ์ด์ฆˆ ์ถ”๊ฐ€):
[๊นจ๋—ํ•œ ์ด๋ฏธ์ง€] โ”€โ”€t=0โ”€โ”€โ–ถ [์•ฝ๊ฐ„ ๋…ธ์ด์ฆˆ] โ”€โ”€t=500โ”€โ”€โ–ถ [๋” ๋งŽ์€ ๋…ธ์ด์ฆˆ] โ”€โ”€t=1000โ”€โ”€โ–ถ [์™„์ „ ๋…ธ์ด์ฆˆ]

Reverse (๋…ธ์ด์ฆˆ ์ œ๊ฑฐ):
[์™„์ „ ๋…ธ์ด์ฆˆ] โ”€โ”€t=1000โ”€โ”€โ–ถ [์•ฝ๊ฐ„ ์„ ๋ช…] โ”€โ”€t=500โ”€โ”€โ–ถ [๋” ์„ ๋ช…] โ”€โ”€t=0โ”€โ”€โ–ถ [๊นจ๋—ํ•œ ์ด๋ฏธ์ง€]

2. DDPM (Denoising Diffusion Probabilistic Models)

Forward Process (q)

# Forward process: q(x_t | x_{t-1})
# x_t = sqrt(1 - beta_t) * x_{t-1} + sqrt(beta_t) * epsilon

# Closed form (ํ•œ ๋ฒˆ์— x_0์—์„œ x_t๋กœ):
# x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon

# alpha_t = 1 - beta_t
# alpha_bar_t = prod(alpha_1 * alpha_2 * ... * alpha_t)

์ˆ˜ํ•™์  ์ •์˜

import torch

def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    """์„ ํ˜• ๋…ธ์ด์ฆˆ ์Šค์ผ€์ค„ (โญโญ)"""
    return torch.linspace(beta_start, beta_end, timesteps)

def cosine_beta_schedule(timesteps, s=0.008):
    """์ฝ”์‚ฌ์ธ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ค„ (๋” ์ข‹์€ ์„ฑ๋Šฅ) (โญโญโญ)"""
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def get_index_from_list(vals, t, x_shape):
    """๋ฐฐ์น˜์˜ ๊ฐ ์ƒ˜ํ”Œ์— ๋งž๋Š” t ์‹œ์ ์˜ ๊ฐ’ ์ถ”์ถœ"""
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

Forward Diffusion ๊ตฌํ˜„

class DiffusionSchedule:
    """Diffusion ์Šค์ผ€์ค„ ๊ด€๋ฆฌ (โญโญโญ)"""
    def __init__(self, timesteps=1000, beta_schedule='linear'):
        self.timesteps = timesteps

        if beta_schedule == 'linear':
            betas = linear_beta_schedule(timesteps)
        else:
            betas = cosine_beta_schedule(timesteps)

        self.betas = betas
        self.alphas = 1.0 - betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)

        # ๊ณ„์‚ฐ์— ํ•„์š”ํ•œ ๊ฐ’๋“ค
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)

        # Posterior ๊ณ„์‚ฐ์šฉ
        self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)

    def q_sample(self, x_0, t, noise=None):
        """Forward process: x_0์—์„œ x_t ์ƒ˜ํ”Œ๋ง (โญโญโญ)

        x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
        """
        if noise is None:
            noise = torch.randn_like(x_0)

        sqrt_alphas_cumprod_t = get_index_from_list(
            self.sqrt_alphas_cumprod, t, x_0.shape
        )
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
            self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
        )

        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise

3. ๋…ธ์ด์ฆˆ ์˜ˆ์ธก ๋„คํŠธ์›Œํฌ

๋ชฉํ‘œ

๋ชจ๋ธ์ด x_t์—์„œ ์ถ”๊ฐ€๋œ ๋…ธ์ด์ฆˆ epsilon์„ ์˜ˆ์ธก
epsilon_theta(x_t, t) โ‰ˆ epsilon

์†์‹ค ํ•จ์ˆ˜:
L = E[||epsilon - epsilon_theta(x_t, t)||^2]

๊ฐ„๋‹จํ•œ U-Net ๊ตฌ์กฐ

import torch.nn as nn
import torch.nn.functional as F

class SinusoidalPositionEmbeddings(nn.Module):
    """์‹œ๊ฐ„ ์ž„๋ฒ ๋”ฉ (โญโญโญ)"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class Block(nn.Module):
    """๊ธฐ๋ณธ Conv Block"""
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)

        if up:
            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        time_emb = time_emb[(...,) + (None,) * 2]
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)


class SimpleUNet(nn.Module):
    """๊ฐ„๋‹จํ•œ U-Net for Diffusion (โญโญโญ)"""
    def __init__(self, in_channels=3, out_channels=3, time_dim=256):
        super().__init__()

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.ReLU()
        )

        # Initial projection
        self.conv0 = nn.Conv2d(in_channels, 64, 3, padding=1)

        # Downsampling
        self.downs = nn.ModuleList([
            Block(64, 128, time_dim),
            Block(128, 256, time_dim),
            Block(256, 256, time_dim),
        ])

        # Upsampling
        self.ups = nn.ModuleList([
            Block(256, 128, time_dim, up=True),
            Block(128, 64, time_dim, up=True),
            Block(64, 64, time_dim, up=True),
        ])

        # Output
        self.output = nn.Conv2d(64, out_channels, 1)

    def forward(self, x, timestep):
        # Time embedding
        t = self.time_mlp(timestep)

        # Initial conv
        x = self.conv0(x)

        # Downsampling
        residuals = []
        for down in self.downs:
            x = down(x, t)
            residuals.append(x)

        # Upsampling with skip connections
        for up in self.ups:
            residual = residuals.pop()
            x = torch.cat((x, residual), dim=1)
            x = up(x, t)

        return self.output(x)

4. ํ•™์Šต ๊ณผ์ •

ํ•™์Šต ์•Œ๊ณ ๋ฆฌ์ฆ˜ (DDPM)

1. x_0 ~ q(x_0): ๋ฐ์ดํ„ฐ์—์„œ ์ƒ˜ํ”Œ
2. t ~ Uniform(1, T): ๋žœ๋ค ํƒ€์ž„์Šคํ…
3. epsilon ~ N(0, I): ๋žœ๋ค ๋…ธ์ด์ฆˆ
4. x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
5. Loss = ||epsilon - epsilon_theta(x_t, t)||^2
6. ์—ญ์ „ํŒŒ ๋ฐ ์—…๋ฐ์ดํŠธ

ํ•™์Šต ์ฝ”๋“œ

def train_diffusion(model, schedule, dataloader, epochs=100, lr=1e-4):
    """Diffusion ๋ชจ๋ธ ํ•™์Šต (โญโญโญ)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        total_loss = 0

        for batch_idx, (images, _) in enumerate(dataloader):
            images = images.to(device)
            batch_size = images.size(0)

            # ๋žœ๋ค ํƒ€์ž„์Šคํ…
            t = torch.randint(0, schedule.timesteps, (batch_size,), device=device).long()

            # ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€
            noise = torch.randn_like(images)
            x_t = schedule.q_sample(images, t, noise)

            # ๋…ธ์ด์ฆˆ ์˜ˆ์ธก
            noise_pred = model(x_t, t)

            # ์†์‹ค ๊ณ„์‚ฐ
            loss = criterion(noise_pred, noise)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

5. ์ƒ˜ํ”Œ๋ง (Reverse Process)

DDPM ์ƒ˜ํ”Œ๋ง

@torch.no_grad()
def sample_ddpm(model, schedule, shape, device):
    """DDPM ์ƒ˜ํ”Œ๋ง (โญโญโญ)

    x_T ~ N(0, I)์—์„œ ์‹œ์ž‘ํ•˜์—ฌ x_0 ์ƒ์„ฑ
    """
    model.eval()

    # ์ˆœ์ˆ˜ ๋…ธ์ด์ฆˆ์—์„œ ์‹œ์ž‘
    x = torch.randn(shape, device=device)

    for i in reversed(range(schedule.timesteps)):
        t = torch.full((shape[0],), i, device=device, dtype=torch.long)

        betas_t = get_index_from_list(schedule.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
            schedule.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = get_index_from_list(
            schedule.sqrt_recip_alphas, t, x.shape
        )

        # ๋…ธ์ด์ฆˆ ์˜ˆ์ธก
        noise_pred = model(x, t)

        # Mean ๊ณ„์‚ฐ
        model_mean = sqrt_recip_alphas_t * (
            x - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
        )

        if i > 0:
            posterior_variance_t = get_index_from_list(
                schedule.posterior_variance, t, x.shape
            )
            noise = torch.randn_like(x)
            x = model_mean + torch.sqrt(posterior_variance_t) * noise
        else:
            x = model_mean

    return x

DDIM ์ƒ˜ํ”Œ๋ง (๋” ๋น ๋ฆ„)

@torch.no_grad()
def sample_ddim(model, schedule, shape, device, num_inference_steps=50, eta=0.0):
    """DDIM ์ƒ˜ํ”Œ๋ง (โญโญโญโญ)

    ๋” ์ ์€ ์Šคํ…์œผ๋กœ ๋น ๋ฅธ ์ƒ˜ํ”Œ๋ง
    eta=0: ๊ฒฐ์ •๋ก ์ , eta=1: DDPM๊ณผ ๋™์ผ
    """
    model.eval()

    # ์Šคํ… ๊ฐ„๊ฒฉ
    step_size = schedule.timesteps // num_inference_steps
    timesteps = list(range(0, schedule.timesteps, step_size))
    timesteps = list(reversed(timesteps))

    x = torch.randn(shape, device=device)

    for i, t in enumerate(timesteps):
        t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)

        alpha_cumprod_t = schedule.alphas_cumprod[t]
        alpha_cumprod_prev = schedule.alphas_cumprod[timesteps[i+1]] if i < len(timesteps)-1 else 1.0

        # ๋…ธ์ด์ฆˆ ์˜ˆ์ธก
        noise_pred = model(x, t_tensor)

        # x_0 ์˜ˆ์ธก
        pred_x0 = (x - torch.sqrt(1 - alpha_cumprod_t) * noise_pred) / torch.sqrt(alpha_cumprod_t)

        # ๋ฐฉํ–ฅ ๊ณ„์‚ฐ
        sigma = eta * torch.sqrt((1 - alpha_cumprod_prev) / (1 - alpha_cumprod_t)) * \
                     torch.sqrt(1 - alpha_cumprod_t / alpha_cumprod_prev)

        pred_dir = torch.sqrt(1 - alpha_cumprod_prev - sigma**2) * noise_pred

        # ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€ (eta > 0์ธ ๊ฒฝ์šฐ)
        noise = torch.randn_like(x) if eta > 0 else 0

        x = torch.sqrt(alpha_cumprod_prev) * pred_x0 + pred_dir + sigma * noise

    return x

6. Score-based Models

Score Function

Score = gradient of log probability
s(x) = โˆ‡_x log p(x)

๋…ธ์ด์ฆˆ๊ฐ€ ์ถ”๊ฐ€๋œ ๋ฐ์ดํ„ฐ์˜ score:
s_theta(x_t, t) โ‰ˆ โˆ‡_{x_t} log p(x_t)

DDPM๊ณผ์˜ ๊ด€๊ณ„

# DDPM์—์„œ ๋…ธ์ด์ฆˆ ์˜ˆ์ธก๊ณผ score์˜ ๊ด€๊ณ„:
# epsilon_theta(x_t, t) = -sqrt(1 - alpha_bar_t) * s_theta(x_t, t)

# Score ์˜ˆ์ธก โ†’ ๋…ธ์ด์ฆˆ ์˜ˆ์ธก์œผ๋กœ ๋ณ€ํ™˜ ๊ฐ€๋Šฅ

7. Stable Diffusion ์›๋ฆฌ

Latent Diffusion

์ด๋ฏธ์ง€ ๊ณต๊ฐ„์ด ์•„๋‹Œ ์ž ์žฌ ๊ณต๊ฐ„์—์„œ diffusion

1. Encoder: ์ด๋ฏธ์ง€ โ†’ ์ž ์žฌ ํ‘œํ˜„ z
2. Diffusion: z์—์„œ ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€/์ œ๊ฑฐ
3. Decoder: ์ž ์žฌ ํ‘œํ˜„ โ†’ ์ด๋ฏธ์ง€

์žฅ์ :
- ๊ณ„์‚ฐ ํšจ์œจ์„ฑ (์ž‘์€ ํ•ด์ƒ๋„์—์„œ diffusion)
- ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๊ฐ€๋Šฅ

์•„ํ‚คํ…์ฒ˜

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚  Text Prompt โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
       โ”‚ CLIP Text Encoder
       โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              Cross-Attention              โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                          โ”‚
โ”‚  z_T โ”€โ”€โ–ถ U-Net โ”€โ”€โ–ถ z_{T-1} โ”€โ”€โ–ถ ... โ”€โ”€โ–ถ z_0  โ”‚
โ”‚         (time embedding)                 โ”‚
โ”‚                                          โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
       โ”‚ VAE Decoder
       โ–ผ
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚    Image     โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ (Cross-Attention)

class CrossAttention(nn.Module):
    """Text-Image Cross Attention (โญโญโญโญ)"""
    def __init__(self, query_dim, context_dim, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context):
        # x: ์ด๋ฏธ์ง€ ํŠน์ง• (batch, hw, dim)
        # context: ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ (batch, seq_len, context_dim)

        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        # Multi-head attention
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        out = attn @ v
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

8. Classifier-free Guidance

์•„์ด๋””์–ด

์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ๊ณผ ๋น„์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ์„ ํ˜ผํ•ฉ

epsilon_guided = epsilon_uncond + w * (epsilon_cond - epsilon_uncond)

w > 1: ์กฐ๊ฑด์„ ๋” ๊ฐ•ํ•˜๊ฒŒ ๋ฐ˜์˜ (๋” ์„ ๋ช…ํ•˜์ง€๋งŒ ๋‹ค์–‘์„ฑ ๊ฐ์†Œ)
w = 1: ์ผ๋ฐ˜ ์กฐ๊ฑด๋ถ€ ์ƒ์„ฑ
w < 1: ์กฐ๊ฑด ์•ฝํ™”

๊ตฌํ˜„

def classifier_free_guidance_sample(model, schedule, shape, condition, w=7.5, device='cuda'):
    """Classifier-free Guidance ์ƒ˜ํ”Œ๋ง (โญโญโญโญ)"""
    model.eval()

    x = torch.randn(shape, device=device)

    for i in reversed(range(schedule.timesteps)):
        t = torch.full((shape[0],), i, device=device, dtype=torch.long)

        # ์กฐ๊ฑด๋ถ€ ์˜ˆ์ธก
        noise_cond = model(x, t, condition)

        # ๋น„์กฐ๊ฑด๋ถ€ ์˜ˆ์ธก (์กฐ๊ฑด = None ๋˜๋Š” ๋นˆ ์ž„๋ฒ ๋”ฉ)
        noise_uncond = model(x, t, None)

        # Guidance
        noise_pred = noise_uncond + w * (noise_cond - noise_uncond)

        # ์ƒ˜ํ”Œ๋ง ์Šคํ… (DDPM ๋˜๋Š” DDIM)
        x = sampling_step(x, noise_pred, t, schedule)

    return x

ํ•™์Šต ์‹œ ์กฐ๊ฑด ๋“œ๋กญ์•„์›ƒ

def train_with_cfg(model, dataloader, drop_prob=0.1):
    """CFG๋ฅผ ์œ„ํ•œ ํ•™์Šต (์กฐ๊ฑด ๋“œ๋กญ์•„์›ƒ) (โญโญโญ)"""
    for images, conditions in dataloader:
        # ์ผ์ • ํ™•๋ฅ ๋กœ ์กฐ๊ฑด์„ None์œผ๋กœ
        mask = torch.rand(images.size(0)) < drop_prob
        conditions[mask] = None  # ๋˜๋Š” ๋นˆ ์ž„๋ฒ ๋”ฉ

        # ์ผ๋ฐ˜ ํ•™์Šต...

9. ๊ฐ„๋‹จํ•œ DDPM ์ „์ฒด ์˜ˆ์ œ

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ์„ค์ •
image_size = 28
channels = 1
timesteps = 1000
batch_size = 64
epochs = 50
lr = 1e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ๋ฐ์ดํ„ฐ
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) - 1)  # [0, 1] โ†’ [-1, 1]
])

train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

# ์Šค์ผ€์ค„
schedule = DiffusionSchedule(timesteps=timesteps, beta_schedule='linear')

# ๋ชจ๋ธ (๊ฐ„๋‹จํ•œ ๋ฒ„์ „)
model = SimpleUNet(in_channels=channels, out_channels=channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# ํ•™์Šต
for epoch in range(epochs):
    total_loss = 0

    for images, _ in train_loader:
        images = images.to(device)
        batch_size = images.size(0)

        # ๋žœ๋ค ํƒ€์ž„์Šคํ…
        t = torch.randint(0, timesteps, (batch_size,), device=device).long()

        # ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€ (forward process)
        noise = torch.randn_like(images)
        x_t = schedule.q_sample(images, t, noise)

        # ๋…ธ์ด์ฆˆ ์˜ˆ์ธก
        noise_pred = model(x_t, t)

        # ์†์‹ค
        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

# ์ƒ˜ํ”Œ๋ง
model.eval()
with torch.no_grad():
    samples = sample_ddpm(model, schedule, (16, channels, image_size, image_size), device)
    samples = (samples + 1) / 2  # [-1, 1] โ†’ [0, 1]

# ์‹œ๊ฐํ™”
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(samples[i, 0].cpu(), cmap='gray')
    ax.axis('off')
plt.savefig('diffusion_samples.png')
print("์ƒ˜ํ”Œ ์ €์žฅ: diffusion_samples.png")

์ •๋ฆฌ

ํ•ต์‹ฌ ๊ฐœ๋…

  1. Forward Process: ์ ์ง„์  ๋…ธ์ด์ฆˆ ์ถ”๊ฐ€ q(x_t|x_0)
  2. Reverse Process: ์ ์ง„์  ๋…ธ์ด์ฆˆ ์ œ๊ฑฐ p(x_{t-1}|x_t)
  3. DDPM: ๋…ธ์ด์ฆˆ ์˜ˆ์ธก์œผ๋กœ ์—ญ๊ณผ์ • ํ•™์Šต
  4. DDIM: ๊ฒฐ์ •๋ก ์  ์ƒ˜ํ”Œ๋ง์œผ๋กœ ๋น ๋ฅธ ์ƒ์„ฑ
  5. Latent Diffusion: ์ž ์žฌ ๊ณต๊ฐ„์—์„œ ํšจ์œจ์  ์ƒ์„ฑ
  6. CFG: ์กฐ๊ฑด ๊ฐ•๋„ ์กฐ์ ˆ

ํ•ต์‹ฌ ์ˆ˜์‹

Forward: x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon

Loss: L = E[||epsilon - epsilon_theta(x_t, t)||^2]

Reverse: x_{t-1} = (1/sqrt(alpha_t)) * (x_t - (beta_t/sqrt(1-alpha_bar_t)) * epsilon_theta) + sigma_t * z

Diffusion vs GAN vs VAE

ํŠน์„ฑ Diffusion GAN VAE
ํ•™์Šต ์•ˆ์ •์„ฑ ๋งค์šฐ ๋†’์Œ ๋‚ฎ์Œ ๋†’์Œ
์ด๋ฏธ์ง€ ํ’ˆ์งˆ ์ตœ๊ณ  ์ข‹์Œ ํ๋ฆผ
์ƒ˜ํ”Œ๋ง ์†๋„ ๋А๋ฆผ ๋น ๋ฆ„ ๋น ๋ฆ„
Mode Coverage ์ข‹์Œ Mode Collapse ์ข‹์Œ
๋ฐ€๋„ ์ถ”์ • ๊ฐ€๋Šฅ ๋ถˆ๊ฐ€ ๊ฐ€๋Šฅ

๋‹ค์Œ ๋‹จ๊ณ„

17_Attention_Deep_Dive.md์—์„œ Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์‹ฌ์ธต์ ์œผ๋กœ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.

to navigate between lessons