07_evaluation.py

Download
python 521 lines 13.9 KB
  1"""
  2Foundation Models - Evaluation Metrics
  3
  4Implements common evaluation metrics for language models.
  5Demonstrates BLEU, ROUGE, perplexity, exact match, and F1 score.
  6Shows how to evaluate model outputs on different tasks.
  7
  8No external dependencies except numpy.
  9"""
 10
 11import re
 12import math
 13from collections import Counter, defaultdict
 14import numpy as np
 15
 16
 17def tokenize(text):
 18    """Simple word tokenization."""
 19    text = text.lower()
 20    text = re.sub(r'[^\w\s]', ' ', text)
 21    return text.split()
 22
 23
 24def ngrams(tokens, n):
 25    """Generate n-grams from token list."""
 26    return [tuple(tokens[i:i+n]) for i in range(len(tokens) - n + 1)]
 27
 28
 29def bleu_score(reference, candidate, max_n=4, weights=None):
 30    """
 31    Compute BLEU score.
 32
 33    BLEU = BP × exp(sum(w_n × log(p_n)))
 34
 35    where p_n is n-gram precision and BP is brevity penalty.
 36
 37    Args:
 38        reference: Reference text (string or list of tokens)
 39        candidate: Candidate text (string or list of tokens)
 40        max_n: Maximum n-gram order
 41        weights: Weights for each n-gram order (default: uniform)
 42
 43    Returns:
 44        BLEU score (0-1)
 45    """
 46    if isinstance(reference, str):
 47        reference = tokenize(reference)
 48    if isinstance(candidate, str):
 49        candidate = tokenize(candidate)
 50
 51    if weights is None:
 52        weights = [1.0 / max_n] * max_n
 53
 54    # Compute n-gram precisions
 55    precisions = []
 56    for n in range(1, max_n + 1):
 57        ref_ngrams = Counter(ngrams(reference, n))
 58        cand_ngrams = Counter(ngrams(candidate, n))
 59
 60        # Clipped count: min(count, ref_count)
 61        clipped_count = 0
 62        total_count = 0
 63
 64        for ng in cand_ngrams:
 65            clipped_count += min(cand_ngrams[ng], ref_ngrams[ng])
 66            total_count += cand_ngrams[ng]
 67
 68        if total_count > 0:
 69            precision = clipped_count / total_count
 70        else:
 71            precision = 0
 72
 73        precisions.append(precision)
 74
 75    # Brevity penalty
 76    ref_len = len(reference)
 77    cand_len = len(candidate)
 78
 79    if cand_len > ref_len:
 80        bp = 1
 81    else:
 82        bp = math.exp(1 - ref_len / cand_len) if cand_len > 0 else 0
 83
 84    # Geometric mean of precisions
 85    if all(p > 0 for p in precisions):
 86        log_precision = sum(w * math.log(p) for w, p in zip(weights, precisions))
 87        bleu = bp * math.exp(log_precision)
 88    else:
 89        bleu = 0
 90
 91    return bleu
 92
 93
 94def rouge_l_score(reference, candidate):
 95    """
 96    Compute ROUGE-L score based on longest common subsequence.
 97
 98    ROUGE-L = F1 score of LCS
 99
100    Args:
101        reference: Reference text
102        candidate: Candidate text
103
104    Returns:
105        Dictionary with precision, recall, and f1
106    """
107    if isinstance(reference, str):
108        reference = tokenize(reference)
109    if isinstance(candidate, str):
110        candidate = tokenize(candidate)
111
112    # Compute LCS length using dynamic programming
113    m, n = len(reference), len(candidate)
114    dp = [[0] * (n + 1) for _ in range(m + 1)]
115
116    for i in range(1, m + 1):
117        for j in range(1, n + 1):
118            if reference[i-1] == candidate[j-1]:
119                dp[i][j] = dp[i-1][j-1] + 1
120            else:
121                dp[i][j] = max(dp[i-1][j], dp[i][j-1])
122
123    lcs_length = dp[m][n]
124
125    # Compute precision, recall, F1
126    if n > 0:
127        precision = lcs_length / n
128    else:
129        precision = 0
130
131    if m > 0:
132        recall = lcs_length / m
133    else:
134        recall = 0
135
136    if precision + recall > 0:
137        f1 = 2 * precision * recall / (precision + recall)
138    else:
139        f1 = 0
140
141    return {
142        'precision': precision,
143        'recall': recall,
144        'f1': f1,
145        'lcs_length': lcs_length
146    }
147
148
149def exact_match(reference, candidate, normalize=True):
150    """
151    Compute exact match score.
152
153    Args:
154        reference: Reference answer
155        candidate: Predicted answer
156        normalize: Whether to normalize (lowercase, strip)
157
158    Returns:
159        1 if exact match, 0 otherwise
160    """
161    if normalize:
162        reference = reference.lower().strip()
163        candidate = candidate.lower().strip()
164
165    return 1 if reference == candidate else 0
166
167
168def f1_token_score(reference, candidate):
169    """
170    Compute token-level F1 score (for span-based QA).
171
172    Args:
173        reference: Reference text
174        candidate: Predicted text
175
176    Returns:
177        F1 score
178    """
179    ref_tokens = set(tokenize(reference))
180    cand_tokens = set(tokenize(candidate))
181
182    if not cand_tokens:
183        return 0.0
184
185    common = ref_tokens & cand_tokens
186
187    if not common:
188        return 0.0
189
190    precision = len(common) / len(cand_tokens)
191    recall = len(common) / len(ref_tokens)
192
193    f1 = 2 * precision * recall / (precision + recall)
194
195    return f1
196
197
198def perplexity(log_probs):
199    """
200    Compute perplexity from log probabilities.
201
202    PPL = exp(-1/N × sum(log P(w_i)))
203
204    Args:
205        log_probs: List of log probabilities for each token
206
207    Returns:
208        Perplexity value
209    """
210    if not log_probs:
211        return float('inf')
212
213    avg_log_prob = sum(log_probs) / len(log_probs)
214    return math.exp(-avg_log_prob)
215
216
217def classification_metrics(y_true, y_pred):
218    """
219    Compute classification metrics: accuracy, precision, recall, F1.
220
221    Args:
222        y_true: True labels
223        y_pred: Predicted labels
224
225    Returns:
226        Dictionary of metrics
227    """
228    assert len(y_true) == len(y_pred), "Length mismatch"
229
230    # True/False Positives/Negatives (assuming binary)
231    tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 1)
232    fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 1)
233    tn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 0)
234    fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 0)
235
236    # Metrics
237    accuracy = (tp + tn) / len(y_true) if len(y_true) > 0 else 0
238
239    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
240    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
241
242    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
243
244    return {
245        'accuracy': accuracy,
246        'precision': precision,
247        'recall': recall,
248        'f1': f1,
249        'tp': tp,
250        'fp': fp,
251        'tn': tn,
252        'fn': fn,
253    }
254
255
256# ============================================================
257# Demonstrations
258# ============================================================
259
260def demo_bleu():
261    """Demonstrate BLEU score computation."""
262    print("=" * 60)
263    print("DEMO 1: BLEU Score")
264    print("=" * 60)
265
266    reference = "The cat is sitting on the mat"
267    candidates = [
268        "The cat is sitting on the mat",  # Perfect match
269        "The cat sits on the mat",         # Close
270        "A cat is on the mat",             # Moderate
271        "There is a cat",                  # Poor
272    ]
273
274    print(f"\nReference: {reference}\n")
275
276    for cand in candidates:
277        score = bleu_score(reference, cand, max_n=4)
278        print(f"Candidate: {cand}")
279        print(f"BLEU score: {score:.4f}\n")
280
281
282def demo_rouge():
283    """Demonstrate ROUGE-L score computation."""
284    print("\n" + "=" * 60)
285    print("DEMO 2: ROUGE-L Score")
286    print("=" * 60)
287
288    reference = "The quick brown fox jumps over the lazy dog"
289    candidates = [
290        "The quick brown fox jumps over the lazy dog",
291        "The brown fox jumps over the dog",
292        "A quick fox jumped over a dog",
293        "The cat sleeps",
294    ]
295
296    print(f"\nReference: {reference}\n")
297
298    for cand in candidates:
299        scores = rouge_l_score(reference, cand)
300        print(f"Candidate: {cand}")
301        print(f"ROUGE-L: Precision={scores['precision']:.3f}, "
302              f"Recall={scores['recall']:.3f}, F1={scores['f1']:.3f}\n")
303
304
305def demo_exact_match():
306    """Demonstrate exact match evaluation."""
307    print("\n" + "=" * 60)
308    print("DEMO 3: Exact Match")
309    print("=" * 60)
310
311    qa_pairs = [
312        ("Paris", "Paris"),
313        ("Paris", "paris"),
314        ("Paris", "Paris, France"),
315        ("1776", "1776"),
316        ("1776", "1776.0"),
317    ]
318
319    print("\nQuestion Answering Evaluation:\n")
320
321    for ref, pred in qa_pairs:
322        em = exact_match(ref, pred, normalize=True)
323        print(f"Reference: '{ref}' | Prediction: '{pred}' | EM: {em}")
324
325
326def demo_f1_token():
327    """Demonstrate token F1 score."""
328    print("\n" + "=" * 60)
329    print("DEMO 4: Token-level F1 Score")
330    print("=" * 60)
331
332    qa_pairs = [
333        ("Barack Obama", "Barack Obama"),
334        ("Barack Obama", "Barack Hussein Obama"),
335        ("Barack Obama", "Obama"),
336        ("New York City", "New York"),
337    ]
338
339    print("\nSpan-based QA Evaluation:\n")
340
341    for ref, pred in qa_pairs:
342        f1 = f1_token_score(ref, pred)
343        em = exact_match(ref, pred, normalize=True)
344        print(f"Reference: '{ref}'")
345        print(f"Prediction: '{pred}'")
346        print(f"EM: {em} | F1: {f1:.3f}\n")
347
348
349def demo_perplexity():
350    """Demonstrate perplexity computation."""
351    print("\n" + "=" * 60)
352    print("DEMO 5: Perplexity")
353    print("=" * 60)
354
355    # Simulate log probabilities for different model qualities
356    # Good model: high probabilities (less negative log probs)
357    good_model = [-0.1, -0.2, -0.15, -0.3, -0.1, -0.2]
358
359    # Medium model
360    medium_model = [-1.0, -1.5, -1.2, -0.8, -1.3, -1.1]
361
362    # Poor model: low probabilities (very negative log probs)
363    poor_model = [-3.0, -4.0, -3.5, -4.2, -3.8, -3.9]
364
365    print("\nPerplexity for different model qualities:")
366    print("-" * 60)
367
368    for name, log_probs in [("Good", good_model), ("Medium", medium_model), ("Poor", poor_model)]:
369        ppl = perplexity(log_probs)
370        avg_prob = math.exp(sum(log_probs) / len(log_probs))
371
372        print(f"{name} model:")
373        print(f"  Avg log prob: {sum(log_probs)/len(log_probs):.3f}")
374        print(f"  Avg prob: {avg_prob:.3f}")
375        print(f"  Perplexity: {ppl:.2f}\n")
376
377
378def demo_classification():
379    """Demonstrate classification metrics."""
380    print("\n" + "=" * 60)
381    print("DEMO 6: Classification Metrics")
382    print("=" * 60)
383
384    # Simulate sentiment classification
385    y_true = [1, 0, 1, 1, 0, 1, 0, 0, 1, 1]
386    y_pred = [1, 0, 1, 0, 0, 1, 1, 0, 1, 1]
387
388    metrics = classification_metrics(y_true, y_pred)
389
390    print("\nBinary Classification Results:")
391    print("-" * 60)
392    print(f"Accuracy:  {metrics['accuracy']:.3f}")
393    print(f"Precision: {metrics['precision']:.3f}")
394    print(f"Recall:    {metrics['recall']:.3f}")
395    print(f"F1 Score:  {metrics['f1']:.3f}")
396
397    print(f"\nConfusion Matrix:")
398    print(f"  TP: {metrics['tp']}  FP: {metrics['fp']}")
399    print(f"  FN: {metrics['fn']}  TN: {metrics['tn']}")
400
401
402def demo_summarization_eval():
403    """Evaluate summarization task."""
404    print("\n" + "=" * 60)
405    print("DEMO 7: Summarization Evaluation")
406    print("=" * 60)
407
408    reference = "Machine learning is a subset of artificial intelligence. It enables systems to learn from data."
409
410    summaries = [
411        "Machine learning is part of AI and allows systems to learn from data.",
412        "ML is a type of AI that learns from data.",
413        "Artificial intelligence includes machine learning.",
414    ]
415
416    print(f"\nReference: {reference}\n")
417
418    for i, summary in enumerate(summaries, 1):
419        bleu = bleu_score(reference, summary, max_n=2)
420        rouge = rouge_l_score(reference, summary)
421
422        print(f"Summary {i}: {summary}")
423        print(f"  BLEU-2: {bleu:.3f}")
424        print(f"  ROUGE-L F1: {rouge['f1']:.3f}\n")
425
426
427def demo_translation_eval():
428    """Evaluate machine translation."""
429    print("\n" + "=" * 60)
430    print("DEMO 8: Translation Evaluation")
431    print("=" * 60)
432
433    # Example: French to English
434    reference = "The cat is on the table"
435
436    translations = [
437        "The cat is on the table",        # Perfect
438        "The cat sits on the table",      # Good
439        "A cat is on a table",            # Medium
440        "Cat table on",                   # Poor
441    ]
442
443    print(f"\nReference: {reference}\n")
444
445    for i, trans in enumerate(translations, 1):
446        bleu = bleu_score(reference, trans, max_n=4)
447        rouge = rouge_l_score(reference, trans)
448
449        print(f"Translation {i}: {trans}")
450        print(f"  BLEU-4: {bleu:.3f}")
451        print(f"  ROUGE-L F1: {rouge['f1']:.3f}\n")
452
453
454def demo_multi_reference():
455    """Evaluate with multiple references."""
456    print("\n" + "=" * 60)
457    print("DEMO 9: Multi-Reference Evaluation")
458    print("=" * 60)
459
460    references = [
461        "It is raining heavily",
462        "The rain is very strong",
463        "Heavy rainfall is occurring",
464    ]
465
466    candidate = "It is raining a lot"
467
468    print(f"\nCandidate: {candidate}\n")
469    print("References:")
470    for i, ref in enumerate(references, 1):
471        print(f"  {i}. {ref}")
472
473    print("\nScores against each reference:")
474    print("-" * 60)
475
476    bleu_scores = []
477    rouge_scores = []
478
479    for i, ref in enumerate(references, 1):
480        bleu = bleu_score(ref, candidate, max_n=2)
481        rouge = rouge_l_score(ref, candidate)
482
483        bleu_scores.append(bleu)
484        rouge_scores.append(rouge['f1'])
485
486        print(f"Reference {i}: BLEU={bleu:.3f}, ROUGE-L F1={rouge['f1']:.3f}")
487
488    # Best score across references
489    print(f"\nBest scores:")
490    print(f"  BLEU: {max(bleu_scores):.3f}")
491    print(f"  ROUGE-L F1: {max(rouge_scores):.3f}")
492
493
494if __name__ == "__main__":
495    print("\n" + "=" * 60)
496    print("Foundation Models: Evaluation Metrics")
497    print("=" * 60)
498
499    demo_bleu()
500    demo_rouge()
501    demo_exact_match()
502    demo_f1_token()
503    demo_perplexity()
504    demo_classification()
505    demo_summarization_eval()
506    demo_translation_eval()
507    demo_multi_reference()
508
509    print("\n" + "=" * 60)
510    print("Key Takeaways:")
511    print("=" * 60)
512    print("1. BLEU: n-gram overlap (translation, generation)")
513    print("2. ROUGE-L: Longest common subsequence (summarization)")
514    print("3. Exact Match: Binary correctness (QA)")
515    print("4. Token F1: Partial credit for overlap (span QA)")
516    print("5. Perplexity: Model uncertainty (lower is better)")
517    print("6. Classification: Accuracy, precision, recall, F1")
518    print("7. Multi-reference: Take max/avg across references")
519    print("8. Choose metrics appropriate for task")
520    print("=" * 60)