21. Continued Pre-training

21. Continued Pre-training

๊ฐœ์š”

Continued Pre-training(์ง€์† ์‚ฌ์ „ํ•™์Šต)์€ ๊ธฐ์กด pre-trained ๋ชจ๋ธ์„ ํŠน์ • ๋„๋ฉ”์ธ์ด๋‚˜ ํƒœ์Šคํฌ์— ๋งž๊ฒŒ ์ถ”๊ฐ€ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ fine-tuning๊ณผ ๋‹ฌ๋ฆฌ ๋Œ€๋Ÿ‰์˜ ๋„๋ฉ”์ธ ๋ฐ์ดํ„ฐ๋กœ language modeling์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.


1. Continued Pre-training ๊ฐœ์š”

1.1 ์™œ ํ•„์š”ํ•œ๊ฐ€?

์‹œ๋‚˜๋ฆฌ์˜ค:
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚  Base Model (LLaMA-7B)                                  โ”‚
โ”‚  - ํ•™์Šต: ์ผ๋ฐ˜ ์›น ํ…์ŠคํŠธ                                  โ”‚
โ”‚  - ๊ฐ•์ : ์ผ๋ฐ˜์ ์ธ ์–ธ์–ด ์ดํ•ด                              โ”‚
โ”‚  - ์•ฝ์ : ๋„๋ฉ”์ธ ํŠนํ™” ์ง€์‹ ๋ถ€์กฑ                           โ”‚
โ”‚                                                         โ”‚
โ”‚  ๋ชฉํ‘œ ๋„๋ฉ”์ธ: ์˜๋ฃŒ                                       โ”‚
โ”‚  - ์ „๋ฌธ ์šฉ์–ด (์•ฝ๋ฌผ๋ช…, ์งˆ๋ณ‘๋ช…)                            โ”‚
โ”‚  - ๋„๋ฉ”์ธ ํŠนํ™” ์ถ”๋ก                                       โ”‚
โ”‚  - ํŠน์ˆ˜ ๋ฌธ์„œ ํ˜•์‹                                        โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

ํ•ด๊ฒฐ์ฑ…:
1. Instruction Tuning๋งŒ์œผ๋กœ๋Š” ์ง€์‹ ์ฃผ์ž… ์–ด๋ ค์›€
2. Continued Pre-training์œผ๋กœ ๋„๋ฉ”์ธ ์ง€์‹ ํ•™์Šต
3. ์ดํ›„ Instruction Tuning์œผ๋กœ ํƒœ์Šคํฌ ์ ์‘

1.2 ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ

์ผ๋ฐ˜์ ์ธ ํŒŒ์ดํ”„๋ผ์ธ:
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                                                         โ”‚
โ”‚  Pre-trained Model                                      โ”‚
โ”‚         โ†“                                              โ”‚
โ”‚  [Continued Pre-training]                              โ”‚
โ”‚  - ๋„๋ฉ”์ธ ๋ฐ์ดํ„ฐ (10B+ tokens)                          โ”‚
โ”‚  - Causal LM objective                                  โ”‚
โ”‚  - Lower learning rate                                  โ”‚
โ”‚         โ†“                                              โ”‚
โ”‚  Domain-Adapted Model                                   โ”‚
โ”‚         โ†“                                              โ”‚
โ”‚  [Instruction Tuning]                                  โ”‚
โ”‚  - ๋„๋ฉ”์ธ ํŠนํ™” instructions                            โ”‚
โ”‚         โ†“                                              โ”‚
โ”‚  Final Domain Model                                     โ”‚
โ”‚                                                         โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2. Catastrophic Forgetting

2.1 ๋ฌธ์ œ ์ •์˜

Catastrophic Forgetting:
์ƒˆ๋กœ์šด ์ง€์‹์„ ํ•™์Šตํ•˜๋ฉด์„œ ๊ธฐ์กด ์ง€์‹์„ ์žŠ์–ด๋ฒ„๋ฆฌ๋Š” ํ˜„์ƒ

