1"""
2PyTorch Low-Level Variational Autoencoder (VAE) 구현
3
4ELBO, Reparameterization Trick 직접 구현
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10import math
11from typing import Tuple, Optional
12from dataclasses import dataclass
13
14
15@dataclass
16class VAEConfig:
17 """VAE 설정"""
18 image_size: int = 28
19 in_channels: int = 1
20 latent_dim: int = 20
21 hidden_dims: Tuple[int, ...] = (32, 64)
22 beta: float = 1.0 # β-VAE
23
24
25class Encoder(nn.Module):
26 """VAE Encoder: x → (μ, log σ²)"""
27
28 def __init__(self, config: VAEConfig):
29 super().__init__()
30 self.config = config
31
32 # Convolutional layers
33 modules = []
34 in_channels = config.in_channels
35
36 for h_dim in config.hidden_dims:
37 modules.append(
38 nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1)
39 )
40 modules.append(nn.ReLU())
41 in_channels = h_dim
42
43 self.encoder = nn.Sequential(*modules)
44
45 # 최종 feature map 크기 계산
46 self.final_size = config.image_size // (2 ** len(config.hidden_dims))
47 self.flatten_dim = config.hidden_dims[-1] * self.final_size * self.final_size
48
49 # FC layers for μ and log σ²
50 self.fc_mu = nn.Linear(self.flatten_dim, config.latent_dim)
51 self.fc_logvar = nn.Linear(self.flatten_dim, config.latent_dim)
52
53 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
54 """
55 Args:
56 x: (B, C, H, W)
57
58 Returns:
59 mu: (B, latent_dim)
60 logvar: (B, latent_dim)
61 """
62 # Encode
63 h = self.encoder(x)
64 h = h.flatten(start_dim=1)
65
66 # μ and log σ²
67 mu = self.fc_mu(h)
68 logvar = self.fc_logvar(h)
69
70 return mu, logvar
71
72
73class Decoder(nn.Module):
74 """VAE Decoder: z → x̂"""
75
76 def __init__(self, config: VAEConfig):
77 super().__init__()
78 self.config = config
79
80 self.final_size = config.image_size // (2 ** len(config.hidden_dims))
81 self.flatten_dim = config.hidden_dims[-1] * self.final_size * self.final_size
82
83 # FC layer
84 self.fc = nn.Linear(config.latent_dim, self.flatten_dim)
85
86 # Transposed convolutions
87 modules = []
88 hidden_dims = list(config.hidden_dims)[::-1] # 역순
89
90 for i in range(len(hidden_dims) - 1):
91 modules.append(
92 nn.ConvTranspose2d(
93 hidden_dims[i], hidden_dims[i + 1],
94 kernel_size=3, stride=2, padding=1, output_padding=1
95 )
96 )
97 modules.append(nn.ReLU())
98
99 # Final layer
100 modules.append(
101 nn.ConvTranspose2d(
102 hidden_dims[-1], config.in_channels,
103 kernel_size=3, stride=2, padding=1, output_padding=1
104 )
105 )
106 modules.append(nn.Sigmoid()) # [0, 1] 범위
107
108 self.decoder = nn.Sequential(*modules)
109
110 def forward(self, z: torch.Tensor) -> torch.Tensor:
111 """
112 Args:
113 z: (B, latent_dim)
114
115 Returns:
116 x_recon: (B, C, H, W)
117 """
118 h = self.fc(z)
119 h = h.view(-1, self.config.hidden_dims[-1], self.final_size, self.final_size)
120 x_recon = self.decoder(h)
121 return x_recon
122
123
124class VAE(nn.Module):
125 """Variational Autoencoder"""
126
127 def __init__(self, config: VAEConfig):
128 super().__init__()
129 self.config = config
130
131 self.encoder = Encoder(config)
132 self.decoder = Decoder(config)
133
134 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
135 """
136 Reparameterization Trick
137
138 z = μ + σ ⊙ ε, where ε ~ N(0, I)
139
140 이렇게 하면 z의 그래디언트가 μ, σ를 통해 역전파됨
141 """
142 # σ = exp(log σ² / 2) = exp(logvar / 2)
143 std = torch.exp(0.5 * logvar)
144
145 # ε ~ N(0, I)
146 eps = torch.randn_like(std)
147
148 # z = μ + σ ⊙ ε
149 z = mu + std * eps
150
151 return z
152
153 def forward(
154 self,
155 x: torch.Tensor,
156 return_latent: bool = False
157 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
158 """
159 Forward pass
160
161 Args:
162 x: (B, C, H, W)
163 return_latent: latent z 반환 여부
164
165 Returns:
166 x_recon: (B, C, H, W) 재구성 이미지
167 mu: (B, latent_dim)
168 logvar: (B, latent_dim)
169 z: (optional) (B, latent_dim)
170 """
171 # Encode
172 mu, logvar = self.encoder(x)
173
174 # Reparameterize
175 z = self.reparameterize(mu, logvar)
176
177 # Decode
178 x_recon = self.decoder(z)
179
180 if return_latent:
181 return x_recon, mu, logvar, z
182
183 return x_recon, mu, logvar
184
185 def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
186 """인코딩만 수행"""
187 return self.encoder(x)
188
189 def decode(self, z: torch.Tensor) -> torch.Tensor:
190 """디코딩만 수행"""
191 return self.decoder(z)
192
193 def sample(self, num_samples: int, device: torch.device) -> torch.Tensor:
194 """잠재 공간에서 샘플링하여 이미지 생성"""
195 # Prior에서 샘플링: z ~ N(0, I)
196 z = torch.randn(num_samples, self.config.latent_dim, device=device)
197 samples = self.decode(z)
198 return samples
199
200
201def vae_loss(
202 x: torch.Tensor,
203 x_recon: torch.Tensor,
204 mu: torch.Tensor,
205 logvar: torch.Tensor,
206 beta: float = 1.0,
207 reduction: str = "mean"
208) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
209 """
210 VAE Loss: ELBO의 음수
211
212 L = L_recon + β * L_KL
213
214 Args:
215 x: 원본 이미지 (B, C, H, W)
216 x_recon: 재구성 이미지 (B, C, H, W)
217 mu: 평균 (B, latent_dim)
218 logvar: 로그 분산 (B, latent_dim)
219 beta: KL 가중치 (β-VAE)
220 reduction: "mean" or "sum"
221
222 Returns:
223 total_loss: 전체 손실
224 recon_loss: 재구성 손실
225 kl_loss: KL divergence
226 """
227 batch_size = x.size(0)
228
229 # Reconstruction loss (Binary Cross-Entropy)
230 # BCE는 각 픽셀을 독립적인 Bernoulli로 모델링
231 recon_loss = F.binary_cross_entropy(
232 x_recon, x, reduction='sum'
233 )
234
235 # KL Divergence: KL(N(μ, σ²) || N(0, 1))
236 # = -0.5 * Σ(1 + log σ² - μ² - σ²)
237 kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
238
239 # Total loss
240 total_loss = recon_loss + beta * kl_loss
241
242 if reduction == "mean":
243 total_loss = total_loss / batch_size
244 recon_loss = recon_loss / batch_size
245 kl_loss = kl_loss / batch_size
246
247 return total_loss, recon_loss, kl_loss
248
249
250class BetaVAE(VAE):
251 """β-VAE: Disentanglement를 위한 변형"""
252
253 def __init__(self, config: VAEConfig):
254 super().__init__(config)
255
256 def compute_loss(
257 self,
258 x: torch.Tensor,
259 x_recon: torch.Tensor,
260 mu: torch.Tensor,
261 logvar: torch.Tensor
262 ) -> Tuple[torch.Tensor, dict]:
263 """β-VAE 손실 계산"""
264 total_loss, recon_loss, kl_loss = vae_loss(
265 x, x_recon, mu, logvar,
266 beta=self.config.beta
267 )
268
269 return total_loss, {
270 "recon_loss": recon_loss.item(),
271 "kl_loss": kl_loss.item(),
272 "total_loss": total_loss.item()
273 }
274
275
276class ConditionalVAE(nn.Module):
277 """Conditional VAE: 조건부 생성"""
278
279 def __init__(self, config: VAEConfig, num_classes: int = 10):
280 super().__init__()
281 self.config = config
282 self.num_classes = num_classes
283
284 # 클래스 임베딩
285 self.class_embed = nn.Embedding(num_classes, config.latent_dim)
286
287 # Encoder와 Decoder는 동일
288 self.encoder = Encoder(config)
289 self.decoder = Decoder(config)
290
291 # Encoder에 조건 추가를 위한 projection
292 self.cond_proj = nn.Linear(config.latent_dim, config.in_channels * config.image_size * config.image_size)
293
294 def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
295 std = torch.exp(0.5 * logvar)
296 eps = torch.randn_like(std)
297 return mu + std * eps
298
299 def forward(
300 self,
301 x: torch.Tensor,
302 labels: torch.Tensor
303 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
304 """
305 Args:
306 x: (B, C, H, W)
307 labels: (B,) 클래스 레이블
308
309 Returns:
310 x_recon, mu, logvar
311 """
312 # 클래스 임베딩
313 c = self.class_embed(labels) # (B, latent_dim)
314
315 # Encode
316 mu, logvar = self.encoder(x)
317
318 # Reparameterize
319 z = self.reparameterize(mu, logvar)
320
321 # Decode with condition
322 z_cond = z + c # 조건 추가
323 x_recon = self.decoder(z_cond)
324
325 return x_recon, mu, logvar
326
327 def sample(
328 self,
329 num_samples: int,
330 labels: torch.Tensor,
331 device: torch.device
332 ) -> torch.Tensor:
333 """조건부 샘플링"""
334 z = torch.randn(num_samples, self.config.latent_dim, device=device)
335 c = self.class_embed(labels)
336 z_cond = z + c
337 samples = self.decoder(z_cond)
338 return samples
339
340
341# 잠재 공간 시각화
342def visualize_latent_space(
343 model: VAE,
344 data_loader,
345 device: torch.device,
346 num_samples: int = 1000
347):
348 """잠재 공간 2D 시각화 (latent_dim=2인 경우)"""
349 import matplotlib.pyplot as plt
350
351 model.eval()
352 latents = []
353 labels_list = []
354
355 with torch.no_grad():
356 for batch_idx, (data, labels) in enumerate(data_loader):
357 if len(latents) * data.size(0) >= num_samples:
358 break
359
360 data = data.to(device)
361 mu, _ = model.encode(data)
362 latents.append(mu.cpu())
363 labels_list.append(labels)
364
365 latents = torch.cat(latents, dim=0)[:num_samples]
366 labels = torch.cat(labels_list, dim=0)[:num_samples]
367
368 # 2D 시각화 (처음 2차원만 사용)
369 plt.figure(figsize=(10, 10))
370 scatter = plt.scatter(
371 latents[:, 0].numpy(),
372 latents[:, 1].numpy(),
373 c=labels.numpy(),
374 cmap='tab10',
375 alpha=0.7
376 )
377 plt.colorbar(scatter)
378 plt.xlabel('z[0]')
379 plt.ylabel('z[1]')
380 plt.title('VAE Latent Space')
381 plt.savefig('vae_latent_space.png')
382 print("Saved vae_latent_space.png")
383
384
385def interpolate_latent(
386 model: VAE,
387 x1: torch.Tensor,
388 x2: torch.Tensor,
389 num_steps: int = 10
390) -> torch.Tensor:
391 """두 이미지 간 잠재 공간 보간"""
392 model.eval()
393
394 with torch.no_grad():
395 # 두 이미지 인코딩
396 mu1, _ = model.encode(x1)
397 mu2, _ = model.encode(x2)
398
399 # 선형 보간
400 alphas = torch.linspace(0, 1, num_steps).to(mu1.device)
401 interpolated = []
402
403 for alpha in alphas:
404 z = (1 - alpha) * mu1 + alpha * mu2
405 x_recon = model.decode(z)
406 interpolated.append(x_recon)
407
408 return torch.cat(interpolated, dim=0)
409
410
411# 테스트
412if __name__ == "__main__":
413 print("=== VAE Low-Level Implementation ===\n")
414
415 # 설정
416 config = VAEConfig(
417 image_size=28,
418 in_channels=1,
419 latent_dim=20,
420 hidden_dims=(32, 64),
421 beta=1.0
422 )
423 print(f"Config: {config}\n")
424
425 # 모델 생성
426 model = VAE(config)
427
428 # 파라미터 수
429 total_params = sum(p.numel() for p in model.parameters())
430 print(f"Total parameters: {total_params:,}\n")
431
432 # 테스트 입력
433 batch_size = 8
434 x = torch.rand(batch_size, 1, 28, 28)
435
436 # Forward
437 x_recon, mu, logvar = model(x)
438 print(f"Input shape: {x.shape}")
439 print(f"Reconstruction shape: {x_recon.shape}")
440 print(f"Mu shape: {mu.shape}")
441 print(f"Logvar shape: {logvar.shape}")
442
443 # Loss 계산
444 total_loss, recon_loss, kl_loss = vae_loss(x, x_recon, mu, logvar)
445 print(f"\nTotal Loss: {total_loss.item():.4f}")
446 print(f"Recon Loss: {recon_loss.item():.4f}")
447 print(f"KL Loss: {kl_loss.item():.4f}")
448
449 # 샘플링 테스트
450 samples = model.sample(16, x.device)
451 print(f"\nSampled images shape: {samples.shape}")
452
453 # β-VAE 테스트
454 print("\n=== β-VAE Test ===")
455 config_beta = VAEConfig(beta=4.0) # β > 1 for disentanglement
456 beta_vae = BetaVAE(config_beta)
457
458 x_recon, mu, logvar = beta_vae(x)
459 loss, metrics = beta_vae.compute_loss(x, x_recon, mu, logvar)
460 print(f"β-VAE Loss: {metrics}")
461
462 # Conditional VAE 테스트
463 print("\n=== Conditional VAE Test ===")
464 cvae = ConditionalVAE(config, num_classes=10)
465 labels = torch.randint(0, 10, (batch_size,))
466
467 x_recon, mu, logvar = cvae(x, labels)
468 print(f"CVAE Reconstruction shape: {x_recon.shape}")
469
470 # 조건부 샘플링
471 cond_samples = cvae.sample(16, torch.arange(10).repeat(2)[:16], x.device)
472 print(f"Conditional samples shape: {cond_samples.shape}")
473
474 # 잠재 공간 보간
475 print("\n=== Latent Interpolation ===")
476 interp = interpolate_latent(model, x[:1], x[1:2], num_steps=5)
477 print(f"Interpolated images shape: {interp.shape}")
478
479 print("\nAll tests passed!")