14_dcgan.py

Download
python 376 lines 12.3 KB
  1"""
  2DCGAN (Deep Convolutional GAN) Implementation
  3
  4This script implements DCGAN following "Unsupervised Representation Learning
  5with Deep Convolutional Generative Adversarial Networks" (Radford et al., 2015).
  6
  7Key architecture guidelines:
  8- Replace pooling with strided convolutions (discriminator) and transposed convolutions (generator)
  9- Use batch normalization in both generator and discriminator
 10- Remove fully connected hidden layers
 11- Use ReLU in generator (except output: Tanh), LeakyReLU in discriminator
 12- Proper weight initialization
 13
 14References:
 15- Radford et al. (2015): https://arxiv.org/abs/1511.06434
 16"""
 17
 18import torch
 19import torch.nn as nn
 20import torch.optim as optim
 21from torch.utils.data import DataLoader
 22from torchvision import datasets, transforms
 23from torchvision.utils import make_grid
 24import matplotlib.pyplot as plt
 25import numpy as np
 26from tqdm import tqdm
 27
 28
 29# ============================================================================
 30# Weight Initialization
 31# ============================================================================
 32
 33def weights_init(m):
 34    """
 35    Custom weight initialization as described in DCGAN paper.
 36
 37    - Conv/ConvTranspose layers: mean=0, std=0.02
 38    - BatchNorm layers: mean=1, std=0.02
 39    """
 40    classname = m.__class__.__name__
 41    if classname.find('Conv') != -1:
 42        nn.init.normal_(m.weight.data, 0.0, 0.02)
 43    elif classname.find('BatchNorm') != -1:
 44        nn.init.normal_(m.weight.data, 1.0, 0.02)
 45        nn.init.constant_(m.bias.data, 0)
 46
 47
 48# ============================================================================
 49# Generator
 50# ============================================================================
 51
 52class Generator(nn.Module):
 53    """
 54    DCGAN Generator: transforms latent vector z to image.
 55
 56    Architecture:
 57    - Input: [batch_size, nz, 1, 1] latent vector
 58    - 4 transposed convolution blocks with BatchNorm and ReLU
 59    - Output: [batch_size, nc, 64, 64] image with Tanh activation
 60
 61    Args:
 62        nz: size of latent vector (input noise dimension)
 63        ngf: number of generator filters in first layer
 64        nc: number of output channels (1 for grayscale, 3 for RGB)
 65    """
 66    def __init__(self, nz=100, ngf=64, nc=1):
 67        super(Generator, self).__init__()
 68
 69        self.main = nn.Sequential(
 70            # Input: [batch, nz, 1, 1]
 71            # Output: [batch, ngf*8, 4, 4]
 72            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
 73            nn.BatchNorm2d(ngf * 8),
 74            nn.ReLU(True),
 75
 76            # [batch, ngf*8, 4, 4] -> [batch, ngf*4, 8, 8]
 77            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
 78            nn.BatchNorm2d(ngf * 4),
 79            nn.ReLU(True),
 80
 81            # [batch, ngf*4, 8, 8] -> [batch, ngf*2, 16, 16]
 82            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
 83            nn.BatchNorm2d(ngf * 2),
 84            nn.ReLU(True),
 85
 86            # [batch, ngf*2, 16, 16] -> [batch, ngf, 32, 32]
 87            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
 88            nn.BatchNorm2d(ngf),
 89            nn.ReLU(True),
 90
 91            # [batch, ngf, 32, 32] -> [batch, nc, 64, 64]
 92            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
 93            nn.Tanh()  # Output in [-1, 1]
 94        )
 95
 96    def forward(self, z):
 97        """
 98        Args:
 99            z: [batch_size, nz, 1, 1] latent vectors
100        Returns:
101            [batch_size, nc, 64, 64] generated images
102        """
103        return self.main(z)
104
105
106# ============================================================================
107# Discriminator
108# ============================================================================
109
110class Discriminator(nn.Module):
111    """
112    DCGAN Discriminator: classifies images as real or fake.
113
114    Architecture:
115    - Input: [batch_size, nc, 64, 64] image
116    - 4 strided convolution blocks with BatchNorm and LeakyReLU
117    - Output: [batch_size, 1, 1, 1] probability (via Sigmoid)
118
119    Args:
120        nc: number of input channels (1 for grayscale, 3 for RGB)
121        ndf: number of discriminator filters in first layer
122    """
123    def __init__(self, nc=1, ndf=64):
124        super(Discriminator, self).__init__()
125
126        self.main = nn.Sequential(
127            # Input: [batch, nc, 64, 64]
128            # Output: [batch, ndf, 32, 32]
129            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
130            nn.LeakyReLU(0.2, inplace=True),
131
132            # [batch, ndf, 32, 32] -> [batch, ndf*2, 16, 16]
133            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
134            nn.BatchNorm2d(ndf * 2),
135            nn.LeakyReLU(0.2, inplace=True),
136
137            # [batch, ndf*2, 16, 16] -> [batch, ndf*4, 8, 8]
138            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
139            nn.BatchNorm2d(ndf * 4),
140            nn.LeakyReLU(0.2, inplace=True),
141
142            # [batch, ndf*4, 8, 8] -> [batch, ndf*8, 4, 4]
143            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
144            nn.BatchNorm2d(ndf * 8),
145            nn.LeakyReLU(0.2, inplace=True),
146
147            # [batch, ndf*8, 4, 4] -> [batch, 1, 1, 1]
148            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
149            nn.Sigmoid()  # Output probability
150        )
151
152    def forward(self, x):
153        """
154        Args:
155            x: [batch_size, nc, 64, 64] images
156        Returns:
157            [batch_size, 1, 1, 1] probability of being real
158        """
159        return self.main(x)
160
161
162# ============================================================================
163# Training
164# ============================================================================
165
166def train_dcgan(epochs=25, batch_size=128, nz=100, lr=0.0002, beta1=0.5, device='cuda'):
167    """
168    Train DCGAN on MNIST dataset.
169
170    Args:
171        epochs: number of training epochs
172        batch_size: batch size
173        nz: size of latent vector
174        lr: learning rate
175        beta1: beta1 parameter for Adam optimizer
176        device: 'cuda' or 'cpu'
177
178    Returns:
179        generator: trained Generator model
180        discriminator: trained Discriminator model
181        losses: dict with generator and discriminator losses
182    """
183    device = torch.device(device if torch.cuda.is_available() else 'cpu')
184
185    # Data preparation
186    # Resize MNIST to 64x64 and normalize to [-1, 1]
187    transform = transforms.Compose([
188        transforms.Resize(64),
189        transforms.ToTensor(),
190        transforms.Normalize((0.5,), (0.5,))
191    ])
192
193    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
194    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
195
196    # Initialize models
197    netG = Generator(nz=nz, ngf=64, nc=1).to(device)
198    netD = Discriminator(nc=1, ndf=64).to(device)
199
200    # Apply weight initialization
201    netG.apply(weights_init)
202    netD.apply(weights_init)
203
204    print("Generator architecture:")
205    print(netG)
206    print("\nDiscriminator architecture:")
207    print(netD)
208
209    # Loss function and optimizers
210    criterion = nn.BCELoss()
211
212    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
213    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
214
215    # Fixed noise for visualization
216    fixed_noise = torch.randn(64, nz, 1, 1, device=device)
217
218    # Labels for real and fake images
219    real_label = 1.0
220    fake_label = 0.0
221
222    # Lists to track losses
223    G_losses = []
224    D_losses = []
225
226    # Training loop
227    print("\nStarting training...")
228    for epoch in range(epochs):
229        for i, (real_images, _) in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}')):
230            batch_size_actual = real_images.size(0)
231            real_images = real_images.to(device)
232
233            # ================================================================
234            # (1) Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
235            # ================================================================
236            netD.zero_grad()
237
238            # Train with real images
239            label = torch.full((batch_size_actual,), real_label, dtype=torch.float, device=device)
240            output = netD(real_images).view(-1)
241            errD_real = criterion(output, label)
242            errD_real.backward()
243            D_x = output.mean().item()
244
245            # Train with fake images
246            noise = torch.randn(batch_size_actual, nz, 1, 1, device=device)
247            fake_images = netG(noise)
248            label.fill_(fake_label)
249            output = netD(fake_images.detach()).view(-1)
250            errD_fake = criterion(output, label)
251            errD_fake.backward()
252            D_G_z1 = output.mean().item()
253
254            # Total discriminator loss
255            errD = errD_real + errD_fake
256            optimizerD.step()
257
258            # ================================================================
259            # (2) Update Generator: maximize log(D(G(z)))
260            # ================================================================
261            netG.zero_grad()
262            label.fill_(real_label)  # Fake images should be classified as real
263            output = netD(fake_images).view(-1)
264            errG = criterion(output, label)
265            errG.backward()
266            D_G_z2 = output.mean().item()
267            optimizerG.step()
268
269            # Save losses
270            if i % 50 == 0:
271                G_losses.append(errG.item())
272                D_losses.append(errD.item())
273
274        # Print statistics
275        print(f'[{epoch+1}/{epochs}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
276              f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
277
278        # Generate and save images every 5 epochs
279        if (epoch + 1) % 5 == 0 or epoch == 0:
280            with torch.no_grad():
281                fake_samples = netG(fixed_noise).detach().cpu()
282
283            # Create image grid
284            grid = make_grid(fake_samples, nrow=8, normalize=True)
285
286            # Visualize
287            plt.figure(figsize=(10, 10))
288            plt.imshow(np.transpose(grid, (1, 2, 0)))
289            plt.title(f'Generated Images - Epoch {epoch+1}')
290            plt.axis('off')
291            plt.tight_layout()
292            plt.savefig(f'dcgan_samples_epoch_{epoch+1}.png')
293            plt.close()
294
295    print("\nTraining completed!")
296
297    # Plot losses
298    plt.figure(figsize=(10, 5))
299    plt.plot(G_losses, label='Generator Loss')
300    plt.plot(D_losses, label='Discriminator Loss')
301    plt.xlabel('Iterations (x50)')
302    plt.ylabel('Loss')
303    plt.legend()
304    plt.title('DCGAN Training Losses')
305    plt.savefig('dcgan_losses.png')
306    plt.close()
307
308    return netG, netD, {'G_losses': G_losses, 'D_losses': D_losses}
309
310
311# ============================================================================
312# Image Generation
313# ============================================================================
314
315@torch.no_grad()
316def generate_images(generator, num_images=64, nz=100, device='cuda'):
317    """
318    Generate images using trained generator.
319
320    Args:
321        generator: trained Generator model
322        num_images: number of images to generate
323        nz: size of latent vector
324        device: 'cuda' or 'cpu'
325
326    Returns:
327        [num_images, nc, 64, 64] generated images
328    """
329    device = torch.device(device if torch.cuda.is_available() else 'cpu')
330    generator.eval()
331
332    # Sample random noise
333    noise = torch.randn(num_images, nz, 1, 1, device=device)
334
335    # Generate images
336    fake_images = generator(noise)
337
338    return fake_images.cpu()
339
340
341# ============================================================================
342# Main
343# ============================================================================
344
345if __name__ == '__main__':
346    # Train DCGAN
347    netG, netD, losses = train_dcgan(
348        epochs=25,
349        batch_size=128,
350        nz=100,
351        lr=0.0002,
352        beta1=0.5,
353        device='cuda'
354    )
355
356    # Generate final samples
357    print("\nGenerating final samples...")
358    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
359    samples = generate_images(netG, num_images=64, nz=100, device=device)
360
361    # Visualize final samples
362    grid = make_grid(samples, nrow=8, normalize=True)
363
364    plt.figure(figsize=(12, 12))
365    plt.imshow(np.transpose(grid, (1, 2, 0)))
366    plt.title('DCGAN Generated Samples (Final)')
367    plt.axis('off')
368    plt.tight_layout()
369    plt.savefig('dcgan_final_samples.png')
370    plt.show()
371
372    # Optional: Save models
373    torch.save(netG.state_dict(), 'dcgan_generator.pth')
374    torch.save(netD.state_dict(), 'dcgan_discriminator.pth')
375    print("\nModels saved to dcgan_generator.pth and dcgan_discriminator.pth")