์˜ˆ์‹œ:
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚  Before CPT:                           โ”‚
โ”‚  Q: "What is the capital of France?"   โ”‚
โ”‚  A: "Paris"  โœ“                         โ”‚
โ”‚                                        โ”‚
โ”‚  After CPT (medical domain):           โ”‚
โ”‚  Q: "What is the capital of France?"   โ”‚
โ”‚  A: "The patient presented with..."  โœ— โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

2.2 ์™„ํ™” ์ „๋žต

import torch
import torch.nn as nn
from typing import Dict, List, Optional

class ContinuedPretrainingWithRegularization:
    """Catastrophic Forgetting ์™„ํ™”๋ฅผ ์œ„ํ•œ ํ•™์Šต"""

    def __init__(
        self,
        model: nn.Module,
        reference_model: nn.Module,  # Frozen original
        reg_weight: float = 0.1
    ):
        self.model = model
        self.reference_model = reference_model
        self.reg_weight = reg_weight

        # Reference model freeze
        for param in self.reference_model.parameters():
            param.requires_grad = False

    def compute_loss(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor,
        regularization: str = "kl"
    ) -> torch.Tensor:
        """
        Regularization methods:
        - "kl": KL divergence with reference model
        - "ewc": Elastic Weight Consolidation
        - "replay": Experience replay (๋ณ„๋„ ๊ตฌํ˜„)
        """
        # Main loss
        outputs = self.model(input_ids, labels=labels)
        lm_loss = outputs.loss

        # Regularization
        if regularization == "kl":
            reg_loss = self._kl_regularization(input_ids)
        elif regularization == "ewc":
            reg_loss = self._ewc_regularization()
        else:
            reg_loss = 0.0

        return lm_loss + self.reg_weight * reg_loss

    def _kl_regularization(self, input_ids: torch.Tensor) -> torch.Tensor:
        """KL divergence ๊ธฐ๋ฐ˜ ์ •๊ทœํ™”"""
        with torch.no_grad():
            ref_logits = self.reference_model(input_ids).logits

        current_logits = self.model(input_ids).logits

        # KL(current || reference)
        kl_loss = nn.functional.kl_div(
            nn.functional.log_softmax(current_logits, dim=-1),
            nn.functional.softmax(ref_logits, dim=-1),
            reduction="batchmean"
        )

        return kl_loss

    def _ewc_regularization(self) -> torch.Tensor:
        """
        Elastic Weight Consolidation

        L_ewc = ฮฃแตข Fแตข(ฮธแตข - ฮธแตข*)ยฒ

        Fแตข: Fisher information (importance)
        ฮธแตข*: original parameters
        """
        if not hasattr(self, 'fisher_info'):
            # Fisher information ์‚ฌ์ „ ๊ณ„์‚ฐ ํ•„์š”
            return torch.tensor(0.0)

        ewc_loss = 0.0
        for name, param in self.model.named_parameters():
            if name in self.fisher_info:
                ewc_loss += (
                    self.fisher_info[name] *
                    (param - self.original_params[name]).pow(2)
                ).sum()

        return ewc_loss

    def compute_fisher_information(
        self,
        dataloader,
        num_samples: int = 1000
    ):
        """Fisher Information ๊ณ„์‚ฐ"""
        self.fisher_info = {}
        self.original_params = {}

        # Original parameters ์ €์žฅ
        for name, param in self.model.named_parameters():
            self.original_params[name] = param.clone().detach()
            self.fisher_info[name] = torch.zeros_like(param)

        self.model.eval()
        for i, batch in enumerate(dataloader):
            if i >= num_samples:
                break

            input_ids = batch["input_ids"]
            outputs = self.model(input_ids)
            log_probs = nn.functional.log_softmax(outputs.logits, dim=-1)

            # Sample from output distribution
            sampled = torch.multinomial(
                log_probs.view(-1, log_probs.size(-1)).exp(), 1
            )

            # Compute gradients
            loss = -log_probs.view(-1, log_probs.size(-1)).gather(1, sampled).mean()
            loss.backward()

            # Accumulate squared gradients
            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    self.fisher_info[name] += param.grad.pow(2)

            self.model.zero_grad()

        # Normalize
        for name in self.fisher_info:
            self.fisher_info[name] /= num_samples

