17_diffusion_simple.py

Download
python 508 lines 15.6 KB
  1"""
  217. Simple Diffusion Model (DDPM) Implementation
  3
  4A minimal implementation of Denoising Diffusion Probabilistic Models
  5for MNIST digit generation.
  6"""
  7
  8import torch
  9import torch.nn as nn
 10import torch.nn.functional as F
 11from torch.utils.data import DataLoader
 12from torchvision import datasets, transforms
 13import torchvision.utils as vutils
 14import matplotlib.pyplot as plt
 15import numpy as np
 16import math
 17
 18print("=" * 60)
 19print("Simple Diffusion Model (DDPM) Implementation")
 20print("=" * 60)
 21
 22device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 23print(f"Using device: {device}")
 24
 25
 26# ============================================
 27# 1. Noise Schedule
 28# ============================================
 29print("\n[1] Noise Schedule")
 30print("-" * 40)
 31
 32
 33def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
 34    """Linear noise schedule"""
 35    return torch.linspace(beta_start, beta_end, timesteps)
 36
 37
 38def cosine_beta_schedule(timesteps, s=0.008):
 39    """Cosine noise schedule (better performance)"""
 40    steps = timesteps + 1
 41    x = torch.linspace(0, timesteps, steps)
 42    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
 43    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
 44    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
 45    return torch.clip(betas, 0.0001, 0.9999)
 46
 47
 48def get_index_from_list(vals, t, x_shape):
 49    """Extract values from schedule at timestep t for each sample in batch"""
 50    batch_size = t.shape[0]
 51    out = vals.gather(-1, t.cpu())
 52    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
 53
 54
 55class DiffusionSchedule:
 56    """Manages all diffusion schedule parameters"""
 57    def __init__(self, timesteps=1000, beta_schedule='linear', device='cpu'):
 58        self.timesteps = timesteps
 59        self.device = device
 60
 61        if beta_schedule == 'linear':
 62            betas = linear_beta_schedule(timesteps)
 63        else:
 64            betas = cosine_beta_schedule(timesteps)
 65
 66        self.betas = betas.to(device)
 67        self.alphas = (1.0 - betas).to(device)
 68        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(device)
 69        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0).to(device)
 70
 71        # Calculations for diffusion q(x_t | x_0) and posterior q(x_{t-1} | x_t, x_0)
 72        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod).to(device)
 73        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod).to(device)
 74        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas).to(device)
 75
 76        # Posterior variance
 77        self.posterior_variance = (
 78            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
 79        ).to(device)
 80
 81    def q_sample(self, x_0, t, noise=None):
 82        """Forward diffusion: sample x_t from q(x_t | x_0)
 83
 84        x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
 85        """
 86        if noise is None:
 87            noise = torch.randn_like(x_0)
 88
 89        sqrt_alphas_cumprod_t = get_index_from_list(
 90            self.sqrt_alphas_cumprod, t, x_0.shape
 91        )
 92        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
 93            self.sqrt_one_minus_alphas_cumprod, t, x_0.shape
 94        )
 95
 96        return sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
 97
 98
 99# Test schedule
