05. Data Curation

05. Data Curation

Overview

The performance of Foundation Models heavily depends on data quality and diversity. "Garbage in, garbage out" is more critical than ever. This lesson covers the construction, refinement, and management of large-scale pre-training datasets.


1. Major Pre-training Datasets

1.1 Dataset Overview

┌──────────────────────────────────────────────────────────────────┐
│                Pre-training Dataset Evolution                    │
├──────────────────────────────────────────────────────────────────┤
│                                                                  │
│  2018: BookCorpus + Wikipedia (3.3B tokens) → BERT              │
│         │                                                        │
│  2019: WebText (40GB, Reddit links) → GPT-2                     │
│         │                                                        │
│  2020: C4 (750GB, Common Crawl filtered) → T5                   │
│         │                                                        │
│  2020: The Pile (825GB, 22 sources) → GPT-Neo, Pythia          │
│         │                                                        │
│  2022: ROOTS (1.6TB, 59 languages) → BLOOM                      │
│         │                                                        │
│  2023: RedPajama (1.2T tokens) → RedPajama-INCITE              │
│         │                                                        │
│  2024: FineWeb (15T tokens) → Latest open models               │
│                                                                  │
└──────────────────────────────────────────────────────────────────┘

1.2 Major Dataset Comparison

Dataset Size Source Features
The Pile 825GB 22 diverse sources Includes code, academic, books
C4 750GB Common Crawl English only, filtered
RedPajama 1.2T tokens LLaMA recipe replication Open source
ROOTS 1.6TB 59 languages Multilingual, BigScience
FineWeb 15T tokens Common Crawl HuggingFace, latest
Dolma 3T tokens Various sources Allen AI, transparency focus

1.3 The Pile Composition

# The Pile's 22 subdatasets
PILE_COMPONENTS = {
    # Web text
    'Pile-CC': 227.12,      # Filtered Common Crawl
    'OpenWebText2': 62.77,  # Reddit-linked webpages

    # Books and literature
    'Books3': 100.96,       # Books
    'BookCorpus2': 6.30,    # Additional books
    'Gutenberg': 10.88,     # Public domain books

    # Academic
    'PubMed Central': 90.27,   # Medical papers
    'ArXiv': 56.21,            # Scientific papers
    'PubMed Abstracts': 19.26, # Paper abstracts
    'PhilPapers': 2.38,        # Philosophy papers
    'NIH ExPorter': 1.89,      # NIH research info

    # Code
    'Github': 95.16,        # GitHub code
    'StackExchange': 32.20, # Q&A

    # Other
    'Wikipedia (en)': 16.11,
    'FreeLaw': 51.15,       # Legal documents
    'USPTO': 22.90,         # Patents
    'DM Mathematics': 7.75, # Math problems
    'Ubuntu IRC': 5.52,     # IRC logs
    'EuroParl': 4.59,       # EU parliament
    'HackerNews': 3.90,
    'YoutubeSubtitles': 3.73,
    'Enron Emails': 0.88,
}

# Calculate ratios
total = sum(PILE_COMPONENTS.values())
for name, size in sorted(PILE_COMPONENTS.items(), key=lambda x: -x[1])[:5]:
    print(f"{name}: {size:.1f}GB ({size/total*100:.1f}%)")

2. Data Collection

2.1 Using Common Crawl

import gzip
import json
from warcio.archiveiterator import ArchiveIterator
import requests

