Math graphic
๐Ÿ“ Concept diagram

18-07 โ€” Scaling Laws

Phase: 18 โ€” Large Language Model Mathematics Subject: 18-07 Prerequisites: 18-06 (Pre-training Objective Mathematics), 12-09 (Regression โ€” fitting power laws), 03-05 (Exponential and Logarithmic Functions), 14-01 (Optimization Fundamentals) Next subject: 18-08 โ€” Inference Mathematics


Learning Objectives

By the end of this subject, you will be able to:

  1. Derive the Kaplan scaling law L(N, D) from first principles of power-law behavior
  2. Compute optimal parameter/data ratios under the Chinchilla scaling law (D โ‰ˆ 20N)
  3. Explain the relationship between compute C, parameters N, and data D: C โ‰ˆ 6ND
  4. Derive the critical batch size from gradient noise scale and explain its role in scaling
  5. Apply compute-optimal training to determine the best model size for a given compute budget

Core Content

1. The Scaling Law Hypothesis

Empirical finding: Language model test loss follows predictable power-law relationships with model size N (parameters), dataset size D (tokens), and compute C (FLOPs).

L(N, D) โ‰ˆ (N_c / N)^(ฮฑ_N) + (D_c / D)^(ฮฑ_D) + L_โˆž

where: - N = number of non-embedding parameters - D = number of training tokens - L_โˆž = irreducible loss (entropy of language + unavoidable approximation error) - N_c, D_c = constants (characteristic scales) - ฮฑ_N, ฮฑ_D = power-law exponents

In log-space: log(L โˆ’ L_โˆž) = ฮฑ_Nยทlog(N_c/N) + ฮฑ_Dยทlog(D_c/D). This is why scaling laws are fit on log-log plots.

2. Kaplan et al. (2020) โ€” OpenAI Scaling Laws

The original scaling law paper from OpenAI studied GPT-style models and found:

Model size scaling: For models trained far from convergence on fixed data:

L(N) โ‰ˆ (N_c / N)^(ฮฑ_N) + L_โˆž

With ฮฑ_N โ‰ˆ 0.076, N_c โ‰ˆ 8.8 ร— 10^13.

Data size scaling: For large models on varying data:

L(D) โ‰ˆ (D_c / D)^(ฮฑ_D) + L_โˆž

With ฮฑ_D โ‰ˆ 0.095.

Compute scaling: Since C โ‰ˆ 6ND (6 FLOPs per parameter per token for transformer forward+backward):

L(C) โ‰ˆ (C_c / C)^(ฮฑ_C) + L_โˆž

With ฮฑ_C โ‰ˆ 0.050-0.057.

โš ๏ธ THIS IS CRITICAL โ€” The Kaplan paper's key recommendation: increase model size more aggressively than data size. For a 10ร— increase in compute, increase N by ~5.5ร— and D by only ~1.8ร—. This was the dominant approach until Chinchilla.

3. Chinchilla (Hoffmann et al., 2022) โ€” Compute-Optimal Scaling

DeepMind's Chinchilla paper revisited scaling laws with a key methodological improvement: they varied BOTH model size AND data size systematically, fitting a joint parametric loss:

Approach 1 (parametric fit):

L(N, D) = E + A/N^(ฮฑ) + B/D^(ฮฒ)

Fitted parameters: E โ‰ˆ 1.69, A โ‰ˆ 406.4, B โ‰ˆ 410.7, ฮฑ โ‰ˆ 0.34, ฮฒ โ‰ˆ 0.28.

Approach 2 (isoFLOP analysis): For each fixed compute budget C, find (N_opt, D_opt) that minimizes loss subject to C โ‰ˆ 6ND.

Result: N_opt โˆ C^0.50, D_opt โˆ C^0.50 โ€” parameters and tokens scale EQUALLY with compute.

This implies: D_opt โ‰ˆ 20 ยท N_opt (tokens should be ~20ร— the number of parameters).

