25_loss_functions.py

Download
python 575 lines 20.2 KB
  1"""
  225. Loss Functions Comparison
  3
  4Demonstrates various loss functions used in deep learning:
  5- Regression: MSE, MAE, Huber, Log-Cosh
  6- Classification: BCE, Cross-Entropy, Focal Loss, Label Smoothing
  7- Metric Learning: Contrastive, Triplet, InfoNCE
  8- Segmentation: Dice Loss
  9- Custom loss implementation patterns
 10"""
 11
 12import torch
 13import torch.nn as nn
 14import torch.nn.functional as F
 15import numpy as np
 16
 17
 18def print_section(title):
 19    """Print formatted section header."""
 20    print("\n" + "=" * 60)
 21    print(f"  {title}")
 22    print("=" * 60)
 23
 24
 25# ============================================================================
 26# 1. Regression Losses: MSE vs MAE vs Huber
 27# ============================================================================
 28print_section("1. Regression Losses: Handling Outliers")
 29
 30# Create synthetic data with outliers
 31torch.manual_seed(42)
 32predictions = torch.randn(100)
 33targets = predictions + torch.randn(100) * 0.1
 34
 35# Add outliers
 36outlier_indices = torch.tensor([10, 25, 50, 75])
 37targets[outlier_indices] += torch.randn(4) * 5.0
 38
 39print(f"Data shape: {predictions.shape}")
 40print(f"Outlier indices: {outlier_indices.tolist()}")
 41print(f"Outlier values: {targets[outlier_indices].tolist()}")
 42
 43# MSE Loss (L2)
 44mse_loss = F.mse_loss(predictions, targets)
 45print(f"\nMSE Loss: {mse_loss.item():.4f}")
 46print("  → Squares errors, sensitive to outliers")
 47
 48# MAE Loss (L1)
 49mae_loss = F.l1_loss(predictions, targets)
 50print(f"\nMAE Loss: {mae_loss.item():.4f}")
 51print("  → Absolute errors, robust to outliers")
 52
 53# Huber Loss (smooth L1)
 54huber_loss = F.huber_loss(predictions, targets, delta=1.0)
 55print(f"\nHuber Loss (delta=1.0): {huber_loss.item():.4f}")
 56print("  → L2 for small errors, L1 for large errors")
 57
 58# Manual Huber implementation
 59def huber_loss_manual(pred, target, delta=1.0):
 60    """Manual Huber loss implementation."""
 61    error = torch.abs(pred - target)
 62    quadratic = torch.min(error, torch.tensor(delta))
 63    linear = error - quadratic
 64    return torch.mean(0.5 * quadratic**2 + delta * linear)
 65
 66huber_manual = huber_loss_manual(predictions, targets, delta=1.0)
 67print(f"Huber Loss (manual): {huber_manual.item():.4f}")
 68
 69# Log-Cosh Loss
 70def log_cosh_loss(pred, target):
 71    """Log-Cosh loss: log(cosh(x)) ā‰ˆ smooth L1."""
 72    error = pred - target
 73    return torch.mean(torch.log(torch.cosh(error)))
 74
 75log_cosh = log_cosh_loss(predictions, targets)
 76print(f"\nLog-Cosh Loss: {log_cosh.item():.4f}")
 77print("  → Smooth approximation to MAE, less sensitive to outliers")
 78
 79
 80# ============================================================================
 81# 2. Cross-Entropy Loss: Numerical Stability
 82# ============================================================================
 83print_section("2. Cross-Entropy Loss with Numerical Stability")
 84
 85torch.manual_seed(42)
 86batch_size, num_classes = 8, 10
 87logits = torch.randn(batch_size, num_classes)
 88targets = torch.randint(0, num_classes, (batch_size,))
 89
 90print(f"Logits shape: {logits.shape}")
 91print(f"Targets: {targets.tolist()}")
 92
 93# Built-in CrossEntropyLoss
 94ce_loss = F.cross_entropy(logits, targets)
 95print(f"\nBuilt-in CE Loss: {ce_loss.item():.4f}")
 96
 97# Manual implementation (naive - unstable)
 98def cross_entropy_naive(logits, targets):
 99    """Naive CE implementation (can cause numerical overflow)."""