2.3 Experience Replay

class ExperienceReplayTrainer:
    """Experience Replay๋กœ forgetting ๋ฐฉ์ง€"""

    def __init__(
        self,
        model: nn.Module,
        domain_dataloader,
        general_dataloader,  # ์ผ๋ฐ˜ ๋ฐ์ดํ„ฐ
        replay_ratio: float = 0.1
    ):
        self.model = model
        self.domain_dataloader = domain_dataloader
        self.general_dataloader = general_dataloader
        self.replay_ratio = replay_ratio

    def train_step(self, optimizer) -> Dict[str, float]:
        """๋„๋ฉ”์ธ + ์ผ๋ฐ˜ ๋ฐ์ดํ„ฐ ํ˜ผํ•ฉ ํ•™์Šต"""
        # Domain data
        domain_batch = next(iter(self.domain_dataloader))
        domain_loss = self._compute_lm_loss(domain_batch)

        # Replay (general data)
        if torch.rand(1).item() < self.replay_ratio:
            general_batch = next(iter(self.general_dataloader))
            replay_loss = self._compute_lm_loss(general_batch)
            total_loss = domain_loss + replay_loss
        else:
            replay_loss = torch.tensor(0.0)
            total_loss = domain_loss

        # Backward
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        return {
            "domain_loss": domain_loss.item(),
            "replay_loss": replay_loss.item() if isinstance(replay_loss, torch.Tensor) else 0.0
        }

    def _compute_lm_loss(self, batch) -> torch.Tensor:
        outputs = self.model(
            input_ids=batch["input_ids"],
            labels=batch["labels"]
        )
        return outputs.loss

3. ๋ฐ์ดํ„ฐ ์ค€๋น„

3.1 ๋„๋ฉ”์ธ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘

class DomainDataPipeline:
    """๋„๋ฉ”์ธ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ"""

    def __init__(self, domain: str):
        self.domain = domain
        self.quality_filters = []

    def add_filter(self, filter_fn):
        self.quality_filters.append(filter_fn)

    def process_document(self, doc: str) -> Optional[str]:
        """๋ฌธ์„œ ์ „์ฒ˜๋ฆฌ"""
        # ๊ธฐ๋ณธ ์ •์ œ
        doc = self._clean_text(doc)

        # ํ’ˆ์งˆ ํ•„ํ„ฐ๋ง
        for filter_fn in self.quality_filters:
            if not filter_fn(doc):
                return None

        return doc

    def _clean_text(self, text: str) -> str:
        """ํ…์ŠคํŠธ ์ •์ œ"""
        import re

        # HTML ํƒœ๊ทธ ์ œ๊ฑฐ
        text = re.sub(r'<[^>]+>', '', text)

        # ํŠน์ˆ˜ ๋ฌธ์ž ์ •๊ทœํ™”
        text = re.sub(r'\s+', ' ', text)

        # ๋„๋ฉ”์ธ ํŠนํ™” ์ •์ œ
        if self.domain == "medical":
            # ํ™˜์ž ์ •๋ณด ์ต๋ช…ํ™”
            text = re.sub(r'\b\d{6}-\d{7}\b', '[ID]', text)  # ์ฃผ๋ฏผ๋ฒˆํ˜ธ ํŒจํ„ด

        return text.strip()


# ํ’ˆ์งˆ ํ•„ํ„ฐ ์˜ˆ์‹œ
def length_filter(min_len: int = 100, max_len: int = 100000):
    def filter_fn(doc):
        return min_len <= len(doc) <= max_len
    return filter_fn

def language_filter(target_lang: str = "ko"):
    def filter_fn(doc):
        from langdetect import detect
        try:
            return detect(doc) == target_lang
        except:
            return False
    return filter_fn