class CommonCrawlExtractor:
    """Extract text from Common Crawl"""

    CC_INDEX_URL = "https://index.commoncrawl.org/CC-MAIN-2024-10-index"

    def fetch_warc_paths(self, domain: str, limit: int = 100) -> list[str]:
        """Query WARC file paths for specific domain"""
        params = {
            'url': f'*.{domain}/*',
            'output': 'json',
            'limit': limit
        }
        response = requests.get(self.CC_INDEX_URL, params=params)
        return [json.loads(line)['filename'] for line in response.text.strip().split('\n')]

    def extract_text_from_warc(self, warc_url: str) -> list[dict]:
        """Extract text from WARC file"""
        results = []

        response = requests.get(
            f"https://data.commoncrawl.org/{warc_url}",
            stream=True
        )

        with gzip.open(response.raw, 'rb') as stream:
            for record in ArchiveIterator(stream):
                if record.rec_type == 'response':
                    url = record.rec_headers.get_header('WARC-Target-URI')
                    content = record.content_stream().read().decode('utf-8', errors='ignore')

                    # Extract text from HTML (using trafilatura, etc.)
                    text = self.extract_text(content)

                    if text:
                        results.append({
                            'url': url,
                            'text': text,
                            'timestamp': record.rec_headers.get_header('WARC-Date')
                        })

        return results

    def extract_text(self, html: str) -> str:
        """Extract main text from HTML"""
        try:
            import trafilatura
            return trafilatura.extract(html)
        except:
            from bs4 import BeautifulSoup
            soup = BeautifulSoup(html, 'html.parser')
            # Remove script, style
            for tag in soup(['script', 'style', 'nav', 'footer']):
                tag.decompose()
            return soup.get_text(separator=' ', strip=True)

2.2 GitHub Code Collection

import os
from github import Github
from typing import Generator

class GitHubCodeCollector:
    """Collect code from GitHub"""

    # Languages and extensions to collect
    LANGUAGES = {
        'python': ['.py'],
        'javascript': ['.js', '.jsx', '.ts', '.tsx'],
        'java': ['.java'],
        'cpp': ['.cpp', '.hpp', '.c', '.h'],
        'go': ['.go'],
        'rust': ['.rs'],
    }

    def __init__(self, token: str):
        self.github = Github(token)

    def collect_repos(
        self,
        language: str,
        min_stars: int = 100,
        limit: int = 1000
    ) -> Generator[dict, None, None]:
        """Collect popular repositories"""
        query = f"language:{language} stars:>{min_stars}"
        repos = self.github.search_repositories(query, sort='stars')

        for i, repo in enumerate(repos):
            if i >= limit:
                break

            yield {
                'name': repo.full_name,
                'stars': repo.stargazers_count,
                'language': repo.language,
                'license': repo.license.key if repo.license else None,
                'url': repo.html_url
            }

    def extract_code_files(
        self,
        repo_name: str,
        extensions: list[str]
    ) -> Generator[dict, None, None]:
        """Extract code files from repository"""
        repo = self.github.get_repo(repo_name)

        try:
            contents = repo.get_contents("")
            while contents:
                file_content = contents.pop(0)

                if file_content.type == "dir":
                    contents.extend(repo.get_contents(file_content.path))
                elif any(file_content.path.endswith(ext) for ext in extensions):
                    try:
                        content = file_content.decoded_content.decode('utf-8')
                        yield {
                            'path': file_content.path,
                            'content': content,
                            'size': file_content.size
                        }
                    except:
                        continue
        except Exception as e:
            print(f"Error processing {repo_name}: {e}")

3. Data Cleaning Pipeline

3.1 Quality Filtering

import re
from typing import Optional
import fasttext
from collections import Counter