100    probs = torch.exp(logits) / torch.exp(logits).sum(dim=1, keepdim=True)
101    log_probs = torch.log(probs)
102    nll = -log_probs[range(len(targets)), targets]
103    return nll.mean()
104
105ce_naive = cross_entropy_naive(logits, targets)
106print(f"Naive CE Loss: {ce_naive.item():.4f}")
107
108# Manual implementation (stable with log-sum-exp trick)
109def cross_entropy_stable(logits, targets):
110    """Stable CE implementation using log-sum-exp trick."""
111    # Subtract max for numerical stability
112    logits_max = logits.max(dim=1, keepdim=True)[0]
113    log_sum_exp = torch.log(torch.exp(logits - logits_max).sum(dim=1, keepdim=True))
114    log_probs = logits - logits_max - log_sum_exp
115    nll = -log_probs[range(len(targets)), targets]
116    return nll.mean()
117
118ce_stable = cross_entropy_stable(logits, targets)
119print(f"Stable CE Loss: {ce_stable.item():.4f}")
120
121# With extreme logits (show stability)
122extreme_logits = logits * 100  # Scale up to cause overflow in naive version
123ce_builtin_extreme = F.cross_entropy(extreme_logits, targets)
124ce_stable_extreme = cross_entropy_stable(extreme_logits, targets)
125print(f"\nExtreme logits (Ɨ100):")
126print(f"  Built-in CE: {ce_builtin_extreme.item():.4f}")
127print(f"  Stable CE:   {ce_stable_extreme.item():.4f}")
128
129
130# ============================================================================
131# 3. Focal Loss: Handling Class Imbalance
132# ============================================================================
133print_section("3. Focal Loss for Hard Examples")
134
135def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
136    """
137    Focal Loss for addressing class imbalance.
138    FL(p_t) = -α(1 - p_t)^γ log(p_t)
139
140    Args:
141        alpha: Weighting factor for positive class
142        gamma: Focusing parameter (higher → more focus on hard examples)
143    """
144    ce_loss = F.cross_entropy(logits, targets, reduction='none')
145    probs = F.softmax(logits, dim=1)
146    p_t = probs[range(len(targets)), targets]
147
148    # Focal term: (1 - p_t)^gamma
149    focal_term = (1 - p_t) ** gamma
150
151    # Apply alpha weighting
152    loss = alpha * focal_term * ce_loss
153    return loss.mean()
154
155torch.manual_seed(42)
156# Create predictions with varying confidence
157easy_logits = torch.randn(8, 10) * 0.5  # Low variance → confident
158easy_logits[range(8), targets] += 5.0   # Correct class has high logit
159
160hard_logits = torch.randn(8, 10)        # Uncertain predictions
161
162print("Easy examples (high confidence):")
163ce_easy = F.cross_entropy(easy_logits, targets)
164focal_easy = focal_loss(easy_logits, targets, gamma=2.0)
165print(f"  CE Loss:    {ce_easy.item():.4f}")
166print(f"  Focal Loss: {focal_easy.item():.4f}")
167print(f"  Reduction:  {(ce_easy - focal_easy) / ce_easy * 100:.1f}%")
168
169print("\nHard examples (uncertain):")
170ce_hard = F.cross_entropy(hard_logits, targets)
171focal_hard = focal_loss(hard_logits, targets, gamma=2.0)
172print(f"  CE Loss:    {ce_hard.item():.4f}")
173print(f"  Focal Loss: {focal_hard.item():.4f}")
174print(f"  Reduction:  {(ce_hard - focal_hard) / ce_hard * 100:.1f}%")
175
176print("\nEffect of gamma parameter:")
177for gamma in [0.0, 1.0, 2.0, 5.0]:
178    fl = focal_loss(easy_logits, targets, gamma=gamma)
179    print(f"  γ={gamma:.1f}: {fl.item():.4f}")
180
181
182# ============================================================================
183# 4. Label Smoothing: Confidence Calibration
184# ============================================================================
185print_section("4. Label Smoothing for Calibration")
186
187def cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1):
188    """
189    Cross-entropy with label smoothing.
190    Instead of [0, 0, 1, 0], use [ε/K, ε/K, 1-ε+ε/K, ε/K]
191    where ε is smoothing factor, K is num_classes.
192    """
193    num_classes = logits.size(-1)
194    confidence = 1.0 - smoothing
195
196    # Create smoothed labels
197    smooth_labels = torch.full_like(logits, smoothing / num_classes)
198    smooth_labels.scatter_(1, targets.unsqueeze(1), confidence)
199
200    # Compute loss
201    log_probs = F.log_softmax(logits, dim=1)
202    loss = -(smooth_labels * log_probs).sum(dim=1)
203    return loss.mean()
204
205torch.manual_seed(42)
206logits = torch.randn(8, 10)
207targets = torch.randint(0, 10, (8,))
208
209print("Without label smoothing:")
210ce_standard = F.cross_entropy(logits, targets)
211probs_standard = F.softmax(logits, dim=1)
212confidence_standard = probs_standard[range(8), targets].mean()
213print(f"  CE Loss: {ce_standard.item():.4f}")
214print(f"  Avg confidence on true class: {confidence_standard.item():.4f}")
215
216print("\nWith label smoothing (ε=0.1):")
217ce_smooth = cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1)
218print(f"  CE Loss: {ce_smooth.item():.4f}")
219print(f"  → Encourages model to be less confident")
220
221print("\nSmoothing factor comparison:")
222for eps in [0.0, 0.05, 0.1, 0.2, 0.3]:
223    loss = cross_entropy_with_label_smoothing(logits, targets, smoothing=eps)
224    print(f"  ε={eps:.2f}: {loss.item():.4f}")
225
226
227# ============================================================================
228# 5. Contrastive Loss: Metric Learning for Pairs
229# ============================================================================
230print_section("5. Contrastive Loss for Similarity Learning")
231
232def contrastive_loss(embedding1, embedding2, label, margin=1.0):
233    """
234    Contrastive loss for pairs.
235    L = (1-y) * 0.5 * D^2 + y * 0.5 * max(0, margin - D)^2
236    where y=0 for similar pairs, y=1 for dissimilar pairs, D is distance.
237    """
238    distance = F.pairwise_distance(embedding1, embedding2)
239
240    # Similar pairs: minimize distance
241    loss_similar = (1 - label) * 0.5 * distance ** 2
242
243    # Dissimilar pairs: push apart beyond margin
244    loss_dissimilar = label * 0.5 * torch.clamp(margin - distance, min=0) ** 2
245
246    return (loss_similar + loss_dissimilar).mean(), distance.mean()
247
248torch.manual_seed(42)
249embedding_dim = 128
250
251# Positive pairs (similar)
252emb1_pos = torch.randn(16, embedding_dim)
253emb2_pos = emb1_pos + torch.randn(16, embedding_dim) * 0.1  # Small noise
254labels_pos = torch.zeros(16)  # Label 0 = similar
255
256# Negative pairs (dissimilar)
257emb1_neg = torch.randn(16, embedding_dim)
258emb2_neg = torch.randn(16, embedding_dim)  # Independent
259labels_neg = torch.ones(16)  # Label 1 = dissimilar
260
261# Combine
262emb1 = torch.cat([emb1_pos, emb1_neg])
263emb2 = torch.cat([emb2_pos, emb2_neg])
264labels = torch.cat([labels_pos, labels_neg])
265
266loss, avg_dist = contrastive_loss(emb1, emb2, labels, margin=1.0)
267
268print(f"Total pairs: {len(labels)}")
269print(f"  Positive pairs: {(labels == 0).sum().item()}")
270print(f"  Negative pairs: {(labels == 1).sum().item()}")
271print(f"\nContrastive Loss: {loss.item():.4f}")
272print(f"Average pairwise distance: {avg_dist.item():.4f}")
273
274# Check distances separately
275pos_dist = F.pairwise_distance(emb1_pos, emb2_pos).mean()
276neg_dist = F.pairwise_distance(emb1_neg, emb2_neg).mean()
277print(f"\nPositive pair distance: {pos_dist.item():.4f} (should be small)")
278print(f"Negative pair distance: {neg_dist.item():.4f} (should be > margin)")
279
280
281# ============================================================================
282# 6. Triplet Loss: Anchor/Positive/Negative
283# ============================================================================
284print_section("6. Triplet Loss for Ranking")
285
286def triplet_loss(anchor, positive, negative, margin=1.0):
287    """
288    Triplet loss: L = max(0, D(a,p) - D(a,n) + margin)
289    Push positive closer than negative by at least margin.
290    """
291    distance_pos = F.pairwise_distance(anchor, positive)
292    distance_neg = F.pairwise_distance(anchor, negative)
293
294    loss = torch.clamp(distance_pos - distance_neg + margin, min=0)
295    return loss.mean(), distance_pos.mean(), distance_neg.mean()
296
297torch.manual_seed(42)
298num_triplets = 16
299embedding_dim = 128
300
301# Anchors
302anchors = torch.randn(num_triplets, embedding_dim)
303
304# Positives: same class, small perturbation
305positives = anchors + torch.randn(num_triplets, embedding_dim) * 0.2
306
307# Negatives: different class
308negatives = torch.randn(num_triplets, embedding_dim)
309
310loss, d_pos, d_neg = triplet_loss(anchors, positives, negatives, margin=1.0)
311
312print(f"Triplets: {num_triplets}")
313print(f"Embedding dim: {embedding_dim}")
314print(f"\nTriplet Loss: {loss.item():.4f}")
315print(f"Avg D(anchor, positive): {d_pos.item():.4f}")
316print(f"Avg D(anchor, negative): {d_neg.item():.4f}")
317print(f"Margin satisfaction: {(d_neg - d_pos).item():.4f} (should be > margin=1.0)")
318
319# Built-in triplet margin loss
320loss_builtin = F.triplet_margin_loss(anchors, positives, negatives, margin=1.0)
321print(f"\nBuilt-in TripletMarginLoss: {loss_builtin.item():.4f}")
322
323print("\nEffect of margin:")
324for m in [0.5, 1.0, 2.0, 5.0]:
325    loss_m, _, _ = triplet_loss(anchors, positives, negatives, margin=m)
326    print(f"  margin={m:.1f}: {loss_m.item():.4f}")
327
328
329# ============================================================================
330# 7. Dice Loss: Segmentation
331# ============================================================================
332print_section("7. Dice Loss for Binary Segmentation")
333
334def dice_loss(pred, target, smooth=1e-6):
335    """
336    Dice loss for segmentation.
337    Dice = 2 * |X ∩ Y| / (|X| + |Y|)
338    Loss = 1 - Dice
339    """
340    pred = torch.sigmoid(pred)  # Convert logits to probabilities
341
342    # Flatten
343    pred_flat = pred.view(-1)
344    target_flat = target.view(-1)
345
346    intersection = (pred_flat * target_flat).sum()
347    union = pred_flat.sum() + target_flat.sum()
348
349    dice = (2.0 * intersection + smooth) / (union + smooth)
350    return 1 - dice
351
352torch.manual_seed(42)
353batch_size, height, width = 4, 64, 64
354
355# Create synthetic segmentation masks
356target_masks = torch.zeros(batch_size, height, width)
357# Add some positive regions
358for i in range(batch_size):
359    x, y = torch.randint(10, 50, (2,))
360    w, h = torch.randint(10, 20, (2,))
361    target_masks[i, x:x+w, y:y+h] = 1.0
362
363# Predicted logits (noisy version of target)
364pred_logits = target_masks * 5.0 + torch.randn(batch_size, height, width) * 2.0
365
366print(f"Mask shape: {target_masks.shape}")
367print(f"Positive pixels: {target_masks.sum().item()} / {target_masks.numel()}")
368
369# Compute losses
370dice = dice_loss(pred_logits, target_masks)
371bce = F.binary_cross_entropy_with_logits(pred_logits, target_masks)
372
373print(f"\nDice Loss: {dice.item():.4f}")
374print(f"BCE Loss:  {bce.item():.4f}")
375print("  → Dice is better for imbalanced segmentation")
376
377# Dice coefficient (metric, not loss)
378pred_probs = torch.sigmoid(pred_logits)
379pred_binary = (pred_probs > 0.5).float()
380dice_coef = 1 - dice_loss(pred_logits, target_masks)
381print(f"\nDice Coefficient: {dice_coef.item():.4f} (higher is better)")
382
383# IoU for comparison
384intersection = (pred_binary * target_masks).sum()
385union = pred_binary.sum() + target_masks.sum() - intersection
386iou = intersection / (union + 1e-6)
387print(f"IoU: {iou.item():.4f}")
388
389
390# ============================================================================
391# 8. Custom Multi-Task Loss: Uncertainty Weighting
392# ============================================================================
393print_section("8. Multi-Task Loss with Learned Uncertainty")
394
395class MultiTaskLoss(nn.Module):
396    """
397    Multi-task loss with learned uncertainty weighting.
398    From "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al.)
399
400    L = (1/2Ļƒā‚Ā²)L₁ + (1/2Ļƒā‚‚Ā²)Lā‚‚ + log(Ļƒā‚Ļƒā‚‚)
401    where Ļƒā‚, Ļƒā‚‚ are learned task uncertainties.
402    """
403    def __init__(self, num_tasks=2):
404        super().__init__()
405        # Log variance parameters (learnable)
406        self.log_vars = nn.Parameter(torch.zeros(num_tasks))
407
408    def forward(self, losses):
409        """
410        Args:
411            losses: List of task losses [L1, L2, ...]
412        """
413        weighted_losses = []
414        for i, loss in enumerate(losses):
415            # Precision weighting: exp(-log_var) = 1/σ²
416            precision = torch.exp(-self.log_vars[i])
417            weighted_loss = precision * loss + self.log_vars[i]
418            weighted_losses.append(weighted_loss)
419
420        return torch.stack(weighted_losses).sum()
421
422# Simulate two tasks
423torch.manual_seed(42)
424
425# Task 1: Regression (MSE)
426pred1 = torch.randn(32, 10)
427target1 = torch.randn(32, 10)
428loss1 = F.mse_loss(pred1, target1)
429
430# Task 2: Classification (CE)
431pred2 = torch.randn(32, 10)
432target2 = torch.randint(0, 10, (32,))
433loss2 = F.cross_entropy(pred2, target2)
434
435print("Task losses (unweighted):")
436print(f"  Task 1 (regression): {loss1.item():.4f}")
437print(f"  Task 2 (classification): {loss2.item():.4f}")
438print(f"  Simple sum: {(loss1 + loss2).item():.4f}")
439
440# Multi-task loss with learned weights
441multi_task_loss = MultiTaskLoss(num_tasks=2)
442weighted_loss = multi_task_loss([loss1, loss2])
443
444print(f"\nWeighted multi-task loss: {weighted_loss.item():.4f}")
445print(f"Task uncertainties (σ):")
446print(f"  Task 1: {torch.exp(0.5 * multi_task_loss.log_vars[0]).item():.4f}")
447print(f"  Task 2: {torch.exp(0.5 * multi_task_loss.log_vars[1]).item():.4f}")
448
449# Simulate training: task 2 is harder (higher loss)
450print("\nAfter simulated training (task 2 is harder):")
451multi_task_loss.log_vars.data = torch.tensor([0.0, 1.0])  # Higher uncertainty for task 2
452weighted_loss = multi_task_loss([loss1, loss2])
453print(f"Weighted loss: {weighted_loss.item():.4f}")
454print(f"  Task 1 weight (1/σ²): {torch.exp(-multi_task_loss.log_vars[0]).item():.4f}")
455print(f"  Task 2 weight (1/σ²): {torch.exp(-multi_task_loss.log_vars[1]).item():.4f}")
456print("  → Task 2 gets lower weight due to higher uncertainty")
457
458
459# ============================================================================
460# 9. Loss Landscape Visualization
461# ============================================================================
462print_section("9. Loss Landscape Comparison")
463
464def quadratic_loss(w1, w2):
465    """Simple quadratic loss: L = w1² + w2²"""
466    return w1**2 + w2**2
467
468def rosenbrock_loss(w1, w2, a=1, b=100):
469    """Rosenbrock function: non-convex, valley-shaped"""
470    return (a - w1)**2 + b * (w2 - w1**2)**2
471
472def loss_landscape_stats(loss_fn, grid_size=50, range_val=2.0):
473    """Compute statistics for a loss landscape."""
474    w1 = np.linspace(-range_val, range_val, grid_size)
475    w2 = np.linspace(-range_val, range_val, grid_size)
476
477    losses = np.zeros((grid_size, grid_size))
478    for i, w1_val in enumerate(w1):
479        for j, w2_val in enumerate(w2):
480            losses[i, j] = loss_fn(w1_val, w2_val)
481
482    return {
483        'min': losses.min(),
484        'max': losses.max(),
485        'mean': losses.mean(),
486        'std': losses.std()
487    }
488
489print("Quadratic loss landscape:")
490quad_stats = loss_landscape_stats(quadratic_loss)
491print(f"  Min: {quad_stats['min']:.4f}")
492print(f"  Max: {quad_stats['max']:.4f}")
493print(f"  Mean: {quad_stats['mean']:.4f}")
494print(f"  Std: {quad_stats['std']:.4f}")
495print("  → Convex, smooth, single minimum")
496
497print("\nRosenbrock loss landscape:")
498rosen_stats = loss_landscape_stats(rosenbrock_loss, range_val=2.0)
499print(f"  Min: {rosen_stats['min']:.4f}")
500print(f"  Max: {rosen_stats['max']:.4f}")
501print(f"  Mean: {rosen_stats['mean']:.4f}")
502print(f"  Std: {rosen_stats['std']:.4f}")
503print("  → Non-convex, narrow valley, harder optimization")
504
505# Simple gradient descent comparison
506print("\nGradient descent comparison (10 steps, lr=0.01):")
507w_quad = torch.tensor([1.5, 1.5], requires_grad=True)
508w_rosen = torch.tensor([1.5, 1.5], requires_grad=True)
509
510lr = 0.01
511for step in range(10):
512    # Quadratic
513    loss_q = w_quad[0]**2 + w_quad[1]**2
514    loss_q.backward()
515    with torch.no_grad():
516        w_quad -= lr * w_quad.grad
517        w_quad.grad.zero_()
518
519    # Rosenbrock
520    loss_r = (1 - w_rosen[0])**2 + 100 * (w_rosen[1] - w_rosen[0]**2)**2
521    loss_r.backward()
522    with torch.no_grad():
523        w_rosen -= lr * w_rosen.grad
524        w_rosen.grad.zero_()
525
526print(f"Quadratic final: w = [{w_quad[0].item():.4f}, {w_quad[1].item():.4f}]")
527print(f"  Distance to optimum (0,0): {torch.norm(w_quad).item():.4f}")
528print(f"Rosenbrock final: w = [{w_rosen[0].item():.4f}, {w_rosen[1].item():.4f}]")
529print(f"  Distance to optimum (1,1): {torch.norm(w_rosen - torch.tensor([1.0, 1.0])).item():.4f}")
530
531
532# ============================================================================
533# Summary
534# ============================================================================
535print_section("Summary: Loss Function Selection Guide")
536
537print("""
538Regression:
539  - MSE: Standard, sensitive to outliers
540  - MAE: Robust to outliers, but non-smooth at zero
541  - Huber: Best of both worlds, smooth and robust
542  - Log-Cosh: Smooth MAE alternative
543
544Classification:
545  - Cross-Entropy: Standard, use with softmax
546  - Focal Loss: Class imbalance, focus on hard examples
547  - Label Smoothing: Improve calibration, prevent overconfidence
548
549Metric Learning:
550  - Contrastive: Pairwise similarity
551  - Triplet: Ranking, requires triplets
552  - InfoNCE: Contrastive with multiple negatives
553
554Segmentation:
555  - Dice Loss: Handles class imbalance better than BCE
556  - Focal Loss: Can be adapted for segmentation
557  - Combo: Dice + BCE often works best
558
559Multi-Task:
560  - Uncertainty weighting: Let model learn task importance
561  - Manual weighting: Domain knowledge required
562  - Gradient balancing: GradNorm, PCGrad
563
564Implementation Tips:
565  1. Always use numerical stability tricks (log-sum-exp)
566  2. Check for NaN/inf during training
567  3. Normalize losses if combining multiple terms
568  4. Visualize loss landscapes to understand optimization
569  5. Consider task-specific requirements (imbalance, outliers, etc.)
570""")
571
572print("\n" + "=" * 60)
573print("Loss functions demonstration complete!")
574print("=" * 60)