1"""
2DCGAN (Deep Convolutional GAN) Implementation
3
4This script implements DCGAN following "Unsupervised Representation Learning
5with Deep Convolutional Generative Adversarial Networks" (Radford et al., 2015).
6
7Key architecture guidelines:
8- Replace pooling with strided convolutions (discriminator) and transposed convolutions (generator)
9- Use batch normalization in both generator and discriminator
10- Remove fully connected hidden layers
11- Use ReLU in generator (except output: Tanh), LeakyReLU in discriminator
12- Proper weight initialization
13
14References:
15- Radford et al. (2015): https://arxiv.org/abs/1511.06434
16"""
17
18import torch
19import torch.nn as nn
20import torch.optim as optim
21from torch.utils.data import DataLoader
22from torchvision import datasets, transforms
23from torchvision.utils import make_grid
24import matplotlib.pyplot as plt
25import numpy as np
26from tqdm import tqdm
27
28
29# ============================================================================
30# Weight Initialization
31# ============================================================================
32
33def weights_init(m):
34 """
35 Custom weight initialization as described in DCGAN paper.
36
37 - Conv/ConvTranspose layers: mean=0, std=0.02
38 - BatchNorm layers: mean=1, std=0.02
39 """
40 classname = m.__class__.__name__
41 if classname.find('Conv') != -1:
42 nn.init.normal_(m.weight.data, 0.0, 0.02)
43 elif classname.find('BatchNorm') != -1:
44 nn.init.normal_(m.weight.data, 1.0, 0.02)
45 nn.init.constant_(m.bias.data, 0)
46
47
48# ============================================================================
49# Generator
50# ============================================================================
51
52class Generator(nn.Module):
53 """
54 DCGAN Generator: transforms latent vector z to image.
55
56 Architecture:
57 - Input: [batch_size, nz, 1, 1] latent vector
58 - 4 transposed convolution blocks with BatchNorm and ReLU
59 - Output: [batch_size, nc, 64, 64] image with Tanh activation
60
61 Args:
62 nz: size of latent vector (input noise dimension)
63 ngf: number of generator filters in first layer
64 nc: number of output channels (1 for grayscale, 3 for RGB)
65 """
66 def __init__(self, nz=100, ngf=64, nc=1):
67 super(Generator, self).__init__()
68
69 self.main = nn.Sequential(
70 # Input: [batch, nz, 1, 1]
71 # Output: [batch, ngf*8, 4, 4]
72 nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
73 nn.BatchNorm2d(ngf * 8),
74 nn.ReLU(True),
75
76 # [batch, ngf*8, 4, 4] -> [batch, ngf*4, 8, 8]
77 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
78 nn.BatchNorm2d(ngf * 4),
79 nn.ReLU(True),
80
81 # [batch, ngf*4, 8, 8] -> [batch, ngf*2, 16, 16]
82 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
83 nn.BatchNorm2d(ngf * 2),
84 nn.ReLU(True),
85
86 # [batch, ngf*2, 16, 16] -> [batch, ngf, 32, 32]
87 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
88 nn.BatchNorm2d(ngf),
89 nn.ReLU(True),
90
91 # [batch, ngf, 32, 32] -> [batch, nc, 64, 64]
92 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
93 nn.Tanh() # Output in [-1, 1]
94 )
95
96 def forward(self, z):
97 """
98 Args:
99 z: [batch_size, nz, 1, 1] latent vectors
100 Returns:
101 [batch_size, nc, 64, 64] generated images
102 """
103 return self.main(z)
104
105
106# ============================================================================
107# Discriminator
108# ============================================================================
109
110class Discriminator(nn.Module):
111 """
112 DCGAN Discriminator: classifies images as real or fake.
113
114 Architecture:
115 - Input: [batch_size, nc, 64, 64] image
116 - 4 strided convolution blocks with BatchNorm and LeakyReLU
117 - Output: [batch_size, 1, 1, 1] probability (via Sigmoid)
118
119 Args:
120 nc: number of input channels (1 for grayscale, 3 for RGB)
121 ndf: number of discriminator filters in first layer
122 """
123 def __init__(self, nc=1, ndf=64):
124 super(Discriminator, self).__init__()
125
126 self.main = nn.Sequential(
127 # Input: [batch, nc, 64, 64]
128 # Output: [batch, ndf, 32, 32]
129 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
130 nn.LeakyReLU(0.2, inplace=True),
131
132 # [batch, ndf, 32, 32] -> [batch, ndf*2, 16, 16]
133 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
134 nn.BatchNorm2d(ndf * 2),
135 nn.LeakyReLU(0.2, inplace=True),
136
137 # [batch, ndf*2, 16, 16] -> [batch, ndf*4, 8, 8]
138 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
139 nn.BatchNorm2d(ndf * 4),
140 nn.LeakyReLU(0.2, inplace=True),
141
142 # [batch, ndf*4, 8, 8] -> [batch, ndf*8, 4, 4]
143 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
144 nn.BatchNorm2d(ndf * 8),
145 nn.LeakyReLU(0.2, inplace=True),
146
147 # [batch, ndf*8, 4, 4] -> [batch, 1, 1, 1]
148 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
149 nn.Sigmoid() # Output probability
150 )
151
152 def forward(self, x):
153 """
154 Args:
155 x: [batch_size, nc, 64, 64] images
156 Returns:
157 [batch_size, 1, 1, 1] probability of being real
158 """
159 return self.main(x)
160
161
162# ============================================================================
163# Training
164# ============================================================================
165
166def train_dcgan(epochs=25, batch_size=128, nz=100, lr=0.0002, beta1=0.5, device='cuda'):
167 """
168 Train DCGAN on MNIST dataset.
169
170 Args:
171 epochs: number of training epochs
172 batch_size: batch size
173 nz: size of latent vector
174 lr: learning rate
175 beta1: beta1 parameter for Adam optimizer
176 device: 'cuda' or 'cpu'
177
178 Returns:
179 generator: trained Generator model
180 discriminator: trained Discriminator model
181 losses: dict with generator and discriminator losses
182 """
183 device = torch.device(device if torch.cuda.is_available() else 'cpu')
184
185 # Data preparation
186 # Resize MNIST to 64x64 and normalize to [-1, 1]
187 transform = transforms.Compose([
188 transforms.Resize(64),
189 transforms.ToTensor(),
190 transforms.Normalize((0.5,), (0.5,))
191 ])
192
193 dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
194 dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
195
196 # Initialize models
197 netG = Generator(nz=nz, ngf=64, nc=1).to(device)
198 netD = Discriminator(nc=1, ndf=64).to(device)
199
200 # Apply weight initialization
201 netG.apply(weights_init)
202 netD.apply(weights_init)
203
204 print("Generator architecture:")
205 print(netG)
206 print("\nDiscriminator architecture:")
207 print(netD)
208
209 # Loss function and optimizers
210 criterion = nn.BCELoss()
211
212 optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
213 optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
214
215 # Fixed noise for visualization
216 fixed_noise = torch.randn(64, nz, 1, 1, device=device)
217
218 # Labels for real and fake images
219 real_label = 1.0
220 fake_label = 0.0
221
222 # Lists to track losses
223 G_losses = []
224 D_losses = []
225
226 # Training loop
227 print("\nStarting training...")
228 for epoch in range(epochs):
229 for i, (real_images, _) in enumerate(tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}')):
230 batch_size_actual = real_images.size(0)
231 real_images = real_images.to(device)
232
233 # ================================================================
234 # (1) Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))
235 # ================================================================
236 netD.zero_grad()
237
238 # Train with real images
239 label = torch.full((batch_size_actual,), real_label, dtype=torch.float, device=device)
240 output = netD(real_images).view(-1)
241 errD_real = criterion(output, label)
242 errD_real.backward()
243 D_x = output.mean().item()
244
245 # Train with fake images
246 noise = torch.randn(batch_size_actual, nz, 1, 1, device=device)
247 fake_images = netG(noise)
248 label.fill_(fake_label)
249 output = netD(fake_images.detach()).view(-1)
250 errD_fake = criterion(output, label)
251 errD_fake.backward()
252 D_G_z1 = output.mean().item()
253
254 # Total discriminator loss
255 errD = errD_real + errD_fake
256 optimizerD.step()
257
258 # ================================================================
259 # (2) Update Generator: maximize log(D(G(z)))
260 # ================================================================
261 netG.zero_grad()
262 label.fill_(real_label) # Fake images should be classified as real
263 output = netD(fake_images).view(-1)
264 errG = criterion(output, label)
265 errG.backward()
266 D_G_z2 = output.mean().item()
267 optimizerG.step()
268
269 # Save losses
270 if i % 50 == 0:
271 G_losses.append(errG.item())
272 D_losses.append(errD.item())
273
274 # Print statistics
275 print(f'[{epoch+1}/{epochs}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
276 f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')
277
278 # Generate and save images every 5 epochs
279 if (epoch + 1) % 5 == 0 or epoch == 0:
280 with torch.no_grad():
281 fake_samples = netG(fixed_noise).detach().cpu()
282
283 # Create image grid
284 grid = make_grid(fake_samples, nrow=8, normalize=True)
285
286 # Visualize
287 plt.figure(figsize=(10, 10))
288 plt.imshow(np.transpose(grid, (1, 2, 0)))
289 plt.title(f'Generated Images - Epoch {epoch+1}')
290 plt.axis('off')
291 plt.tight_layout()
292 plt.savefig(f'dcgan_samples_epoch_{epoch+1}.png')
293 plt.close()
294
295 print("\nTraining completed!")
296
297 # Plot losses
298 plt.figure(figsize=(10, 5))
299 plt.plot(G_losses, label='Generator Loss')
300 plt.plot(D_losses, label='Discriminator Loss')
301 plt.xlabel('Iterations (x50)')
302 plt.ylabel('Loss')
303 plt.legend()
304 plt.title('DCGAN Training Losses')
305 plt.savefig('dcgan_losses.png')
306 plt.close()
307
308 return netG, netD, {'G_losses': G_losses, 'D_losses': D_losses}
309
310
311# ============================================================================
312# Image Generation
313# ============================================================================
314
315@torch.no_grad()
316def generate_images(generator, num_images=64, nz=100, device='cuda'):
317 """
318 Generate images using trained generator.
319
320 Args:
321 generator: trained Generator model
322 num_images: number of images to generate
323 nz: size of latent vector
324 device: 'cuda' or 'cpu'
325
326 Returns:
327 [num_images, nc, 64, 64] generated images
328 """
329 device = torch.device(device if torch.cuda.is_available() else 'cpu')
330 generator.eval()
331
332 # Sample random noise
333 noise = torch.randn(num_images, nz, 1, 1, device=device)
334
335 # Generate images
336 fake_images = generator(noise)
337
338 return fake_images.cpu()
339
340
341# ============================================================================
342# Main
343# ============================================================================
344
345if __name__ == '__main__':
346 # Train DCGAN
347 netG, netD, losses = train_dcgan(
348 epochs=25,
349 batch_size=128,
350 nz=100,
351 lr=0.0002,
352 beta1=0.5,
353 device='cuda'
354 )
355
356 # Generate final samples
357 print("\nGenerating final samples...")
358 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
359 samples = generate_images(netG, num_images=64, nz=100, device=device)
360
361 # Visualize final samples
362 grid = make_grid(samples, nrow=8, normalize=True)
363
364 plt.figure(figsize=(12, 12))
365 plt.imshow(np.transpose(grid, (1, 2, 0)))
366 plt.title('DCGAN Generated Samples (Final)')
367 plt.axis('off')
368 plt.tight_layout()
369 plt.savefig('dcgan_final_samples.png')
370 plt.show()
371
372 # Optional: Save models
373 torch.save(netG.state_dict(), 'dcgan_generator.pth')
374 torch.save(netD.state_dict(), 'dcgan_discriminator.pth')
375 print("\nModels saved to dcgan_generator.pth and dcgan_discriminator.pth")