27_normalization.py

Download
python 617 lines 20.7 KB
  1"""
  227. Normalization Layers Comparison
  3
  4Demonstrates various normalization techniques:
  5- Batch Normalization (manual + nn.BatchNorm2d)
  6- Layer Normalization (manual + nn.LayerNorm)
  7- Group Normalization (nn.GroupNorm)
  8- Instance Normalization (nn.InstanceNorm2d)
  9- RMSNorm (manual implementation)
 10- Training vs Inference behavior
 11"""
 12
 13import torch
 14import torch.nn as nn
 15import torch.nn.functional as F
 16import numpy as np
 17
 18# Set random seed for reproducibility
 19torch.manual_seed(42)
 20np.random.seed(42)
 21
 22
 23# ============================================================================
 24# 1. BatchNorm from Scratch
 25# ============================================================================
 26def manual_batch_norm_2d(x, gamma, beta, running_mean, running_var,
 27                         momentum=0.1, eps=1e-5, training=True):
 28    """
 29    Manual implementation of BatchNorm2d.
 30    x: (N, C, H, W)
 31    gamma, beta: (C,) - learnable parameters
 32    running_mean, running_var: (C,) - running statistics
 33    """
 34    if training:
 35        # Compute batch statistics over (N, H, W) dimensions
 36        batch_mean = x.mean(dim=(0, 2, 3), keepdim=False)  # (C,)
 37        batch_var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=False)  # (C,)
 38
 39        # Update running statistics
 40        running_mean.data = (1 - momentum) * running_mean + momentum * batch_mean
 41        running_var.data = (1 - momentum) * running_var + momentum * batch_var
 42
 43        # Normalize using batch statistics
 44        mean = batch_mean
 45        var = batch_var
 46    else:
 47        # Use running statistics during inference
 48        mean = running_mean
 49        var = running_var
 50
 51    # Reshape for broadcasting: (1, C, 1, 1)
 52    mean = mean.view(1, -1, 1, 1)
 53    var = var.view(1, -1, 1, 1)
 54    gamma = gamma.view(1, -1, 1, 1)
 55    beta = beta.view(1, -1, 1, 1)
 56
 57    # Normalize and scale
 58    x_norm = (x - mean) / torch.sqrt(var + eps)
 59    out = gamma * x_norm + beta
 60
 61    return out
 62
 63
 64def section1_batchnorm_from_scratch():
 65    print("\n" + "="*80)
 66    print("1. BatchNorm from Scratch")
 67    print("="*80)
 68
 69    # Create sample input
 70    N, C, H, W = 4, 3, 8, 8
 71    x = torch.randn(N, C, H, W)
 72
 73    # Manual BatchNorm
 74    gamma = torch.ones(C)
 75    beta = torch.zeros(C)
 76    running_mean = torch.zeros(C)
 77    running_var = torch.ones(C)
 78
 79    # Training mode
 80    manual_out_train = manual_batch_norm_2d(
 81        x, gamma, beta, running_mean.clone(), running_var.clone(), training=True
 82    )
 83
 84    # PyTorch BatchNorm
 85    bn = nn.BatchNorm2d(C, momentum=0.1, eps=1e-5)
 86    bn.weight.data = gamma.clone()
 87    bn.bias.data = beta.clone()
 88    bn.running_mean.data = torch.zeros(C)
 89    bn.running_var.data = torch.ones(C)
 90
 91    bn.train()
 92    pytorch_out_train = bn(x)
 93
 94    print(f"Input shape: {x.shape}")
 95    print(f"Manual output mean: {manual_out_train.mean():.6f}")
 96    print(f"PyTorch output mean: {pytorch_out_train.mean():.6f}")
 97    print(f"Max difference (training): {(manual_out_train - pytorch_out_train).abs().max():.8f}")
 98
 99    # Eval mode - show running statistics are used
