04. Pre-training λͺ©μ ν•¨μˆ˜

04. Pre-training λͺ©μ ν•¨μˆ˜

κ°œμš”

Pre-training λͺ©μ ν•¨μˆ˜λŠ” Foundation Model이 λŒ€κ·œλͺ¨ λ°μ΄ν„°μ—μ„œ μ–΄λ–€ νŒ¨ν„΄μ„ ν•™μŠ΅ν• μ§€ κ²°μ •ν•©λ‹ˆλ‹€. λͺ©μ ν•¨μˆ˜ 선택이 λͺ¨λΈμ˜ λŠ₯λ ₯κ³Ό downstream task μ„±λŠ₯에 직접적인 영ν–₯을 λ―ΈμΉ©λ‹ˆλ‹€.


1. Language Modeling νŒ¨λŸ¬λ‹€μž„

1.1 μ„Έ κ°€μ§€ μ£Όμš” 접근법

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    Language Modeling νŒ¨λŸ¬λ‹€μž„                     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  Causal LM (Autoregressive)        Masked LM (Bidirectional)   β”‚
β”‚  β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”            β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”        β”‚
β”‚  β”‚ A β”‚ B β”‚ C β”‚ D β”‚ ? β”‚            β”‚ A β”‚[M]β”‚ C β”‚[M]β”‚ E β”‚        β”‚
β”‚  β””β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”˜            β””β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”˜        β”‚
β”‚       ↓                                 ↓                       β”‚
β”‚  P(x_t | x_<t)                     P(x_mask | x_context)        β”‚
β”‚  "λ‹€μŒ 토큰 예츑"                   "λ§ˆμŠ€ν‚Ήλœ 토큰 볡원"           β”‚
β”‚                                                                 β”‚
β”‚  Prefix LM (Encoder-Decoder)                                    β”‚
β”‚  β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β” β†’ β”Œβ”€β”€β”€β”¬β”€β”€β”€β”¬β”€β”€β”€β”                                 β”‚
β”‚  β”‚ A β”‚ B β”‚ C β”‚   β”‚ X β”‚ Y β”‚ Z β”‚                                 β”‚
β”‚  β””β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”˜   β””β”€β”€β”€β”΄β”€β”€β”€β”΄β”€β”€β”€β”˜                                 β”‚
β”‚  Bidirectional    Autoregressive                                β”‚
β”‚  "μž…λ ₯ 인코딩"      "좜λ ₯ 생성"                                   β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

1.2 각 νŒ¨λŸ¬λ‹€μž„ 비ꡐ

νŠΉμ„± Causal LM Masked LM Prefix LM
λŒ€ν‘œ λͺ¨λΈ GPT, LLaMA BERT, RoBERTa T5, BART
μ»¨ν…μŠ€νŠΈ μ™Όμͺ½λ§Œ μ°Έμ‘° μ–‘λ°©ν–₯ μ°Έμ‘° 인코더: μ–‘λ°©ν–₯, 디코더: μ™Όμͺ½
ν•™μŠ΅ μ‹ ν˜Έ λͺ¨λ“  토큰 λ§ˆμŠ€ν‚Ήλœ ν† ν°λ§Œ (15%) Span/μ‹œν€€μŠ€
생성 λŠ₯λ ₯ μžμ—°μŠ€λŸ¬μš΄ 생성 μΆ”κ°€ ν•™μŠ΅ ν•„μš” μžμ—°μŠ€λŸ¬μš΄ 생성
이해 λŠ₯λ ₯ Zero-shot으둜 κ°€λŠ₯ κ°•λ ₯ν•œ ν‘œν˜„ ν•™μŠ΅ κ· ν˜•μ 

2. Causal Language Modeling (CLM)

2.1 μˆ˜ν•™μ  μ •μ˜

λͺ©μ ν•¨μˆ˜:
L_CLM = -Ξ£ log P(x_t | x_1, x_2, ..., x_{t-1})

νŠΉμ§•:
- μ‹œν€€μŠ€μ˜ λͺ¨λ“  토큰을 ν•™μŠ΅ μ‹ ν˜Έλ‘œ μ‚¬μš©
- Autoregressive: μ™Όμͺ½β†’μ˜€λ₯Έμͺ½ 순차 생성
- Causal Mask둜 미래 토큰 μ ‘κ·Ό 차단

2.2 PyTorch κ΅¬ν˜„

