1"""
216. Variational Autoencoder (VAE) Implementation
3
4VAE implementation for MNIST digit generation with latent space visualization.
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10from torch.utils.data import DataLoader
11from torchvision import datasets, transforms
12import torchvision.utils as vutils
13import matplotlib.pyplot as plt
14import numpy as np
15
16print("=" * 60)
17print("Variational Autoencoder (VAE) Implementation")
18print("=" * 60)
19
20device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21print(f"Using device: {device}")
22
23
24# ============================================
25# 1. VAE Components
26# ============================================
27print("\n[1] VAE Architecture")
28print("-" * 40)
29
30
31class VAEEncoder(nn.Module):
32 """VAE Encoder: Image -> mu, log_var"""
33 def __init__(self, in_channels=1, latent_dim=20):
34 super().__init__()
35
36 self.conv_layers = nn.Sequential(
37 nn.Conv2d(in_channels, 32, 3, stride=2, padding=1), # 28 -> 14
38 nn.ReLU(),
39 nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14 -> 7
40 nn.ReLU(),
41 nn.Flatten()
42 )
43
44 self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
45 self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
46
47 def forward(self, x):
48 h = self.conv_layers(x)
49 mu = self.fc_mu(h)
50 log_var = self.fc_logvar(h)
51 return mu, log_var
52
53
54class VAEDecoder(nn.Module):
55 """VAE Decoder: z -> Image"""
56 def __init__(self, latent_dim=20, out_channels=1):
57 super().__init__()
58
59 self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
60
61 self.deconv_layers = nn.Sequential(
62 nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 7 -> 14
63 nn.ReLU(),
64 nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1), # 14 -> 28
65 nn.Sigmoid()
66 )
67
68 def forward(self, z):
69 h = self.fc(z)
70 h = h.view(-1, 64, 7, 7)
71 return self.deconv_layers(h)
72
73
74class VAE(nn.Module):
75 """Variational Autoencoder"""
76 def __init__(self, in_channels=1, latent_dim=20):
77 super().__init__()
78 self.encoder = VAEEncoder(in_channels, latent_dim)
79 self.decoder = VAEDecoder(latent_dim, in_channels)
80 self.latent_dim = latent_dim
81
82 def reparameterize(self, mu, log_var):
83 """Reparameterization trick: z = mu + sigma * epsilon"""
84 std = torch.exp(0.5 * log_var)
85 eps = torch.randn_like(std)
86 return mu + std * eps
87
88 def forward(self, x):
89 mu, log_var = self.encoder(x)
90 z = self.reparameterize(mu, log_var)
91 x_recon = self.decoder(z)
92 return x_recon, mu, log_var
93
94 def generate(self, num_samples, device):
95 """Generate new samples from prior"""
96 z = torch.randn(num_samples, self.latent_dim, device=device)
97 return self.decoder(z)
98
99 def reconstruct(self, x):
100 """Reconstruct input images"""
101 with torch.no_grad():
102 x_recon, _, _ = self.forward(x)
103 return x_recon
104
105
106# Test VAE
107vae = VAE(in_channels=1, latent_dim=20)
108x = torch.randn(4, 1, 28, 28)
109x_recon, mu, log_var = vae(x)
110print(f"Input: {x.shape}")
111print(f"Reconstruction: {x_recon.shape}")
112print(f"Mu: {mu.shape}, Log_var: {log_var.shape}")
113
114
115# ============================================
116# 2. Loss Functions
117# ============================================
118print("\n[2] VAE Loss Functions")
119print("-" * 40)
120
121
122def vae_loss(x, x_recon, mu, log_var, beta=1.0):
123 """VAE ELBO Loss: Reconstruction + KL Divergence
124
125 Args:
126 x: Original images
127 x_recon: Reconstructed images
128 mu: Latent mean
129 log_var: Latent log variance
130 beta: KL weight (beta > 1 for beta-VAE)
131
132 Returns:
133 total_loss, recon_loss, kl_loss
134 """
135 # Reconstruction loss (Binary Cross Entropy)
136 recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
137
138 # KL Divergence: -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
139 kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
140
141 total_loss = recon_loss + beta * kl_loss
142
143 return total_loss, recon_loss, kl_loss
144
145
146def vae_loss_mse(x, x_recon, mu, log_var, beta=1.0):
147 """VAE Loss with MSE reconstruction"""
148 recon_loss = F.mse_loss(x_recon, x, reduction='sum')
149 kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
150 total_loss = recon_loss + beta * kl_loss
151 return total_loss, recon_loss, kl_loss
152
153
154print("Loss functions defined")
155
156
157# ============================================
158# 3. Training Loop
159# ============================================
160print("\n[3] Training Loop")
161print("-" * 40)
162
163
164def train_vae(model, dataloader, epochs=10, lr=1e-3, beta=1.0):
165 """Train VAE on MNIST"""
166 model.to(device)
167 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
168
169 history = {'total': [], 'recon': [], 'kl': []}
170
171 for epoch in range(epochs):
172 model.train()
173 total_loss = 0
174 total_recon = 0
175 total_kl = 0
176
177 for batch_idx, (data, _) in enumerate(dataloader):
178 data = data.to(device)
179 optimizer.zero_grad()
180
181 # Forward
182 x_recon, mu, log_var = model(data)
183
184 # Loss
185 loss, recon, kl = vae_loss(data, x_recon, mu, log_var, beta)
186 loss = loss / data.size(0) # Normalize by batch
187
188 # Backward
189 loss.backward()
190 optimizer.step()
191
192 total_loss += loss.item()
193 total_recon += recon.item() / data.size(0)
194 total_kl += kl.item() / data.size(0)
195
196 avg_loss = total_loss / len(dataloader)
197 avg_recon = total_recon / len(dataloader)
198 avg_kl = total_kl / len(dataloader)
199
200 history['total'].append(avg_loss)
201 history['recon'].append(avg_recon)
202 history['kl'].append(avg_kl)
203
204 print(f"Epoch {epoch+1}/{epochs}: Loss={avg_loss:.4f}, Recon={avg_recon:.4f}, KL={avg_kl:.4f}")
205
206 return history
207
208
209# ============================================
210# 4. Latent Space Visualization
211# ============================================
212print("\n[4] Latent Space Visualization")
213print("-" * 40)
214
215
216def visualize_latent_space(model, dataloader, device):
217 """Visualize 2D latent space with class colors"""
218 model.eval()
219 latents = []
220 labels = []
221
222 with torch.no_grad():
223 for data, label in dataloader:
224 data = data.to(device)
225 mu, _ = model.encoder(data)
226 latents.append(mu.cpu())
227 labels.append(label)
228
229 latents = torch.cat(latents, dim=0).numpy()
230 labels = torch.cat(labels, dim=0).numpy()
231
232 # Only plot first 2 dimensions
233 plt.figure(figsize=(10, 8))
234 scatter = plt.scatter(latents[:, 0], latents[:, 1], c=labels, cmap='tab10', alpha=0.6, s=5)
235 plt.colorbar(scatter, label='Digit')
236 plt.xlabel('z[0]')
237 plt.ylabel('z[1]')
238 plt.title('VAE Latent Space (First 2 Dimensions)')
239 plt.savefig('vae_latent_space.png', dpi=150)
240 plt.close()
241 print("Latent space visualization saved to vae_latent_space.png")
242
243
244def generate_manifold(model, n=20, latent_dim=2, device='cpu', digit_size=28):
245 """Generate manifold of digits from 2D latent space"""
246 model.eval()
247
248 # Create grid in latent space
249 grid_x = torch.linspace(-3, 3, n)
250 grid_y = torch.linspace(-3, 3, n)
251
252 figure = np.zeros((digit_size * n, digit_size * n))
253
254 with torch.no_grad():
255 for i, yi in enumerate(grid_y):
256 for j, xi in enumerate(grid_x):
257 z = torch.zeros(1, latent_dim, device=device)
258 z[0, 0] = xi
259 z[0, 1] = yi
260
261 x_decoded = model.decoder(z)
262 digit = x_decoded[0, 0].cpu().numpy()
263
264 figure[i * digit_size:(i + 1) * digit_size,
265 j * digit_size:(j + 1) * digit_size] = digit
266
267 plt.figure(figsize=(10, 10))
268 plt.imshow(figure, cmap='gray')
269 plt.axis('off')
270 plt.title('VAE Manifold')
271 plt.savefig('vae_manifold.png', dpi=150)
272 plt.close()
273 print("Manifold saved to vae_manifold.png")
274
275
276def explore_latent_dimension(model, dim_idx, range_vals, fixed_z, device):
277 """Explore effect of changing one latent dimension"""
278 model.eval()
279 images = []
280
281 with torch.no_grad():
282 for val in range_vals:
283 z = fixed_z.clone()
284 z[0, dim_idx] = val
285 img = model.decoder(z.to(device))
286 images.append(img.cpu())
287
288 return torch.cat(images, dim=0)
289
290
291# ============================================
292# 5. Reconstruction Visualization
293# ============================================
294print("\n[5] Reconstruction Visualization")
295print("-" * 40)
296
297
298def visualize_reconstructions(model, dataloader, num_samples=8, device='cpu'):
299 """Show original vs reconstructed images"""
300 model.eval()
301
302 # Get batch
303 data, _ = next(iter(dataloader))
304 data = data[:num_samples].to(device)
305
306 with torch.no_grad():
307 recon, _, _ = model(data)
308
309 # Create comparison grid
310 comparison = torch.cat([data, recon], dim=0)
311 grid = vutils.make_grid(comparison.cpu(), nrow=num_samples, normalize=True, padding=2)
312
313 plt.figure(figsize=(12, 4))
314 plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
315 plt.axis('off')
316 plt.title('Original (top) vs Reconstructed (bottom)')
317 plt.savefig('vae_reconstruction.png', dpi=150)
318 plt.close()
319 print("Reconstruction comparison saved to vae_reconstruction.png")
320
321
322# ============================================
323# 6. Beta-VAE
324# ============================================
325print("\n[6] Beta-VAE for Disentanglement")
326print("-" * 40)
327
328
329class BetaVAE(VAE):
330 """Beta-VAE with higher KL weight for disentanglement"""
331 def __init__(self, in_channels=1, latent_dim=10, beta=4.0):
332 super().__init__(in_channels, latent_dim)
333 self.beta = beta
334
335 def loss(self, x, x_recon, mu, log_var):
336 return vae_loss(x, x_recon, mu, log_var, self.beta)
337
338
339print(f"Beta-VAE class defined with configurable beta parameter")
340
341
342# ============================================
343# 7. Conditional VAE (CVAE)
344# ============================================
345print("\n[7] Conditional VAE")
346print("-" * 40)
347
348
349class CVAEEncoder(nn.Module):
350 """CVAE Encoder with label conditioning"""
351 def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
352 super().__init__()
353 self.num_classes = num_classes
354
355 self.conv_layers = nn.Sequential(
356 nn.Conv2d(in_channels + num_classes, 32, 3, stride=2, padding=1),
357 nn.ReLU(),
358 nn.Conv2d(32, 64, 3, stride=2, padding=1),
359 nn.ReLU(),
360 nn.Flatten()
361 )
362
363 self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
364 self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)
365
366 def forward(self, x, y):
367 # Concat one-hot label as additional channels
368 y_expanded = y.view(-1, self.num_classes, 1, 1).expand(-1, -1, x.size(2), x.size(3))
369 x_cond = torch.cat([x, y_expanded], dim=1)
370
371 h = self.conv_layers(x_cond)
372 return self.fc_mu(h), self.fc_logvar(h)
373
374
375class CVAEDecoder(nn.Module):
376 """CVAE Decoder with label conditioning"""
377 def __init__(self, latent_dim=20, out_channels=1, num_classes=10):
378 super().__init__()
379 self.num_classes = num_classes
380
381 self.fc = nn.Linear(latent_dim + num_classes, 64 * 7 * 7)
382
383 self.deconv_layers = nn.Sequential(
384 nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
385 nn.ReLU(),
386 nn.ConvTranspose2d(32, out_channels, 4, stride=2, padding=1),
387 nn.Sigmoid()
388 )
389
390 def forward(self, z, y):
391 z_cond = torch.cat([z, y], dim=1)
392 h = self.fc(z_cond)
393 h = h.view(-1, 64, 7, 7)
394 return self.deconv_layers(h)
395
396
397class CVAE(nn.Module):
398 """Conditional VAE for digit generation"""
399 def __init__(self, in_channels=1, latent_dim=20, num_classes=10):
400 super().__init__()
401 self.encoder = CVAEEncoder(in_channels, latent_dim, num_classes)
402 self.decoder = CVAEDecoder(latent_dim, in_channels, num_classes)
403 self.latent_dim = latent_dim
404 self.num_classes = num_classes
405
406 def reparameterize(self, mu, log_var):
407 std = torch.exp(0.5 * log_var)
408 eps = torch.randn_like(std)
409 return mu + std * eps
410
411 def forward(self, x, label):
412 y = F.one_hot(label, self.num_classes).float()
413 mu, log_var = self.encoder(x, y)
414 z = self.reparameterize(mu, log_var)
415 x_recon = self.decoder(z, y)
416 return x_recon, mu, log_var
417
418 def generate(self, label, num_samples, device):
419 """Generate specific digit"""
420 z = torch.randn(num_samples, self.latent_dim, device=device)
421 y = F.one_hot(label.expand(num_samples), self.num_classes).float().to(device)
422 return self.decoder(z, y)
423
424
425print("CVAE class defined for conditional generation")
426
427
428# ============================================
429# 8. Training Example
430# ============================================
431print("\n[8] Training VAE on MNIST")
432print("-" * 40)
433
434# Hyperparameters
435latent_dim = 20
436batch_size = 128
437epochs = 10
438lr = 1e-3
439
440# Data
441transform = transforms.ToTensor()
442print("Loading MNIST dataset...")
443train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
444test_data = datasets.MNIST('data', train=False, transform=transform)
445
446train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
447test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0)
448
449# Model
450vae = VAE(in_channels=1, latent_dim=latent_dim).to(device)
451print(f"VAE parameters: {sum(p.numel() for p in vae.parameters()):,}")
452
453# Train
454print("\nTraining VAE...")
455history = train_vae(vae, train_loader, epochs=epochs, lr=lr, beta=1.0)
456
457# Visualizations
458print("\nGenerating visualizations...")
459
460# 1. Reconstruction
461visualize_reconstructions(vae, test_loader, num_samples=10, device=device)
462
463# 2. Generated samples
464vae.eval()
465with torch.no_grad():
466 samples = vae.generate(64, device)
467 grid = vutils.make_grid(samples.cpu(), nrow=8, normalize=True)
468 plt.figure(figsize=(8, 8))
469 plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
470 plt.axis('off')
471 plt.title('Generated Samples')
472 plt.savefig('vae_samples.png', dpi=150)
473 plt.close()
474print("Generated samples saved to vae_samples.png")
475
476# 3. Latent space (for 2D visualization, use latent_dim=2)
477vae_2d = VAE(in_channels=1, latent_dim=2).to(device)
478print("\nTraining 2D VAE for visualization...")
479_ = train_vae(vae_2d, train_loader, epochs=5, lr=lr)
480visualize_latent_space(vae_2d, test_loader, device)
481generate_manifold(vae_2d, n=20, latent_dim=2, device=device)
482
483# 4. Loss curves
484plt.figure(figsize=(10, 4))
485plt.subplot(1, 2, 1)
486plt.plot(history['total'], label='Total')
487plt.plot(history['recon'], label='Reconstruction')
488plt.xlabel('Epoch')
489plt.ylabel('Loss')
490plt.legend()
491plt.title('Training Loss')
492
493plt.subplot(1, 2, 2)
494plt.plot(history['kl'], label='KL Divergence', color='orange')
495plt.xlabel('Epoch')
496plt.ylabel('Loss')
497plt.legend()
498plt.title('KL Divergence')
499
500plt.tight_layout()
501plt.savefig('vae_loss.png', dpi=150)
502plt.close()
503print("Loss curves saved to vae_loss.png")
504
505
506# ============================================
507# Summary
508# ============================================
509print("\n" + "=" * 60)
510print("VAE Summary")
511print("=" * 60)
512
513summary = """
514Key Concepts:
5151. VAE: Probabilistic latent variable model
5162. Reparameterization: z = mu + sigma * epsilon
5173. ELBO: Reconstruction + KL Divergence
5184. Beta-VAE: beta > 1 for disentanglement
5195. CVAE: Conditional generation
520
521Loss Function:
522L = E[log p(x|z)] - beta * KL(q(z|x) || p(z))
523 = BCE(x, x_recon) + beta * (-0.5 * sum(1 + log(var) - mu^2 - var))
524
525Latent Space Properties:
526- Continuous and structured
527- Can interpolate between samples
528- Each dimension captures a factor of variation
529
530Output Files:
531- vae_samples.png: Generated samples
532- vae_reconstruction.png: Original vs reconstructed
533- vae_latent_space.png: 2D latent space visualization
534- vae_manifold.png: Digit manifold
535- vae_loss.png: Training loss curves
536"""
537print(summary)
538print("=" * 60)