4. Kaplan vs. Chinchilla: The Resolution

The apparent contradiction arises from different experimental designs:

Aspect Kaplan Chinchilla
Learning rate schedule Fixed steps Cosine decay to ~0 for each model
Model size range Limited Wide range
Analysis method Separate power laws Joint parametric fit
Recommendation N grows faster than D N and D grow equally

Chinchilla's conclusion: Many large models were undertrained โ€” they had too many parameters for their training data. A 70B model trained on 1.4T tokens (Chinchilla-optimal: D = 20ร—70B = 1.4T) outperforms a 280B model trained on 300B tokens (Gopher-style: undertrained).

5. The Compute Relationship

The training compute for a Transformer:

C_forward = 2 ยท N ยท D (approximately, for decoder-only)

2 FLOPs per parameter per token for forward pass (one multiply, one add for each weight).

C_total = C_forward + C_backward โ‰ˆ 2ยทC_forward = 4ยทNยทD (roughly, backward is ~2ร— forward)

A more precise estimate: C โ‰ˆ 6ยทNยทD (accounting for attention, LayerNorm, etc.).

For Chinchilla-optimal: D = 20N, so C โ‰ˆ 6ยทNยท(20N) = 120ยทNยฒ.

Solving for N given C: N_opt โ‰ˆ โˆš(C / 120).

Example: For C = 10^23 FLOPs (GPT-3 scale): N_opt โ‰ˆ โˆš(10^23 / 120) โ‰ˆ โˆš(8.33ร—10^20) โ‰ˆ 2.89ร—10^10 โ‰ˆ 29B parameters D_opt โ‰ˆ 20ยท29B โ‰ˆ 580B tokens

6. Critical Batch Size

The gradient noise scale (B_N) determines how large a batch can be before returns diminish:

B_crit = tr(ฮฃ) / |G|ยฒ

where ฮฃ is the gradient covariance matrix, G is the true gradient. Empirically:

L(B) โ‰ˆ L_โˆž + (B_crit / B)^(ฮฑ_B) ยท (L(Bโ†’1) โˆ’ L_โˆž)

The optimal tradeoff: increase batch size proportionally to the loss decrease. The critical batch size grows during training:

B_crit(t) โˆ 1 / L(t)

As loss decreases, you can efficiently use larger batches.

Practical implication: B_crit is typically 0.5-2M tokens for LLM training. Batch sizes of ~4M tokens are near-optimal for most of training. You don't need exponentially growing batch sizes.

7. Scaling Law Implications

Pre-training compute allocation: For a fixed compute budget C: - Chinchilla: spend equally on N and D scaling - Overtraining (more data than Chinchilla-optimal): better for inference efficiency (smaller model = faster inference) - Undertraining (more params than Chinchilla-optimal): worse performance, wastes compute

Emergent abilities: Capabilities that appear only above certain scale thresholds. While some argue these are smooth improvements that appear discontinuous due to metric choice, scaling laws predict CONTINUOUS improvement โ€” but the practical manifestation of that improvement can seem step-like for certain tasks.

Scaling law extrapolation: Fit power laws on small models (up to ~1B) and extrapolate to predict performance of much larger models. This is how labs decide whether to train a larger model.



Pitfalls

โš ๏ธ Pitfall 1: Applying Kaplan's recommendation to Chinchilla-era training. Kaplan said "increase model size more than data." Chinchilla showed this was because Kaplan undertrained larger models. For modern training, D โ‰ˆ 20N is compute-optimal. Training a 70B model on 300B tokens is severely undertrained.

โš ๏ธ Pitfall 2: Forgetting that C โ‰ˆ 6ND is an approximation. The exact constant depends on architecture details (SwiGLU vs ReLU, vocabulary size, sequence length). Use 6ND for back-of-envelope; for precise planning, profile your actual implementation.

