1"""
217. Simple Diffusion Model (DDPM) Implementation
3
4A minimal implementation of Denoising Diffusion Probabilistic Models
5for MNIST digit generation.
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.utils.data import DataLoader
12from torchvision import datasets, transforms
13import torchvision.utils as vutils
14import matplotlib.pyplot as plt
15import numpy as np
16import math
17
18print("=" * 60)
19print("Simple Diffusion Model (DDPM) Implementation")
20print("=" * 60)
21
22device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23print(f"Using device: {device}")
24
25
26# ============================================
27# 1. Noise Schedule
28# ============================================
29print("\n[1] Noise Schedule")
30print("-" * 40)
31
32
33def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
34 """Linear noise schedule"""
35 return torch.linspace(beta_start, beta_end, timesteps)
36
37
38def cosine_beta_schedule(timesteps, s=0.008):
39 """Cosine noise schedule (better performance)"""
40 steps = timesteps + 1
41 x = torch.linspace(0, timesteps, steps)
42 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
43 alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
44 betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
45 return torch.clip(betas, 0.0001, 0.9999)
46
47
48def get_index_from_list(vals, t, x_shape):
49 """Extract values from schedule at timestep t for each sample in batch"""
50 batch_size = t.shape[0]
51 out = vals.gather(-1, t.cpu())
52 return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
53
54
55class DiffusionSchedule:
56 """Manages all diffusion schedule parameters"""
57 def __init__(self, timesteps=1000, beta_schedule='linear', device='cpu'):
58 self.timesteps = timesteps
59 self.device = device
60
61 if beta_schedule == 'linear':
62 betas = linear_beta_schedule(timesteps)
63 else:
64 betas = cosine_beta_schedule(timesteps)
65
66 self.betas = betas.to(device)
67 self.alphas = (1.0 - betas).to(device)
68 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(device)
69 self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0).to(device)
70
71 # Calculations for diffusion q(x_t | x_0) and posterior q(x_{t-1} | x_t, x_0)
72 self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(device)
73 self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(device)
74 self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas).to(device)
75
76 # Posterior variance
77 self.posterior_variance = (
78 betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
79 ).to(device)
80
81 def q_sample(self, x_0, t, noise=None):
82 """Forward diffusion: sample x_t from q(x_t | x_0)
83
84 x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
85 """
86 if noise is None:
87 noise = torch.randn_like(x_0)
88
89 sqrt_alphas_cumprod_t = get_index_from_list(
90 self.sqrt_alphas_cumprod, t, x_0.shape
91 )
92 sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
93 self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
94 )
95
96 return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
97
98
99# Test schedule
100schedule = DiffusionSchedule(timesteps=1000, device=device)
101print(f"Timesteps: {schedule.timesteps}")
102print(f"Beta range: [{schedule.betas[0]:.6f}, {schedule.betas[-1]:.6f}]")
103print(f"Alpha_bar range: [{schedule.alphas_cumprod[-1]:.6f}, {schedule.alphas_cumprod[0]:.6f}]")
104
105
106# ============================================
107# 2. U-Net Architecture
108# ============================================
109print("\n[2] U-Net Architecture")
110print("-" * 40)
111
112
113class SinusoidalPositionEmbeddings(nn.Module):
114 """Sinusoidal embeddings for timestep"""
115 def __init__(self, dim):
116 super().__init__()
117 self.dim = dim
118
119 def forward(self, time):
120 device = time.device
121 half_dim = self.dim // 2
122 embeddings = math.log(10000) / (half_dim - 1)
123 embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
124 embeddings = time[:, None] * embeddings[None, :]
125 embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
126 return embeddings
127
128
129class Block(nn.Module):
130 """Basic convolutional block with time embedding"""
131 def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
132 super().__init__()
133 self.time_mlp = nn.Linear(time_emb_dim, out_ch)
134
135 if up:
136 self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
137 self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
138 else:
139 self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
140 self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
141
142 self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
143 self.bnorm1 = nn.BatchNorm2d(out_ch)
144 self.bnorm2 = nn.BatchNorm2d(out_ch)
145 self.relu = nn.ReLU()
146
147 def forward(self, x, t):
148 h = self.bnorm1(self.relu(self.conv1(x)))
149 time_emb = self.relu(self.time_mlp(t))
150 time_emb = time_emb[..., None, None]
151 h = h + time_emb
152 h = self.bnorm2(self.relu(self.conv2(h)))
153 return self.transform(h)
154
155
156class SimpleUNet(nn.Module):
157 """Simple U-Net for noise prediction"""
158 def __init__(self, in_channels=1, out_channels=1, time_dim=256, base_channels=64):
159 super().__init__()
160
161 # Time embedding
162 self.time_mlp = nn.Sequential(
163 SinusoidalPositionEmbeddings(time_dim),
164 nn.Linear(time_dim, time_dim),
165 nn.ReLU()
166 )
167
168 # Initial conv
169 self.conv0 = nn.Conv2d(in_channels, base_channels, 3, padding=1)
170
171 # Downsampling
172 self.downs = nn.ModuleList([
173 Block(base_channels, base_channels * 2, time_dim),
174 Block(base_channels * 2, base_channels * 4, time_dim),
175 ])
176
177 # Bottleneck
178 self.bot1 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
179 self.bot2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
180
181 # Upsampling
182 self.ups = nn.ModuleList([
183 Block(base_channels * 4, base_channels * 2, time_dim, up=True),
184 Block(base_channels * 2, base_channels, time_dim, up=True),
185 ])
186
187 # Output
188 self.output = nn.Conv2d(base_channels, out_channels, 1)
189
190 def forward(self, x, timestep):
191 t = self.time_mlp(timestep)
192 x = self.conv0(x)
193
194 # Downsample
195 residuals = []
196 for down in self.downs:
197 x = down(x, t)
198 residuals.append(x)
199
200 # Bottleneck
201 x = F.relu(self.bot1(x))
202 x = F.relu(self.bot2(x))
203
204 # Upsample with skip connections
205 for up in self.ups:
206 residual = residuals.pop()
207 x = torch.cat((x, residual), dim=1)
208 x = up(x, t)
209
210 return self.output(x)
211
212
213# Test U-Net
214unet = SimpleUNet(in_channels=1, out_channels=1)
215x = torch.randn(4, 1, 28, 28)
216t = torch.randint(0, 1000, (4,))
217out = unet(x, t)
218print(f"U-Net input: {x.shape}")
219print(f"U-Net output: {out.shape}")
220print(f"Parameters: {sum(p.numel() for p in unet.parameters()):,}")
221
222
223# ============================================
224# 3. Training
225# ============================================
226print("\n[3] Training Loop")
227print("-" * 40)
228
229
230def train_diffusion(model, schedule, dataloader, epochs=5, lr=1e-3):
231 """Train diffusion model"""
232 model.to(device)
233 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
234 criterion = nn.MSELoss()
235
236 losses = []
237
238 for epoch in range(epochs):
239 total_loss = 0
240
241 for batch_idx, (images, _) in enumerate(dataloader):
242 images = images.to(device)
243 batch_size = images.size(0)
244
245 # Random timesteps
246 t = torch.randint(0, schedule.timesteps, (batch_size,), device=device).long()
247
248 # Add noise
249 noise = torch.randn_like(images)
250 x_t = schedule.q_sample(images, t, noise)
251
252 # Predict noise
253 noise_pred = model(x_t, t)
254
255 # Loss
256 loss = criterion(noise_pred, noise)
257
258 optimizer.zero_grad()
259 loss.backward()
260 optimizer.step()
261
262 total_loss += loss.item()
263
264 avg_loss = total_loss / len(dataloader)
265 losses.append(avg_loss)
266 print(f"Epoch {epoch+1}/{epochs}: Loss = {avg_loss:.6f}")
267
268 return losses
269
270
271# ============================================
272# 4. Sampling (Reverse Process)
273# ============================================
274print("\n[4] Sampling (Reverse Process)")
275print("-" * 40)
276
277
278@torch.no_grad()
279def sample_ddpm(model, schedule, shape, device, show_progress=True):
280 """DDPM sampling: generate images from pure noise"""
281 model.eval()
282
283 # Start from pure noise
284 x = torch.randn(shape, device=device)
285
286 for i in reversed(range(schedule.timesteps)):
287 t = torch.full((shape[0],), i, device=device, dtype=torch.long)
288
289 betas_t = get_index_from_list(schedule.betas, t, x.shape)
290 sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
291 schedule.sqrt_one_minus_alphas_cumprod, t, x.shape
292 )
293 sqrt_recip_alphas_t = get_index_from_list(
294 schedule.sqrt_recip_alphas, t, x.shape
295 )
296
297 # Predict noise
298 noise_pred = model(x, t)
299
300 # Compute mean
301 model_mean = sqrt_recip_alphas_t * (
302 x - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
303 )
304
305 # Add noise (except for t=0)
306 if i > 0:
307 posterior_variance_t = get_index_from_list(
308 schedule.posterior_variance, t, x.shape
309 )
310 noise = torch.randn_like(x)
311 x = model_mean + torch.sqrt(posterior_variance_t) * noise
312 else:
313 x = model_mean
314
315 if show_progress and i % 100 == 0:
316 print(f" Sampling step {schedule.timesteps - i}/{schedule.timesteps}")
317
318 return x
319
320
321@torch.no_grad()
322def sample_ddim(model, schedule, shape, device, num_steps=50, eta=0.0):
323 """DDIM sampling: faster with fewer steps"""
324 model.eval()
325
326 # Create step sequence
327 step_size = schedule.timesteps // num_steps
328 timesteps = list(range(0, schedule.timesteps, step_size))
329 timesteps = list(reversed(timesteps))
330
331 x = torch.randn(shape, device=device)
332
333 for i, t in enumerate(timesteps):
334 t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)
335
336 alpha_cumprod_t = schedule.alphas_cumprod[t]
337
338 if i < len(timesteps) - 1:
339 alpha_cumprod_prev = schedule.alphas_cumprod[timesteps[i + 1]]
340 else:
341 alpha_cumprod_prev = torch.tensor(1.0, device=device)
342
343 # Predict noise
344 noise_pred = model(x, t_tensor)
345
346 # Predict x_0
347 pred_x0 = (x - torch.sqrt(1 - alpha_cumprod_t) * noise_pred) / torch.sqrt(alpha_cumprod_t)
348
349 # Compute variance
350 sigma = eta * torch.sqrt(
351 (1 - alpha_cumprod_prev) / (1 - alpha_cumprod_t) *
352 (1 - alpha_cumprod_t / alpha_cumprod_prev)
353 )
354
355 # Direction
356 pred_dir = torch.sqrt(1 - alpha_cumprod_prev - sigma ** 2) * noise_pred
357
358 # Next x
359 noise = torch.randn_like(x) if eta > 0 and i < len(timesteps) - 1 else 0
360 x = torch.sqrt(alpha_cumprod_prev) * pred_x0 + pred_dir + sigma * noise
361
362 return x
363
364
365# ============================================
366# 5. Visualize Diffusion Process
367# ============================================
368print("\n[5] Visualize Forward Process")
369print("-" * 40)
370
371
372def visualize_forward_process(schedule, image, timesteps_to_show):
373 """Show image at different noise levels"""
374 fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(15, 3))
375
376 for idx, t in enumerate(timesteps_to_show):
377 t_tensor = torch.tensor([t])
378 noisy = schedule.q_sample(image.unsqueeze(0), t_tensor)
379
380 axes[idx].imshow(noisy[0, 0].cpu(), cmap='gray')
381 axes[idx].set_title(f't = {t}')
382 axes[idx].axis('off')
383
384 plt.suptitle('Forward Diffusion Process')
385 plt.tight_layout()
386 plt.savefig('diffusion_forward.png', dpi=150)
387 plt.close()
388 print("Forward process visualization saved to diffusion_forward.png")
389
390
391# ============================================
392# 6. Training Example
393# ============================================
394print("\n[6] Training on MNIST")
395print("-" * 40)
396
397# Hyperparameters
398timesteps = 1000
399batch_size = 64
400epochs = 5 # Increase for better results
401lr = 1e-3
402
403# Data
404transform = transforms.Compose([
405 transforms.ToTensor(),
406 transforms.Lambda(lambda t: (t * 2) - 1) # [0, 1] -> [-1, 1]
407])
408
409print("Loading MNIST dataset...")
410train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
411train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
412
413# Schedule
414schedule = DiffusionSchedule(timesteps=timesteps, beta_schedule='linear', device=device)
415
416# Model
417model = SimpleUNet(in_channels=1, out_channels=1).to(device)
418print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
419
420# Visualize forward process
421sample_img, _ = train_data[0]
422visualize_forward_process(schedule, sample_img, [0, 100, 300, 500, 700, 900, 999])
423
424# Train
425print("\nTraining diffusion model...")
426losses = train_diffusion(model, schedule, train_loader, epochs=epochs, lr=lr)
427
428# Plot loss
429plt.figure(figsize=(10, 4))
430plt.plot(losses)
431plt.xlabel('Epoch')
432plt.ylabel('Loss')
433plt.title('Diffusion Model Training Loss')
434plt.savefig('diffusion_loss.png', dpi=150)
435plt.close()
436print("Loss curve saved to diffusion_loss.png")
437
438# Sample
439print("\nGenerating samples with DDPM...")
440samples = sample_ddpm(model, schedule, (16, 1, 28, 28), device)
441samples = (samples + 1) / 2 # [-1, 1] -> [0, 1]
442samples = samples.clamp(0, 1)
443
444grid = vutils.make_grid(samples.cpu(), nrow=4, normalize=False, padding=2)
445plt.figure(figsize=(8, 8))
446plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
447plt.axis('off')
448plt.title('DDPM Generated Samples')
449plt.savefig('diffusion_samples.png', dpi=150)
450plt.close()
451print("Generated samples saved to diffusion_samples.png")
452
453# DDIM sampling (faster)
454print("\nGenerating samples with DDIM (50 steps)...")
455samples_ddim = sample_ddim(model, schedule, (16, 1, 28, 28), device, num_steps=50, eta=0.0)
456samples_ddim = (samples_ddim + 1) / 2
457samples_ddim = samples_ddim.clamp(0, 1)
458
459grid_ddim = vutils.make_grid(samples_ddim.cpu(), nrow=4, normalize=False, padding=2)
460plt.figure(figsize=(8, 8))
461plt.imshow(grid_ddim.permute(1, 2, 0).squeeze(), cmap='gray')
462plt.axis('off')
463plt.title('DDIM Generated Samples (50 steps)')
464plt.savefig('diffusion_samples_ddim.png', dpi=150)
465plt.close()
466print("DDIM samples saved to diffusion_samples_ddim.png")
467
468
469# ============================================
470# Summary
471# ============================================
472print("\n" + "=" * 60)
473print("Diffusion Model Summary")
474print("=" * 60)
475
476summary = """
477Key Concepts:
4781. Forward Process: Gradually add noise to data
479 x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
480
4812. Reverse Process: Learn to denoise step by step
482 Model predicts noise epsilon at each step
483
4843. Training: Simple MSE loss on noise prediction
485 L = ||epsilon - epsilon_theta(x_t, t)||^2
486
4874. Sampling:
488 - DDPM: 1000 steps, stochastic
489 - DDIM: 50-100 steps, deterministic
490
491Noise Schedules:
492- Linear: Simple, widely used
493- Cosine: Better quality for small images
494
495Key Parameters:
496- timesteps: Number of diffusion steps (1000)
497- beta_start, beta_end: Noise schedule bounds
498- U-Net: Time-conditioned denoising network
499
500Output Files:
501- diffusion_forward.png: Forward process visualization
502- diffusion_loss.png: Training loss curve
503- diffusion_samples.png: DDPM generated samples
504- diffusion_samples_ddim.png: DDIM generated samples
505"""
506print(summary)
507print("=" * 60)