vae_lowlevel.py

Download
python 480 lines 12.9 KB
  1"""
  2PyTorch Low-Level Variational Autoencoder (VAE) 구현
  3
  4ELBO, Reparameterization Trick 직접 구현
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10import math
 11from typing import Tuple, Optional
 12from dataclasses import dataclass
 13
 14
 15@dataclass
 16class VAEConfig:
 17    """VAE 설정"""
 18    image_size: int = 28
 19    in_channels: int = 1
 20    latent_dim: int = 20
 21    hidden_dims: Tuple[int, ...] = (32, 64)
 22    beta: float = 1.0  # β-VAE
 23
 24
 25class Encoder(nn.Module):
 26    """VAE Encoder: x → (μ, log σ²)"""
 27
 28    def __init__(self, config: VAEConfig):
 29        super().__init__()
 30        self.config = config
 31
 32        # Convolutional layers
 33        modules = []
 34        in_channels = config.in_channels
 35
 36        for h_dim in config.hidden_dims:
 37            modules.append(
 38                nn.Conv2d(in_channels, h_dim, kernel_size=3, stride=2, padding=1)
 39            )
 40            modules.append(nn.ReLU())
 41            in_channels = h_dim
 42
 43        self.encoder = nn.Sequential(*modules)
 44
 45        # 최종 feature map 크기 계산
 46        self.final_size = config.image_size // (2 ** len(config.hidden_dims))
 47        self.flatten_dim = config.hidden_dims[-1] * self.final_size * self.final_size
 48
 49        # FC layers for μ and log σ²
 50        self.fc_mu = nn.Linear(self.flatten_dim, config.latent_dim)
 51        self.fc_logvar = nn.Linear(self.flatten_dim, config.latent_dim)
 52
 53    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 54        """
 55        Args:
 56            x: (B, C, H, W)
 57
 58        Returns:
 59            mu: (B, latent_dim)
 60            logvar: (B, latent_dim)
 61        """
 62        # Encode
 63        h = self.encoder(x)
 64        h = h.flatten(start_dim=1)
 65
 66        # μ and log σ²
 67        mu = self.fc_mu(h)
 68        logvar = self.fc_logvar(h)
 69
 70        return mu, logvar
 71
 72
 73class Decoder(nn.Module):
 74    """VAE Decoder: z → x̂"""
 75
 76    def __init__(self, config: VAEConfig):
 77        super().__init__()
 78        self.config = config
 79
 80        self.final_size = config.image_size // (2 ** len(config.hidden_dims))
 81        self.flatten_dim = config.hidden_dims[-1] * self.final_size * self.final_size
 82
 83        # FC layer
 84        self.fc = nn.Linear(config.latent_dim, self.flatten_dim)
 85
 86        # Transposed convolutions
 87        modules = []
 88        hidden_dims = list(config.hidden_dims)[::-1]  # 역순
 89
 90        for i in range(len(hidden_dims) - 1):
 91            modules.append(
 92                nn.ConvTranspose2d(
 93                    hidden_dims[i], hidden_dims[i + 1],
 94                    kernel_size=3, stride=2, padding=1, output_padding=1
 95                )
 96            )
 97            modules.append(nn.ReLU())
 98
 99        # Final layer
100        modules.append(
101            nn.ConvTranspose2d(
102                hidden_dims[-1], config.in_channels,
103                kernel_size=3, stride=2, padding=1, output_padding=1
104            )
105        )
106        modules.append(nn.Sigmoid())  # [0, 1] 범위
107
108        self.decoder = nn.Sequential(*modules)
109
110    def forward(self, z: torch.Tensor) -> torch.Tensor:
111        """
112        Args:
113            z: (B, latent_dim)
114
115        Returns:
116            x_recon: (B, C, H, W)
117        """
118        h = self.fc(z)
119        h = h.view(-1, self.config.hidden_dims[-1], self.final_size, self.final_size)
120        x_recon = self.decoder(h)
121        return x_recon
122
123
124class VAE(nn.Module):
125    """Variational Autoencoder"""
126
127    def __init__(self, config: VAEConfig):
128        super().__init__()
129        self.config = config
130
131        self.encoder = Encoder(config)
132        self.decoder = Decoder(config)
133
134    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
135        """
136        Reparameterization Trick
137
138        z = μ + σ ⊙ ε, where ε ~ N(0, I)
139
140        이렇게 하면 z의 그래디언트가 μ, σ를 통해 역전파됨
141        """
142        # σ = exp(log σ² / 2) = exp(logvar / 2)
143        std = torch.exp(0.5 * logvar)
144
145        # ε ~ N(0, I)
146        eps = torch.randn_like(std)
147
148        # z = μ + σ ⊙ ε
149        z = mu + std * eps
150
151        return z
152
153    def forward(
154        self,
155        x: torch.Tensor,
156        return_latent: bool = False
157    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
158        """
159        Forward pass
160
161        Args:
162            x: (B, C, H, W)
163            return_latent: latent z 반환 여부
164
165        Returns:
166            x_recon: (B, C, H, W) 재구성 이미지
167            mu: (B, latent_dim)
168            logvar: (B, latent_dim)
169            z: (optional) (B, latent_dim)
170        """
171        # Encode
172        mu, logvar = self.encoder(x)
173
174        # Reparameterize
175        z = self.reparameterize(mu, logvar)
176
177        # Decode
178        x_recon = self.decoder(z)
179
180        if return_latent:
181            return x_recon, mu, logvar, z
182
183        return x_recon, mu, logvar
184
185    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
186        """인코딩만 수행"""
187        return self.encoder(x)
188
189    def decode(self, z: torch.Tensor) -> torch.Tensor:
190        """디코딩만 수행"""
191        return self.decoder(z)
192
193    def sample(self, num_samples: int, device: torch.device) -> torch.Tensor:
194        """잠재 공간에서 샘플링하여 이미지 생성"""
195        # Prior에서 샘플링: z ~ N(0, I)
196        z = torch.randn(num_samples, self.config.latent_dim, device=device)
197        samples = self.decode(z)
198        return samples
199
200
201def vae_loss(
202    x: torch.Tensor,
203    x_recon: torch.Tensor,
204    mu: torch.Tensor,
205    logvar: torch.Tensor,
206    beta: float = 1.0,
207    reduction: str = "mean"
208) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
209    """
210    VAE Loss: ELBO의 음수
211
212    L = L_recon + β * L_KL
213
214    Args:
215        x: 원본 이미지 (B, C, H, W)
216        x_recon: 재구성 이미지 (B, C, H, W)
217        mu: 평균 (B, latent_dim)
218        logvar: 로그 분산 (B, latent_dim)
219        beta: KL 가중치 (β-VAE)
220        reduction: "mean" or "sum"
221
222    Returns:
223        total_loss: 전체 손실
224        recon_loss: 재구성 손실
225        kl_loss: KL divergence
226    """
227    batch_size = x.size(0)
228
229    # Reconstruction loss (Binary Cross-Entropy)
230    # BCE는 각 픽셀을 독립적인 Bernoulli로 모델링
231    recon_loss = F.binary_cross_entropy(
232        x_recon, x, reduction='sum'
233    )
234
235    # KL Divergence: KL(N(μ, σ²) || N(0, 1))
236    # = -0.5 * Σ(1 + log σ² - μ² - σ²)
237    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
238
239    # Total loss
240    total_loss = recon_loss + beta * kl_loss
241
242    if reduction == "mean":
243        total_loss = total_loss / batch_size
244        recon_loss = recon_loss / batch_size
245        kl_loss = kl_loss / batch_size
246
247    return total_loss, recon_loss, kl_loss
248
249
250class BetaVAE(VAE):
251    """β-VAE: Disentanglement를 위한 변형"""
252
253    def __init__(self, config: VAEConfig):
254        super().__init__(config)
255
256    def compute_loss(
257        self,
258        x: torch.Tensor,
259        x_recon: torch.Tensor,
260        mu: torch.Tensor,
261        logvar: torch.Tensor
262    ) -> Tuple[torch.Tensor, dict]:
263        """β-VAE 손실 계산"""
264        total_loss, recon_loss, kl_loss = vae_loss(
265            x, x_recon, mu, logvar,
266            beta=self.config.beta
267        )
268
269        return total_loss, {
270            "recon_loss": recon_loss.item(),
271            "kl_loss": kl_loss.item(),
272            "total_loss": total_loss.item()
273        }
274
275
276class ConditionalVAE(nn.Module):
277    """Conditional VAE: 조건부 생성"""
278
279    def __init__(self, config: VAEConfig, num_classes: int = 10):
280        super().__init__()
281        self.config = config
282        self.num_classes = num_classes
283
284        # 클래스 임베딩
285        self.class_embed = nn.Embedding(num_classes, config.latent_dim)
286
287        # Encoder와 Decoder는 동일
288        self.encoder = Encoder(config)
289        self.decoder = Decoder(config)
290
291        # Encoder에 조건 추가를 위한 projection
292        self.cond_proj = nn.Linear(config.latent_dim, config.in_channels * config.image_size * config.image_size)
293
294    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
295        std = torch.exp(0.5 * logvar)
296        eps = torch.randn_like(std)
297        return mu + std * eps
298
299    def forward(
300        self,
301        x: torch.Tensor,
302        labels: torch.Tensor
303    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
304        """
305        Args:
306            x: (B, C, H, W)
307            labels: (B,) 클래스 레이블
308
309        Returns:
310            x_recon, mu, logvar
311        """
312        # 클래스 임베딩
313        c = self.class_embed(labels)  # (B, latent_dim)
314
315        # Encode
316        mu, logvar = self.encoder(x)
317
318        # Reparameterize
319        z = self.reparameterize(mu, logvar)
320
321        # Decode with condition
322        z_cond = z + c  # 조건 추가
323        x_recon = self.decoder(z_cond)
324
325        return x_recon, mu, logvar
326
327    def sample(
328        self,
329        num_samples: int,
330        labels: torch.Tensor,
331        device: torch.device
332    ) -> torch.Tensor:
333        """조건부 샘플링"""
334        z = torch.randn(num_samples, self.config.latent_dim, device=device)
335        c = self.class_embed(labels)
336        z_cond = z + c
337        samples = self.decoder(z_cond)
338        return samples
339
340
341# 잠재 공간 시각화
342def visualize_latent_space(
343    model: VAE,
344    data_loader,
345    device: torch.device,
346    num_samples: int = 1000
347):
348    """잠재 공간 2D 시각화 (latent_dim=2인 경우)"""
349    import matplotlib.pyplot as plt
350
351    model.eval()
352    latents = []
353    labels_list = []
354
355    with torch.no_grad():
356        for batch_idx, (data, labels) in enumerate(data_loader):
357            if len(latents) * data.size(0) >= num_samples:
358                break
359
360            data = data.to(device)
361            mu, _ = model.encode(data)
362            latents.append(mu.cpu())
363            labels_list.append(labels)
364
365    latents = torch.cat(latents, dim=0)[:num_samples]
366    labels = torch.cat(labels_list, dim=0)[:num_samples]
367
368    # 2D 시각화 (처음 2차원만 사용)
369    plt.figure(figsize=(10, 10))
370    scatter = plt.scatter(
371        latents[:, 0].numpy(),
372        latents[:, 1].numpy(),
373        c=labels.numpy(),
374        cmap='tab10',
375        alpha=0.7
376    )
377    plt.colorbar(scatter)
378    plt.xlabel('z[0]')
379    plt.ylabel('z[1]')
380    plt.title('VAE Latent Space')
381    plt.savefig('vae_latent_space.png')
382    print("Saved vae_latent_space.png")
383
384
385def interpolate_latent(
386    model: VAE,
387    x1: torch.Tensor,
388    x2: torch.Tensor,
389    num_steps: int = 10
390) -> torch.Tensor:
391    """두 이미지 간 잠재 공간 보간"""
392    model.eval()
393
394    with torch.no_grad():
395        # 두 이미지 인코딩
396        mu1, _ = model.encode(x1)
397        mu2, _ = model.encode(x2)
398
399        # 선형 보간
400        alphas = torch.linspace(0, 1, num_steps).to(mu1.device)
401        interpolated = []
402
403        for alpha in alphas:
404            z = (1 - alpha) * mu1 + alpha * mu2
405            x_recon = model.decode(z)
406            interpolated.append(x_recon)
407
408        return torch.cat(interpolated, dim=0)
409
410
411# 테스트
412if __name__ == "__main__":
413    print("=== VAE Low-Level Implementation ===\n")
414
415    # 설정
416    config = VAEConfig(
417        image_size=28,
418        in_channels=1,
419        latent_dim=20,
420        hidden_dims=(32, 64),
421        beta=1.0
422    )
423    print(f"Config: {config}\n")
424
425    # 모델 생성
426    model = VAE(config)
427
428    # 파라미터 수
429    total_params = sum(p.numel() for p in model.parameters())
430    print(f"Total parameters: {total_params:,}\n")
431
432    # 테스트 입력
433    batch_size = 8
434    x = torch.rand(batch_size, 1, 28, 28)
435
436    # Forward
437    x_recon, mu, logvar = model(x)
438    print(f"Input shape: {x.shape}")
439    print(f"Reconstruction shape: {x_recon.shape}")
440    print(f"Mu shape: {mu.shape}")
441    print(f"Logvar shape: {logvar.shape}")
442
443    # Loss 계산
444    total_loss, recon_loss, kl_loss = vae_loss(x, x_recon, mu, logvar)
445    print(f"\nTotal Loss: {total_loss.item():.4f}")
446    print(f"Recon Loss: {recon_loss.item():.4f}")
447    print(f"KL Loss: {kl_loss.item():.4f}")
448
449    # 샘플링 테스트
450    samples = model.sample(16, x.device)
451    print(f"\nSampled images shape: {samples.shape}")
452
453    # β-VAE 테스트
454    print("\n=== β-VAE Test ===")
455    config_beta = VAEConfig(beta=4.0)  # β > 1 for disentanglement
456    beta_vae = BetaVAE(config_beta)
457
458    x_recon, mu, logvar = beta_vae(x)
459    loss, metrics = beta_vae.compute_loss(x, x_recon, mu, logvar)
460    print(f"β-VAE Loss: {metrics}")
461
462    # Conditional VAE 테스트
463    print("\n=== Conditional VAE Test ===")
464    cvae = ConditionalVAE(config, num_classes=10)
465    labels = torch.randint(0, 10, (batch_size,))
466
467    x_recon, mu, logvar = cvae(x, labels)
468    print(f"CVAE Reconstruction shape: {x_recon.shape}")
469
470    # 조건부 샘플링
471    cond_samples = cvae.sample(16, torch.arange(10).repeat(2)[:16], x.device)
472    print(f"Conditional samples shape: {cond_samples.shape}")
473
474    # 잠재 공간 보간
475    print("\n=== Latent Interpolation ===")
476    interp = interpolate_latent(model, x[:1], x[1:2], num_steps=5)
477    print(f"Interpolated images shape: {interp.shape}")
478
479    print("\nAll tests passed!")