36. Self-Supervised Learning
์ด์ : CLIP (Contrastive Language-Image Pre-training) | ๋ค์: ํ๋ ๋ฅ๋ฌ๋ ์ํคํ ์ฒ
36. Self-Supervised Learning¶
ํ์ต ๋ชฉํ¶
- Self-Supervised Learning ๊ฐ๋ ๊ณผ ํ์์ฑ ์ดํด
- Contrastive Learning (SimCLR, MoCo, BYOL)
- Masked Image Modeling (MAE)
- ์ฌ์ ํ์ต ํํ์ ์ ์ด ํ์ต
- PyTorch ๊ตฌํ ๋ฐ ์ค์ต
1. Self-Supervised Learning ๊ฐ์¶
์ ์์ ํ์์ฑ¶
์ ์: ๋ ์ด๋ธ ์์ด ๋ฐ์ดํฐ ์์ฒด์์ ํ์ต ์ ํธ ์์ฑ
์ ํ์ํ๊ฐ?
1. ๋ ์ด๋ธ๋ง ๋น์ฉ: ImageNet (1400๋ง์ฅ ๋ ์ด๋ธ๋ง) ๋น์ฉ ๋ง๋
2. ํ๋ถํ ๋น๋ ์ด๋ธ ๋ฐ์ดํฐ: ์ธํฐ๋ท์ ๋๋ถ๋ถ ๋ฐ์ดํฐ
3. ์ผ๋ฐํ ๋ฅ๋ ฅ: ๋ค์ํ ๋ค์ด์คํธ๋ฆผ ํ์คํฌ์ ์ ์ด
ํจ๋ฌ๋ค์ ๋ณํ:
์ง๋ํ์ต: ๋ฐ์ดํฐ + ๋ ์ด๋ธ โ ๋ชจ๋ธ
์๊ธฐ์ง๋ํ์ต: ๋ฐ์ดํฐ โ Pretext Task โ ํํ โ ๋ค์ด์คํธ๋ฆผ ํ์คํฌ
SSL์ ์ข ๋ฅ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Self-Supervised Learning Methods โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Contrastive Learning Masked Modeling โ
โ โโโ SimCLR โโโ MAE (Vision) โ
โ โโโ MoCo โโโ BERT (NLP) โ
โ โโโ BYOL โโโ BEiT โ
โ โโโ SimSiam โ
โ โ
โ Clustering Generative โ
โ โโโ DeepCluster โโโ VAE โ
โ โโโ SwAV โโโ GAN โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
2. Contrastive Learning ๊ธฐ์ด¶
ํต์ฌ ์์ด๋์ด¶
๊ฐ์ ์ด๋ฏธ์ง์ ๋ค๋ฅธ augmentation โ ๊ฐ๊น๊ฒ (Positive)
๋ค๋ฅธ ์ด๋ฏธ์ง โ ๋ฉ๊ฒ (Negative)
x โโโฌโโ Aug1 โ view1 โโโ
โ โ โ ์ ์ฌ๋ ์ต๋ํ (positive pair)
โโโ Aug2 โ view2 โโโ
x1 โโ Aug โ view1 โโ
โ โ ์ ์ฌ๋ ์ต์ํ (negative pair)
x2 โโ Aug โ view2 โโ
InfoNCE Loss¶
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.5):
"""InfoNCE (NT-Xent) Loss (โญโญโญ)
features: (2N, D) - N๊ฐ ์ด๋ฏธ์ง์ ๋ augmentation
Returns:
loss: contrastive loss
"""
batch_size = features.shape[0] // 2
# Normalize
features = F.normalize(features, dim=1)
# Similarity matrix
similarity_matrix = features @ features.T / temperature
# Mask diagonal and same-image pairs
labels = torch.cat([torch.arange(batch_size) for _ in range(2)])
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
# Mask self-similarity (diagonal)
mask = torch.eye(labels.shape[0], dtype=torch.bool)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# Positives: same image, different augmentation
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# Negatives: different images
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
# Logits
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long)
return F.cross_entropy(logits, labels.to(logits.device))
3. SimCLR (Simple Framework for Contrastive Learning)¶
์ํคํ ์ฒ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ SimCLR Framework โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Image x โ
โ โ โ
โ โผ โ
โ โโโโโโโโโโโโโโโ โ
โ โ Data Aug โ โ ๋ ๊ฐ์ view ์์ฑ (t, t') โ
โ โโโโโโโโโโโโโโโ โ
โ โ โ โ
โ t(x) t'(x) โ
โ โ โ โ
โ โผ โผ โ
โ โโโโโโโโโโโโโโโ โ
โ โ Encoder f โ โ ํน์ง ์ถ์ถ (ResNet ๋ฑ) โ
โ โโโโโโโโโโโโโโโ โ
โ โ โ โ
โ h_i h_j โ
โ โ โ โ
โ โผ โผ โ
โ โโโโโโโโโโโโโโโ โ
โ โProjection g โ โ MLP๋ก ์ ์ฐจ์ ์๋ฒ ๋ฉ โ
โ โโโโโโโโโโโโโโโ โ
โ โ โ โ
โ z_i z_j โ
โ โ โ โ
โ โโโโฌโโโ โ
โ โผ โ
โ NT-Xent Loss โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
PyTorch ๊ตฌํ¶
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50
class SimCLR(nn.Module):
"""SimCLR Model (โญโญโญโญ)"""
def __init__(self, base_encoder=resnet50, projection_dim=128):
super().__init__()
# Encoder (pretrained ResNet without FC)
self.encoder = base_encoder(weights=None)
self.encoder_dim = self.encoder.fc.in_features
self.encoder.fc = nn.Identity()
# Projection Head (MLP)
self.projection_head = nn.Sequential(
nn.Linear(self.encoder_dim, self.encoder_dim),
nn.ReLU(),
nn.Linear(self.encoder_dim, projection_dim)
)
def forward(self, x):
h = self.encoder(x) # ํน์ง ๋ฒกํฐ
z = self.projection_head(h) # ํฌ์ ๋ฒกํฐ
return h, z
def get_features(self, x):
"""๋ค์ด์คํธ๋ฆผ ํ์คํฌ์ฉ ํน์ง ์ถ์ถ"""
return self.encoder(x)
class SimCLRAugmentation:
"""SimCLR Data Augmentation (โญโญโญ)"""
def __init__(self, size=224):
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=int(0.1 * size) // 2 * 2 + 1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.train_transform(x), self.train_transform(x)
SimCLR ํ์ต¶
def train_simclr(model, train_loader, epochs=100, lr=0.3, temperature=0.5):
"""SimCLR Training Loop (โญโญโญ)"""
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
device = next(model.parameters()).device
for epoch in range(epochs):
total_loss = 0
for (x_i, x_j), _ in train_loader:
x_i, x_j = x_i.to(device), x_j.to(device)
# Forward
_, z_i = model(x_i)
_, z_j = model(x_j)
# Concatenate
z = torch.cat([z_i, z_j], dim=0)
# Loss
loss = info_nce_loss(z, temperature)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
print(f"Epoch {epoch+1}/{epochs}: Loss = {total_loss/len(train_loader):.4f}")
return model
4. MoCo (Momentum Contrast)¶
ํต์ฌ ์์ด๋์ด¶
SimCLR ๋ฌธ์ : ํฐ ๋ฐฐ์น ์ฌ์ด์ฆ ํ์ (4096+)
MoCo ํด๊ฒฐ: Momentum Encoder + Queue๋ก ๋ง์ negative ํ๋ณด
ํน์ง:
1. Queue: ์ด์ ๋ฐฐ์น์ ์๋ฒ ๋ฉ ์ ์ฅ (65536๊ฐ)
2. Momentum Encoder: ์ฒ์ฒํ ์
๋ฐ์ดํธ๋๋ ์ธ์ฝ๋
MoCo ์ํคํ ์ฒ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ MoCo Framework โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Query Key โ
โ โ โ โ
โ โผ โผ โ
โ โโโโโโโโโโโ โโโโโโโโโโโ โ
โ โEncoder fโ โEncoder โ โ
โ โ (q) โ โf_k (mom)โ โ momentum update โ
โ โโโโโโฌโโโโโ โโโโโโฌโโโโโ โ
โ โ โ โ
โ โผ โผ โ
โ q_i k_i โโโโโโโโโโโโโโโโโ โ
โ โ โ โ โ
โ โ โผ โผ โ
โ โ โโโโโโโโโโโ โโโโโโโโโโโโ โ
โ โ โ Queue โ โโโ โ enqueue โ โ
โ โ โ (k-) โ โโโโโโโโโโโโ โ
โ โ โโโโโโฌโโโโโ โ
โ โ โ โ
โ โโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโ โ
โ โผ โ
โ InfoNCE Loss โ
โ (qยทk+ / qยทk-) โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
PyTorch ๊ตฌํ¶
import copy
class MoCo(nn.Module):
"""Momentum Contrast (MoCo v2) (โญโญโญโญ)"""
def __init__(self, base_encoder=resnet50, dim=128, K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K # Queue size
self.m = m # Momentum
self.T = T # Temperature
# Query encoder
self.encoder_q = base_encoder(weights=None)
self.encoder_q.fc = nn.Sequential(
nn.Linear(self.encoder_q.fc.in_features, dim * 4),
nn.ReLU(),
nn.Linear(dim * 4, dim)
)
# Key encoder (momentum)
self.encoder_k = copy.deepcopy(self.encoder_q)
for param_k in self.encoder_k.parameters():
param_k.requires_grad = False # ์ญ์ ํ X
# Queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""Key encoder์ momentum ์
๋ฐ์ดํธ"""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""Queue ์
๋ฐ์ดํธ"""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
# Replace oldest
if ptr + batch_size > self.K:
# Wrap around
self.queue[:, ptr:] = keys[:self.K - ptr].T
self.queue[:, :batch_size - (self.K - ptr)] = keys[self.K - ptr:].T
else:
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
"""
Args:
im_q: query image
im_k: key image (different augmentation)
"""
# Query
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1)
# Key (no gradient)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
# Positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# Logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
# Labels: positives are at index 0
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
# Dequeue and enqueue
self._dequeue_and_enqueue(k)
return logits, labels
def train_moco(model, train_loader, epochs=200):
"""MoCo Training (โญโญโญ)"""
optimizer = torch.optim.SGD(model.encoder_q.parameters(),
lr=0.03, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for (im_q, im_k), _ in train_loader:
logits, labels = model(im_q.cuda(), im_k.cuda())
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
5. BYOL (Bootstrap Your Own Latent)¶
ํต์ฌ ์์ด๋์ด¶
๋ฌธ์ : Negative samples์ด ์ ๋ง ํ์ํ๊ฐ?
BYOL: Negative ์์ด ํ์ต! (Online + Target network)
ํต์ฌ: Predictor ๋คํธ์ํฌ + EMA๋ก collapse ๋ฐฉ์ง
BYOL ์ํคํ ์ฒ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ BYOL Framework โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Online Network Target Network โ
โ (ํ์ต๋จ) (EMA๋ก ์
๋ฐ์ดํธ) โ
โ โ
โ โโโโโโโโโโโ โโโโโโโโโโโ โ
โ โEncoder โ โEncoder โ โ
โ โ(ฮธ) โ โ(ฮพ) โ โ EMA โ
โ โโโโโโฌโโโโโ โโโโโโฌโโโโโ โ
โ โ โ โ
โ โผ โผ โ
โ โโโโโโโโโโโ โโโโโโโโโโโ โ
โ โProjectorโ โProjectorโ โ
โ โโโโโโฌโโโโโ โโโโโโฌโโโโโ โ
โ โ โ โ
โ โผ โ โ
โ โโโโโโโโโโโ โ โ
โ โPredictorโ (Online๋ง) โ โ
โ โโโโโโฌโโโโโ โ โ
โ โ โ โ
โ q_ฮธ z_ฮพ โ
โ โ โ โ
โ โโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโ โ
โ โผ โ
โ MSE Loss (q_ฮธ, sg(z_ฮพ)) โ
โ sg = stop gradient โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
PyTorch ๊ตฌํ¶
class BYOL(nn.Module):
"""Bootstrap Your Own Latent (โญโญโญโญ)"""
def __init__(self, base_encoder=resnet50, hidden_dim=4096, proj_dim=256, pred_dim=256):
super().__init__()
# Online network
encoder = base_encoder(weights=None)
encoder_dim = encoder.fc.in_features
encoder.fc = nn.Identity()
self.online_encoder = encoder
self.online_projector = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, proj_dim)
)
self.predictor = nn.Sequential(
nn.Linear(proj_dim, pred_dim),
nn.BatchNorm1d(pred_dim),
nn.ReLU(),
nn.Linear(pred_dim, proj_dim)
)
# Target network (EMA)
self.target_encoder = copy.deepcopy(self.online_encoder)
self.target_projector = copy.deepcopy(self.online_projector)
# Freeze target
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self, tau=0.99):
"""Target network EMA ์
๋ฐ์ดํธ"""
for online, target in zip(self.online_encoder.parameters(),
self.target_encoder.parameters()):
target.data = tau * target.data + (1 - tau) * online.data
for online, target in zip(self.online_projector.parameters(),
self.target_projector.parameters()):
target.data = tau * target.data + (1 - tau) * online.data
def forward(self, x1, x2):
# Online predictions
online_proj_1 = self.online_projector(self.online_encoder(x1))
online_proj_2 = self.online_projector(self.online_encoder(x2))
online_pred_1 = self.predictor(online_proj_1)
online_pred_2 = self.predictor(online_proj_2)
# Target projections (no gradient)
with torch.no_grad():
target_proj_1 = self.target_projector(self.target_encoder(x1))
target_proj_2 = self.target_projector(self.target_encoder(x2))
return online_pred_1, online_pred_2, target_proj_1, target_proj_2
def byol_loss(pred, target):
"""BYOL Loss: Negative Cosine Similarity (โญโญโญ)"""
pred = F.normalize(pred, dim=-1)
target = F.normalize(target, dim=-1)
return 2 - 2 * (pred * target).sum(dim=-1).mean()
def train_byol(model, train_loader, epochs=100):
"""BYOL Training (โญโญโญ)"""
optimizer = torch.optim.Adam(
list(model.online_encoder.parameters()) +
list(model.online_projector.parameters()) +
list(model.predictor.parameters()),
lr=3e-4
)
for epoch in range(epochs):
for (x1, x2), _ in train_loader:
x1, x2 = x1.cuda(), x2.cuda()
pred1, pred2, target1, target2 = model(x1, x2)
# ๋์นญ์ ์์ค
loss = byol_loss(pred1, target2) + byol_loss(pred2, target1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Target ์
๋ฐ์ดํธ
model.update_target(tau=0.99)
6. MAE (Masked Autoencoder)¶
ํต์ฌ ์์ด๋์ด¶
NLP์ BERT ์์ด๋์ด๋ฅผ Vision์ ์ ์ฉ:
- ์ด๋ฏธ์ง์ ์ผ๋ถ(75%)๋ฅผ ๋ง์คํน
- ๋ง์คํน๋ ๋ถ๋ถ์ ๋ณต์ํ๋๋ก ํ์ต
์ฅ์ :
1. ๋์ ๋ง์คํน ๋น์จ โ ๊ฐํ ํ์ต ์ ํธ
2. ํจ์จ์ : ๋ง์คํน๋์ง ์์ ํจ์น๋ง ์ธ์ฝ๋ฉ
3. ViT์ ์์ฐ์ค๋ฝ๊ฒ ๊ฒฐํฉ
MAE ์ํคํ ์ฒ¶
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ MAE Architecture โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโค
โ โ
โ Image โ Patches โ
โ โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ [P1] [P2] [P3] ... [P196] โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Random Masking (75%) โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ [P1] [M] [P3] [M] [M] ... [P196] โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ (visible patches only) โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Encoder (ViT) - only on visible patches โ โ
โ โ [P1] [P3] ... [P196] โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ + Mask tokens + Position Embedding โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ Decoder (small ViT) โ โ
โ โ [P1] [M] [P3] [M] [M] ... [P196] โ โ
โ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ Reconstruct masked patches (MSE Loss) โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
PyTorch ๊ตฌํ¶
import random
class MAE(nn.Module):
"""Masked Autoencoder (โญโญโญโญ)"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
encoder_embed_dim=768, encoder_depth=12, encoder_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_heads=16,
mask_ratio=0.75):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
num_patches = (img_size // patch_size) ** 2
# Encoder
self.patch_embed = nn.Conv2d(in_channels, encoder_embed_dim,
kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, encoder_embed_dim))
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, encoder_embed_dim))
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(encoder_embed_dim, encoder_heads, batch_first=True),
num_layers=encoder_depth
)
self.encoder_norm = nn.LayerNorm(encoder_embed_dim)
# Decoder
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim)
self.mask_token = nn.Parameter(torch.randn(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, decoder_embed_dim))
self.decoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(decoder_embed_dim, decoder_heads, batch_first=True),
num_layers=decoder_depth
)
self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels)
def random_masking(self, x, mask_ratio):
"""Random masking (โญโญโญ)"""
B, N, D = x.shape
len_keep = int(N * (1 - mask_ratio))
# ๋๋ค ์
ํ
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Keep first len_keep
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
# Generate mask
mask = torch.ones([B, N], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward(self, x):
# Patch embedding
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
# Add position embedding (without cls)
x = x + self.pos_embed[:, 1:, :]
# Masking
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# Append cls token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
x = torch.cat([cls_tokens, x], dim=1)
# Encoder
x = self.encoder(x)
x = self.encoder_norm(x)
# Decoder embed
x = self.decoder_embed(x)
# Append mask tokens
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# Add position embedding
x = x + self.decoder_pos_embed
# Decoder
x = self.decoder(x)
x = self.decoder_norm(x)
# Predictor
x = self.decoder_pred(x)
x = x[:, 1:, :] # remove cls token
return x, mask
def loss(self, pred, target, mask):
"""MAE Loss: MSE on masked patches only (โญโญโญ)"""
target = self.patchify(target)
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # mean per patch
loss = (loss * mask).sum() / mask.sum() # mean on masked
return loss
def patchify(self, imgs):
"""์ด๋ฏธ์ง๋ฅผ ํจ์น๋ก ๋ณํ"""
p = self.patch_size
B, C, H, W = imgs.shape
h = w = H // p
x = imgs.reshape(B, C, h, p, w, p)
x = x.permute(0, 2, 4, 3, 5, 1)
x = x.reshape(B, h * w, p * p * C)
return x
MAE ํ์ต¶
def train_mae(model, train_loader, epochs=400, lr=1.5e-4):
"""MAE Training (โญโญโญ)"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
betas=(0.9, 0.95), weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(epochs):
total_loss = 0
for images, _ in train_loader:
images = images.cuda()
pred, mask = model(images)
loss = model.loss(pred, images, mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
print(f"Epoch {epoch+1}/{epochs}: Loss = {total_loss/len(train_loader):.4f}")
7. Linear Evaluation¶
ํํ ํ์ง ํ๊ฐ¶
def linear_evaluation(encoder, train_loader, test_loader, num_classes=10):
"""SSL ํํ์ Linear Evaluation (โญโญโญ)"""
# Encoder freeze
encoder.eval()
for param in encoder.parameters():
param.requires_grad = False
# Linear classifier
feature_dim = 2048 # ResNet-50 ๊ธฐ์ค
classifier = nn.Linear(feature_dim, num_classes).cuda()
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9)
# Extract features
def extract_features(loader):
features, labels = [], []
with torch.no_grad():
for images, targets in loader:
feat = encoder(images.cuda())
features.append(feat.cpu())
labels.append(targets)
return torch.cat(features), torch.cat(labels)
train_features, train_labels = extract_features(train_loader)
test_features, test_labels = extract_features(test_loader)
# Train linear classifier
train_dataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
for epoch in range(100):
for features, labels in train_loader:
features, labels = features.cuda(), labels.cuda()
output = classifier(features)
loss = F.cross_entropy(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate
classifier.eval()
with torch.no_grad():
output = classifier(test_features.cuda())
pred = output.argmax(dim=1).cpu()
accuracy = (pred == test_labels).float().mean().item()
return accuracy * 100
8. ๋ฐฉ๋ฒ๋ก ๋น๊ต¶
ํน์ฑ ๋น๊ต¶
| ๋ฐฉ๋ฒ | Negative | ๋ฐฐ์น ์ฌ์ด์ฆ | ์ฃผ์ ํน์ง |
|---|---|---|---|
| SimCLR | ํ์ | 4096+ | ๋จ์, ๊ฐ๋ ฅํ augmentation |
| MoCo | ํ์ (Queue) | 256 | ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ |
| BYOL | ๋ถํ์ | 256 | Predictor + EMA |
| SimSiam | ๋ถํ์ | 256 | BYOL ๋จ์ํ (EMA ์์) |
| MAE | ๋ถํ์ | 256 | Reconstruction |
์ฑ๋ฅ ๋น๊ต (ImageNet Linear Probe)¶
SimCLR (ResNet-50, 8192 batch): 69.3%
MoCo v2 (ResNet-50): 71.1%
BYOL (ResNet-50): 74.3%
MAE (ViT-Base): 67.8% โ Fine-tune: 83.6%
์ ๋ฆฌ¶
ํต์ฌ ๊ฐ๋ ¶
- Contrastive Learning: Positive/Negative ์์ผ๋ก ํ์ต
- InfoNCE Loss: ๋์กฐ ์์ค ํจ์
- Momentum Encoder: ๋๋ฆฌ๊ฒ ์ ๋ฐ์ดํธ๋๋ target
- Masked Modeling: ์ผ๋ถ ์ ๋ ฅ์ ๋ณต์ํ๋๋ก ํ์ต
- Linear Evaluation: ๊ณ ์ ๋ ํํ ์์ ์ ํ ๋ถ๋ฅ๊ธฐ ์ฑ๋ฅ
์ ํ ๊ฐ์ด๋¶
๋๊ท๋ชจ ๋ฐฐ์น ๊ฐ๋ฅ: SimCLR (๋จ์ํ๊ณ ํจ๊ณผ์ )
์ ํ๋ ๋ฆฌ์์ค: MoCo (Queue๋ก ํจ์จ์ )
Negative ์์ด: BYOL, SimSiam
ViT ๊ธฐ๋ฐ: MAE (๋ณต์ ๊ธฐ๋ฐ)
์ค์ ํ¶
# 1. Data Augmentation์ด ํต์ฌ
# - ๊ฐํ augmentation = ๋ ์ข์ ํํ
# 2. ์จ๋ ํ๋ผ๋ฏธํฐ ์ฃผ์
# - ๋๋ฌด ๋ฎ์ผ๋ฉด ํ์ต ๋ถ์์
# - ๋๋ฌด ๋์ผ๋ฉด ํ์ต ์ ํธ ์ฝํจ
# 3. ๊ธด ํ์ต ํ์
# - ์ต์ 200-800 epochs ๊ถ์ฅ
# 4. Linear eval vs Fine-tune
# - Linear: ํํ ํ์ง ํ๊ฐ
# - Fine-tune: ์ค์ ์ฑ๋ฅ (๋ ๋์)
์ฐธ๊ณ ์๋ฃ¶
- SimCLR: https://arxiv.org/abs/2002.05709
- MoCo: https://arxiv.org/abs/1911.05722
- BYOL: https://arxiv.org/abs/2006.07733
- MAE: https://arxiv.org/abs/2111.06377
- SimSiam: https://arxiv.org/abs/2011.10566