33. νμ° λͺ¨λΈ(Diffusion Models, DDPM)
μ΄μ : Diffusion Models | λ€μ: CLIPκ³Ό λ©ν°λͺ¨λ¬ νμ΅
33. νμ° λͺ¨λΈ(Diffusion Models, DDPM)¶
κ°μ¶
λλ Έμ΄μ§ νμ° νλ₯ λͺ¨λΈ(Denoising Diffusion Probabilistic Models, DDPM)μ μ μ§μ μΈ λ Έμ΄μ¦ μΆκ° κ³Όμ μ μμ μμΌ λ°μ΄ν°λ₯Ό μμ±νλ κ°λ ₯ν μμ± λͺ¨λΈμ λλ€. "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
μνμ λ°°κ²½¶
1. μλ°©ν₯ νμ° κ³Όμ (Forward Diffusion Process)¶
λͺ©ν: λ°μ΄ν° xβμ μ μ§μ μΌλ‘ κ°μ°μμ λ
Έμ΄μ¦ μΆκ°
q(xβ|xβββ) = N(xβ; β(1-Ξ²β)xβββ, Ξ²βI)
μ¬κΈ°μ:
- xβ: μλ³Έ λ°μ΄ν°
- xβ: νμμ€ν
tμμμ λ
Έμ΄μ¦κ° μλ λ°μ΄ν°
- Ξ²β: λ
Έμ΄μ¦ μ€μΌμ€ (Ξ²β, ..., Ξ²β)
- T: μ 체 νμμ€ν
(μΌλ°μ μΌλ‘ 1000)
λ«ν νμ(Closed form) (Ξ±β = 1 - Ξ²β, αΎ±β = βα΅’ββα΅ Ξ±α΅’ μ¬μ©):
q(xβ|xβ) = N(xβ; βαΎ±β xβ, (1-αΎ±β)I)
xβ = βαΎ±β xβ + β(1-αΎ±β) Ξ΅, Ξ΅ ~ N(0, I)
t β TμΌ λ: xβ β N(0, I) (μμ λ
Έμ΄μ¦)
2. μλ°©ν₯ νμ° κ³Όμ (Reverse Diffusion Process)¶
λͺ©ν: λλ
Έμ΄μ§ p(xβββ|xβ) νμ΅
μ€μ μ¬ν λΆν¬(Intractable):
q(xβββ|xβ, xβ) = N(xβββ; ΞΌΜβ(xβ, xβ), Ξ²ΜβI)
μ¬κΈ°μ:
ΞΌΜβ(xβ, xβ) = (βαΎ±βββ Ξ²β)/(1-αΎ±β) xβ + (βΞ±β(1-αΎ±βββ))/(1-αΎ±β) xβ
Ξ²Μβ = (1-αΎ±βββ)/(1-αΎ±β) Β· Ξ²β
νμ΅λ μλ°©ν₯ κ³Όμ :
pΞΈ(xβββ|xβ) = N(xβββ; ΞΌΞΈ(xβ, t), Σθ(xβ, t))
λ¨μν: νκ· λμ λ
Έμ΄μ¦ Ξ΅ μμΈ‘
Ρθ(xβ, t) β Ξ΅
3. νμ΅ λͺ©μ ν¨μ(Training Objective)¶
λ³λΆ νν(Variational Lower Bound, ELBO):
L = Eβ,xβ,Ξ΅[||Ξ΅ - Ρθ(xβ, t)||Β²]
μ¬κΈ°μ:
- t ~ Uniform(1, T)
- xβ ~ q(xβ)
- Ξ΅ ~ N(0, I)
- xβ = βαΎ±β xβ + β(1-αΎ±β) Ξ΅
μμΈ‘λ λ
Έμ΄μ¦μ λν λ¨μν MSE μμ€!
βββββββββββββββββββββββββββββββββββββββββββ
β νμ΅: β
β 1. xβ, t, Ξ΅ μνλ§ β
β 2. xβ = βαΎ±β xβ + β(1-αΎ±β) Ξ΅ μμ± β
β 3. Ξ΅Μ = Ρθ(xβ, t) μμΈ‘ β
β 4. μμ€ = ||Ξ΅ - Ξ΅Μ||Β² β
βββββββββββββββββββββββββββββββββββββββββββ
4. μνλ§(μμ±, Sampling/Generation)¶
xβ ~ N(0, I)μμ μμ
t = T, T-1, ..., 1μ λν΄:
z ~ N(0, I) (t > 1μΌ λ), κ·Έλ μ§ μμΌλ©΄ z = 0
Ξ΅Μ = Ρθ(xβ, t)
xβββ = 1/βΞ±β (xβ - (1-Ξ±β)/β(1-αΎ±β) Ξ΅Μ) + Οβz
μ¬κΈ°μ:
Οβ = βΞ²Μβ λλ βΞ²β (λΆμ° μ€μΌμ€)
μ΅μ’
: xβκ° μμ±λ μν
DDPM μν€ν μ²¶
μκ° μλ² λ©μ κ°λ UNet(UNet with Time Embedding)¶
μκ° μλ² λ©(Sinusoidal Positional Encoding):
t (μ€μΉΌλΌ)
β
PE(t, dim) = [sin(t/10000^(0/d)), cos(t/10000^(0/d)),
sin(t/10000^(2/d)), cos(t/10000^(2/d)), ...]
β
Linear(dimβ4*dim) + SiLU + Linear(4*dimβ4*dim)
β
time_emb (κ³΅κ° μ°¨μμΌλ‘ λΈλ‘λμΊμ€νΈ)
UNet ꡬ쑰 (μ: 32Γ32Γ3 μ΄λ―Έμ§):
μ
λ ₯ xβ (32Γ32Γ3) + time_emb
β
βββββββββββββββββββββββββββββββββββββββββββ
β μΈμ½λ (λ€μ΄μνλ§) β
βββββββββββββββββββββββββββββββββββββββββββ€
β Conv(3β64) + TimeEmb + ResBlock β β skip1
β β Downsample β
β Conv(64β128) + TimeEmb + ResBlock β β skip2
β β Downsample β
β Conv(128β256) + TimeEmb + ResBlock β β skip3
β β Downsample β
βββββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββββ
β λ³λͺ©μΈ΅(Bottleneck) β
β Conv(256β512) + Attention + ResBlock β
βββββββββββββββββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββββββββββββββββββββββ
β λμ½λ (μ
μνλ§) β
βββββββββββββββββββββββββββββββββββββββββββ€
β β Upsample + Concat(skip3) β
β Conv(512+256β256) + TimeEmb + ResBlock β
β β Upsample + Concat(skip2) β
β Conv(256+128β128) + TimeEmb + ResBlock β
β β Upsample + Concat(skip1) β
β Conv(128+64β64) + TimeEmb + ResBlock β
βββββββββββββββββββββββββββββββββββββββββββ
β
Conv(64β3) + GroupNorm
β
μΆλ ₯ Ρθ(xβ, t) (32Γ32Γ3)
μκ° μλ² λ©μ κ°λ ResBlock(ResBlock with Time Embedding)¶
x, time_emb β ResBlock β out
βββββββββββββββββββββββββββββββββββββββββββ
β GroupNorm β SiLU β Conv β
β β β
β + time_emb (λΈλ‘λμΊμ€νΈ) β
β β β
β GroupNorm β SiLU β Conv β
β β β
β + skip connection (νλ‘μ μ
ν¬ν¨) β
βββββββββββββββββββββββββββββββββββββββββββ
λ Έμ΄μ¦ μ€μΌμ€(Noise Schedule)¶
μ ν μ€μΌμ€(Linear Schedule)¶
# μ ν μ€μΌμ€ (Ho et al., 2020)
Ξ²β = 1e-4
Ξ²β = 0.02
Ξ²β = linear_interpolate(Ξ²β, Ξ²β, t/T)
# ν¨μ¨μ±μ μν μ¬μ κ³μ°
Ξ±β = 1 - Ξ²β
αΎ±β = βα΅’ββα΅ Ξ±α΅’
βαΎ±β, β(1-αΎ±β) # μλ°©ν₯ κ³Όμ μμ μ¬μ©
μ½μ¬μΈ μ€μΌμ€(κ°μ λ λ²μ , Cosine Schedule - Improved)¶
# μ½μ¬μΈ μ€μΌμ€ (Nichol & Dhariwal, 2021)
s = 0.008
f(t) = cosΒ²((t/T + s)/(1 + s) Β· Ο/2)
αΎ±β = f(t) / f(0)
Ξ²β = 1 - Ξ±β/Ξ±βββ
# λ λΆλλ¬μ΄ λ
Έμ΄μ¦ μ€μΌμ€, κ³ ν΄μλμ λ μ ν©
νμΌ κ΅¬μ‘°¶
13_Diffusion/
βββ README.md
βββ pytorch_lowlevel/
β βββ ddpm_mnist.py # MNISTμ© DDPM (28Γ28)
β βββ ddpm_cifar.py # CIFAR-10μ© DDPM (32Γ32)
βββ paper/
β βββ ddpm_paper.py # μ 체 DDPM ꡬν
β βββ ddim_sampling.py # DDIM λΉ λ₯Έ μνλ§
β βββ cosine_schedule.py # κ°μ λ λ
Έμ΄μ¦ μ€μΌμ€
βββ exercises/
βββ 01_noise_schedule.md # λ
Έμ΄μ¦ μ€μΌμ€ μκ°ν
βββ 02_sampling_steps.md # DDPM vs DDIM λΉκ΅
ν΅μ¬ κ°λ ¶
1. DDPM vs DDIM μνλ§¶
DDPM (Ho et al., 2020):
- νλ₯ μ μνλ§(κ° λ¨κ³μμ λ
Έμ΄μ¦ z μΆκ°)
- T λ¨κ³ νμ (μ: 1000 λ¨κ³)
- κ³ νμ§μ΄μ§λ§ λλ¦Ό
DDIM (Song et al., 2020):
- κ²°μ μ μνλ§ (z = 0)
- νμμ€ν
건λλ°κΈ°: λΆλΆμ§ν© μ¬μ© [Οβ, Οβ, ..., Οβ]
- 10-50λ°° λΉ λ¦ (μ: 50 λ¨κ³)
- νμ§ μ½κ° μ ν
DDIM μ
λ°μ΄νΈ:
xβββ = βαΎ±βββ xΜβ + β(1-αΎ±βββ) Ρθ(xβ, t)
μ¬κΈ°μ xΜβ = (xβ - β(1-αΎ±β)Ρθ(xβ, t))/βαΎ±β
2. λΆλ₯κΈ° κ°μ΄λμ€(Classifier Guidance)¶
λͺ©ν: ν΄λμ€ yμ 쑰건νλ μν μμ±
μ‘°κ±΄λΆ μ€μ½μ΄:
ββ log p(xβ|y) β ββ log p(xβ) + sΒ·ββ log p(y|xβ)
βββββββββββββ βββββββββββββββββ
λ¬΄μ‘°κ±΄λΆ λΆλ₯κΈ° κ·ΈλλμΈνΈ
κ°μ΄λλ λ
Έμ΄μ¦ μμΈ‘:
Ξ΅Μ = Ρθ(xβ, t) - sΒ·β(1-αΎ±β)Β·ββ log pΟ(y|xβ)
s: κ°μ΄λμ€ μ€μΌμΌ (s > 1 β λ κ°ν 쑰건ν)
3. λΆλ₯κΈ° ν리 κ°μ΄λμ€(Classifier-Free Guidance)¶
λ³λμ λΆλ₯κΈ° λΆνμ!
쑰건λΆμ λ¬΄μ‘°κ±΄λΆ λͺ¨λ μ²λ¦¬νλλ‘ λͺ¨λΈ νμ΅:
Ρθ(xβ, t, c) (νλ₯ pλ‘)
Ρθ(xβ, t, β
) (νλ₯ 1-pλ‘) (β
= λ ν΄λμ€)
κ°μ΄λλ μμΈ‘:
Ξ΅Μ = Ρθ(xβ, t, β
) + wΒ·(Ρθ(xβ, t, c) - Ρθ(xβ, t, β
))
w: κ°μ΄λμ€ κ°μ€μΉ (w=0 β 무쑰건λΆ, w>1 β λ κ°ν¨)
μ¬μ©μ²: Stable Diffusion, DALL-E 2, Imagen
4. νμ΅ ν¶
1. EMA (μ§μ μ΄λ νκ· , Exponential Moving Average):
- ΞΈ_ema = 0.9999Β·ΞΈ_ema + 0.0001Β·ΞΈ μ μ§
- μνλ§μ ΞΈ_ema μ¬μ©
2. μ μ§μ νμ΅(Progressive Training):
- μμ ν΄μλλ‘ μμ
- μ μ§μ μΌλ‘ μ¦κ° (8Γ8 β 16Γ16 β 32Γ32)
3. λ°μ΄ν° μ¦κ°:
- 무μμ μν λ€μ§κΈ°
- [-1, 1]λ‘ μ κ·ν
4. νμ΅λ₯ :
- MNIST/CIFAR: 2e-4
- κ³ ν΄μλ: 1e-4
5. λ°°μΉ ν¬κΈ°:
- μμ μ΄λ―Έμ§: 128-256
- ν° μ΄λ―Έμ§: 32-64
ꡬν λ 벨¶
λ 벨 2: PyTorch λ‘μ°λ 벨 (pytorch_lowlevel/)¶
- μλ°©ν₯/μλ°©ν₯ νμ° κ΅¬ν
- λ Έμ΄μ¦ μ€μΌμ€(μ ν) ꡬν
- μκ° μλ² λ©μ΄ μλ UNet ꡬμΆ
- MNIST (28Γ28) λ° CIFAR-10 (32Γ32)μμ νμ΅
λ 벨 3: λ Όλ¬Έ ꡬν (paper/)¶
- μ½μ¬μΈ μ€μΌμ€μ κ°λ μ 체 DDPM
- DDIM μνλ§ (λΉ λ₯Έ μΆλ‘ )
- λΆλ₯κΈ° ν리 κ°μ΄λμ€
- FID/IS νκ° λ©νΈλ¦
νμ΅ λ£¨ν¶
# μμ¬μ½λ
for epoch in epochs:
for x0, _ in dataloader:
# 무μμ νμμ€ν
μνλ§
t = torch.randint(1, T+1, (batch_size,))
# λ
Έμ΄μ¦ μνλ§
noise = torch.randn_like(x0)
# μλ°©ν₯ νμ°: λ
Έμ΄μ¦κ° μλ μ΄λ―Έμ§ μμ±
xt = sqrt_alpha_bar[t] * x0 + sqrt_one_minus_alpha_bar[t] * noise
# λ
Έμ΄μ¦ μμΈ‘
noise_pred = model(xt, t)
# MSE μμ€
loss = F.mse_loss(noise_pred, noise)
# μμ ν
optimizer.zero_grad()
loss.backward()
optimizer.step()
μνλ§ λ£¨ν¶
# DDPM μνλ§
x = torch.randn(batch_size, 3, 32, 32) # λ
Έμ΄μ¦μμ μμ
for t in reversed(range(1, T+1)):
# λ
Έμ΄μ¦ μμΈ‘
t_batch = torch.full((batch_size,), t)
noise_pred = model(x, t_batch)
# νκ· κ³μ°
alpha_t = alpha[t]
alpha_bar_t = alpha_bar[t]
mean = (x - (1 - alpha_t) / sqrt(1 - alpha_bar_t) * noise_pred) / sqrt(alpha_t)
# λ
Έμ΄μ¦ μΆκ° (λ§μ§λ§ λ¨κ³ μ μΈ)
if t > 1:
noise = torch.randn_like(x)
sigma_t = sqrt(beta[t])
x = mean + sigma_t * noise
else:
x = mean
# xλ μμ±λ μ΄λ―Έμ§
νμ΅ μ²΄ν¬λ¦¬μ€νΈ¶
- [ ] μλ°©ν₯ νμ° λ«ν νμ μ΄ν΄
- [ ] ELBOμμ μλ°©ν₯ νμ° μ λ
- [ ] λ Έμ΄μ¦ μ€μΌμ€ ꡬν (μ ν, μ½μ¬μΈ)
- [ ] μκ° μλ² λ©μ΄ μλ UNet ꡬμΆ
- [ ] DDPM vs DDIM μνλ§ μ΄ν΄
- [ ] λΆλ₯κΈ° ν리 κ°μ΄λμ€ κ΅¬ν
- [ ] νκ°λ₯Ό μν FID μ€μ½μ΄ κ³μ°
μ°Έκ³ λ¬Έν¶
- Ho et al. (2020). "Denoising Diffusion Probabilistic Models"
- Song et al. (2020). "Denoising Diffusion Implicit Models"
- Nichol & Dhariwal (2021). "Improved Denoising Diffusion Probabilistic Models"
- Ho & Salimans (2022). "Classifier-Free Diffusion Guidance"
- 32_Diffusion_Models.md