1"""
227. Normalization Layers Comparison
3
4Demonstrates various normalization techniques:
5- Batch Normalization (manual + nn.BatchNorm2d)
6- Layer Normalization (manual + nn.LayerNorm)
7- Group Normalization (nn.GroupNorm)
8- Instance Normalization (nn.InstanceNorm2d)
9- RMSNorm (manual implementation)
10- Training vs Inference behavior
11"""
12
13import torch
14import torch.nn as nn
15import torch.nn.functional as F
16import numpy as np
17
18# Set random seed for reproducibility
19torch.manual_seed(42)
20np.random.seed(42)
21
22
23# ============================================================================
24# 1. BatchNorm from Scratch
25# ============================================================================
26def manual_batch_norm_2d(x, gamma, beta, running_mean, running_var,
27 momentum=0.1, eps=1e-5, training=True):
28 """
29 Manual implementation of BatchNorm2d.
30 x: (N, C, H, W)
31 gamma, beta: (C,) - learnable parameters
32 running_mean, running_var: (C,) - running statistics
33 """
34 if training:
35 # Compute batch statistics over (N, H, W) dimensions
36 batch_mean = x.mean(dim=(0, 2, 3), keepdim=False) # (C,)
37 batch_var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=False) # (C,)
38
39 # Update running statistics
40 running_mean.data = (1 - momentum) * running_mean + momentum * batch_mean
41 running_var.data = (1 - momentum) * running_var + momentum * batch_var
42
43 # Normalize using batch statistics
44 mean = batch_mean
45 var = batch_var
46 else:
47 # Use running statistics during inference
48 mean = running_mean
49 var = running_var
50
51 # Reshape for broadcasting: (1, C, 1, 1)
52 mean = mean.view(1, -1, 1, 1)
53 var = var.view(1, -1, 1, 1)
54 gamma = gamma.view(1, -1, 1, 1)
55 beta = beta.view(1, -1, 1, 1)
56
57 # Normalize and scale
58 x_norm = (x - mean) / torch.sqrt(var + eps)
59 out = gamma * x_norm + beta
60
61 return out
62
63
64def section1_batchnorm_from_scratch():
65 print("\n" + "="*80)
66 print("1. BatchNorm from Scratch")
67 print("="*80)
68
69 # Create sample input
70 N, C, H, W = 4, 3, 8, 8
71 x = torch.randn(N, C, H, W)
72
73 # Manual BatchNorm
74 gamma = torch.ones(C)
75 beta = torch.zeros(C)
76 running_mean = torch.zeros(C)
77 running_var = torch.ones(C)
78
79 # Training mode
80 manual_out_train = manual_batch_norm_2d(
81 x, gamma, beta, running_mean.clone(), running_var.clone(), training=True
82 )
83
84 # PyTorch BatchNorm
85 bn = nn.BatchNorm2d(C, momentum=0.1, eps=1e-5)
86 bn.weight.data = gamma.clone()
87 bn.bias.data = beta.clone()
88 bn.running_mean.data = torch.zeros(C)
89 bn.running_var.data = torch.ones(C)
90
91 bn.train()
92 pytorch_out_train = bn(x)
93
94 print(f"Input shape: {x.shape}")
95 print(f"Manual output mean: {manual_out_train.mean():.6f}")
96 print(f"PyTorch output mean: {pytorch_out_train.mean():.6f}")
97 print(f"Max difference (training): {(manual_out_train - pytorch_out_train).abs().max():.8f}")
98
99 # Eval mode - show running statistics are used
100 bn.eval()
101 with torch.no_grad():
102 pytorch_out_eval = bn(x)
103
104 print(f"\nRunning mean after training: {bn.running_mean[:3]}")
105 print(f"Running var after training: {bn.running_var[:3]}")
106 print(f"Eval mode output mean: {pytorch_out_eval.mean():.6f}")
107 print("✓ Training vs Eval mode produces different outputs")
108
109
110# ============================================================================
111# 2. LayerNorm from Scratch
112# ============================================================================
113def manual_layer_norm(x, normalized_shape, gamma, beta, eps=1e-5):
114 """
115 Manual implementation of LayerNorm.
116 x: (N, C, H, W) or any shape
117 normalized_shape: dimensions to normalize over (from the end)
118 """
119 # Compute mean and var over the last len(normalized_shape) dimensions
120 dims_to_normalize = list(range(-len(normalized_shape), 0))
121
122 mean = x.mean(dim=dims_to_normalize, keepdim=True)
123 var = x.var(dim=dims_to_normalize, unbiased=False, keepdim=True)
124
125 # Normalize
126 x_norm = (x - mean) / torch.sqrt(var + eps)
127
128 # Scale and shift (gamma and beta should match normalized_shape)
129 out = gamma * x_norm + beta
130
131 return out
132
133
134def section2_layernorm_from_scratch():
135 print("\n" + "="*80)
136 print("2. LayerNorm from Scratch")
137 print("="*80)
138
139 # Create sample input (batch-independent normalization)
140 N, C, H, W = 4, 8, 16, 16
141 x = torch.randn(N, C, H, W)
142
143 # LayerNorm over (C, H, W)
144 normalized_shape = (C, H, W)
145 gamma = torch.ones(normalized_shape)
146 beta = torch.zeros(normalized_shape)
147
148 # Manual LayerNorm
149 manual_out = manual_layer_norm(x, normalized_shape, gamma, beta)
150
151 # PyTorch LayerNorm
152 ln = nn.LayerNorm(normalized_shape, eps=1e-5)
153 ln.weight.data = gamma.clone().flatten()
154 ln.bias.data = beta.clone().flatten()
155
156 pytorch_out = ln(x)
157
158 print(f"Input shape: {x.shape}")
159 print(f"Normalized shape: {normalized_shape}")
160 print(f"Manual output mean per sample: {manual_out.mean(dim=(1,2,3))}")
161 print(f"Manual output std per sample: {manual_out.std(dim=(1,2,3))}")
162 print(f"Max difference: {(manual_out - pytorch_out).abs().max():.8f}")
163
164 # Show batch independence
165 single_sample = x[0:1]
166 manual_single = manual_layer_norm(single_sample, normalized_shape, gamma, beta)
167 print(f"\nFirst sample from batch: {manual_out[0, 0, 0, :3]}")
168 print(f"Same sample processed alone: {manual_single[0, 0, 0, :3]}")
169 print("✓ LayerNorm is batch-independent")
170
171
172# ============================================================================
173# 3. RMSNorm from Scratch
174# ============================================================================
175class RMSNorm(nn.Module):
176 """
177 Root Mean Square Layer Normalization (as used in LLaMA).
178 Only normalizes by RMS, no mean centering.
179 """
180 def __init__(self, dim, eps=1e-6):
181 super().__init__()
182 self.eps = eps
183 self.weight = nn.Parameter(torch.ones(dim))
184
185 def forward(self, x):
186 # x: (batch, seq_len, dim) or (batch, dim)
187 # Compute RMS over the last dimension
188 rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
189 x_norm = x / rms
190 return self.weight * x_norm
191
192
193def section3_rmsnorm_from_scratch():
194 print("\n" + "="*80)
195 print("3. RMSNorm from Scratch")
196 print("="*80)
197
198 batch_size, seq_len, dim = 2, 4, 8
199 x = torch.randn(batch_size, seq_len, dim)
200
201 # RMSNorm
202 rms_norm = RMSNorm(dim)
203 rms_out = rms_norm(x)
204
205 # LayerNorm for comparison
206 layer_norm = nn.LayerNorm(dim)
207 ln_out = layer_norm(x)
208
209 print(f"Input shape: {x.shape}")
210 print(f"Input sample: {x[0, 0, :4]}")
211 print(f"\nRMSNorm output: {rms_out[0, 0, :4]}")
212 print(f"RMSNorm output RMS: {rms_out[0, 0].pow(2).mean().sqrt():.6f}")
213 print(f"\nLayerNorm output: {ln_out[0, 0, :4]}")
214 print(f"LayerNorm output mean: {ln_out[0, 0].mean():.6f}")
215 print(f"LayerNorm output std: {ln_out[0, 0].std():.6f}")
216
217 print("\n✓ RMSNorm: only normalizes scale (no mean centering)")
218 print("✓ LayerNorm: normalizes both mean and variance")
219
220
221# ============================================================================
222# 4. Normalization Dimension Comparison
223# ============================================================================
224def section4_normalization_dimensions():
225 print("\n" + "="*80)
226 print("4. Normalization Dimension Comparison")
227 print("="*80)
228
229 N, C, H, W = 2, 4, 8, 8
230 x = torch.randn(N, C, H, W) * 10 + 5 # Non-zero mean, large variance
231
232 print(f"Input shape: (N={N}, C={C}, H={H}, W={W})")
233 print(f"Input statistics:")
234 print(f" Global mean: {x.mean():.4f}, std: {x.std():.4f}")
235 print(f" Per-channel mean: {x.mean(dim=(0, 2, 3))}")
236
237 # BatchNorm2d: normalizes over (N, H, W) for each C
238 bn = nn.BatchNorm2d(C)
239 bn_out = bn(x)
240 print(f"\nBatchNorm2d (normalize over N,H,W for each C):")
241 print(f" Output mean: {bn_out.mean():.4f}, std: {bn_out.std():.4f}")
242 print(f" Per-channel mean: {bn_out.mean(dim=(0, 2, 3))}")
243
244 # LayerNorm: normalizes over (C, H, W) for each N
245 ln = nn.LayerNorm((C, H, W))
246 ln_out = ln(x)
247 print(f"\nLayerNorm (normalize over C,H,W for each N):")
248 print(f" Output mean: {ln_out.mean():.4f}, std: {ln_out.std():.4f}")
249 print(f" Per-sample mean: {ln_out.mean(dim=(1, 2, 3))}")
250
251 # GroupNorm: normalizes over (H, W) and groups of C for each N
252 gn = nn.GroupNorm(num_groups=2, num_channels=C)
253 gn_out = gn(x)
254 print(f"\nGroupNorm (2 groups, normalize over H,W for each group in each N):")
255 print(f" Output mean: {gn_out.mean():.4f}, std: {gn_out.std():.4f}")
256
257 # InstanceNorm2d: normalizes over (H, W) for each C in each N
258 in_norm = nn.InstanceNorm2d(C)
259 in_out = in_norm(x)
260 print(f"\nInstanceNorm2d (normalize over H,W for each C in each N):")
261 print(f" Output mean: {in_out.mean():.4f}, std: {in_out.std():.4f}")
262
263 print("\n✓ Different normalizations operate over different dimensions")
264
265
266# ============================================================================
267# 5. GroupNorm vs BatchNorm with Small Batch
268# ============================================================================
269def section5_groupnorm_vs_batchnorm():
270 print("\n" + "="*80)
271 print("5. GroupNorm vs BatchNorm with Small Batch Size")
272 print("="*80)
273
274 C, H, W = 32, 16, 16
275
276 # Create data with batch_size=1 (BatchNorm fails here)
277 x_small = torch.randn(1, C, H, W)
278
279 print(f"Input shape: {x_small.shape} (batch_size=1)")
280 print(f"Input mean: {x_small.mean():.4f}, std: {x_small.std():.4f}")
281
282 # BatchNorm with batch_size=1
283 bn = nn.BatchNorm2d(C)
284 bn.train()
285 bn_out = bn(x_small)
286 print(f"\nBatchNorm2d (training mode, batch_size=1):")
287 print(f" Output mean: {bn_out.mean():.4f}, std: {bn_out.std():.4f}")
288 print(f" ⚠️ With batch_size=1, variance=0, normalization unstable")
289
290 # GroupNorm works fine
291 gn = nn.GroupNorm(num_groups=8, num_channels=C)
292 gn_out = gn(x_small)
293 print(f"\nGroupNorm (8 groups):")
294 print(f" Output mean: {gn_out.mean():.4f}, std: {gn_out.std():.4f}")
295 print(f" ✓ GroupNorm stable with batch_size=1")
296
297 # Larger batch for comparison
298 x_large = torch.randn(16, C, H, W)
299 bn_large = bn(x_large)
300 print(f"\nBatchNorm2d with batch_size=16:")
301 print(f" Output mean: {bn_large.mean():.4f}, std: {bn_large.std():.4f}")
302 print(f" ✓ BatchNorm stable with larger batch")
303
304
305# ============================================================================
306# 6. Pre-Norm vs Post-Norm Transformer Block
307# ============================================================================
308class PostNormTransformerBlock(nn.Module):
309 """Traditional: Attention/FFN -> Add -> Norm"""
310 def __init__(self, dim, num_heads=4):
311 super().__init__()
312 self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
313 self.norm1 = nn.LayerNorm(dim)
314 self.ffn = nn.Sequential(
315 nn.Linear(dim, dim * 4),
316 nn.GELU(),
317 nn.Linear(dim * 4, dim)
318 )
319 self.norm2 = nn.LayerNorm(dim)
320
321 def forward(self, x):
322 # Post-norm: x + Sublayer(x) -> Norm
323 attn_out, _ = self.attn(x, x, x)
324 x = self.norm1(x + attn_out)
325
326 ffn_out = self.ffn(x)
327 x = self.norm2(x + ffn_out)
328 return x
329
330
331class PreNormTransformerBlock(nn.Module):
332 """Modern: Norm -> Attention/FFN -> Add"""
333 def __init__(self, dim, num_heads=4):
334 super().__init__()
335 self.norm1 = nn.LayerNorm(dim)
336 self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
337 self.norm2 = nn.LayerNorm(dim)
338 self.ffn = nn.Sequential(
339 nn.Linear(dim, dim * 4),
340 nn.GELU(),
341 nn.Linear(dim * 4, dim)
342 )
343
344 def forward(self, x):
345 # Pre-norm: x + Sublayer(Norm(x))
346 attn_out, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
347 x = x + attn_out
348
349 ffn_out = self.ffn(self.norm2(x))
350 x = x + ffn_out
351 return x
352
353
354def section6_prenorm_vs_postnorm():
355 print("\n" + "="*80)
356 print("6. Pre-Norm vs Post-Norm Transformer Block")
357 print("="*80)
358
359 batch_size, seq_len, dim = 2, 10, 64
360 x = torch.randn(batch_size, seq_len, dim)
361
362 post_norm_block = PostNormTransformerBlock(dim)
363 pre_norm_block = PreNormTransformerBlock(dim)
364
365 # Forward pass
366 post_out = post_norm_block(x)
367 pre_out = pre_norm_block(x)
368
369 print(f"Input shape: {x.shape}")
370 print(f"Input norm: {x.norm():.4f}")
371 print(f"\nPost-Norm output norm: {post_out.norm():.4f}")
372 print(f"Pre-Norm output norm: {pre_out.norm():.4f}")
373
374 # Check gradient flow (dummy loss)
375 loss_post = post_out.sum()
376 loss_post.backward()
377 post_grad_norm = sum(p.grad.norm().item() for p in post_norm_block.parameters() if p.grad is not None)
378
379 # Reset and compute pre-norm gradients
380 pre_norm_block.zero_grad()
381 x_pre = x.clone().detach().requires_grad_(True)
382 pre_out = pre_norm_block(x_pre)
383 loss_pre = pre_out.sum()
384 loss_pre.backward()
385 pre_grad_norm = sum(p.grad.norm().item() for p in pre_norm_block.parameters() if p.grad is not None)
386
387 print(f"\nPost-Norm gradient norm: {post_grad_norm:.4f}")
388 print(f"Pre-Norm gradient norm: {pre_grad_norm:.4f}")
389 print("\n✓ Pre-Norm: Better gradient flow, easier training")
390 print("✓ Post-Norm: Traditional, may need warmup")
391
392
393# ============================================================================
394# 7. Training Experiment
395# ============================================================================
396class SimpleCNN(nn.Module):
397 def __init__(self, norm_type='batch'):
398 super().__init__()
399 self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
400
401 if norm_type == 'batch':
402 self.norm1 = nn.BatchNorm2d(16)
403 elif norm_type == 'group':
404 self.norm1 = nn.GroupNorm(4, 16)
405 else:
406 self.norm1 = nn.Identity()
407
408 self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
409
410 if norm_type == 'batch':
411 self.norm2 = nn.BatchNorm2d(32)
412 elif norm_type == 'group':
413 self.norm2 = nn.GroupNorm(8, 32)
414 else:
415 self.norm2 = nn.Identity()
416
417 self.fc = nn.Linear(32 * 7 * 7, 10)
418
419 def forward(self, x):
420 x = F.relu(self.norm1(self.conv1(x)))
421 x = F.max_pool2d(x, 2)
422 x = F.relu(self.norm2(self.conv2(x)))
423 x = F.max_pool2d(x, 2)
424 x = x.view(x.size(0), -1)
425 x = self.fc(x)
426 return x
427
428
429def train_model(model, epochs=5, batch_size=32):
430 """Train on synthetic data"""
431 optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
432 criterion = nn.CrossEntropyLoss()
433
434 losses = []
435 for epoch in range(epochs):
436 # Generate synthetic data
437 x = torch.randn(batch_size, 1, 28, 28)
438 y = torch.randint(0, 10, (batch_size,))
439
440 model.train()
441 optimizer.zero_grad()
442 output = model(x)
443 loss = criterion(output, y)
444 loss.backward()
445 optimizer.step()
446
447 losses.append(loss.item())
448
449 return losses
450
451
452def section7_training_experiment():
453 print("\n" + "="*80)
454 print("7. Training Experiment: Different Normalizations")
455 print("="*80)
456
457 epochs = 10
458
459 # Train with BatchNorm
460 model_bn = SimpleCNN(norm_type='batch')
461 losses_bn = train_model(model_bn, epochs=epochs)
462
463 # Train with GroupNorm
464 model_gn = SimpleCNN(norm_type='group')
465 losses_gn = train_model(model_gn, epochs=epochs)
466
467 # Train without normalization
468 model_none = SimpleCNN(norm_type='none')
469 losses_none = train_model(model_none, epochs=epochs)
470
471 print(f"Training for {epochs} epochs on synthetic data:")
472 print(f"\nBatchNorm losses: {[f'{l:.4f}' for l in losses_bn[:5]]} ... {losses_bn[-1]:.4f}")
473 print(f"GroupNorm losses: {[f'{l:.4f}' for l in losses_gn[:5]]} ... {losses_gn[-1]:.4f}")
474 print(f"No Norm losses: {[f'{l:.4f}' for l in losses_none[:5]]} ... {losses_none[-1]:.4f}")
475
476 print(f"\nFinal loss comparison:")
477 print(f" BatchNorm: {losses_bn[-1]:.4f}")
478 print(f" GroupNorm: {losses_gn[-1]:.4f}")
479 print(f" No Norm: {losses_none[-1]:.4f}")
480 print("\n✓ Normalization helps convergence")
481
482
483# ============================================================================
484# 8. Common Pitfalls
485# ============================================================================
486def section8_common_pitfalls():
487 print("\n" + "="*80)
488 print("8. Common Pitfalls")
489 print("="*80)
490
491 # Pitfall 1: Forgetting model.eval() for BatchNorm
492 print("\n--- Pitfall 1: Forgetting model.eval() for BatchNorm ---")
493 x = torch.randn(4, 3, 8, 8)
494 bn = nn.BatchNorm2d(3)
495
496 # Train mode (uses batch statistics)
497 bn.train()
498 out_train_1 = bn(x)
499 out_train_2 = bn(x) # Different output each time!
500
501 # Eval mode (uses running statistics)
502 bn.eval()
503 with torch.no_grad():
504 out_eval_1 = bn(x)
505 out_eval_2 = bn(x) # Same output
506
507 print(f"Train mode, pass 1 mean: {out_train_1.mean():.6f}")
508 print(f"Train mode, pass 2 mean: {out_train_2.mean():.6f}")
509 print(f"Difference: {(out_train_1 - out_train_2).abs().max():.6f}")
510 print(f"\nEval mode, pass 1 mean: {out_eval_1.mean():.6f}")
511 print(f"Eval mode, pass 2 mean: {out_eval_2.mean():.6f}")
512 print(f"Difference: {(out_eval_1 - out_eval_2).abs().max():.6f}")
513 print("⚠️ Always use model.eval() during inference!")
514
515 # Pitfall 2: Wrong dimension for LayerNorm
516 print("\n--- Pitfall 2: Wrong normalized_shape for LayerNorm ---")
517 x = torch.randn(2, 10, 64) # (batch, seq, dim)
518
519 # Correct: normalize over last dimension
520 ln_correct = nn.LayerNorm(64)
521 out_correct = ln_correct(x)
522 print(f"Input shape: {x.shape}")
523 print(f"Correct LayerNorm(64): output mean per sample = {out_correct.mean(dim=(1,2))}")
524
525 # Wrong: normalizing over wrong dimensions
526 try:
527 ln_wrong = nn.LayerNorm((10, 64)) # This normalizes over (seq, dim)
528 out_wrong = ln_wrong(x)
529 print(f"Wrong LayerNorm((10,64)): output mean per sample = {out_wrong.mean(dim=(1,2))}")
530 print("⚠️ Make sure normalized_shape matches your intention!")
531 except Exception as e:
532 print(f"Error: {e}")
533
534 # Pitfall 3: Frozen BatchNorm
535 print("\n--- Pitfall 3: Frozen BatchNorm (for fine-tuning) ---")
536 bn = nn.BatchNorm2d(3)
537 x = torch.randn(4, 3, 8, 8)
538
539 # Normal training: running stats update
540 bn.train()
541 initial_mean = bn.running_mean.clone()
542 _ = bn(x)
543 updated_mean = bn.running_mean
544 print(f"Initial running mean: {initial_mean[:3]}")
545 print(f"After forward (train mode): {updated_mean[:3]}")
546 print(f"Difference: {(updated_mean - initial_mean).abs().sum():.6f}")
547
548 # Frozen BatchNorm: eval mode, no gradient
549 bn.eval()
550 for param in bn.parameters():
551 param.requires_grad = False
552
553 frozen_mean = bn.running_mean.clone()
554 _ = bn(x)
555 after_frozen = bn.running_mean
556 print(f"\nFrozen BN running mean: {frozen_mean[:3]}")
557 print(f"After forward (frozen): {after_frozen[:3]}")
558 print(f"Difference: {(after_frozen - frozen_mean).abs().sum():.6f}")
559 print("✓ Frozen BatchNorm: use for fine-tuning with different batch sizes")
560
561
562# ============================================================================
563# Main
564# ============================================================================
565def main():
566 print("\n" + "="*80)
567 print("PyTorch Normalization Layers Comprehensive Guide")
568 print("="*80)
569
570 section1_batchnorm_from_scratch()
571 section2_layernorm_from_scratch()
572 section3_rmsnorm_from_scratch()
573 section4_normalization_dimensions()
574 section5_groupnorm_vs_batchnorm()
575 section6_prenorm_vs_postnorm()
576 section7_training_experiment()
577 section8_common_pitfalls()
578
579 print("\n" + "="*80)
580 print("Summary")
581 print("="*80)
582 print("""
583Normalization Comparison:
584- BatchNorm: Normalizes over batch dimension, maintains running stats
585 → Best for large batches, computer vision
586 → Requires model.eval() during inference
587
588- LayerNorm: Normalizes over feature dimensions, batch-independent
589 → Best for NLP, transformers, small/variable batches
590 → No train/eval mode difference
591
592- GroupNorm: Normalizes over channel groups, batch-independent
593 → Best for small batches, instance segmentation
594 → More stable than BatchNorm with batch_size=1
595
596- InstanceNorm: Normalizes over spatial dimensions per channel
597 → Best for style transfer, GANs
598
599- RMSNorm: Simplified normalization without mean centering
600 → Used in LLaMA, faster than LayerNorm
601 → Only normalizes scale
602
603Pre-Norm vs Post-Norm:
604- Pre-Norm: Better gradient flow, easier training, modern default
605- Post-Norm: Traditional, may need learning rate warmup
606
607Common Pitfalls:
6081. Forgetting model.eval() for BatchNorm during inference
6092. Wrong normalized_shape for LayerNorm
6103. Frozen BatchNorm behavior when fine-tuning
611 """)
612 print("="*80)
613
614
615if __name__ == "__main__":
616 main()