36. Self-Supervised Learning

์ด์ „: CLIP (Contrastive Language-Image Pre-training) | ๋‹ค์Œ: ํ˜„๋Œ€ ๋”ฅ๋Ÿฌ๋‹ ์•„ํ‚คํ…์ฒ˜


36. Self-Supervised Learning

ํ•™์Šต ๋ชฉํ‘œ

  • Self-Supervised Learning ๊ฐœ๋…๊ณผ ํ•„์š”์„ฑ ์ดํ•ด
  • Contrastive Learning (SimCLR, MoCo, BYOL)
  • Masked Image Modeling (MAE)
  • ์‚ฌ์ „ํ•™์Šต ํ‘œํ˜„์˜ ์ „์ด ํ•™์Šต
  • PyTorch ๊ตฌํ˜„ ๋ฐ ์‹ค์Šต

1. Self-Supervised Learning ๊ฐœ์š”

์ •์˜์™€ ํ•„์š”์„ฑ

์ •์˜: ๋ ˆ์ด๋ธ” ์—†์ด ๋ฐ์ดํ„ฐ ์ž์ฒด์—์„œ ํ•™์Šต ์‹ ํ˜ธ ์ƒ์„ฑ

์™œ ํ•„์š”ํ•œ๊ฐ€?
1. ๋ ˆ์ด๋ธ”๋ง ๋น„์šฉ: ImageNet (1400๋งŒ์žฅ ๋ ˆ์ด๋ธ”๋ง) ๋น„์šฉ ๋ง‰๋Œ€
2. ํ’๋ถ€ํ•œ ๋น„๋ ˆ์ด๋ธ” ๋ฐ์ดํ„ฐ: ์ธํ„ฐ๋„ท์˜ ๋Œ€๋ถ€๋ถ„ ๋ฐ์ดํ„ฐ
3. ์ผ๋ฐ˜ํ™” ๋Šฅ๋ ฅ: ๋‹ค์–‘ํ•œ ๋‹ค์šด์ŠคํŠธ๋ฆผ ํƒœ์Šคํฌ์— ์ „์ด

ํŒจ๋Ÿฌ๋‹ค์ž„ ๋ณ€ํ™”:
์ง€๋„ํ•™์Šต: ๋ฐ์ดํ„ฐ + ๋ ˆ์ด๋ธ” โ†’ ๋ชจ๋ธ
์ž๊ธฐ์ง€๋„ํ•™์Šต: ๋ฐ์ดํ„ฐ โ†’ Pretext Task โ†’ ํ‘œํ˜„ โ†’ ๋‹ค์šด์ŠคํŠธ๋ฆผ ํƒœ์Šคํฌ

SSL์˜ ์ข…๋ฅ˜

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚              Self-Supervised Learning Methods               โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                              โ”‚
โ”‚  Contrastive Learning          Masked Modeling               โ”‚
โ”‚  โ”œโ”€โ”€ SimCLR                   โ”œโ”€โ”€ MAE (Vision)               โ”‚
โ”‚  โ”œโ”€โ”€ MoCo                     โ”œโ”€โ”€ BERT (NLP)                 โ”‚
โ”‚  โ”œโ”€โ”€ BYOL                     โ””โ”€โ”€ BEiT                       โ”‚
โ”‚  โ””โ”€โ”€ SimSiam                                                 โ”‚
โ”‚                                                              โ”‚
โ”‚  Clustering                    Generative                    โ”‚
โ”‚  โ”œโ”€โ”€ DeepCluster              โ”œโ”€โ”€ VAE                        โ”‚
โ”‚  โ””โ”€โ”€ SwAV                     โ””โ”€โ”€ GAN                        โ”‚
โ”‚                                                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2. Contrastive Learning ๊ธฐ์ดˆ

ํ•ต์‹ฌ ์•„์ด๋””์–ด