def perplexity_filter(model, tokenizer, max_ppl: float = 100):
    """ํ’ˆ์งˆ์ด ๋‚ฎ์€ (perplexity ๋†’์€) ๋ฌธ์„œ ํ•„ํ„ฐ๋ง"""
    def filter_fn(doc):
        inputs = tokenizer(doc, return_tensors="pt", truncation=True)
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
        ppl = torch.exp(outputs.loss).item()
        return ppl < max_ppl
    return filter_fn

3.2 ๋ฐ์ดํ„ฐ ๋ฏน์‹ฑ ์ „๋žต

class CurriculumDataMixer:
    """์ปค๋ฆฌํ˜๋Ÿผ ํ•™์Šต ๊ธฐ๋ฐ˜ ๋ฐ์ดํ„ฐ ๋ฏน์‹ฑ"""

    def __init__(
        self,
        domain_data: List[str],
        general_data: List[str],
        total_steps: int
    ):
        self.domain_data = domain_data
        self.general_data = general_data
        self.total_steps = total_steps

    def get_mix_ratio(self, current_step: int) -> float:
        """
        ์ ์ง„์ ์œผ๋กœ ๋„๋ฉ”์ธ ๋ฐ์ดํ„ฐ ๋น„์œจ ์ฆ๊ฐ€

        Step 0: 50% domain, 50% general
        Step T: 90% domain, 10% general
        """
        progress = current_step / self.total_steps
        domain_ratio = 0.5 + 0.4 * progress  # 0.5 โ†’ 0.9
        return domain_ratio

    def sample_batch(self, batch_size: int, current_step: int) -> List[str]:
        """ํ˜„์žฌ step์— ๋งž๋Š” ๋ฐฐ์น˜ ์ƒ˜ํ”Œ๋ง"""
        domain_ratio = self.get_mix_ratio(current_step)
        num_domain = int(batch_size * domain_ratio)
        num_general = batch_size - num_domain

        batch = (
            random.sample(self.domain_data, min(num_domain, len(self.domain_data))) +
            random.sample(self.general_data, min(num_general, len(self.general_data)))
        )

        random.shuffle(batch)
        return batch

4. ํ•™์Šต ์„ค์ •

4.1 Learning Rate ์ „๋žต

from transformers import get_scheduler

def get_cpt_lr_scheduler(
    optimizer,
    num_training_steps: int,
    warmup_ratio: float = 0.03,
    min_lr_ratio: float = 0.1
):
    """
    Continued Pre-training์šฉ LR ์Šค์ผ€์ค„๋Ÿฌ

    - ๋‚ฎ์€ ์ดˆ๊ธฐ LR (base model ์†์ƒ ๋ฐฉ์ง€)
    - ๊ธด warmup
    - Cosine decay
    """
    num_warmup_steps = int(num_training_steps * warmup_ratio)

    scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    return scheduler


# ๊ถŒ์žฅ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ
CPT_CONFIG = {
    "learning_rate": 1e-5,  # Base model ๋Œ€๋น„ ๋‚ฎ์€ LR
    "warmup_ratio": 0.03,
    "weight_decay": 0.01,
    "max_grad_norm": 1.0,
    "batch_size": 256,  # Large batch for stability
    "gradient_accumulation_steps": 16,
    "num_epochs": 1,  # ๋ณดํ†ต 1 epoch๋ฉด ์ถฉ๋ถ„
}

4.2 ์ฒดํฌํฌ์ธํŒ… ์ „๋žต

