18-07 โ Scaling Laws
Phase: 18 โ Large Language Model Mathematics Subject: 18-07 Prerequisites: 18-06 (Pre-training Objective Mathematics), 12-09 (Regression โ fitting power laws), 03-05 (Exponential and Logarithmic Functions), 14-01 (Optimization Fundamentals) Next subject: 18-08 โ Inference Mathematics
Learning Objectives
By the end of this subject, you will be able to:
- Derive the Kaplan scaling law L(N, D) from first principles of power-law behavior
- Compute optimal parameter/data ratios under the Chinchilla scaling law (D โ 20N)
- Explain the relationship between compute C, parameters N, and data D: C โ 6ND
- Derive the critical batch size from gradient noise scale and explain its role in scaling
- Apply compute-optimal training to determine the best model size for a given compute budget
Core Content
1. The Scaling Law Hypothesis
Empirical finding: Language model test loss follows predictable power-law relationships with model size N (parameters), dataset size D (tokens), and compute C (FLOPs).
L(N, D) โ (N_c / N)^(ฮฑ_N) + (D_c / D)^(ฮฑ_D) + L_โ
where: - N = number of non-embedding parameters - D = number of training tokens - L_โ = irreducible loss (entropy of language + unavoidable approximation error) - N_c, D_c = constants (characteristic scales) - ฮฑ_N, ฮฑ_D = power-law exponents
In log-space: log(L โ L_โ) = ฮฑ_Nยทlog(N_c/N) + ฮฑ_Dยทlog(D_c/D). This is why scaling laws are fit on log-log plots.
2. Kaplan et al. (2020) โ OpenAI Scaling Laws
The original scaling law paper from OpenAI studied GPT-style models and found:
Model size scaling: For models trained far from convergence on fixed data:
L(N) โ (N_c / N)^(ฮฑ_N) + L_โ
With ฮฑ_N โ 0.076, N_c โ 8.8 ร 10^13.
Data size scaling: For large models on varying data:
L(D) โ (D_c / D)^(ฮฑ_D) + L_โ
With ฮฑ_D โ 0.095.
Compute scaling: Since C โ 6ND (6 FLOPs per parameter per token for transformer forward+backward):
L(C) โ (C_c / C)^(ฮฑ_C) + L_โ
With ฮฑ_C โ 0.050-0.057.
โ ๏ธ THIS IS CRITICAL โ The Kaplan paper's key recommendation: increase model size more aggressively than data size. For a 10ร increase in compute, increase N by ~5.5ร and D by only ~1.8ร. This was the dominant approach until Chinchilla.
3. Chinchilla (Hoffmann et al., 2022) โ Compute-Optimal Scaling
DeepMind's Chinchilla paper revisited scaling laws with a key methodological improvement: they varied BOTH model size AND data size systematically, fitting a joint parametric loss:
Approach 1 (parametric fit):
L(N, D) = E + A/N^(ฮฑ) + B/D^(ฮฒ)
Fitted parameters: E โ 1.69, A โ 406.4, B โ 410.7, ฮฑ โ 0.34, ฮฒ โ 0.28.
Approach 2 (isoFLOP analysis): For each fixed compute budget C, find (N_opt, D_opt) that minimizes loss subject to C โ 6ND.
Result: N_opt โ C^0.50, D_opt โ C^0.50 โ parameters and tokens scale EQUALLY with compute.
This implies: D_opt โ 20 ยท N_opt (tokens should be ~20ร the number of parameters).
4. Kaplan vs. Chinchilla: The Resolution
The apparent contradiction arises from different experimental designs:
| Aspect | Kaplan | Chinchilla |
|---|---|---|
| Learning rate schedule | Fixed steps | Cosine decay to ~0 for each model |
| Model size range | Limited | Wide range |
| Analysis method | Separate power laws | Joint parametric fit |
| Recommendation | N grows faster than D | N and D grow equally |
Chinchilla's conclusion: Many large models were undertrained โ they had too many parameters for their training data. A 70B model trained on 1.4T tokens (Chinchilla-optimal: D = 20ร70B = 1.4T) outperforms a 280B model trained on 300B tokens (Gopher-style: undertrained).
5. The Compute Relationship
The training compute for a Transformer:
C_forward = 2 ยท N ยท D (approximately, for decoder-only)
2 FLOPs per parameter per token for forward pass (one multiply, one add for each weight).
C_total = C_forward + C_backward โ 2ยทC_forward = 4ยทNยทD (roughly, backward is ~2ร forward)
A more precise estimate: C โ 6ยทNยทD (accounting for attention, LayerNorm, etc.).
For Chinchilla-optimal: D = 20N, so C โ 6ยทNยท(20N) = 120ยทNยฒ.
Solving for N given C: N_opt โ โ(C / 120).
Example: For C = 10^23 FLOPs (GPT-3 scale): N_opt โ โ(10^23 / 120) โ โ(8.33ร10^20) โ 2.89ร10^10 โ 29B parameters D_opt โ 20ยท29B โ 580B tokens
6. Critical Batch Size
The gradient noise scale (B_N) determines how large a batch can be before returns diminish:
B_crit = tr(ฮฃ) / |G|ยฒ
where ฮฃ is the gradient covariance matrix, G is the true gradient. Empirically:
L(B) โ L_โ + (B_crit / B)^(ฮฑ_B) ยท (L(Bโ1) โ L_โ)
The optimal tradeoff: increase batch size proportionally to the loss decrease. The critical batch size grows during training:
B_crit(t) โ 1 / L(t)
As loss decreases, you can efficiently use larger batches.
Practical implication: B_crit is typically 0.5-2M tokens for LLM training. Batch sizes of ~4M tokens are near-optimal for most of training. You don't need exponentially growing batch sizes.
7. Scaling Law Implications
Pre-training compute allocation: For a fixed compute budget C: - Chinchilla: spend equally on N and D scaling - Overtraining (more data than Chinchilla-optimal): better for inference efficiency (smaller model = faster inference) - Undertraining (more params than Chinchilla-optimal): worse performance, wastes compute
Emergent abilities: Capabilities that appear only above certain scale thresholds. While some argue these are smooth improvements that appear discontinuous due to metric choice, scaling laws predict CONTINUOUS improvement โ but the practical manifestation of that improvement can seem step-like for certain tasks.
Scaling law extrapolation: Fit power laws on small models (up to ~1B) and extrapolate to predict performance of much larger models. This is how labs decide whether to train a larger model.
Pitfalls
โ ๏ธ Pitfall 1: Applying Kaplan's recommendation to Chinchilla-era training. Kaplan said "increase model size more than data." Chinchilla showed this was because Kaplan undertrained larger models. For modern training, D โ 20N is compute-optimal. Training a 70B model on 300B tokens is severely undertrained.
โ ๏ธ Pitfall 2: Forgetting that C โ 6ND is an approximation. The exact constant depends on architecture details (SwiGLU vs ReLU, vocabulary size, sequence length). Use 6ND for back-of-envelope; for precise planning, profile your actual implementation.
โ ๏ธ Pitfall 3: Confusing Chinchilla-optimal with best-for-inference. Chinchilla-optimal minimizes training loss for a compute budget. But for deployment, overtrained smaller models (like Llama 7B on 2T tokens) can be better โ faster inference at slightly higher loss.
Key Terms
- 18 07 Scaling Laws
- Analysis method
- Aspect
- Batch Size
- Chinchilla optimum
- Common Pitfalls
- Critical Batch Size
- D_opt โ 20 ยท N_opt
- Example 1: Compute-Optimal Model Size
- Example 2: Predicting Loss from Data
- Example 3: Critical Batch Size
- Kaplan et al. (2020) โ OpenAI Scaling Laws
Worked Examples
Example 1: Compute-Optimal Model Size
Problem: You have a compute budget of C = 5 ร 10^24 FLOPs. Using Chinchilla scaling (D โ 20N, C โ 6ND), what are the optimal N and D? What loss reduction would you expect going from N=150B, D=500B to optimal?
Solution:
C โ 6ND = 6N(20N) = 120Nยฒ N_opt = โ(C/120) = โ(5ร10^24/120) = โ(4.167ร10^22) โ 2.04ร10^11 = 204B params D_opt = 20N = 20ยท204B โ 4.08T tokens
Check: 6ยท204Bยท4.08T = 6ยท204ยท4.08 ร 10^(9+12) = 6ยท832.3 ร 10^21 โ 4.99ร10^24 โ
Non-optimal: N=150B, D=500B โ C = 6ยท150Bยท0.5T = 6ยท150ยท0.5ร10^21 = 450ร10^21 = 4.5ร10^23 (10ร less โ the constraint binds one of them). Actually: if the budget is 5ร10^24, we can use: 6ยท150BยทD = 5ร10^24 โ D = 5ร10^24/(6ยท150ร10^9) = 5ร10^24/(9ร10^11) โ 5.56T tokens.
Loss comparison (using simplified Chinchilla: L โ 1.69 + 406/N^0.34 + 410/D^0.28):
Optimal (204B, 4.08T): 406/204^0.34 โ 406/5.68 โ 71.5 410/4080^0.28 โ 410/9.26 โ 44.3 L_opt โ 1.69 + 0.0715 + 0.0443 โ 1.806
Suboptimal (150B, 5.56T): 406/150^0.34 โ 406/5.18 โ 78.4 410/5560^0.28 โ 410/10.05 โ 40.8 L_sub โ 1.69 + 0.0784 + 0.0408 โ 1.809
The optimal gives L โ 1.806, suboptimal gives L โ 1.809. Difference: ~0.003 nats. Perplexity: exp(0.003) โ 1.003 improvement. The Chinchilla optimum is relatively flat โ not too sensitive to exact ratio.
Example 2: Predicting Loss from Data
Problem: Using the Chinchilla data term only: L(D) โ 1.69 + 410/D^0.28, predict loss when training on 100B vs. 1T tokens. What's the loss reduction?
Solution:
D = 100B: 100^(0.28) = e^(0.28ยทln(100)) = e^(0.28ยท4.605) = e^(1.289) โ 3.63 L(100B) โ 1.69 + 410/3.63 โ 1.69 + 112.9 โ 114.6 nats... wait, that's way too high.
Oh โ the A and B constants include a factor. Let me reconsider. In the paper, L(N,D) = E + A/N^ฮฑ + B/D^ฮฒ with parameters in specific units. The 406 and 410 are fitted constants โ let me use a more practical form.
Actually, using a normalized form: L(N, D) โ L_โ + (D_c/D)^0.28 where D_c is calibrated. For a model achieving L = 2.5 at D = 100B tokens:
2.5 = L_โ + (D_c/100)^0.28
If we assume L_โ โ 1.69 and solve for D_c: (D_c/100)^0.28 = 0.81 D_c/100 = 0.81^(1/0.28) = 0.81^3.571 โ 0.47 D_c โ 47B
Then at D = 1T (1000B): L(1000B) = 1.69 + (47/1000)^0.28 = 1.69 + (0.047)^0.28 = 1.69 + 0.047^0.28 = 1.69 + e^(0.28ยทln(0.047)) = 1.69 + e^(0.28ยท(โ3.058)) = 1.69 + e^(โ0.856) = 1.69 + 0.425 = 2.115
Loss reduction: 2.5 โ 2.115 = 0.385 nats. Perplexity: exp(2.5) โ 12.2 โ exp(2.115) โ 8.3 (32% reduction in PPL).
Example 3: Critical Batch Size
Problem: During training, the gradient noise scale B_crit is estimated to grow from 0.5M to 2M tokens. For a model training on 1T tokens, compute how many optimizer steps are needed at batch sizes B = 0.5M, 1M, 2M, and 4M.
| Batch Size | Steps = D/B |
|---|---|
| 0.5M | 1T/0.5M = 2,000,000 |
| 1M | 1,000,000 |
| 2M | 500,000 |
| 4M | 250,000 |
The critical batch size of 2M means that beyond 2M tokens per batch, returns diminish. At B=4M, you get slightly worse utilization but require half the optimizer steps โ the wall-clock time might be lower if communication overhead dominates.
Optimal strategy: use B โ B_crit = 2M for most of training. Steps = 500K.
Summary
- Scaling laws have the form L(N,D) โ L_โ + A/N^ฮฑ + B/D^ฮฒ, fit on log-log data, with L_โ representing irreducible loss
- Chinchilla-optimal: D โ 20N (tokens โ 20ร parameters), implying N โ โC and D โ โC โ equal scaling of parameters and data
- Training compute: C โ 6ND for Transformers (2 forward + 4 backward FLOPs per parameter per token approximately)
- Critical batch size B_crit grows during training, typically reaching 0.5-2M tokens; beyond this, larger batches yield diminishing returns
- Overtraining (D โซ 20N) trades some pre-training loss for faster inference; undertraining wastes compute
Quiz
Q1: What does the Chinchilla scaling law recommend for the ratio of training tokens to parameters?
A) D โ N (1:1) B) D โ 20N (20:1) C) D โ 100N (100:1) D) D โ Nยฒ
Correct: B)
- If you chose B: The Chinchilla paper found that compute-optimal training uses ~20ร as many tokens as parameters. For a 70B model, this means ~1.4T tokens. This was a major correction to the Kaplan-era practice of using far fewer tokens (e.g., 300B for a 280B model). Correct!
Q3: Why are Chinchilla's scaling exponents (ฮฑ โ 0.34) larger than Kaplan's (ฮฑ โ 0.076)?
A) Chinchilla used bigger computers B) Chinchilla trained each model size to near-convergence with appropriate learning rate schedules, while Kaplan fixed training duration C) Chinchilla used a different architecture D) The exponents measure different things
Correct: B)
- If you chose B: Kaplan fixed training steps, so larger models were further from convergence. This made larger models appear to help less. Chinchilla trained each model to convergence with tuned learning rates, revealing that model size is significantly more beneficial than Kaplan estimated. This methodological difference explains the entire discrepancy. Correct!
Q5: For a 7B model trained Chinchilla-optimally (140B tokens), what happens if you instead train on 2T tokens?
A) The model will overfit and perform worse B) The model will continue to improve (it's "overtrained"), but each additional token gives diminishing returns C) Training will crash D) The model size must also increase
B โ Language models don't overfit in the classical sense when trained on more text โ the loss continues to decrease (power-law). Training on 2T tokens instead of 140B (14ร more) will reduce loss further. This is intentional "overtraining" โ it gives a smaller model (cheaper inference) but requires more training compute. The tradeoff favors inference-heavy deployments.
Next Steps
Continue to 18-08 โ Inference Mathematics to understand how trained models generate text, including sampling strategies, beam search, and log-probability computation.