๊ฐ™์€ ์ด๋ฏธ์ง€์˜ ๋‹ค๋ฅธ augmentation โ†’ ๊ฐ€๊น๊ฒŒ (Positive)
๋‹ค๋ฅธ ์ด๋ฏธ์ง€ โ†’ ๋ฉ€๊ฒŒ (Negative)

x โ”€โ”€โ”ฌโ”€โ”€ Aug1 โ†’ view1 โ”€โ”€โ”
    โ”‚                   โ”‚ โ†’ ์œ ์‚ฌ๋„ ์ตœ๋Œ€ํ™” (positive pair)
    โ””โ”€โ”€ Aug2 โ†’ view2 โ”€โ”€โ”˜

x1 โ”€โ”€ Aug โ†’ view1 โ”€โ”
                    โ”‚ โ†’ ์œ ์‚ฌ๋„ ์ตœ์†Œํ™” (negative pair)
x2 โ”€โ”€ Aug โ†’ view2 โ”€โ”˜

InfoNCE Loss

import torch
import torch.nn.functional as F

def info_nce_loss(features, temperature=0.5):
    """InfoNCE (NT-Xent) Loss (โญโญโญ)

    features: (2N, D) - N๊ฐœ ์ด๋ฏธ์ง€์˜ ๋‘ augmentation

    Returns:
        loss: contrastive loss
    """
    batch_size = features.shape[0] // 2

    # Normalize
    features = F.normalize(features, dim=1)

    # Similarity matrix
    similarity_matrix = features @ features.T / temperature

    # Mask diagonal and same-image pairs
    labels = torch.cat([torch.arange(batch_size) for _ in range(2)])
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    # Mask self-similarity (diagonal)
    mask = torch.eye(labels.shape[0], dtype=torch.bool)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    # Positives: same image, different augmentation
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # Negatives: different images
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    # Logits
    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long)

    return F.cross_entropy(logits, labels.to(logits.device))

3. SimCLR (Simple Framework for Contrastive Learning)

