04. Pre-training Objectives

04. Pre-training Objectives

Overview

Pre-training objectives determine what patterns a Foundation Model learns from large-scale data. The choice of objective directly impacts the model's capabilities and downstream task performance.


1. Language Modeling Paradigms

1.1 Three Main Approaches

┌─────────────────────────────────────────────────────────────────┐
                    Language Modeling Paradigms                   
├─────────────────────────────────────────────────────────────────┤
                                                                 
  Causal LM (Autoregressive)        Masked LM (Bidirectional)   
  ┌───┬───┬───┬───┬───┐            ┌───┬───┬───┬───┬───┐        
   A  B  C  D  ?              A [M] C [M] E         
  └───┴───┴───┴───┴───┘            └───┴───┴───┴───┴───┘        
                                                               
  P(x_t | x_<t)                     P(x_mask | x_context)        
  "Predict next token"              "Restore masked token"       
                                                                 
  Prefix LM (Encoder-Decoder)                                    
  ┌───┬───┬───┐  ┌───┬───┬───┐                                 
   A  B  C     X  Y  Z                                  
  └───┴───┴───┘   └───┴───┴───┘                                 
  Bidirectional    Autoregressive                                
  "Encode input"   "Generate output"                             
                                                                 
└─────────────────────────────────────────────────────────────────┘

1.2 Paradigm Comparison

Feature Causal LM Masked LM Prefix LM
Representative Models GPT, LLaMA BERT, RoBERTa T5, BART
Context Left-only Bidirectional Encoder: bidirectional, Decoder: left
Training Signal All tokens Masked tokens only (15%) Span/sequence
Generation Natural generation Requires additional training Natural generation
Understanding Zero-shot capable Strong representation learning Balanced

2. Causal Language Modeling (CLM)

2.1 Mathematical Definition

Objective Function:
L_CLM = -Σ log P(x_t | x_1, x_2, ..., x_{t-1})

Characteristics:
- Uses all tokens in sequence as training signal
- Autoregressive: leftright sequential generation
- Causal Mask blocks access to future tokens

2.2 PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F

class CausalLMHead(nn.Module):
    """Causal Language Model output layer"""

    def __init__(self, hidden_dim: int, vocab_size: int):
        super().__init__()
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hidden_states: (batch, seq_len, hidden_dim)
        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        return self.lm_head(hidden_states)


def causal_lm_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    ignore_index: int = -100
) -> torch.Tensor:
    """
    Calculate Causal LM Loss

    Args:
        logits: (batch, seq_len, vocab_size)
        labels: (batch, seq_len) - next token as label
    """
    # Shift: logits[:-1] predicts labels[1:]
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=ignore_index
    )
    return loss


# Create Causal Mask
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    Generate upper triangular mask (block future tokens)

    Returns:
        mask: (seq_len, seq_len) - True = masked
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    return mask


# Usage example
batch_size, seq_len, hidden_dim, vocab_size = 4, 128, 768, 50257
hidden_states = torch.randn(batch_size, seq_len, hidden_dim)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))

lm_head = CausalLMHead(hidden_dim, vocab_size)
logits = lm_head(hidden_states)
loss = causal_lm_loss(logits, labels)
print(f"CLM Loss: {loss.item():.4f}")

2.3 GPT-style Training

class GPTPretraining:
    """GPT-style Pre-training"""

    def __init__(self, model, tokenizer, max_length=1024):
        self.model = model
        self.tokenizer = tokenizer
        self.max_length = max_length

    def prepare_data(self, texts: list[str]) -> dict:
        """
        Split continuous text into fixed-length chunks

        Document 1: "The cat sat on..."
        Document 2: "Machine learning is..."

        → [BOS] The cat sat on... [EOS] [BOS] Machine learning is... [EOS]
        → Split into fixed-length chunks (max_length units)
        """
        # Concatenate all texts
        all_tokens = []
        for text in texts:
            tokens = self.tokenizer.encode(text)
            all_tokens.extend(tokens)
            all_tokens.append(self.tokenizer.eos_token_id)

        # Split into fixed-length chunks
        chunks = []
        for i in range(0, len(all_tokens) - self.max_length, self.max_length):
            chunk = all_tokens[i:i + self.max_length]
            chunks.append(chunk)

        return {
            'input_ids': torch.tensor(chunks),
            'labels': torch.tensor(chunks)  # Same (shift happens in loss)
        }

    def train_step(self, batch):
        """Single training step"""
        input_ids = batch['input_ids']
        labels = batch['labels']

        # Forward
        outputs = self.model(input_ids)
        logits = outputs.logits

        # Loss
        loss = causal_lm_loss(logits, labels)

        return loss

3. Masked Language Modeling (MLM)

