1"""
2Foundation Models - Scaling Laws Implementation
3
4Demonstrates Chinchilla scaling laws and compute-optimal model sizing.
5Implements power law relationships between loss, model size, and training data.
6Visualizes scaling curves and compute-optimal frontier.
7
8No external dependencies except numpy and matplotlib.
9"""
10
11import numpy as np
12import matplotlib.pyplot as plt
13
14
15def chinchilla_loss(N, D, A=406.4, B=410.7, alpha=0.34, beta=0.28, E=1.69):
16 """
17 Chinchilla scaling law for loss prediction.
18
19 L(N, D) = A/N^alpha + B/D^beta + E
20
21 Args:
22 N: Number of model parameters (non-embedding)
23 D: Number of training tokens
24 A, B: Scaling coefficients
25 alpha, beta: Scaling exponents
26 E: Irreducible loss (entropy of natural text)
27
28 Returns:
29 Predicted loss value
30 """
31 return A / (N ** alpha) + B / (D ** beta) + E
32
33
34def compute_optimal_ratio(A=406.4, B=410.7, alpha=0.34, beta=0.28):
35 """
36 Compute the compute-optimal ratio N/D from Chinchilla paper.
37
38 At optimum: dL/dN = 0 and dL/dD = 0 under compute constraint.
39 Result: N ∝ D^(beta/alpha)
40
41 Returns:
42 Optimal ratio coefficient
43 """
44 # From calculus of Lagrange multipliers with compute constraint
45 # Optimal: N = k * D^(beta/alpha)
46 ratio_exponent = beta / alpha
47 return ratio_exponent
48
49
50def compute_flops(N, D):
51 """
52 Estimate FLOPs for training a transformer.
53
54 FLOPs ≈ 6ND (forward + backward pass)
55
56 Args:
57 N: Model parameters
58 D: Training tokens
59
60 Returns:
61 Approximate FLOPs
62 """
63 return 6 * N * D
64
65
66def find_optimal_allocation(C, A=406.4, B=410.7, alpha=0.34, beta=0.28, E=1.69):
67 """
68 Given compute budget C (in FLOPs), find optimal N and D.
69
70 Constraint: 6ND = C
71 Optimization: minimize L(N, D)
72
73 Solution: N = (C/6)^(beta/(alpha+beta)) * (A*alpha/(B*beta))^(beta/(alpha+beta))
74
75 Args:
76 C: Compute budget in FLOPs
77
78 Returns:
79 (N_optimal, D_optimal, L_optimal)
80 """
81 # Analytical solution from Chinchilla paper
82 exponent = beta / (alpha + beta)
83 coeff = (A * alpha / (B * beta)) ** exponent
84
85 N_opt = coeff * (C / 6) ** exponent
86 D_opt = C / (6 * N_opt)
87 L_opt = chinchilla_loss(N_opt, D_opt, A, B, alpha, beta, E)
88
89 return N_opt, D_opt, L_opt
90
91
92def kaplan_scaling_law(N, D, Nc=8.8e13, Dc=5.4e13, alpha_N=0.076, alpha_D=0.095):
93 """
94 Original Kaplan et al. (2020) scaling law.
95
96 L(N) = (Nc/N)^alpha_N when data is abundant
97 L(D) = (Dc/D)^alpha_D when model is large enough
98
99 Returns:
100 Predicted loss
101 """
102 # Use minimum of both constraints
103 loss_N = (Nc / N) ** alpha_N
104 loss_D = (Dc / D) ** alpha_D
105 return max(loss_N, loss_D) + 1.69 # Add irreducible loss
106
107
108# ============================================================
109# Main Demonstrations
110# ============================================================
111
112def demo_scaling_curves():
113 """Visualize how loss scales with model size and data."""
114 print("=" * 60)
115 print("DEMO 1: Scaling Curves")
116 print("=" * 60)
117
118 # Create range of model sizes (1M to 100B parameters)
119 N_range = np.logspace(6, 11, 50) # 1M to 100B
120 D_fixed = 200e9 # 200B tokens (GPT-3 scale)
121
122 losses = [chinchilla_loss(N, D_fixed) for N in N_range]
123
124 print(f"\nFixed data: {D_fixed/1e9:.0f}B tokens")
125 print(f"Model size range: {N_range[0]/1e6:.1f}M to {N_range[-1]/1e9:.1f}B params")
126 print(f"Loss range: {min(losses):.3f} to {max(losses):.3f}")
127
128 # Show specific points
129 for N in [1e9, 7e9, 70e9]:
130 loss = chinchilla_loss(N, D_fixed)
131 print(f" {N/1e9:.0f}B params → Loss = {loss:.3f}")
132
133 # Create range of data sizes (1B to 10T tokens)
134 D_range = np.logspace(9, 13, 50) # 1B to 10T
135 N_fixed = 7e9 # 7B params (LLaMA-7B scale)
136
137 losses_data = [chinchilla_loss(N_fixed, D) for D in D_range]
138
139 print(f"\nFixed model: {N_fixed/1e9:.0f}B parameters")
140 print(f"Data range: {D_range[0]/1e9:.0f}B to {D_range[-1]/1e12:.0f}T tokens")
141 print(f"Loss range: {min(losses_data):.3f} to {max(losses_data):.3f}")
142
143
144def demo_compute_optimal():
145 """Find compute-optimal model size for different budgets."""
146 print("\n" + "=" * 60)
147 print("DEMO 2: Compute-Optimal Allocation")
148 print("=" * 60)
149
150 # Different compute budgets (in FLOPs)
151 budgets = {
152 "GPT-3 (2020)": 3.14e23, # ~175B params, 300B tokens
153 "Chinchilla (2022)": 5.76e23, # ~70B params, 1.4T tokens
154 "LLaMA-65B": 6.3e23, # ~65B params, 1.4T tokens
155 "GPT-4 (estimated)": 1e25, # Speculation
156 }
157
158 print("\nCompute Budget → Optimal Allocation:\n")
159 print(f"{'Model':<20} {'FLOPs':<15} {'N (params)':<15} {'D (tokens)':<15} {'Loss':<10}")
160 print("-" * 75)
161
162 for name, budget in budgets.items():
163 N_opt, D_opt, L_opt = find_optimal_allocation(budget)
164 print(f"{name:<20} {budget:.2e} {N_opt/1e9:>8.1f}B {D_opt/1e9:>8.0f}B {L_opt:.3f}")
165
166 # Compare with actual GPT-3 (not optimal by Chinchilla standards)
167 print("\n" + "-" * 75)
168 print("Comparison: GPT-3 vs Optimal")
169 print("-" * 75)
170
171 gpt3_N = 175e9
172 gpt3_D = 300e9
173 gpt3_flops = compute_flops(gpt3_N, gpt3_D)
174 gpt3_loss = chinchilla_loss(gpt3_N, gpt3_D)
175
176 opt_N, opt_D, opt_loss = find_optimal_allocation(gpt3_flops)
177
178 print(f"GPT-3 actual: {gpt3_N/1e9:.0f}B params, {gpt3_D/1e9:.0f}B tokens → Loss = {gpt3_loss:.3f}")
179 print(f"Optimal: {opt_N/1e9:.0f}B params, {opt_D/1e9:.0f}B tokens → Loss = {opt_loss:.3f}")
180 print(f"Improvement: {gpt3_loss - opt_loss:.3f} reduction in loss")
181
182
183def demo_scaling_ratio():
184 """Demonstrate the compute-optimal N/D ratio."""
185 print("\n" + "=" * 60)
186 print("DEMO 3: Compute-Optimal Ratio")
187 print("=" * 60)
188
189 ratio_exp = compute_optimal_ratio()
190 print(f"\nChinchilla optimal ratio: N ∝ D^{ratio_exp:.3f}")
191 print(f"This means: D ∝ N^{1/ratio_exp:.3f}")
192 print(f"\nRule of thumb: For every 2x increase in model size,")
193 print(f"you should increase data by ~{2**(1/ratio_exp):.2f}x")
194
195 print("\n" + "-" * 60)
196 print("Scaling trajectory:")
197 print("-" * 60)
198
199 base_N = 1e9 # Start with 1B params
200 base_D = 20e9 # 20B tokens (Chinchilla optimal for 1B)
201
202 for scale in [1, 2, 4, 8, 16]:
203 N = base_N * scale
204 D = base_D * (scale ** (1/ratio_exp))
205 flops = compute_flops(N, D)
206 loss = chinchilla_loss(N, D)
207
208 print(f"{scale:>3}x: {N/1e9:>6.1f}B params, {D/1e9:>7.0f}B tokens, "
209 f"{flops:.2e} FLOPs, Loss = {loss:.3f}")
210
211
212def demo_comparison_with_kaplan():
213 """Compare Chinchilla vs Kaplan scaling laws."""
214 print("\n" + "=" * 60)
215 print("DEMO 4: Chinchilla vs Kaplan Scaling Laws")
216 print("=" * 60)
217
218 test_models = [
219 (1e9, "1B"),
220 (7e9, "7B"),
221 (13e9, "13B"),
222 (70e9, "70B"),
223 ]
224
225 D = 200e9 # 200B tokens
226
227 print(f"\nWith {D/1e9:.0f}B training tokens:\n")
228 print(f"{'Size':<10} {'Chinchilla':<15} {'Kaplan':<15} {'Difference':<10}")
229 print("-" * 50)
230
231 for N, name in test_models:
232 chin_loss = chinchilla_loss(N, D)
233 kaplan_loss = kaplan_scaling_law(N, D)
234 diff = kaplan_loss - chin_loss
235
236 print(f"{name:<10} {chin_loss:.4f} {kaplan_loss:.4f} {diff:+.4f}")
237
238 print("\nNote: Kaplan (2020) predicted more aggressive scaling benefits.")
239 print("Chinchilla (2022) revised this with better data efficiency emphasis.")
240
241
242def plot_scaling_laws():
243 """Generate comprehensive scaling law visualizations."""
244 print("\n" + "=" * 60)
245 print("DEMO 5: Visualization (plots generated)")
246 print("=" * 60)
247
248 fig, axes = plt.subplots(2, 2, figsize=(14, 10))
249
250 # Plot 1: Loss vs Model Size (fixed data)
251 ax1 = axes[0, 0]
252 N_range = np.logspace(6, 11, 100)
253 D_fixed = 200e9
254 losses = [chinchilla_loss(N, D_fixed) for N in N_range]
255
256 ax1.loglog(N_range, losses, 'b-', linewidth=2, label='Chinchilla')
257 ax1.set_xlabel('Model Parameters (N)', fontsize=11)
258 ax1.set_ylabel('Loss', fontsize=11)
259 ax1.set_title(f'Loss vs Model Size (D = {D_fixed/1e9:.0f}B tokens)', fontsize=12)
260 ax1.grid(True, alpha=0.3)
261 ax1.legend()
262
263 # Plot 2: Loss vs Training Data (fixed model)
264 ax2 = axes[0, 1]
265 D_range = np.logspace(9, 13, 100)
266 N_fixed = 7e9
267 losses_data = [chinchilla_loss(N_fixed, D) for D in D_range]
268
269 ax2.loglog(D_range, losses_data, 'r-', linewidth=2, label='Chinchilla')
270 ax2.set_xlabel('Training Tokens (D)', fontsize=11)
271 ax2.set_ylabel('Loss', fontsize=11)
272 ax2.set_title(f'Loss vs Training Data (N = {N_fixed/1e9:.0f}B params)', fontsize=12)
273 ax2.grid(True, alpha=0.3)
274 ax2.legend()
275
276 # Plot 3: Compute-Optimal Frontier
277 ax3 = axes[1, 0]
278 compute_budgets = np.logspace(21, 25, 50)
279 N_opts = []
280 D_opts = []
281
282 for C in compute_budgets:
283 N_opt, D_opt, _ = find_optimal_allocation(C)
284 N_opts.append(N_opt)
285 D_opts.append(D_opt)
286
287 ax3.loglog(N_opts, D_opts, 'g-', linewidth=2, label='Optimal frontier')
288
289 # Add specific model points
290 models = {
291 'GPT-3': (175e9, 300e9),
292 'Chinchilla': (70e9, 1400e9),
293 'LLaMA-65B': (65e9, 1400e9),
294 }
295
296 for name, (n, d) in models.items():
297 ax3.plot(n, d, 'o', markersize=8, label=name)
298
299 ax3.set_xlabel('Model Parameters (N)', fontsize=11)
300 ax3.set_ylabel('Training Tokens (D)', fontsize=11)
301 ax3.set_title('Compute-Optimal Frontier', fontsize=12)
302 ax3.grid(True, alpha=0.3)
303 ax3.legend()
304
305 # Plot 4: Loss Landscape (2D)
306 ax4 = axes[1, 1]
307 N_grid = np.logspace(9, 11, 30)
308 D_grid = np.logspace(10, 13, 30)
309 N_mesh, D_mesh = np.meshgrid(N_grid, D_grid)
310
311 L_mesh = np.zeros_like(N_mesh)
312 for i in range(len(D_grid)):
313 for j in range(len(N_grid)):
314 L_mesh[i, j] = chinchilla_loss(N_mesh[i, j], D_mesh[i, j])
315
316 contour = ax4.contour(N_mesh, D_mesh, L_mesh, levels=15, cmap='viridis')
317 ax4.clabel(contour, inline=True, fontsize=8)
318 ax4.set_xlabel('Model Parameters (N)', fontsize=11)
319 ax4.set_ylabel('Training Tokens (D)', fontsize=11)
320 ax4.set_title('Loss Landscape L(N, D)', fontsize=12)
321 ax4.set_xscale('log')
322 ax4.set_yscale('log')
323 ax4.grid(True, alpha=0.3)
324
325 plt.tight_layout()
326 plt.savefig('/opt/projects/01_Personal/03_Study/examples/Foundation_Models/scaling_laws.png', dpi=150)
327 print("\nPlot saved to: scaling_laws.png")
328 print("Shows: (1) Loss vs N, (2) Loss vs D, (3) Optimal frontier, (4) Loss landscape")
329
330
331if __name__ == "__main__":
332 print("\n" + "=" * 60)
333 print("Foundation Models: Scaling Laws")
334 print("=" * 60)
335
336 demo_scaling_curves()
337 demo_compute_optimal()
338 demo_scaling_ratio()
339 demo_comparison_with_kaplan()
340 plot_scaling_laws()
341
342 print("\n" + "=" * 60)
343 print("Key Takeaways:")
344 print("=" * 60)
345 print("1. Loss scales as power laws: L ~ N^(-α) and L ~ D^(-β)")
346 print("2. Chinchilla law: For compute budget C, optimal N ∝ C^0.45, D ∝ C^0.55")
347 print("3. Most models are overtrained (too many params, too little data)")
348 print("4. Doubling model size requires ~2.4x more data for optimality")
349 print("=" * 60)