์•„ํ‚คํ…์ฒ˜

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                     SimCLR Framework                         โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                              โ”‚
โ”‚     Image x                                                  โ”‚
โ”‚        โ”‚                                                     โ”‚
โ”‚        โ–ผ                                                     โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                                            โ”‚
โ”‚  โ”‚ Data Aug    โ”‚ โ†’ ๋‘ ๊ฐœ์˜ view ์ƒ์„ฑ (t, t')                โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                                            โ”‚
โ”‚      โ”‚     โ”‚                                                 โ”‚
โ”‚     t(x)  t'(x)                                             โ”‚
โ”‚      โ”‚     โ”‚                                                 โ”‚
โ”‚      โ–ผ     โ–ผ                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                                            โ”‚
โ”‚  โ”‚  Encoder f  โ”‚ โ†’ ํŠน์ง• ์ถ”์ถœ (ResNet ๋“ฑ)                    โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                                            โ”‚
โ”‚      โ”‚     โ”‚                                                 โ”‚
โ”‚     h_i   h_j                                                โ”‚
โ”‚      โ”‚     โ”‚                                                 โ”‚
โ”‚      โ–ผ     โ–ผ                                                 โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                                            โ”‚
โ”‚  โ”‚Projection g โ”‚ โ†’ MLP๋กœ ์ €์ฐจ์› ์ž„๋ฒ ๋”ฉ                      โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                                            โ”‚
โ”‚      โ”‚     โ”‚                                                 โ”‚
โ”‚     z_i   z_j                                                โ”‚
โ”‚      โ”‚     โ”‚                                                 โ”‚
โ”‚      โ””โ”€โ”€โ”ฌโ”€โ”€โ”˜                                                โ”‚
โ”‚         โ–ผ                                                    โ”‚
โ”‚    NT-Xent Loss                                              โ”‚
โ”‚                                                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

PyTorch ๊ตฌํ˜„

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50

class SimCLR(nn.Module):
    """SimCLR Model (โญโญโญโญ)"""
    def __init__(self, base_encoder=resnet50, projection_dim=128):
        super().__init__()

        # Encoder (pretrained ResNet without FC)
        self.encoder = base_encoder(weights=None)
        self.encoder_dim = self.encoder.fc.in_features
        self.encoder.fc = nn.Identity()

        # Projection Head (MLP)
        self.projection_head = nn.Sequential(
            nn.Linear(self.encoder_dim, self.encoder_dim),
            nn.ReLU(),
            nn.Linear(self.encoder_dim, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)  # ํŠน์ง• ๋ฒกํ„ฐ
        z = self.projection_head(h)  # ํˆฌ์˜ ๋ฒกํ„ฐ
        return h, z

    def get_features(self, x):
        """๋‹ค์šด์ŠคํŠธ๋ฆผ ํƒœ์Šคํฌ์šฉ ํŠน์ง• ์ถ”์ถœ"""
        return self.encoder(x)


class SimCLRAugmentation:
    """SimCLR Data Augmentation (โญโญโญ)"""
    def __init__(self, size=224):
        self.train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=int(0.1 * size) // 2 * 2 + 1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

SimCLR ํ•™์Šต

def train_simclr(model, train_loader, epochs=100, lr=0.3, temperature=0.5):
    """SimCLR Training Loop (โญโญโญ)"""
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    device = next(model.parameters()).device

    for epoch in range(epochs):
        total_loss = 0

        for (x_i, x_j), _ in train_loader:
            x_i, x_j = x_i.to(device), x_j.to(device)

            # Forward
            _, z_i = model(x_i)
            _, z_j = model(x_j)

            # Concatenate
            z = torch.cat([z_i, z_j], dim=0)

            # Loss
            loss = info_nce_loss(z, temperature)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs}: Loss = {total_loss/len(train_loader):.4f}")

    return model

4. MoCo (Momentum Contrast)

ํ•ต์‹ฌ ์•„์ด๋””์–ด

SimCLR ๋ฌธ์ œ: ํฐ ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ ํ•„์š” (4096+)
MoCo ํ•ด๊ฒฐ: Momentum Encoder + Queue๋กœ ๋งŽ์€ negative ํ™•๋ณด

ํŠน์ง•:
1. Queue: ์ด์ „ ๋ฐฐ์น˜์˜ ์ž„๋ฒ ๋”ฉ ์ €์žฅ (65536๊ฐœ)
2. Momentum Encoder: ์ฒœ์ฒœํžˆ ์—…๋ฐ์ดํŠธ๋˜๋Š” ์ธ์ฝ”๋”

MoCo ์•„ํ‚คํ…์ฒ˜

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                     MoCo Framework                           โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                              โ”‚
โ”‚   Query                          Key                         โ”‚
โ”‚     โ”‚                              โ”‚                         โ”‚
โ”‚     โ–ผ                              โ–ผ                         โ”‚
โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”‚
โ”‚ โ”‚Encoder fโ”‚                  โ”‚Encoder  โ”‚                    โ”‚
โ”‚ โ”‚  (q)    โ”‚                  โ”‚f_k (mom)โ”‚ โ† momentum update  โ”‚
โ”‚ โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                  โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                    โ”‚
โ”‚      โ”‚                            โ”‚                          โ”‚
โ”‚      โ–ผ                            โ–ผ                          โ”‚
โ”‚     q_i                         k_i โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”       โ”‚
โ”‚      โ”‚                            โ”‚                 โ”‚        โ”‚
โ”‚      โ”‚                            โ–ผ                 โ–ผ        โ”‚
โ”‚      โ”‚                       โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”     โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”   โ”‚
โ”‚      โ”‚                       โ”‚  Queue  โ”‚ โ†โ”€โ”€ โ”‚ enqueue  โ”‚   โ”‚
โ”‚      โ”‚                       โ”‚ (k-)    โ”‚     โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜   โ”‚
โ”‚      โ”‚                       โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                    โ”‚
โ”‚      โ”‚                            โ”‚                          โ”‚
โ”‚      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                         โ”‚
โ”‚                   โ–ผ                                          โ”‚
โ”‚            InfoNCE Loss                                      โ”‚
โ”‚        (qยทk+ / qยทk-)                                        โ”‚
โ”‚                                                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

PyTorch ๊ตฌํ˜„

import copy

class MoCo(nn.Module):
    """Momentum Contrast (MoCo v2) (โญโญโญโญ)"""
    def __init__(self, base_encoder=resnet50, dim=128, K=65536, m=0.999, T=0.07):
        super().__init__()
        self.K = K  # Queue size
        self.m = m  # Momentum
        self.T = T  # Temperature

        # Query encoder
        self.encoder_q = base_encoder(weights=None)
        self.encoder_q.fc = nn.Sequential(
            nn.Linear(self.encoder_q.fc.in_features, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )

        # Key encoder (momentum)
        self.encoder_k = copy.deepcopy(self.encoder_q)
        for param_k in self.encoder_k.parameters():
            param_k.requires_grad = False  # ์—ญ์ „ํŒŒ X

        # Queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = F.normalize(self.queue, dim=0)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """Key encoder์˜ momentum ์—…๋ฐ์ดํŠธ"""
        for param_q, param_k in zip(self.encoder_q.parameters(),
                                     self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        """Queue ์—…๋ฐ์ดํŠธ"""
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)

        # Replace oldest
        if ptr + batch_size > self.K:
            # Wrap around
            self.queue[:, ptr:] = keys[:self.K - ptr].T
            self.queue[:, :batch_size - (self.K - ptr)] = keys[self.K - ptr:].T
        else:
            self.queue[:, ptr:ptr + batch_size] = keys.T

        ptr = (ptr + batch_size) % self.K
        self.queue_ptr[0] = ptr

    def forward(self, im_q, im_k):
        """
        Args:
            im_q: query image
            im_k: key image (different augmentation)
        """
        # Query
        q = self.encoder_q(im_q)
        q = F.normalize(q, dim=1)

        # Key (no gradient)
        with torch.no_grad():
            self._momentum_update_key_encoder()
            k = self.encoder_k(im_k)
            k = F.normalize(k, dim=1)

        # Positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)

        # Negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # Logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1) / self.T

        # Labels: positives are at index 0
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)

        # Dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels


def train_moco(model, train_loader, epochs=200):
    """MoCo Training (โญโญโญ)"""
    optimizer = torch.optim.SGD(model.encoder_q.parameters(),
                                lr=0.03, momentum=0.9, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        for (im_q, im_k), _ in train_loader:
            logits, labels = model(im_q.cuda(), im_k.cuda())
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

5. BYOL (Bootstrap Your Own Latent)

ํ•ต์‹ฌ ์•„์ด๋””์–ด

๋ฌธ์ œ: Negative samples์ด ์ •๋ง ํ•„์š”ํ•œ๊ฐ€?
BYOL: Negative ์—†์ด ํ•™์Šต! (Online + Target network)

ํ•ต์‹ฌ: Predictor ๋„คํŠธ์›Œํฌ + EMA๋กœ collapse ๋ฐฉ์ง€

BYOL ์•„ํ‚คํ…์ฒ˜

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                     BYOL Framework                           โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                              โ”‚
โ”‚   Online Network                 Target Network              โ”‚
โ”‚   (ํ•™์Šต๋จ)                       (EMA๋กœ ์—…๋ฐ์ดํŠธ)            โ”‚
โ”‚                                                              โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                 โ”‚
โ”‚  โ”‚Encoder  โ”‚                    โ”‚Encoder  โ”‚                 โ”‚
โ”‚  โ”‚(ฮธ)      โ”‚                    โ”‚(ฮพ)      โ”‚ โ† EMA           โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                    โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                 โ”‚
โ”‚       โ”‚                              โ”‚                       โ”‚
โ”‚       โ–ผ                              โ–ผ                       โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                 โ”‚
โ”‚  โ”‚Projectorโ”‚                    โ”‚Projectorโ”‚                 โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                    โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                 โ”‚
โ”‚       โ”‚                              โ”‚                       โ”‚
โ”‚       โ–ผ                              โ”‚                       โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                         โ”‚                       โ”‚
โ”‚  โ”‚Predictorโ”‚ (Online๋งŒ)              โ”‚                       โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”˜                         โ”‚                       โ”‚
โ”‚       โ”‚                              โ”‚                       โ”‚
โ”‚      q_ฮธ                           z_ฮพ                       โ”‚
โ”‚       โ”‚                              โ”‚                       โ”‚
โ”‚       โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                      โ”‚
โ”‚                  โ–ผ                                           โ”‚
โ”‚         MSE Loss (q_ฮธ, sg(z_ฮพ))                             โ”‚
โ”‚         sg = stop gradient                                   โ”‚
โ”‚                                                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

PyTorch ๊ตฌํ˜„

class BYOL(nn.Module):
    """Bootstrap Your Own Latent (โญโญโญโญ)"""
    def __init__(self, base_encoder=resnet50, hidden_dim=4096, proj_dim=256, pred_dim=256):
        super().__init__()

        # Online network
        encoder = base_encoder(weights=None)
        encoder_dim = encoder.fc.in_features
        encoder.fc = nn.Identity()

        self.online_encoder = encoder
        self.online_projector = nn.Sequential(
            nn.Linear(encoder_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, proj_dim)
        )
        self.predictor = nn.Sequential(
            nn.Linear(proj_dim, pred_dim),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(),
            nn.Linear(pred_dim, proj_dim)
        )

        # Target network (EMA)
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.target_projector = copy.deepcopy(self.online_projector)

        # Freeze target
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False

    @torch.no_grad()
    def update_target(self, tau=0.99):
        """Target network EMA ์—…๋ฐ์ดํŠธ"""
        for online, target in zip(self.online_encoder.parameters(),
                                   self.target_encoder.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data

        for online, target in zip(self.online_projector.parameters(),
                                   self.target_projector.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data

    def forward(self, x1, x2):
        # Online predictions
        online_proj_1 = self.online_projector(self.online_encoder(x1))
        online_proj_2 = self.online_projector(self.online_encoder(x2))

        online_pred_1 = self.predictor(online_proj_1)
        online_pred_2 = self.predictor(online_proj_2)

        # Target projections (no gradient)
        with torch.no_grad():
            target_proj_1 = self.target_projector(self.target_encoder(x1))
            target_proj_2 = self.target_projector(self.target_encoder(x2))

        return online_pred_1, online_pred_2, target_proj_1, target_proj_2


def byol_loss(pred, target):
    """BYOL Loss: Negative Cosine Similarity (โญโญโญ)"""
    pred = F.normalize(pred, dim=-1)
    target = F.normalize(target, dim=-1)
    return 2 - 2 * (pred * target).sum(dim=-1).mean()


def train_byol(model, train_loader, epochs=100):
    """BYOL Training (โญโญโญ)"""
    optimizer = torch.optim.Adam(
        list(model.online_encoder.parameters()) +
        list(model.online_projector.parameters()) +
        list(model.predictor.parameters()),
        lr=3e-4
    )

    for epoch in range(epochs):
        for (x1, x2), _ in train_loader:
            x1, x2 = x1.cuda(), x2.cuda()

            pred1, pred2, target1, target2 = model(x1, x2)

            # ๋Œ€์นญ์  ์†์‹ค
            loss = byol_loss(pred1, target2) + byol_loss(pred2, target1)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Target ์—…๋ฐ์ดํŠธ
            model.update_target(tau=0.99)

6. MAE (Masked Autoencoder)

ํ•ต์‹ฌ ์•„์ด๋””์–ด

NLP์˜ BERT ์•„์ด๋””์–ด๋ฅผ Vision์— ์ ์šฉ:
- ์ด๋ฏธ์ง€์˜ ์ผ๋ถ€(75%)๋ฅผ ๋งˆ์Šคํ‚น
- ๋งˆ์Šคํ‚น๋œ ๋ถ€๋ถ„์„ ๋ณต์›ํ•˜๋„๋ก ํ•™์Šต

์žฅ์ :
1. ๋†’์€ ๋งˆ์Šคํ‚น ๋น„์œจ โ†’ ๊ฐ•ํ•œ ํ•™์Šต ์‹ ํ˜ธ
2. ํšจ์œจ์ : ๋งˆ์Šคํ‚น๋˜์ง€ ์•Š์€ ํŒจ์น˜๋งŒ ์ธ์ฝ”๋”ฉ
3. ViT์™€ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๊ฒฐํ•ฉ

MAE ์•„ํ‚คํ…์ฒ˜

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                     MAE Architecture                         โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚                                                              โ”‚
โ”‚   Image โ†’ Patches                                            โ”‚
โ”‚         โ†“                                                    โ”‚
โ”‚   โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚   โ”‚ [P1] [P2] [P3] ... [P196]                   โ”‚           โ”‚
โ”‚   โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚         โ†“ Random Masking (75%)                               โ”‚
โ”‚   โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚   โ”‚ [P1] [M] [P3] [M] [M] ... [P196]            โ”‚           โ”‚
โ”‚   โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚         โ†“ (visible patches only)                             โ”‚
โ”‚   โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚   โ”‚ Encoder (ViT) - only on visible patches     โ”‚           โ”‚
โ”‚   โ”‚ [P1] [P3] ... [P196]                        โ”‚           โ”‚
โ”‚   โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚         โ†“ + Mask tokens + Position Embedding                 โ”‚
โ”‚   โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”           โ”‚
โ”‚   โ”‚ Decoder (small ViT)                         โ”‚           โ”‚
โ”‚   โ”‚ [P1] [M] [P3] [M] [M] ... [P196]            โ”‚           โ”‚
โ”‚   โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜           โ”‚
โ”‚         โ†“                                                    โ”‚
โ”‚   Reconstruct masked patches (MSE Loss)                      โ”‚
โ”‚                                                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

PyTorch ๊ตฌํ˜„

import random

class MAE(nn.Module):
    """Masked Autoencoder (โญโญโญโญ)"""
    def __init__(self, img_size=224, patch_size=16, in_channels=3,
                 encoder_embed_dim=768, encoder_depth=12, encoder_heads=12,
                 decoder_embed_dim=512, decoder_depth=8, decoder_heads=16,
                 mask_ratio=0.75):
        super().__init__()
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        num_patches = (img_size // patch_size) ** 2

        # Encoder
        self.patch_embed = nn.Conv2d(in_channels, encoder_embed_dim,
                                     kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, encoder_embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, encoder_embed_dim))

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(encoder_embed_dim, encoder_heads, batch_first=True),
            num_layers=encoder_depth
        )
        self.encoder_norm = nn.LayerNorm(encoder_embed_dim)

        # Decoder
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim)
        self.mask_token = nn.Parameter(torch.randn(1, 1, decoder_embed_dim))
        self.decoder_pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, decoder_embed_dim))

        self.decoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(decoder_embed_dim, decoder_heads, batch_first=True),
            num_layers=decoder_depth
        )
        self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels)

    def random_masking(self, x, mask_ratio):
        """Random masking (โญโญโญ)"""
        B, N, D = x.shape
        len_keep = int(N * (1 - mask_ratio))

        # ๋žœ๋ค ์…”ํ”Œ
        noise = torch.rand(B, N, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # Keep first len_keep
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))

        # Generate mask
        mask = torch.ones([B, N], device=x.device)
        mask[:, :len_keep] = 0
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)

        # Add position embedding (without cls)
        x = x + self.pos_embed[:, 1:, :]

        # Masking
        x, mask, ids_restore = self.random_masking(x, self.mask_ratio)

        # Append cls token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
        x = torch.cat([cls_tokens, x], dim=1)

        # Encoder
        x = self.encoder(x)
        x = self.encoder_norm(x)

        # Decoder embed
        x = self.decoder_embed(x)

        # Append mask tokens
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
        x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x.shape[2]))
        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

        # Add position embedding
        x = x + self.decoder_pos_embed

        # Decoder
        x = self.decoder(x)
        x = self.decoder_norm(x)

        # Predictor
        x = self.decoder_pred(x)
        x = x[:, 1:, :]  # remove cls token

        return x, mask

    def loss(self, pred, target, mask):
        """MAE Loss: MSE on masked patches only (โญโญโญ)"""
        target = self.patchify(target)
        loss = (pred - target) ** 2
        loss = loss.mean(dim=-1)  # mean per patch
        loss = (loss * mask).sum() / mask.sum()  # mean on masked
        return loss

    def patchify(self, imgs):
        """์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜๋กœ ๋ณ€ํ™˜"""
        p = self.patch_size
        B, C, H, W = imgs.shape
        h = w = H // p
        x = imgs.reshape(B, C, h, p, w, p)
        x = x.permute(0, 2, 4, 3, 5, 1)
        x = x.reshape(B, h * w, p * p * C)
        return x

