13_diffusion_ddpm.py

Download
python 394 lines 12.5 KB
  1"""
  2DDPM (Denoising Diffusion Probabilistic Model) Implementation
  3
  4This script implements a simple DDPM for image generation following
  5"Denoising Diffusion Probabilistic Models" (Ho et al., 2020).
  6
  7Key concepts:
  8- Forward diffusion: gradually add Gaussian noise to data
  9- Reverse diffusion: learn to denoise and generate samples
 10- Linear beta schedule for noise variance
 11- Simple UNet architecture with time embedding
 12
 13References:
 14- Ho et al. (2020): https://arxiv.org/abs/2006.11239
 15"""
 16
 17import torch
 18import torch.nn as nn
 19import torch.nn.functional as F
 20from torch.utils.data import DataLoader
 21from torchvision import datasets, transforms
 22import matplotlib.pyplot as plt
 23import numpy as np
 24from tqdm import tqdm
 25
 26
 27# ============================================================================
 28# Noise Schedule
 29# ============================================================================
 30
 31def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
 32    """
 33    Linear schedule for beta (variance) from beta_start to beta_end.
 34
 35    Args:
 36        timesteps: number of diffusion steps (T)
 37        beta_start: minimum noise variance
 38        beta_end: maximum noise variance
 39
 40    Returns:
 41        betas: [T] tensor of noise variances
 42    """
 43    return torch.linspace(beta_start, beta_end, timesteps)
 44
 45
 46def get_diffusion_params(betas):
 47    """
 48    Precompute diffusion parameters for efficient sampling.
 49
 50    Args:
 51        betas: [T] noise schedule
 52
 53    Returns:
 54        Dictionary with precomputed parameters
 55    """
 56    alphas = 1.0 - betas
 57    alphas_cumprod = torch.cumprod(alphas, dim=0)
 58    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
 59
 60    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
 61    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
 62
 63    # Posterior variance for reverse process
 64    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
 65
 66    return {
 67        'betas': betas,
 68        'alphas': alphas,
 69        'alphas_cumprod': alphas_cumprod,
 70        'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
 71        'sqrt_one_minus_alphas_cumprod': sqrt_one_minus_alphas_cumprod,
 72        'posterior_variance': posterior_variance,
 73    }
 74
 75
 76# ============================================================================
 77# Time Embedding
 78# ============================================================================
 79
 80class SinusoidalPositionEmbedding(nn.Module):
 81    """
 82    Sinusoidal time embedding similar to Transformer positional encoding.
 83    Maps timestep t to a high-dimensional vector.
 84    """
 85    def __init__(self, dim):
 86        super().__init__()
 87        self.dim = dim
 88
 89    def forward(self, t):
 90        """
 91        Args:
 92            t: [batch_size] timesteps
 93        Returns:
 94            [batch_size, dim] embeddings
 95        """
 96        device = t.device
 97        half_dim = self.dim // 2
 98        embeddings = np.log(10000) / (half_dim - 1)
 99        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
