1"""
2PyTorch Low-Level CLIP ๊ตฌํ
3
4Image Encoder, Text Encoder, Contrastive Loss ์ง์ ๊ตฌํ
5"""
6
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10import math
11from typing import Optional, Tuple
12from dataclasses import dataclass
13
14
15@dataclass
16class CLIPConfig:
17 """CLIP ์ค์ """
18 # Image encoder (ViT)
19 image_size: int = 224
20 patch_size: int = 32
21 vision_width: int = 768
22 vision_layers: int = 12
23 vision_heads: int = 12
24
25 # Text encoder
26 vocab_size: int = 49408
27 context_length: int = 77
28 text_width: int = 512
29 text_layers: int = 12
30 text_heads: int = 8
31
32 # Shared
33 embed_dim: int = 512 # ๊ณตํต ์๋ฒ ๋ฉ ์ฐจ์
34
35 # Training
36 temperature: float = 0.07 # learnable
37
38
39# ============== Vision Transformer (Image Encoder) ==============
40
41class PatchEmbedding(nn.Module):
42 """์ด๋ฏธ์ง๋ฅผ ํจ์น๋ก ๋ถํ ํ๊ณ ์๋ฒ ๋ฉ"""
43
44 def __init__(
45 self,
46 image_size: int,
47 patch_size: int,
48 in_channels: int = 3,
49 embed_dim: int = 768
50 ):
51 super().__init__()
52 self.num_patches = (image_size // patch_size) ** 2
53
54 # Conv2d๋ก ํจ์นํ + ์๋ฒ ๋ฉ
55 self.proj = nn.Conv2d(
56 in_channels, embed_dim,
57 kernel_size=patch_size, stride=patch_size
58 )
59
60 def forward(self, x: torch.Tensor) -> torch.Tensor:
61 # (B, C, H, W) -> (B, D, H/P, W/P) -> (B, N, D)
62 x = self.proj(x)
63 x = x.flatten(2).transpose(1, 2)
64 return x
65
66
67class MultiHeadAttention(nn.Module):
68 """Multi-Head Self-Attention"""
69
70 def __init__(self, dim: int, num_heads: int, dropout: float = 0.0):
71 super().__init__()
72 self.num_heads = num_heads
73 self.head_dim = dim // num_heads
74 self.scale = self.head_dim ** -0.5
75
76 self.qkv = nn.Linear(dim, dim * 3)
77 self.proj = nn.Linear(dim, dim)
78 self.dropout = nn.Dropout(dropout)
79
80 def forward(
81 self,
82 x: torch.Tensor,
83 attn_mask: Optional[torch.Tensor] = None
84 ) -> torch.Tensor:
85 B, N, D = x.shape
86
87 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
88 qkv = qkv.permute(2, 0, 3, 1, 4)
89 q, k, v = qkv[0], qkv[1], qkv[2]
90
91 attn = (q @ k.transpose(-2, -1)) * self.scale
92
93 if attn_mask is not None:
94 attn = attn + attn_mask
95
96 attn = F.softmax(attn, dim=-1)
97 attn = self.dropout(attn)
98
99 x = (attn @ v).transpose(1, 2).reshape(B, N, D)
100 x = self.proj(x)
101
102 return x
103
104
105class TransformerBlock(nn.Module):
106 """Transformer Block (Pre-LN)"""
107
108 def __init__(
109 self,
110 dim: int,
111 num_heads: int,
112 mlp_ratio: float = 4.0,
113 dropout: float = 0.0
114 ):
115 super().__init__()
116 self.norm1 = nn.LayerNorm(dim)
117 self.attn = MultiHeadAttention(dim, num_heads, dropout)
118 self.norm2 = nn.LayerNorm(dim)
119
120 mlp_hidden = int(dim * mlp_ratio)
121 self.mlp = nn.Sequential(
122 nn.Linear(dim, mlp_hidden),
123 nn.GELU(),
124 nn.Dropout(dropout),
125 nn.Linear(mlp_hidden, dim),
126 nn.Dropout(dropout)
127 )
128
129 def forward(
130 self,
131 x: torch.Tensor,
132 attn_mask: Optional[torch.Tensor] = None
133 ) -> torch.Tensor:
134 x = x + self.attn(self.norm1(x), attn_mask)
135 x = x + self.mlp(self.norm2(x))
136 return x
137
138
139class VisionTransformer(nn.Module):
140 """CLIP Vision Encoder"""
141
142 def __init__(self, config: CLIPConfig):
143 super().__init__()
144
145 # Patch embedding
146 self.patch_embed = PatchEmbedding(
147 config.image_size, config.patch_size,
148 in_channels=3, embed_dim=config.vision_width
149 )
150 num_patches = self.patch_embed.num_patches
151
152 # [CLS] token
153 self.cls_token = nn.Parameter(torch.zeros(1, 1, config.vision_width))
154
155 # Position embedding
156 self.pos_embed = nn.Parameter(
157 torch.zeros(1, num_patches + 1, config.vision_width)
158 )
159
160 # Transformer blocks
161 self.blocks = nn.ModuleList([
162 TransformerBlock(
163 config.vision_width,
164 config.vision_heads,
165 mlp_ratio=4.0
166 )
167 for _ in range(config.vision_layers)
168 ])
169
170 self.norm = nn.LayerNorm(config.vision_width)
171
172 # Projection to shared embedding space
173 self.proj = nn.Linear(config.vision_width, config.embed_dim, bias=False)
174
175 self._init_weights()
176
177 def _init_weights(self):
178 nn.init.trunc_normal_(self.pos_embed, std=0.02)
179 nn.init.trunc_normal_(self.cls_token, std=0.02)
180
181 def forward(self, x: torch.Tensor) -> torch.Tensor:
182 """
183 Args:
184 x: (B, 3, H, W)
185
186 Returns:
187 image_features: (B, embed_dim)
188 """
189 B = x.shape[0]
190
191 # Patch embedding
192 x = self.patch_embed(x)
193
194 # Add [CLS] token
195 cls_tokens = self.cls_token.expand(B, -1, -1)
196 x = torch.cat([cls_tokens, x], dim=1)
197
198 # Add position embedding
199 x = x + self.pos_embed
200
201 # Transformer
202 for block in self.blocks:
203 x = block(x)
204
205 x = self.norm(x)
206
207 # Use [CLS] token as image representation
208 x = x[:, 0]
209
210 # Project to shared space
211 x = self.proj(x)
212
213 return x
214
215
216# ============== Text Encoder ==============
217
218class TextTransformer(nn.Module):
219 """CLIP Text Encoder"""
220
221 def __init__(self, config: CLIPConfig):
222 super().__init__()
223 self.context_length = config.context_length
224
225 # Token embedding
226 self.token_embedding = nn.Embedding(config.vocab_size, config.text_width)
227
228 # Position embedding
229 self.positional_embedding = nn.Parameter(
230 torch.zeros(config.context_length, config.text_width)
231 )
232
233 # Transformer blocks
234 self.blocks = nn.ModuleList([
235 TransformerBlock(
236 config.text_width,
237 config.text_heads,
238 mlp_ratio=4.0
239 )
240 for _ in range(config.text_layers)
241 ])
242
243 self.ln_final = nn.LayerNorm(config.text_width)
244
245 # Projection to shared embedding space
246 self.text_projection = nn.Linear(config.text_width, config.embed_dim, bias=False)
247
248 # Causal mask
249 self.register_buffer(
250 "attn_mask",
251 self._build_causal_mask(config.context_length)
252 )
253
254 self._init_weights()
255
256 def _init_weights(self):
257 nn.init.normal_(self.token_embedding.weight, std=0.02)
258 nn.init.normal_(self.positional_embedding, std=0.01)
259
260 def _build_causal_mask(self, context_length: int) -> torch.Tensor:
261 """Causal attention mask"""
262 mask = torch.triu(
263 torch.full((context_length, context_length), float("-inf")),
264 diagonal=1
265 )
266 return mask
267
268 def forward(
269 self,
270 text: torch.Tensor,
271 attention_mask: Optional[torch.Tensor] = None
272 ) -> torch.Tensor:
273 """
274 Args:
275 text: (B, L) token ids
276 attention_mask: (B, L) optional padding mask
277
278 Returns:
279 text_features: (B, embed_dim)
280 """
281 B, L = text.shape
282
283 # Token + Position embedding
284 x = self.token_embedding(text)
285 x = x + self.positional_embedding[:L]
286
287 # Causal mask
288 causal_mask = self.attn_mask[:L, :L]
289
290 # Transformer
291 for block in self.blocks:
292 x = block(x, causal_mask)
293
294 x = self.ln_final(x)
295
296 # Use [EOS] token as text representation
297 # CLIP์์๋ ๊ฐ์ฅ ๋์ ํ ํฐ ์์น ์ฌ์ฉ (EOT)
298 # ์ฌ๊ธฐ์๋ ๊ฐ๋จํ ๋ง์ง๋ง ํ ํฐ ์ฌ์ฉ
299 if attention_mask is not None:
300 # ๊ฐ ์ํ์ค์ ๋ง์ง๋ง ์ ํจ ํ ํฐ
301 seq_lengths = attention_mask.sum(dim=1) - 1
302 x = x[torch.arange(B), seq_lengths]
303 else:
304 x = x[:, -1] # ๋ง์ง๋ง ํ ํฐ
305
306 # Project to shared space
307 x = self.text_projection(x)
308
309 return x
310
311
312# ============== CLIP Model ==============
313
314class CLIP(nn.Module):
315 """CLIP: Contrastive Language-Image Pre-training"""
316
317 def __init__(self, config: CLIPConfig):
318 super().__init__()
319 self.config = config
320
321 # Encoders
322 self.visual = VisionTransformer(config)
323 self.text_encoder = TextTransformer(config)
324
325 # Learnable temperature
326 self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / config.temperature))
327
328 def encode_image(self, image: torch.Tensor) -> torch.Tensor:
329 """์ด๋ฏธ์ง ์ธ์ฝ๋ฉ"""
330 features = self.visual(image)
331 # L2 normalize
332 features = F.normalize(features, dim=-1)
333 return features
334
335 def encode_text(
336 self,
337 text: torch.Tensor,
338 attention_mask: Optional[torch.Tensor] = None
339 ) -> torch.Tensor:
340 """ํ
์คํธ ์ธ์ฝ๋ฉ"""
341 features = self.text_encoder(text, attention_mask)
342 # L2 normalize
343 features = F.normalize(features, dim=-1)
344 return features
345
346 def forward(
347 self,
348 image: torch.Tensor,
349 text: torch.Tensor,
350 attention_mask: Optional[torch.Tensor] = None
351 ) -> Tuple[torch.Tensor, torch.Tensor]:
352 """
353 Args:
354 image: (B, 3, H, W)
355 text: (B, L)
356 attention_mask: (B, L) optional
357
358 Returns:
359 logits_per_image: (B, B)
360 logits_per_text: (B, B)
361 """
362 # Encode
363 image_features = self.encode_image(image)
364 text_features = self.encode_text(text, attention_mask)
365
366 # Scaled cosine similarity
367 logit_scale = self.logit_scale.exp()
368 logits_per_image = logit_scale * image_features @ text_features.t()
369 logits_per_text = logits_per_image.t()
370
371 return logits_per_image, logits_per_text
372
373
374def clip_loss(
375 logits_per_image: torch.Tensor,
376 logits_per_text: torch.Tensor
377) -> torch.Tensor:
378 """
379 CLIP Contrastive Loss (InfoNCE)
380
381 ์ด๋ฏธ์งโํ
์คํธ, ํ
์คํธโ์ด๋ฏธ์ง ์๋ฐฉํฅ ์์ค
382 """
383 batch_size = logits_per_image.shape[0]
384
385 # Ground truth: ๋๊ฐ์ ์ด positive
386 labels = torch.arange(batch_size, device=logits_per_image.device)
387
388 # Image-to-Text loss
389 loss_i2t = F.cross_entropy(logits_per_image, labels)
390
391 # Text-to-Image loss
392 loss_t2i = F.cross_entropy(logits_per_text, labels)
393
394 # Symmetric loss
395 loss = (loss_i2t + loss_t2i) / 2
396
397 return loss
398
399
400# ============== Zero-shot Classification ==============
401
402class ZeroShotClassifier:
403 """CLIP Zero-shot ๋ถ๋ฅ๊ธฐ"""
404
405 def __init__(self, model: CLIP, device: torch.device):
406 self.model = model
407 self.device = device
408 self.model.eval()
409
410 def create_text_embeddings(
411 self,
412 class_names: list,
413 templates: list = None
414 ) -> torch.Tensor:
415 """
416 ํด๋์ค๋ณ ํ
์คํธ ์๋ฒ ๋ฉ ์์ฑ
417
418 Args:
419 class_names: ํด๋์ค ์ด๋ฆ ๋ฆฌ์คํธ
420 templates: ํ๋กฌํํธ ํ
ํ๋ฆฟ ๋ฆฌ์คํธ
421
422 Returns:
423 text_features: (num_classes, embed_dim)
424 """
425 if templates is None:
426 templates = [
427 "a photo of a {}",
428 "a picture of a {}",
429 "an image of a {}"
430 ]
431
432 all_embeddings = []
433
434 with torch.no_grad():
435 for class_name in class_names:
436 class_embeddings = []
437
438 for template in templates:
439 text = template.format(class_name)
440 # ์ค์ ๋ก๋ tokenizer ์ฌ์ฉ
441 # ์ฌ๊ธฐ์๋ ๊ฐ๋จํ mock ํ
์ ์ฌ์ฉ
442 tokens = self._simple_tokenize(text)
443 tokens = tokens.to(self.device)
444
445 embedding = self.model.encode_text(tokens)
446 class_embeddings.append(embedding)
447
448 # ํ
ํ๋ฆฟ ํ๊ท
449 class_embedding = torch.stack(class_embeddings).mean(dim=0)
450 class_embedding = F.normalize(class_embedding, dim=-1)
451 all_embeddings.append(class_embedding)
452
453 return torch.cat(all_embeddings, dim=0)
454
455 def classify(
456 self,
457 images: torch.Tensor,
458 text_features: torch.Tensor
459 ) -> torch.Tensor:
460 """
461 Zero-shot ๋ถ๋ฅ
462
463 Args:
464 images: (B, 3, H, W)
465 text_features: (num_classes, embed_dim)
466
467 Returns:
468 predictions: (B,)
469 """
470 with torch.no_grad():
471 image_features = self.model.encode_image(images)
472
473 # Similarity
474 logits = 100.0 * image_features @ text_features.t()
475 probs = F.softmax(logits, dim=-1)
476 predictions = probs.argmax(dim=-1)
477
478 return predictions
479
480 def _simple_tokenize(self, text: str, max_length: int = 77) -> torch.Tensor:
481 """๊ฐ๋จํ ํ ํฐํ (์ค์ ๋ก๋ BPE ์ฌ์ฉ)"""
482 # Mock tokenization - ์ค์ ๊ตฌํ์์๋ CLIP tokenizer ์ฌ์ฉ
483 tokens = [ord(c) % 49408 for c in text[:max_length-2]]
484 tokens = [49406] + tokens + [49407] # SOT, EOT
485 tokens = tokens + [0] * (max_length - len(tokens)) # Padding
486 return torch.tensor([tokens])
487
488
489# ============== Image-Text Retrieval ==============
490
491class ImageTextRetrieval:
492 """์ด๋ฏธ์ง-ํ
์คํธ ๊ฒ์"""
493
494 def __init__(self, model: CLIP, device: torch.device):
495 self.model = model
496 self.device = device
497 self.model.eval()
498
499 self.image_embeddings = None
500 self.text_embeddings = None
501
502 def index_images(self, images: torch.Tensor):
503 """์ด๋ฏธ์ง ์ธ๋ฑ์ฑ"""
504 with torch.no_grad():
505 self.image_embeddings = self.model.encode_image(images.to(self.device))
506
507 def index_texts(self, texts: torch.Tensor):
508 """ํ
์คํธ ์ธ๋ฑ์ฑ"""
509 with torch.no_grad():
510 self.text_embeddings = self.model.encode_text(texts.to(self.device))
511
512 def search_by_text(
513 self,
514 query_text: torch.Tensor,
515 top_k: int = 5
516 ) -> Tuple[torch.Tensor, torch.Tensor]:
517 """ํ
์คํธ๋ก ์ด๋ฏธ์ง ๊ฒ์"""
518 with torch.no_grad():
519 query_features = self.model.encode_text(query_text.to(self.device))
520 similarities = query_features @ self.image_embeddings.t()
521 scores, indices = similarities.topk(top_k, dim=-1)
522
523 return indices, scores
524
525 def search_by_image(
526 self,
527 query_image: torch.Tensor,
528 top_k: int = 5
529 ) -> Tuple[torch.Tensor, torch.Tensor]:
530 """์ด๋ฏธ์ง๋ก ํ
์คํธ ๊ฒ์"""
531 with torch.no_grad():
532 query_features = self.model.encode_image(query_image.to(self.device))
533 similarities = query_features @ self.text_embeddings.t()
534 scores, indices = similarities.topk(top_k, dim=-1)
535
536 return indices, scores
537
538
539# ํ
์คํธ
540if __name__ == "__main__":
541 print("=== CLIP Low-Level Implementation ===\n")
542
543 # ์ค์
544 config = CLIPConfig(
545 image_size=224,
546 patch_size=32,
547 vision_width=768,
548 vision_layers=12,
549 vision_heads=12,
550 vocab_size=49408,
551 context_length=77,
552 text_width=512,
553 text_layers=12,
554 text_heads=8,
555 embed_dim=512
556 )
557 print(f"Config: embed_dim={config.embed_dim}, vision_layers={config.vision_layers}\n")
558
559 # ๋ชจ๋ธ ์์ฑ
560 model = CLIP(config)
561
562 # ํ๋ผ๋ฏธํฐ ์
563 vision_params = sum(p.numel() for p in model.visual.parameters())
564 text_params = sum(p.numel() for p in model.text_encoder.parameters())
565 total_params = sum(p.numel() for p in model.parameters())
566
567 print(f"Vision encoder params: {vision_params:,}")
568 print(f"Text encoder params: {text_params:,}")
569 print(f"Total params: {total_params:,}\n")
570
571 # ํ
์คํธ ์
๋ ฅ
572 batch_size = 4
573 images = torch.randn(batch_size, 3, 224, 224)
574 texts = torch.randint(0, config.vocab_size, (batch_size, 77))
575
576 # Forward
577 logits_per_image, logits_per_text = model(images, texts)
578 print(f"Images shape: {images.shape}")
579 print(f"Texts shape: {texts.shape}")
580 print(f"Logits per image shape: {logits_per_image.shape}")
581 print(f"Logits per text shape: {logits_per_text.shape}")
582
583 # Loss ๊ณ์ฐ
584 loss = clip_loss(logits_per_image, logits_per_text)
585 print(f"\nContrastive Loss: {loss.item():.4f}")
586
587 # ๊ฐ๋ณ ์ธ์ฝ๋ฉ ํ
์คํธ
588 print("\n=== Encoding Test ===")
589 image_features = model.encode_image(images)
590 text_features = model.encode_text(texts)
591 print(f"Image features shape: {image_features.shape}")
592 print(f"Text features shape: {text_features.shape}")
593 print(f"Image features norm: {image_features.norm(dim=-1).mean():.4f} (should be ~1.0)")
594 print(f"Text features norm: {text_features.norm(dim=-1).mean():.4f} (should be ~1.0)")
595
596 # Similarity ๊ณ์ฐ
597 similarity = image_features @ text_features.t()
598 print(f"\nSimilarity matrix:\n{similarity}")
599
600 # Temperature ํจ๊ณผ
601 print(f"\nTemperature (1/exp(logit_scale)): {1/model.logit_scale.exp().item():.4f}")
602 print(f"Scaled similarity range: [{(model.logit_scale.exp() * similarity).min().item():.2f}, "
603 f"{(model.logit_scale.exp() * similarity).max().item():.2f}]")
604
605 # Zero-shot ๋ถ๋ฅ ํ
์คํธ
606 print("\n=== Zero-shot Classification Test ===")
607 device = torch.device("cpu")
608 classifier = ZeroShotClassifier(model, device)
609
610 # Mock ๋ถ๋ฅ
611 class_names = ["cat", "dog", "bird", "car", "plane"]
612 print(f"Classes: {class_names}")
613
614 # ์ค์ ๋ก๋ text_features๋ฅผ ์์ฑํ๊ณ ๋ถ๋ฅ
615 # text_features = classifier.create_text_embeddings(class_names)
616 # predictions = classifier.classify(images, text_features)
617
618 print("\nAll tests passed!")