1"""
2DDPM (Denoising Diffusion Probabilistic Model) Implementation
3
4This script implements a simple DDPM for image generation following
5"Denoising Diffusion Probabilistic Models" (Ho et al., 2020).
6
7Key concepts:
8- Forward diffusion: gradually add Gaussian noise to data
9- Reverse diffusion: learn to denoise and generate samples
10- Linear beta schedule for noise variance
11- Simple UNet architecture with time embedding
12
13References:
14- Ho et al. (2020): https://arxiv.org/abs/2006.11239
15"""
16
17import torch
18import torch.nn as nn
19import torch.nn.functional as F
20from torch.utils.data import DataLoader
21from torchvision import datasets, transforms
22import matplotlib.pyplot as plt
23import numpy as np
24from tqdm import tqdm
25
26
27# ============================================================================
28# Noise Schedule
29# ============================================================================
30
31def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
32 """
33 Linear schedule for beta (variance) from beta_start to beta_end.
34
35 Args:
36 timesteps: number of diffusion steps (T)
37 beta_start: minimum noise variance
38 beta_end: maximum noise variance
39
40 Returns:
41 betas: [T] tensor of noise variances
42 """
43 return torch.linspace(beta_start, beta_end, timesteps)
44
45
46def get_diffusion_params(betas):
47 """
48 Precompute diffusion parameters for efficient sampling.
49
50 Args:
51 betas: [T] noise schedule
52
53 Returns:
54 Dictionary with precomputed parameters
55 """
56 alphas = 1.0 - betas
57 alphas_cumprod = torch.cumprod(alphas, dim=0)
58 alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
59
60 sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
61 sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
62
63 # Posterior variance for reverse process
64 posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
65
66 return {
67 'betas': betas,
68 'alphas': alphas,
69 'alphas_cumprod': alphas_cumprod,
70 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
71 'sqrt_one_minus_alphas_cumprod': sqrt_one_minus_alphas_cumprod,
72 'posterior_variance': posterior_variance,
73 }
74
75
76# ============================================================================
77# Time Embedding
78# ============================================================================
79
80class SinusoidalPositionEmbedding(nn.Module):
81 """
82 Sinusoidal time embedding similar to Transformer positional encoding.
83 Maps timestep t to a high-dimensional vector.
84 """
85 def __init__(self, dim):
86 super().__init__()
87 self.dim = dim
88
89 def forward(self, t):
90 """
91 Args:
92 t: [batch_size] timesteps
93 Returns:
94 [batch_size, dim] embeddings
95 """
96 device = t.device
97 half_dim = self.dim // 2
98 embeddings = np.log(10000) / (half_dim - 1)
99 embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
100 embeddings = t[:, None] * embeddings[None, :]
101 embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
102 return embeddings
103
104
105# ============================================================================
106# Simple UNet Architecture
107# ============================================================================
108
109class SimpleUNet(nn.Module):
110 """
111 Simplified UNet for DDPM with time conditioning.
112
113 Architecture:
114 - Encoder: downsampling with conv blocks
115 - Decoder: upsampling with conv blocks
116 - Time embedding injected at each resolution
117 """
118 def __init__(self, in_channels=1, out_channels=1, time_dim=128, base_channels=64):
119 super().__init__()
120
121 # Time embedding
122 self.time_mlp = nn.Sequential(
123 SinusoidalPositionEmbedding(time_dim),
124 nn.Linear(time_dim, time_dim * 4),
125 nn.GELU(),
126 nn.Linear(time_dim * 4, time_dim),
127 )
128
129 # Encoder (downsampling)
130 self.enc1 = self._make_block(in_channels, base_channels, time_dim)
131 self.enc2 = self._make_block(base_channels, base_channels * 2, time_dim)
132 self.enc3 = self._make_block(base_channels * 2, base_channels * 4, time_dim)
133
134 # Bottleneck
135 self.bottleneck = self._make_block(base_channels * 4, base_channels * 4, time_dim)
136
137 # Decoder (upsampling)
138 self.dec3 = self._make_block(base_channels * 8, base_channels * 2, time_dim)
139 self.dec2 = self._make_block(base_channels * 4, base_channels, time_dim)
140 self.dec1 = self._make_block(base_channels * 2, base_channels, time_dim)
141
142 # Output layer
143 self.out = nn.Conv2d(base_channels, out_channels, 1)
144
145 def _make_block(self, in_ch, out_ch, time_dim):
146 """Create a residual block with time conditioning."""
147 return nn.ModuleDict({
148 'conv1': nn.Conv2d(in_ch, out_ch, 3, padding=1),
149 'conv2': nn.Conv2d(out_ch, out_ch, 3, padding=1),
150 'time_proj': nn.Linear(time_dim, out_ch),
151 'norm1': nn.GroupNorm(8, out_ch),
152 'norm2': nn.GroupNorm(8, out_ch),
153 })
154
155 def _forward_block(self, x, t_emb, block):
156 """Forward pass through a block with time embedding."""
157 h = block['conv1'](x)
158 h = block['norm1'](h)
159
160 # Add time embedding
161 t_proj = block['time_proj'](t_emb)[:, :, None, None]
162 h = h + t_proj
163
164 h = F.gelu(h)
165 h = block['conv2'](h)
166 h = block['norm2'](h)
167 h = F.gelu(h)
168
169 return h
170
171 def forward(self, x, t):
172 """
173 Args:
174 x: [B, C, H, W] noisy images
175 t: [B] timesteps
176 Returns:
177 [B, C, H, W] predicted noise
178 """
179 # Time embedding
180 t_emb = self.time_mlp(t)
181
182 # Encoder
183 x1 = self._forward_block(x, t_emb, self.enc1)
184 x2 = F.max_pool2d(x1, 2)
185
186 x2 = self._forward_block(x2, t_emb, self.enc2)
187 x3 = F.max_pool2d(x2, 2)
188
189 x3 = self._forward_block(x3, t_emb, self.enc3)
190 x4 = F.max_pool2d(x3, 2)
191
192 # Bottleneck
193 x4 = self._forward_block(x4, t_emb, self.bottleneck)
194
195 # Decoder with skip connections
196 x = F.interpolate(x4, scale_factor=2, mode='nearest')
197 x = torch.cat([x, x3], dim=1)
198 x = self._forward_block(x, t_emb, self.dec3)
199
200 x = F.interpolate(x, scale_factor=2, mode='nearest')
201 x = torch.cat([x, x2], dim=1)
202 x = self._forward_block(x, t_emb, self.dec2)
203
204 x = F.interpolate(x, scale_factor=2, mode='nearest')
205 x = torch.cat([x, x1], dim=1)
206 x = self._forward_block(x, t_emb, self.dec1)
207
208 return self.out(x)
209
210
211# ============================================================================
212# Diffusion Process
213# ============================================================================
214
215def forward_diffusion(x0, t, params, device):
216 """
217 Add noise to data according to forward diffusion process.
218
219 q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
220
221 Args:
222 x0: [B, C, H, W] clean images
223 t: [B] timesteps
224 params: diffusion parameters
225 device: torch device
226
227 Returns:
228 noisy_x: [B, C, H, W] noisy images
229 noise: [B, C, H, W] added noise
230 """
231 noise = torch.randn_like(x0)
232
233 sqrt_alpha_cumprod_t = params['sqrt_alphas_cumprod'][t][:, None, None, None]
234 sqrt_one_minus_alpha_cumprod_t = params['sqrt_one_minus_alphas_cumprod'][t][:, None, None, None]
235
236 noisy_x = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
237
238 return noisy_x, noise
239
240
241@torch.no_grad()
242def sample(model, params, image_size, batch_size, timesteps, device):
243 """
244 Generate samples using reverse diffusion process.
245
246 Start from random noise and iteratively denoise.
247
248 Args:
249 model: trained UNet
250 params: diffusion parameters
251 image_size: (C, H, W)
252 batch_size: number of samples
253 timesteps: number of diffusion steps
254 device: torch device
255
256 Returns:
257 [batch_size, C, H, W] generated images
258 """
259 model.eval()
260
261 # Start from random noise
262 x = torch.randn(batch_size, *image_size, device=device)
263
264 for i in tqdm(reversed(range(timesteps)), desc='Sampling', total=timesteps):
265 t = torch.full((batch_size,), i, device=device, dtype=torch.long)
266
267 # Predict noise
268 predicted_noise = model(x, t)
269
270 # Get parameters for this timestep
271 alpha = params['alphas'][t][:, None, None, None]
272 alpha_cumprod = params['alphas_cumprod'][t][:, None, None, None]
273 beta = params['betas'][t][:, None, None, None]
274
275 # Compute mean of reverse distribution
276 if i > 0:
277 noise = torch.randn_like(x)
278 else:
279 noise = torch.zeros_like(x)
280
281 # Reverse diffusion step
282 x = (1 / torch.sqrt(alpha)) * (
283 x - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
284 ) + torch.sqrt(beta) * noise
285
286 return x
287
288
289# ============================================================================
290# Training
291# ============================================================================
292
293def train_ddpm(epochs=10, batch_size=128, timesteps=1000, device='cuda'):
294 """
295 Train DDPM on MNIST dataset.
296
297 Args:
298 epochs: number of training epochs
299 batch_size: batch size
300 timesteps: number of diffusion steps (T)
301 device: 'cuda' or 'cpu'
302 """
303 device = torch.device(device if torch.cuda.is_available() else 'cpu')
304
305 # Data preparation
306 transform = transforms.Compose([
307 transforms.ToTensor(),
308 transforms.Normalize((0.5,), (0.5,)) # Scale to [-1, 1]
309 ])
310
311 dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
312 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
313
314 # Initialize model and diffusion parameters
315 model = SimpleUNet(in_channels=1, out_channels=1).to(device)
316 betas = linear_beta_schedule(timesteps).to(device)
317 params = get_diffusion_params(betas)
318
319 optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
320
321 # Training loop
322 for epoch in range(epochs):
323 model.train()
324 total_loss = 0
325
326 for batch_idx, (images, _) in enumerate(dataloader):
327 images = images.to(device)
328 batch_size_actual = images.shape[0]
329
330 # Sample random timesteps
331 t = torch.randint(0, timesteps, (batch_size_actual,), device=device)
332
333 # Forward diffusion (add noise)
334 noisy_images, noise = forward_diffusion(images, t, params, device)
335
336 # Predict noise
337 predicted_noise = model(noisy_images, t)
338
339 # MSE loss between predicted and actual noise
340 loss = F.mse_loss(predicted_noise, noise)
341
342 # Optimization step
343 optimizer.zero_grad()
344 loss.backward()
345 optimizer.step()
346
347 total_loss += loss.item()
348
349 avg_loss = total_loss / len(dataloader)
350 print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')
351
352 # Sample images every 5 epochs
353 if (epoch + 1) % 5 == 0:
354 samples = sample(model, params, (1, 28, 28), 16, timesteps, device)
355 samples = (samples + 1) / 2 # Denormalize to [0, 1]
356
357 # Visualize
358 fig, axes = plt.subplots(4, 4, figsize=(8, 8))
359 for i, ax in enumerate(axes.flat):
360 ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
361 ax.axis('off')
362 plt.suptitle(f'Generated Samples - Epoch {epoch+1}')
363 plt.tight_layout()
364 plt.savefig(f'ddpm_samples_epoch_{epoch+1}.png')
365 plt.close()
366
367 print("Training completed!")
368 return model, params
369
370
371# ============================================================================
372# Main
373# ============================================================================
374
375if __name__ == '__main__':
376 # Train model
377 model, params = train_ddpm(epochs=10, batch_size=128, timesteps=1000, device='cuda')
378
379 # Generate final samples
380 print("\nGenerating final samples...")
381 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
382 samples = sample(model, params, (1, 28, 28), 64, 1000, device)
383 samples = (samples + 1) / 2
384
385 # Visualize final samples
386 fig, axes = plt.subplots(8, 8, figsize=(12, 12))
387 for i, ax in enumerate(axes.flat):
388 ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
389 ax.axis('off')
390 plt.suptitle('DDPM Generated Samples (Final)')
391 plt.tight_layout()
392 plt.savefig('ddpm_final_samples.png')
393 plt.show()