class CPTCheckpointer:
    """Continued Pre-training ์ฒดํฌํฌ์ธํ„ฐ"""

    def __init__(
        self,
        model,
        save_dir: str,
        eval_dataloader,
        save_steps: int = 1000,
        keep_last_n: int = 3
    ):
        self.model = model
        self.save_dir = save_dir
        self.eval_dataloader = eval_dataloader
        self.save_steps = save_steps
        self.keep_last_n = keep_last_n
        self.saved_checkpoints = []
        self.best_ppl = float('inf')

    def maybe_save(self, step: int, loss: float):
        """์กฐ๊ฑด๋ถ€ ์ €์žฅ"""
        if step % self.save_steps == 0:
            # ํ‰๊ฐ€
            ppl = self._evaluate()

            # ์ €์žฅ
            ckpt_path = f"{self.save_dir}/checkpoint-{step}"
            self.model.save_pretrained(ckpt_path)
            self.saved_checkpoints.append((step, ppl, ckpt_path))

            # Best ์—…๋ฐ์ดํŠธ
            if ppl < self.best_ppl:
                self.best_ppl = ppl
                best_path = f"{self.save_dir}/best"
                self.model.save_pretrained(best_path)
                print(f"New best: ppl={ppl:.2f}")

            # ์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์‚ญ์ œ
            self._cleanup_old_checkpoints()

    def _evaluate(self) -> float:
        """Perplexity ํ‰๊ฐ€"""
        self.model.eval()
        total_loss = 0
        total_tokens = 0

        with torch.no_grad():
            for batch in self.eval_dataloader:
                outputs = self.model(**batch)
                total_loss += outputs.loss.item() * batch["input_ids"].numel()
                total_tokens += batch["input_ids"].numel()

        self.model.train()
        ppl = math.exp(total_loss / total_tokens)
        return ppl

    def _cleanup_old_checkpoints(self):
        """์˜ค๋ž˜๋œ ์ฒดํฌํฌ์ธํŠธ ์‚ญ์ œ"""
        if len(self.saved_checkpoints) > self.keep_last_n:
            # PPL ๊ธฐ์ค€ ์ •๋ ฌ
            sorted_ckpts = sorted(self.saved_checkpoints, key=lambda x: x[1])
            to_keep = sorted_ckpts[:self.keep_last_n]
            to_remove = set(self.saved_checkpoints) - set(to_keep)

            for _, _, path in to_remove:
                if os.path.exists(path):
                    shutil.rmtree(path)

            self.saved_checkpoints = list(to_keep)

5. ๋„๋ฉ”์ธ๋ณ„ ์˜ˆ์‹œ

5.1 ์˜๋ฃŒ ๋„๋ฉ”์ธ

class MedicalCPT:
    """์˜๋ฃŒ ๋„๋ฉ”์ธ Continued Pre-training"""

    def __init__(self, base_model_name: str):
        self.model = AutoModelForCausalLM.from_pretrained(base_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)

    def prepare_medical_data(self, sources: List[str]) -> List[str]:
        """์˜๋ฃŒ ๋ฐ์ดํ„ฐ ์ค€๋น„"""
        processed = []

        for source in sources:
            if source == "pubmed":
                # PubMed abstracts
                data = self._load_pubmed()
            elif source == "clinical_notes":
                # ์ž„์ƒ ๋…ธํŠธ (์ต๋ช…ํ™”)
                data = self._load_clinical_notes()
            elif source == "medical_textbooks":
                # ์˜ํ•™ ๊ต๊ณผ์„œ
                data = self._load_textbooks()

            processed.extend(data)

        return processed

    def _load_pubmed(self) -> List[str]:
        """PubMed ๋ฐ์ดํ„ฐ ๋กœ๋“œ"""
        from datasets import load_dataset

        dataset = load_dataset("pubmed", split="train")
        return [ex["abstract"] for ex in dataset if len(ex["abstract"]) > 100]

    def train(self, data: List[str], output_dir: str):
        """ํ•™์Šต ์‹คํ–‰"""
        # ํ† ํฌ๋‚˜์ด์ง•
        def tokenize(examples):
            return self.tokenizer(
                examples["text"],
                truncation=True,
                max_length=2048
            )

        dataset = Dataset.from_dict({"text": data})
        tokenized = dataset.map(tokenize, batched=True)

        # ํ•™์Šต
        training_args = TrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=4,
            gradient_accumulation_steps=16,
            learning_rate=5e-6,  # ๋‚ฎ์€ LR
            num_train_epochs=1,
            warmup_ratio=0.05,
            lr_scheduler_type="cosine",
            logging_steps=100,
            save_steps=500,
            fp16=True
        )

        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized,
            data_collator=DataCollatorForLanguageModeling(
                tokenizer=self.tokenizer, mlm=False
            )
        )

        trainer.train()