MAE ํ•™์Šต

def train_mae(model, train_loader, epochs=400, lr=1.5e-4):
    """MAE Training (โญโญโญ)"""
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
                                   betas=(0.9, 0.95), weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    for epoch in range(epochs):
        total_loss = 0

        for images, _ in train_loader:
            images = images.cuda()

            pred, mask = model(images)
            loss = model.loss(pred, images, mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}/{epochs}: Loss = {total_loss/len(train_loader):.4f}")

7. Linear Evaluation

ํ‘œํ˜„ ํ’ˆ์งˆ ํ‰๊ฐ€

def linear_evaluation(encoder, train_loader, test_loader, num_classes=10):
    """SSL ํ‘œํ˜„์˜ Linear Evaluation (โญโญโญ)"""
    # Encoder freeze
    encoder.eval()
    for param in encoder.parameters():
        param.requires_grad = False

    # Linear classifier
    feature_dim = 2048  # ResNet-50 ๊ธฐ์ค€
    classifier = nn.Linear(feature_dim, num_classes).cuda()

    optimizer = torch.optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9)

    # Extract features
    def extract_features(loader):
        features, labels = [], []
        with torch.no_grad():
            for images, targets in loader:
                feat = encoder(images.cuda())
                features.append(feat.cpu())
                labels.append(targets)
        return torch.cat(features), torch.cat(labels)

    train_features, train_labels = extract_features(train_loader)
    test_features, test_labels = extract_features(test_loader)

    # Train linear classifier
    train_dataset = torch.utils.data.TensorDataset(train_features, train_labels)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)

    for epoch in range(100):
        for features, labels in train_loader:
            features, labels = features.cuda(), labels.cuda()
            output = classifier(features)
            loss = F.cross_entropy(output, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Evaluate
    classifier.eval()
    with torch.no_grad():
        output = classifier(test_features.cuda())
        pred = output.argmax(dim=1).cpu()
        accuracy = (pred == test_labels).float().mean().item()

    return accuracy * 100

8. ๋ฐฉ๋ฒ•๋ก  ๋น„๊ต

ํŠน์„ฑ ๋น„๊ต

๋ฐฉ๋ฒ• Negative ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ ์ฃผ์š” ํŠน์ง•
SimCLR ํ•„์š” 4096+ ๋‹จ์ˆœ, ๊ฐ•๋ ฅํ•œ augmentation
MoCo ํ•„์š” (Queue) 256 ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ 
BYOL ๋ถˆํ•„์š” 256 Predictor + EMA
SimSiam ๋ถˆํ•„์š” 256 BYOL ๋‹จ์ˆœํ™” (EMA ์—†์Œ)
MAE ๋ถˆํ•„์š” 256 Reconstruction

์„ฑ๋Šฅ ๋น„๊ต (ImageNet Linear Probe)

SimCLR (ResNet-50, 8192 batch): 69.3%
MoCo v2 (ResNet-50): 71.1%
BYOL (ResNet-50): 74.3%
MAE (ViT-Base): 67.8% โ†’ Fine-tune: 83.6%

์ •๋ฆฌ

ํ•ต์‹ฌ ๊ฐœ๋…

  1. Contrastive Learning: Positive/Negative ์Œ์œผ๋กœ ํ•™์Šต
  2. InfoNCE Loss: ๋Œ€์กฐ ์†์‹ค ํ•จ์ˆ˜
  3. Momentum Encoder: ๋А๋ฆฌ๊ฒŒ ์—…๋ฐ์ดํŠธ๋˜๋Š” target
  4. Masked Modeling: ์ผ๋ถ€ ์ž…๋ ฅ์„ ๋ณต์›ํ•˜๋„๋ก ํ•™์Šต
  5. Linear Evaluation: ๊ณ ์ •๋œ ํ‘œํ˜„ ์œ„์˜ ์„ ํ˜• ๋ถ„๋ฅ˜๊ธฐ ์„ฑ๋Šฅ

์„ ํƒ ๊ฐ€์ด๋“œ

๋Œ€๊ทœ๋ชจ ๋ฐฐ์น˜ ๊ฐ€๋Šฅ: SimCLR (๋‹จ์ˆœํ•˜๊ณ  ํšจ๊ณผ์ )
์ œํ•œ๋œ ๋ฆฌ์†Œ์Šค: MoCo (Queue๋กœ ํšจ์œจ์ )
Negative ์—†์ด: BYOL, SimSiam
ViT ๊ธฐ๋ฐ˜: MAE (๋ณต์› ๊ธฐ๋ฐ˜)

์‹ค์ „ ํŒ

# 1. Data Augmentation์ด ํ•ต์‹ฌ
# - ๊ฐ•ํ•œ augmentation = ๋” ์ข‹์€ ํ‘œํ˜„

# 2. ์˜จ๋„ ํŒŒ๋ผ๋ฏธํ„ฐ ์ฃผ์˜
# - ๋„ˆ๋ฌด ๋‚ฎ์œผ๋ฉด ํ•™์Šต ๋ถˆ์•ˆ์ •
# - ๋„ˆ๋ฌด ๋†’์œผ๋ฉด ํ•™์Šต ์‹ ํ˜ธ ์•ฝํ•จ

# 3. ๊ธด ํ•™์Šต ํ•„์š”
# - ์ตœ์†Œ 200-800 epochs ๊ถŒ์žฅ

# 4. Linear eval vs Fine-tune
# - Linear: ํ‘œํ˜„ ํ’ˆ์งˆ ํ‰๊ฐ€
# - Fine-tune: ์‹ค์ œ ์„ฑ๋Šฅ (๋” ๋†’์Œ)

์ฐธ๊ณ  ์ž๋ฃŒ

  • SimCLR: https://arxiv.org/abs/2002.05709
  • MoCo: https://arxiv.org/abs/1911.05722
  • BYOL: https://arxiv.org/abs/2006.07733
  • MAE: https://arxiv.org/abs/2111.06377
  • SimSiam: https://arxiv.org/abs/2011.10566
to navigate between lessons