class QualityFilter:
    """Text quality filtering"""

    def __init__(self, lang_model_path: str = 'lid.176.bin'):
        # FastText language detection model
        self.lang_detector = fasttext.load_model(lang_model_path)

    def filter_document(self, text: str, target_lang: str = 'en') -> Optional[str]:
        """
        Filter document

        Returns:
            Cleaned text or None (filtered out)
        """
        # 1. Basic filter
        if not self._basic_filter(text):
            return None

        # 2. Language filter
        if not self._language_filter(text, target_lang):
            return None

        # 3. Quality score
        if not self._quality_score_filter(text):
            return None

        # 4. Text cleaning
        cleaned = self._clean_text(text)

        return cleaned if len(cleaned) > 100 else None

    def _basic_filter(self, text: str) -> bool:
        """Basic filtering rules"""
        # Min/max length
        if len(text) < 100 or len(text) > 100000:
            return False

        # Word count
        words = text.split()
        if len(words) < 20:
            return False

        # Average word length (too short/long suggests spam)
        avg_word_len = sum(len(w) for w in words) / len(words)
        if avg_word_len < 3 or avg_word_len > 15:
            return False

        # Alphabet ratio
        alpha_chars = sum(c.isalpha() for c in text)
        if alpha_chars / len(text) < 0.6:
            return False

        return True

    def _language_filter(self, text: str, target_lang: str) -> bool:
        """Language filtering"""
        # Detect language from first 500 chars
        sample = text[:500].replace('\n', ' ')
        predictions = self.lang_detector.predict(sample, k=1)

        lang = predictions[0][0].replace('__label__', '')
        confidence = predictions[1][0]

        return lang == target_lang and confidence > 0.8

    def _quality_score_filter(self, text: str) -> bool:
        """Quality score-based filtering"""
        lines = text.split('\n')

        # Line-ending punctuation ratio
        end_punct = sum(1 for line in lines if line.strip() and line.strip()[-1] in '.!?')
        punct_ratio = end_punct / max(len(lines), 1)

        # Lines starting with capital letter ratio
        cap_start = sum(1 for line in lines if line.strip() and line.strip()[0].isupper())
        cap_ratio = cap_start / max(len(lines), 1)

        # Bullet/number list ratio (too high suggests list page)
        bullet_lines = sum(1 for line in lines if re.match(r'^\s*[\-\*\•\d\.]\s', line))
        bullet_ratio = bullet_lines / max(len(lines), 1)

        # Quality score
        if punct_ratio < 0.3:  # Too little punctuation
            return False
        if bullet_ratio > 0.5:  # Too many lists
            return False

        return True

    def _clean_text(self, text: str) -> str:
        """Clean text"""
        # Remove URLs
        text = re.sub(r'https?://\S+', '', text)

        # Remove emails
        text = re.sub(r'\S+@\S+\.\S+', '[EMAIL]', text)

        # Clean excessive whitespace
        text = re.sub(r'\n{3,}', '\n\n', text)
        text = re.sub(r' {2,}', ' ', text)

        # Remove control characters
        text = ''.join(c for c in text if c.isprintable() or c in '\n\t')

        return text.strip()

3.2 Deduplication

import hashlib
from datasketch import MinHash, MinHashLSH
from typing import Generator

class DeduplicationPipeline:
    """Large-scale deduplication pipeline"""

    def __init__(
        self,
        num_perm: int = 128,
        threshold: float = 0.8,
        ngram_size: int = 5
    ):
        self.num_perm = num_perm
        self.threshold = threshold
        self.ngram_size = ngram_size

        # LSH index
        self.lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
        self.seen_hashes = set()

    def get_minhash(self, text: str) -> MinHash:
        """Calculate MinHash of text"""
        minhash = MinHash(num_perm=self.num_perm)

        # Generate N-grams
        words = text.lower().split()
        for i in range(len(words) - self.ngram_size + 1):
            ngram = ' '.join(words[i:i + self.ngram_size])
            minhash.update(ngram.encode('utf-8'))

        return minhash

    def exact_dedup(self, text: str) -> bool:
        """
        Exact deduplication (hash-based)

        Returns:
            True if unique, False if duplicate
        """
        # Hash of normalized text
        normalized = ' '.join(text.lower().split())
        text_hash = hashlib.md5(normalized.encode()).hexdigest()

        if text_hash in self.seen_hashes:
            return False

        self.seen_hashes.add(text_hash)
        return True

    def fuzzy_dedup(self, doc_id: str, text: str) -> bool:
        """
        Fuzzy deduplication (MinHash LSH)

        Returns:
            True if unique, False if near-duplicate found
        """
        minhash = self.get_minhash(text)

        # Search for similar documents
        result = self.lsh.query(minhash)

        if result:
            return False

        # Add new document
        self.lsh.insert(doc_id, minhash)
        return True

    def deduplicate_stream(
        self,
        documents: Generator[dict, None, None]
    ) -> Generator[dict, None, None]:
        """
        Streaming deduplication
        """
        for i, doc in enumerate(documents):
            text = doc['text']
            doc_id = doc.get('id', str(i))

            # Stage 1: Exact duplicates
            if not self.exact_dedup(text):
                continue

            # Stage 2: Near-duplicates
            if not self.fuzzy_dedup(doc_id, text):
                continue

            yield doc


