04. Pre-training λͺ©μ ν¨μ
κ°μ
Pre-training λͺ©μ ν¨μλ Foundation Modelμ΄ λκ·λͺ¨ λ°μ΄ν°μμ μ΄λ€ ν¨ν΄μ νμ΅ν μ§ κ²°μ ν©λλ€. λͺ©μ ν¨μ μ νμ΄ λͺ¨λΈμ λ₯λ ₯κ³Ό downstream task μ±λ₯μ μ§μ μ μΈ μν₯μ λ―ΈμΉ©λλ€.
1. Language Modeling ν¨λ¬λ€μ
1.1 μΈ κ°μ§ μ£Όμ μ κ·Όλ²
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Language Modeling ν¨λ¬λ€μ β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β Causal LM (Autoregressive) Masked LM (Bidirectional) β
β βββββ¬ββββ¬ββββ¬ββββ¬ββββ βββββ¬ββββ¬ββββ¬ββββ¬ββββ β
β β A β B β C β D β ? β β A β[M]β C β[M]β E β β
β βββββ΄ββββ΄ββββ΄ββββ΄ββββ βββββ΄ββββ΄ββββ΄ββββ΄ββββ β
β β β β
β P(x_t | x_<t) P(x_mask | x_context) β
β "λ€μ ν ν° μμΈ‘" "λ§μ€νΉλ ν ν° λ³΅μ" β
β β
β Prefix LM (Encoder-Decoder) β
β βββββ¬ββββ¬ββββ β βββββ¬ββββ¬ββββ β
β β A β B β C β β X β Y β Z β β
β βββββ΄ββββ΄ββββ βββββ΄ββββ΄ββββ β
β Bidirectional Autoregressive β
β "μ
λ ₯ μΈμ½λ©" "μΆλ ₯ μμ±" β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 κ° ν¨λ¬λ€μ λΉκ΅
| νΉμ± |
Causal LM |
Masked LM |
Prefix LM |
| λν λͺ¨λΈ |
GPT, LLaMA |
BERT, RoBERTa |
T5, BART |
| 컨ν
μ€νΈ |
μΌμͺ½λ§ μ°Έμ‘° |
μλ°©ν₯ μ°Έμ‘° |
μΈμ½λ: μλ°©ν₯, λμ½λ: μΌμͺ½ |
| νμ΅ μ νΈ |
λͺ¨λ ν ν° |
λ§μ€νΉλ ν ν°λ§ (15%) |
Span/μνμ€ |
| μμ± λ₯λ ₯ |
μμ°μ€λ¬μ΄ μμ± |
μΆκ° νμ΅ νμ |
μμ°μ€λ¬μ΄ μμ± |
| μ΄ν΄ λ₯λ ₯ |
Zero-shotμΌλ‘ κ°λ₯ |
κ°λ ₯ν νν νμ΅ |
κ· νμ |
2. Causal Language Modeling (CLM)
2.1 μνμ μ μ
λͺ©μ ν¨μ:
L_CLM = -Ξ£ log P(x_t | x_1, x_2, ..., x_{t-1})
νΉμ§:
- μνμ€μ λͺ¨λ ν ν°μ νμ΅ μ νΈλ‘ μ¬μ©
- Autoregressive: μΌμͺ½βμ€λ₯Έμͺ½ μμ°¨ μμ±
- Causal Maskλ‘ λ―Έλ ν ν° μ κ·Ό μ°¨λ¨
2.2 PyTorch ꡬν
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalLMHead(nn.Module):
"""Causal Language Model μΆλ ₯ λ μ΄μ΄"""
def __init__(self, hidden_dim: int, vocab_size: int):
super().__init__()
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: (batch, seq_len, hidden_dim)
Returns:
logits: (batch, seq_len, vocab_size)
"""
return self.lm_head(hidden_states)
def causal_lm_loss(
logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100
) -> torch.Tensor:
"""
Causal LM Loss κ³μ°
Args:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len) - λ€μ ν ν°μ΄ λ μ΄λΈ
"""
# Shift: logits[:-1]μ΄ labels[1:]μ μμΈ‘
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=ignore_index
)
return loss
# Causal Mask μμ±
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""
μμΌκ° λ§μ€ν¬ μμ± (λ―Έλ ν ν° μ°¨λ¨)
Returns:
mask: (seq_len, seq_len) - True = λ§μ€νΉ
"""
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
return mask
# μ¬μ© μμ
batch_size, seq_len, hidden_dim, vocab_size = 4, 128, 768, 50257
hidden_states = torch.randn(batch_size, seq_len, hidden_dim)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))
lm_head = CausalLMHead(hidden_dim, vocab_size)
logits = lm_head(hidden_states)
loss = causal_lm_loss(logits, labels)
print(f"CLM Loss: {loss.item():.4f}")
2.3 GPT μ€νμΌ νμ΅
class GPTPretraining:
"""GPT μ€νμΌ Pre-training"""
def __init__(self, model, tokenizer, max_length=1024):
self.model = model
self.tokenizer = tokenizer
self.max_length = max_length
def prepare_data(self, texts: list[str]) -> dict:
"""
μ°μλ ν
μ€νΈλ₯Ό κ³ μ κΈΈμ΄λ‘ λΆν
Document 1: "The cat sat on..."
Document 2: "Machine learning is..."
β [BOS] The cat sat on... [EOS] [BOS] Machine learning is... [EOS]
β κ³ μ κΈΈμ΄ μ²ν¬λ‘ λΆν (max_length λ¨μ)
"""
# μ 체 ν
μ€νΈ μ°κ²°
all_tokens = []
for text in texts:
tokens = self.tokenizer.encode(text)
all_tokens.extend(tokens)
all_tokens.append(self.tokenizer.eos_token_id)
# κ³ μ κΈΈμ΄λ‘ λΆν
chunks = []
for i in range(0, len(all_tokens) - self.max_length, self.max_length):
chunk = all_tokens[i:i + self.max_length]
chunks.append(chunk)
return {
'input_ids': torch.tensor(chunks),
'labels': torch.tensor(chunks) # λμΌ (shiftλ lossμμ)
}
def train_step(self, batch):
"""λ¨μΌ νμ΅ μ€ν
"""
input_ids = batch['input_ids']
labels = batch['labels']
# Forward
outputs = self.model(input_ids)
logits = outputs.logits
# Loss
loss = causal_lm_loss(logits, labels)
return loss
3. Masked Language Modeling (MLM)
3.1 BERT μ€νμΌ MLM
μλ³Έ: "The quick brown fox jumps over the lazy dog"
λ§μ€νΉ μ λ΅ (15% ν ν°):
- 80%: [MASK] ν ν°μΌλ‘ λ체
- 10%: λλ€ ν ν°μΌλ‘ λ체
- 10%: μλ³Έ μ μ§
κ²°κ³Ό: "The [MASK] brown fox jumps over the [MASK] dog"
β β
λͺ©ν: "quick" "lazy"
3.2 ꡬν
import random
class MLMDataCollator:
"""Masked Language Modeling λ°μ΄ν° μ μ²λ¦¬"""
def __init__(
self,
tokenizer,
mlm_probability: float = 0.15,
mask_token_ratio: float = 0.8,
random_token_ratio: float = 0.1
):
self.tokenizer = tokenizer
self.mlm_probability = mlm_probability
self.mask_token_ratio = mask_token_ratio
self.random_token_ratio = random_token_ratio
# νΉμ ν ν° ID
self.mask_token_id = tokenizer.mask_token_id
self.vocab_size = tokenizer.vocab_size
self.special_tokens = set([
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.pad_token_id
])
def __call__(self, batch: list[dict]) -> dict:
"""λ°°μΉ μ²λ¦¬"""
input_ids = torch.stack([item['input_ids'] for item in batch])
# λ§μ€νΉ
input_ids, labels = self.mask_tokens(input_ids)
return {
'input_ids': input_ids,
'labels': labels,
'attention_mask': torch.stack([item['attention_mask'] for item in batch])
}
def mask_tokens(
self,
input_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
ν ν° λ§μ€νΉ μν
Returns:
masked_input_ids: λ§μ€νΉλ μ
λ ₯
labels: μλ³Έ ν ν° (λ§μ€νΉ μ λ μμΉλ -100)
"""
labels = input_ids.clone()
# λ§μ€νΉ νλ₯ νλ ¬
probability_matrix = torch.full(input_ids.shape, self.mlm_probability)
# νΉμ ν ν°μ λ§μ€νΉνμ§ μμ
special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool)
for token_id in self.special_tokens:
special_tokens_mask |= (input_ids == token_id)
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
# λ§μ€νΉν μμΉ μ ν
masked_indices = torch.bernoulli(probability_matrix).bool()
# λ§μ€νΉ μ λ μμΉλ -100 (loss 무μ)
labels[~masked_indices] = -100
# 80%: [MASK]λ‘ λ체
indices_replaced = torch.bernoulli(
torch.full(input_ids.shape, self.mask_token_ratio)
).bool() & masked_indices
input_ids[indices_replaced] = self.mask_token_id
# 10%: λλ€ ν ν°
indices_random = torch.bernoulli(
torch.full(input_ids.shape, self.random_token_ratio / (1 - self.mask_token_ratio))
).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(self.vocab_size, input_ids.shape, dtype=torch.long)
input_ids[indices_random] = random_words[indices_random]
# λλ¨Έμ§ 10%: μλ³Έ μ μ§ (μ묡μ μΌλ‘ μ²λ¦¬λ¨)
return input_ids, labels
def mlm_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
"""MLM Loss (λ§μ€νΉλ μμΉλ§ κ³μ°)"""
return F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
3.3 RoBERTa κ°μ μ
class RoBERTaMLM:
"""
RoBERTa: MLM κ°μ λ²μ
BERT λλΉ λ³κ²½μ :
1. Dynamic Masking: μνλ§λ€ λ€λ₯Έ λ§μ€νΉ
2. λ κΈ΄ μνμ€ (512 β λ κΈΈκ²)
3. λ ν° λ°°μΉ (256 β 8K)
4. NSP μ κ±°
5. λ λ§μ λ°μ΄ν°, λ κΈ΄ νμ΅
"""
def __init__(self, tokenizer):
self.collator = MLMDataCollator(tokenizer)
def create_epoch_data(self, texts: list[str], epoch: int):
"""
Dynamic Masking: λ§€ μνλ§λ€ μλ‘μ΄ λ§μ€νΉ ν¨ν΄
"""
# μλλ₯Ό μνμ λ°λΌ λ³κ²½
random.seed(epoch)
torch.manual_seed(epoch)
# λ°μ΄ν° μ μ²λ¦¬ (μλ‘μ΄ λ§μ€νΉ μ μ©)
# ...
4. Span Corruption (T5)
4.1 κ°λ
μλ³Έ: "The quick brown fox jumps over the lazy dog"
Span Corruption:
- μ°μλ ν ν° spanμ νλμ sentinel ν ν°μΌλ‘ λ체
- λμ½λκ° μλ³Έ span 볡μ
μ
λ ₯: "The <X> fox <Y> over the lazy dog"
μΆλ ₯: "<X> quick brown <Y> jumps"
νΉμ§:
- νκ· span κΈΈμ΄: 3 ν ν°
- λ§μ€νΉ λΉμ¨: 15%
- Sentinel: <extra_id_0>, <extra_id_1>, ...
4.2 ꡬν
class SpanCorruptionCollator:
"""T5 μ€νμΌ Span Corruption"""
def __init__(
self,
tokenizer,
noise_density: float = 0.15,
mean_span_length: float = 3.0
):
self.tokenizer = tokenizer
self.noise_density = noise_density
self.mean_span_length = mean_span_length
# Sentinel ν ν° (<extra_id_0>, <extra_id_1>, ...)
self.sentinel_start_id = tokenizer.convert_tokens_to_ids("<extra_id_0>")
def __call__(self, examples: list[dict]) -> dict:
"""λ°°μΉ μ²λ¦¬"""
batch_inputs = []
batch_targets = []
for example in examples:
input_ids = example['input_ids']
inputs, targets = self.corrupt_span(input_ids)
batch_inputs.append(inputs)
batch_targets.append(targets)
# ν¨λ©
inputs_padded = self._pad_sequences(batch_inputs)
targets_padded = self._pad_sequences(batch_targets)
return {
'input_ids': inputs_padded,
'labels': targets_padded
}
def corrupt_span(
self,
input_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Span Corruption μ μ©
"""
length = len(input_ids)
num_noise_tokens = int(length * self.noise_density)
num_spans = max(1, int(num_noise_tokens / self.mean_span_length))
# Span μμ μμΉ μνλ§
span_starts = sorted(random.sample(range(length - 1), num_spans))
# κ° spanμ κΈΈμ΄ (μ§μ λΆν¬)
span_lengths = torch.poisson(
torch.full((num_spans,), self.mean_span_length - 1)
).long() + 1
# Span λ§μ€ν¬ μμ±
noise_mask = torch.zeros(length, dtype=torch.bool)
for start, span_len in zip(span_starts, span_lengths):
end = min(start + span_len, length)
noise_mask[start:end] = True
# μ
λ ₯ ꡬμ±: λ
Έμ΄μ¦ spanμ sentinelλ‘ λ체
input_tokens = []
target_tokens = []
sentinel_id = self.sentinel_start_id
i = 0
while i < length:
if noise_mask[i]:
# Span μμ: sentinel μΆκ°
input_tokens.append(sentinel_id)
target_tokens.append(sentinel_id)
# Span λ΄μ©μ targetμ μΆκ°
while i < length and noise_mask[i]:
target_tokens.append(input_ids[i].item())
i += 1
sentinel_id += 1
else:
input_tokens.append(input_ids[i].item())
i += 1
return torch.tensor(input_tokens), torch.tensor(target_tokens)
def _pad_sequences(self, sequences: list[torch.Tensor]) -> torch.Tensor:
"""μνμ€ ν¨λ©"""
max_len = max(len(seq) for seq in sequences)
padded = torch.full((len(sequences), max_len), self.tokenizer.pad_token_id)
for i, seq in enumerate(sequences):
padded[i, :len(seq)] = seq
return padded
5. UL2: Unified Language Learner
5.1 Mixture of Denoisers (MoD)
UL2: μ¬λ¬ λͺ©μ ν¨μλ₯Ό νΌν©νμ¬ νμ΅
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Mixture of Denoisers β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β R-Denoiser (Regular) S-Denoiser (Short) β
β - μ§§μ span (3-8 ν ν°) - λ§€μ° μ§§μ span (β€3 ν ν°) β
β - 15% λ§μ€νΉ - 15% λ§μ€νΉ β
β - NLU νμ€ν¬μ μ 리 - μΈλ°ν μ΄ν΄μ μ 리 β
β β
β X-Denoiser (Extreme) β
β - κΈ΄ span (12-64 ν ν°) β
β - 50% λ§μ€νΉ β
β - μμ± νμ€ν¬μ μ 리 β
β β
β Mode Switching: μ
λ ₯μ [R], [S], [X] ν리ν½μ€ μΆκ° β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
5.2 ꡬν
class UL2Collator:
"""UL2 Mixture of Denoisers"""
DENOISERS = {
'R': { # Regular
'span_length': (3, 8),
'noise_density': 0.15,
'prefix': '[R]'
},
'S': { # Short
'span_length': (1, 3),
'noise_density': 0.15,
'prefix': '[S]'
},
'X': { # Extreme
'span_length': (12, 64),
'noise_density': 0.5,
'prefix': '[X]'
}
}
def __init__(self, tokenizer, denoiser_weights: dict = None):
self.tokenizer = tokenizer
# κΈ°λ³Έ κ°μ€μΉ: R=50%, S=25%, X=25%
self.weights = denoiser_weights or {'R': 0.5, 'S': 0.25, 'X': 0.25}
def __call__(self, examples: list[dict]) -> dict:
"""λ°°μΉ μ²λ¦¬: κ° μμ μ λλ€ denoiser μ μ©"""
batch_inputs = []
batch_targets = []
for example in examples:
# Denoiser μ ν
denoiser = random.choices(
list(self.DENOISERS.keys()),
weights=list(self.weights.values())
)[0]
config = self.DENOISERS[denoiser]
# Span corruption μ μ©
inputs, targets = self.apply_denoiser(
example['input_ids'],
config
)
batch_inputs.append(inputs)
batch_targets.append(targets)
return self._collate(batch_inputs, batch_targets)
def apply_denoiser(
self,
input_ids: torch.Tensor,
config: dict
) -> tuple[torch.Tensor, torch.Tensor]:
"""νΉμ denoiser μ€μ μΌλ‘ corruption μ μ©"""
# ν리ν½μ€ μΆκ°
prefix_ids = self.tokenizer.encode(
config['prefix'],
add_special_tokens=False
)
# Span corruption (configμ λ°λΌ)
span_len = random.randint(*config['span_length'])
# ... corruption λ‘μ§
# ν리ν½μ€ + μ
λ ₯
inputs = torch.cat([
torch.tensor(prefix_ids),
input_ids # corrupted
])
return inputs, targets
6. Next Sentence Prediction (NSP) vs Sentence Order Prediction (SOP)
6.1 NSP (BERT)
class NSPDataCollator:
"""
Next Sentence Prediction
50%: μ€μ λ€μ λ¬Έμ₯ (IsNext)
50%: λλ€ λ¬Έμ₯ (NotNext)
λ¬Έμ μ : λ무 μ¬μ β RoBERTaμμ μ κ±°
"""
def create_nsp_pair(
self,
sentence_a: str,
sentence_b: str,
all_sentences: list[str]
) -> tuple[str, str, int]:
"""NSP λ°μ΄ν° μμ±"""
if random.random() < 0.5:
# μ€μ λ€μ λ¬Έμ₯
return sentence_a, sentence_b, 1 # IsNext
else:
# λλ€ λ¬Έμ₯
random_sentence = random.choice(all_sentences)
return sentence_a, random_sentence, 0 # NotNext
6.2 SOP (ALBERT)
class SOPDataCollator:
"""
Sentence Order Prediction (λ μ΄λ €μ΄ νμ€ν¬)
50%: μ μ μμ (A β B)
50%: μμ (B β A)
ν ν½ μμΈ‘μ΄ μλ μμ μμΈ‘ β λ μ μ©ν νμ΅ μ νΈ
"""
def create_sop_pair(
self,
sentence_a: str,
sentence_b: str
) -> tuple[str, str, int]:
"""SOP λ°μ΄ν° μμ±"""
if random.random() < 0.5:
return sentence_a, sentence_b, 1 # μ μ μμ
else:
return sentence_b, sentence_a, 0 # μμ
7. Pre-training λͺ©μ ν¨μ μ ν κ°μ΄λ
7.1 νμ€ν¬λ³ κΆμ₯ λͺ©μ ν¨μ
ββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββββββββ
β Downstream Task β κΆμ₯ Pre-training λͺ©μ ν¨μ β
ββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββ€
β ν
μ€νΈ μμ± β Causal LM (GPT μ€νμΌ) β
β ν
μ€νΈ λΆλ₯ β MLM (BERT) λλ Causal LM + Fine-tuning β
β μ§μμλ΅ β Span Corruption (T5) λλ MLM β
β λ²μ/μμ½ β Encoder-Decoder (T5, BART) β
β λ²μ© (Few-shot) β Causal LM λκ·λͺ¨ (GPT-3 μ€νμΌ) β
β λ²μ© (λ€μν) β UL2 (Mixture of Denoisers) β
ββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββ
7.2 λͺ¨λΈ ν¬κΈ°λ³ μ λ΅
| λͺ¨λΈ ν¬κΈ° |
κΆμ₯ μ κ·Όλ² |
μ΄μ |
| < 1B |
MLM + Fine-tuning |
νμ€ν¬ νΉν μ±λ₯ μ°μ |
| 1B - 10B |
Causal LM |
λ²μ©μ±κ³Ό ν¨μ¨μ κ· ν |
| > 10B |
Causal LM |
In-context Learning μΆν |
8. μ€μ΅: λͺ©μ ν¨μ λΉκ΅
from transformers import (
AutoModelForCausalLM,
AutoModelForMaskedLM,
T5ForConditionalGeneration,
AutoTokenizer
)
def compare_objectives():
"""μΈ κ°μ§ λͺ©μ ν¨μ λΉκ΅"""
# 1. Causal LM (GPT-2)
causal_tokenizer = AutoTokenizer.from_pretrained('gpt2')
causal_model = AutoModelForCausalLM.from_pretrained('gpt2')
text = "The capital of France is"
inputs = causal_tokenizer(text, return_tensors='pt')
# μμ±
outputs = causal_model.generate(
inputs['input_ids'],
max_new_tokens=5,
do_sample=False
)
print("Causal LM:", causal_tokenizer.decode(outputs[0]))
# β "The capital of France is Paris."
# 2. Masked LM (BERT)
mlm_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
mlm_model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')
text = "The capital of France is [MASK]."
inputs = mlm_tokenizer(text, return_tensors='pt')
outputs = mlm_model(**inputs)
mask_idx = (inputs['input_ids'] == mlm_tokenizer.mask_token_id).nonzero()[0, 1]
predicted_id = outputs.logits[0, mask_idx].argmax()
print("Masked LM:", mlm_tokenizer.decode(predicted_id))
# β "paris"
# 3. Span Corruption (T5)
t5_tokenizer = AutoTokenizer.from_pretrained('t5-small')
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small')
text = "translate English to French: The house is wonderful."
inputs = t5_tokenizer(text, return_tensors='pt')
outputs = t5_model.generate(inputs['input_ids'], max_new_tokens=20)
print("T5:", t5_tokenizer.decode(outputs[0], skip_special_tokens=True))
# β "La maison est merveilleuse."
if __name__ == "__main__":
compare_objectives()
μ°Έκ³ μλ£
λ
Όλ¬Έ
- Devlin et al. (2018). "BERT: Pre-training of Deep Bidirectional Transformers"
- Radford et al. (2019). "Language Models are Unsupervised Multitask Learners" (GPT-2)
- Raffel et al. (2019). "Exploring the Limits of Transfer Learning with T5"
- Tay et al. (2022). "UL2: Unifying Language Learning Paradigms"
κ΄λ ¨ λ μ¨