16_vae.py

Download
python 539 lines 15.8 KB
  1"""
  216. Variational Autoencoder (VAE) Implementation
  3
  4VAE implementation for MNIST digit generation with latent space visualization.
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10from torch.utils.data import DataLoader
 11from torchvision import datasets, transforms
 12import torchvision.utils as vutils
 13import matplotlib.pyplot as plt
 14import numpy as np
 15
 16print("=" * 60)
 17print("Variational Autoencoder (VAE) Implementation")
 18print("=" * 60)
 19
 20device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 21print(f"Using device: {device}")
 22
 23
 24# ============================================
 25# 1. VAE Components
 26# ============================================
 27print("\n[1] VAE Architecture")
 28print("-" * 40)
 29
 30
 31class VAEEncoder(nn.Module):
 32    """VAE Encoder: Image -> mu, log_var"""
 33    def __init__(self, in_channels=1, latent_dim=20):
 34        super().__init__()
 35
 36        self.conv_layers = nn.Sequential(
 37            nn.Conv2d(in_channels, 32, 3, stride=2, padding=1),  # 28 -> 14
 38            nn.ReLU(),
 39            nn.Conv2d(32, 64, 3, stride=2, padding=1),           # 14 -> 7
 40            nn.ReLU(),
 41            nn.Flatten()
 42        )
 43
 44        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
 45        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
 46
 47    def forward(self, x):
 48        h = self.conv_layers(x)
 49        mu = self.fc_mu(h)
 50        log_var = self.fc_logvar(h)
 51        return mu, log_var
 52
 53
 54class VAEDecoder(nn.Module):
 55    """VAE Decoder: z -> Image"""
 56    def __init__(self, latent_dim=20, out_channels=1):
 57        super().__init__()
 58
 59        self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
 60
 61        self.deconv_layers = nn.Sequential(
 62            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),  # 7 -> 14
 63            nn.ReLU(),
 64            nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1),  # 14 -> 28
 65            nn.Sigmoid()
 66        )
 67
 68    def forward(self, z):
 69        h = self.fc(z)
 70        h = h.view(-1, 64, 7, 7)
 71        return self.deconv_layers(h)
 72
 73
 74class VAE(nn.Module):
 75    """Variational Autoencoder"""
 76    def __init__(self, in_channels=1, latent_dim=20):
 77        super().__init__()
 78        self.encoder = VAEEncoder(in_channels, latent_dim)
 79        self.decoder = VAEDecoder(latent_dim, in_channels)
 80        self.latent_dim = latent_dim
 81
 82    def reparameterize(self, mu, log_var):
 83        """Reparameterization trick: z = mu + sigma * epsilon"""
 84        std = torch.exp(0.5 * log_var)
 85        eps = torch.randn_like(std)
 86        return mu + std * eps
 87
 88    def forward(self, x):
 89        mu, log_var = self.encoder(x)
 90        z = self.reparameterize(mu, log_var)
 91        x_recon = self.decoder(z)
 92        return x_recon, mu, log_var
 93
 94    def generate(self, num_samples, device):
 95        """Generate new samples from prior"""
 96        z = torch.randn(num_samples, self.latent_dim, device=device)
 97        return self.decoder(z)
 98
 99    def reconstruct(self, x):
100        """Reconstruct input images"""
101        with torch.no_grad():
102            x_recon, _, _ = self.forward(x)
103        return x_recon
104
105
106# Test VAE
107vae = VAE(in_channels=1, latent_dim=20)
108x = torch.randn(4, 1, 28, 28)
109x_recon, mu, log_var = vae(x)
110print(f"Input: {x.shape}")
111print(f"Reconstruction: {x_recon.shape}")
112print(f"Mu: {mu.shape}, Log_var: {log_var.shape}")
113
114
115# ============================================
116# 2. Loss Functions
117# ============================================
118print("\n[2] VAE Loss Functions")
119print("-" * 40)
120
121
122def vae_loss(x, x_recon, mu, log_var, beta=1.0):
123    """VAE ELBO Loss: Reconstruction + KL Divergence
124
125    Args:
126        x: Original images
127        x_recon: Reconstructed images
128        mu: Latent mean
129        log_var: Latent log variance
130        beta: KL weight (beta > 1 for beta-VAE)
131
132    Returns:
133        total_loss, recon_loss, kl_loss
134    """
135    # Reconstruction loss (Binary Cross Entropy)
136    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
137
138    # KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
139    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
140
141    total_loss = recon_loss + beta * kl_loss
142
143    return total_loss, recon_loss, kl_loss
144
145
146def vae_loss_mse(x, x_recon, mu, log_var, beta=1.0):
147    """VAE Loss with MSE reconstruction"""
148    recon_loss = F.mse_loss(x_recon, x, reduction='sum')
149    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
150    total_loss = recon_loss + beta * kl_loss
151    return total_loss, recon_loss, kl_loss
152
153
154print("Loss functions defined")
155
156
157# ============================================
158# 3. Training Loop
159# ============================================
160print("\n[3] Training Loop")
161print("-" * 40)
162
163
164def train_vae(model, dataloader, epochs=10, lr=1e-3, beta=1.0):
165    """Train VAE on MNIST"""
166    model.to(device)
167    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
168
169    history = {'total': [], 'recon': [], 'kl': []}
170
171    for epoch in range(epochs):
172        model.train()
173        total_loss = 0
174        total_recon = 0
175        total_kl = 0
176
177        for batch_idx, (data, _) in enumerate(dataloader):
178            data = data.to(device)
179            optimizer.zero_grad()
180
181            # Forward
182            x_recon, mu, log_var = model(data)
183
184            # Loss
185            loss, recon, kl = vae_loss(data, x_recon, mu, log_var, beta)
186            loss = loss / data.size(0)  # Normalize by batch
187
188            # Backward
189            loss.backward()
190            optimizer.step()
191
192            total_loss += loss.item()
193            total_recon += recon.item() / data.size(0)
194            total_kl += kl.item() / data.size(0)
195
196        avg_loss = total_loss / len(dataloader)
197        avg_recon = total_recon / len(dataloader)
198        avg_kl = total_kl / len(dataloader)
199
200        history['total'].append(avg_loss)
201        history['recon'].append(avg_recon)
202        history['kl'].append(avg_kl)
203
204        print(f"Epoch {epoch+1}/{epochs}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}")
205
206    return history
207
208
209# ============================================
210# 4. Latent Space Visualization
211# ============================================
212print("\n[4] Latent Space Visualization")
213print("-" * 40)
214
215
216def visualize_latent_space(model, dataloader, device):
217    """Visualize 2D latent space with class colors"""
218    model.eval()
219    latents = []
220    labels = []
221
222    with torch.no_grad():
223        for data, label in dataloader:
224            data = data.to(device)
225            mu, _ = model.encoder(data)
226            latents.append(mu.cpu())
227            labels.append(label)
228
229    latents = torch.cat(latents, dim=0).numpy()
230    labels = torch.cat(labels, dim=0).numpy()
231
232    # Only plot first 2 dimensions
233    plt.figure(figsize=(10, 8))
234    scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6, s=5)
235    plt.colorbar(scatter, label='Digit')
236    plt.xlabel('z[0]')
237    plt.ylabel('z[1]')
238    plt.title('VAE Latent Space (First 2 Dimensions)')
239    plt.savefig('vae_latent_space.png', dpi=150)
240    plt.close()
241    print("Latent space visualization saved to vae_latent_space.png")
242
243
244def generate_manifold(model, n=20, latent_dim=2, device='cpu', digit_size=28):
245    """Generate manifold of digits from 2D latent space"""
246    model.eval()
247
248    # Create grid in latent space
249    grid_x = torch.linspace(-3, 3, n)
250    grid_y = torch.linspace(-3, 3, n)
251
252    figure = np.zeros((digit_size * n, digit_size * n))
253
254    with torch.no_grad():
255        for i, yi in enumerate(grid_y):
256            for j, xi in enumerate(grid_x):
257                z = torch.zeros(1, latent_dim, device=device)
258                z[0, 0] = xi
259                z[0, 1] = yi
260
261                x_decoded = model.decoder(z)
262                digit = x_decoded[0, 0].cpu().numpy()
263
264                figure[i * digit_size:(i + 1) * digit_size,
265                       j * digit_size:(j + 1) * digit_size] = digit
266
267    plt.figure(figsize=(10, 10))
268    plt.imshow(figure, cmap='gray')
269    plt.axis('off')
270    plt.title('VAE Manifold')
271    plt.savefig('vae_manifold.png', dpi=150)
272    plt.close()
273    print("Manifold saved to vae_manifold.png")
274
275
276def explore_latent_dimension(model, dim_idx, range_vals, fixed_z, device):
277    """Explore effect of changing one latent dimension"""
278    model.eval()
279    images = []
280
281    with torch.no_grad():
282        for val in range_vals:
283            z = fixed_z.clone()
284            z[0, dim_idx] = val
285            img = model.decoder(z.to(device))
286            images.append(img.cpu())
287
288    return torch.cat(images, dim=0)
289
290
291# ============================================
292# 5. Reconstruction Visualization
293# ============================================
294print("\n[5] Reconstruction Visualization")
295print("-" * 40)
296
297
298def visualize_reconstructions(model, dataloader, num_samples=8, device='cpu'):
299    """Show original vs reconstructed images"""
300    model.eval()
301
302    # Get batch
303    data, _ = next(iter(dataloader))
304    data = data[:num_samples].to(device)
305
306    with torch.no_grad():
307        recon, _, _ = model(data)
308
309    # Create comparison grid
310    comparison = torch.cat([data, recon], dim=0)
311    grid = vutils.make_grid(comparison.cpu(), nrow=num_samples, normalize=True, padding=2)
312
313    plt.figure(figsize=(12, 4))
314    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
315    plt.axis('off')
316    plt.title('Original (top) vs Reconstructed (bottom)')
317    plt.savefig('vae_reconstruction.png', dpi=150)
318    plt.close()
319    print("Reconstruction comparison saved to vae_reconstruction.png")
320
321
322# ============================================
323# 6. Beta-VAE
324# ============================================
325print("\n[6] Beta-VAE for Disentanglement")
326print("-" * 40)
327
328
329class BetaVAE(VAE):
330    """Beta-VAE with higher KL weight for disentanglement"""
331    def __init__(self, in_channels=1, latent_dim=10, beta=4.0):
332        super().__init__(in_channels, latent_dim)
333        self.beta = beta
334
335    def loss(self, x, x_recon, mu, log_var):
336        return vae_loss(x, x_recon, mu, log_var, self.beta)
337
338
339print(f"Beta-VAE class defined with configurable beta parameter")
340
341
342# ============================================
343# 7. Conditional VAE (CVAE)
344# ============================================
345print("\n[7] Conditional VAE")
346print("-" * 40)
347
348
349class CVAEEncoder(nn.Module):
350    """CVAE Encoder with label conditioning"""
351    def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
352        super().__init__()
353        self.num_classes = num_classes
354
355        self.conv_layers = nn.Sequential(
356            nn.Conv2d(in_channels + num_classes, 32, 3, stride=2, padding=1),
357            nn.ReLU(),
358            nn.Conv2d(32, 64, 3, stride=2, padding=1),
359            nn.ReLU(),
360            nn.Flatten()
361        )
362
363        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
364        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
365
366    def forward(self, x, y):
367        # Concat one-hot label as additional channels
368        y_expanded = y.view(-1, self.num_classes, 1, 1).expand(-1, -1, x.size(2), x.size(3))
369        x_cond = torch.cat([x, y_expanded], dim=1)
370
371        h = self.conv_layers(x_cond)
372        return self.fc_mu(h), self.fc_logvar(h)
373
374
375class CVAEDecoder(nn.Module):
376    """CVAE Decoder with label conditioning"""
377    def __init__(self, latent_dim=20, out_channels=1, num_classes=10):
378        super().__init__()
379        self.num_classes = num_classes
380
381        self.fc = nn.Linear(latent_dim + num_classes, 64 * 7 * 7)
382
383        self.deconv_layers = nn.Sequential(
384            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
385            nn.ReLU(),
386            nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1),
387            nn.Sigmoid()
388        )
389
390    def forward(self, z, y):
391        z_cond = torch.cat([z, y], dim=1)
392        h = self.fc(z_cond)
393        h = h.view(-1, 64, 7, 7)
394        return self.deconv_layers(h)
395
396
397class CVAE(nn.Module):
398    """Conditional VAE for digit generation"""
399    def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
400        super().__init__()
401        self.encoder = CVAEEncoder(in_channels, latent_dim, num_classes)
402        self.decoder = CVAEDecoder(latent_dim, in_channels, num_classes)
403        self.latent_dim = latent_dim
404        self.num_classes = num_classes
405
406    def reparameterize(self, mu, log_var):
407        std = torch.exp(0.5 * log_var)
408        eps = torch.randn_like(std)
409        return mu + std * eps
410
411    def forward(self, x, label):
412        y = F.one_hot(label, self.num_classes).float()
413        mu, log_var = self.encoder(x, y)
414        z = self.reparameterize(mu, log_var)
415        x_recon = self.decoder(z, y)
416        return x_recon, mu, log_var
417
418    def generate(self, label, num_samples, device):
419        """Generate specific digit"""
420        z = torch.randn(num_samples, self.latent_dim, device=device)
421        y = F.one_hot(label.expand(num_samples), self.num_classes).float().to(device)
422        return self.decoder(z, y)
423
424
425print("CVAE class defined for conditional generation")
426
427
428# ============================================
429# 8. Training Example
430# ============================================
431print("\n[8] Training VAE on MNIST")
432print("-" * 40)
433
434# Hyperparameters
435latent_dim = 20
436batch_size = 128
437epochs = 10
438lr = 1e-3
439
440# Data
441transform = transforms.ToTensor()
442print("Loading MNIST dataset...")
443train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
444test_data = datasets.MNIST('data', train=False, transform=transform)
445
446train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
447test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0)
448
449# Model
450vae = VAE(in_channels=1, latent_dim=latent_dim).to(device)
451print(f"VAE parameters: {sum(p.numel() for p in vae.parameters()):,}")
452
453# Train
454print("\nTraining VAE...")
455history = train_vae(vae, train_loader, epochs=epochs, lr=lr, beta=1.0)
456
457# Visualizations
458print("\nGenerating visualizations...")
459
460# 1. Reconstruction
461visualize_reconstructions(vae, test_loader, num_samples=10, device=device)
462
463# 2. Generated samples
464vae.eval()
465with torch.no_grad():
466    samples = vae.generate(64, device)
467    grid = vutils.make_grid(samples.cpu(), nrow=8, normalize=True)
468    plt.figure(figsize=(8, 8))
469    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
470    plt.axis('off')
471    plt.title('Generated Samples')
472    plt.savefig('vae_samples.png', dpi=150)
473    plt.close()
474print("Generated samples saved to vae_samples.png")
475
476# 3. Latent space (for 2D visualization, use latent_dim=2)
477vae_2d = VAE(in_channels=1, latent_dim=2).to(device)
478print("\nTraining 2D VAE for visualization...")
479_ = train_vae(vae_2d, train_loader, epochs=5, lr=lr)
480visualize_latent_space(vae_2d, test_loader, device)
481generate_manifold(vae_2d, n=20, latent_dim=2, device=device)
482
483# 4. Loss curves
484plt.figure(figsize=(10, 4))
485plt.subplot(1, 2, 1)
486plt.plot(history['total'], label='Total')
487plt.plot(history['recon'], label='Reconstruction')
488plt.xlabel('Epoch')
489plt.ylabel('Loss')
490plt.legend()
491plt.title('Training Loss')
492
493plt.subplot(1, 2, 2)
494plt.plot(history['kl'], label='KL Divergence', color='orange')
495plt.xlabel('Epoch')
496plt.ylabel('Loss')
497plt.legend()
498plt.title('KL Divergence')
499
500plt.tight_layout()
501plt.savefig('vae_loss.png', dpi=150)
502plt.close()
503print("Loss curves saved to vae_loss.png")
504
505
506# ============================================
507# Summary
508# ============================================
509print("\n" + "=" * 60)
510print("VAE Summary")
511print("=" * 60)
512
513summary = """
514Key Concepts:
5151. VAE: Probabilistic latent variable model
5162. Reparameterization: z = mu + sigma * epsilon
5173. ELBO: Reconstruction + KL Divergence
5184. Beta-VAE: beta > 1 for disentanglement
5195. CVAE: Conditional generation
520
521Loss Function:
522L = E[log p(x|z)] - beta * KL(q(z|x) || p(z))
523  = BCE(x, x_recon) + beta * (-0.5 * sum(1 + log(var) - mu^2 - var))
524
525Latent Space Properties:
526- Continuous and structured
527- Can interpolate between samples
528- Each dimension captures a factor of variation
529
530Output Files:
531- vae_samples.png: Generated samples
532- vae_reconstruction.png: Original vs reconstructed
533- vae_latent_space.png: 2D latent space visualization
534- vae_manifold.png: Digit manifold
535- vae_loss.png: Training loss curves
536"""
537print(summary)
538print("=" * 60)