โš ๏ธ Pitfall 3: Confusing Chinchilla-optimal with best-for-inference. Chinchilla-optimal minimizes training loss for a compute budget. But for deployment, overtrained smaller models (like Llama 7B on 2T tokens) can be better โ€” faster inference at slightly higher loss.


Key Terms

Worked Examples

Example 1: Compute-Optimal Model Size

Problem: You have a compute budget of C = 5 ร— 10^24 FLOPs. Using Chinchilla scaling (D โ‰ˆ 20N, C โ‰ˆ 6ND), what are the optimal N and D? What loss reduction would you expect going from N=150B, D=500B to optimal?

Solution:

C โ‰ˆ 6ND = 6N(20N) = 120Nยฒ N_opt = โˆš(C/120) = โˆš(5ร—10^24/120) = โˆš(4.167ร—10^22) โ‰ˆ 2.04ร—10^11 = 204B params D_opt = 20N = 20ยท204B โ‰ˆ 4.08T tokens

Check: 6ยท204Bยท4.08T = 6ยท204ยท4.08 ร— 10^(9+12) = 6ยท832.3 ร— 10^21 โ‰ˆ 4.99ร—10^24 โœ“

Non-optimal: N=150B, D=500B โ†’ C = 6ยท150Bยท0.5T = 6ยท150ยท0.5ร—10^21 = 450ร—10^21 = 4.5ร—10^23 (10ร— less โ€” the constraint binds one of them). Actually: if the budget is 5ร—10^24, we can use: 6ยท150BยทD = 5ร—10^24 โ†’ D = 5ร—10^24/(6ยท150ร—10^9) = 5ร—10^24/(9ร—10^11) โ‰ˆ 5.56T tokens.

Loss comparison (using simplified Chinchilla: L โ‰ˆ 1.69 + 406/N^0.34 + 410/D^0.28):

Optimal (204B, 4.08T): 406/204^0.34 โ‰ˆ 406/5.68 โ‰ˆ 71.5 410/4080^0.28 โ‰ˆ 410/9.26 โ‰ˆ 44.3 L_opt โ‰ˆ 1.69 + 0.0715 + 0.0443 โ‰ˆ 1.806

Suboptimal (150B, 5.56T): 406/150^0.34 โ‰ˆ 406/5.18 โ‰ˆ 78.4 410/5560^0.28 โ‰ˆ 410/10.05 โ‰ˆ 40.8 L_sub โ‰ˆ 1.69 + 0.0784 + 0.0408 โ‰ˆ 1.809

The optimal gives L โ‰ˆ 1.806, suboptimal gives L โ‰ˆ 1.809. Difference: ~0.003 nats. Perplexity: exp(0.003) โ‰ˆ 1.003 improvement. The Chinchilla optimum is relatively flat โ€” not too sensitive to exact ratio.

Example 2: Predicting Loss from Data

Problem: Using the Chinchilla data term only: L(D) โ‰ˆ 1.69 + 410/D^0.28, predict loss when training on 100B vs. 1T tokens. What's the loss reduction?

Solution:

D = 100B: 100^(0.28) = e^(0.28ยทln(100)) = e^(0.28ยท4.605) = e^(1.289) โ‰ˆ 3.63 L(100B) โ‰ˆ 1.69 + 410/3.63 โ‰ˆ 1.69 + 112.9 โ‰ˆ 114.6 nats... wait, that's way too high.

Oh โ€” the A and B constants include a factor. Let me reconsider. In the paper, L(N,D) = E + A/N^ฮฑ + B/D^ฮฒ with parameters in specific units. The 406 and 410 are fitted constants โ€” let me use a more practical form.

Actually, using a normalized form: L(N, D) โ‰ˆ L_โˆž + (D_c/D)^0.28 where D_c is calibrated. For a model achieving L = 2.5 at D = 100B tokens:

2.5 = L_โˆž + (D_c/100)^0.28

