15_gan_dcgan.py

Download
python 497 lines 14.3 KB
  1"""
  215. GAN and DCGAN Implementation
  3
  4GAN (Generative Adversarial Networks) and DCGAN (Deep Convolutional GAN)
  5implementation for MNIST digit generation.
  6"""
  7
  8import torch
  9import torch.nn as nn
 10import torch.nn.functional as F
 11from torch.utils.data import DataLoader
 12from torchvision import datasets, transforms
 13import torchvision.utils as vutils
 14import matplotlib.pyplot as plt
 15import numpy as np
 16import os
 17
 18print("=" * 60)
 19print("GAN and DCGAN Implementation")
 20print("=" * 60)
 21
 22device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 23print(f"Using device: {device}")
 24
 25
 26# ============================================
 27# 1. Basic GAN (Fully Connected)
 28# ============================================
 29print("\n[1] Basic GAN Architecture")
 30print("-" * 40)
 31
 32
 33class Generator(nn.Module):
 34    """Basic Generator using fully connected layers"""
 35    def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
 36        super().__init__()
 37        self.img_shape = img_shape
 38        self.img_size = int(np.prod(img_shape))
 39
 40        def block(in_feat, out_feat, normalize=True):
 41            layers = [nn.Linear(in_feat, out_feat)]
 42            if normalize:
 43                layers.append(nn.BatchNorm1d(out_feat))
 44            layers.append(nn.LeakyReLU(0.2, inplace=True))
 45            return layers
 46
 47        self.model = nn.Sequential(
 48            *block(latent_dim, 128, normalize=False),
 49            *block(128, 256),
 50            *block(256, 512),
 51            *block(512, 1024),
 52            nn.Linear(1024, self.img_size),
 53            nn.Tanh()
 54        )
 55
 56    def forward(self, z):
 57        img = self.model(z)
 58        return img.view(img.size(0), *self.img_shape)
 59
 60
 61class Discriminator(nn.Module):
 62    """Basic Discriminator using fully connected layers"""
 63    def __init__(self, img_shape=(1, 28, 28)):
 64        super().__init__()
 65        self.img_size = int(np.prod(img_shape))
 66
 67        self.model = nn.Sequential(
 68            nn.Flatten(),
 69            nn.Linear(self.img_size, 512),
 70            nn.LeakyReLU(0.2, inplace=True),
 71            nn.Linear(512, 256),
 72            nn.LeakyReLU(0.2, inplace=True),
 73            nn.Linear(256, 1),
 74            nn.Sigmoid()
 75        )
 76
 77    def forward(self, img):
 78        return self.model(img)
 79
 80
 81# Test basic GAN
 82G = Generator(latent_dim=100)
 83D = Discriminator()
 84z = torch.randn(4, 100)
 85fake_imgs = G(z)
 86validity = D(fake_imgs)
 87print(f"Generator output shape: {fake_imgs.shape}")
 88print(f"Discriminator output shape: {validity.shape}")
 89
 90
 91# ============================================
 92# 2. DCGAN Architecture
 93# ============================================
 94print("\n[2] DCGAN Architecture")
 95print("-" * 40)
 96
 97
 98class DCGANGenerator(nn.Module):
 99    """DCGAN Generator with transposed convolutions
100
101    Architecture for 64x64 output:
102    z (latent_dim,) -> (ngf*8, 4, 4) -> (ngf*4, 8, 8) -> (ngf*2, 16, 16)
103    -> (ngf, 32, 32) -> (nc, 64, 64)
104    """
105    def __init__(self, latent_dim=100, ngf=64, nc=1):
106        super().__init__()
107        self.latent_dim = latent_dim
108
109        self.main = nn.Sequential(
110            # Input: z (latent_dim,) -> (ngf*8, 4, 4)
111            nn.ConvTranspose2d(latent_dim, ngf * 8, 4, 1, 0, bias=False),
112            nn.BatchNorm2d(ngf * 8),
113            nn.ReLU(True),
114
115            # (ngf*8, 4, 4) -> (ngf*4, 8, 8)
116            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
117            nn.BatchNorm2d(ngf * 4),
118            nn.ReLU(True),
119
120            # (ngf*4, 8, 8) -> (ngf*2, 16, 16)
121            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
122            nn.BatchNorm2d(ngf * 2),
123            nn.ReLU(True),
124
125            # (ngf*2, 16, 16) -> (ngf, 32, 32)
126            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
127            nn.BatchNorm2d(ngf),
128            nn.ReLU(True),
129
130            # (ngf, 32, 32) -> (nc, 64, 64)
131            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
132            nn.Tanh()
133        )
134
135        self._init_weights()
136
137    def _init_weights(self):
138        for m in self.modules():
139            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
140                nn.init.normal_(m.weight, 0.0, 0.02)
141            elif isinstance(m, nn.BatchNorm2d):
142                nn.init.normal_(m.weight, 1.0, 0.02)
143                nn.init.constant_(m.bias, 0)
144
145    def forward(self, z):
146        z = z.view(-1, self.latent_dim, 1, 1)
147        return self.main(z)
148
149
150class DCGANDiscriminator(nn.Module):
151    """DCGAN Discriminator with strided convolutions
152
153    Architecture for 64x64 input:
154    (nc, 64, 64) -> (ndf, 32, 32) -> (ndf*2, 16, 16) -> (ndf*4, 8, 8)
155    -> (ndf*8, 4, 4) -> (1,)
156    """
157    def __init__(self, nc=1, ndf=64):
158        super().__init__()
159
160        self.main = nn.Sequential(
161            # (nc, 64, 64) -> (ndf, 32, 32)
162            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
163            nn.LeakyReLU(0.2, inplace=True),
164
165            # (ndf, 32, 32) -> (ndf*2, 16, 16)
166            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
167            nn.BatchNorm2d(ndf * 2),
168            nn.LeakyReLU(0.2, inplace=True),
169
170            # (ndf*2, 16, 16) -> (ndf*4, 8, 8)
171            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
172            nn.BatchNorm2d(ndf * 4),
173            nn.LeakyReLU(0.2, inplace=True),
174
175            # (ndf*4, 8, 8) -> (ndf*8, 4, 4)
176            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
177            nn.BatchNorm2d(ndf * 8),
178            nn.LeakyReLU(0.2, inplace=True),
179
180            # (ndf*8, 4, 4) -> (1, 1, 1)
181            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
182            nn.Sigmoid()
183        )
184
185        self._init_weights()
186
187    def _init_weights(self):
188        for m in self.modules():
189            if isinstance(m, nn.Conv2d):
190                nn.init.normal_(m.weight, 0.0, 0.02)
191            elif isinstance(m, nn.BatchNorm2d):
192                nn.init.normal_(m.weight, 1.0, 0.02)
193                nn.init.constant_(m.bias, 0)
194
195    def forward(self, img):
196        return self.main(img).view(-1, 1)
197
198
199# Test DCGAN
200dc_G = DCGANGenerator(latent_dim=100, ngf=64, nc=1)
201dc_D = DCGANDiscriminator(nc=1, ndf=64)
202z = torch.randn(4, 100)
203fake_imgs = dc_G(z)
204validity = dc_D(fake_imgs)
205print(f"DCGAN Generator output: {fake_imgs.shape}")
206print(f"DCGAN Discriminator output: {validity.shape}")
207
208
209# ============================================
210# 3. Loss Functions
211# ============================================
212print("\n[3] GAN Loss Functions")
213print("-" * 40)
214
215
216def bce_loss(output, target):
217    """Binary Cross Entropy Loss (vanilla GAN)"""
218    return F.binary_cross_entropy(output, target)
219
220
221def wasserstein_loss(output, is_real):
222    """Wasserstein Loss (WGAN)
223
224    D tries to maximize: E[D(real)] - E[D(fake)]
225    G tries to maximize: E[D(fake)]
226    """
227    if is_real:
228        return -torch.mean(output)
229    else:
230        return torch.mean(output)
231
232
233def hinge_loss(output, is_real):
234    """Hinge Loss
235
236    D: max(0, 1 - D(real)) + max(0, 1 + D(fake))
237    G: -E[D(fake)]
238    """
239    if is_real:
240        return torch.mean(F.relu(1.0 - output))
241    else:
242        return torch.mean(F.relu(1.0 + output))
243
244
245# ============================================
246# 4. Gradient Penalty (WGAN-GP)
247# ============================================
248print("\n[4] Gradient Penalty (WGAN-GP)")
249print("-" * 40)
250
251
252def gradient_penalty(discriminator, real_imgs, fake_imgs, device):
253    """Compute gradient penalty for WGAN-GP"""
254    batch_size = real_imgs.size(0)
255
256    # Random interpolation between real and fake
257    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
258    interpolated = alpha * real_imgs + (1 - alpha) * fake_imgs
259    interpolated.requires_grad_(True)
260
261    # Get discriminator output
262    d_interpolated = discriminator(interpolated)
263
264    # Compute gradients
265    gradients = torch.autograd.grad(
266        outputs=d_interpolated,
267        inputs=interpolated,
268        grad_outputs=torch.ones_like(d_interpolated),
269        create_graph=True,
270        retain_graph=True
271    )[0]
272
273    # Compute gradient norm
274    gradients = gradients.view(batch_size, -1)
275    gradient_norm = gradients.norm(2, dim=1)
276
277    # Penalty: (||grad|| - 1)^2
278    penalty = ((gradient_norm - 1) ** 2).mean()
279
280    return penalty
281
282
283print("Gradient penalty function defined")
284
285
286# ============================================
287# 5. Training Loop
288# ============================================
289print("\n[5] Training Loop")
290print("-" * 40)
291
292
293def train_gan(generator, discriminator, dataloader, epochs=5, latent_dim=100, lr=0.0002):
294    """Train basic GAN on MNIST"""
295    generator.to(device)
296    discriminator.to(device)
297
298    criterion = nn.BCELoss()
299
300    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
301    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
302
303    g_losses = []
304    d_losses = []
305
306    for epoch in range(epochs):
307        for i, (real_imgs, _) in enumerate(dataloader):
308            batch_size = real_imgs.size(0)
309            real_imgs = real_imgs.to(device)
310
311            # Labels
312            real_labels = torch.ones(batch_size, 1, device=device)
313            fake_labels = torch.zeros(batch_size, 1, device=device)
314
315            # ---------------------
316            # Train Discriminator
317            # ---------------------
318            optimizer_D.zero_grad()
319
320            # Real images
321            real_output = discriminator(real_imgs)
322            d_loss_real = criterion(real_output, real_labels)
323
324            # Fake images
325            z = torch.randn(batch_size, latent_dim, device=device)
326            fake_imgs = generator(z)
327            fake_output = discriminator(fake_imgs.detach())
328            d_loss_fake = criterion(fake_output, fake_labels)
329
330            d_loss = d_loss_real + d_loss_fake
331            d_loss.backward()
332            optimizer_D.step()
333
334            # -----------------
335            # Train Generator
336            # -----------------
337            optimizer_G.zero_grad()
338
339            fake_output = discriminator(fake_imgs)
340            g_loss = criterion(fake_output, real_labels)
341
342            g_loss.backward()
343            optimizer_G.step()
344
345            g_losses.append(g_loss.item())
346            d_losses.append(d_loss.item())
347
348        print(f"Epoch [{epoch+1}/{epochs}] D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
349
350    return g_losses, d_losses
351
352
353# ============================================
354# 6. Sample Generation and Visualization
355# ============================================
356print("\n[6] Sample Generation")
357print("-" * 40)
358
359
360def generate_samples(generator, num_samples=64, latent_dim=100):
361    """Generate samples from trained generator"""
362    generator.eval()
363    with torch.no_grad():
364        z = torch.randn(num_samples, latent_dim, device=device)
365        samples = generator(z)
366    return samples
367
368
369def save_samples(samples, filename='generated_samples.png', nrow=8):
370    """Save generated samples as image grid"""
371    samples = (samples + 1) / 2  # [-1, 1] -> [0, 1]
372    grid = vutils.make_grid(samples.cpu(), nrow=nrow, normalize=False, padding=2)
373    plt.figure(figsize=(10, 10))
374    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap='gray')
375    plt.axis('off')
376    plt.savefig(filename, dpi=150, bbox_inches='tight')
377    plt.close()
378    print(f"Samples saved to {filename}")
379
380
381def latent_interpolation(generator, z1, z2, steps=10):
382    """Interpolate between two latent vectors"""
383    generator.eval()
384    images = []
385    with torch.no_grad():
386        for alpha in torch.linspace(0, 1, steps):
387            z = (1 - alpha) * z1 + alpha * z2
388            img = generator(z.unsqueeze(0).to(device))
389            images.append(img)
390    return torch.cat(images, dim=0)
391
392
393def spherical_interpolation(z1, z2, alpha):
394    """Spherical linear interpolation (slerp)"""
395    z1_norm = z1 / z1.norm()
396    z2_norm = z2 / z2.norm()
397    omega = torch.acos((z1_norm * z2_norm).sum())
398    so = torch.sin(omega)
399    return (torch.sin((1 - alpha) * omega) / so) * z1 + (torch.sin(alpha * omega) / so) * z2
400
401
402# ============================================
403# 7. Training Example
404# ============================================
405print("\n[7] Training Example (Basic GAN on MNIST)")
406print("-" * 40)
407
408# Hyperparameters
409latent_dim = 100
410batch_size = 64
411epochs = 5  # Increase for better results
412
413# Data
414transform = transforms.Compose([
415    transforms.ToTensor(),
416    transforms.Normalize([0.5], [0.5])  # [-1, 1]
417])
418
419print("Loading MNIST dataset...")
420train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
421train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
422
423# Models
424G = Generator(latent_dim=latent_dim, img_shape=(1, 28, 28)).to(device)
425D = Discriminator(img_shape=(1, 28, 28)).to(device)
426
427print(f"Generator parameters: {sum(p.numel() for p in G.parameters()):,}")
428print(f"Discriminator parameters: {sum(p.numel() for p in D.parameters()):,}")
429
430# Train
431print("\nTraining...")
432g_losses, d_losses = train_gan(G, D, train_loader, epochs=epochs, latent_dim=latent_dim)
433
434# Generate samples
435print("\nGenerating samples...")
436samples = generate_samples(G, num_samples=64, latent_dim=latent_dim)
437save_samples(samples, 'gan_mnist_samples.png')
438
439# Plot losses
440plt.figure(figsize=(10, 5))
441plt.plot(g_losses, label='Generator', alpha=0.7)
442plt.plot(d_losses, label='Discriminator', alpha=0.7)
443plt.xlabel('Iteration')
444plt.ylabel('Loss')
445plt.title('GAN Training Loss')
446plt.legend()
447plt.savefig('gan_loss.png')
448plt.close()
449print("Loss plot saved to gan_loss.png")
450
451
452# ============================================
453# 8. Latent Space Exploration
454# ============================================
455print("\n[8] Latent Space Exploration")
456print("-" * 40)
457
458# Interpolation
459z1 = torch.randn(latent_dim)
460z2 = torch.randn(latent_dim)
461interp_imgs = latent_interpolation(G, z1, z2, steps=10)
462save_samples(interp_imgs, 'gan_interpolation.png', nrow=10)
463
464
465# ============================================
466# Summary
467# ============================================
468print("\n" + "=" * 60)
469print("GAN/DCGAN Summary")
470print("=" * 60)
471
472summary = """
473Key Concepts:
4741. GAN: Generator vs Discriminator adversarial training
4752. DCGAN: Convolutional architecture with BatchNorm, LeakyReLU
4763. Loss: BCE (vanilla), Wasserstein, Hinge
4774. WGAN-GP: Gradient penalty for stable training
478
479Training Tips:
480- Adam with beta1=0.5
481- Learning rate: 0.0001 ~ 0.0002
482- Label smoothing for stability
483- Monitor D/G balance
484
485Common Issues:
486- Mode collapse: G produces limited variety
487- Training instability: D too strong
488- Vanishing gradients: Use WGAN/WGAN-GP
489
490Output Files:
491- gan_mnist_samples.png: Generated samples
492- gan_loss.png: Training loss curves
493- gan_interpolation.png: Latent space interpolation
494"""
495print(summary)
496print("=" * 60)