21. Continued Pre-training
21. Continued Pre-training¶
Overview¶
Continued Pre-training is a method of further training existing pre-trained models to adapt them to specific domains or tasks. Unlike typical fine-tuning, it performs language modeling on large amounts of domain data.
1. Continued Pre-training Overview¶
1.1 Why Is It Needed?¶
Scenario:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β Base Model (LLaMA-7B) β
β - Training: General web text β
β - Strength: General language understanding β
β - Weakness: Lacking domain-specific knowledge β
β β
β Target Domain: Medical β
β - Specialized terminology (drug names, diseases) β
β - Domain-specific reasoning β
β - Special document formats β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Solution:
1. Instruction Tuning alone is insufficient for knowledge injection
2. Learn domain knowledge through Continued Pre-training
3. Then apply Instruction Tuning for task adaptation
1.2 Training Pipeline¶
General Pipeline:
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β β
β Pre-trained Model β
β β β
β [Continued Pre-training] β
β - Domain data (10B+ tokens) β
β - Causal LM objective β
β - Lower learning rate β
β β β
β Domain-Adapted Model β
β β β
β [Instruction Tuning] β
β - Domain-specific instructions β
β β β
β Final Domain Model β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
2. Catastrophic Forgetting¶
2.1 Problem Definition¶
Catastrophic Forgetting:
The phenomenon of forgetting existing knowledge while learning new knowledge
Example:
ββββββββββββββββββββββββββββββββββββββββββ
β Before CPT: β
β Q: "What is the capital of France?" β
β A: "Paris" β β
β β
β After CPT (medical domain): β
β Q: "What is the capital of France?" β
β A: "The patient presented with..." β β
ββββββββββββββββββββββββββββββββββββββββββ
2.2 Mitigation Strategies¶
import torch
import torch.nn as nn
from typing import Dict, List, Optional
class ContinuedPretrainingWithRegularization:
"""Training with Catastrophic Forgetting mitigation"""
def __init__(
self,
model: nn.Module,
reference_model: nn.Module, # Frozen original
reg_weight: float = 0.1
):
self.model = model
self.reference_model = reference_model
self.reg_weight = reg_weight
# Reference model freeze
for param in self.reference_model.parameters():
param.requires_grad = False
def compute_loss(
self,
input_ids: torch.Tensor,
labels: torch.Tensor,
regularization: str = "kl"
) -> torch.Tensor:
"""
Regularization methods:
- "kl": KL divergence with reference model
- "ewc": Elastic Weight Consolidation
- "replay": Experience replay (separate implementation)
"""
# Main loss
outputs = self.model(input_ids, labels=labels)
lm_loss = outputs.loss
# Regularization
if regularization == "kl":
reg_loss = self._kl_regularization(input_ids)
elif regularization == "ewc":
reg_loss = self._ewc_regularization()
else:
reg_loss = 0.0
return lm_loss + self.reg_weight * reg_loss
def _kl_regularization(self, input_ids: torch.Tensor) -> torch.Tensor:
"""KL divergence-based regularization"""
with torch.no_grad():
ref_logits = self.reference_model(input_ids).logits
current_logits = self.model(input_ids).logits
# KL(current || reference)
kl_loss = nn.functional.kl_div(
nn.functional.log_softmax(current_logits, dim=-1),
nn.functional.softmax(ref_logits, dim=-1),
reduction="batchmean"
)
return kl_loss
def _ewc_regularization(self) -> torch.Tensor:
"""
Elastic Weight Consolidation
L_ewc = Ξ£α΅’ Fα΅’(ΞΈα΅’ - ΞΈα΅’*)Β²
Fα΅’: Fisher information (importance)
ΞΈα΅’*: original parameters
"""
if not hasattr(self, 'fisher_info'):
# Fisher information needs to be pre-computed
return torch.tensor(0.0)
ewc_loss = 0.0
for name, param in self.model.named_parameters():
if name in self.fisher_info:
ewc_loss += (
self.fisher_info[name] *
(param - self.original_params[name]).pow(2)
).sum()
return ewc_loss
def compute_fisher_information(
self,
dataloader,
num_samples: int = 1000
):
"""Compute Fisher Information"""
self.fisher_info = {}
self.original_params = {}
# Save original parameters
for name, param in self.model.named_parameters():
self.original_params[name] = param.clone().detach()
self.fisher_info[name] = torch.zeros_like(param)
self.model.eval()
for i, batch in enumerate(dataloader):
if i >= num_samples:
break
input_ids = batch["input_ids"]
outputs = self.model(input_ids)
log_probs = nn.functional.log_softmax(outputs.logits, dim=-1)
# Sample from output distribution
sampled = torch.multinomial(
log_probs.view(-1, log_probs.size(-1)).exp(), 1
)
# Compute gradients
loss = -log_probs.view(-1, log_probs.size(-1)).gather(1, sampled).mean()
loss.backward()
# Accumulate squared gradients
for name, param in self.model.named_parameters():
if param.grad is not None:
self.fisher_info[name] += param.grad.pow(2)
self.model.zero_grad()
# Normalize
for name in self.fisher_info:
self.fisher_info[name] /= num_samples
2.3 Experience Replay¶
class ExperienceReplayTrainer:
"""Prevent forgetting with Experience Replay"""
def __init__(
self,
model: nn.Module,
domain_dataloader,
general_dataloader, # General data
replay_ratio: float = 0.1
):
self.model = model
self.domain_dataloader = domain_dataloader
self.general_dataloader = general_dataloader
self.replay_ratio = replay_ratio
def train_step(self, optimizer) -> Dict[str, float]:
"""Mixed training with domain + general data"""
# Domain data
domain_batch = next(iter(self.domain_dataloader))
domain_loss = self._compute_lm_loss(domain_batch)
# Replay (general data)
if torch.rand(1).item() < self.replay_ratio:
general_batch = next(iter(self.general_dataloader))
replay_loss = self._compute_lm_loss(general_batch)
total_loss = domain_loss + replay_loss
else:
replay_loss = torch.tensor(0.0)
total_loss = domain_loss
# Backward
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return {
"domain_loss": domain_loss.item(),
"replay_loss": replay_loss.item() if isinstance(replay_loss, torch.Tensor) else 0.0
}
def _compute_lm_loss(self, batch) -> torch.Tensor:
outputs = self.model(
input_ids=batch["input_ids"],
labels=batch["labels"]
)
return outputs.loss
3. Data Preparation¶
3.1 Domain Data Collection¶
class DomainDataPipeline:
"""Domain data preprocessing pipeline"""
def __init__(self, domain: str):
self.domain = domain
self.quality_filters = []
def add_filter(self, filter_fn):
self.quality_filters.append(filter_fn)
def process_document(self, doc: str) -> Optional[str]:
"""Document preprocessing"""
# Basic cleaning
doc = self._clean_text(doc)
# Quality filtering
for filter_fn in self.quality_filters:
if not filter_fn(doc):
return None
return doc
def _clean_text(self, text: str) -> str:
"""Text cleaning"""
import re
# Remove HTML tags
text = re.sub(r'<[^>]+>', '', text)
# Normalize special characters
text = re.sub(r'\s+', ' ', text)
# Domain-specific cleaning
if self.domain == "medical":
# Anonymize patient information
text = re.sub(r'\b\d{3}-\d{2}-\d{4}\b', '[ID]', text) # SSN pattern
return text.strip()
# Quality filter examples
def length_filter(min_len: int = 100, max_len: int = 100000):
def filter_fn(doc):
return min_len <= len(doc) <= max_len
return filter_fn
def language_filter(target_lang: str = "en"):
def filter_fn(doc):
from langdetect import detect
try:
return detect(doc) == target_lang
except:
return False
return filter_fn
def perplexity_filter(model, tokenizer, max_ppl: float = 100):
"""Filter low-quality (high perplexity) documents"""
def filter_fn(doc):
inputs = tokenizer(doc, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
ppl = torch.exp(outputs.loss).item()
return ppl < max_ppl
return filter_fn
3.2 Data Mixing Strategy¶
class CurriculumDataMixer:
"""Curriculum learning-based data mixing"""
def __init__(
self,
domain_data: List[str],
general_data: List[str],
total_steps: int
):
self.domain_data = domain_data
self.general_data = general_data
self.total_steps = total_steps
def get_mix_ratio(self, current_step: int) -> float:
"""
Progressively increase domain data ratio
Step 0: 50% domain, 50% general
Step T: 90% domain, 10% general
"""
progress = current_step / self.total_steps
domain_ratio = 0.5 + 0.4 * progress # 0.5 β 0.9
return domain_ratio
def sample_batch(self, batch_size: int, current_step: int) -> List[str]:
"""Sample batch appropriate for current step"""
domain_ratio = self.get_mix_ratio(current_step)
num_domain = int(batch_size * domain_ratio)
num_general = batch_size - num_domain
batch = (
random.sample(self.domain_data, min(num_domain, len(self.domain_data))) +
random.sample(self.general_data, min(num_general, len(self.general_data)))
)
random.shuffle(batch)
return batch
4. Training Configuration¶
4.1 Learning Rate Strategy¶
from transformers import get_scheduler
def get_cpt_lr_scheduler(
optimizer,
num_training_steps: int,
warmup_ratio: float = 0.03,
min_lr_ratio: float = 0.1
):
"""
LR scheduler for Continued Pre-training
- Low initial LR (prevent base model damage)
- Long warmup
- Cosine decay
"""
num_warmup_steps = int(num_training_steps * warmup_ratio)
scheduler = get_scheduler(
"cosine",
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return scheduler
# Recommended hyperparameters
CPT_CONFIG = {
"learning_rate": 1e-5, # Lower LR than base model
"warmup_ratio": 0.03,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"batch_size": 256, # Large batch for stability
"gradient_accumulation_steps": 16,
"num_epochs": 1, # Usually 1 epoch is sufficient
}
4.2 Checkpointing Strategy¶
class CPTCheckpointer:
"""Continued Pre-training checkpointer"""
def __init__(
self,
model,
save_dir: str,
eval_dataloader,
save_steps: int = 1000,
keep_last_n: int = 3
):
self.model = model
self.save_dir = save_dir
self.eval_dataloader = eval_dataloader
self.save_steps = save_steps
self.keep_last_n = keep_last_n
self.saved_checkpoints = []
self.best_ppl = float('inf')
def maybe_save(self, step: int, loss: float):
"""Conditional save"""
if step % self.save_steps == 0:
# Evaluate
ppl = self._evaluate()
# Save
ckpt_path = f"{self.save_dir}/checkpoint-{step}"
self.model.save_pretrained(ckpt_path)
self.saved_checkpoints.append((step, ppl, ckpt_path))
# Update best
if ppl < self.best_ppl:
self.best_ppl = ppl
best_path = f"{self.save_dir}/best"
self.model.save_pretrained(best_path)
print(f"New best: ppl={ppl:.2f}")
# Delete old checkpoints
self._cleanup_old_checkpoints()
def _evaluate(self) -> float:
"""Perplexity evaluation"""
self.model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for batch in self.eval_dataloader:
outputs = self.model(**batch)
total_loss += outputs.loss.item() * batch["input_ids"].numel()
total_tokens += batch["input_ids"].numel()
self.model.train()
ppl = math.exp(total_loss / total_tokens)
return ppl
def _cleanup_old_checkpoints(self):
"""Delete old checkpoints"""
if len(self.saved_checkpoints) > self.keep_last_n:
# Sort by PPL
sorted_ckpts = sorted(self.saved_checkpoints, key=lambda x: x[1])
to_keep = sorted_ckpts[:self.keep_last_n]
to_remove = set(self.saved_checkpoints) - set(to_keep)
for _, _, path in to_remove:
if os.path.exists(path):
shutil.rmtree(path)
self.saved_checkpoints = list(to_keep)
5. Domain-Specific Examples¶
5.1 Medical Domain¶
class MedicalCPT:
"""Medical domain Continued Pre-training"""
def __init__(self, base_model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(base_model_name)
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
def prepare_medical_data(self, sources: List[str]) -> List[str]:
"""Prepare medical data"""
processed = []
for source in sources:
if source == "pubmed":
# PubMed abstracts
data = self._load_pubmed()
elif source == "clinical_notes":
# Clinical notes (anonymized)
data = self._load_clinical_notes()
elif source == "medical_textbooks":
# Medical textbooks
data = self._load_textbooks()
processed.extend(data)
return processed
def _load_pubmed(self) -> List[str]:
"""Load PubMed data"""
from datasets import load_dataset
dataset = load_dataset("pubmed", split="train")
return [ex["abstract"] for ex in dataset if len(ex["abstract"]) > 100]
def train(self, data: List[str], output_dir: str):
"""Run training"""
# Tokenize
def tokenize(examples):
return self.tokenizer(
examples["text"],
truncation=True,
max_length=2048
)
dataset = Dataset.from_dict({"text": data})
tokenized = dataset.map(tokenize, batched=True)
# Training
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
learning_rate=5e-6, # Low LR
num_train_epochs=1,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
logging_steps=100,
save_steps=500,
fp16=True
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=tokenized,
data_collator=DataCollatorForLanguageModeling(
tokenizer=self.tokenizer, mlm=False
)
)
trainer.train()
5.2 Code Domain¶
class CodeCPT:
"""Code domain Continued Pre-training"""
def prepare_code_data(self) -> List[str]:
"""Prepare code data"""
from datasets import load_dataset
# The Stack
dataset = load_dataset(
"bigcode/the-stack",
data_dir="data/python",
split="train",
streaming=True
)
processed = []
for example in dataset:
code = example["content"]
# Quality filtering
if self._is_quality_code(code):
processed.append(code)
if len(processed) >= 1000000: # 1M samples
break
return processed
def _is_quality_code(self, code: str) -> bool:
"""Code quality check"""
# Length
if len(code) < 50 or len(code) > 100000:
return False
# Comment ratio
lines = code.split("\n")
comment_lines = sum(1 for l in lines if l.strip().startswith("#"))
if len(lines) > 0 and comment_lines / len(lines) > 0.5:
return False
# Syntax check
try:
import ast
ast.parse(code)
return True
except SyntaxError:
return False
Key Summary¶
Continued Pre-training Core¶
1. Purpose: Domain knowledge injection
2. Data: Large amounts of domain text (10B+ tokens)
3. Method: Causal LM objective
4. Caution: Catastrophic forgetting
Forgetting Mitigation Strategies¶
1. KL Regularization: Minimize KL with reference model
2. EWC: Preserve important parameters
3. Experience Replay: Mix general data
4. Curriculum: Progressive domain ratio increase
Training Recommendations¶
- Learning Rate: 1/10 ~ 1/5 of base
- Warmup: 3-5%
- Batch Size: Large batch (256+)
- Epochs: 1 epoch
- Checkpointing: Save frequently, monitor perplexity
References¶
- Gururangan et al. (2020). "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks"
- Ke et al. (2023). "Continual Pre-training of Language Models"
- Ibrahim et al. (2024). "Simple and Scalable Strategies to Continually Pre-train Large Language Models"
- Xie et al. (2023). "Efficient Continual Pre-training for Building Domain Specific Large Language Models"