If we assume L_โˆž โ‰ˆ 1.69 and solve for D_c: (D_c/100)^0.28 = 0.81 D_c/100 = 0.81^(1/0.28) = 0.81^3.571 โ‰ˆ 0.47 D_c โ‰ˆ 47B

Then at D = 1T (1000B): L(1000B) = 1.69 + (47/1000)^0.28 = 1.69 + (0.047)^0.28 = 1.69 + 0.047^0.28 = 1.69 + e^(0.28ยทln(0.047)) = 1.69 + e^(0.28ยท(โˆ’3.058)) = 1.69 + e^(โˆ’0.856) = 1.69 + 0.425 = 2.115

Loss reduction: 2.5 โˆ’ 2.115 = 0.385 nats. Perplexity: exp(2.5) โ‰ˆ 12.2 โ†’ exp(2.115) โ‰ˆ 8.3 (32% reduction in PPL).

Example 3: Critical Batch Size

Problem: During training, the gradient noise scale B_crit is estimated to grow from 0.5M to 2M tokens. For a model training on 1T tokens, compute how many optimizer steps are needed at batch sizes B = 0.5M, 1M, 2M, and 4M.

Batch Size Steps = D/B
0.5M 1T/0.5M = 2,000,000
1M 1,000,000
2M 500,000
4M 250,000

The critical batch size of 2M means that beyond 2M tokens per batch, returns diminish. At B=4M, you get slightly worse utilization but require half the optimizer steps โ€” the wall-clock time might be lower if communication overhead dominates.

Optimal strategy: use B โ‰ˆ B_crit = 2M for most of training. Steps = 500K.



Summary

  1. Scaling laws have the form L(N,D) โ‰ˆ L_โˆž + A/N^ฮฑ + B/D^ฮฒ, fit on log-log data, with L_โˆž representing irreducible loss
  2. Chinchilla-optimal: D โ‰ˆ 20N (tokens โ‰ˆ 20ร— parameters), implying N โˆ โˆšC and D โˆ โˆšC โ€” equal scaling of parameters and data
  3. Training compute: C โ‰ˆ 6ND for Transformers (2 forward + 4 backward FLOPs per parameter per token approximately)
  4. Critical batch size B_crit grows during training, typically reaching 0.5-2M tokens; beyond this, larger batches yield diminishing returns
  5. Overtraining (D โ‰ซ 20N) trades some pre-training loss for faster inference; undertraining wastes compute

Quiz

Q1: What does the Chinchilla scaling law recommend for the ratio of training tokens to parameters?

A) D โ‰ˆ N (1:1) B) D โ‰ˆ 20N (20:1) C) D โ‰ˆ 100N (100:1) D) D โ‰ˆ Nยฒ

Correct: B)

Q3: Why are Chinchilla's scaling exponents (ฮฑ โ‰ˆ 0.34) larger than Kaplan's (ฮฑ โ‰ˆ 0.076)?

A) Chinchilla used bigger computers B) Chinchilla trained each model size to near-convergence with appropriate learning rate schedules, while Kaplan fixed training duration C) Chinchilla used a different architecture D) The exponents measure different things

Correct: B)

Q5: For a 7B model trained Chinchilla-optimally (140B tokens), what happens if you instead train on 2T tokens?

A) The model will overfit and perform worse B) The model will continue to improve (it's "overtrained"), but each additional token gives diminishing returns C) Training will crash D) The model size must also increase

B โ€” Language models don't overfit in the classical sense when trained on more text โ€” the loss continues to decrease (power-law). Training on 2T tokens instead of 140B (14ร— more) will reduce loss further. This is intentional "overtraining" โ€” it gives a smaller model (cheaper inference) but requires more training compute. The tradeoff favors inference-heavy deployments.




Next Steps

Continue to 18-08 โ€” Inference Mathematics to understand how trained models generate text, including sampling strategies, beam search, and log-probability computation.