1"""
215. GAN and DCGAN Implementation
3
4GAN (Generative Adversarial Networks) and DCGAN (Deep Convolutional GAN)
5implementation for MNIST digit generation.
6"""
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.utils.data import DataLoader
12from torchvision import datasets, transforms
13import torchvision.utils as vutils
14import matplotlib.pyplot as plt
15import numpy as np
16import os
17
18print("=" * 60)
19print("GAN and DCGAN Implementation")
20print("=" * 60)
21
22device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23print(f"Using device: {device}")
24
25
26# ============================================
27# 1. Basic GAN (Fully Connected)
28# ============================================
29print("\n[1] Basic GAN Architecture")
30print("-" * 40)
31
32
33class Generator(nn.Module):
34 """Basic Generator using fully connected layers"""
35 def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
36 super().__init__()
37 self.img_shape = img_shape
38 self.img_size = int(np.prod(img_shape))
39
40 def block(in_feat, out_feat, normalize=True):
41 layers = [nn.Linear(in_feat, out_feat)]
42 if normalize:
43 layers.append(nn.BatchNorm1d(out_feat))
44 layers.append(nn.LeakyReLU(0.2, inplace=True))
45 return layers
46
47 self.model = nn.Sequential(
48 *block(latent_dim, 128, normalize=False),
49 *block(128, 256),
50 *block(256, 512),
51 *block(512, 1024),
52 nn.Linear(1024, self.img_size),
53 nn.Tanh()
54 )
55
56 def forward(self, z):
57 img = self.model(z)
58 return img.view(img.size(0), *self.img_shape)
59
60
61class Discriminator(nn.Module):
62 """Basic Discriminator using fully connected layers"""
63 def __init__(self, img_shape=(1, 28, 28)):
64 super().__init__()
65 self.img_size = int(np.prod(img_shape))
66
67 self.model = nn.Sequential(
68 nn.Flatten(),
69 nn.Linear(self.img_size, 512),
70 nn.LeakyReLU(0.2, inplace=True),
71 nn.Linear(512, 256),
72 nn.LeakyReLU(0.2, inplace=True),
73 nn.Linear(256, 1),
74 nn.Sigmoid()
75 )
76
77 def forward(self, img):
78 return self.model(img)
79
80
81# Test basic GAN
82G = Generator(latent_dim=100)
83D = Discriminator()
84z = torch.randn(4, 100)
85fake_imgs = G(z)
86validity = D(fake_imgs)
87print(f"Generator output shape: {fake_imgs.shape}")
88print(f"Discriminator output shape: {validity.shape}")
89
90
91# ============================================
92# 2. DCGAN Architecture
93# ============================================
94print("\n[2] DCGAN Architecture")
95print("-" * 40)
96
97
98class DCGANGenerator(nn.Module):
99 """DCGAN Generator with transposed convolutions
100
101 Architecture for 64x64 output:
102 z (latent_dim,) -> (ngf*8, 4, 4) -> (ngf*4, 8, 8) -> (ngf*2, 16, 16)
103 -> (ngf, 32, 32) -> (nc, 64, 64)
104 """
105 def __init__(self, latent_dim=100, ngf=64, nc=1):
106 super().__init__()
107 self.latent_dim = latent_dim
108
109 self.main = nn.Sequential(
110 # Input: z (latent_dim,) -> (ngf*8, 4, 4)
111 nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
112 nn.BatchNorm2d(ngf * 8),
113 nn.ReLU(True),
114
115 # (ngf*8, 4, 4) -> (ngf*4, 8, 8)
116 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
117 nn.BatchNorm2d(ngf * 4),
118 nn.ReLU(True),
119
120 # (ngf*4, 8, 8) -> (ngf*2, 16, 16)
121 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
122 nn.BatchNorm2d(ngf * 2),
123 nn.ReLU(True),
124
125 # (ngf*2, 16, 16) -> (ngf, 32, 32)
126 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
127 nn.BatchNorm2d(ngf),
128 nn.ReLU(True),
129
130 # (ngf, 32, 32) -> (nc, 64, 64)
131 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
132 nn.Tanh()
133 )
134
135 self._init_weights()
136
137 def _init_weights(self):
138 for m in self.modules():
139 if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
140 nn.init.normal_(m.weight, 0.0, 0.02)
141 elif isinstance(m, nn.BatchNorm2d):
142 nn.init.normal_(m.weight, 1.0, 0.02)
143 nn.init.constant_(m.bias, 0)
144
145 def forward(self, z):
146 z = z.view(-1, self.latent_dim, 1, 1)
147 return self.main(z)
148
149
150class DCGANDiscriminator(nn.Module):
151 """DCGAN Discriminator with strided convolutions
152
153 Architecture for 64x64 input:
154 (nc, 64, 64) -> (ndf, 32, 32) -> (ndf*2, 16, 16) -> (ndf*4, 8, 8)
155 -> (ndf*8, 4, 4) -> (1,)
156 """
157 def __init__(self, nc=1, ndf=64):
158 super().__init__()
159
160 self.main = nn.Sequential(
161 # (nc, 64, 64) -> (ndf, 32, 32)
162 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
163 nn.LeakyReLU(0.2, inplace=True),
164
165 # (ndf, 32, 32) -> (ndf*2, 16, 16)
166 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
167 nn.BatchNorm2d(ndf * 2),
168 nn.LeakyReLU(0.2, inplace=True),
169
170 # (ndf*2, 16, 16) -> (ndf*4, 8, 8)
171 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
172 nn.BatchNorm2d(ndf * 4),
173 nn.LeakyReLU(0.2, inplace=True),
174
175 # (ndf*4, 8, 8) -> (ndf*8, 4, 4)
176 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
177 nn.BatchNorm2d(ndf * 8),
178 nn.LeakyReLU(0.2, inplace=True),
179
180 # (ndf*8, 4, 4) -> (1, 1, 1)
181 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
182 nn.Sigmoid()
183 )
184
185 self._init_weights()
186
187 def _init_weights(self):
188 for m in self.modules():
189 if isinstance(m, nn.Conv2d):
190 nn.init.normal_(m.weight, 0.0, 0.02)
191 elif isinstance(m, nn.BatchNorm2d):
192 nn.init.normal_(m.weight, 1.0, 0.02)
193 nn.init.constant_(m.bias, 0)
194
195 def forward(self, img):
196 return self.main(img).view(-1, 1)
197
198
199# Test DCGAN
200dc_G = DCGANGenerator(latent_dim=100, ngf=64, nc=1)
201dc_D = DCGANDiscriminator(nc=1, ndf=64)
202z = torch.randn(4, 100)
203fake_imgs = dc_G(z)
204validity = dc_D(fake_imgs)
205print(f"DCGAN Generator output: {fake_imgs.shape}")
206print(f"DCGAN Discriminator output: {validity.shape}")
207
208
209# ============================================
210# 3. Loss Functions
211# ============================================
212print("\n[3] GAN Loss Functions")
213print("-" * 40)
214
215
216def bce_loss(output, target):
217 """Binary Cross Entropy Loss (vanilla GAN)"""
218 return F.binary_cross_entropy(output, target)
219
220
221def wasserstein_loss(output, is_real):
222 """Wasserstein Loss (WGAN)
223
224 D tries to maximize: E[D(real)] - E[D(fake)]
225 G tries to maximize: E[D(fake)]
226 """
227 if is_real:
228 return -torch.mean(output)
229 else:
230 return torch.mean(output)
231
232
233def hinge_loss(output, is_real):
234 """Hinge Loss
235
236 D: max(0, 1 - D(real)) + max(0, 1 + D(fake))
237 G: -E[D(fake)]
238 """
239 if is_real:
240 return torch.mean(F.relu(1.0 - output))
241 else:
242 return torch.mean(F.relu(1.0 + output))
243
244
245# ============================================
246# 4. Gradient Penalty (WGAN-GP)
247# ============================================
248print("\n[4] Gradient Penalty (WGAN-GP)")
249print("-" * 40)
250
251
252def gradient_penalty(discriminator, real_imgs, fake_imgs, device):
253 """Compute gradient penalty for WGAN-GP"""
254 batch_size = real_imgs.size(0)
255
256 # Random interpolation between real and fake
257 alpha = torch.rand(batch_size, 1, 1, 1, device=device)
258 interpolated = alpha * real_imgs + (1 - alpha) * fake_imgs
259 interpolated.requires_grad_(True)
260
261 # Get discriminator output
262 d_interpolated = discriminator(interpolated)
263
264 # Compute gradients
265 gradients = torch.autograd.grad(
266 outputs=d_interpolated,
267 inputs=interpolated,
268 grad_outputs=torch.ones_like(d_interpolated),
269 create_graph=True,
270 retain_graph=True
271 )[0]
272
273 # Compute gradient norm
274 gradients = gradients.view(batch_size, -1)
275 gradient_norm = gradients.norm(2, dim=1)
276
277 # Penalty: (||grad|| - 1)^2
278 penalty = ((gradient_norm - 1) ** 2).mean()
279
280 return penalty
281
282
283print("Gradient penalty function defined")
284
285
286# ============================================
287# 5. Training Loop
288# ============================================
289print("\n[5] Training Loop")
290print("-" * 40)
291
292
293def train_gan(generator, discriminator, dataloader, epochs=5, latent_dim=100, lr=0.0002):
294 """Train basic GAN on MNIST"""
295 generator.to(device)
296 discriminator.to(device)
297
298 criterion = nn.BCELoss()
299
300 optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
301 optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
302
303 g_losses = []
304 d_losses = []
305
306 for epoch in range(epochs):
307 for i, (real_imgs, _) in enumerate(dataloader):
308 batch_size = real_imgs.size(0)
309 real_imgs = real_imgs.to(device)
310
311 # Labels
312 real_labels = torch.ones(batch_size, 1, device=device)
313 fake_labels = torch.zeros(batch_size, 1, device=device)
314
315 # ---------------------
316 # Train Discriminator
317 # ---------------------
318 optimizer_D.zero_grad()
319
320 # Real images
321 real_output = discriminator(real_imgs)
322 d_loss_real = criterion(real_output, real_labels)
323
324 # Fake images
325 z = torch.randn(batch_size, latent_dim, device=device)
326 fake_imgs = generator(z)
327 fake_output = discriminator(fake_imgs.detach())
328 d_loss_fake = criterion(fake_output, fake_labels)
329
330 d_loss = d_loss_real + d_loss_fake
331 d_loss.backward()
332 optimizer_D.step()
333
334 # -----------------
335 # Train Generator
336 # -----------------
337 optimizer_G.zero_grad()
338
339 fake_output = discriminator(fake_imgs)
340 g_loss = criterion(fake_output, real_labels)
341
342 g_loss.backward()
343 optimizer_G.step()
344
345 g_losses.append(g_loss.item())
346 d_losses.append(d_loss.item())
347
348 print(f"Epoch [{epoch+1}/{epochs}] D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
349
350 return g_losses, d_losses
351
352
353# ============================================
354# 6. Sample Generation and Visualization
355# ============================================
356print("\n[6] Sample Generation")
357print("-" * 40)
358
359
360def generate_samples(generator, num_samples=64, latent_dim=100):
361 """Generate samples from trained generator"""
362 generator.eval()
363 with torch.no_grad():
364 z = torch.randn(num_samples, latent_dim, device=device)
365 samples = generator(z)
366 return samples
367
368
369def save_samples(samples, filename='generated_samples.png', nrow=8):
370 """Save generated samples as image grid"""
371 samples = (samples + 1) / 2 # [-1, 1] -> [0, 1]
372 grid = vutils.make_grid(samples.cpu(), nrow=nrow, normalize=False, padding=2)
373 plt.figure(figsize=(10, 10))
374 plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
375 plt.axis('off')
376 plt.savefig(filename, dpi=150, bbox_inches='tight')
377 plt.close()
378 print(f"Samples saved to {filename}")
379
380
381def latent_interpolation(generator, z1, z2, steps=10):
382 """Interpolate between two latent vectors"""
383 generator.eval()
384 images = []
385 with torch.no_grad():
386 for alpha in torch.linspace(0, 1, steps):
387 z = (1 - alpha) * z1 + alpha * z2
388 img = generator(z.unsqueeze(0).to(device))
389 images.append(img)
390 return torch.cat(images, dim=0)
391
392
393def spherical_interpolation(z1, z2, alpha):
394 """Spherical linear interpolation (slerp)"""
395 z1_norm = z1 / z1.norm()
396 z2_norm = z2 / z2.norm()
397 omega = torch.acos((z1_norm * z2_norm).sum())
398 so = torch.sin(omega)
399 return (torch.sin((1 - alpha) * omega) / so) * z1 + (torch.sin(alpha * omega) / so) * z2
400
401
402# ============================================
403# 7. Training Example
404# ============================================
405print("\n[7] Training Example (Basic GAN on MNIST)")
406print("-" * 40)
407
408# Hyperparameters
409latent_dim = 100
410batch_size = 64
411epochs = 5 # Increase for better results
412
413# Data
414transform = transforms.Compose([
415 transforms.ToTensor(),
416 transforms.Normalize([0.5], [0.5]) # [-1, 1]
417])
418
419print("Loading MNIST dataset...")
420train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
421train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
422
423# Models
424G = Generator(latent_dim=latent_dim, img_shape=(1, 28, 28)).to(device)
425D = Discriminator(img_shape=(1, 28, 28)).to(device)
426
427print(f"Generator parameters: {sum(p.numel() for p in G.parameters()):,}")
428print(f"Discriminator parameters: {sum(p.numel() for p in D.parameters()):,}")
429
430# Train
431print("\nTraining...")
432g_losses, d_losses = train_gan(G, D, train_loader, epochs=epochs, latent_dim=latent_dim)
433
434# Generate samples
435print("\nGenerating samples...")
436samples = generate_samples(G, num_samples=64, latent_dim=latent_dim)
437save_samples(samples, 'gan_mnist_samples.png')
438
439# Plot losses
440plt.figure(figsize=(10, 5))
441plt.plot(g_losses, label='Generator', alpha=0.7)
442plt.plot(d_losses, label='Discriminator', alpha=0.7)
443plt.xlabel('Iteration')
444plt.ylabel('Loss')
445plt.title('GAN Training Loss')
446plt.legend()
447plt.savefig('gan_loss.png')
448plt.close()
449print("Loss plot saved to gan_loss.png")
450
451
452# ============================================
453# 8. Latent Space Exploration
454# ============================================
455print("\n[8] Latent Space Exploration")
456print("-" * 40)
457
458# Interpolation
459z1 = torch.randn(latent_dim)
460z2 = torch.randn(latent_dim)
461interp_imgs = latent_interpolation(G, z1, z2, steps=10)
462save_samples(interp_imgs, 'gan_interpolation.png', nrow=10)
463
464
465# ============================================
466# Summary
467# ============================================
468print("\n" + "=" * 60)
469print("GAN/DCGAN Summary")
470print("=" * 60)
471
472summary = """
473Key Concepts:
4741. GAN: Generator vs Discriminator adversarial training
4752. DCGAN: Convolutional architecture with BatchNorm, LeakyReLU
4763. Loss: BCE (vanilla), Wasserstein, Hinge
4774. WGAN-GP: Gradient penalty for stable training
478
479Training Tips:
480- Adam with beta1=0.5
481- Learning rate: 0.0001 ~ 0.0002
482- Label smoothing for stability
483- Monitor D/G balance
484
485Common Issues:
486- Mode collapse: G produces limited variety
487- Training instability: D too strong
488- Vanishing gradients: Use WGAN/WGAN-GP
489
490Output Files:
491- gan_mnist_samples.png: Generated samples
492- gan_loss.png: Training loss curves
493- gan_interpolation.png: Latent space interpolation
494"""
495print(summary)
496print("=" * 60)