100schedule = DiffusionSchedule(timesteps=1000, device=device)
101print(f"Timesteps: {schedule.timesteps}")
102print(f"Beta range: [{schedule.betas[0]:.6f}, {schedule.betas[-1]:.6f}]")
103print(f"Alpha_bar range: [{schedule.alphas_cumprod[-1]:.6f}, {schedule.alphas_cumprod[0]:.6f}]")
104
105
106# ============================================
107# 2. U-Net Architecture
108# ============================================
109print("\n[2] U-Net Architecture")
110print("-" * 40)
111
112
113class SinusoidalPositionEmbeddings(nn.Module):
114    """Sinusoidal embeddings for timestep"""
115    def __init__(self, dim):
116        super().__init__()
117        self.dim = dim
118
119    def forward(self, time):
120        device = time.device
121        half_dim = self.dim // 2
122        embeddings = math.log(10000) / (half_dim - 1)
123        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
124        embeddings = time[:, None] * embeddings[None, :]
125        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
126        return embeddings
127
128
129class Block(nn.Module):
130    """Basic convolutional block with time embedding"""
131    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
132        super().__init__()
133        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
134
135        if up:
136            self.conv1 = nn.Conv2d(2 * in_ch, out_ch, 3, padding=1)
137            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
138        else:
139            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
140            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
141
142        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
143        self.bnorm1 = nn.BatchNorm2d(out_ch)
144        self.bnorm2 = nn.BatchNorm2d(out_ch)
145        self.relu = nn.ReLU()
146
147    def forward(self, x, t):
148        h = self.bnorm1(self.relu(self.conv1(x)))
149        time_emb = self.relu(self.time_mlp(t))
150        time_emb = time_emb[..., None, None]
151        h = h + time_emb
152        h = self.bnorm2(self.relu(self.conv2(h)))
153        return self.transform(h)
154
155
156class SimpleUNet(nn.Module):
157    """Simple U-Net for noise prediction"""
158    def __init__(self, in_channels=1, out_channels=1, time_dim=256, base_channels=64):
159        super().__init__()
160
161        # Time embedding
162        self.time_mlp = nn.Sequential(
163            SinusoidalPositionEmbeddings(time_dim),
164            nn.Linear(time_dim, time_dim),
165            nn.ReLU()
166        )
167
168        # Initial conv
169        self.conv0 = nn.Conv2d(in_channels, base_channels, 3, padding=1)
170
171        # Downsampling
172        self.downs = nn.ModuleList([
173            Block(base_channels, base_channels * 2, time_dim),
174            Block(base_channels * 2, base_channels * 4, time_dim),
175        ])
176
177        # Bottleneck
178        self.bot1 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
179        self.bot2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1)
180
181        # Upsampling
182        self.ups = nn.ModuleList([
183            Block(base_channels * 4, base_channels * 2, time_dim, up=True),
184            Block(base_channels * 2, base_channels, time_dim, up=True),
185        ])
186
187        # Output
188        self.output = nn.Conv2d(base_channels, out_channels, 1)
189
190    def forward(self, x, timestep):
191        t = self.time_mlp(timestep)
192        x = self.conv0(x)
193
194        # Downsample
195        residuals = []
196        for down in self.downs:
197            x = down(x, t)
198            residuals.append(x)
199
200        # Bottleneck
201        x = F.relu(self.bot1(x))
202        x = F.relu(self.bot2(x))
203
204        # Upsample with skip connections
205        for up in self.ups:
206            residual = residuals.pop()
207            x = torch.cat((x, residual), dim=1)
208            x = up(x, t)
209
210        return self.output(x)
211
212
213# Test U-Net
214unet = SimpleUNet(in_channels=1, out_channels=1)
215x = torch.randn(4, 1, 28, 28)
216t = torch.randint(0, 1000, (4,))
217out = unet(x, t)
218print(f"U-Net input: {x.shape}")
219print(f"U-Net output: {out.shape}")
220print(f"Parameters: {sum(p.numel() for p in unet.parameters()):,}")
221
222
223# ============================================
224# 3. Training
225# ============================================
226print("\n[3] Training Loop")
227print("-" * 40)
228
229
230def train_diffusion(model, schedule, dataloader, epochs=5, lr=1e-3):
231    """Train diffusion model"""
232    model.to(device)
233    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
234    criterion = nn.MSELoss()
235
236    losses = []
237
238    for epoch in range(epochs):
239        total_loss = 0
240
241        for batch_idx, (images, _) in enumerate(dataloader):
242            images = images.to(device)
243            batch_size = images.size(0)
244
245            # Random timesteps
246            t = torch.randint(0, schedule.timesteps, (batch_size,), device=device).long()
247
248            # Add noise
249            noise = torch.randn_like(images)
250            x_t = schedule.q_sample(images, t, noise)
251
252            # Predict noise
253            noise_pred = model(x_t, t)
254
255            # Loss
256            loss = criterion(noise_pred, noise)
257
258            optimizer.zero_grad()
259            loss.backward()
260            optimizer.step()
261
262            total_loss += loss.item()
263
264        avg_loss = total_loss / len(dataloader)
265        losses.append(avg_loss)
266        print(f"Epoch {epoch+1}/{epochs}: Loss = {avg_loss:.6f}")
267
268    return losses
269
270
271# ============================================
272# 4. Sampling (Reverse Process)
273# ============================================
274print("\n[4] Sampling (Reverse Process)")
275print("-" * 40)
276
277
278@torch.no_grad()
279def sample_ddpm(model, schedule, shape, device, show_progress=True):
280    """DDPM sampling: generate images from pure noise"""
281    model.eval()
282
283    # Start from pure noise
284    x = torch.randn(shape, device=device)
285
286    for i in reversed(range(schedule.timesteps)):
287        t = torch.full((shape[0],), i, device=device, dtype=torch.long)
288
289        betas_t = get_index_from_list(schedule.betas, t, x.shape)
290        sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
291            schedule.sqrt_one_minus_alphas_cumprod, t, x.shape
292        )
293        sqrt_recip_alphas_t = get_index_from_list(
294            schedule.sqrt_recip_alphas, t, x.shape
295        )
296
297        # Predict noise
298        noise_pred = model(x, t)
299
300        # Compute mean
301        model_mean = sqrt_recip_alphas_t * (
302            x - betas_t * noise_pred / sqrt_one_minus_alphas_cumprod_t
303        )
304
305        # Add noise (except for t=0)
306        if i > 0:
307            posterior_variance_t = get_index_from_list(
308                schedule.posterior_variance, t, x.shape
309            )
310            noise = torch.randn_like(x)
311            x = model_mean + torch.sqrt(posterior_variance_t) * noise
312        else:
313            x = model_mean
314
315        if show_progress and i % 100 == 0:
316            print(f"  Sampling step {schedule.timesteps - i}/{schedule.timesteps}")
317
318    return x
319
320
321@torch.no_grad()
322def sample_ddim(model, schedule, shape, device, num_steps=50, eta=0.0):
323    """DDIM sampling: faster with fewer steps"""
324    model.eval()
325
326    # Create step sequence
327    step_size = schedule.timesteps // num_steps
328    timesteps = list(range(0, schedule.timesteps, step_size))
329    timesteps = list(reversed(timesteps))
330
331    x = torch.randn(shape, device=device)
332
333    for i, t in enumerate(timesteps):
334        t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)
335
336        alpha_cumprod_t = schedule.alphas_cumprod[t]
337
338        if i < len(timesteps) - 1:
339            alpha_cumprod_prev = schedule.alphas_cumprod[timesteps[i + 1]]
340        else:
341            alpha_cumprod_prev = torch.tensor(1.0, device=device)
342
343        # Predict noise
344        noise_pred = model(x, t_tensor)
345
346        # Predict x_0
347        pred_x0 = (x - torch.sqrt(1 - alpha_cumprod_t) * noise_pred) / torch.sqrt(alpha_cumprod_t)
348
349        # Compute variance
350        sigma = eta * torch.sqrt(
351            (1 - alpha_cumprod_prev) / (1 - alpha_cumprod_t) *
352            (1 - alpha_cumprod_t / alpha_cumprod_prev)
353        )
354
355        # Direction
356        pred_dir = torch.sqrt(1 - alpha_cumprod_prev - sigma ** 2) * noise_pred
357
358        # Next x
359        noise = torch.randn_like(x) if eta > 0 and i < len(timesteps) - 1 else 0
360        x = torch.sqrt(alpha_cumprod_prev) * pred_x0 + pred_dir + sigma * noise
361
362    return x
363
364
365# ============================================
366# 5. Visualize Diffusion Process
367# ============================================
368print("\n[5] Visualize Forward Process")
369print("-" * 40)
370
371
372def visualize_forward_process(schedule, image, timesteps_to_show):
373    """Show image at different noise levels"""
374    fig, axes = plt.subplots(1, len(timesteps_to_show), figsize=(15, 3))
375
376    for idx, t in enumerate(timesteps_to_show):
377        t_tensor = torch.tensor([t])
378        noisy = schedule.q_sample(image.unsqueeze(0), t_tensor)
379
380        axes[idx].imshow(noisy[0, 0].cpu(), cmap='gray')
381        axes[idx].set_title(f't = {t}')
382        axes[idx].axis('off')
383
384    plt.suptitle('Forward Diffusion Process')
385    plt.tight_layout()
386    plt.savefig('diffusion_forward.png', dpi=150)
387    plt.close()
388    print("Forward process visualization saved to diffusion_forward.png")
389
390
391# ============================================
392# 6. Training Example
393# ============================================
394print("\n[6] Training on MNIST")
395print("-" * 40)
396
397# Hyperparameters
398timesteps = 1000
399batch_size = 64
400epochs = 5  # Increase for better results
401lr = 1e-3
402
403# Data
404transform = transforms.Compose([
405    transforms.ToTensor(),
406    transforms.Lambda(lambda t: (t * 2) - 1)  # [0, 1] -> [-1, 1]
407])
408
409print("Loading MNIST dataset...")
410train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
411train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
412
413# Schedule
414schedule = DiffusionSchedule(timesteps=timesteps, beta_schedule='linear', device=device)
415
416# Model
417model = SimpleUNet(in_channels=1, out_channels=1).to(device)
418print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
419
420# Visualize forward process
421sample_img, _ = train_data[0]
422visualize_forward_process(schedule, sample_img, [0, 100, 300, 500, 700, 900, 999])
423
424# Train
425print("\nTraining diffusion model...")
426losses = train_diffusion(model, schedule, train_loader, epochs=epochs, lr=lr)
427
428# Plot loss
429plt.figure(figsize=(10, 4))
430plt.plot(losses)
431plt.xlabel('Epoch')
432plt.ylabel('Loss')
433plt.title('Diffusion Model Training Loss')
434plt.savefig('diffusion_loss.png', dpi=150)
435plt.close()
436print("Loss curve saved to diffusion_loss.png")
437
438# Sample
439print("\nGenerating samples with DDPM...")
440samples = sample_ddpm(model, schedule, (16, 1, 28, 28), device)
441samples = (samples + 1) / 2  # [-1, 1] -> [0, 1]
442samples = samples.clamp(0, 1)
443
444grid = vutils.make_grid(samples.cpu(), nrow=4, normalize=False, padding=2)
445plt.figure(figsize=(8, 8))
446plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
447plt.axis('off')
448plt.title('DDPM Generated Samples')
449plt.savefig('diffusion_samples.png', dpi=150)
450plt.close()
451print("Generated samples saved to diffusion_samples.png")
452
453# DDIM sampling (faster)
454print("\nGenerating samples with DDIM (50 steps)...")
455samples_ddim = sample_ddim(model, schedule, (16, 1, 28, 28), device, num_steps=50, eta=0.0)
456samples_ddim = (samples_ddim + 1) / 2
457samples_ddim = samples_ddim.clamp(0, 1)
458
459grid_ddim = vutils.make_grid(samples_ddim.cpu(), nrow=4, normalize=False, padding=2)
460plt.figure(figsize=(8, 8))
461plt.imshow(grid_ddim.permute(1, 2, 0).squeeze(), cmap='gray')
462plt.axis('off')
463plt.title('DDIM Generated Samples (50 steps)')
464plt.savefig('diffusion_samples_ddim.png', dpi=150)
465plt.close()
466print("DDIM samples saved to diffusion_samples_ddim.png")
467
468
469# ============================================
470# Summary
471# ============================================
472print("\n" + "=" * 60)
473print("Diffusion Model Summary")
474print("=" * 60)
475
476summary = """
477Key Concepts:
4781. Forward Process: Gradually add noise to data
479   x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
480
4812. Reverse Process: Learn to denoise step by step
482   Model predicts noise epsilon at each step
483
4843. Training: Simple MSE loss on noise prediction
485   L = ||epsilon - epsilon_theta(x_t, t)||^2
486
4874. Sampling:
488   - DDPM: 1000 steps, stochastic
489   - DDIM: 50-100 steps, deterministic
490
491Noise Schedules:
492- Linear: Simple, widely used
493- Cosine: Better quality for small images
494
495Key Parameters:
496- timesteps: Number of diffusion steps (1000)
497- beta_start, beta_end: Noise schedule bounds
498- U-Net: Time-conditioned denoising network
499
500Output Files:
501- diffusion_forward.png: Forward process visualization
502- diffusion_loss.png: Training loss curve
503- diffusion_samples.png: DDPM generated samples
504- diffusion_samples_ddim.png: DDIM generated samples
505"""
506print(summary)
507print("=" * 60)