# Usage example
def deduplicate_dataset(input_path: str, output_path: str):
    """Deduplicate dataset"""
    pipeline = DeduplicationPipeline(threshold=0.85)

    def read_documents(path):
        with open(path, 'r') as f:
            for line in f:
                yield json.loads(line)

    unique_count = 0
    total_count = 0

    with open(output_path, 'w') as out:
        for doc in pipeline.deduplicate_stream(read_documents(input_path)):
            out.write(json.dumps(doc) + '\n')
            unique_count += 1
        total_count += 1

    print(f"Total: {total_count}, Unique: {unique_count}")
    print(f"Dedup ratio: {(1 - unique_count/total_count)*100:.1f}%")

4. Data Mixing

4.1 Domain Mixing Strategy

import numpy as np
from dataclasses import dataclass
from typing import Iterator

@dataclass
class DataSource:
    name: str
    path: str
    weight: float  # Sampling weight
    quality_score: float  # Quality score (0-1)

class DataMixer:
    """
    Mix data from various sources

    Strategies:
    1. Quality-based: Sample more from high-quality sources
    2. Diversity-based: Balance all domains
    3. Scaling law-based: Search for optimal ratios
    """

    # LLaMA-style mixing ratios
    LLAMA_MIX = {
        'CommonCrawl': 0.67,    # Web
        'C4': 0.15,             # Filtered web
        'Github': 0.045,        # Code
        'Wikipedia': 0.045,     # Encyclopedia
        'Books': 0.045,         # Books
        'ArXiv': 0.025,         # Scientific
        'StackExchange': 0.02,  # Q&A
    }

    def __init__(self, sources: list[DataSource]):
        self.sources = sources
        self.normalize_weights()

    def normalize_weights(self):
        """Normalize weights"""
        total = sum(s.weight for s in self.sources)
        for source in self.sources:
            source.weight /= total

    def temperature_sampling(
        self,
        temperature: float = 1.0
    ) -> list[float]:
        """
        Adjust sampling probabilities with temperature

        temperature < 1: Focus on high-frequency sources
        temperature > 1: Distribute more evenly
        """
        weights = np.array([s.weight for s in self.sources])

        # Apply temperature
        adjusted = np.power(weights, 1 / temperature)
        adjusted /= adjusted.sum()

        return adjusted.tolist()

    def sample_batch(
        self,
        batch_size: int,
        temperature: float = 1.0
    ) -> list[tuple[str, int]]:
        """
        Sample batch

        Returns:
            List of (source_name, num_samples)
        """
        probs = self.temperature_sampling(temperature)

        # Number of documents to sample from each source
        samples = np.random.multinomial(batch_size, probs)

        return [
            (source.name, count)
            for source, count in zip(self.sources, samples)
        ]

    def iter_mixed_data(
        self,
        batch_size: int = 1000,
        temperature: float = 1.0
    ) -> Iterator[dict]:
        """Mixed data iterator"""
        source_iters = {
            s.name: self._read_source(s.path)
            for s in self.sources
        }

        while True:
            batch_plan = self.sample_batch(batch_size, temperature)

            for source_name, count in batch_plan:
                for _ in range(count):
                    try:
                        yield next(source_iters[source_name])
                    except StopIteration:
                        # Restart source or terminate
                        break

    @staticmethod
    def _read_source(path: str) -> Iterator[dict]:
        """Read data source"""
        with open(path, 'r') as f:
            for line in f:
                yield json.loads(line)


# Search for optimal mixing ratios
def find_optimal_mix(
    sources: list[DataSource],
    validation_data: list,
    model_fn,
    n_trials: int = 20
) -> dict[str, float]:
    """
    Search for optimal mixing ratios with Bayesian Optimization
    """
    import optuna

    def objective(trial):
        # Sample weight for each source
        weights = {}
        for source in sources:
            weights[source.name] = trial.suggest_float(
                source.name, 0.01, 1.0
            )

        # Normalize
        total = sum(weights.values())
        weights = {k: v/total for k, v in weights.items()}

        # Train model and validate
        # (In practice, use small proxy model)
        val_loss = model_fn(weights, validation_data)

        return val_loss

    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=n_trials)

    return study.best_params

