1"""
2Foundation Models - Evaluation Metrics
3
4Implements common evaluation metrics for language models.
5Demonstrates BLEU, ROUGE, perplexity, exact match, and F1 score.
6Shows how to evaluate model outputs on different tasks.
7
8No external dependencies except numpy.
9"""
10
11import re
12import math
13from collections import Counter, defaultdict
14import numpy as np
15
16
17def tokenize(text):
18 """Simple word tokenization."""
19 text = text.lower()
20 text = re.sub(r'[^\w\s]', ' ', text)
21 return text.split()
22
23
24def ngrams(tokens, n):
25 """Generate n-grams from token list."""
26 return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
27
28
29def bleu_score(reference, candidate, max_n=4, weights=None):
30 """
31 Compute BLEU score.
32
33 BLEU = BP × exp(sum(w_n × log(p_n)))
34
35 where p_n is n-gram precision and BP is brevity penalty.
36
37 Args:
38 reference: Reference text (string or list of tokens)
39 candidate: Candidate text (string or list of tokens)
40 max_n: Maximum n-gram order
41 weights: Weights for each n-gram order (default: uniform)
42
43 Returns:
44 BLEU score (0-1)
45 """
46 if isinstance(reference, str):
47 reference = tokenize(reference)
48 if isinstance(candidate, str):
49 candidate = tokenize(candidate)
50
51 if weights is None:
52 weights = [1.0 / max_n] * max_n
53
54 # Compute n-gram precisions
55 precisions = []
56 for n in range(1, max_n + 1):
57 ref_ngrams = Counter(ngrams(reference, n))
58 cand_ngrams = Counter(ngrams(candidate, n))
59
60 # Clipped count: min(count, ref_count)
61 clipped_count = 0
62 total_count = 0
63
64 for ng in cand_ngrams:
65 clipped_count += min(cand_ngrams[ng], ref_ngrams[ng])
66 total_count += cand_ngrams[ng]
67
68 if total_count > 0:
69 precision = clipped_count / total_count
70 else:
71 precision = 0
72
73 precisions.append(precision)
74
75 # Brevity penalty
76 ref_len = len(reference)
77 cand_len = len(candidate)
78
79 if cand_len > ref_len:
80 bp = 1
81 else:
82 bp = math.exp(1 - ref_len / cand_len) if cand_len > 0 else 0
83
84 # Geometric mean of precisions
85 if all(p > 0 for p in precisions):
86 log_precision = sum(w * math.log(p) for w, p in zip(weights, precisions))
87 bleu = bp * math.exp(log_precision)
88 else:
89 bleu = 0
90
91 return bleu
92
93
94def rouge_l_score(reference, candidate):
95 """
96 Compute ROUGE-L score based on longest common subsequence.
97
98 ROUGE-L = F1 score of LCS
99
100 Args:
101 reference: Reference text
102 candidate: Candidate text
103
104 Returns:
105 Dictionary with precision, recall, and f1
106 """
107 if isinstance(reference, str):
108 reference = tokenize(reference)
109 if isinstance(candidate, str):
110 candidate = tokenize(candidate)
111
112 # Compute LCS length using dynamic programming
113 m, n = len(reference), len(candidate)
114 dp = [[0] * (n + 1) for _ in range(m + 1)]
115
116 for i in range(1, m + 1):
117 for j in range(1, n + 1):
118 if reference[i-1] == candidate[j-1]:
119 dp[i][j] = dp[i-1][j-1] + 1
120 else:
121 dp[i][j] = max(dp[i-1][j], dp[i][j-1])
122
123 lcs_length = dp[m][n]
124
125 # Compute precision, recall, F1
126 if n > 0:
127 precision = lcs_length / n
128 else:
129 precision = 0
130
131 if m > 0:
132 recall = lcs_length / m
133 else:
134 recall = 0
135
136 if precision + recall > 0:
137 f1 = 2 * precision * recall / (precision + recall)
138 else:
139 f1 = 0
140
141 return {
142 'precision': precision,
143 'recall': recall,
144 'f1': f1,
145 'lcs_length': lcs_length
146 }
147
148
149def exact_match(reference, candidate, normalize=True):
150 """
151 Compute exact match score.
152
153 Args:
154 reference: Reference answer
155 candidate: Predicted answer
156 normalize: Whether to normalize (lowercase, strip)
157
158 Returns:
159 1 if exact match, 0 otherwise
160 """
161 if normalize:
162 reference = reference.lower().strip()
163 candidate = candidate.lower().strip()
164
165 return 1 if reference == candidate else 0
166
167
168def f1_token_score(reference, candidate):
169 """
170 Compute token-level F1 score (for span-based QA).
171
172 Args:
173 reference: Reference text
174 candidate: Predicted text
175
176 Returns:
177 F1 score
178 """
179 ref_tokens = set(tokenize(reference))
180 cand_tokens = set(tokenize(candidate))
181
182 if not cand_tokens:
183 return 0.0
184
185 common = ref_tokens & cand_tokens
186
187 if not common:
188 return 0.0
189
190 precision = len(common) / len(cand_tokens)
191 recall = len(common) / len(ref_tokens)
192
193 f1 = 2 * precision * recall / (precision + recall)
194
195 return f1
196
197
198def perplexity(log_probs):
199 """
200 Compute perplexity from log probabilities.
201
202 PPL = exp(-1/N × sum(log P(w_i)))
203
204 Args:
205 log_probs: List of log probabilities for each token
206
207 Returns:
208 Perplexity value
209 """
210 if not log_probs:
211 return float('inf')
212
213 avg_log_prob = sum(log_probs) / len(log_probs)
214 return math.exp(-avg_log_prob)
215
216
217def classification_metrics(y_true, y_pred):
218 """
219 Compute classification metrics: accuracy, precision, recall, F1.
220
221 Args:
222 y_true: True labels
223 y_pred: Predicted labels
224
225 Returns:
226 Dictionary of metrics
227 """
228 assert len(y_true) == len(y_pred), "Length mismatch"
229
230 # True/False Positives/Negatives (assuming binary)
231 tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 1)
232 fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 1)
233 tn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 0)
234 fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 0)
235
236 # Metrics
237 accuracy = (tp + tn) / len(y_true) if len(y_true) > 0 else 0
238
239 precision = tp / (tp + fp) if (tp + fp) > 0 else 0
240 recall = tp / (tp + fn) if (tp + fn) > 0 else 0
241
242 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
243
244 return {
245 'accuracy': accuracy,
246 'precision': precision,
247 'recall': recall,
248 'f1': f1,
249 'tp': tp,
250 'fp': fp,
251 'tn': tn,
252 'fn': fn,
253 }
254
255
256# ============================================================
257# Demonstrations
258# ============================================================
259
260def demo_bleu():
261 """Demonstrate BLEU score computation."""
262 print("=" * 60)
263 print("DEMO 1: BLEU Score")
264 print("=" * 60)
265
266 reference = "The cat is sitting on the mat"
267 candidates = [
268 "The cat is sitting on the mat", # Perfect match
269 "The cat sits on the mat", # Close
270 "A cat is on the mat", # Moderate
271 "There is a cat", # Poor
272 ]
273
274 print(f"\nReference: {reference}\n")
275
276 for cand in candidates:
277 score = bleu_score(reference, cand, max_n=4)
278 print(f"Candidate: {cand}")
279 print(f"BLEU score: {score:.4f}\n")
280
281
282def demo_rouge():
283 """Demonstrate ROUGE-L score computation."""
284 print("\n" + "=" * 60)
285 print("DEMO 2: ROUGE-L Score")
286 print("=" * 60)
287
288 reference = "The quick brown fox jumps over the lazy dog"
289 candidates = [
290 "The quick brown fox jumps over the lazy dog",
291 "The brown fox jumps over the dog",
292 "A quick fox jumped over a dog",
293 "The cat sleeps",
294 ]
295
296 print(f"\nReference: {reference}\n")
297
298 for cand in candidates:
299 scores = rouge_l_score(reference, cand)
300 print(f"Candidate: {cand}")
301 print(f"ROUGE-L: Precision={scores['precision']:.3f}, "
302 f"Recall={scores['recall']:.3f}, F1={scores['f1']:.3f}\n")
303
304
305def demo_exact_match():
306 """Demonstrate exact match evaluation."""
307 print("\n" + "=" * 60)
308 print("DEMO 3: Exact Match")
309 print("=" * 60)
310
311 qa_pairs = [
312 ("Paris", "Paris"),
313 ("Paris", "paris"),
314 ("Paris", "Paris, France"),
315 ("1776", "1776"),
316 ("1776", "1776.0"),
317 ]
318
319 print("\nQuestion Answering Evaluation:\n")
320
321 for ref, pred in qa_pairs:
322 em = exact_match(ref, pred, normalize=True)
323 print(f"Reference: '{ref}' | Prediction: '{pred}' | EM: {em}")
324
325
326def demo_f1_token():
327 """Demonstrate token F1 score."""
328 print("\n" + "=" * 60)
329 print("DEMO 4: Token-level F1 Score")
330 print("=" * 60)
331
332 qa_pairs = [
333 ("Barack Obama", "Barack Obama"),
334 ("Barack Obama", "Barack Hussein Obama"),
335 ("Barack Obama", "Obama"),
336 ("New York City", "New York"),
337 ]
338
339 print("\nSpan-based QA Evaluation:\n")
340
341 for ref, pred in qa_pairs:
342 f1 = f1_token_score(ref, pred)
343 em = exact_match(ref, pred, normalize=True)
344 print(f"Reference: '{ref}'")
345 print(f"Prediction: '{pred}'")
346 print(f"EM: {em} | F1: {f1:.3f}\n")
347
348
349def demo_perplexity():
350 """Demonstrate perplexity computation."""
351 print("\n" + "=" * 60)
352 print("DEMO 5: Perplexity")
353 print("=" * 60)
354
355 # Simulate log probabilities for different model qualities
356 # Good model: high probabilities (less negative log probs)
357 good_model = [-0.1, -0.2, -0.15, -0.3, -0.1, -0.2]
358
359 # Medium model
360 medium_model = [-1.0, -1.5, -1.2, -0.8, -1.3, -1.1]
361
362 # Poor model: low probabilities (very negative log probs)
363 poor_model = [-3.0, -4.0, -3.5, -4.2, -3.8, -3.9]
364
365 print("\nPerplexity for different model qualities:")
366 print("-" * 60)
367
368 for name, log_probs in [("Good", good_model), ("Medium", medium_model), ("Poor", poor_model)]:
369 ppl = perplexity(log_probs)
370 avg_prob = math.exp(sum(log_probs) / len(log_probs))
371
372 print(f"{name} model:")
373 print(f" Avg log prob: {sum(log_probs)/len(log_probs):.3f}")
374 print(f" Avg prob: {avg_prob:.3f}")
375 print(f" Perplexity: {ppl:.2f}\n")
376
377
378def demo_classification():
379 """Demonstrate classification metrics."""
380 print("\n" + "=" * 60)
381 print("DEMO 6: Classification Metrics")
382 print("=" * 60)
383
384 # Simulate sentiment classification
385 y_true = [1, 0, 1, 1, 0, 1, 0, 0, 1, 1]
386 y_pred = [1, 0, 1, 0, 0, 1, 1, 0, 1, 1]
387
388 metrics = classification_metrics(y_true, y_pred)
389
390 print("\nBinary Classification Results:")
391 print("-" * 60)
392 print(f"Accuracy: {metrics['accuracy']:.3f}")
393 print(f"Precision: {metrics['precision']:.3f}")
394 print(f"Recall: {metrics['recall']:.3f}")
395 print(f"F1 Score: {metrics['f1']:.3f}")
396
397 print(f"\nConfusion Matrix:")
398 print(f" TP: {metrics['tp']} FP: {metrics['fp']}")
399 print(f" FN: {metrics['fn']} TN: {metrics['tn']}")
400
401
402def demo_summarization_eval():
403 """Evaluate summarization task."""
404 print("\n" + "=" * 60)
405 print("DEMO 7: Summarization Evaluation")
406 print("=" * 60)
407
408 reference = "Machine learning is a subset of artificial intelligence. It enables systems to learn from data."
409
410 summaries = [
411 "Machine learning is part of AI and allows systems to learn from data.",
412 "ML is a type of AI that learns from data.",
413 "Artificial intelligence includes machine learning.",
414 ]
415
416 print(f"\nReference: {reference}\n")
417
418 for i, summary in enumerate(summaries, 1):
419 bleu = bleu_score(reference, summary, max_n=2)
420 rouge = rouge_l_score(reference, summary)
421
422 print(f"Summary {i}: {summary}")
423 print(f" BLEU-2: {bleu:.3f}")
424 print(f" ROUGE-L F1: {rouge['f1']:.3f}\n")
425
426
427def demo_translation_eval():
428 """Evaluate machine translation."""
429 print("\n" + "=" * 60)
430 print("DEMO 8: Translation Evaluation")
431 print("=" * 60)
432
433 # Example: French to English
434 reference = "The cat is on the table"
435
436 translations = [
437 "The cat is on the table", # Perfect
438 "The cat sits on the table", # Good
439 "A cat is on a table", # Medium
440 "Cat table on", # Poor
441 ]
442
443 print(f"\nReference: {reference}\n")
444
445 for i, trans in enumerate(translations, 1):
446 bleu = bleu_score(reference, trans, max_n=4)
447 rouge = rouge_l_score(reference, trans)
448
449 print(f"Translation {i}: {trans}")
450 print(f" BLEU-4: {bleu:.3f}")
451 print(f" ROUGE-L F1: {rouge['f1']:.3f}\n")
452
453
454def demo_multi_reference():
455 """Evaluate with multiple references."""
456 print("\n" + "=" * 60)
457 print("DEMO 9: Multi-Reference Evaluation")
458 print("=" * 60)
459
460 references = [
461 "It is raining heavily",
462 "The rain is very strong",
463 "Heavy rainfall is occurring",
464 ]
465
466 candidate = "It is raining a lot"
467
468 print(f"\nCandidate: {candidate}\n")
469 print("References:")
470 for i, ref in enumerate(references, 1):
471 print(f" {i}. {ref}")
472
473 print("\nScores against each reference:")
474 print("-" * 60)
475
476 bleu_scores = []
477 rouge_scores = []
478
479 for i, ref in enumerate(references, 1):
480 bleu = bleu_score(ref, candidate, max_n=2)
481 rouge = rouge_l_score(ref, candidate)
482
483 bleu_scores.append(bleu)
484 rouge_scores.append(rouge['f1'])
485
486 print(f"Reference {i}: BLEU={bleu:.3f}, ROUGE-L F1={rouge['f1']:.3f}")
487
488 # Best score across references
489 print(f"\nBest scores:")
490 print(f" BLEU: {max(bleu_scores):.3f}")
491 print(f" ROUGE-L F1: {max(rouge_scores):.3f}")
492
493
494if __name__ == "__main__":
495 print("\n" + "=" * 60)
496 print("Foundation Models: Evaluation Metrics")
497 print("=" * 60)
498
499 demo_bleu()
500 demo_rouge()
501 demo_exact_match()
502 demo_f1_token()
503 demo_perplexity()
504 demo_classification()
505 demo_summarization_eval()
506 demo_translation_eval()
507 demo_multi_reference()
508
509 print("\n" + "=" * 60)
510 print("Key Takeaways:")
511 print("=" * 60)
512 print("1. BLEU: n-gram overlap (translation, generation)")
513 print("2. ROUGE-L: Longest common subsequence (summarization)")
514 print("3. Exact Match: Binary correctness (QA)")
515 print("4. Token F1: Partial credit for overlap (span QA)")
516 print("5. Perplexity: Model uncertainty (lower is better)")
517 print("6. Classification: Accuracy, precision, recall, F1")
518 print("7. Multi-reference: Take max/avg across references")
519 print("8. Choose metrics appropriate for task")
520 print("=" * 60)