100    bn.eval()
101    with torch.no_grad():
102        pytorch_out_eval = bn(x)
103
104    print(f"\nRunning mean after training: {bn.running_mean[:3]}")
105    print(f"Running var after training: {bn.running_var[:3]}")
106    print(f"Eval mode output mean: {pytorch_out_eval.mean():.6f}")
107    print("✓ Training vs Eval mode produces different outputs")
108
109
110# ============================================================================
111# 2. LayerNorm from Scratch
112# ============================================================================
113def manual_layer_norm(x, normalized_shape, gamma, beta, eps=1e-5):
114    """
115    Manual implementation of LayerNorm.
116    x: (N, C, H, W) or any shape
117    normalized_shape: dimensions to normalize over (from the end)
118    """
119    # Compute mean and var over the last len(normalized_shape) dimensions
120    dims_to_normalize = list(range(-len(normalized_shape), 0))
121
122    mean = x.mean(dim=dims_to_normalize, keepdim=True)
123    var = x.var(dim=dims_to_normalize, unbiased=False, keepdim=True)
124
125    # Normalize
126    x_norm = (x - mean) / torch.sqrt(var + eps)
127
128    # Scale and shift (gamma and beta should match normalized_shape)
129    out = gamma * x_norm + beta
130
131    return out
132
133
134def section2_layernorm_from_scratch():
135    print("\n" + "="*80)
136    print("2. LayerNorm from Scratch")
137    print("="*80)
138
139    # Create sample input (batch-independent normalization)
140    N, C, H, W = 4, 8, 16, 16
141    x = torch.randn(N, C, H, W)
142
143    # LayerNorm over (C, H, W)
144    normalized_shape = (C, H, W)
145    gamma = torch.ones(normalized_shape)
146    beta = torch.zeros(normalized_shape)
147
148    # Manual LayerNorm
149    manual_out = manual_layer_norm(x, normalized_shape, gamma, beta)
150
151    # PyTorch LayerNorm
152    ln = nn.LayerNorm(normalized_shape, eps=1e-5)
153    ln.weight.data = gamma.clone().flatten()
154    ln.bias.data = beta.clone().flatten()
155
156    pytorch_out = ln(x)
157
158    print(f"Input shape: {x.shape}")
159    print(f"Normalized shape: {normalized_shape}")
160    print(f"Manual output mean per sample: {manual_out.mean(dim=(1,2,3))}")
161    print(f"Manual output std per sample: {manual_out.std(dim=(1,2,3))}")
162    print(f"Max difference: {(manual_out - pytorch_out).abs().max():.8f}")
163
164    # Show batch independence
165    single_sample = x[0:1]
166    manual_single = manual_layer_norm(single_sample, normalized_shape, gamma, beta)
167    print(f"\nFirst sample from batch: {manual_out[0, 0, 0, :3]}")
168    print(f"Same sample processed alone: {manual_single[0, 0, 0, :3]}")
169    print("✓ LayerNorm is batch-independent")
170
171
172# ============================================================================
173# 3. RMSNorm from Scratch
174# ============================================================================
175class RMSNorm(nn.Module):
176    """
177    Root Mean Square Layer Normalization (as used in LLaMA).
178    Only normalizes by RMS, no mean centering.
179    """
180    def __init__(self, dim, eps=1e-6):
181        super().__init__()
182        self.eps = eps
183        self.weight = nn.Parameter(torch.ones(dim))
184
185    def forward(self, x):
186        # x: (batch, seq_len, dim) or (batch, dim)
187        # Compute RMS over the last dimension
188        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
189        x_norm = x / rms
190        return self.weight * x_norm
191
192
193def section3_rmsnorm_from_scratch():
194    print("\n" + "="*80)
195    print("3. RMSNorm from Scratch")
196    print("="*80)
197
198    batch_size, seq_len, dim = 2, 4, 8
199    x = torch.randn(batch_size, seq_len, dim)
200
201    # RMSNorm
202    rms_norm = RMSNorm(dim)
203    rms_out = rms_norm(x)
204
205    # LayerNorm for comparison
206    layer_norm = nn.LayerNorm(dim)
207    ln_out = layer_norm(x)
208
209    print(f"Input shape: {x.shape}")
210    print(f"Input sample: {x[0, 0, :4]}")
211    print(f"\nRMSNorm output: {rms_out[0, 0, :4]}")
212    print(f"RMSNorm output RMS: {rms_out[0, 0].pow(2).mean().sqrt():.6f}")
213    print(f"\nLayerNorm output: {ln_out[0, 0, :4]}")
214    print(f"LayerNorm output mean: {ln_out[0, 0].mean():.6f}")
215    print(f"LayerNorm output std: {ln_out[0, 0].std():.6f}")
216
217    print("\n✓ RMSNorm: only normalizes scale (no mean centering)")
218    print("✓ LayerNorm: normalizes both mean and variance")
219
220
221# ============================================================================
222# 4. Normalization Dimension Comparison
223# ============================================================================
224def section4_normalization_dimensions():
225    print("\n" + "="*80)
226    print("4. Normalization Dimension Comparison")
227    print("="*80)
228
229    N, C, H, W = 2, 4, 8, 8
230    x = torch.randn(N, C, H, W) * 10 + 5  # Non-zero mean, large variance
231
232    print(f"Input shape: (N={N}, C={C}, H={H}, W={W})")
233    print(f"Input statistics:")
234    print(f"  Global mean: {x.mean():.4f}, std: {x.std():.4f}")
235    print(f"  Per-channel mean: {x.mean(dim=(0, 2, 3))}")
236
237    # BatchNorm2d: normalizes over (N, H, W) for each C
238    bn = nn.BatchNorm2d(C)
239    bn_out = bn(x)
240    print(f"\nBatchNorm2d (normalize over N,H,W for each C):")
241    print(f"  Output mean: {bn_out.mean():.4f}, std: {bn_out.std():.4f}")
242    print(f"  Per-channel mean: {bn_out.mean(dim=(0, 2, 3))}")
243
244    # LayerNorm: normalizes over (C, H, W) for each N
245    ln = nn.LayerNorm((C, H, W))
246    ln_out = ln(x)
247    print(f"\nLayerNorm (normalize over C,H,W for each N):")
248    print(f"  Output mean: {ln_out.mean():.4f}, std: {ln_out.std():.4f}")
249    print(f"  Per-sample mean: {ln_out.mean(dim=(1, 2, 3))}")
250
251    # GroupNorm: normalizes over (H, W) and groups of C for each N
252    gn = nn.GroupNorm(num_groups=2, num_channels=C)
253    gn_out = gn(x)
254    print(f"\nGroupNorm (2 groups, normalize over H,W for each group in each N):")
255    print(f"  Output mean: {gn_out.mean():.4f}, std: {gn_out.std():.4f}")
256
257    # InstanceNorm2d: normalizes over (H, W) for each C in each N
258    in_norm = nn.InstanceNorm2d(C)
259    in_out = in_norm(x)
260    print(f"\nInstanceNorm2d (normalize over H,W for each C in each N):")
261    print(f"  Output mean: {in_out.mean():.4f}, std: {in_out.std():.4f}")
262
263    print("\n✓ Different normalizations operate over different dimensions")
264
265
266# ============================================================================
267# 5. GroupNorm vs BatchNorm with Small Batch
268# ============================================================================
269def section5_groupnorm_vs_batchnorm():
270    print("\n" + "="*80)
271    print("5. GroupNorm vs BatchNorm with Small Batch Size")
272    print("="*80)
273
274    C, H, W = 32, 16, 16
275
276    # Create data with batch_size=1 (BatchNorm fails here)
277    x_small = torch.randn(1, C, H, W)
278
279    print(f"Input shape: {x_small.shape} (batch_size=1)")
280    print(f"Input mean: {x_small.mean():.4f}, std: {x_small.std():.4f}")
281
282    # BatchNorm with batch_size=1
283    bn = nn.BatchNorm2d(C)
284    bn.train()
285    bn_out = bn(x_small)
286    print(f"\nBatchNorm2d (training mode, batch_size=1):")
287    print(f"  Output mean: {bn_out.mean():.4f}, std: {bn_out.std():.4f}")
288    print(f"  ⚠️  With batch_size=1, variance=0, normalization unstable")
289
290    # GroupNorm works fine
291    gn = nn.GroupNorm(num_groups=8, num_channels=C)
292    gn_out = gn(x_small)
293    print(f"\nGroupNorm (8 groups):")
294    print(f"  Output mean: {gn_out.mean():.4f}, std: {gn_out.std():.4f}")
295    print(f"  ✓ GroupNorm stable with batch_size=1")
296
297    # Larger batch for comparison
298    x_large = torch.randn(16, C, H, W)
299    bn_large = bn(x_large)
300    print(f"\nBatchNorm2d with batch_size=16:")
301    print(f"  Output mean: {bn_large.mean():.4f}, std: {bn_large.std():.4f}")
302    print(f"  ✓ BatchNorm stable with larger batch")
303
304
305# ============================================================================
306# 6. Pre-Norm vs Post-Norm Transformer Block
307# ============================================================================
308class PostNormTransformerBlock(nn.Module):
309    """Traditional: Attention/FFN -> Add -> Norm"""
310    def __init__(self, dim, num_heads=4):
311        super().__init__()
312        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
313        self.norm1 = nn.LayerNorm(dim)
314        self.ffn = nn.Sequential(
315            nn.Linear(dim, dim * 4),
316            nn.GELU(),
317            nn.Linear(dim * 4, dim)
318        )
319        self.norm2 = nn.LayerNorm(dim)
320
321    def forward(self, x):
322        # Post-norm: x + Sublayer(x) -> Norm
323        attn_out, _ = self.attn(x, x, x)
324        x = self.norm1(x + attn_out)
325
326        ffn_out = self.ffn(x)
327        x = self.norm2(x + ffn_out)
328        return x
329
330
331class PreNormTransformerBlock(nn.Module):
332    """Modern: Norm -> Attention/FFN -> Add"""
333    def __init__(self, dim, num_heads=4):
334        super().__init__()
335        self.norm1 = nn.LayerNorm(dim)
336        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
337        self.norm2 = nn.LayerNorm(dim)
338        self.ffn = nn.Sequential(
339            nn.Linear(dim, dim * 4),
340            nn.GELU(),
341            nn.Linear(dim * 4, dim)
342        )
343
344    def forward(self, x):
345        # Pre-norm: x + Sublayer(Norm(x))
346        attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
347        x = x + attn_out
348
349        ffn_out = self.ffn(self.norm2(x))
350        x = x + ffn_out
351        return x
352
353
354def section6_prenorm_vs_postnorm():
355    print("\n" + "="*80)
356    print("6. Pre-Norm vs Post-Norm Transformer Block")
357    print("="*80)
358
359    batch_size, seq_len, dim = 2, 10, 64
360    x = torch.randn(batch_size, seq_len, dim)
361
362    post_norm_block = PostNormTransformerBlock(dim)
363    pre_norm_block = PreNormTransformerBlock(dim)
364
365    # Forward pass
366    post_out = post_norm_block(x)
367    pre_out = pre_norm_block(x)
368
369    print(f"Input shape: {x.shape}")
370    print(f"Input norm: {x.norm():.4f}")
371    print(f"\nPost-Norm output norm: {post_out.norm():.4f}")
372    print(f"Pre-Norm output norm: {pre_out.norm():.4f}")
373
374    # Check gradient flow (dummy loss)
375    loss_post = post_out.sum()
376    loss_post.backward()
377    post_grad_norm = sum(p.grad.norm().item() for p in post_norm_block.parameters() if p.grad is not None)
378
379    # Reset and compute pre-norm gradients
380    pre_norm_block.zero_grad()
381    x_pre = x.clone().detach().requires_grad_(True)
382    pre_out = pre_norm_block(x_pre)
383    loss_pre = pre_out.sum()
384    loss_pre.backward()
385    pre_grad_norm = sum(p.grad.norm().item() for p in pre_norm_block.parameters() if p.grad is not None)
386
387    print(f"\nPost-Norm gradient norm: {post_grad_norm:.4f}")
388    print(f"Pre-Norm gradient norm: {pre_grad_norm:.4f}")
389    print("\n✓ Pre-Norm: Better gradient flow, easier training")
390    print("✓ Post-Norm: Traditional, may need warmup")
391
392
393# ============================================================================
394# 7. Training Experiment
395# ============================================================================
396class SimpleCNN(nn.Module):
397    def __init__(self, norm_type='batch'):
398        super().__init__()
399        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
400
401        if norm_type == 'batch':
402            self.norm1 = nn.BatchNorm2d(16)
403        elif norm_type == 'group':
404            self.norm1 = nn.GroupNorm(4, 16)
405        else:
406            self.norm1 = nn.Identity()
407
408        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
409
410        if norm_type == 'batch':
411            self.norm2 = nn.BatchNorm2d(32)
412        elif norm_type == 'group':
413            self.norm2 = nn.GroupNorm(8, 32)
414        else:
415            self.norm2 = nn.Identity()
416
417        self.fc = nn.Linear(32 * 7 * 7, 10)
418
419    def forward(self, x):
420        x = F.relu(self.norm1(self.conv1(x)))
421        x = F.max_pool2d(x, 2)
422        x = F.relu(self.norm2(self.conv2(x)))
423        x = F.max_pool2d(x, 2)
424        x = x.view(x.size(0), -1)
425        x = self.fc(x)
426        return x
427
428
429def train_model(model, epochs=5, batch_size=32):
430    """Train on synthetic data"""
431    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
432    criterion = nn.CrossEntropyLoss()
433
434    losses = []
435    for epoch in range(epochs):
436        # Generate synthetic data
437        x = torch.randn(batch_size, 1, 28, 28)
438        y = torch.randint(0, 10, (batch_size,))
439
440        model.train()
441        optimizer.zero_grad()
442        output = model(x)
443        loss = criterion(output, y)
444        loss.backward()
445        optimizer.step()
446
447        losses.append(loss.item())
448
449    return losses
450
451
452def section7_training_experiment():
453    print("\n" + "="*80)
454    print("7. Training Experiment: Different Normalizations")
455    print("="*80)
456
457    epochs = 10
458
459    # Train with BatchNorm
460    model_bn = SimpleCNN(norm_type='batch')
461    losses_bn = train_model(model_bn, epochs=epochs)
462
463    # Train with GroupNorm
464    model_gn = SimpleCNN(norm_type='group')
465    losses_gn = train_model(model_gn, epochs=epochs)
466
467    # Train without normalization
468    model_none = SimpleCNN(norm_type='none')
469    losses_none = train_model(model_none, epochs=epochs)
470
471    print(f"Training for {epochs} epochs on synthetic data:")
472    print(f"\nBatchNorm losses: {[f'{l:.4f}' for l in losses_bn[:5]]} ... {losses_bn[-1]:.4f}")
473    print(f"GroupNorm losses: {[f'{l:.4f}' for l in losses_gn[:5]]} ... {losses_gn[-1]:.4f}")
474    print(f"No Norm losses:   {[f'{l:.4f}' for l in losses_none[:5]]} ... {losses_none[-1]:.4f}")
475
476    print(f"\nFinal loss comparison:")
477    print(f"  BatchNorm:  {losses_bn[-1]:.4f}")
478    print(f"  GroupNorm:  {losses_gn[-1]:.4f}")
479    print(f"  No Norm:    {losses_none[-1]:.4f}")
480    print("\n✓ Normalization helps convergence")
481
482
483# ============================================================================
484# 8. Common Pitfalls
485# ============================================================================
486def section8_common_pitfalls():
487    print("\n" + "="*80)
488    print("8. Common Pitfalls")
489    print("="*80)
490
491    # Pitfall 1: Forgetting model.eval() for BatchNorm
492    print("\n--- Pitfall 1: Forgetting model.eval() for BatchNorm ---")
493    x = torch.randn(4, 3, 8, 8)
494    bn = nn.BatchNorm2d(3)
495
496    # Train mode (uses batch statistics)
497    bn.train()
498    out_train_1 = bn(x)
499    out_train_2 = bn(x)  # Different output each time!
500
501    # Eval mode (uses running statistics)
502    bn.eval()
503    with torch.no_grad():
504        out_eval_1 = bn(x)
505        out_eval_2 = bn(x)  # Same output
506
507    print(f"Train mode, pass 1 mean: {out_train_1.mean():.6f}")
508    print(f"Train mode, pass 2 mean: {out_train_2.mean():.6f}")
509    print(f"Difference: {(out_train_1 - out_train_2).abs().max():.6f}")
510    print(f"\nEval mode, pass 1 mean: {out_eval_1.mean():.6f}")
511    print(f"Eval mode, pass 2 mean: {out_eval_2.mean():.6f}")
512    print(f"Difference: {(out_eval_1 - out_eval_2).abs().max():.6f}")
513    print("⚠️  Always use model.eval() during inference!")
514
515    # Pitfall 2: Wrong dimension for LayerNorm
516    print("\n--- Pitfall 2: Wrong normalized_shape for LayerNorm ---")
517    x = torch.randn(2, 10, 64)  # (batch, seq, dim)
518
519    # Correct: normalize over last dimension
520    ln_correct = nn.LayerNorm(64)
521    out_correct = ln_correct(x)
522    print(f"Input shape: {x.shape}")
523    print(f"Correct LayerNorm(64): output mean per sample = {out_correct.mean(dim=(1,2))}")
524
525    # Wrong: normalizing over wrong dimensions
526    try:
527        ln_wrong = nn.LayerNorm((10, 64))  # This normalizes over (seq, dim)
528        out_wrong = ln_wrong(x)
529        print(f"Wrong LayerNorm((10,64)): output mean per sample = {out_wrong.mean(dim=(1,2))}")
530        print("⚠️  Make sure normalized_shape matches your intention!")
531    except Exception as e:
532        print(f"Error: {e}")
533
534    # Pitfall 3: Frozen BatchNorm
535    print("\n--- Pitfall 3: Frozen BatchNorm (for fine-tuning) ---")
536    bn = nn.BatchNorm2d(3)
537    x = torch.randn(4, 3, 8, 8)
538
539    # Normal training: running stats update
540    bn.train()
541    initial_mean = bn.running_mean.clone()
542    _ = bn(x)
543    updated_mean = bn.running_mean
544    print(f"Initial running mean: {initial_mean[:3]}")
545    print(f"After forward (train mode): {updated_mean[:3]}")
546    print(f"Difference: {(updated_mean - initial_mean).abs().sum():.6f}")
547
548    # Frozen BatchNorm: eval mode, no gradient
549    bn.eval()
550    for param in bn.parameters():
551        param.requires_grad = False
552
553    frozen_mean = bn.running_mean.clone()
554    _ = bn(x)
555    after_frozen = bn.running_mean
556    print(f"\nFrozen BN running mean: {frozen_mean[:3]}")
557    print(f"After forward (frozen): {after_frozen[:3]}")
558    print(f"Difference: {(after_frozen - frozen_mean).abs().sum():.6f}")
559    print("✓ Frozen BatchNorm: use for fine-tuning with different batch sizes")
560
561
562# ============================================================================
563# Main
564# ============================================================================
565def main():
566    print("\n" + "="*80)
567    print("PyTorch Normalization Layers Comprehensive Guide")
568    print("="*80)
569
570    section1_batchnorm_from_scratch()
571    section2_layernorm_from_scratch()
572    section3_rmsnorm_from_scratch()
573    section4_normalization_dimensions()
574    section5_groupnorm_vs_batchnorm()
575    section6_prenorm_vs_postnorm()
576    section7_training_experiment()
577    section8_common_pitfalls()
578
579    print("\n" + "="*80)
580    print("Summary")
581    print("="*80)
582    print("""
583Normalization Comparison:
584- BatchNorm: Normalizes over batch dimension, maintains running stats
585  → Best for large batches, computer vision
586  → Requires model.eval() during inference
587
588- LayerNorm: Normalizes over feature dimensions, batch-independent
589  → Best for NLP, transformers, small/variable batches
590  → No train/eval mode difference
591
592- GroupNorm: Normalizes over channel groups, batch-independent
593  → Best for small batches, instance segmentation
594  → More stable than BatchNorm with batch_size=1
595
596- InstanceNorm: Normalizes over spatial dimensions per channel
597  → Best for style transfer, GANs
598
599- RMSNorm: Simplified normalization without mean centering
600  → Used in LLaMA, faster than LayerNorm
601  → Only normalizes scale
602
603Pre-Norm vs Post-Norm:
604- Pre-Norm: Better gradient flow, easier training, modern default
605- Post-Norm: Traditional, may need learning rate warmup
606
607Common Pitfalls:
6081. Forgetting model.eval() for BatchNorm during inference
6092. Wrong normalized_shape for LayerNorm
6103. Frozen BatchNorm behavior when fine-tuning
611    """)
612    print("="*80)
613
614
615if __name__ == "__main__":
616    main()