clip_lowlevel.py

Download
python 619 lines 17.5 KB
  1"""
  2PyTorch Low-Level CLIP ๊ตฌํ˜„
  3
  4Image Encoder, Text Encoder, Contrastive Loss ์ง์ ‘ ๊ตฌํ˜„
  5"""
  6
  7import torch
  8import torch.nn as nn
  9import torch.nn.functional as F
 10import math
 11from typing import Optional, Tuple
 12from dataclasses import dataclass
 13
 14
 15@dataclass
 16class CLIPConfig:
 17    """CLIP ์„ค์ •"""
 18    # Image encoder (ViT)
 19    image_size: int = 224
 20    patch_size: int = 32
 21    vision_width: int = 768
 22    vision_layers: int = 12
 23    vision_heads: int = 12
 24
 25    # Text encoder
 26    vocab_size: int = 49408
 27    context_length: int = 77
 28    text_width: int = 512
 29    text_layers: int = 12
 30    text_heads: int = 8
 31
 32    # Shared
 33    embed_dim: int = 512  # ๊ณตํ†ต ์ž„๋ฒ ๋”ฉ ์ฐจ์›
 34
 35    # Training
 36    temperature: float = 0.07  # learnable
 37
 38
 39# ============== Vision Transformer (Image Encoder) ==============
 40
 41class PatchEmbedding(nn.Module):
 42    """์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜๋กœ ๋ถ„ํ• ํ•˜๊ณ  ์ž„๋ฒ ๋”ฉ"""
 43
 44    def __init__(
 45        self,
 46        image_size: int,
 47        patch_size: int,
 48        in_channels: int = 3,
 49        embed_dim: int = 768
 50    ):
 51        super().__init__()
 52        self.num_patches = (image_size // patch_size) ** 2
 53
 54        # Conv2d๋กœ ํŒจ์น˜ํ™” + ์ž„๋ฒ ๋”ฉ
 55        self.proj = nn.Conv2d(
 56            in_channels, embed_dim,
 57            kernel_size=patch_size, stride=patch_size
 58        )
 59
 60    def forward(self, x: torch.Tensor) -> torch.Tensor:
 61        # (B, C, H, W) -> (B, D, H/P, W/P) -> (B, N, D)
 62        x = self.proj(x)
 63        x = x.flatten(2).transpose(1, 2)
 64        return x
 65
 66
 67class MultiHeadAttention(nn.Module):
 68    """Multi-Head Self-Attention"""
 69
 70    def __init__(self, dim: int, num_heads: int, dropout: float = 0.0):
 71        super().__init__()
 72        self.num_heads = num_heads
 73        self.head_dim = dim // num_heads
 74        self.scale = self.head_dim ** -0.5
 75
 76        self.qkv = nn.Linear(dim, dim * 3)
 77        self.proj = nn.Linear(dim, dim)
 78        self.dropout = nn.Dropout(dropout)
 79
 80    def forward(
 81        self,
 82        x: torch.Tensor,
 83        attn_mask: Optional[torch.Tensor] = None
 84    ) -> torch.Tensor:
 85        B, N, D = x.shape
 86
 87        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
 88        qkv = qkv.permute(2, 0, 3, 1, 4)
 89        q, k, v = qkv[0], qkv[1], qkv[2]
 90
 91        attn = (q @ k.transpose(-2, -1)) * self.scale
 92
 93        if attn_mask is not None:
 94            attn = attn + attn_mask
 95
 96        attn = F.softmax(attn, dim=-1)
 97        attn = self.dropout(attn)
 98
 99        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
100        x = self.proj(x)
101
102        return x
103
104
105class TransformerBlock(nn.Module):
106    """Transformer Block (Pre-LN)"""
107
108    def __init__(
109        self,
110        dim: int,
111        num_heads: int,
112        mlp_ratio: float = 4.0,
113        dropout: float = 0.0
114    ):
115        super().__init__()
116        self.norm1 = nn.LayerNorm(dim)
117        self.attn = MultiHeadAttention(dim, num_heads, dropout)
118        self.norm2 = nn.LayerNorm(dim)
119
120        mlp_hidden = int(dim * mlp_ratio)
121        self.mlp = nn.Sequential(
122            nn.Linear(dim, mlp_hidden),
123            nn.GELU(),
124            nn.Dropout(dropout),
125            nn.Linear(mlp_hidden, dim),
126            nn.Dropout(dropout)
127        )
128
129    def forward(
130        self,
131        x: torch.Tensor,
132        attn_mask: Optional[torch.Tensor] = None
133    ) -> torch.Tensor:
134        x = x + self.attn(self.norm1(x), attn_mask)
135        x = x + self.mlp(self.norm2(x))
136        return x
137
138
139class VisionTransformer(nn.Module):
140    """CLIP Vision Encoder"""
141
142    def __init__(self, config: CLIPConfig):
143        super().__init__()
144
145        # Patch embedding
146        self.patch_embed = PatchEmbedding(
147            config.image_size, config.patch_size,
148            in_channels=3, embed_dim=config.vision_width
149        )
150        num_patches = self.patch_embed.num_patches
151
152        # [CLS] token
153        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.vision_width))
154
155        # Position embedding
156        self.pos_embed = nn.Parameter(
157            torch.zeros(1, num_patches + 1, config.vision_width)
158        )
159
160        # Transformer blocks
161        self.blocks = nn.ModuleList([
162            TransformerBlock(
163                config.vision_width,
164                config.vision_heads,
165                mlp_ratio=4.0
166            )
167            for _ in range(config.vision_layers)
168        ])
169
170        self.norm = nn.LayerNorm(config.vision_width)
171
172        # Projection to shared embedding space
173        self.proj = nn.Linear(config.vision_width, config.embed_dim, bias=False)
174
175        self._init_weights()
176
177    def _init_weights(self):
178        nn.init.trunc_normal_(self.pos_embed, std=0.02)
179        nn.init.trunc_normal_(self.cls_token, std=0.02)
180
181    def forward(self, x: torch.Tensor) -> torch.Tensor:
182        """
183        Args:
184            x: (B, 3, H, W)
185
186        Returns:
187            image_features: (B, embed_dim)
188        """
189        B = x.shape[0]
190
191        # Patch embedding
192        x = self.patch_embed(x)
193
194        # Add [CLS] token
195        cls_tokens = self.cls_token.expand(B, -1, -1)
196        x = torch.cat([cls_tokens, x], dim=1)
197
198        # Add position embedding
199        x = x + self.pos_embed
200
201        # Transformer
202        for block in self.blocks:
203            x = block(x)
204
205        x = self.norm(x)
206
207        # Use [CLS] token as image representation
208        x = x[:, 0]
209
210        # Project to shared space
211        x = self.proj(x)
212
213        return x
214
215
216# ============== Text Encoder ==============
217
218class TextTransformer(nn.Module):
219    """CLIP Text Encoder"""
220
221    def __init__(self, config: CLIPConfig):
222        super().__init__()
223        self.context_length = config.context_length
224
225        # Token embedding
226        self.token_embedding = nn.Embedding(config.vocab_size, config.text_width)
227
228        # Position embedding
229        self.positional_embedding = nn.Parameter(
230            torch.zeros(config.context_length, config.text_width)
231        )
232
233        # Transformer blocks
234        self.blocks = nn.ModuleList([
235            TransformerBlock(
236                config.text_width,
237                config.text_heads,
238                mlp_ratio=4.0
239            )
240            for _ in range(config.text_layers)
241        ])
242
243        self.ln_final = nn.LayerNorm(config.text_width)
244
245        # Projection to shared embedding space
246        self.text_projection = nn.Linear(config.text_width, config.embed_dim, bias=False)
247
248        # Causal mask
249        self.register_buffer(
250            "attn_mask",
251            self._build_causal_mask(config.context_length)
252        )
253
254        self._init_weights()
255
256    def _init_weights(self):
257        nn.init.normal_(self.token_embedding.weight, std=0.02)
258        nn.init.normal_(self.positional_embedding, std=0.01)
259
260    def _build_causal_mask(self, context_length: int) -> torch.Tensor:
261        """Causal attention mask"""
262        mask = torch.triu(
263            torch.full((context_length, context_length), float("-inf")),
264            diagonal=1
265        )
266        return mask
267
268    def forward(
269        self,
270        text: torch.Tensor,
271        attention_mask: Optional[torch.Tensor] = None
272    ) -> torch.Tensor:
273        """
274        Args:
275            text: (B, L) token ids
276            attention_mask: (B, L) optional padding mask
277
278        Returns:
279            text_features: (B, embed_dim)
280        """
281        B, L = text.shape
282
283        # Token + Position embedding
284        x = self.token_embedding(text)
285        x = x + self.positional_embedding[:L]
286
287        # Causal mask
288        causal_mask = self.attn_mask[:L, :L]
289
290        # Transformer
291        for block in self.blocks:
292            x = block(x, causal_mask)
293
294        x = self.ln_final(x)
295
296        # Use [EOS] token as text representation
297        # CLIP์—์„œ๋Š” ๊ฐ€์žฅ ๋†’์€ ํ† ํฐ ์œ„์น˜ ์‚ฌ์šฉ (EOT)
298        # ์—ฌ๊ธฐ์„œ๋Š” ๊ฐ„๋‹จํžˆ ๋งˆ์ง€๋ง‰ ํ† ํฐ ์‚ฌ์šฉ
299        if attention_mask is not None:
300            # ๊ฐ ์‹œํ€€์Šค์˜ ๋งˆ์ง€๋ง‰ ์œ ํšจ ํ† ํฐ
301            seq_lengths = attention_mask.sum(dim=1) - 1
302            x = x[torch.arange(B), seq_lengths]
303        else:
304            x = x[:, -1]  # ๋งˆ์ง€๋ง‰ ํ† ํฐ
305
306        # Project to shared space
307        x = self.text_projection(x)
308
309        return x
310
311
312# ============== CLIP Model ==============
313
314class CLIP(nn.Module):
315    """CLIP: Contrastive Language-Image Pre-training"""
316
317    def __init__(self, config: CLIPConfig):
318        super().__init__()
319        self.config = config
320
321        # Encoders
322        self.visual = VisionTransformer(config)
323        self.text_encoder = TextTransformer(config)
324
325        # Learnable temperature
326        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / config.temperature))
327
328    def encode_image(self, image: torch.Tensor) -> torch.Tensor:
329        """์ด๋ฏธ์ง€ ์ธ์ฝ”๋”ฉ"""
330        features = self.visual(image)
331        # L2 normalize
332        features = F.normalize(features, dim=-1)
333        return features
334
335    def encode_text(
336        self,
337        text: torch.Tensor,
338        attention_mask: Optional[torch.Tensor] = None
339    ) -> torch.Tensor:
340        """ํ…์ŠคํŠธ ์ธ์ฝ”๋”ฉ"""
341        features = self.text_encoder(text, attention_mask)
342        # L2 normalize
343        features = F.normalize(features, dim=-1)
344        return features
345
346    def forward(
347        self,
348        image: torch.Tensor,
349        text: torch.Tensor,
350        attention_mask: Optional[torch.Tensor] = None
351    ) -> Tuple[torch.Tensor, torch.Tensor]:
352        """
353        Args:
354            image: (B, 3, H, W)
355            text: (B, L)
356            attention_mask: (B, L) optional
357
358        Returns:
359            logits_per_image: (B, B)
360            logits_per_text: (B, B)
361        """
362        # Encode
363        image_features = self.encode_image(image)
364        text_features = self.encode_text(text, attention_mask)
365
366        # Scaled cosine similarity
367        logit_scale = self.logit_scale.exp()
368        logits_per_image = logit_scale * image_features @ text_features.t()
369        logits_per_text = logits_per_image.t()
370
371        return logits_per_image, logits_per_text
372
373
374def clip_loss(
375    logits_per_image: torch.Tensor,
376    logits_per_text: torch.Tensor
377) -> torch.Tensor:
378    """
379    CLIP Contrastive Loss (InfoNCE)
380
381    ์ด๋ฏธ์ง€โ†’ํ…์ŠคํŠธ, ํ…์ŠคํŠธโ†’์ด๋ฏธ์ง€ ์–‘๋ฐฉํ–ฅ ์†์‹ค
382    """
383    batch_size = logits_per_image.shape[0]
384
385    # Ground truth: ๋Œ€๊ฐ์„ ์ด positive
386    labels = torch.arange(batch_size, device=logits_per_image.device)
387
388    # Image-to-Text loss
389    loss_i2t = F.cross_entropy(logits_per_image, labels)
390
391    # Text-to-Image loss
392    loss_t2i = F.cross_entropy(logits_per_text, labels)
393
394    # Symmetric loss
395    loss = (loss_i2t + loss_t2i) / 2
396
397    return loss
398
399
400# ============== Zero-shot Classification ==============
401
402class ZeroShotClassifier:
403    """CLIP Zero-shot ๋ถ„๋ฅ˜๊ธฐ"""
404
405    def __init__(self, model: CLIP, device: torch.device):
406        self.model = model
407        self.device = device
408        self.model.eval()
409
410    def create_text_embeddings(
411        self,
412        class_names: list,
413        templates: list = None
414    ) -> torch.Tensor:
415        """
416        ํด๋ž˜์Šค๋ณ„ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ
417
418        Args:
419            class_names: ํด๋ž˜์Šค ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ
420            templates: ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ๋ฆฌ์ŠคํŠธ
421
422        Returns:
423            text_features: (num_classes, embed_dim)
424        """
425        if templates is None:
426            templates = [
427                "a photo of a {}",
428                "a picture of a {}",
429                "an image of a {}"
430            ]
431
432        all_embeddings = []
433
434        with torch.no_grad():
435            for class_name in class_names:
436                class_embeddings = []
437
438                for template in templates:
439                    text = template.format(class_name)
440                    # ์‹ค์ œ๋กœ๋Š” tokenizer ์‚ฌ์šฉ
441                    # ์—ฌ๊ธฐ์„œ๋Š” ๊ฐ„๋‹จํžˆ mock ํ…์„œ ์‚ฌ์šฉ
442                    tokens = self._simple_tokenize(text)
443                    tokens = tokens.to(self.device)
444
445                    embedding = self.model.encode_text(tokens)
446                    class_embeddings.append(embedding)
447
448                # ํ…œํ”Œ๋ฆฟ ํ‰๊ท 
449                class_embedding = torch.stack(class_embeddings).mean(dim=0)
450                class_embedding = F.normalize(class_embedding, dim=-1)
451                all_embeddings.append(class_embedding)
452
453        return torch.cat(all_embeddings, dim=0)
454
455    def classify(
456        self,
457        images: torch.Tensor,
458        text_features: torch.Tensor
459    ) -> torch.Tensor:
460        """
461        Zero-shot ๋ถ„๋ฅ˜
462
463        Args:
464            images: (B, 3, H, W)
465            text_features: (num_classes, embed_dim)
466
467        Returns:
468            predictions: (B,)
469        """
470        with torch.no_grad():
471            image_features = self.model.encode_image(images)
472
473            # Similarity
474            logits = 100.0 * image_features @ text_features.t()
475            probs = F.softmax(logits, dim=-1)
476            predictions = probs.argmax(dim=-1)
477
478        return predictions
479
480    def _simple_tokenize(self, text: str, max_length: int = 77) -> torch.Tensor:
481        """๊ฐ„๋‹จํ•œ ํ† ํฐํ™” (์‹ค์ œ๋กœ๋Š” BPE ์‚ฌ์šฉ)"""
482        # Mock tokenization - ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” CLIP tokenizer ์‚ฌ์šฉ
483        tokens = [ord(c) % 49408 for c in text[:max_length-2]]
484        tokens = [49406] + tokens + [49407]  # SOT, EOT
485        tokens = tokens + [0] * (max_length - len(tokens))  # Padding
486        return torch.tensor([tokens])
487
488
489# ============== Image-Text Retrieval ==============
490
491class ImageTextRetrieval:
492    """์ด๋ฏธ์ง€-ํ…์ŠคํŠธ ๊ฒ€์ƒ‰"""
493
494    def __init__(self, model: CLIP, device: torch.device):
495        self.model = model
496        self.device = device
497        self.model.eval()
498
499        self.image_embeddings = None
500        self.text_embeddings = None
501
502    def index_images(self, images: torch.Tensor):
503        """์ด๋ฏธ์ง€ ์ธ๋ฑ์‹ฑ"""
504        with torch.no_grad():
505            self.image_embeddings = self.model.encode_image(images.to(self.device))
506
507    def index_texts(self, texts: torch.Tensor):
508        """ํ…์ŠคํŠธ ์ธ๋ฑ์‹ฑ"""
509        with torch.no_grad():
510            self.text_embeddings = self.model.encode_text(texts.to(self.device))
511
512    def search_by_text(
513        self,
514        query_text: torch.Tensor,
515        top_k: int = 5
516    ) -> Tuple[torch.Tensor, torch.Tensor]:
517        """ํ…์ŠคํŠธ๋กœ ์ด๋ฏธ์ง€ ๊ฒ€์ƒ‰"""
518        with torch.no_grad():
519            query_features = self.model.encode_text(query_text.to(self.device))
520            similarities = query_features @ self.image_embeddings.t()
521            scores, indices = similarities.topk(top_k, dim=-1)
522
523        return indices, scores
524
525    def search_by_image(
526        self,
527        query_image: torch.Tensor,
528        top_k: int = 5
529    ) -> Tuple[torch.Tensor, torch.Tensor]:
530        """์ด๋ฏธ์ง€๋กœ ํ…์ŠคํŠธ ๊ฒ€์ƒ‰"""
531        with torch.no_grad():
532            query_features = self.model.encode_image(query_image.to(self.device))
533            similarities = query_features @ self.text_embeddings.t()
534            scores, indices = similarities.topk(top_k, dim=-1)
535
536        return indices, scores
537
538
539# ํ…Œ์ŠคํŠธ
540if __name__ == "__main__":
541    print("=== CLIP Low-Level Implementation ===\n")
542
543    # ์„ค์ •
544    config = CLIPConfig(
545        image_size=224,
546        patch_size=32,
547        vision_width=768,
548        vision_layers=12,
549        vision_heads=12,
550        vocab_size=49408,
551        context_length=77,
552        text_width=512,
553        text_layers=12,
554        text_heads=8,
555        embed_dim=512
556    )
557    print(f"Config: embed_dim={config.embed_dim}, vision_layers={config.vision_layers}\n")
558
559    # ๋ชจ๋ธ ์ƒ์„ฑ
560    model = CLIP(config)
561
562    # ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜
563    vision_params = sum(p.numel() for p in model.visual.parameters())
564    text_params = sum(p.numel() for p in model.text_encoder.parameters())
565    total_params = sum(p.numel() for p in model.parameters())
566
567    print(f"Vision encoder params: {vision_params:,}")
568    print(f"Text encoder params: {text_params:,}")
569    print(f"Total params: {total_params:,}\n")
570
571    # ํ…Œ์ŠคํŠธ ์ž…๋ ฅ
572    batch_size = 4
573    images = torch.randn(batch_size, 3, 224, 224)
574    texts = torch.randint(0, config.vocab_size, (batch_size, 77))
575
576    # Forward
577    logits_per_image, logits_per_text = model(images, texts)
578    print(f"Images shape: {images.shape}")
579    print(f"Texts shape: {texts.shape}")
580    print(f"Logits per image shape: {logits_per_image.shape}")
581    print(f"Logits per text shape: {logits_per_text.shape}")
582
583    # Loss ๊ณ„์‚ฐ
584    loss = clip_loss(logits_per_image, logits_per_text)
585    print(f"\nContrastive Loss: {loss.item():.4f}")
586
587    # ๊ฐœ๋ณ„ ์ธ์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ
588    print("\n=== Encoding Test ===")
589    image_features = model.encode_image(images)
590    text_features = model.encode_text(texts)
591    print(f"Image features shape: {image_features.shape}")
592    print(f"Text features shape: {text_features.shape}")
593    print(f"Image features norm: {image_features.norm(dim=-1).mean():.4f} (should be ~1.0)")
594    print(f"Text features norm: {text_features.norm(dim=-1).mean():.4f} (should be ~1.0)")
595
596    # Similarity ๊ณ„์‚ฐ
597    similarity = image_features @ text_features.t()
598    print(f"\nSimilarity matrix:\n{similarity}")
599
600    # Temperature ํšจ๊ณผ
601    print(f"\nTemperature (1/exp(logit_scale)): {1/model.logit_scale.exp().item():.4f}")
602    print(f"Scaled similarity range: [{(model.logit_scale.exp() * similarity).min().item():.2f}, "
603          f"{(model.logit_scale.exp() * similarity).max().item():.2f}]")
604
605    # Zero-shot ๋ถ„๋ฅ˜ ํ…Œ์ŠคํŠธ
606    print("\n=== Zero-shot Classification Test ===")
607    device = torch.device("cpu")
608    classifier = ZeroShotClassifier(model, device)
609
610    # Mock ๋ถ„๋ฅ˜
611    class_names = ["cat", "dog", "bird", "car", "plane"]
612    print(f"Classes: {class_names}")
613
614    # ์‹ค์ œ๋กœ๋Š” text_features๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ๋ถ„๋ฅ˜
615    # text_features = classifier.create_text_embeddings(class_names)
616    # predictions = classifier.classify(images, text_features)
617
618    print("\nAll tests passed!")