36. Self-Supervised Learning
Previous: CLIP (Contrastive Language-Image Pre-training) | Next: Modern Deep Learning Architectures
36. Self-Supervised Learning¶
Learning Objectives¶
- Understand Self-Supervised Learning concepts and necessity
- Contrastive Learning (SimCLR, MoCo, BYOL)
- Masked Image Modeling (MAE)
- Transfer learning of pre-trained representations
- PyTorch implementation and practice
1. Self-Supervised Learning Overview¶
Definition and Necessity¶
Definition: Generate learning signals from data itself without labels
Why is it needed?
1. Labeling cost: ImageNet (14 million images) labeling cost is enormous
2. Abundant unlabeled data: Most data on the internet
3. Generalization ability: Transfer to various downstream tasks
Paradigm shift:
Supervised learning: Data + Labels β Model
Self-supervised learning: Data β Pretext Task β Representation β Downstream Task
Types of SSL¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Self-Supervised Learning Methods β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Contrastive Learning Masked Modeling β
β βββ SimCLR βββ MAE (Vision) β
β βββ MoCo βββ BERT (NLP) β
β βββ BYOL βββ BEiT β
β βββ SimSiam β
β β
β Clustering Generative β
β βββ DeepCluster βββ VAE β
β βββ SwAV βββ GAN β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2. Contrastive Learning Basics¶
Core Idea¶
Different augmentations of the same image β Close (Positive)
Different images β Far (Negative)
x βββ¬ββ Aug1 β view1 βββ
β β β maximize similarity (positive pair)
βββ Aug2 β view2 βββ
x1 ββ Aug β view1 ββ
β β minimize similarity (negative pair)
x2 ββ Aug β view2 ββ
InfoNCE Loss¶
import torch
import torch.nn.functional as F
def info_nce_loss(features, temperature=0.5):
"""InfoNCE (NT-Xent) Loss (βββ)
features: (2N, D) - two augmentations of N images
Returns:
loss: contrastive loss
"""
batch_size = features.shape[0] // 2
# Normalize
features = F.normalize(features, dim=1)
# Similarity matrix
similarity_matrix = features @ features.T / temperature
# Mask diagonal and same-image pairs
labels = torch.cat([torch.arange(batch_size) for _ in range(2)])
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
# Mask self-similarity (diagonal)
mask = torch.eye(labels.shape[0], dtype=torch.bool)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# Positives: same image, different augmentation
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
# Negatives: different images
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
# Logits
logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long)
return F.cross_entropy(logits, labels.to(logits.device))
3. SimCLR (Simple Framework for Contrastive Learning)¶
Architecture¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β SimCLR Framework β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Image x β
β β β
β βΌ β
β βββββββββββββββ β
β β Data Aug β β generate two views (t, t') β
β βββββββββββββββ β
β β β β
β t(x) t'(x) β
β β β β
β βΌ βΌ β
β βββββββββββββββ β
β β Encoder f β β feature extraction (ResNet, etc.) β
β βββββββββββββββ β
β β β β
β h_i h_j β
β β β β
β βΌ βΌ β
β βββββββββββββββ β
β βProjection g β β MLP to low-dim embedding β
β βββββββββββββββ β
β β β β
β z_i z_j β
β β β β
β ββββ¬βββ β
β βΌ β
β NT-Xent Loss β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
PyTorch Implementation¶
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50
class SimCLR(nn.Module):
"""SimCLR Model (ββββ)"""
def __init__(self, base_encoder=resnet50, projection_dim=128):
super().__init__()
# Encoder (pretrained ResNet without FC)
self.encoder = base_encoder(weights=None)
self.encoder_dim = self.encoder.fc.in_features
self.encoder.fc = nn.Identity()
# Projection Head (MLP)
self.projection_head = nn.Sequential(
nn.Linear(self.encoder_dim, self.encoder_dim),
nn.ReLU(),
nn.Linear(self.encoder_dim, projection_dim)
)
def forward(self, x):
h = self.encoder(x) # feature vector
z = self.projection_head(h) # projection vector
return h, z
def get_features(self, x):
"""Feature extraction for downstream tasks"""
return self.encoder(x)
class SimCLRAugmentation:
"""SimCLR Data Augmentation (βββ)"""
def __init__(self, size=224):
self.train_transform = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.08, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=int(0.1 * size) // 2 * 2 + 1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
return self.train_transform(x), self.train_transform(x)
SimCLR Training¶
def train_simclr(model, train_loader, epochs=100, lr=0.3, temperature=0.5):
"""SimCLR Training Loop (βββ)"""
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
device = next(model.parameters()).device
for epoch in range(epochs):
total_loss = 0
for (x_i, x_j), _ in train_loader:
x_i, x_j = x_i.to(device), x_j.to(device)
# Forward
_, z_i = model(x_i)
_, z_j = model(x_j)
# Concatenate
z = torch.cat([z_i, z_j], dim=0)
# Loss
loss = info_nce_loss(z, temperature)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
print(f"Epoch {epoch+1}/{epochs}: Loss = {total_loss/len(train_loader):.4f}")
return model
4. MoCo (Momentum Contrast)¶
Core Idea¶
SimCLR problem: Requires large batch size (4096+)
MoCo solution: Momentum Encoder + Queue for many negatives
Features:
1. Queue: Store embeddings from previous batches (65536)
2. Momentum Encoder: Slowly updated encoder
MoCo Architecture¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β MoCo Framework β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Query Key β
β β β β
β βΌ βΌ β
β βββββββββββ βββββββββββ β
β βEncoder fβ βEncoder β β
β β (q) β βf_k (mom)β β momentum update β
β ββββββ¬βββββ ββββββ¬βββββ β
β β β β
β βΌ βΌ β
β q_i k_i βββββββββββββββββ β
β β β β β
β β βΌ βΌ β
β β βββββββββββ ββββββββββββ β
β β β Queue β βββ β enqueue β β
β β β (k-) β ββββββββββββ β
β β ββββββ¬βββββ β
β β β β
β ββββββββββββββ¬ββββββββββββββββ β
β βΌ β
β InfoNCE Loss β
β (qΒ·k+ / qΒ·k-) β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
PyTorch Implementation¶
import copy
class MoCo(nn.Module):
"""Momentum Contrast (MoCo v2) (ββββ)"""
def __init__(self, base_encoder=resnet50, dim=128, K=65536, m=0.999, T=0.07):
super().__init__()
self.K = K # Queue size
self.m = m # Momentum
self.T = T # Temperature
# Query encoder
self.encoder_q = base_encoder(weights=None)
self.encoder_q.fc = nn.Sequential(
nn.Linear(self.encoder_q.fc.in_features, dim * 4),
nn.ReLU(),
nn.Linear(dim * 4, dim)
)
# Key encoder (momentum)
self.encoder_k = copy.deepcopy(self.encoder_q)
for param_k in self.encoder_k.parameters():
param_k.requires_grad = False # no backprop
# Queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""Momentum update of key encoder"""
for param_q, param_k in zip(self.encoder_q.parameters(),
self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
"""Update queue"""
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
# Replace oldest
if ptr + batch_size > self.K:
# Wrap around
self.queue[:, ptr:] = keys[:self.K - ptr].T
self.queue[:, :batch_size - (self.K - ptr)] = keys[self.K - ptr:].T
else:
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
def forward(self, im_q, im_k):
"""
Args:
im_q: query image
im_k: key image (different augmentation)
"""
# Query
q = self.encoder_q(im_q)
q = F.normalize(q, dim=1)
# Key (no gradient)
with torch.no_grad():
self._momentum_update_key_encoder()
k = self.encoder_k(im_k)
k = F.normalize(k, dim=1)
# Positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# Negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# Logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1) / self.T
# Labels: positives are at index 0
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
# Dequeue and enqueue
self._dequeue_and_enqueue(k)
return logits, labels
def train_moco(model, train_loader, epochs=200):
"""MoCo Training (βββ)"""
optimizer = torch.optim.SGD(model.encoder_q.parameters(),
lr=0.03, momentum=0.9, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
for epoch in range(epochs):
for (im_q, im_k), _ in train_loader:
logits, labels = model(im_q.cuda(), im_k.cuda())
loss = criterion(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
5. BYOL (Bootstrap Your Own Latent)¶
Core Idea¶
Question: Are negative samples really necessary?
BYOL: Learn without negatives! (Online + Target network)
Key: Predictor network + EMA to prevent collapse
BYOL Architecture¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β BYOL Framework β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Online Network Target Network β
β (trained) (updated with EMA) β
β β
β βββββββββββ βββββββββββ β
β βEncoder β βEncoder β β
β β(ΞΈ) β β(ΞΎ) β β EMA β
β ββββββ¬βββββ ββββββ¬βββββ β
β β β β
β βΌ βΌ β
β βββββββββββ βββββββββββ β
β βProjectorβ βProjectorβ β
β ββββββ¬βββββ ββββββ¬βββββ β
β β β β
β βΌ β β
β βββββββββββ β β
β βPredictorβ (Online only) β β
β ββββββ¬βββββ β β
β β β β
β q_ΞΈ z_ΞΎ β
β β β β
β ββββββββββββ¬ββββββββββββββββββββ β
β βΌ β
β MSE Loss (q_ΞΈ, sg(z_ΞΎ)) β
β sg = stop gradient β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
PyTorch Implementation¶
class BYOL(nn.Module):
"""Bootstrap Your Own Latent (ββββ)"""
def __init__(self, base_encoder=resnet50, hidden_dim=4096, proj_dim=256, pred_dim=256):
super().__init__()
# Online network
encoder = base_encoder(weights=None)
encoder_dim = encoder.fc.in_features
encoder.fc = nn.Identity()
self.online_encoder = encoder
self.online_projector = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, proj_dim)
)
self.predictor = nn.Sequential(
nn.Linear(proj_dim, pred_dim),
nn.BatchNorm1d(pred_dim),
nn.ReLU(),
nn.Linear(pred_dim, proj_dim)
)
# Target network (EMA)
self.target_encoder = copy.deepcopy(self.online_encoder)
self.target_projector = copy.deepcopy(self.online_projector)
# Freeze target
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self, tau=0.99):
"""Target network EMA update"""
for online, target in zip(self.online_encoder.parameters(),
self.target_encoder.parameters()):
target.data = tau * target.data + (1 - tau) * online.data
for online, target in zip(self.online_projector.parameters(),
self.target_projector.parameters()):
target.data = tau * target.data + (1 - tau) * online.data
def forward(self, x1, x2):
# Online predictions
online_proj_1 = self.online_projector(self.online_encoder(x1))
online_proj_2 = self.online_projector(self.online_encoder(x2))
online_pred_1 = self.predictor(online_proj_1)
online_pred_2 = self.predictor(online_proj_2)
# Target projections (no gradient)
with torch.no_grad():
target_proj_1 = self.target_projector(self.target_encoder(x1))
target_proj_2 = self.target_projector(self.target_encoder(x2))
return online_pred_1, online_pred_2, target_proj_1, target_proj_2
def byol_loss(pred, target):
"""BYOL Loss: Negative Cosine Similarity (βββ)"""
pred = F.normalize(pred, dim=-1)
target = F.normalize(target, dim=-1)
return 2 - 2 * (pred * target).sum(dim=-1).mean()
def train_byol(model, train_loader, epochs=100):
"""BYOL Training (βββ)"""
optimizer = torch.optim.Adam(
list(model.online_encoder.parameters()) +
list(model.online_projector.parameters()) +
list(model.predictor.parameters()),
lr=3e-4
)
for epoch in range(epochs):
for (x1, x2), _ in train_loader:
x1, x2 = x1.cuda(), x2.cuda()
pred1, pred2, target1, target2 = model(x1, x2)
# Symmetric loss
loss = byol_loss(pred1, target2) + byol_loss(pred2, target1)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update target
model.update_target(tau=0.99)
6. MAE (Masked Autoencoder)¶
Core Idea¶
Apply BERT's idea from NLP to Vision:
- Mask portion of image (75%)
- Train to reconstruct masked portions
Advantages:
1. High masking ratio β strong learning signal
2. Efficient: encode only visible patches
3. Naturally combines with ViT
MAE Architecture¶
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β MAE Architecture β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Image β Patches β
β β β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β [P1] [P2] [P3] ... [P196] β β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β Random Masking (75%) β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β [P1] [M] [P3] [M] [M] ... [P196] β β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β (visible patches only) β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β Encoder (ViT) - only on visible patches β β
β β [P1] [P3] ... [P196] β β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β + Mask tokens + Position Embedding β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β Decoder (small ViT) β β
β β [P1] [M] [P3] [M] [M] ... [P196] β β
β βββββββββββββββββββββββββββββββββββββββββββββββ β
β β β
β Reconstruct masked patches (MSE Loss) β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
PyTorch Implementation¶
import random
class MAE(nn.Module):
"""Masked Autoencoder (ββββ)"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
encoder_embed_dim=768, encoder_depth=12, encoder_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_heads=16,
mask_ratio=0.75):
super().__init__()
self.patch_size = patch_size
self.mask_ratio = mask_ratio
num_patches = (img_size // patch_size) ** 2
# Encoder
self.patch_embed = nn.Conv2d(in_channels, encoder_embed_dim,
kernel_size=patch_size, stride=patch_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, encoder_embed_dim))
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, encoder_embed_dim))
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(encoder_embed_dim, encoder_heads, batch_first=True),
num_layers=encoder_depth
)
self.encoder_norm = nn.LayerNorm(encoder_embed_dim)
# Decoder
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim)
self.mask_token = nn.Parameter(torch.randn(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, decoder_embed_dim))
self.decoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(decoder_embed_dim, decoder_heads, batch_first=True),
num_layers=decoder_depth
)
self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels)
def random_masking(self, x, mask_ratio):
"""Random masking (βββ)"""
B, N, D = x.shape
len_keep = int(N * (1 - mask_ratio))
# Random shuffle
noise = torch.rand(B, N, device=x.device)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# Keep first len_keep
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).expand(-1, -1, D))
# Generate mask
mask = torch.ones([B, N], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward(self, x):
# Patch embedding
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
# Add position embedding (without cls)
x = x + self.pos_embed[:, 1:, :]
# Masking
x, mask, ids_restore = self.random_masking(x, self.mask_ratio)
# Append cls token
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
cls_tokens = cls_tokens + self.pos_embed[:, :1, :]
x = torch.cat([cls_tokens, x], dim=1)
# Encoder
x = self.encoder(x)
x = self.encoder_norm(x)
# Decoder embed
x = self.decoder_embed(x)
# Append mask tokens
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).expand(-1, -1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# Add position embedding
x = x + self.decoder_pos_embed
# Decoder
x = self.decoder(x)
x = self.decoder_norm(x)
# Predictor
x = self.decoder_pred(x)
x = x[:, 1:, :] # remove cls token
return x, mask
def loss(self, pred, target, mask):
"""MAE Loss: MSE on masked patches only (βββ)"""
target = self.patchify(target)
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # mean per patch
loss = (loss * mask).sum() / mask.sum() # mean on masked
return loss
def patchify(self, imgs):
"""Convert images to patches"""
p = self.patch_size
B, C, H, W = imgs.shape
h = w = H // p
x = imgs.reshape(B, C, h, p, w, p)
x = x.permute(0, 2, 4, 3, 5, 1)
x = x.reshape(B, h * w, p * p * C)
return x
MAE Training¶
def train_mae(model, train_loader, epochs=400, lr=1.5e-4):
"""MAE Training (βββ)"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
betas=(0.9, 0.95), weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
for epoch in range(epochs):
total_loss = 0
for images, _ in train_loader:
images = images.cuda()
pred, mask = model(images)
loss = model.loss(pred, images, mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
print(f"Epoch {epoch+1}/{epochs}: Loss = {total_loss/len(train_loader):.4f}")
7. Linear Evaluation¶
Evaluating Representation Quality¶
def linear_evaluation(encoder, train_loader, test_loader, num_classes=10):
"""SSL representation Linear Evaluation (βββ)"""
# Freeze encoder
encoder.eval()
for param in encoder.parameters():
param.requires_grad = False
# Linear classifier
feature_dim = 2048 # ResNet-50 based
classifier = nn.Linear(feature_dim, num_classes).cuda()
optimizer = torch.optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9)
# Extract features
def extract_features(loader):
features, labels = [], []
with torch.no_grad():
for images, targets in loader:
feat = encoder(images.cuda())
features.append(feat.cpu())
labels.append(targets)
return torch.cat(features), torch.cat(labels)
train_features, train_labels = extract_features(train_loader)
test_features, test_labels = extract_features(test_loader)
# Train linear classifier
train_dataset = torch.utils.data.TensorDataset(train_features, train_labels)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
for epoch in range(100):
for features, labels in train_loader:
features, labels = features.cuda(), labels.cuda()
output = classifier(features)
loss = F.cross_entropy(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate
classifier.eval()
with torch.no_grad():
output = classifier(test_features.cuda())
pred = output.argmax(dim=1).cpu()
accuracy = (pred == test_labels).float().mean().item()
return accuracy * 100
8. Method Comparison¶
Characteristics Comparison¶
| Method | Negatives | Batch Size | Key Features |
|---|---|---|---|
| SimCLR | Required | 4096+ | Simple, strong augmentation |
| MoCo | Required (Queue) | 256 | Memory efficient |
| BYOL | Not required | 256 | Predictor + EMA |
| SimSiam | Not required | 256 | BYOL simplified (no EMA) |
| MAE | Not required | 256 | Reconstruction |
Performance Comparison (ImageNet Linear Probe)¶
SimCLR (ResNet-50, 8192 batch): 69.3%
MoCo v2 (ResNet-50): 71.1%
BYOL (ResNet-50): 74.3%
MAE (ViT-Base): 67.8% β Fine-tune: 83.6%
Summary¶
Key Concepts¶
- Contrastive Learning: Learn with positive/negative pairs
- InfoNCE Loss: Contrastive loss function
- Momentum Encoder: Slowly updated target
- Masked Modeling: Train to reconstruct partial input
- Linear Evaluation: Linear classifier performance on frozen representations
Selection Guide¶
Large batch possible: SimCLR (simple and effective)
Limited resources: MoCo (efficient with queue)
Without negatives: BYOL, SimSiam
ViT-based: MAE (reconstruction-based)
Practical Tips¶
# 1. Data Augmentation is key
# - Strong augmentation = better representation
# 2. Be careful with temperature parameter
# - Too low: unstable training
# - Too high: weak learning signal
# 3. Long training required
# - Minimum 200-800 epochs recommended
# 4. Linear eval vs Fine-tune
# - Linear: evaluate representation quality
# - Fine-tune: actual performance (higher)
References¶
- SimCLR: https://arxiv.org/abs/2002.05709
- MoCo: https://arxiv.org/abs/1911.05722
- BYOL: https://arxiv.org/abs/2006.07733
- MAE: https://arxiv.org/abs/2111.06377
- SimSiam: https://arxiv.org/abs/2011.10566