1"""
225. Loss Functions Comparison
3
4Demonstrates various loss functions used in deep learning:
5- Regression: MSE, MAE, Huber, Log-Cosh
6- Classification: BCE, Cross-Entropy, Focal Loss, Label Smoothing
7- Metric Learning: Contrastive, Triplet, InfoNCE
8- Segmentation: Dice Loss
9- Custom loss implementation patterns
10"""
11
12import torch
13import torch.nn as nn
14import torch.nn.functional as F
15import numpy as np
16
17
18def print_section(title):
19 """Print formatted section header."""
20 print("\n" + "=" * 60)
21 print(f" {title}")
22 print("=" * 60)
23
24
25# ============================================================================
26# 1. Regression Losses: MSE vs MAE vs Huber
27# ============================================================================
28print_section("1. Regression Losses: Handling Outliers")
29
30# Create synthetic data with outliers
31torch.manual_seed(42)
32predictions = torch.randn(100)
33targets = predictions + torch.randn(100) * 0.1
34
35# Add outliers
36outlier_indices = torch.tensor([10, 25, 50, 75])
37targets[outlier_indices] += torch.randn(4) * 5.0
38
39print(f"Data shape: {predictions.shape}")
40print(f"Outlier indices: {outlier_indices.tolist()}")
41print(f"Outlier values: {targets[outlier_indices].tolist()}")
42
43# MSE Loss (L2)
44mse_loss = F.mse_loss(predictions, targets)
45print(f"\nMSE Loss: {mse_loss.item():.4f}")
46print(" ā Squares errors, sensitive to outliers")
47
48# MAE Loss (L1)
49mae_loss = F.l1_loss(predictions, targets)
50print(f"\nMAE Loss: {mae_loss.item():.4f}")
51print(" ā Absolute errors, robust to outliers")
52
53# Huber Loss (smooth L1)
54huber_loss = F.huber_loss(predictions, targets, delta=1.0)
55print(f"\nHuber Loss (delta=1.0): {huber_loss.item():.4f}")
56print(" ā L2 for small errors, L1 for large errors")
57
58# Manual Huber implementation
59def huber_loss_manual(pred, target, delta=1.0):
60 """Manual Huber loss implementation."""
61 error = torch.abs(pred - target)
62 quadratic = torch.min(error, torch.tensor(delta))
63 linear = error - quadratic
64 return torch.mean(0.5 * quadratic**2 + delta * linear)
65
66huber_manual = huber_loss_manual(predictions, targets, delta=1.0)
67print(f"Huber Loss (manual): {huber_manual.item():.4f}")
68
69# Log-Cosh Loss
70def log_cosh_loss(pred, target):
71 """Log-Cosh loss: log(cosh(x)) ā smooth L1."""
72 error = pred - target
73 return torch.mean(torch.log(torch.cosh(error)))
74
75log_cosh = log_cosh_loss(predictions, targets)
76print(f"\nLog-Cosh Loss: {log_cosh.item():.4f}")
77print(" ā Smooth approximation to MAE, less sensitive to outliers")
78
79
80# ============================================================================
81# 2. Cross-Entropy Loss: Numerical Stability
82# ============================================================================
83print_section("2. Cross-Entropy Loss with Numerical Stability")
84
85torch.manual_seed(42)
86batch_size, num_classes = 8, 10
87logits = torch.randn(batch_size, num_classes)
88targets = torch.randint(0, num_classes, (batch_size,))
89
90print(f"Logits shape: {logits.shape}")
91print(f"Targets: {targets.tolist()}")
92
93# Built-in CrossEntropyLoss
94ce_loss = F.cross_entropy(logits, targets)
95print(f"\nBuilt-in CE Loss: {ce_loss.item():.4f}")
96
97# Manual implementation (naive - unstable)
98def cross_entropy_naive(logits, targets):
99 """Naive CE implementation (can cause numerical overflow)."""
100 probs = torch.exp(logits) / torch.exp(logits).sum(dim=1, keepdim=True)
101 log_probs = torch.log(probs)
102 nll = -log_probs[range(len(targets)), targets]
103 return nll.mean()
104
105ce_naive = cross_entropy_naive(logits, targets)
106print(f"Naive CE Loss: {ce_naive.item():.4f}")
107
108# Manual implementation (stable with log-sum-exp trick)
109def cross_entropy_stable(logits, targets):
110 """Stable CE implementation using log-sum-exp trick."""
111 # Subtract max for numerical stability
112 logits_max = logits.max(dim=1, keepdim=True)[0]
113 log_sum_exp = torch.log(torch.exp(logits - logits_max).sum(dim=1, keepdim=True))
114 log_probs = logits - logits_max - log_sum_exp
115 nll = -log_probs[range(len(targets)), targets]
116 return nll.mean()
117
118ce_stable = cross_entropy_stable(logits, targets)
119print(f"Stable CE Loss: {ce_stable.item():.4f}")
120
121# With extreme logits (show stability)
122extreme_logits = logits * 100 # Scale up to cause overflow in naive version
123ce_builtin_extreme = F.cross_entropy(extreme_logits, targets)
124ce_stable_extreme = cross_entropy_stable(extreme_logits, targets)
125print(f"\nExtreme logits (Ć100):")
126print(f" Built-in CE: {ce_builtin_extreme.item():.4f}")
127print(f" Stable CE: {ce_stable_extreme.item():.4f}")
128
129
130# ============================================================================
131# 3. Focal Loss: Handling Class Imbalance
132# ============================================================================
133print_section("3. Focal Loss for Hard Examples")
134
135def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
136 """
137 Focal Loss for addressing class imbalance.
138 FL(p_t) = -α(1 - p_t)^γ log(p_t)
139
140 Args:
141 alpha: Weighting factor for positive class
142 gamma: Focusing parameter (higher ā more focus on hard examples)
143 """
144 ce_loss = F.cross_entropy(logits, targets, reduction='none')
145 probs = F.softmax(logits, dim=1)
146 p_t = probs[range(len(targets)), targets]
147
148 # Focal term: (1 - p_t)^gamma
149 focal_term = (1 - p_t) ** gamma
150
151 # Apply alpha weighting
152 loss = alpha * focal_term * ce_loss
153 return loss.mean()
154
155torch.manual_seed(42)
156# Create predictions with varying confidence
157easy_logits = torch.randn(8, 10) * 0.5 # Low variance ā confident
158easy_logits[range(8), targets] += 5.0 # Correct class has high logit
159
160hard_logits = torch.randn(8, 10) # Uncertain predictions
161
162print("Easy examples (high confidence):")
163ce_easy = F.cross_entropy(easy_logits, targets)
164focal_easy = focal_loss(easy_logits, targets, gamma=2.0)
165print(f" CE Loss: {ce_easy.item():.4f}")
166print(f" Focal Loss: {focal_easy.item():.4f}")
167print(f" Reduction: {(ce_easy - focal_easy) / ce_easy * 100:.1f}%")
168
169print("\nHard examples (uncertain):")
170ce_hard = F.cross_entropy(hard_logits, targets)
171focal_hard = focal_loss(hard_logits, targets, gamma=2.0)
172print(f" CE Loss: {ce_hard.item():.4f}")
173print(f" Focal Loss: {focal_hard.item():.4f}")
174print(f" Reduction: {(ce_hard - focal_hard) / ce_hard * 100:.1f}%")
175
176print("\nEffect of gamma parameter:")
177for gamma in [0.0, 1.0, 2.0, 5.0]:
178 fl = focal_loss(easy_logits, targets, gamma=gamma)
179 print(f" γ={gamma:.1f}: {fl.item():.4f}")
180
181
182# ============================================================================
183# 4. Label Smoothing: Confidence Calibration
184# ============================================================================
185print_section("4. Label Smoothing for Calibration")
186
187def cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1):
188 """
189 Cross-entropy with label smoothing.
190 Instead of [0, 0, 1, 0], use [ε/K, ε/K, 1-ε+ε/K, ε/K]
191 where ε is smoothing factor, K is num_classes.
192 """
193 num_classes = logits.size(-1)
194 confidence = 1.0 - smoothing
195
196 # Create smoothed labels
197 smooth_labels = torch.full_like(logits, smoothing / num_classes)
198 smooth_labels.scatter_(1, targets.unsqueeze(1), confidence)
199
200 # Compute loss
201 log_probs = F.log_softmax(logits, dim=1)
202 loss = -(smooth_labels * log_probs).sum(dim=1)
203 return loss.mean()
204
205torch.manual_seed(42)
206logits = torch.randn(8, 10)
207targets = torch.randint(0, 10, (8,))
208
209print("Without label smoothing:")
210ce_standard = F.cross_entropy(logits, targets)
211probs_standard = F.softmax(logits, dim=1)
212confidence_standard = probs_standard[range(8), targets].mean()
213print(f" CE Loss: {ce_standard.item():.4f}")
214print(f" Avg confidence on true class: {confidence_standard.item():.4f}")
215
216print("\nWith label smoothing (ε=0.1):")
217ce_smooth = cross_entropy_with_label_smoothing(logits, targets, smoothing=0.1)
218print(f" CE Loss: {ce_smooth.item():.4f}")
219print(f" ā Encourages model to be less confident")
220
221print("\nSmoothing factor comparison:")
222for eps in [0.0, 0.05, 0.1, 0.2, 0.3]:
223 loss = cross_entropy_with_label_smoothing(logits, targets, smoothing=eps)
224 print(f" ε={eps:.2f}: {loss.item():.4f}")
225
226
227# ============================================================================
228# 5. Contrastive Loss: Metric Learning for Pairs
229# ============================================================================
230print_section("5. Contrastive Loss for Similarity Learning")
231
232def contrastive_loss(embedding1, embedding2, label, margin=1.0):
233 """
234 Contrastive loss for pairs.
235 L = (1-y) * 0.5 * D^2 + y * 0.5 * max(0, margin - D)^2
236 where y=0 for similar pairs, y=1 for dissimilar pairs, D is distance.
237 """
238 distance = F.pairwise_distance(embedding1, embedding2)
239
240 # Similar pairs: minimize distance
241 loss_similar = (1 - label) * 0.5 * distance ** 2
242
243 # Dissimilar pairs: push apart beyond margin
244 loss_dissimilar = label * 0.5 * torch.clamp(margin - distance, min=0) ** 2
245
246 return (loss_similar + loss_dissimilar).mean(), distance.mean()
247
248torch.manual_seed(42)
249embedding_dim = 128
250
251# Positive pairs (similar)
252emb1_pos = torch.randn(16, embedding_dim)
253emb2_pos = emb1_pos + torch.randn(16, embedding_dim) * 0.1 # Small noise
254labels_pos = torch.zeros(16) # Label 0 = similar
255
256# Negative pairs (dissimilar)
257emb1_neg = torch.randn(16, embedding_dim)
258emb2_neg = torch.randn(16, embedding_dim) # Independent
259labels_neg = torch.ones(16) # Label 1 = dissimilar
260
261# Combine
262emb1 = torch.cat([emb1_pos, emb1_neg])
263emb2 = torch.cat([emb2_pos, emb2_neg])
264labels = torch.cat([labels_pos, labels_neg])
265
266loss, avg_dist = contrastive_loss(emb1, emb2, labels, margin=1.0)
267
268print(f"Total pairs: {len(labels)}")
269print(f" Positive pairs: {(labels == 0).sum().item()}")
270print(f" Negative pairs: {(labels == 1).sum().item()}")
271print(f"\nContrastive Loss: {loss.item():.4f}")
272print(f"Average pairwise distance: {avg_dist.item():.4f}")
273
274# Check distances separately
275pos_dist = F.pairwise_distance(emb1_pos, emb2_pos).mean()
276neg_dist = F.pairwise_distance(emb1_neg, emb2_neg).mean()
277print(f"\nPositive pair distance: {pos_dist.item():.4f} (should be small)")
278print(f"Negative pair distance: {neg_dist.item():.4f} (should be > margin)")
279
280
281# ============================================================================
282# 6. Triplet Loss: Anchor/Positive/Negative
283# ============================================================================
284print_section("6. Triplet Loss for Ranking")
285
286def triplet_loss(anchor, positive, negative, margin=1.0):
287 """
288 Triplet loss: L = max(0, D(a,p) - D(a,n) + margin)
289 Push positive closer than negative by at least margin.
290 """
291 distance_pos = F.pairwise_distance(anchor, positive)
292 distance_neg = F.pairwise_distance(anchor, negative)
293
294 loss = torch.clamp(distance_pos - distance_neg + margin, min=0)
295 return loss.mean(), distance_pos.mean(), distance_neg.mean()
296
297torch.manual_seed(42)
298num_triplets = 16
299embedding_dim = 128
300
301# Anchors
302anchors = torch.randn(num_triplets, embedding_dim)
303
304# Positives: same class, small perturbation
305positives = anchors + torch.randn(num_triplets, embedding_dim) * 0.2
306
307# Negatives: different class
308negatives = torch.randn(num_triplets, embedding_dim)
309
310loss, d_pos, d_neg = triplet_loss(anchors, positives, negatives, margin=1.0)
311
312print(f"Triplets: {num_triplets}")
313print(f"Embedding dim: {embedding_dim}")
314print(f"\nTriplet Loss: {loss.item():.4f}")
315print(f"Avg D(anchor, positive): {d_pos.item():.4f}")
316print(f"Avg D(anchor, negative): {d_neg.item():.4f}")
317print(f"Margin satisfaction: {(d_neg - d_pos).item():.4f} (should be > margin=1.0)")
318
319# Built-in triplet margin loss
320loss_builtin = F.triplet_margin_loss(anchors, positives, negatives, margin=1.0)
321print(f"\nBuilt-in TripletMarginLoss: {loss_builtin.item():.4f}")
322
323print("\nEffect of margin:")
324for m in [0.5, 1.0, 2.0, 5.0]:
325 loss_m, _, _ = triplet_loss(anchors, positives, negatives, margin=m)
326 print(f" margin={m:.1f}: {loss_m.item():.4f}")
327
328
329# ============================================================================
330# 7. Dice Loss: Segmentation
331# ============================================================================
332print_section("7. Dice Loss for Binary Segmentation")
333
334def dice_loss(pred, target, smooth=1e-6):
335 """
336 Dice loss for segmentation.
337 Dice = 2 * |X ā© Y| / (|X| + |Y|)
338 Loss = 1 - Dice
339 """
340 pred = torch.sigmoid(pred) # Convert logits to probabilities
341
342 # Flatten
343 pred_flat = pred.view(-1)
344 target_flat = target.view(-1)
345
346 intersection = (pred_flat * target_flat).sum()
347 union = pred_flat.sum() + target_flat.sum()
348
349 dice = (2.0 * intersection + smooth) / (union + smooth)
350 return 1 - dice
351
352torch.manual_seed(42)
353batch_size, height, width = 4, 64, 64
354
355# Create synthetic segmentation masks
356target_masks = torch.zeros(batch_size, height, width)
357# Add some positive regions
358for i in range(batch_size):
359 x, y = torch.randint(10, 50, (2,))
360 w, h = torch.randint(10, 20, (2,))
361 target_masks[i, x:x+w, y:y+h] = 1.0
362
363# Predicted logits (noisy version of target)
364pred_logits = target_masks * 5.0 + torch.randn(batch_size, height, width) * 2.0
365
366print(f"Mask shape: {target_masks.shape}")
367print(f"Positive pixels: {target_masks.sum().item()} / {target_masks.numel()}")
368
369# Compute losses
370dice = dice_loss(pred_logits, target_masks)
371bce = F.binary_cross_entropy_with_logits(pred_logits, target_masks)
372
373print(f"\nDice Loss: {dice.item():.4f}")
374print(f"BCE Loss: {bce.item():.4f}")
375print(" ā Dice is better for imbalanced segmentation")
376
377# Dice coefficient (metric, not loss)
378pred_probs = torch.sigmoid(pred_logits)
379pred_binary = (pred_probs > 0.5).float()
380dice_coef = 1 - dice_loss(pred_logits, target_masks)
381print(f"\nDice Coefficient: {dice_coef.item():.4f} (higher is better)")
382
383# IoU for comparison
384intersection = (pred_binary * target_masks).sum()
385union = pred_binary.sum() + target_masks.sum() - intersection
386iou = intersection / (union + 1e-6)
387print(f"IoU: {iou.item():.4f}")
388
389
390# ============================================================================
391# 8. Custom Multi-Task Loss: Uncertainty Weighting
392# ============================================================================
393print_section("8. Multi-Task Loss with Learned Uncertainty")
394
395class MultiTaskLoss(nn.Module):
396 """
397 Multi-task loss with learned uncertainty weighting.
398 From "Multi-Task Learning Using Uncertainty to Weigh Losses" (Kendall et al.)
399
400 L = (1/2Ļā²)Lā + (1/2Ļā²)Lā + log(ĻāĻā)
401 where Ļā, Ļā are learned task uncertainties.
402 """
403 def __init__(self, num_tasks=2):
404 super().__init__()
405 # Log variance parameters (learnable)
406 self.log_vars = nn.Parameter(torch.zeros(num_tasks))
407
408 def forward(self, losses):
409 """
410 Args:
411 losses: List of task losses [L1, L2, ...]
412 """
413 weighted_losses = []
414 for i, loss in enumerate(losses):
415 # Precision weighting: exp(-log_var) = 1/ϲ
416 precision = torch.exp(-self.log_vars[i])
417 weighted_loss = precision * loss + self.log_vars[i]
418 weighted_losses.append(weighted_loss)
419
420 return torch.stack(weighted_losses).sum()
421
422# Simulate two tasks
423torch.manual_seed(42)
424
425# Task 1: Regression (MSE)
426pred1 = torch.randn(32, 10)
427target1 = torch.randn(32, 10)
428loss1 = F.mse_loss(pred1, target1)
429
430# Task 2: Classification (CE)
431pred2 = torch.randn(32, 10)
432target2 = torch.randint(0, 10, (32,))
433loss2 = F.cross_entropy(pred2, target2)
434
435print("Task losses (unweighted):")
436print(f" Task 1 (regression): {loss1.item():.4f}")
437print(f" Task 2 (classification): {loss2.item():.4f}")
438print(f" Simple sum: {(loss1 + loss2).item():.4f}")
439
440# Multi-task loss with learned weights
441multi_task_loss = MultiTaskLoss(num_tasks=2)
442weighted_loss = multi_task_loss([loss1, loss2])
443
444print(f"\nWeighted multi-task loss: {weighted_loss.item():.4f}")
445print(f"Task uncertainties (Ļ):")
446print(f" Task 1: {torch.exp(0.5 * multi_task_loss.log_vars[0]).item():.4f}")
447print(f" Task 2: {torch.exp(0.5 * multi_task_loss.log_vars[1]).item():.4f}")
448
449# Simulate training: task 2 is harder (higher loss)
450print("\nAfter simulated training (task 2 is harder):")
451multi_task_loss.log_vars.data = torch.tensor([0.0, 1.0]) # Higher uncertainty for task 2
452weighted_loss = multi_task_loss([loss1, loss2])
453print(f"Weighted loss: {weighted_loss.item():.4f}")
454print(f" Task 1 weight (1/ϲ): {torch.exp(-multi_task_loss.log_vars[0]).item():.4f}")
455print(f" Task 2 weight (1/ϲ): {torch.exp(-multi_task_loss.log_vars[1]).item():.4f}")
456print(" ā Task 2 gets lower weight due to higher uncertainty")
457
458
459# ============================================================================
460# 9. Loss Landscape Visualization
461# ============================================================================
462print_section("9. Loss Landscape Comparison")
463
464def quadratic_loss(w1, w2):
465 """Simple quadratic loss: L = w1² + w2²"""
466 return w1**2 + w2**2
467
468def rosenbrock_loss(w1, w2, a=1, b=100):
469 """Rosenbrock function: non-convex, valley-shaped"""
470 return (a - w1)**2 + b * (w2 - w1**2)**2
471
472def loss_landscape_stats(loss_fn, grid_size=50, range_val=2.0):
473 """Compute statistics for a loss landscape."""
474 w1 = np.linspace(-range_val, range_val, grid_size)
475 w2 = np.linspace(-range_val, range_val, grid_size)
476
477 losses = np.zeros((grid_size, grid_size))
478 for i, w1_val in enumerate(w1):
479 for j, w2_val in enumerate(w2):
480 losses[i, j] = loss_fn(w1_val, w2_val)
481
482 return {
483 'min': losses.min(),
484 'max': losses.max(),
485 'mean': losses.mean(),
486 'std': losses.std()
487 }
488
489print("Quadratic loss landscape:")
490quad_stats = loss_landscape_stats(quadratic_loss)
491print(f" Min: {quad_stats['min']:.4f}")
492print(f" Max: {quad_stats['max']:.4f}")
493print(f" Mean: {quad_stats['mean']:.4f}")
494print(f" Std: {quad_stats['std']:.4f}")
495print(" ā Convex, smooth, single minimum")
496
497print("\nRosenbrock loss landscape:")
498rosen_stats = loss_landscape_stats(rosenbrock_loss, range_val=2.0)
499print(f" Min: {rosen_stats['min']:.4f}")
500print(f" Max: {rosen_stats['max']:.4f}")
501print(f" Mean: {rosen_stats['mean']:.4f}")
502print(f" Std: {rosen_stats['std']:.4f}")
503print(" ā Non-convex, narrow valley, harder optimization")
504
505# Simple gradient descent comparison
506print("\nGradient descent comparison (10 steps, lr=0.01):")
507w_quad = torch.tensor([1.5, 1.5], requires_grad=True)
508w_rosen = torch.tensor([1.5, 1.5], requires_grad=True)
509
510lr = 0.01
511for step in range(10):
512 # Quadratic
513 loss_q = w_quad[0]**2 + w_quad[1]**2
514 loss_q.backward()
515 with torch.no_grad():
516 w_quad -= lr * w_quad.grad
517 w_quad.grad.zero_()
518
519 # Rosenbrock
520 loss_r = (1 - w_rosen[0])**2 + 100 * (w_rosen[1] - w_rosen[0]**2)**2
521 loss_r.backward()
522 with torch.no_grad():
523 w_rosen -= lr * w_rosen.grad
524 w_rosen.grad.zero_()
525
526print(f"Quadratic final: w = [{w_quad[0].item():.4f}, {w_quad[1].item():.4f}]")
527print(f" Distance to optimum (0,0): {torch.norm(w_quad).item():.4f}")
528print(f"Rosenbrock final: w = [{w_rosen[0].item():.4f}, {w_rosen[1].item():.4f}]")
529print(f" Distance to optimum (1,1): {torch.norm(w_rosen - torch.tensor([1.0, 1.0])).item():.4f}")
530
531
532# ============================================================================
533# Summary
534# ============================================================================
535print_section("Summary: Loss Function Selection Guide")
536
537print("""
538Regression:
539 - MSE: Standard, sensitive to outliers
540 - MAE: Robust to outliers, but non-smooth at zero
541 - Huber: Best of both worlds, smooth and robust
542 - Log-Cosh: Smooth MAE alternative
543
544Classification:
545 - Cross-Entropy: Standard, use with softmax
546 - Focal Loss: Class imbalance, focus on hard examples
547 - Label Smoothing: Improve calibration, prevent overconfidence
548
549Metric Learning:
550 - Contrastive: Pairwise similarity
551 - Triplet: Ranking, requires triplets
552 - InfoNCE: Contrastive with multiple negatives
553
554Segmentation:
555 - Dice Loss: Handles class imbalance better than BCE
556 - Focal Loss: Can be adapted for segmentation
557 - Combo: Dice + BCE often works best
558
559Multi-Task:
560 - Uncertainty weighting: Let model learn task importance
561 - Manual weighting: Domain knowledge required
562 - Gradient balancing: GradNorm, PCGrad
563
564Implementation Tips:
565 1. Always use numerical stability tricks (log-sum-exp)
566 2. Check for NaN/inf during training
567 3. Normalize losses if combining multiple terms
568 4. Visualize loss landscapes to understand optimization
569 5. Consider task-specific requirements (imbalance, outliers, etc.)
570""")
571
572print("\n" + "=" * 60)
573print("Loss functions demonstration complete!")
574print("=" * 60)