4.2 Multilingual Mixing

class MultilingualMixer:
    """
    Multilingual data mixing

    Strategies:
    1. Prevent English over-representation
    2. Upsample low-resource languages
    3. Group by language similarity
    """

    # Default language ratios (BLOOM style)
    BLOOM_RATIOS = {
        'en': 0.30,  # English
        'zh': 0.15,  # Chinese
        'fr': 0.12,  # French
        'es': 0.10,  # Spanish
        'pt': 0.08,  # Portuguese
        'ar': 0.05,  # Arabic
        # ... other languages
    }

    def __init__(self, language_weights: dict[str, float]):
        self.language_weights = language_weights

    def exponential_smoothing(
        self,
        alpha: float = 0.3
    ) -> dict[str, float]:
        """
        Upsample low-resource languages with exponential smoothing

        P(lang) ∝ P_original(lang)^alpha

        alpha < 1: Increase low-resource language ratio
        alpha = 1: Keep original ratio
        """
        smoothed = {
            lang: weight ** alpha
            for lang, weight in self.language_weights.items()
        }

        total = sum(smoothed.values())
        return {lang: w/total for lang, w in smoothed.items()}

    def sample_by_language(
        self,
        documents: list[dict],
        target_ratio: dict[str, float]
    ) -> list[dict]:
        """Sample to match target ratio per language"""
        by_lang = {}
        for doc in documents:
            lang = doc.get('lang', 'en')
            by_lang.setdefault(lang, []).append(doc)

        sampled = []
        total_target = len(documents)

        for lang, ratio in target_ratio.items():
            if lang in by_lang:
                n_samples = int(total_target * ratio)
                lang_docs = by_lang[lang]

                if len(lang_docs) >= n_samples:
                    # Downsample
                    sampled.extend(np.random.choice(lang_docs, n_samples, replace=False))
                else:
                    # Upsample
                    sampled.extend(np.random.choice(lang_docs, n_samples, replace=True))

        return sampled

5. Data Quality Evaluation

5.1 Automatic Quality Scoring

import kenlm
from transformers import AutoModelForSequenceClassification, AutoTokenizer

class DataQualityScorer:
    """Automatic data quality evaluation"""

    def __init__(
        self,
        perplexity_model_path: str = None,
        classifier_model_name: str = None
    ):
        # 1. Perplexity-based (KenLM)
        if perplexity_model_path:
            self.lm = kenlm.Model(perplexity_model_path)
        else:
            self.lm = None

        # 2. Classifier-based (e.g., Wikipedia vs Web)
        if classifier_model_name:
            self.classifier = AutoModelForSequenceClassification.from_pretrained(
                classifier_model_name
            )
            self.tokenizer = AutoTokenizer.from_pretrained(classifier_model_name)
        else:
            self.classifier = None

    def perplexity_score(self, text: str) -> float:
        """
        KenLM perplexity score

        Lower is better (more natural text for language model)
        """
        if self.lm is None:
            return 0.0

        # Sentence-level perplexity
        score = self.lm.score(text, bos=True, eos=True)
        perplexity = 10 ** (-score / len(text.split()))

        return perplexity

    def classifier_score(self, text: str) -> float:
        """
        Quality classifier score (0-1)

        Higher is better quality
        """
        if self.classifier is None:
            return 0.5

        inputs = self.tokenizer(
            text[:512],
            return_tensors='pt',
            truncation=True
        )

        with torch.no_grad():
            outputs = self.classifier(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)

        # Positive class probability
        return probs[0, 1].item()

    def heuristic_score(self, text: str) -> dict[str, float]:
        """Heuristic-based quality score"""
        lines = text.split('\n')
        words = text.split()

        scores = {
            # 1. Alphabet ratio
            'alpha_ratio': sum(c.isalpha() for c in text) / max(len(text), 1),

            # 2. Average words per line
            'words_per_line': len(words) / max(len(lines), 1),

            # 3. Unique lines ratio
            'unique_lines_ratio': len(set(lines)) / max(len(lines), 1),

            # 4. Punctuation ratio
            'punct_ratio': sum(c in '.,!?;:' for c in text) / max(len(text), 1),

            # 5. Uppercase ratio (too high suggests spam)
            'caps_ratio': sum(c.isupper() for c in text) / max(len(text), 1),

            # 6. Digit ratio
            'digit_ratio': sum(c.isdigit() for c in text) / max(len(text), 1),
        }

        return scores

    def combined_score(self, text: str) -> float:
        """Combined quality score"""
        heuristics = self.heuristic_score(text)

        # Ideal range for each heuristic
        score = 1.0

        # Alphabet ratio: 0.7-0.9 ideal
        if heuristics['alpha_ratio'] < 0.6:
            score *= 0.8

        # Uppercase ratio: < 0.1 ideal
        if heuristics['caps_ratio'] > 0.3:
            score *= 0.7

        # Unique lines: > 0.8 ideal
        if heuristics['unique_lines_ratio'] < 0.5:
            score *= 0.6

        # Perplexity score (lower is better)
        ppl = self.perplexity_score(text)
        if ppl > 1000:
            score *= 0.5
        elif ppl > 500:
            score *= 0.8

        return score

