02_tokenizer.py

Download
python 460 lines 13.6 KB
  1"""
  2Foundation Models - BPE Tokenizer Implementation
  3
  4Implements Byte Pair Encoding (BPE) from scratch.
  5Demonstrates vocabulary building, merge rules, encoding/decoding.
  6Compares with character-level tokenization.
  7
  8No external dependencies except collections.
  9"""
 10
 11import re
 12from collections import Counter, defaultdict
 13
 14
 15class CharacterTokenizer:
 16    """Simple character-level tokenizer for comparison."""
 17
 18    def __init__(self):
 19        self.vocab = {}
 20        self.inv_vocab = {}
 21
 22    def build_vocab(self, text):
 23        """Build vocabulary from unique characters."""
 24        unique_chars = sorted(set(text))
 25        self.vocab = {char: idx for idx, char in enumerate(unique_chars)}
 26        self.inv_vocab = {idx: char for char, idx in self.vocab.items()}
 27
 28        return len(self.vocab)
 29
 30    def encode(self, text):
 31        """Encode text to list of token IDs."""
 32        return [self.vocab[char] for char in text if char in self.vocab]
 33
 34    def decode(self, tokens):
 35        """Decode token IDs back to text."""
 36        return ''.join([self.inv_vocab[tok] for tok in tokens])
 37
 38
 39class BPETokenizer:
 40    """Byte Pair Encoding tokenizer implementation."""
 41
 42    def __init__(self, vocab_size=300):
 43        self.vocab_size = vocab_size
 44        self.merges = []  # List of merge operations
 45        self.vocab = {}   # Token to ID mapping
 46        self.inv_vocab = {}  # ID to token mapping
 47
 48    def get_stats(self, words):
 49        """
 50        Count frequency of adjacent pairs in word list.
 51
 52        Args:
 53            words: Dictionary of {word: frequency}
 54
 55        Returns:
 56            Counter of pair frequencies
 57        """
 58        pairs = Counter()
 59
 60        for word, freq in words.items():
 61            symbols = word.split()
 62            for i in range(len(symbols) - 1):
 63                pair = (symbols[i], symbols[i + 1])
 64                pairs[pair] += freq
 65
 66        return pairs
 67
 68    def merge_pair(self, pair, words):
 69        """
 70        Merge all occurrences of pair in words.
 71
 72        Args:
 73            pair: Tuple of (token1, token2) to merge
 74            words: Dictionary of {word: frequency}
 75
 76        Returns:
 77            New words dictionary with merged pairs
 78        """
 79        new_words = {}
 80        bigram = ' '.join(pair)
 81        replacement = ''.join(pair)
 82
 83        # Compile pattern for efficiency
 84        pattern = re.escape(bigram)
 85
 86        for word, freq in words.items():
 87            # Replace pair with merged token
 88            new_word = re.sub(pattern, replacement, word)
 89            new_words[new_word] = freq
 90
 91        return new_words
 92
 93    def build_vocab(self, text, verbose=False):
 94        """
 95        Build BPE vocabulary from text.
 96
 97        Args:
 98            text: Training text
 99            verbose: Print merge operations
100
101        Returns:
102            Final vocabulary size
103        """
104        # Initialize with character-level tokens
105        # Each word is space-separated characters + end marker
106        words = defaultdict(int)
107
108        for word in text.split():
109            # Add space between characters and end-of-word marker
110            word_chars = ' '.join(list(word)) + ' </w>'
111            words[word_chars] += 1
112
113        # Get initial vocabulary (unique characters)
114        initial_vocab = set()
115        for word in words.keys():
116            initial_vocab.update(word.split())
117
118        print(f"Initial vocabulary size (characters): {len(initial_vocab)}")
119        print(f"Target vocabulary size: {self.vocab_size}")
120
121        # Iteratively merge most frequent pairs
122        num_merges = self.vocab_size - len(initial_vocab)
123        print(f"Number of merges to perform: {num_merges}\n")
124
125        for i in range(num_merges):
126            # Get pair statistics
127            pairs = self.get_stats(words)
128
129            if not pairs:
130                print(f"No more pairs to merge at iteration {i}")
131                break
132
133            # Get most frequent pair
134            best_pair = max(pairs, key=pairs.get)
135            freq = pairs[best_pair]
136
137            # Merge the pair
138            words = self.merge_pair(best_pair, words)
139            self.merges.append(best_pair)
140
141            if verbose and (i < 10 or i % 50 == 0):
142                print(f"Merge {i+1}: {best_pair[0]} + {best_pair[1]} "
143                      f"= {''.join(best_pair)} (freq={freq})")
144
145        # Build final vocabulary from current state
146        final_vocab = set()
147        for word in words.keys():
148            final_vocab.update(word.split())
149
150        # Create token-to-ID mapping
151        self.vocab = {token: idx for idx, token in enumerate(sorted(final_vocab))}
152        self.inv_vocab = {idx: token for token, idx in self.vocab.items()}
153
154        print(f"\nFinal vocabulary size: {len(self.vocab)}")
155        print(f"Total merges performed: {len(self.merges)}")
156
157        return len(self.vocab)
158
159    def encode_word(self, word):
160        """
161        Encode a single word using learned merges.
162
163        Args:
164            word: String to encode
165
166        Returns:
167            List of tokens
168        """
169        # Start with character-level
170        tokens = list(word) + ['</w>']
171
172        # Apply merges in order
173        for merge in self.merges:
174            i = 0
175            while i < len(tokens) - 1:
176                if (tokens[i], tokens[i + 1]) == merge:
177                    # Merge the pair
178                    tokens = tokens[:i] + [''.join(merge)] + tokens[i + 2:]
179                else:
180                    i += 1
181
182        return tokens
183
184    def encode(self, text):
185        """
186        Encode text to token IDs.
187
188        Args:
189            text: Input text
190
191        Returns:
192            List of token IDs
193        """
194        words = text.split()
195        token_ids = []
196
197        for word in words:
198            tokens = self.encode_word(word)
199            for token in tokens:
200                if token in self.vocab:
201                    token_ids.append(self.vocab[token])
202                else:
203                    # Unknown token - use character fallback
204                    for char in token:
205                        if char in self.vocab:
206                            token_ids.append(self.vocab[char])
207
208        return token_ids
209
210    def decode(self, token_ids):
211        """
212        Decode token IDs back to text.
213
214        Args:
215            token_ids: List of token IDs
216
217        Returns:
218            Decoded text
219        """
220        tokens = [self.inv_vocab[tid] for tid in token_ids if tid in self.inv_vocab]
221        text = ''.join(tokens)
222
223        # Remove end-of-word markers and add spaces
224        text = text.replace('</w>', ' ')
225
226        return text.strip()
227
228    def get_token_stats(self):
229        """Get statistics about learned tokens."""
230        token_lengths = Counter()
231
232        for token in self.vocab.keys():
233            # Don't count special markers
234            if token != '</w>':
235                token_lengths[len(token)] += 1
236
237        return token_lengths
238
239
240# ============================================================
241# Demonstrations
242# ============================================================
243
244def demo_character_tokenizer():
245    """Demonstrate simple character-level tokenization."""
246    print("=" * 60)
247    print("DEMO 1: Character-Level Tokenizer")
248    print("=" * 60)
249
250    text = "Hello world! Machine learning is amazing."
251
252    tokenizer = CharacterTokenizer()
253    vocab_size = tokenizer.build_vocab(text)
254
255    print(f"\nVocabulary size: {vocab_size}")
256    print(f"Vocabulary: {sorted(tokenizer.vocab.keys())[:20]}")
257
258    # Encode
259    encoded = tokenizer.encode(text)
260    print(f"\nOriginal text: {text}")
261    print(f"Encoded ({len(encoded)} tokens): {encoded[:30]}")
262
263    # Decode
264    decoded = tokenizer.decode(encoded)
265    print(f"Decoded: {decoded}")
266    print(f"Matches original: {decoded == text}")
267
268
269def demo_bpe_basic():
270    """Demonstrate basic BPE tokenization."""
271    print("\n" + "=" * 60)
272    print("DEMO 2: BPE Tokenizer - Basic")
273    print("=" * 60)
274
275    # Simple training corpus
276    text = "low lower lowest higher high highest new newer newest"
277
278    tokenizer = BPETokenizer(vocab_size=50)
279    tokenizer.build_vocab(text, verbose=True)
280
281    # Show learned merges
282    print("\n" + "-" * 60)
283    print("First 10 learned merges:")
284    print("-" * 60)
285    for i, (a, b) in enumerate(tokenizer.merges[:10]):
286        print(f"{i+1}. {a} + {b}{''.join([a, b])}")
287
288
289def demo_bpe_encoding():
290    """Demonstrate BPE encoding and decoding."""
291    print("\n" + "=" * 60)
292    print("DEMO 3: BPE Encoding/Decoding")
293    print("=" * 60)
294
295    # Training corpus
296    corpus = """
297    the quick brown fox jumps over the lazy dog
298    the dog runs fast and the fox runs faster
299    machine learning models learn from data
300    deep learning uses neural networks
301    """
302
303    tokenizer = BPETokenizer(vocab_size=150)
304    tokenizer.build_vocab(corpus, verbose=False)
305
306    # Test encoding
307    test_sentences = [
308        "the fox runs",
309        "machine learning",
310        "deep neural networks",
311        "the quick dog",
312    ]
313
314    print("\n" + "-" * 60)
315    print("Encoding examples:")
316    print("-" * 60)
317
318    for sentence in test_sentences:
319        tokens = []
320        for word in sentence.split():
321            word_tokens = tokenizer.encode_word(word)
322            tokens.extend(word_tokens)
323
324        token_ids = tokenizer.encode(sentence)
325        decoded = tokenizer.decode(token_ids)
326
327        print(f"\nSentence: {sentence}")
328        print(f"Tokens: {tokens}")
329        print(f"Token IDs ({len(token_ids)}): {token_ids}")
330        print(f"Decoded: {decoded}")
331
332
333def demo_compression_comparison():
334    """Compare compression between character and BPE tokenization."""
335    print("\n" + "=" * 60)
336    print("DEMO 4: Compression Comparison")
337    print("=" * 60)
338
339    corpus = """
340    Natural language processing enables computers to understand human language.
341    Machine learning algorithms can learn patterns from data automatically.
342    Deep learning models use neural networks with multiple layers.
343    Transformers have revolutionized natural language understanding.
344    Large language models can generate coherent and contextual text.
345    """ * 5  # Repeat for more data
346
347    # Character tokenizer
348    char_tok = CharacterTokenizer()
349    char_tok.build_vocab(corpus)
350    char_encoded = char_tok.encode(corpus)
351
352    # BPE tokenizer
353    bpe_tok = BPETokenizer(vocab_size=200)
354    bpe_tok.build_vocab(corpus, verbose=False)
355    bpe_encoded = bpe_tok.encode(corpus)
356
357    print("\n" + "-" * 60)
358    print("Comparison:")
359    print("-" * 60)
360    print(f"Original text length: {len(corpus)} characters")
361    print(f"\nCharacter tokenizer:")
362    print(f"  Vocabulary size: {len(char_tok.vocab)}")
363    print(f"  Encoded length: {len(char_encoded)} tokens")
364    print(f"  Compression ratio: {len(corpus)/len(char_encoded):.2f}x")
365
366    print(f"\nBPE tokenizer:")
367    print(f"  Vocabulary size: {len(bpe_tok.vocab)}")
368    print(f"  Encoded length: {len(bpe_encoded)} tokens")
369    print(f"  Compression ratio: {len(corpus)/len(bpe_encoded):.2f}x")
370
371    reduction = (1 - len(bpe_encoded) / len(char_encoded)) * 100
372    print(f"\nBPE reduces tokens by {reduction:.1f}% vs character-level")
373
374
375def demo_token_statistics():
376    """Analyze learned token statistics."""
377    print("\n" + "=" * 60)
378    print("DEMO 5: Token Statistics")
379    print("=" * 60)
380
381    corpus = """
382    Large language models are trained on massive amounts of text data.
383    These models learn statistical patterns and relationships in language.
384    Tokenization is a crucial preprocessing step for language models.
385    BPE allows models to handle unknown words through subword units.
386    Common words are represented as single tokens for efficiency.
387    Rare words are broken into multiple subword tokens.
388    """ * 10
389
390    tokenizer = BPETokenizer(vocab_size=300)
391    tokenizer.build_vocab(corpus, verbose=False)
392
393    # Get token length statistics
394    token_lengths = tokenizer.get_token_stats()
395
396    print("\n" + "-" * 60)
397    print("Token length distribution:")
398    print("-" * 60)
399    for length in sorted(token_lengths.keys()):
400        count = token_lengths[length]
401        bar = '█' * (count // 5)
402        print(f"Length {length}: {count:3d} tokens {bar}")
403
404    # Show example tokens by length
405    print("\n" + "-" * 60)
406    print("Example tokens by length:")
407    print("-" * 60)
408
409    tokens_by_length = defaultdict(list)
410    for token in tokenizer.vocab.keys():
411        if token != '</w>':
412            tokens_by_length[len(token)].append(token)
413
414    for length in sorted(tokens_by_length.keys())[:8]:
415        examples = tokens_by_length[length][:10]
416        print(f"Length {length}: {examples}")
417
418
419def demo_merge_frequency():
420    """Analyze merge operation frequencies."""
421    print("\n" + "=" * 60)
422    print("DEMO 6: Most Important Merges")
423    print("=" * 60)
424
425    corpus = "the the the and and or if then else while for " * 20
426
427    tokenizer = BPETokenizer(vocab_size=80)
428    tokenizer.build_vocab(corpus, verbose=False)
429
430    print("\n" + "-" * 60)
431    print("Top 20 merges (in order):")
432    print("-" * 60)
433
434    for i, (a, b) in enumerate(tokenizer.merges[:20]):
435        merged = ''.join([a, b])
436        print(f"{i+1:2d}. '{a}' + '{b}' → '{merged}'")
437
438
439if __name__ == "__main__":
440    print("\n" + "=" * 60)
441    print("Foundation Models: BPE Tokenizer")
442    print("=" * 60)
443
444    demo_character_tokenizer()
445    demo_bpe_basic()
446    demo_bpe_encoding()
447    demo_compression_comparison()
448    demo_token_statistics()
449    demo_merge_frequency()
450
451    print("\n" + "=" * 60)
452    print("Key Takeaways:")
453    print("=" * 60)
454    print("1. BPE builds vocabulary by iteratively merging frequent pairs")
455    print("2. Balances vocabulary size with sequence length")
456    print("3. Handles unknown words through subword decomposition")
457    print("4. Common words → single tokens, rare words → multiple tokens")
458    print("5. Reduces sequence length by 30-50% vs character-level")
459    print("=" * 60)