import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalLMHead(nn.Module):
    """Causal Language Model 좜λ ₯ λ ˆμ΄μ–΄"""

    def __init__(self, hidden_dim: int, vocab_size: int):
        super().__init__()
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hidden_states: (batch, seq_len, hidden_dim)
        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        return self.lm_head(hidden_states)


def causal_lm_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    ignore_index: int = -100
) -> torch.Tensor:
    """
    Causal LM Loss 계산

    Args:
        logits: (batch, seq_len, vocab_size)
        labels: (batch, seq_len) - λ‹€μŒ 토큰이 λ ˆμ΄λΈ”
    """
    # Shift: logits[:-1]이 labels[1:]을 예츑
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=ignore_index
    )
    return loss


# Causal Mask 생성
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    상삼각 마슀크 생성 (미래 토큰 차단)

    Returns:
        mask: (seq_len, seq_len) - True = λ§ˆμŠ€ν‚Ή
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask


# μ‚¬μš© μ˜ˆμ‹œ
batch_size, seq_len, hidden_dim, vocab_size = 4, 128, 768, 50257
hidden_states = torch.randn(batch_size, seq_len, hidden_dim)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))

lm_head = CausalLMHead(hidden_dim, vocab_size)
logits = lm_head(hidden_states)
loss = causal_lm_loss(logits, labels)
print(f"CLM Loss: {loss.item():.4f}")

2.3 GPT μŠ€νƒ€μΌ ν•™μŠ΅

class GPTPretraining:
    """GPT μŠ€νƒ€μΌ Pre-training"""

    def __init__(self, model, tokenizer, max_length=1024):
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length

    def prepare_data(self, texts: list[str]) -> dict:
        """
        μ—°μ†λœ ν…μŠ€νŠΈλ₯Ό κ³ μ • 길이둜 λΆ„ν• 

        Document 1: "The cat sat on..."
        Document 2: "Machine learning is..."

        β†’ [BOS] The cat sat on... [EOS] [BOS] Machine learning is... [EOS]
        β†’ κ³ μ • 길이 청크둜 λΆ„ν•  (max_length λ‹¨μœ„)
        """
        # 전체 ν…μŠ€νŠΈ μ—°κ²°
        all_tokens = []
        for text in texts:
            tokens = self.tokenizer.encode(text)
            all_tokens.extend(tokens)
            all_tokens.append(self.tokenizer.eos_token_id)

        # κ³ μ • 길이둜 λΆ„ν• 
        chunks = []
        for i in range(0, len(all_tokens) - self.max_length, self.max_length):
            chunk = all_tokens[i:i + self.max_length]
            chunks.append(chunk)

        return {
            'input_ids': torch.tensor(chunks),
            'labels': torch.tensor(chunks)  # 동일 (shiftλŠ” lossμ—μ„œ)
        }

    def train_step(self, batch):
        """단일 ν•™μŠ΅ μŠ€ν…"""
        input_ids = batch['input_ids']
        labels = batch['labels']

        # Forward
        outputs = self.model(input_ids)
        logits = outputs.logits

        # Loss
        loss = causal_lm_loss(logits, labels)

        return loss

3. Masked Language Modeling (MLM)

3.1 BERT μŠ€νƒ€μΌ MLM

원본: "The quick brown fox jumps over the lazy dog"

λ§ˆμŠ€ν‚Ή μ „λž΅ (15% 토큰):
- 80%: [MASK] ν† ν°μœΌλ‘œ λŒ€μ²΄
- 10%: 랜덀 ν† ν°μœΌλ‘œ λŒ€μ²΄
- 10%: 원본 μœ μ§€

κ²°κ³Ό: "The [MASK] brown fox jumps over the [MASK] dog"
                ↓                          ↓
λͺ©ν‘œ:        "quick"                    "lazy"

3.2 κ΅¬ν˜„

import random