5.2 ์ฝ”๋“œ ๋„๋ฉ”์ธ

class CodeCPT:
    """์ฝ”๋“œ ๋„๋ฉ”์ธ Continued Pre-training"""

    def prepare_code_data(self) -> List[str]:
        """์ฝ”๋“œ ๋ฐ์ดํ„ฐ ์ค€๋น„"""
        from datasets import load_dataset

        # The Stack
        dataset = load_dataset(
            "bigcode/the-stack",
            data_dir="data/python",
            split="train",
            streaming=True
        )

        processed = []
        for example in dataset:
            code = example["content"]

            # ํ’ˆ์งˆ ํ•„ํ„ฐ๋ง
            if self._is_quality_code(code):
                processed.append(code)

            if len(processed) >= 1000000:  # 1M ์ƒ˜ํ”Œ
                break

        return processed

    def _is_quality_code(self, code: str) -> bool:
        """์ฝ”๋“œ ํ’ˆ์งˆ ๊ฒ€์‚ฌ"""
        # ๊ธธ์ด
        if len(code) < 50 or len(code) > 100000:
            return False

        # ์ฃผ์„ ๋น„์œจ
        lines = code.split("\n")
        comment_lines = sum(1 for l in lines if l.strip().startswith("#"))
        if len(lines) > 0 and comment_lines / len(lines) > 0.5:
            return False

        # ๊ตฌ๋ฌธ ๊ฒ€์‚ฌ
        try:
            import ast
            ast.parse(code)
            return True
        except SyntaxError:
            return False

ํ•ต์‹ฌ ์ •๋ฆฌ

Continued Pre-training ํ•ต์‹ฌ

1. ๋ชฉ์ : ๋„๋ฉ”์ธ ์ง€์‹ ์ฃผ์ž…
2. ๋ฐ์ดํ„ฐ: ๋Œ€๋Ÿ‰์˜ ๋„๋ฉ”์ธ ํ…์ŠคํŠธ (10B+ tokens)
3. ๋ฐฉ๋ฒ•: Causal LM objective
4. ์ฃผ์˜: Catastrophic forgetting

Forgetting ์™„ํ™” ์ „๋žต

1. KL Regularization: reference ๋ชจ๋ธ๊ณผ์˜ KL ์ตœ์†Œํ™”
2. EWC: ์ค‘์š” ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณด์กด
3. Experience Replay: ์ผ๋ฐ˜ ๋ฐ์ดํ„ฐ ํ˜ผํ•ฉ
4. Curriculum: ์ ์ง„์  ๋„๋ฉ”์ธ ๋น„์œจ ์ฆ๊ฐ€

ํ•™์Šต ๊ถŒ์žฅ ์‚ฌํ•ญ

- Learning Rate: base์˜ 1/10 ~ 1/5
- Warmup: 3-5%
- Batch Size: ํฐ ๋ฐฐ์น˜ (256+)
- Epochs: 1 epoch
- Checkpointing: ์ž์ฃผ ์ €์žฅ, perplexity ๋ชจ๋‹ˆํ„ฐ๋ง

์ฐธ๊ณ  ์ž๋ฃŒ

  1. Gururangan et al. (2020). "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks"
  2. Ke et al. (2023). "Continual Pre-training of Language Models"
  3. Ibrahim et al. (2024). "Simple and Scalable Strategies to Continually Pre-train Large Language Models"
  4. Xie et al. (2023). "Efficient Continual Pre-training for Building Domain Specific Large Language Models"
to navigate between lessons