01_scaling_laws.py

Download
python 350 lines 11.0 KB
  1"""
  2Foundation Models - Scaling Laws Implementation
  3
  4Demonstrates Chinchilla scaling laws and compute-optimal model sizing.
  5Implements power law relationships between loss, model size, and training data.
  6Visualizes scaling curves and compute-optimal frontier.
  7
  8No external dependencies except numpy and matplotlib.
  9"""
 10
 11import numpy as np
 12import matplotlib.pyplot as plt
 13
 14
 15def chinchilla_loss(N, D, A=406.4, B=410.7, alpha=0.34, beta=0.28, E=1.69):
 16    """
 17    Chinchilla scaling law for loss prediction.
 18
 19    L(N, D) = A/N^alpha + B/D^beta + E
 20
 21    Args:
 22        N: Number of model parameters (non-embedding)
 23        D: Number of training tokens
 24        A, B: Scaling coefficients
 25        alpha, beta: Scaling exponents
 26        E: Irreducible loss (entropy of natural text)
 27
 28    Returns:
 29        Predicted loss value
 30    """
 31    return A / (N ** alpha) + B / (D ** beta) + E
 32
 33
 34def compute_optimal_ratio(A=406.4, B=410.7, alpha=0.34, beta=0.28):
 35    """
 36    Compute the compute-optimal ratio N/D from Chinchilla paper.
 37
 38    At optimum: dL/dN = 0 and dL/dD = 0 under compute constraint.
 39    Result: N ∝ D^(beta/alpha)
 40
 41    Returns:
 42        Optimal ratio coefficient
 43    """
 44    # From calculus of Lagrange multipliers with compute constraint
 45    # Optimal: N = k * D^(beta/alpha)
 46    ratio_exponent = beta / alpha
 47    return ratio_exponent
 48
 49
 50def compute_flops(N, D):
 51    """
 52    Estimate FLOPs for training a transformer.
 53
 54    FLOPs ≈ 6ND (forward + backward pass)
 55
 56    Args:
 57        N: Model parameters
 58        D: Training tokens
 59
 60    Returns:
 61        Approximate FLOPs
 62    """
 63    return 6 * N * D
 64
 65
 66def find_optimal_allocation(C, A=406.4, B=410.7, alpha=0.34, beta=0.28, E=1.69):
 67    """
 68    Given compute budget C (in FLOPs), find optimal N and D.
 69
 70    Constraint: 6ND = C
 71    Optimization: minimize L(N, D)
 72
 73    Solution: N = (C/6)^(beta/(alpha+beta)) * (A*alpha/(B*beta))^(beta/(alpha+beta))
 74
 75    Args:
 76        C: Compute budget in FLOPs
 77
 78    Returns:
 79        (N_optimal, D_optimal, L_optimal)
 80    """
 81    # Analytical solution from Chinchilla paper
 82    exponent = beta / (alpha + beta)
 83    coeff = (A * alpha / (B * beta)) ** exponent
 84
 85    N_opt = coeff * (C / 6) ** exponent
 86    D_opt = C / (6 * N_opt)
 87    L_opt = chinchilla_loss(N_opt, D_opt, A, B, alpha, beta, E)
 88
 89    return N_opt, D_opt, L_opt
 90
 91
 92def kaplan_scaling_law(N, D, Nc=8.8e13, Dc=5.4e13, alpha_N=0.076, alpha_D=0.095):
 93    """
 94    Original Kaplan et al. (2020) scaling law.
 95
 96    L(N) = (Nc/N)^alpha_N when data is abundant
 97    L(D) = (Dc/D)^alpha_D when model is large enough
 98
 99    Returns:
100        Predicted loss
101    """
102    # Use minimum of both constraints
103    loss_N = (Nc / N) ** alpha_N
104    loss_D = (Dc / D) ** alpha_D
105    return max(loss_N, loss_D) + 1.69  # Add irreducible loss
106
107
108# ============================================================
109# Main Demonstrations
110# ============================================================
111
112def demo_scaling_curves():
113    """Visualize how loss scales with model size and data."""
114    print("=" * 60)
115    print("DEMO 1: Scaling Curves")
116    print("=" * 60)
117
118    # Create range of model sizes (1M to 100B parameters)
119    N_range = np.logspace(6, 11, 50)  # 1M to 100B
120    D_fixed = 200e9  # 200B tokens (GPT-3 scale)
121
122    losses = [chinchilla_loss(N, D_fixed) for N in N_range]
123
124    print(f"\nFixed data: {D_fixed/1e9:.0f}B tokens")
125    print(f"Model size range: {N_range[0]/1e6:.1f}M to {N_range[-1]/1e9:.1f}B params")
126    print(f"Loss range: {min(losses):.3f} to {max(losses):.3f}")
127
128    # Show specific points
129    for N in [1e9, 7e9, 70e9]:
130        loss = chinchilla_loss(N, D_fixed)
131        print(f"  {N/1e9:.0f}B params → Loss = {loss:.3f}")
132
133    # Create range of data sizes (1B to 10T tokens)
134    D_range = np.logspace(9, 13, 50)  # 1B to 10T
135    N_fixed = 7e9  # 7B params (LLaMA-7B scale)
136
137    losses_data = [chinchilla_loss(N_fixed, D) for D in D_range]
138
139    print(f"\nFixed model: {N_fixed/1e9:.0f}B parameters")
140    print(f"Data range: {D_range[0]/1e9:.0f}B to {D_range[-1]/1e12:.0f}T tokens")
141    print(f"Loss range: {min(losses_data):.3f} to {max(losses_data):.3f}")
142
143
144def demo_compute_optimal():
145    """Find compute-optimal model size for different budgets."""
146    print("\n" + "=" * 60)
147    print("DEMO 2: Compute-Optimal Allocation")
148    print("=" * 60)
149
150    # Different compute budgets (in FLOPs)
151    budgets = {
152        "GPT-3 (2020)": 3.14e23,      # ~175B params, 300B tokens
153        "Chinchilla (2022)": 5.76e23,  # ~70B params, 1.4T tokens
154        "LLaMA-65B": 6.3e23,           # ~65B params, 1.4T tokens
155        "GPT-4 (estimated)": 1e25,     # Speculation
156    }
157
158    print("\nCompute Budget → Optimal Allocation:\n")
159    print(f"{'Model':<20} {'FLOPs':<15} {'N (params)':<15} {'D (tokens)':<15} {'Loss':<10}")
160    print("-" * 75)
161
162    for name, budget in budgets.items():
163        N_opt, D_opt, L_opt = find_optimal_allocation(budget)
164        print(f"{name:<20} {budget:.2e}  {N_opt/1e9:>8.1f}B      {D_opt/1e9:>8.0f}B      {L_opt:.3f}")
165
166    # Compare with actual GPT-3 (not optimal by Chinchilla standards)
167    print("\n" + "-" * 75)
168    print("Comparison: GPT-3 vs Optimal")
169    print("-" * 75)
170
171    gpt3_N = 175e9
172    gpt3_D = 300e9
173    gpt3_flops = compute_flops(gpt3_N, gpt3_D)
174    gpt3_loss = chinchilla_loss(gpt3_N, gpt3_D)
175
176    opt_N, opt_D, opt_loss = find_optimal_allocation(gpt3_flops)
177
178    print(f"GPT-3 actual:  {gpt3_N/1e9:.0f}B params, {gpt3_D/1e9:.0f}B tokens → Loss = {gpt3_loss:.3f}")
179    print(f"Optimal:       {opt_N/1e9:.0f}B params, {opt_D/1e9:.0f}B tokens → Loss = {opt_loss:.3f}")
180    print(f"Improvement:   {gpt3_loss - opt_loss:.3f} reduction in loss")
181
182
183def demo_scaling_ratio():
184    """Demonstrate the compute-optimal N/D ratio."""
185    print("\n" + "=" * 60)
186    print("DEMO 3: Compute-Optimal Ratio")
187    print("=" * 60)
188
189    ratio_exp = compute_optimal_ratio()
190    print(f"\nChinchilla optimal ratio: N ∝ D^{ratio_exp:.3f}")
191    print(f"This means: D ∝ N^{1/ratio_exp:.3f}")
192    print(f"\nRule of thumb: For every 2x increase in model size,")
193    print(f"you should increase data by ~{2**(1/ratio_exp):.2f}x")
194
195    print("\n" + "-" * 60)
196    print("Scaling trajectory:")
197    print("-" * 60)
198
199    base_N = 1e9  # Start with 1B params
200    base_D = 20e9  # 20B tokens (Chinchilla optimal for 1B)
201
202    for scale in [1, 2, 4, 8, 16]:
203        N = base_N * scale
204        D = base_D * (scale ** (1/ratio_exp))
205        flops = compute_flops(N, D)
206        loss = chinchilla_loss(N, D)
207
208        print(f"{scale:>3}x: {N/1e9:>6.1f}B params, {D/1e9:>7.0f}B tokens, "
209              f"{flops:.2e} FLOPs, Loss = {loss:.3f}")
210
211
212def demo_comparison_with_kaplan():
213    """Compare Chinchilla vs Kaplan scaling laws."""
214    print("\n" + "=" * 60)
215    print("DEMO 4: Chinchilla vs Kaplan Scaling Laws")
216    print("=" * 60)
217
218    test_models = [
219        (1e9, "1B"),
220        (7e9, "7B"),
221        (13e9, "13B"),
222        (70e9, "70B"),
223    ]
224
225    D = 200e9  # 200B tokens
226
227    print(f"\nWith {D/1e9:.0f}B training tokens:\n")
228    print(f"{'Size':<10} {'Chinchilla':<15} {'Kaplan':<15} {'Difference':<10}")
229    print("-" * 50)
230
231    for N, name in test_models:
232        chin_loss = chinchilla_loss(N, D)
233        kaplan_loss = kaplan_scaling_law(N, D)
234        diff = kaplan_loss - chin_loss
235
236        print(f"{name:<10} {chin_loss:.4f}          {kaplan_loss:.4f}          {diff:+.4f}")
237
238    print("\nNote: Kaplan (2020) predicted more aggressive scaling benefits.")
239    print("Chinchilla (2022) revised this with better data efficiency emphasis.")
240
241
242def plot_scaling_laws():
243    """Generate comprehensive scaling law visualizations."""
244    print("\n" + "=" * 60)
245    print("DEMO 5: Visualization (plots generated)")
246    print("=" * 60)
247
248    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
249
250    # Plot 1: Loss vs Model Size (fixed data)
251    ax1 = axes[0, 0]
252    N_range = np.logspace(6, 11, 100)
253    D_fixed = 200e9
254    losses = [chinchilla_loss(N, D_fixed) for N in N_range]
255
256    ax1.loglog(N_range, losses, 'b-', linewidth=2, label='Chinchilla')
257    ax1.set_xlabel('Model Parameters (N)', fontsize=11)
258    ax1.set_ylabel('Loss', fontsize=11)
259    ax1.set_title(f'Loss vs Model Size (D = {D_fixed/1e9:.0f}B tokens)', fontsize=12)
260    ax1.grid(True, alpha=0.3)
261    ax1.legend()
262
263    # Plot 2: Loss vs Training Data (fixed model)
264    ax2 = axes[0, 1]
265    D_range = np.logspace(9, 13, 100)
266    N_fixed = 7e9
267    losses_data = [chinchilla_loss(N_fixed, D) for D in D_range]
268
269    ax2.loglog(D_range, losses_data, 'r-', linewidth=2, label='Chinchilla')
270    ax2.set_xlabel('Training Tokens (D)', fontsize=11)
271    ax2.set_ylabel('Loss', fontsize=11)
272    ax2.set_title(f'Loss vs Training Data (N = {N_fixed/1e9:.0f}B params)', fontsize=12)
273    ax2.grid(True, alpha=0.3)
274    ax2.legend()
275
276    # Plot 3: Compute-Optimal Frontier
277    ax3 = axes[1, 0]
278    compute_budgets = np.logspace(21, 25, 50)
279    N_opts = []
280    D_opts = []
281
282    for C in compute_budgets:
283        N_opt, D_opt, _ = find_optimal_allocation(C)
284        N_opts.append(N_opt)
285        D_opts.append(D_opt)
286
287    ax3.loglog(N_opts, D_opts, 'g-', linewidth=2, label='Optimal frontier')
288
289    # Add specific model points
290    models = {
291        'GPT-3': (175e9, 300e9),
292        'Chinchilla': (70e9, 1400e9),
293        'LLaMA-65B': (65e9, 1400e9),
294    }
295
296    for name, (n, d) in models.items():
297        ax3.plot(n, d, 'o', markersize=8, label=name)
298
299    ax3.set_xlabel('Model Parameters (N)', fontsize=11)
300    ax3.set_ylabel('Training Tokens (D)', fontsize=11)
301    ax3.set_title('Compute-Optimal Frontier', fontsize=12)
302    ax3.grid(True, alpha=0.3)
303    ax3.legend()
304
305    # Plot 4: Loss Landscape (2D)
306    ax4 = axes[1, 1]
307    N_grid = np.logspace(9, 11, 30)
308    D_grid = np.logspace(10, 13, 30)
309    N_mesh, D_mesh = np.meshgrid(N_grid, D_grid)
310
311    L_mesh = np.zeros_like(N_mesh)
312    for i in range(len(D_grid)):
313        for j in range(len(N_grid)):
314            L_mesh[i, j] = chinchilla_loss(N_mesh[i, j], D_mesh[i, j])
315
316    contour = ax4.contour(N_mesh, D_mesh, L_mesh, levels=15, cmap='viridis')
317    ax4.clabel(contour, inline=True, fontsize=8)
318    ax4.set_xlabel('Model Parameters (N)', fontsize=11)
319    ax4.set_ylabel('Training Tokens (D)', fontsize=11)
320    ax4.set_title('Loss Landscape L(N, D)', fontsize=12)
321    ax4.set_xscale('log')
322    ax4.set_yscale('log')
323    ax4.grid(True, alpha=0.3)
324
325    plt.tight_layout()
326    plt.savefig('/opt/projects/01_Personal/03_Study/examples/Foundation_Models/scaling_laws.png', dpi=150)
327    print("\nPlot saved to: scaling_laws.png")
328    print("Shows: (1) Loss vs N, (2) Loss vs D, (3) Optimal frontier, (4) Loss landscape")
329
330
331if __name__ == "__main__":
332    print("\n" + "=" * 60)
333    print("Foundation Models: Scaling Laws")
334    print("=" * 60)
335
336    demo_scaling_curves()
337    demo_compute_optimal()
338    demo_scaling_ratio()
339    demo_comparison_with_kaplan()
340    plot_scaling_laws()
341
342    print("\n" + "=" * 60)
343    print("Key Takeaways:")
344    print("=" * 60)
345    print("1. Loss scales as power laws: L ~ N^(-α) and L ~ D^(-β)")
346    print("2. Chinchilla law: For compute budget C, optimal N ∝ C^0.45, D ∝ C^0.55")
347    print("3. Most models are overtrained (too many params, too little data)")
348    print("4. Doubling model size requires ~2.4x more data for optimality")
349    print("=" * 60)