6. Practice: FineWeb-style Pipeline

class FineWebPipeline:
    """
    FineWeb-style data pipeline

    Steps:
    1. URL filtering
    2. Text extraction
    3. Language detection
    4. Quality filtering
    5. Deduplication
    6. PII removal
    """

    def __init__(self):
        self.quality_filter = QualityFilter()
        self.dedup = DeduplicationPipeline()
        self.quality_scorer = DataQualityScorer()

    def process_batch(
        self,
        warc_batch: list[dict]
    ) -> list[dict]:
        """Process batch"""
        results = []

        for record in warc_batch:
            # 1. URL filtering
            if not self._url_filter(record['url']):
                continue

            # 2. Text extraction
            text = self._extract_text(record['html'])
            if not text:
                continue

            # 3. Quality filtering
            text = self.quality_filter.filter_document(text)
            if not text:
                continue

            # 4. Quality score
            score = self.quality_scorer.combined_score(text)
            if score < 0.5:
                continue

            # 5. PII masking
            text = self._mask_pii(text)

            results.append({
                'url': record['url'],
                'text': text,
                'quality_score': score
            })

        # 6. Deduplication
        return list(self.dedup.deduplicate_stream(iter(results)))

    def _url_filter(self, url: str) -> bool:
        """URL-based filtering"""
        # Blacklist domains
        blacklist = ['porn', 'xxx', 'adult', 'gambling']
        if any(b in url.lower() for b in blacklist):
            return False

        # Allowed extensions
        if any(url.endswith(ext) for ext in ['.pdf', '.jpg', '.png', '.gif']):
            return False

        return True

    def _extract_text(self, html: str) -> str:
        """Extract main text from HTML"""
        import trafilatura
        return trafilatura.extract(html) or ''

    def _mask_pii(self, text: str) -> str:
        """Mask personal information"""
        import re

        # Email
        text = re.sub(r'\b[\w.-]+@[\w.-]+\.\w+\b', '[EMAIL]', text)

        # Phone number (US format)
        text = re.sub(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', '[PHONE]', text)

        # IP address
        text = re.sub(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', '[IP]', text)

        # Credit card
        text = re.sub(r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b', '[CARD]', text)

        return text


# Execute
if __name__ == "__main__":
    pipeline = FineWebPipeline()

    # Process Common Crawl batch
    warc_batch = [...]  # WARC records

    cleaned_data = pipeline.process_batch(warc_batch)

    print(f"Input: {len(warc_batch)}, Output: {len(cleaned_data)}")
    print(f"Filtering ratio: {(1 - len(cleaned_data)/len(warc_batch))*100:.1f}%")

References

Datasets

Papers

  • Gao et al. (2020). "The Pile: An 800GB Dataset of Diverse Text"
  • Penedo et al. (2023). "The RefinedWeb Dataset for Falcon LLM"
  • Soldaini et al. (2024). "Dolma: An Open Corpus of 3T Tokens"

Tools

to navigate between lessons