1"""
2Foundation Models - BPE Tokenizer Implementation
3
4Implements Byte Pair Encoding (BPE) from scratch.
5Demonstrates vocabulary building, merge rules, encoding/decoding.
6Compares with character-level tokenization.
7
8No external dependencies except collections.
9"""
10
11import re
12from collections import Counter, defaultdict
13
14
15class CharacterTokenizer:
16 """Simple character-level tokenizer for comparison."""
17
18 def __init__(self):
19 self.vocab = {}
20 self.inv_vocab = {}
21
22 def build_vocab(self, text):
23 """Build vocabulary from unique characters."""
24 unique_chars = sorted(set(text))
25 self.vocab = {char: idx for idx, char in enumerate(unique_chars)}
26 self.inv_vocab = {idx: char for char, idx in self.vocab.items()}
27
28 return len(self.vocab)
29
30 def encode(self, text):
31 """Encode text to list of token IDs."""
32 return [self.vocab[char] for char in text if char in self.vocab]
33
34 def decode(self, tokens):
35 """Decode token IDs back to text."""
36 return ''.join([self.inv_vocab[tok] for tok in tokens])
37
38
39class BPETokenizer:
40 """Byte Pair Encoding tokenizer implementation."""
41
42 def __init__(self, vocab_size=300):
43 self.vocab_size = vocab_size
44 self.merges = [] # List of merge operations
45 self.vocab = {} # Token to ID mapping
46 self.inv_vocab = {} # ID to token mapping
47
48 def get_stats(self, words):
49 """
50 Count frequency of adjacent pairs in word list.
51
52 Args:
53 words: Dictionary of {word: frequency}
54
55 Returns:
56 Counter of pair frequencies
57 """
58 pairs = Counter()
59
60 for word, freq in words.items():
61 symbols = word.split()
62 for i in range(len(symbols) - 1):
63 pair = (symbols[i], symbols[i + 1])
64 pairs[pair] += freq
65
66 return pairs
67
68 def merge_pair(self, pair, words):
69 """
70 Merge all occurrences of pair in words.
71
72 Args:
73 pair: Tuple of (token1, token2) to merge
74 words: Dictionary of {word: frequency}
75
76 Returns:
77 New words dictionary with merged pairs
78 """
79 new_words = {}
80 bigram = ' '.join(pair)
81 replacement = ''.join(pair)
82
83 # Compile pattern for efficiency
84 pattern = re.escape(bigram)
85
86 for word, freq in words.items():
87 # Replace pair with merged token
88 new_word = re.sub(pattern, replacement, word)
89 new_words[new_word] = freq
90
91 return new_words
92
93 def build_vocab(self, text, verbose=False):
94 """
95 Build BPE vocabulary from text.
96
97 Args:
98 text: Training text
99 verbose: Print merge operations
100
101 Returns:
102 Final vocabulary size
103 """
104 # Initialize with character-level tokens
105 # Each word is space-separated characters + end marker
106 words = defaultdict(int)
107
108 for word in text.split():
109 # Add space between characters and end-of-word marker
110 word_chars = ' '.join(list(word)) + ' </w>'
111 words[word_chars] += 1
112
113 # Get initial vocabulary (unique characters)
114 initial_vocab = set()
115 for word in words.keys():
116 initial_vocab.update(word.split())
117
118 print(f"Initial vocabulary size (characters): {len(initial_vocab)}")
119 print(f"Target vocabulary size: {self.vocab_size}")
120
121 # Iteratively merge most frequent pairs
122 num_merges = self.vocab_size - len(initial_vocab)
123 print(f"Number of merges to perform: {num_merges}\n")
124
125 for i in range(num_merges):
126 # Get pair statistics
127 pairs = self.get_stats(words)
128
129 if not pairs:
130 print(f"No more pairs to merge at iteration {i}")
131 break
132
133 # Get most frequent pair
134 best_pair = max(pairs, key=pairs.get)
135 freq = pairs[best_pair]
136
137 # Merge the pair
138 words = self.merge_pair(best_pair, words)
139 self.merges.append(best_pair)
140
141 if verbose and (i < 10 or i % 50 == 0):
142 print(f"Merge {i+1}: {best_pair[0]} + {best_pair[1]} "
143 f"= {''.join(best_pair)} (freq={freq})")
144
145 # Build final vocabulary from current state
146 final_vocab = set()
147 for word in words.keys():
148 final_vocab.update(word.split())
149
150 # Create token-to-ID mapping
151 self.vocab = {token: idx for idx, token in enumerate(sorted(final_vocab))}
152 self.inv_vocab = {idx: token for token, idx in self.vocab.items()}
153
154 print(f"\nFinal vocabulary size: {len(self.vocab)}")
155 print(f"Total merges performed: {len(self.merges)}")
156
157 return len(self.vocab)
158
159 def encode_word(self, word):
160 """
161 Encode a single word using learned merges.
162
163 Args:
164 word: String to encode
165
166 Returns:
167 List of tokens
168 """
169 # Start with character-level
170 tokens = list(word) + ['</w>']
171
172 # Apply merges in order
173 for merge in self.merges:
174 i = 0
175 while i < len(tokens) - 1:
176 if (tokens[i], tokens[i + 1]) == merge:
177 # Merge the pair
178 tokens = tokens[:i] + [''.join(merge)] + tokens[i + 2:]
179 else:
180 i += 1
181
182 return tokens
183
184 def encode(self, text):
185 """
186 Encode text to token IDs.
187
188 Args:
189 text: Input text
190
191 Returns:
192 List of token IDs
193 """
194 words = text.split()
195 token_ids = []
196
197 for word in words:
198 tokens = self.encode_word(word)
199 for token in tokens:
200 if token in self.vocab:
201 token_ids.append(self.vocab[token])
202 else:
203 # Unknown token - use character fallback
204 for char in token:
205 if char in self.vocab:
206 token_ids.append(self.vocab[char])
207
208 return token_ids
209
210 def decode(self, token_ids):
211 """
212 Decode token IDs back to text.
213
214 Args:
215 token_ids: List of token IDs
216
217 Returns:
218 Decoded text
219 """
220 tokens = [self.inv_vocab[tid] for tid in token_ids if tid in self.inv_vocab]
221 text = ''.join(tokens)
222
223 # Remove end-of-word markers and add spaces
224 text = text.replace('</w>', ' ')
225
226 return text.strip()
227
228 def get_token_stats(self):
229 """Get statistics about learned tokens."""
230 token_lengths = Counter()
231
232 for token in self.vocab.keys():
233 # Don't count special markers
234 if token != '</w>':
235 token_lengths[len(token)] += 1
236
237 return token_lengths
238
239
240# ============================================================
241# Demonstrations
242# ============================================================
243
244def demo_character_tokenizer():
245 """Demonstrate simple character-level tokenization."""
246 print("=" * 60)
247 print("DEMO 1: Character-Level Tokenizer")
248 print("=" * 60)
249
250 text = "Hello world! Machine learning is amazing."
251
252 tokenizer = CharacterTokenizer()
253 vocab_size = tokenizer.build_vocab(text)
254
255 print(f"\nVocabulary size: {vocab_size}")
256 print(f"Vocabulary: {sorted(tokenizer.vocab.keys())[:20]}")
257
258 # Encode
259 encoded = tokenizer.encode(text)
260 print(f"\nOriginal text: {text}")
261 print(f"Encoded ({len(encoded)} tokens): {encoded[:30]}")
262
263 # Decode
264 decoded = tokenizer.decode(encoded)
265 print(f"Decoded: {decoded}")
266 print(f"Matches original: {decoded == text}")
267
268
269def demo_bpe_basic():
270 """Demonstrate basic BPE tokenization."""
271 print("\n" + "=" * 60)
272 print("DEMO 2: BPE Tokenizer - Basic")
273 print("=" * 60)
274
275 # Simple training corpus
276 text = "low lower lowest higher high highest new newer newest"
277
278 tokenizer = BPETokenizer(vocab_size=50)
279 tokenizer.build_vocab(text, verbose=True)
280
281 # Show learned merges
282 print("\n" + "-" * 60)
283 print("First 10 learned merges:")
284 print("-" * 60)
285 for i, (a, b) in enumerate(tokenizer.merges[:10]):
286 print(f"{i+1}. {a} + {b} → {''.join([a, b])}")
287
288
289def demo_bpe_encoding():
290 """Demonstrate BPE encoding and decoding."""
291 print("\n" + "=" * 60)
292 print("DEMO 3: BPE Encoding/Decoding")
293 print("=" * 60)
294
295 # Training corpus
296 corpus = """
297 the quick brown fox jumps over the lazy dog
298 the dog runs fast and the fox runs faster
299 machine learning models learn from data
300 deep learning uses neural networks
301 """
302
303 tokenizer = BPETokenizer(vocab_size=150)
304 tokenizer.build_vocab(corpus, verbose=False)
305
306 # Test encoding
307 test_sentences = [
308 "the fox runs",
309 "machine learning",
310 "deep neural networks",
311 "the quick dog",
312 ]
313
314 print("\n" + "-" * 60)
315 print("Encoding examples:")
316 print("-" * 60)
317
318 for sentence in test_sentences:
319 tokens = []
320 for word in sentence.split():
321 word_tokens = tokenizer.encode_word(word)
322 tokens.extend(word_tokens)
323
324 token_ids = tokenizer.encode(sentence)
325 decoded = tokenizer.decode(token_ids)
326
327 print(f"\nSentence: {sentence}")
328 print(f"Tokens: {tokens}")
329 print(f"Token IDs ({len(token_ids)}): {token_ids}")
330 print(f"Decoded: {decoded}")
331
332
333def demo_compression_comparison():
334 """Compare compression between character and BPE tokenization."""
335 print("\n" + "=" * 60)
336 print("DEMO 4: Compression Comparison")
337 print("=" * 60)
338
339 corpus = """
340 Natural language processing enables computers to understand human language.
341 Machine learning algorithms can learn patterns from data automatically.
342 Deep learning models use neural networks with multiple layers.
343 Transformers have revolutionized natural language understanding.
344 Large language models can generate coherent and contextual text.
345 """ * 5 # Repeat for more data
346
347 # Character tokenizer
348 char_tok = CharacterTokenizer()
349 char_tok.build_vocab(corpus)
350 char_encoded = char_tok.encode(corpus)
351
352 # BPE tokenizer
353 bpe_tok = BPETokenizer(vocab_size=200)
354 bpe_tok.build_vocab(corpus, verbose=False)
355 bpe_encoded = bpe_tok.encode(corpus)
356
357 print("\n" + "-" * 60)
358 print("Comparison:")
359 print("-" * 60)
360 print(f"Original text length: {len(corpus)} characters")
361 print(f"\nCharacter tokenizer:")
362 print(f" Vocabulary size: {len(char_tok.vocab)}")
363 print(f" Encoded length: {len(char_encoded)} tokens")
364 print(f" Compression ratio: {len(corpus)/len(char_encoded):.2f}x")
365
366 print(f"\nBPE tokenizer:")
367 print(f" Vocabulary size: {len(bpe_tok.vocab)}")
368 print(f" Encoded length: {len(bpe_encoded)} tokens")
369 print(f" Compression ratio: {len(corpus)/len(bpe_encoded):.2f}x")
370
371 reduction = (1 - len(bpe_encoded) / len(char_encoded)) * 100
372 print(f"\nBPE reduces tokens by {reduction:.1f}% vs character-level")
373
374
375def demo_token_statistics():
376 """Analyze learned token statistics."""
377 print("\n" + "=" * 60)
378 print("DEMO 5: Token Statistics")
379 print("=" * 60)
380
381 corpus = """
382 Large language models are trained on massive amounts of text data.
383 These models learn statistical patterns and relationships in language.
384 Tokenization is a crucial preprocessing step for language models.
385 BPE allows models to handle unknown words through subword units.
386 Common words are represented as single tokens for efficiency.
387 Rare words are broken into multiple subword tokens.
388 """ * 10
389
390 tokenizer = BPETokenizer(vocab_size=300)
391 tokenizer.build_vocab(corpus, verbose=False)
392
393 # Get token length statistics
394 token_lengths = tokenizer.get_token_stats()
395
396 print("\n" + "-" * 60)
397 print("Token length distribution:")
398 print("-" * 60)
399 for length in sorted(token_lengths.keys()):
400 count = token_lengths[length]
401 bar = '█' * (count // 5)
402 print(f"Length {length}: {count:3d} tokens {bar}")
403
404 # Show example tokens by length
405 print("\n" + "-" * 60)
406 print("Example tokens by length:")
407 print("-" * 60)
408
409 tokens_by_length = defaultdict(list)
410 for token in tokenizer.vocab.keys():
411 if token != '</w>':
412 tokens_by_length[len(token)].append(token)
413
414 for length in sorted(tokens_by_length.keys())[:8]:
415 examples = tokens_by_length[length][:10]
416 print(f"Length {length}: {examples}")
417
418
419def demo_merge_frequency():
420 """Analyze merge operation frequencies."""
421 print("\n" + "=" * 60)
422 print("DEMO 6: Most Important Merges")
423 print("=" * 60)
424
425 corpus = "the the the and and or if then else while for " * 20
426
427 tokenizer = BPETokenizer(vocab_size=80)
428 tokenizer.build_vocab(corpus, verbose=False)
429
430 print("\n" + "-" * 60)
431 print("Top 20 merges (in order):")
432 print("-" * 60)
433
434 for i, (a, b) in enumerate(tokenizer.merges[:20]):
435 merged = ''.join([a, b])
436 print(f"{i+1:2d}. '{a}' + '{b}' → '{merged}'")
437
438
439if __name__ == "__main__":
440 print("\n" + "=" * 60)
441 print("Foundation Models: BPE Tokenizer")
442 print("=" * 60)
443
444 demo_character_tokenizer()
445 demo_bpe_basic()
446 demo_bpe_encoding()
447 demo_compression_comparison()
448 demo_token_statistics()
449 demo_merge_frequency()
450
451 print("\n" + "=" * 60)
452 print("Key Takeaways:")
453 print("=" * 60)
454 print("1. BPE builds vocabulary by iteratively merging frequent pairs")
455 print("2. Balances vocabulary size with sequence length")
456 print("3. Handles unknown words through subword decomposition")
457 print("4. Common words → single tokens, rare words → multiple tokens")
458 print("5. Reduces sequence length by 30-50% vs character-level")
459 print("=" * 60)