class MLMDataCollator:
    """Masked Language Modeling 데이터 μ „μ²˜λ¦¬"""

    def __init__(
        self,
        tokenizer,
        mlm_probability: float = 0.15,
        mask_token_ratio: float = 0.8,
        random_token_ratio: float = 0.1
    ):
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability
        self.mask_token_ratio = mask_token_ratio
        self.random_token_ratio = random_token_ratio

        # 특수 토큰 ID
        self.mask_token_id = tokenizer.mask_token_id
        self.vocab_size = tokenizer.vocab_size
        self.special_tokens = set([
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
            tokenizer.pad_token_id
        ])

    def __call__(self, batch: list[dict]) -> dict:
        """배치 처리"""
        input_ids = torch.stack([item['input_ids'] for item in batch])

        # λ§ˆμŠ€ν‚Ή
        input_ids, labels = self.mask_tokens(input_ids)

        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': torch.stack([item['attention_mask'] for item in batch])
        }

    def mask_tokens(
        self,
        input_ids: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        토큰 λ§ˆμŠ€ν‚Ή μˆ˜ν–‰

        Returns:
            masked_input_ids: λ§ˆμŠ€ν‚Ήλœ μž…λ ₯
            labels: 원본 토큰 (λ§ˆμŠ€ν‚Ή μ•ˆ 된 μœ„μΉ˜λŠ” -100)
        """
        labels = input_ids.clone()

        # λ§ˆμŠ€ν‚Ή ν™•λ₯  ν–‰λ ¬
        probability_matrix = torch.full(input_ids.shape, self.mlm_probability)

        # 특수 토큰은 λ§ˆμŠ€ν‚Ήν•˜μ§€ μ•ŠμŒ
        special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        for token_id in self.special_tokens:
            special_tokens_mask |= (input_ids == token_id)

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        # λ§ˆμŠ€ν‚Ήν•  μœ„μΉ˜ 선택
        masked_indices = torch.bernoulli(probability_matrix).bool()

        # λ§ˆμŠ€ν‚Ή μ•ˆ 된 μœ„μΉ˜λŠ” -100 (loss λ¬΄μ‹œ)
        labels[~masked_indices] = -100

        # 80%: [MASK]둜 λŒ€μ²΄
        indices_replaced = torch.bernoulli(
            torch.full(input_ids.shape, self.mask_token_ratio)
        ).bool() & masked_indices
        input_ids[indices_replaced] = self.mask_token_id

        # 10%: 랜덀 토큰
        indices_random = torch.bernoulli(
            torch.full(input_ids.shape, self.random_token_ratio / (1 - self.mask_token_ratio))
        ).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(self.vocab_size, input_ids.shape, dtype=torch.long)
        input_ids[indices_random] = random_words[indices_random]

        # λ‚˜λ¨Έμ§€ 10%: 원본 μœ μ§€ (μ•”λ¬΅μ μœΌλ‘œ 처리됨)

        return input_ids, labels


def mlm_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """MLM Loss (λ§ˆμŠ€ν‚Ήλœ μœ„μΉ˜λ§Œ 계산)"""
    return F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        labels.view(-1),
        ignore_index=-100
    )

3.3 RoBERTa κ°œμ„ μ 

class RoBERTaMLM:
    """
    RoBERTa: MLM κ°œμ„  버전

    BERT λŒ€λΉ„ 변경점:
    1. Dynamic Masking: μ—ν­λ§ˆλ‹€ λ‹€λ₯Έ λ§ˆμŠ€ν‚Ή
    2. 더 κΈ΄ μ‹œν€€μŠ€ (512 β†’ 더 길게)
    3. 더 큰 배치 (256 β†’ 8K)
    4. NSP 제거
    5. 더 λ§Žμ€ 데이터, 더 κΈ΄ ν•™μŠ΅
    """

    def __init__(self, tokenizer):
        self.collator = MLMDataCollator(tokenizer)

    def create_epoch_data(self, texts: list[str], epoch: int):
        """
        Dynamic Masking: λ§€ μ—ν­λ§ˆλ‹€ μƒˆλ‘œμš΄ λ§ˆμŠ€ν‚Ή νŒ¨ν„΄
        """
        # μ‹œλ“œλ₯Ό 에폭에 따라 λ³€κ²½
        random.seed(epoch)
        torch.manual_seed(epoch)

        # 데이터 μ „μ²˜λ¦¬ (μƒˆλ‘œμš΄ λ§ˆμŠ€ν‚Ή 적용)
        # ...

4. Span Corruption (T5)

4.1 κ°œλ…

원본: "The quick brown fox jumps over the lazy dog"

Span Corruption:
- μ—°μ†λœ 토큰 span을 ν•˜λ‚˜μ˜ sentinel ν† ν°μœΌλ‘œ λŒ€μ²΄
- 디코더가 원본 span 볡원

μž…λ ₯: "The <X> fox <Y> over the lazy dog"
좜λ ₯: "<X> quick brown <Y> jumps"

νŠΉμ§•:
- 평균 span 길이: 3 토큰
- λ§ˆμŠ€ν‚Ή λΉ„μœ¨: 15%
- Sentinel: <extra_id_0>, <extra_id_1>, ...

4.2 κ΅¬ν˜„

class SpanCorruptionCollator:
    """T5 μŠ€νƒ€μΌ Span Corruption"""

    def __init__(
        self,
        tokenizer,
        noise_density: float = 0.15,
        mean_span_length: float = 3.0
    ):
        self.tokenizer = tokenizer
        self.noise_density = noise_density
        self.mean_span_length = mean_span_length

        # Sentinel 토큰 (<extra_id_0>, <extra_id_1>, ...)
        self.sentinel_start_id = tokenizer.convert_tokens_to_ids("<extra_id_0>")

    def __call__(self, examples: list[dict]) -> dict:
        """배치 처리"""
        batch_inputs = []
        batch_targets = []

        for example in examples:
            input_ids = example['input_ids']
            inputs, targets = self.corrupt_span(input_ids)
            batch_inputs.append(inputs)
            batch_targets.append(targets)

        # νŒ¨λ”©
        inputs_padded = self._pad_sequences(batch_inputs)
        targets_padded = self._pad_sequences(batch_targets)

        return {
            'input_ids': inputs_padded,
            'labels': targets_padded
        }

    def corrupt_span(
        self,
        input_ids: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Span Corruption 적용
        """
        length = len(input_ids)
        num_noise_tokens = int(length * self.noise_density)
        num_spans = max(1, int(num_noise_tokens / self.mean_span_length))

        # Span μ‹œμž‘ μœ„μΉ˜ μƒ˜ν”Œλ§
        span_starts = sorted(random.sample(range(length - 1), num_spans))

        # 각 span의 길이 (μ§€μˆ˜ 뢄포)
        span_lengths = torch.poisson(
            torch.full((num_spans,), self.mean_span_length - 1)
        ).long() + 1

        # Span 마슀크 생성
        noise_mask = torch.zeros(length, dtype=torch.bool)
        for start, span_len in zip(span_starts, span_lengths):
            end = min(start + span_len, length)
            noise_mask[start:end] = True

        # μž…λ ₯ ꡬ성: λ…Έμ΄μ¦ˆ span을 sentinel둜 λŒ€μ²΄
        input_tokens = []
        target_tokens = []
        sentinel_id = self.sentinel_start_id

        i = 0
        while i < length:
            if noise_mask[i]:
                # Span μ‹œμž‘: sentinel μΆ”κ°€
                input_tokens.append(sentinel_id)
                target_tokens.append(sentinel_id)

                # Span λ‚΄μš©μ„ target에 μΆ”κ°€
                while i < length and noise_mask[i]:
                    target_tokens.append(input_ids[i].item())
                    i += 1

                sentinel_id += 1
            else:
                input_tokens.append(input_ids[i].item())
                i += 1

        return torch.tensor(input_tokens), torch.tensor(target_tokens)

    def _pad_sequences(self, sequences: list[torch.Tensor]) -> torch.Tensor:
        """μ‹œν€€μŠ€ νŒ¨λ”©"""
        max_len = max(len(seq) for seq in sequences)
        padded = torch.full((len(sequences), max_len), self.tokenizer.pad_token_id)
        for i, seq in enumerate(sequences):
            padded[i, :len(seq)] = seq
        return padded

5. UL2: Unified Language Learner

5.1 Mixture of Denoisers (MoD)

UL2: μ—¬λŸ¬ λͺ©μ ν•¨μˆ˜λ₯Ό ν˜Όν•©ν•˜μ—¬ ν•™μŠ΅

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    Mixture of Denoisers                        β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                β”‚
β”‚  R-Denoiser (Regular)      S-Denoiser (Short)                 β”‚
β”‚  - 짧은 span (3-8 토큰)     - 맀우 짧은 span (≀3 토큰)          β”‚
β”‚  - 15% λ§ˆμŠ€ν‚Ή               - 15% λ§ˆμŠ€ν‚Ή                        β”‚
β”‚  - NLU νƒœμŠ€ν¬μ— 유리         - μ„Έλ°€ν•œ 이해에 유리                 β”‚
β”‚                                                                β”‚
β”‚  X-Denoiser (Extreme)                                          β”‚
β”‚  - κΈ΄ span (12-64 토큰)                                        β”‚
β”‚  - 50% λ§ˆμŠ€ν‚Ή                                                  β”‚
β”‚  - 생성 νƒœμŠ€ν¬μ— 유리                                           β”‚
β”‚                                                                β”‚
β”‚  Mode Switching: μž…λ ₯에 [R], [S], [X] ν”„λ¦¬ν”½μŠ€ μΆ”κ°€             β”‚
β”‚                                                                β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

5.2 κ΅¬ν˜„

class UL2Collator:
    """UL2 Mixture of Denoisers"""

    DENOISERS = {
        'R': {  # Regular
            'span_length': (3, 8),
            'noise_density': 0.15,
            'prefix': '[R]'
        },
        'S': {  # Short
            'span_length': (1, 3),
            'noise_density': 0.15,
            'prefix': '[S]'
        },
        'X': {  # Extreme
            'span_length': (12, 64),
            'noise_density': 0.5,
            'prefix': '[X]'
        }
    }

    def __init__(self, tokenizer, denoiser_weights: dict = None):
        self.tokenizer = tokenizer
        # κΈ°λ³Έ κ°€μ€‘μΉ˜: R=50%, S=25%, X=25%
        self.weights = denoiser_weights or {'R': 0.5, 'S': 0.25, 'X': 0.25}

    def __call__(self, examples: list[dict]) -> dict:
        """배치 처리: 각 μ˜ˆμ œμ— 랜덀 denoiser 적용"""
        batch_inputs = []
        batch_targets = []

        for example in examples:
            # Denoiser 선택
            denoiser = random.choices(
                list(self.DENOISERS.keys()),
                weights=list(self.weights.values())
            )[0]

            config = self.DENOISERS[denoiser]

            # Span corruption 적용
            inputs, targets = self.apply_denoiser(
                example['input_ids'],
                config
            )

            batch_inputs.append(inputs)
            batch_targets.append(targets)

        return self._collate(batch_inputs, batch_targets)

    def apply_denoiser(
        self,
        input_ids: torch.Tensor,
        config: dict
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """νŠΉμ • denoiser μ„€μ •μœΌλ‘œ corruption 적용"""
        # ν”„λ¦¬ν”½μŠ€ μΆ”κ°€
        prefix_ids = self.tokenizer.encode(
            config['prefix'],
            add_special_tokens=False
        )

        # Span corruption (config에 따라)
        span_len = random.randint(*config['span_length'])
        # ... corruption 둜직

        # ν”„λ¦¬ν”½μŠ€ + μž…λ ₯
        inputs = torch.cat([
            torch.tensor(prefix_ids),
            input_ids  # corrupted
        ])

        return inputs, targets

6. Next Sentence Prediction (NSP) vs Sentence Order Prediction (SOP)

6.1 NSP (BERT)

class NSPDataCollator:
    """
    Next Sentence Prediction

    50%: μ‹€μ œ λ‹€μŒ λ¬Έμž₯ (IsNext)
    50%: 랜덀 λ¬Έμž₯ (NotNext)

    문제점: λ„ˆλ¬΄ 쉬움 β†’ RoBERTaμ—μ„œ 제거
    """

    def create_nsp_pair(
        self,
        sentence_a: str,
        sentence_b: str,
        all_sentences: list[str]
    ) -> tuple[str, str, int]:
        """NSP 데이터 생성"""
        if random.random() < 0.5:
            # μ‹€μ œ λ‹€μŒ λ¬Έμž₯
            return sentence_a, sentence_b, 1  # IsNext
        else:
            # 랜덀 λ¬Έμž₯
            random_sentence = random.choice(all_sentences)
            return sentence_a, random_sentence, 0  # NotNext

6.2 SOP (ALBERT)

class SOPDataCollator:
    """
    Sentence Order Prediction (더 μ–΄λ €μš΄ νƒœμŠ€ν¬)

    50%: 정상 μˆœμ„œ (A β†’ B)
    50%: μ—­μˆœ (B β†’ A)

    ν† ν”½ 예츑이 μ•„λ‹Œ μˆœμ„œ 예츑 β†’ 더 μœ μš©ν•œ ν•™μŠ΅ μ‹ ν˜Έ
    """

    def create_sop_pair(
        self,
        sentence_a: str,
        sentence_b: str
    ) -> tuple[str, str, int]:
        """SOP 데이터 생성"""
        if random.random() < 0.5:
            return sentence_a, sentence_b, 1  # 정상 μˆœμ„œ
        else:
            return sentence_b, sentence_a, 0  # μ—­μˆœ

7. Pre-training λͺ©μ ν•¨μˆ˜ 선택 κ°€μ΄λ“œ

7.1 νƒœμŠ€ν¬λ³„ ꢌμž₯ λͺ©μ ν•¨μˆ˜

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Downstream Task  β”‚ ꢌμž₯ Pre-training λͺ©μ ν•¨μˆ˜              β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ ν…μŠ€νŠΈ 생성      β”‚ Causal LM (GPT μŠ€νƒ€μΌ)                  β”‚
β”‚ ν…μŠ€νŠΈ λΆ„λ₯˜      β”‚ MLM (BERT) λ˜λŠ” Causal LM + Fine-tuning β”‚
β”‚ μ§ˆμ˜μ‘λ‹΅         β”‚ Span Corruption (T5) λ˜λŠ” MLM           β”‚
β”‚ λ²ˆμ—­/μš”μ•½        β”‚ Encoder-Decoder (T5, BART)              β”‚
β”‚ λ²”μš© (Few-shot)  β”‚ Causal LM λŒ€κ·œλͺ¨ (GPT-3 μŠ€νƒ€μΌ)         β”‚
β”‚ λ²”μš© (λ‹€μ–‘ν•œ)    β”‚ UL2 (Mixture of Denoisers)              β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

7.2 λͺ¨λΈ 크기별 μ „λž΅

λͺ¨λΈ 크기 ꢌμž₯ 접근법 이유
< 1B MLM + Fine-tuning νƒœμŠ€ν¬ νŠΉν™” μ„±λŠ₯ 우수
1B - 10B Causal LM λ²”μš©μ„±κ³Ό 효율의 κ· ν˜•
> 10B Causal LM In-context Learning μΆœν˜„

8. μ‹€μŠ΅: λͺ©μ ν•¨μˆ˜ 비ꡐ

from transformers import (
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    T5ForConditionalGeneration,
    AutoTokenizer
)

def compare_objectives():
    """μ„Έ κ°€μ§€ λͺ©μ ν•¨μˆ˜ 비ꡐ"""

    # 1. Causal LM (GPT-2)
    causal_tokenizer = AutoTokenizer.from_pretrained('gpt2')
    causal_model = AutoModelForCausalLM.from_pretrained('gpt2')

    text = "The capital of France is"
    inputs = causal_tokenizer(text, return_tensors='pt')

    # 생성
    outputs = causal_model.generate(
        inputs['input_ids'],
        max_new_tokens=5,
        do_sample=False
    )
    print("Causal LM:", causal_tokenizer.decode(outputs[0]))
    # β†’ "The capital of France is Paris."

    # 2. Masked LM (BERT)
    mlm_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    mlm_model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')

    text = "The capital of France is [MASK]."
    inputs = mlm_tokenizer(text, return_tensors='pt')

    outputs = mlm_model(**inputs)
    mask_idx = (inputs['input_ids'] == mlm_tokenizer.mask_token_id).nonzero()[0, 1]
    predicted_id = outputs.logits[0, mask_idx].argmax()
    print("Masked LM:", mlm_tokenizer.decode(predicted_id))
    # β†’ "paris"

    # 3. Span Corruption (T5)
    t5_tokenizer = AutoTokenizer.from_pretrained('t5-small')
    t5_model = T5ForConditionalGeneration.from_pretrained('t5-small')

    text = "translate English to French: The house is wonderful."
    inputs = t5_tokenizer(text, return_tensors='pt')

    outputs = t5_model.generate(inputs['input_ids'], max_new_tokens=20)
    print("T5:", t5_tokenizer.decode(outputs[0], skip_special_tokens=True))
    # β†’ "La maison est merveilleuse."

if __name__ == "__main__":
    compare_objectives()

참고 자료

λ…Όλ¬Έ

  • Devlin et al. (2018). "BERT: Pre-training of Deep Bidirectional Transformers"
  • Radford et al. (2019). "Language Models are Unsupervised Multitask Learners" (GPT-2)
  • Raffel et al. (2019). "Exploring the Limits of Transfer Learning with T5"
  • Tay et al. (2022). "UL2: Unifying Language Learning Paradigms"

κ΄€λ ¨ 레슨

to navigate between lessons