21. Continued Pre-training
21. Continued Pre-training¶
๊ฐ์¶
Continued Pre-training(์ง์ ์ฌ์ ํ์ต)์ ๊ธฐ์กด pre-trained ๋ชจ๋ธ์ ํน์ ๋๋ฉ์ธ์ด๋ ํ์คํฌ์ ๋ง๊ฒ ์ถ๊ฐ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ์ผ๋ฐ์ ์ธ fine-tuning๊ณผ ๋ฌ๋ฆฌ ๋๋์ ๋๋ฉ์ธ ๋ฐ์ดํฐ๋ก language modeling์ ์ํํฉ๋๋ค.
1. Continued Pre-training ๊ฐ์¶
1.1 ์ ํ์ํ๊ฐ?¶
์๋๋ฆฌ์ค:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Base Model (LLaMA-7B) โ
โ - ํ์ต: ์ผ๋ฐ ์น ํ
์คํธ โ
โ - ๊ฐ์ : ์ผ๋ฐ์ ์ธ ์ธ์ด ์ดํด โ
โ - ์ฝ์ : ๋๋ฉ์ธ ํนํ ์ง์ ๋ถ์กฑ โ
โ โ
โ ๋ชฉํ ๋๋ฉ์ธ: ์๋ฃ โ
โ - ์ ๋ฌธ ์ฉ์ด (์ฝ๋ฌผ๋ช
, ์ง๋ณ๋ช
) โ
โ - ๋๋ฉ์ธ ํนํ ์ถ๋ก โ
โ - ํน์ ๋ฌธ์ ํ์ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
ํด๊ฒฐ์ฑ
:
1. Instruction Tuning๋ง์ผ๋ก๋ ์ง์ ์ฃผ์
์ด๋ ค์
2. Continued Pre-training์ผ๋ก ๋๋ฉ์ธ ์ง์ ํ์ต
3. ์ดํ Instruction Tuning์ผ๋ก ํ์คํฌ ์ ์
1.2 ํ์ต ํ์ดํ๋ผ์ธ¶
์ผ๋ฐ์ ์ธ ํ์ดํ๋ผ์ธ:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ โ
โ Pre-trained Model โ
โ โ โ
โ [Continued Pre-training] โ
โ - ๋๋ฉ์ธ ๋ฐ์ดํฐ (10B+ tokens) โ
โ - Causal LM objective โ
โ - Lower learning rate โ
โ โ โ
โ Domain-Adapted Model โ
โ โ โ
โ [Instruction Tuning] โ
โ - ๋๋ฉ์ธ ํนํ instructions โ
โ โ โ
โ Final Domain Model โ
โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
2. Catastrophic Forgetting¶
2.1 ๋ฌธ์ ์ ์¶
Catastrophic Forgetting:
์๋ก์ด ์ง์์ ํ์ตํ๋ฉด์ ๊ธฐ์กด ์ง์์ ์์ด๋ฒ๋ฆฌ๋ ํ์
์์:
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Before CPT: โ
โ Q: "What is the capital of France?" โ
โ A: "Paris" โ โ
โ โ
โ After CPT (medical domain): โ
โ Q: "What is the capital of France?" โ
โ A: "The patient presented with..." โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
2.2 ์ํ ์ ๋ต¶
import torch
import torch.nn as nn
from typing import Dict, List, Optional
class ContinuedPretrainingWithRegularization:
"""Catastrophic Forgetting ์ํ๋ฅผ ์ํ ํ์ต"""
def __init__(
self,
model: nn.Module,
reference_model: nn.Module, # Frozen original
reg_weight: float = 0.1
):
self.model = model
self.reference_model = reference_model
self.reg_weight = reg_weight
# Reference model freeze
for param in self.reference_model.parameters():
param.requires_grad = False
def compute_loss(
self,
input_ids: torch.Tensor,
labels: torch.Tensor,
regularization: str = "kl"
) -> torch.Tensor:
"""
Regularization methods:
- "kl": KL divergence with reference model
- "ewc": Elastic Weight Consolidation
- "replay": Experience replay (๋ณ๋ ๊ตฌํ)
"""
# Main loss
outputs = self.model(input_ids, labels=labels)
lm_loss = outputs.loss
# Regularization
if regularization == "kl":
reg_loss = self._kl_regularization(input_ids)
elif regularization == "ewc":
reg_loss = self._ewc_regularization()
else:
reg_loss = 0.0
return lm_loss + self.reg_weight * reg_loss
def _kl_regularization(self, input_ids: torch.Tensor) -> torch.Tensor:
"""KL divergence ๊ธฐ๋ฐ ์ ๊ทํ"""
with torch.no_grad():
ref_logits = self.reference_model(input_ids).logits
current_logits = self.model(input_ids).logits
# KL(current || reference)
kl_loss = nn.functional.kl_div(
nn.functional.log_softmax(current_logits, dim=-1),
nn.functional.softmax(ref_logits, dim=-1),
reduction="batchmean"
)
return kl_loss
def _ewc_regularization(self) -> torch.Tensor:
"""
Elastic Weight Consolidation
L_ewc = ฮฃแตข Fแตข(ฮธแตข - ฮธแตข*)ยฒ
Fแตข: Fisher information (importance)
ฮธแตข*: original parameters
"""
if not hasattr(self, 'fisher_info'):
# Fisher information ์ฌ์ ๊ณ์ฐ ํ์
return torch.tensor(0.0)
ewc_loss = 0.0
for name, param in self.model.named_parameters():
if name in self.fisher_info:
ewc_loss += (
self.fisher_info[name] *
(param - self.original_params[name]).pow(2)
).sum()
return ewc_loss
def compute_fisher_information(
self,
dataloader,
num_samples: int = 1000
):
"""Fisher Information ๊ณ์ฐ"""
self.fisher_info = {}
self.original_params = {}
# Original parameters ์ ์ฅ
for name, param in self.model.named_parameters():
self.original_params[name] = param.clone().detach()
self.fisher_info[name] = torch.zeros_like(param)
self.model.eval()
for i, batch in enumerate(dataloader):
if i >= num_samples:
break
input_ids = batch["input_ids"]
outputs = self.model(input_ids)
log_probs = nn.functional.log_softmax(outputs.logits, dim=-1)
# Sample from output distribution
sampled = torch.multinomial(
log_probs.view(-1, log_probs.size(-1)).exp(), 1
)
# Compute gradients
loss = -log_probs.view(-1, log_probs.size(-1)).gather(1, sampled).mean()
loss.backward()
# Accumulate squared gradients
for name, param in self.model.named_parameters():
if param.grad is not None:
self.fisher_info[name] += param.grad.pow(2)
self.model.zero_grad()
# Normalize
for name in self.fisher_info:
self.fisher_info[name] /= num_samples
2.3 Experience Replay¶
class ExperienceReplayTrainer:
"""Experience Replay๋ก forgetting ๋ฐฉ์ง"""
def __init__(
self,
model: nn.Module,
domain_dataloader,
general_dataloader, # ์ผ๋ฐ ๋ฐ์ดํฐ
replay_ratio: float = 0.1
):
self.model = model
self.domain_dataloader = domain_dataloader
self.general_dataloader = general_dataloader
self.replay_ratio = replay_ratio
def train_step(self, optimizer) -> Dict[str, float]:
"""๋๋ฉ์ธ + ์ผ๋ฐ ๋ฐ์ดํฐ ํผํฉ ํ์ต"""
# Domain data
domain_batch = next(iter(self.domain_dataloader))
domain_loss = self._compute_lm_loss(domain_batch)
# Replay (general data)
if torch.rand(1).item() < self.replay_ratio:
general_batch = next(iter(self.general_dataloader))
replay_loss = self._compute_lm_loss(general_batch)
total_loss = domain_loss + replay_loss
else:
replay_loss = torch.tensor(0.0)
total_loss = domain_loss
# Backward
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return {
"domain_loss": domain_loss.item(),
"replay_loss": replay_loss.item() if isinstance(replay_loss, torch.Tensor) else 0.0
}
def _compute_lm_loss(self, batch) -> torch.Tensor:
outputs = self.model(
input_ids=batch["input_ids"],
labels=batch["labels"]
)
return outputs.loss
3. ๋ฐ์ดํฐ ์ค๋น¶
3.1 ๋๋ฉ์ธ ๋ฐ์ดํฐ ์์ง¶
class DomainDataPipeline:
"""๋๋ฉ์ธ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ"""
def __init__(self, domain: str):
self.domain = domain
self.quality_filters = []
def add_filter(self, filter_fn):
self.quality_filters.append(filter_fn)
def process_document(self, doc: str) -> Optional[str]:
"""๋ฌธ์ ์ ์ฒ๋ฆฌ"""
# ๊ธฐ๋ณธ ์ ์
doc = self._clean_text(doc)
# ํ์ง ํํฐ๋ง
for filter_fn in self.quality_filters:
if not filter_fn(doc):
return None
return doc
def _clean_text(self, text: str) -> str:
"""ํ
์คํธ ์ ์ """
import re
# HTML ํ๊ทธ ์ ๊ฑฐ
text = re.sub(r'<[^>]+>', '', text)
# ํน์ ๋ฌธ์ ์ ๊ทํ
text = re.sub(r'\s+', ' ', text)
# ๋๋ฉ์ธ ํนํ ์ ์
if self.domain == "medical":
# ํ์ ์ ๋ณด ์ต๋ช
ํ
text = re.sub(r'\b\d{6}-\d{7}\b', '[ID]', text) # ์ฃผ๋ฏผ๋ฒํธ ํจํด
return text.strip()
# ํ์ง ํํฐ ์์
def length_filter(min_len: int = 100, max_len: int = 100000):
def filter_fn(doc):
return min_len <= len(doc) <= max_len
return filter_fn
def language_filter(target_lang: str = "ko"):
def filter_fn(doc):
from langdetect import detect
try:
return detect(doc) == target_lang
except:
return False
return filter_fn
def perplexity_filter(model, tokenizer, max_ppl: float = 100):
"""ํ์ง์ด ๋ฎ์ (perplexity ๋์) ๋ฌธ์ ํํฐ๋ง"""
def filter_fn(doc):
inputs = tokenizer(doc, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
ppl = torch.exp(outputs.loss).item()
return ppl < max_ppl
return filter_fn
3.2 ๋ฐ์ดํฐ ๋ฏน์ฑ ์ ๋ต¶
class CurriculumDataMixer:
"""์ปค๋ฆฌํ๋ผ ํ์ต ๊ธฐ๋ฐ ๋ฐ์ดํฐ ๋ฏน์ฑ"""
def __init__(
self,
domain_data: List[str],
general_data: List[str],
total_steps: int
):
self.domain_data = domain_data
self.general_data = general_data
self.total_steps = total_steps
def get_mix_ratio(self, current_step: int) -> float:
"""
์ ์ง์ ์ผ๋ก ๋๋ฉ์ธ ๋ฐ์ดํฐ ๋น์จ ์ฆ๊ฐ
Step 0: 50% domain, 50% general
Step T: 90% domain, 10% general
"""
progress = current_step / self.total_steps
domain_ratio = 0.5 + 0.4 * progress # 0.5 โ 0.9
return domain_ratio
def sample_batch(self, batch_size: int, current_step: int) -> List[str]:
"""ํ์ฌ step์ ๋ง๋ ๋ฐฐ์น ์ํ๋ง"""
domain_ratio = self.get_mix_ratio(current_step)
num_domain = int(batch_size * domain_ratio)
num_general = batch_size - num_domain
batch = (
random.sample(self.domain_data, min(num_domain, len(self.domain_data))) +
random.sample(self.general_data, min(num_general, len(self.general_data)))
)
random.shuffle(batch)
return batch
4. ํ์ต ์ค์ ¶
4.1 Learning Rate ์ ๋ต¶
from transformers import get_scheduler
def get_cpt_lr_scheduler(
optimizer,
num_training_steps: int,
warmup_ratio: float = 0.03,
min_lr_ratio: float = 0.1
):
"""
Continued Pre-training์ฉ LR ์ค์ผ์ค๋ฌ
- ๋ฎ์ ์ด๊ธฐ LR (base model ์์ ๋ฐฉ์ง)
- ๊ธด warmup
- Cosine decay
"""
num_warmup_steps = int(num_training_steps * warmup_ratio)
scheduler = get_scheduler(
"cosine",
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return scheduler
# ๊ถ์ฅ ํ์ดํผํ๋ผ๋ฏธํฐ
CPT_CONFIG = {
"learning_rate": 1e-5, # Base model ๋๋น ๋ฎ์ LR
"warmup_ratio": 0.03,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"batch_size": 256, # Large batch for stability
"gradient_accumulation_steps": 16,
"num_epochs": 1, # ๋ณดํต 1 epoch๋ฉด ์ถฉ๋ถ
}
4.2 ์ฒดํฌํฌ์ธํ ์ ๋ต¶
class CPTCheckpointer:
"""Continued Pre-training ์ฒดํฌํฌ์ธํฐ"""
def __init__(
self,
model,
save_dir: str,
eval_dataloader,
save_steps: int = 1000,
keep_last_n: int = 3
):
self.model = model
self.save_dir = save_dir
self.eval_dataloader = eval_dataloader
self.save_steps = save_steps
self.keep_last_n = keep_last_n
self.saved_checkpoints = []
self.best_ppl = float('inf')
def maybe_save(self, step: int, loss: float):
"""์กฐ๊ฑด๋ถ ์ ์ฅ"""
if step % self.save_steps == 0:
# ํ๊ฐ
ppl = self._evaluate()
# ์ ์ฅ
ckpt_path = f"{self.save_dir}/checkpoint-{step}"
self.model.save_pretrained(ckpt_path)
self.saved_checkpoints.append((step, ppl, ckpt_path))
# Best ์
๋ฐ์ดํธ
if ppl < self.best_ppl:
self.best_ppl = ppl
best_path = f"{self.save_dir}/best"
self.model.save_pretrained(best_path)
print(f"New best: ppl={ppl:.2f}")
# ์ค๋๋ ์ฒดํฌํฌ์ธํธ ์ญ์
self._cleanup_old_checkpoints()
def _evaluate(self) -> float:
"""Perplexity ํ๊ฐ"""
self.model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad():
for batch in self.eval_dataloader:
outputs = self.model(**batch)
total_loss += outputs.loss.item() * batch["input_ids"].numel()
total_tokens += batch["input_ids"].numel()
self.model.train()
ppl = math.exp(total_loss / total_tokens)
return ppl
def _cleanup_old_checkpoints(self):
"""์ค๋๋ ์ฒดํฌํฌ์ธํธ ์ญ์ """
if len(self.saved_checkpoints) > self.keep_last_n:
# PPL ๊ธฐ์ค ์ ๋ ฌ
sorted_ckpts = sorted(self.saved_checkpoints, key=lambda x: x[1])
to_keep = sorted_ckpts[:self.keep_last_n]
to_remove = set(self.saved_checkpoints) - set(to_keep)
for _, _, path in to_remove:
if os.path.exists(path):
shutil.rmtree(path)
self.saved_checkpoints = list(to_keep)
5. ๋๋ฉ์ธ๋ณ ์์¶
5.1 ์๋ฃ ๋๋ฉ์ธ¶
class MedicalCPT:
"""์๋ฃ ๋๋ฉ์ธ Continued Pre-training"""
def __init__(self, base_model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(base_model_name)
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
def prepare_medical_data(self, sources: List[str]) -> List[str]:
"""์๋ฃ ๋ฐ์ดํฐ ์ค๋น"""
processed = []
for source in sources:
if source == "pubmed":
# PubMed abstracts
data = self._load_pubmed()
elif source == "clinical_notes":
# ์์ ๋
ธํธ (์ต๋ช
ํ)
data = self._load_clinical_notes()
elif source == "medical_textbooks":
# ์ํ ๊ต๊ณผ์
data = self._load_textbooks()
processed.extend(data)
return processed
def _load_pubmed(self) -> List[str]:
"""PubMed ๋ฐ์ดํฐ ๋ก๋"""
from datasets import load_dataset
dataset = load_dataset("pubmed", split="train")
return [ex["abstract"] for ex in dataset if len(ex["abstract"]) > 100]
def train(self, data: List[str], output_dir: str):
"""ํ์ต ์คํ"""
# ํ ํฌ๋์ด์ง
def tokenize(examples):
return self.tokenizer(
examples["text"],
truncation=True,
max_length=2048
)
dataset = Dataset.from_dict({"text": data})
tokenized = dataset.map(tokenize, batched=True)
# ํ์ต
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=4,
gradient_accumulation_steps=16,
learning_rate=5e-6, # ๋ฎ์ LR
num_train_epochs=1,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
logging_steps=100,
save_steps=500,
fp16=True
)
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=tokenized,
data_collator=DataCollatorForLanguageModeling(
tokenizer=self.tokenizer, mlm=False
)
)
trainer.train()
5.2 ์ฝ๋ ๋๋ฉ์ธ¶
class CodeCPT:
"""์ฝ๋ ๋๋ฉ์ธ Continued Pre-training"""
def prepare_code_data(self) -> List[str]:
"""์ฝ๋ ๋ฐ์ดํฐ ์ค๋น"""
from datasets import load_dataset
# The Stack
dataset = load_dataset(
"bigcode/the-stack",
data_dir="data/python",
split="train",
streaming=True
)
processed = []
for example in dataset:
code = example["content"]
# ํ์ง ํํฐ๋ง
if self._is_quality_code(code):
processed.append(code)
if len(processed) >= 1000000: # 1M ์ํ
break
return processed
def _is_quality_code(self, code: str) -> bool:
"""์ฝ๋ ํ์ง ๊ฒ์ฌ"""
# ๊ธธ์ด
if len(code) < 50 or len(code) > 100000:
return False
# ์ฃผ์ ๋น์จ
lines = code.split("\n")
comment_lines = sum(1 for l in lines if l.strip().startswith("#"))
if len(lines) > 0 and comment_lines / len(lines) > 0.5:
return False
# ๊ตฌ๋ฌธ ๊ฒ์ฌ
try:
import ast
ast.parse(code)
return True
except SyntaxError:
return False
ํต์ฌ ์ ๋ฆฌ¶
Continued Pre-training ํต์ฌ¶
1. ๋ชฉ์ : ๋๋ฉ์ธ ์ง์ ์ฃผ์
2. ๋ฐ์ดํฐ: ๋๋์ ๋๋ฉ์ธ ํ
์คํธ (10B+ tokens)
3. ๋ฐฉ๋ฒ: Causal LM objective
4. ์ฃผ์: Catastrophic forgetting
Forgetting ์ํ ์ ๋ต¶
1. KL Regularization: reference ๋ชจ๋ธ๊ณผ์ KL ์ต์ํ
2. EWC: ์ค์ ํ๋ผ๋ฏธํฐ ๋ณด์กด
3. Experience Replay: ์ผ๋ฐ ๋ฐ์ดํฐ ํผํฉ
4. Curriculum: ์ ์ง์ ๋๋ฉ์ธ ๋น์จ ์ฆ๊ฐ
ํ์ต ๊ถ์ฅ ์ฌํญ¶
- Learning Rate: base์ 1/10 ~ 1/5
- Warmup: 3-5%
- Batch Size: ํฐ ๋ฐฐ์น (256+)
- Epochs: 1 epoch
- Checkpointing: ์์ฃผ ์ ์ฅ, perplexity ๋ชจ๋ํฐ๋ง
์ฐธ๊ณ ์๋ฃ¶
- Gururangan et al. (2020). "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks"
- Ke et al. (2023). "Continual Pre-training of Language Models"
- Ibrahim et al. (2024). "Simple and Scalable Strategies to Continually Pre-train Large Language Models"
- Xie et al. (2023). "Efficient Continual Pre-training for Building Domain Specific Large Language Models"