3.1 BERT-style MLM

Original: "The quick brown fox jumps over the lazy dog"

Masking Strategy (15% of tokens):
- 80%: Replace with [MASK] token
- 10%: Replace with random token
- 10%: Keep original

Result: "The [MASK] brown fox jumps over the [MASK] dog"
                                        
Target:    "quick"                    "lazy"

3.2 Implementation

import random

class MLMDataCollator:
    """Masked Language Modeling data preprocessing"""

    def __init__(
        self,
        tokenizer,
        mlm_probability: float = 0.15,
        mask_token_ratio: float = 0.8,
        random_token_ratio: float = 0.1
    ):
        self.tokenizer = tokenizer
        self.mlm_probability = mlm_probability
        self.mask_token_ratio = mask_token_ratio
        self.random_token_ratio = random_token_ratio

        # Special token IDs
        self.mask_token_id = tokenizer.mask_token_id
        self.vocab_size = tokenizer.vocab_size
        self.special_tokens = set([
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
            tokenizer.pad_token_id
        ])

    def __call__(self, batch: list[dict]) -> dict:
        """Process batch"""
        input_ids = torch.stack([item['input_ids'] for item in batch])

        # Masking
        input_ids, labels = self.mask_tokens(input_ids)

        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': torch.stack([item['attention_mask'] for item in batch])
        }

    def mask_tokens(
        self,
        input_ids: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Perform token masking

        Returns:
            masked_input_ids: Masked input
            labels: Original tokens (unmasked positions are -100)
        """
        labels = input_ids.clone()

        # Masking probability matrix
        probability_matrix = torch.full(input_ids.shape, self.mlm_probability)

        # Don't mask special tokens
        special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool)
        for token_id in self.special_tokens:
            special_tokens_mask |= (input_ids == token_id)

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)

        # Select positions to mask
        masked_indices = torch.bernoulli(probability_matrix).bool()

        # Unmasked positions are -100 (ignore in loss)
        labels[~masked_indices] = -100

        # 80%: Replace with [MASK]
        indices_replaced = torch.bernoulli(
            torch.full(input_ids.shape, self.mask_token_ratio)
        ).bool() & masked_indices
        input_ids[indices_replaced] = self.mask_token_id

        # 10%: Random token
        indices_random = torch.bernoulli(
            torch.full(input_ids.shape, self.random_token_ratio / (1 - self.mask_token_ratio))
        ).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(self.vocab_size, input_ids.shape, dtype=torch.long)
        input_ids[indices_random] = random_words[indices_random]

        # Remaining 10%: Keep original (handled implicitly)

        return input_ids, labels


def mlm_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """MLM Loss (compute only on masked positions)"""
    return F.cross_entropy(
        logits.view(-1, logits.size(-1)),
        labels.view(-1),
        ignore_index=-100
    )

3.3 RoBERTa Improvements

class RoBERTaMLM:
    """
    RoBERTa: Improved MLM version

    Changes from BERT:
    1. Dynamic Masking: Different masking per epoch
    2. Longer sequences (512 → longer)
    3. Larger batches (256 → 8K)
    4. NSP removal
    5. More data, longer training
    """

    def __init__(self, tokenizer):
        self.collator = MLMDataCollator(tokenizer)

    def create_epoch_data(self, texts: list[str], epoch: int):
        """
        Dynamic Masking: New masking pattern each epoch
        """
        # Change seed based on epoch
        random.seed(epoch)
        torch.manual_seed(epoch)

        # Data preprocessing (apply new masking)
        # ...

4. Span Corruption (T5)

4.1 Concept

Original: "The quick brown fox jumps over the lazy dog"

Span Corruption:
- Replace consecutive token spans with a single sentinel token
- Decoder restores original span

Input: "The <X> fox <Y> over the lazy dog"
Output: "<X> quick brown <Y> jumps"

Characteristics:
- Average span length: 3 tokens
- Masking ratio: 15%
- Sentinels: <extra_id_0>, <extra_id_1>, ...

4.2 Implementation

class SpanCorruptionCollator:
    """T5-style Span Corruption"""

    def __init__(
        self,
        tokenizer,
        noise_density: float = 0.15,
        mean_span_length: float = 3.0
    ):
        self.tokenizer = tokenizer
        self.noise_density = noise_density
        self.mean_span_length = mean_span_length

        # Sentinel tokens (<extra_id_0>, <extra_id_1>, ...)
        self.sentinel_start_id = tokenizer.convert_tokens_to_ids("<extra_id_0>")

    def __call__(self, examples: list[dict]) -> dict:
        """Process batch"""
        batch_inputs = []
        batch_targets = []

        for example in examples:
            input_ids = example['input_ids']
            inputs, targets = self.corrupt_span(input_ids)
            batch_inputs.append(inputs)
            batch_targets.append(targets)

        # Padding
        inputs_padded = self._pad_sequences(batch_inputs)
        targets_padded = self._pad_sequences(batch_targets)

        return {
            'input_ids': inputs_padded,
            'labels': targets_padded
        }

    def corrupt_span(
        self,
        input_ids: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Apply Span Corruption
        """
        length = len(input_ids)
        num_noise_tokens = int(length * self.noise_density)
        num_spans = max(1, int(num_noise_tokens / self.mean_span_length))

        # Sample span start positions
        span_starts = sorted(random.sample(range(length - 1), num_spans))

        # Span lengths (exponential distribution)
        span_lengths = torch.poisson(
            torch.full((num_spans,), self.mean_span_length - 1)
        ).long() + 1

        # Create span mask
        noise_mask = torch.zeros(length, dtype=torch.bool)
        for start, span_len in zip(span_starts, span_lengths):
            end = min(start + span_len, length)
            noise_mask[start:end] = True

        # Construct input: replace noise spans with sentinels
        input_tokens = []
        target_tokens = []
        sentinel_id = self.sentinel_start_id

        i = 0
        while i < length:
            if noise_mask[i]:
                # Span start: add sentinel
                input_tokens.append(sentinel_id)
                target_tokens.append(sentinel_id)

                # Add span content to target
                while i < length and noise_mask[i]:
                    target_tokens.append(input_ids[i].item())
                    i += 1

                sentinel_id += 1
            else:
                input_tokens.append(input_ids[i].item())
                i += 1

        return torch.tensor(input_tokens), torch.tensor(target_tokens)

    def _pad_sequences(self, sequences: list[torch.Tensor]) -> torch.Tensor:
        """Pad sequences"""
        max_len = max(len(seq) for seq in sequences)
        padded = torch.full((len(sequences), max_len), self.tokenizer.pad_token_id)
        for i, seq in enumerate(sequences):
            padded[i, :len(seq)] = seq
        return padded

5. UL2: Unified Language Learner

5.1 Mixture of Denoisers (MoD)

UL2: Training with mixture of objectives

┌────────────────────────────────────────────────────────────────┐
                    Mixture of Denoisers                        
├────────────────────────────────────────────────────────────────┤
                                                                
  R-Denoiser (Regular)      S-Denoiser (Short)                 
  - Short spans (3-8 tokens) - Very short spans (3 tokens)    
  - 15% masking              - 15% masking                      
  - Good for NLU tasks       - Good for fine-grained understanding 
                                                                
  X-Denoiser (Extreme)                                          
  - Long spans (12-64 tokens)                                   
  - 50% masking                                                 
  - Good for generation tasks                                   
                                                                
  Mode Switching: Add [R], [S], [X] prefix to input            
                                                                
└────────────────────────────────────────────────────────────────┘

5.2 Implementation

class UL2Collator:
    """UL2 Mixture of Denoisers"""

    DENOISERS = {
        'R': {  # Regular
            'span_length': (3, 8),
            'noise_density': 0.15,
            'prefix': '[R]'
        },
        'S': {  # Short
            'span_length': (1, 3),
            'noise_density': 0.15,
            'prefix': '[S]'
        },
        'X': {  # Extreme
            'span_length': (12, 64),
            'noise_density': 0.5,
            'prefix': '[X]'
        }
    }

    def __init__(self, tokenizer, denoiser_weights: dict = None):
        self.tokenizer = tokenizer
        # Default weights: R=50%, S=25%, X=25%
        self.weights = denoiser_weights or {'R': 0.5, 'S': 0.25, 'X': 0.25}

    def __call__(self, examples: list[dict]) -> dict:
        """Process batch: Apply random denoiser to each example"""
        batch_inputs = []
        batch_targets = []

        for example in examples:
            # Select denoiser
            denoiser = random.choices(
                list(self.DENOISERS.keys()),
                weights=list(self.weights.values())
            )[0]

            config = self.DENOISERS[denoiser]

            # Apply span corruption
            inputs, targets = self.apply_denoiser(
                example['input_ids'],
                config
            )

            batch_inputs.append(inputs)
            batch_targets.append(targets)

        return self._collate(batch_inputs, batch_targets)

    def apply_denoiser(
        self,
        input_ids: torch.Tensor,
        config: dict
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Apply corruption with specific denoiser config"""
        # Add prefix
        prefix_ids = self.tokenizer.encode(
            config['prefix'],
            add_special_tokens=False
        )

        # Span corruption (according to config)
        span_len = random.randint(*config['span_length'])
        # ... corruption logic

        # Prefix + input
        inputs = torch.cat([
            torch.tensor(prefix_ids),
            input_ids  # corrupted
        ])

        return inputs, targets

6. Next Sentence Prediction (NSP) vs Sentence Order Prediction (SOP)

6.1 NSP (BERT)

class NSPDataCollator:
    """
    Next Sentence Prediction

    50%: Actual next sentence (IsNext)
    50%: Random sentence (NotNext)

    Problem: Too easy → Removed in RoBERTa
    """

    def create_nsp_pair(
        self,
        sentence_a: str,
        sentence_b: str,
        all_sentences: list[str]
    ) -> tuple[str, str, int]:
        """Create NSP data"""
        if random.random() < 0.5:
            # Actual next sentence
            return sentence_a, sentence_b, 1  # IsNext
        else:
            # Random sentence
            random_sentence = random.choice(all_sentences)
            return sentence_a, random_sentence, 0  # NotNext

6.2 SOP (ALBERT)

class SOPDataCollator:
    """
    Sentence Order Prediction (harder task)

    50%: Normal order (A → B)
    50%: Reversed (B → A)

    Order prediction not topic prediction → More useful training signal
    """

    def create_sop_pair(
        self,
        sentence_a: str,
        sentence_b: str
    ) -> tuple[str, str, int]:
        """Create SOP data"""
        if random.random() < 0.5:
            return sentence_a, sentence_b, 1  # Normal order
        else:
            return sentence_b, sentence_a, 0  # Reversed

7. Pre-training Objective Selection Guide

┌──────────────────┬─────────────────────────────────────────┐
│ Downstream Task  │ Recommended Pre-training Objective      │
├──────────────────┼─────────────────────────────────────────┤
│ Text Generation  │ Causal LM (GPT-style)                   │
│ Text Classification │ MLM (BERT) or Causal LM + Fine-tuning │
│ Question Answering │ Span Corruption (T5) or MLM            │
│ Translation/Summarization │ Encoder-Decoder (T5, BART)     │
│ General (Few-shot) │ Large-scale Causal LM (GPT-3 style)   │
│ General (Diverse) │ UL2 (Mixture of Denoisers)             │
└──────────────────┴─────────────────────────────────────────┘

7.2 Strategy by Model Size

Model Size Recommended Approach Reason
< 1B MLM + Fine-tuning Excellent task-specific performance
1B - 10B Causal LM Balance of versatility and efficiency
> 10B Causal LM Emergence of in-context learning

8. Practice: Comparing Objectives

from transformers import (
    AutoModelForCausalLM,
    AutoModelForMaskedLM,
    T5ForConditionalGeneration,
    AutoTokenizer
)

def compare_objectives():
    """Compare three objectives"""

    # 1. Causal LM (GPT-2)
    causal_tokenizer = AutoTokenizer.from_pretrained('gpt2')
    causal_model = AutoModelForCausalLM.from_pretrained('gpt2')

    text = "The capital of France is"
    inputs = causal_tokenizer(text, return_tensors='pt')

    # Generate
    outputs = causal_model.generate(
        inputs['input_ids'],
        max_new_tokens=5,
        do_sample=False
    )
    print("Causal LM:", causal_tokenizer.decode(outputs[0]))
    # → "The capital of France is Paris."

    # 2. Masked LM (BERT)
    mlm_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    mlm_model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased')

    text = "The capital of France is [MASK]."
    inputs = mlm_tokenizer(text, return_tensors='pt')

    outputs = mlm_model(**inputs)
    mask_idx = (inputs['input_ids'] == mlm_tokenizer.mask_token_id).nonzero()[0, 1]
    predicted_id = outputs.logits[0, mask_idx].argmax()
    print("Masked LM:", mlm_tokenizer.decode(predicted_id))
    # → "paris"

    # 3. Span Corruption (T5)
    t5_tokenizer = AutoTokenizer.from_pretrained('t5-small')
    t5_model = T5ForConditionalGeneration.from_pretrained('t5-small')

    text = "translate English to French: The house is wonderful."
    inputs = t5_tokenizer(text, return_tensors='pt')

    outputs = t5_model.generate(inputs['input_ids'], max_new_tokens=20)
    print("T5:", t5_tokenizer.decode(outputs[0], skip_special_tokens=True))
    # → "La maison est merveilleuse."

if __name__ == "__main__":
    compare_objectives()

References

Papers

  • Devlin et al. (2018). "BERT: Pre-training of Deep Bidirectional Transformers"
  • Radford et al. (2019). "Language Models are Unsupervised Multitask Learners" (GPT-2)
  • Raffel et al. (2019). "Exploring the Limits of Transfer Learning with T5"
  • Tay et al. (2022). "UL2: Unifying Language Learning Paradigms"
to navigate between lessons