26_optimizers.py

Download
python 730 lines 22.7 KB
  1"""
  226. Optimizers Comparison
  3
  4Demonstrates various optimization algorithms:
  5- SGD, SGD+Momentum, SGD+Nesterov
  6- Adam, AdamW
  7- Manual implementations
  8- Learning rate schedulers
  9- Practical optimization patterns
 10"""
 11
 12import torch
 13import torch.nn as nn
 14import torch.nn.functional as F
 15import numpy as np
 16from typing import List, Tuple
 17
 18
 19def print_section(title: str):
 20    """Print formatted section header."""
 21    print(f"\n{'='*70}")
 22    print(f"  {title}")
 23    print('='*70)
 24
 25
 26# =============================================================================
 27# 1. SGD from Scratch
 28# =============================================================================
 29
 30def rosenbrock(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
 31    """Rosenbrock function: f(x,y) = (1-x)^2 + 100(y-x^2)^2"""
 32    return (1 - x)**2 + 100 * (y - x**2)**2
 33
 34
 35def sgd_step(params: List[torch.Tensor], grads: List[torch.Tensor], lr: float):
 36    """Basic SGD step."""
 37    with torch.no_grad():
 38        for p, g in zip(params, grads):
 39            p -= lr * g
 40
 41
 42def sgd_momentum_step(params: List[torch.Tensor], grads: List[torch.Tensor],
 43                      velocities: List[torch.Tensor], lr: float, momentum: float):
 44    """SGD with momentum."""
 45    with torch.no_grad():
 46        for p, g, v in zip(params, grads, velocities):
 47            v.mul_(momentum).add_(g)
 48            p -= lr * v
 49
 50
 51def sgd_nesterov_step(params: List[torch.Tensor], grads: List[torch.Tensor],
 52                      velocities: List[torch.Tensor], lr: float, momentum: float):
 53    """SGD with Nesterov momentum."""
 54    with torch.no_grad():
 55        for p, g, v in zip(params, grads, velocities):
 56            v.mul_(momentum).add_(g)
 57            p -= lr * (g + momentum * v)
 58
 59
 60def optimize_rosenbrock():
 61    """Compare SGD variants on Rosenbrock function."""
 62    print_section("1. SGD from Scratch - Rosenbrock Optimization")
 63
 64    # Starting point
 65    x0, y0 = -1.0, 1.0
 66    lr = 0.001
 67    n_steps = 1000
 68
 69    # Basic SGD
 70    x, y = torch.tensor([x0], requires_grad=True), torch.tensor([y0], requires_grad=True)
 71    params = [x, y]
 72
 73    for step in range(n_steps):
 74        loss = rosenbrock(x, y)
 75        loss.backward()
 76
 77        with torch.no_grad():
 78            grads = [x.grad.clone(), y.grad.clone()]
 79            sgd_step(params, grads, lr)
 80            x.grad.zero_()
 81            y.grad.zero_()
 82
 83        if step % 200 == 0:
 84            print(f"SGD Step {step:4d}: x={x.item():.4f}, y={y.item():.4f}, loss={loss.item():.6f}")
 85
 86    print(f"SGD Final: x={x.item():.4f}, y={y.item():.4f} (optimal: x=1, y=1)")
 87
 88    # SGD with momentum
 89    x, y = torch.tensor([x0], requires_grad=True), torch.tensor([y0], requires_grad=True)
 90    params = [x, y]
 91    velocities = [torch.zeros_like(x), torch.zeros_like(y)]
 92    momentum = 0.9
 93
 94    for step in range(n_steps):
 95        loss = rosenbrock(x, y)
 96        loss.backward()
 97
 98        with torch.no_grad():
 99            grads = [x.grad.clone(), y.grad.clone()]
100            sgd_momentum_step(params, grads, velocities, lr, momentum)
101            x.grad.zero_()
102            y.grad.zero_()
103
104    print(f"Momentum Final: x={x.item():.4f}, y={y.item():.4f}")
105
106    # Nesterov
107    x, y = torch.tensor([x0], requires_grad=True), torch.tensor([y0], requires_grad=True)
108    params = [x, y]
109    velocities = [torch.zeros_like(x), torch.zeros_like(y)]
110
111    for step in range(n_steps):
112        loss = rosenbrock(x, y)
113        loss.backward()
114
115        with torch.no_grad():
116            grads = [x.grad.clone(), y.grad.clone()]
117            sgd_nesterov_step(params, grads, velocities, lr, momentum)
118            x.grad.zero_()
119            y.grad.zero_()
120
121    print(f"Nesterov Final: x={x.item():.4f}, y={y.item():.4f}")
122
123
124# =============================================================================
125# 2. Adam from Scratch
126# =============================================================================
127
128class ManualAdam:
129    """Adam optimizer implemented manually."""
130
131    def __init__(self, params: List[torch.Tensor], lr: float = 0.001,
132                 betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8):
133        self.params = list(params)
134        self.lr = lr
135        self.beta1, self.beta2 = betas
136        self.eps = eps
137        self.step_count = 0
138
139        # Initialize moments
140        self.m = [torch.zeros_like(p) for p in self.params]
141        self.v = [torch.zeros_like(p) for p in self.params]
142
143    def step(self):
144        """Perform one optimization step."""
145        self.step_count += 1
146
147        with torch.no_grad():
148            for i, p in enumerate(self.params):
149                if p.grad is None:
150                    continue
151
152                grad = p.grad
153
154                # Update biased first moment
155                self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
156
157                # Update biased second moment
158                self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * grad**2
159
160                # Bias correction
161                m_hat = self.m[i] / (1 - self.beta1**self.step_count)
162                v_hat = self.v[i] / (1 - self.beta2**self.step_count)
163
164                # Update parameters
165                p -= self.lr * m_hat / (torch.sqrt(v_hat) + self.eps)
166
167    def zero_grad(self):
168        """Zero out gradients."""
169        for p in self.params:
170            if p.grad is not None:
171                p.grad.zero_()
172
173
174def compare_adam_implementations():
175    """Compare manual Adam with PyTorch's Adam."""
176    print_section("2. Adam from Scratch")
177
178    # Simple quadratic: f(x) = x^2 + y^2
179    x0, y0 = 5.0, 3.0
180    n_steps = 100
181    lr = 0.1
182
183    # Manual Adam
184    x1 = torch.tensor([x0], requires_grad=True)
185    y1 = torch.tensor([y0], requires_grad=True)
186    manual_adam = ManualAdam([x1, y1], lr=lr)
187
188    # PyTorch Adam
189    x2 = torch.tensor([x0], requires_grad=True)
190    y2 = torch.tensor([y0], requires_grad=True)
191    torch_adam = torch.optim.Adam([x2, y2], lr=lr)
192
193    print(f"Initial: x={x0:.4f}, y={y0:.4f}")
194
195    for step in range(n_steps):
196        # Manual
197        loss1 = x1**2 + y1**2
198        loss1.backward()
199        manual_adam.step()
200        manual_adam.zero_grad()
201
202        # PyTorch
203        loss2 = x2**2 + y2**2
204        loss2.backward()
205        torch_adam.step()
206        torch_adam.zero_grad()
207
208        if step % 20 == 0:
209            print(f"Step {step:3d} - Manual: ({x1.item():.4f}, {y1.item():.4f}), "
210                  f"PyTorch: ({x2.item():.4f}, {y2.item():.4f})")
211
212    print(f"\nFinal Manual: ({x1.item():.6f}, {y1.item():.6f})")
213    print(f"Final PyTorch: ({x2.item():.6f}, {y2.item():.6f})")
214    print(f"Difference: ({abs(x1.item()-x2.item()):.2e}, {abs(y1.item()-y2.item()):.2e})")
215
216
217# =============================================================================
218# 3. Optimizer Comparison on Toy Problem
219# =============================================================================
220
221class ToyMLP(nn.Module):
222    """Small MLP for classification."""
223
224    def __init__(self, input_dim: int = 2, hidden_dim: int = 32, output_dim: int = 2):
225        super().__init__()
226        self.net = nn.Sequential(
227            nn.Linear(input_dim, hidden_dim),
228            nn.ReLU(),
229            nn.Linear(hidden_dim, hidden_dim),
230            nn.ReLU(),
231            nn.Linear(hidden_dim, output_dim)
232        )
233
234    def forward(self, x):
235        return self.net(x)
236
237
238def generate_spiral_data(n_samples: int = 1000) -> Tuple[torch.Tensor, torch.Tensor]:
239    """Generate two-class spiral dataset."""
240    n_per_class = n_samples // 2
241
242    theta = torch.linspace(0, 4 * np.pi, n_per_class)
243    r = torch.linspace(0.5, 1.0, n_per_class)
244
245    # Class 0
246    x0 = r * torch.cos(theta)
247    y0 = r * torch.sin(theta)
248    class0 = torch.stack([x0, y0], dim=1)
249
250    # Class 1 (rotated)
251    x1 = r * torch.cos(theta + np.pi)
252    y1 = r * torch.sin(theta + np.pi)
253    class1 = torch.stack([x1, y1], dim=1)
254
255    X = torch.cat([class0, class1], dim=0)
256    y = torch.cat([torch.zeros(n_per_class, dtype=torch.long),
257                   torch.ones(n_per_class, dtype=torch.long)])
258
259    # Add noise
260    X += 0.1 * torch.randn_like(X)
261
262    # Shuffle
263    perm = torch.randperm(n_samples)
264    return X[perm], y[perm]
265
266
267def train_with_optimizer(model: nn.Module, optimizer, X: torch.Tensor,
268                        y: torch.Tensor, n_epochs: int = 50) -> List[float]:
269    """Train model and return loss history."""
270    losses = []
271    criterion = nn.CrossEntropyLoss()
272
273    for epoch in range(n_epochs):
274        optimizer.zero_grad()
275        logits = model(X)
276        loss = criterion(logits, y)
277        loss.backward()
278        optimizer.step()
279        losses.append(loss.item())
280
281    return losses
282
283
284def compare_optimizers():
285    """Compare SGD, Adam, AdamW on toy problem."""
286    print_section("3. Optimizer Comparison on Toy Problem")
287
288    X, y = generate_spiral_data(1000)
289    print(f"Generated spiral dataset: {X.shape}, {y.shape}")
290
291    n_epochs = 100
292    lr = 0.01
293
294    # SGD
295    model_sgd = ToyMLP()
296    optimizer_sgd = torch.optim.SGD(model_sgd.parameters(), lr=lr)
297    losses_sgd = train_with_optimizer(model_sgd, optimizer_sgd, X, y, n_epochs)
298
299    # Adam
300    model_adam = ToyMLP()
301    optimizer_adam = torch.optim.Adam(model_adam.parameters(), lr=lr)
302    losses_adam = train_with_optimizer(model_adam, optimizer_adam, X, y, n_epochs)
303
304    # AdamW
305    model_adamw = ToyMLP()
306    optimizer_adamw = torch.optim.AdamW(model_adamw.parameters(), lr=lr, weight_decay=0.01)
307    losses_adamw = train_with_optimizer(model_adamw, optimizer_adamw, X, y, n_epochs)
308
309    print(f"\nFinal losses after {n_epochs} epochs:")
310    print(f"SGD:   {losses_sgd[-1]:.6f}")
311    print(f"Adam:  {losses_adam[-1]:.6f}")
312    print(f"AdamW: {losses_adamw[-1]:.6f}")
313
314    # Accuracy
315    with torch.no_grad():
316        acc_sgd = (model_sgd(X).argmax(dim=1) == y).float().mean()
317        acc_adam = (model_adam(X).argmax(dim=1) == y).float().mean()
318        acc_adamw = (model_adamw(X).argmax(dim=1) == y).float().mean()
319
320    print(f"\nAccuracies:")
321    print(f"SGD:   {acc_sgd:.4f}")
322    print(f"Adam:  {acc_adam:.4f}")
323    print(f"AdamW: {acc_adamw:.4f}")
324
325
326# =============================================================================
327# 4. Weight Decay vs L2 Regularization
328# =============================================================================
329
330def demonstrate_weight_decay_vs_l2():
331    """Show difference between weight decay and L2 regularization."""
332    print_section("4. Weight Decay vs L2 Regularization")
333
334    # Simple linear model
335    torch.manual_seed(42)
336    X = torch.randn(100, 10)
337    y = X @ torch.randn(10) + 0.1 * torch.randn(100)
338
339    n_steps = 100
340    lr = 0.01
341    wd = 0.1
342
343    # Adam with L2 regularization (add to loss)
344    model_l2 = nn.Linear(10, 1)
345    optimizer_l2 = torch.optim.Adam(model_l2.parameters(), lr=lr)
346
347    for _ in range(n_steps):
348        optimizer_l2.zero_grad()
349        pred = model_l2(X).squeeze()
350        loss = F.mse_loss(pred, y)
351
352        # Add L2 penalty to loss
353        l2_reg = sum(p.pow(2).sum() for p in model_l2.parameters())
354        loss = loss + 0.5 * wd * l2_reg
355
356        loss.backward()
357        optimizer_l2.step()
358
359    # AdamW with decoupled weight decay
360    model_adamw = nn.Linear(10, 1)
361    optimizer_adamw = torch.optim.AdamW(model_adamw.parameters(), lr=lr, weight_decay=wd)
362
363    for _ in range(n_steps):
364        optimizer_adamw.zero_grad()
365        pred = model_adamw(X).squeeze()
366        loss = F.mse_loss(pred, y)
367        loss.backward()
368        optimizer_adamw.step()
369
370    # Compare weight norms
371    l2_weight_norm = model_l2.weight.norm().item()
372    adamw_weight_norm = model_adamw.weight.norm().item()
373
374    print(f"Weight norm with Adam+L2: {l2_weight_norm:.4f}")
375    print(f"Weight norm with AdamW:   {adamw_weight_norm:.4f}")
376    print(f"\nAdamW typically produces smaller weights due to decoupled decay.")
377    print("L2 regularization interacts with adaptive learning rates, weight decay doesn't.")
378
379
380# =============================================================================
381# 5. Learning Rate Schedulers
382# =============================================================================
383
384def demonstrate_schedulers():
385    """Show different LR schedulers."""
386    print_section("5. Learning Rate Schedulers")
387
388    model = nn.Linear(10, 1)
389
390    # StepLR
391    print("\n--- StepLR (decay by 0.5 every 30 steps) ---")
392    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
393    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
394
395    for step in range(100):
396        if step % 20 == 0:
397            print(f"Step {step:3d}: LR = {optimizer.param_groups[0]['lr']:.6f}")
398        scheduler.step()
399
400    # CosineAnnealingLR
401    print("\n--- CosineAnnealingLR (T_max=50) ---")
402    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
403    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
404
405    for step in range(100):
406        if step % 20 == 0:
407            print(f"Step {step:3d}: LR = {optimizer.param_groups[0]['lr']:.6f}")
408        scheduler.step()
409
410    # Linear warmup + cosine
411    print("\n--- Linear Warmup + Cosine ---")
412    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
413    warmup_steps = 10
414    total_steps = 100
415
416    def lr_lambda(step):
417        if step < warmup_steps:
418            return step / warmup_steps
419        else:
420            progress = (step - warmup_steps) / (total_steps - warmup_steps)
421            return 0.5 * (1 + np.cos(np.pi * progress))
422
423    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
424
425    for step in range(total_steps):
426        if step % 20 == 0 or step < 15:
427            print(f"Step {step:3d}: LR = {optimizer.param_groups[0]['lr']:.6f}")
428        scheduler.step()
429
430    # OneCycleLR
431    print("\n--- OneCycleLR (max_lr=0.1, total_steps=100) ---")
432    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
433    scheduler = torch.optim.lr_scheduler.OneCycleLR(
434        optimizer, max_lr=0.1, total_steps=100
435    )
436
437    for step in range(100):
438        if step % 20 == 0:
439            print(f"Step {step:3d}: LR = {optimizer.param_groups[0]['lr']:.6f}")
440        scheduler.step()
441
442
443# =============================================================================
444# 6. Learning Rate Finder
445# =============================================================================
446
447def lr_finder(model: nn.Module, X: torch.Tensor, y: torch.Tensor,
448              start_lr: float = 1e-7, end_lr: float = 1.0, num_steps: int = 100):
449    """Simple LR range test."""
450    print_section("6. Learning Rate Finder")
451
452    criterion = nn.CrossEntropyLoss()
453    optimizer = torch.optim.SGD(model.parameters(), lr=start_lr)
454
455    lr_mult = (end_lr / start_lr) ** (1 / num_steps)
456    lrs = []
457    losses = []
458
459    best_loss = float('inf')
460
461    for step in range(num_steps):
462        optimizer.zero_grad()
463        logits = model(X)
464        loss = criterion(logits, y)
465
466        lrs.append(optimizer.param_groups[0]['lr'])
467        losses.append(loss.item())
468
469        if loss.item() < best_loss:
470            best_loss = loss.item()
471
472        # Stop if loss explodes
473        if loss.item() > 4 * best_loss and step > 10:
474            print(f"Stopping early at step {step} (loss exploded)")
475            break
476
477        loss.backward()
478        optimizer.step()
479
480        # Increase LR
481        for param_group in optimizer.param_groups:
482            param_group['lr'] *= lr_mult
483
484    # Find LR with steepest decrease
485    smoothed_losses = []
486    window = 5
487    for i in range(len(losses)):
488        start_idx = max(0, i - window)
489        end_idx = min(len(losses), i + window + 1)
490        smoothed_losses.append(sum(losses[start_idx:end_idx]) / (end_idx - start_idx))
491
492    min_loss_idx = smoothed_losses.index(min(smoothed_losses))
493    suggested_lr = lrs[min_loss_idx]
494
495    print(f"\nTested LR range: {start_lr:.2e} to {lrs[-1]:.2e}")
496    print(f"Best loss: {best_loss:.4f}")
497    print(f"Suggested LR (at min loss): {suggested_lr:.2e}")
498    print(f"Suggested LR (1/10 of that): {suggested_lr/10:.2e}")
499
500    # Show some data points
501    print("\nSample LR vs Loss:")
502    for i in range(0, len(lrs), max(1, len(lrs)//10)):
503        print(f"  LR={lrs[i]:.2e}, Loss={losses[i]:.4f}")
504
505
506def run_lr_finder():
507    """Run LR finder on toy problem."""
508    X, y = generate_spiral_data(1000)
509    model = ToyMLP()
510    lr_finder(model, X, y, start_lr=1e-6, end_lr=1.0, num_steps=100)
511
512
513# =============================================================================
514# 7. Per-Parameter Group Learning Rates
515# =============================================================================
516
517def demonstrate_param_groups():
518    """Show different LRs for different layers."""
519    print_section("7. Per-Parameter Group Learning Rates")
520
521    class BackboneHead(nn.Module):
522        def __init__(self):
523            super().__init__()
524            self.backbone = nn.Sequential(
525                nn.Linear(10, 32),
526                nn.ReLU(),
527                nn.Linear(32, 32)
528            )
529            self.head = nn.Linear(32, 2)
530
531        def forward(self, x):
532            return self.head(self.backbone(x))
533
534    model = BackboneHead()
535
536    # Different LRs for backbone and head
537    optimizer = torch.optim.Adam([
538        {'params': model.backbone.parameters(), 'lr': 1e-4},
539        {'params': model.head.parameters(), 'lr': 1e-3}
540    ])
541
542    print("Parameter groups:")
543    for i, group in enumerate(optimizer.param_groups):
544        n_params = sum(p.numel() for p in group['params'])
545        print(f"  Group {i}: LR={group['lr']:.2e}, {n_params} parameters")
546
547    # Train a few steps
548    X, y = generate_spiral_data(100)
549    criterion = nn.CrossEntropyLoss()
550
551    for step in range(5):
552        optimizer.zero_grad()
553        logits = model(X)
554        loss = criterion(logits, y)
555        loss.backward()
556        optimizer.step()
557        print(f"Step {step}: Loss={loss.item():.4f}")
558
559    print("\nThis pattern is common for transfer learning:")
560    print("  - Backbone (pretrained): small LR (fine-tune)")
561    print("  - Head (randomly initialized): large LR (learn from scratch)")
562
563
564# =============================================================================
565# 8. Gradient Clipping
566# =============================================================================
567
568class ExplodingGradientModel(nn.Module):
569    """Model that can have exploding gradients."""
570
571    def __init__(self):
572        super().__init__()
573        # Large weights can cause gradients to explode
574        self.layers = nn.Sequential(
575            nn.Linear(10, 50),
576            nn.ReLU(),
577            nn.Linear(50, 50),
578            nn.ReLU(),
579            nn.Linear(50, 1)
580        )
581
582        # Initialize with large weights
583        for layer in self.layers:
584            if isinstance(layer, nn.Linear):
585                nn.init.uniform_(layer.weight, -5, 5)
586
587    def forward(self, x):
588        return self.layers(x)
589
590
591def demonstrate_gradient_clipping():
592    """Show gradient clipping techniques."""
593    print_section("8. Gradient Clipping")
594
595    torch.manual_seed(42)
596    X = torch.randn(100, 10)
597    y = torch.randn(100, 1)
598
599    lr = 0.01
600    n_steps = 20
601
602    # Without clipping
603    print("--- Without gradient clipping ---")
604    model = ExplodingGradientModel()
605    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
606
607    for step in range(n_steps):
608        optimizer.zero_grad()
609        pred = model(X)
610        loss = F.mse_loss(pred, y)
611        loss.backward()
612
613        # Compute gradient norm
614        total_norm = 0.0
615        for p in model.parameters():
616            if p.grad is not None:
617                total_norm += p.grad.norm().item() ** 2
618        total_norm = total_norm ** 0.5
619
620        optimizer.step()
621
622        if step % 5 == 0:
623            print(f"Step {step:2d}: Loss={loss.item():.4f}, Grad norm={total_norm:.4f}")
624
625    # With clip_grad_norm_
626    print("\n--- With clip_grad_norm_ (max_norm=1.0) ---")
627    model = ExplodingGradientModel()
628    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
629    max_norm = 1.0
630
631    for step in range(n_steps):
632        optimizer.zero_grad()
633        pred = model(X)
634        loss = F.mse_loss(pred, y)
635        loss.backward()
636
637        # Clip gradients
638        total_norm_before = sum(
639            p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None
640        )**0.5
641        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
642        total_norm_after = sum(
643            p.grad.norm().item()**2 for p in model.parameters() if p.grad is not None
644        )**0.5
645
646        optimizer.step()
647
648        if step % 5 == 0:
649            print(f"Step {step:2d}: Loss={loss.item():.4f}, "
650                  f"Grad before={total_norm_before:.4f}, after={total_norm_after:.4f}")
651
652    # With clip_grad_value_
653    print("\n--- With clip_grad_value_ (clip_value=0.5) ---")
654    model = ExplodingGradientModel()
655    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
656    clip_value = 0.5
657
658    for step in range(n_steps):
659        optimizer.zero_grad()
660        pred = model(X)
661        loss = F.mse_loss(pred, y)
662        loss.backward()
663
664        # Check max gradient value
665        max_grad = max(
666            p.grad.abs().max().item() for p in model.parameters() if p.grad is not None
667        )
668
669        # Clip gradients by value
670        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value)
671
672        max_grad_after = max(
673            p.grad.abs().max().item() for p in model.parameters() if p.grad is not None
674        )
675
676        optimizer.step()
677
678        if step % 5 == 0:
679            print(f"Step {step:2d}: Loss={loss.item():.4f}, "
680                  f"Max grad before={max_grad:.4f}, after={max_grad_after:.4f}")
681
682    print("\nGradient clipping is essential for:")
683    print("  - RNNs/LSTMs (prevent exploding gradients)")
684    print("  - Training with high learning rates")
685    print("  - Reinforcement learning (PPO uses clip_grad_norm_)")
686
687
688# =============================================================================
689# Main
690# =============================================================================
691
692def main():
693    print("\n" + "="*70)
694    print("  PyTorch Optimizers Demonstration")
695    print("="*70)
696
697    optimize_rosenbrock()
698    compare_adam_implementations()
699    compare_optimizers()
700    demonstrate_weight_decay_vs_l2()
701    demonstrate_schedulers()
702    run_lr_finder()
703    demonstrate_param_groups()
704    demonstrate_gradient_clipping()
705
706    print("\n" + "="*70)
707    print("  Summary")
708    print("="*70)
709    print("""
710Key takeaways:
7111. SGD variants: vanilla, momentum, Nesterov (each improves convergence)
7122. Adam: adaptive learning rates per parameter (first & second moments)
7133. AdamW: decoupled weight decay (better than L2 regularization)
7144. Schedulers: StepLR, Cosine, OneCycle, Warmup (critical for training)
7155. LR Finder: automated way to find good learning rate
7166. Param groups: different LRs for different layers (transfer learning)
7177. Gradient clipping: prevent exploding gradients (RNNs, RL)
718
719Practical tips:
720- Start with Adam/AdamW for most tasks
721- Use SGD+momentum for CNNs if you have time to tune
722- Always use a scheduler (cosine or OneCycle work well)
723- Clip gradients for RNNs (max_norm=1.0 is common)
724- Use param groups for transfer learning (small LR for pretrained layers)
725    """)
726
727
728if __name__ == '__main__':
729    main()