100        embeddings = t[:, None] * embeddings[None, :]
101        embeddings = torch.cat([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
102        return embeddings
103
104
105# ============================================================================
106# Simple UNet Architecture
107# ============================================================================
108
109class SimpleUNet(nn.Module):
110    """
111    Simplified UNet for DDPM with time conditioning.
112
113    Architecture:
114    - Encoder: downsampling with conv blocks
115    - Decoder: upsampling with conv blocks
116    - Time embedding injected at each resolution
117    """
118    def __init__(self, in_channels=1, out_channels=1, time_dim=128, base_channels=64):
119        super().__init__()
120
121        # Time embedding
122        self.time_mlp = nn.Sequential(
123            SinusoidalPositionEmbedding(time_dim),
124            nn.Linear(time_dim, time_dim * 4),
125            nn.GELU(),
126            nn.Linear(time_dim * 4, time_dim),
127        )
128
129        # Encoder (downsampling)
130        self.enc1 = self._make_block(in_channels, base_channels, time_dim)
131        self.enc2 = self._make_block(base_channels, base_channels * 2, time_dim)
132        self.enc3 = self._make_block(base_channels * 2, base_channels * 4, time_dim)
133
134        # Bottleneck
135        self.bottleneck = self._make_block(base_channels * 4, base_channels * 4, time_dim)
136
137        # Decoder (upsampling)
138        self.dec3 = self._make_block(base_channels * 8, base_channels * 2, time_dim)
139        self.dec2 = self._make_block(base_channels * 4, base_channels, time_dim)
140        self.dec1 = self._make_block(base_channels * 2, base_channels, time_dim)
141
142        # Output layer
143        self.out = nn.Conv2d(base_channels, out_channels, 1)
144
145    def _make_block(self, in_ch, out_ch, time_dim):
146        """Create a residual block with time conditioning."""
147        return nn.ModuleDict({
148            'conv1': nn.Conv2d(in_ch, out_ch, 3, padding=1),
149            'conv2': nn.Conv2d(out_ch, out_ch, 3, padding=1),
150            'time_proj': nn.Linear(time_dim, out_ch),
151            'norm1': nn.GroupNorm(8, out_ch),
152            'norm2': nn.GroupNorm(8, out_ch),
153        })
154
155    def _forward_block(self, x, t_emb, block):
156        """Forward pass through a block with time embedding."""
157        h = block['conv1'](x)
158        h = block['norm1'](h)
159
160        # Add time embedding
161        t_proj = block['time_proj'](t_emb)[:, :, None, None]
162        h = h + t_proj
163
164        h = F.gelu(h)
165        h = block['conv2'](h)
166        h = block['norm2'](h)
167        h = F.gelu(h)
168
169        return h
170
171    def forward(self, x, t):
172        """
173        Args:
174            x: [B, C, H, W] noisy images
175            t: [B] timesteps
176        Returns:
177            [B, C, H, W] predicted noise
178        """
179        # Time embedding
180        t_emb = self.time_mlp(t)
181
182        # Encoder
183        x1 = self._forward_block(x, t_emb, self.enc1)
184        x2 = F.max_pool2d(x1, 2)
185
186        x2 = self._forward_block(x2, t_emb, self.enc2)
187        x3 = F.max_pool2d(x2, 2)
188
189        x3 = self._forward_block(x3, t_emb, self.enc3)
190        x4 = F.max_pool2d(x3, 2)
191
192        # Bottleneck
193        x4 = self._forward_block(x4, t_emb, self.bottleneck)
194
195        # Decoder with skip connections
196        x = F.interpolate(x4, scale_factor=2, mode='nearest')
197        x = torch.cat([x, x3], dim=1)
198        x = self._forward_block(x, t_emb, self.dec3)
199
200        x = F.interpolate(x, scale_factor=2, mode='nearest')
201        x = torch.cat([x, x2], dim=1)
202        x = self._forward_block(x, t_emb, self.dec2)
203
204        x = F.interpolate(x, scale_factor=2, mode='nearest')
205        x = torch.cat([x, x1], dim=1)
206        x = self._forward_block(x, t_emb, self.dec1)
207
208        return self.out(x)
209
210
211# ============================================================================
212# Diffusion Process
213# ============================================================================
214
215def forward_diffusion(x0, t, params, device):
216    """
217    Add noise to data according to forward diffusion process.
218
219    q(x_t | x_0) = N(x_t; sqrt(alpha_bar_t) * x_0, (1 - alpha_bar_t) * I)
220
221    Args:
222        x0: [B, C, H, W] clean images
223        t: [B] timesteps
224        params: diffusion parameters
225        device: torch device
226
227    Returns:
228        noisy_x: [B, C, H, W] noisy images
229        noise: [B, C, H, W] added noise
230    """
231    noise = torch.randn_like(x0)
232
233    sqrt_alpha_cumprod_t = params['sqrt_alphas_cumprod'][t][:, None, None, None]
234    sqrt_one_minus_alpha_cumprod_t = params['sqrt_one_minus_alphas_cumprod'][t][:, None, None, None]
235
236    noisy_x = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
237
238    return noisy_x, noise
239
240
241@torch.no_grad()
242def sample(model, params, image_size, batch_size, timesteps, device):
243    """
244    Generate samples using reverse diffusion process.
245
246    Start from random noise and iteratively denoise.
247
248    Args:
249        model: trained UNet
250        params: diffusion parameters
251        image_size: (C, H, W)
252        batch_size: number of samples
253        timesteps: number of diffusion steps
254        device: torch device
255
256    Returns:
257        [batch_size, C, H, W] generated images
258    """
259    model.eval()
260
261    # Start from random noise
262    x = torch.randn(batch_size, *image_size, device=device)
263
264    for i in tqdm(reversed(range(timesteps)), desc='Sampling', total=timesteps):
265        t = torch.full((batch_size,), i, device=device, dtype=torch.long)
266
267        # Predict noise
268        predicted_noise = model(x, t)
269
270        # Get parameters for this timestep
271        alpha = params['alphas'][t][:, None, None, None]
272        alpha_cumprod = params['alphas_cumprod'][t][:, None, None, None]
273        beta = params['betas'][t][:, None, None, None]
274
275        # Compute mean of reverse distribution
276        if i > 0:
277            noise = torch.randn_like(x)
278        else:
279            noise = torch.zeros_like(x)
280
281        # Reverse diffusion step
282        x = (1 / torch.sqrt(alpha)) * (
283            x - ((1 - alpha) / torch.sqrt(1 - alpha_cumprod)) * predicted_noise
284        ) + torch.sqrt(beta) * noise
285
286    return x
287
288
289# ============================================================================
290# Training
291# ============================================================================
292
293def train_ddpm(epochs=10, batch_size=128, timesteps=1000, device='cuda'):
294    """
295    Train DDPM on MNIST dataset.
296
297    Args:
298        epochs: number of training epochs
299        batch_size: batch size
300        timesteps: number of diffusion steps (T)
301        device: 'cuda' or 'cpu'
302    """
303    device = torch.device(device if torch.cuda.is_available() else 'cpu')
304
305    # Data preparation
306    transform = transforms.Compose([
307        transforms.ToTensor(),
308        transforms.Normalize((0.5,), (0.5,))  # Scale to [-1, 1]
309    ])
310
311    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
312    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
313
314    # Initialize model and diffusion parameters
315    model = SimpleUNet(in_channels=1, out_channels=1).to(device)
316    betas = linear_beta_schedule(timesteps).to(device)
317    params = get_diffusion_params(betas)
318
319    optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
320
321    # Training loop
322    for epoch in range(epochs):
323        model.train()
324        total_loss = 0
325
326        for batch_idx, (images, _) in enumerate(dataloader):
327            images = images.to(device)
328            batch_size_actual = images.shape[0]
329
330            # Sample random timesteps
331            t = torch.randint(0, timesteps, (batch_size_actual,), device=device)
332
333            # Forward diffusion (add noise)
334            noisy_images, noise = forward_diffusion(images, t, params, device)
335
336            # Predict noise
337            predicted_noise = model(noisy_images, t)
338
339            # MSE loss between predicted and actual noise
340            loss = F.mse_loss(predicted_noise, noise)
341
342            # Optimization step
343            optimizer.zero_grad()
344            loss.backward()
345            optimizer.step()
346
347            total_loss += loss.item()
348
349        avg_loss = total_loss / len(dataloader)
350        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')
351
352        # Sample images every 5 epochs
353        if (epoch + 1) % 5 == 0:
354            samples = sample(model, params, (1, 28, 28), 16, timesteps, device)
355            samples = (samples + 1) / 2  # Denormalize to [0, 1]
356
357            # Visualize
358            fig, axes = plt.subplots(4, 4, figsize=(8, 8))
359            for i, ax in enumerate(axes.flat):
360                ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
361                ax.axis('off')
362            plt.suptitle(f'Generated Samples - Epoch {epoch+1}')
363            plt.tight_layout()
364            plt.savefig(f'ddpm_samples_epoch_{epoch+1}.png')
365            plt.close()
366
367    print("Training completed!")
368    return model, params
369
370
371# ============================================================================
372# Main
373# ============================================================================
374
375if __name__ == '__main__':
376    # Train model
377    model, params = train_ddpm(epochs=10, batch_size=128, timesteps=1000, device='cuda')
378
379    # Generate final samples
380    print("\nGenerating final samples...")
381    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
382    samples = sample(model, params, (1, 28, 28), 64, 1000, device)
383    samples = (samples + 1) / 2
384
385    # Visualize final samples
386    fig, axes = plt.subplots(8, 8, figsize=(12, 12))
387    for i, ax in enumerate(axes.flat):
388        ax.imshow(samples[i].cpu().squeeze(), cmap='gray')
389        ax.axis('off')
390    plt.suptitle('DDPM Generated Samples (Final)')
391    plt.tight_layout()
392    plt.savefig('ddpm_final_samples.png')
393    plt.show()