15-06 — Mixed-Precision Training
Phase: Numerical Methods for ML | Subject: 15-06 Prerequisites: 15-01-floating-point-arithmetic.md, 15-04-backpropagation-implementation.md Next subject: 15-07-gpu-computation-model.md
Learning Objectives
By the end of this subject, you will be able to:
- Compare float32, float16, and bfloat16 formats and their ML tradeoffs
- Explain why float16 training requires loss scaling
- Describe the master-weight-copy strategy for mixed-precision training
- Understand why bfloat16 eliminates the need for loss scaling
- Quantify the memory and speed benefits of mixed precision
Core Content
Floating-Point Formats Compared
| Format | Total Bits | Sign | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|---|
| float32 | 32 | 1 | 8 | 23 | $\sim 10^{\pm 38}$ | $\sim 7.2$ decimal digits |
| float16 | 16 | 1 | 5 | 10 | $\sim 10^{\pm 4.5}$ | $\sim 3.3$ decimal digits |
| bfloat16 | 16 | 1 | 8 | 7 | $\sim 10^{\pm 38}$ | $\sim 2.2$ decimal digits |
⚠️ CRITICAL — float16 vs bfloat16: bfloat16 has the same range as float32 (8 exponent bits) but low precision (7 mantissa bits). float16 has low range (5 exponent bits, max ~65504) and moderate precision (10 mantissa bits). bfloat16 = same exponent range as float32, truncated mantissa. This matters enormously for training.
Why float16 Alone Fails for Training
Training neural networks involves:
- Gradients that span many orders of magnitude — from ~$10^{-7}$ (later layers, small learning rates) to ~$10^3$ (early layers in large models)
- Weight updates $\Delta w = -\eta \cdot g$ that can be $10^{-4}$ or smaller
With float16 (max value ~65504): - Large activations overflow - Gradients $< 6 \times 10^{-5}$ underflow to zero (smallest normal float16)
Even without overflow, float16's 3.3 decimal digits mean: - Addition $65504 + 1 = 65504$ (the 1 is lost in rounding) - Weight updates smaller than $\sim 0.001\%$ of the weight magnitude are discarded
This is the loss-of-update problem: small gradients produce no weight change, and training stalls.
Mixed-Precision Training (Micikevicius et al., 2018)
The solution: keep a master copy of weights in float32 while performing forward/backward computation in float16.
Algorithm: 1. Maintain float32 master weights $\mathbf{W}{\text{master}}$ 2. Cast to float16: $\mathbf{W}{\text{fp16}} \leftarrow \mathbf{W}{\text{master}}$ 3. Forward pass in float16 4. Compute loss in float32 (higher precision for the scalar) 5. Backward pass in float16 6. Multiply gradients by loss scale $S$ (to prevent underflow) 7. Cast gradients to float32, divide by $S$ 8. Update master weights: $\mathbf{W}{\text{master}} \leftarrow \mathbf{W}{\text{master}} - \eta \cdot \nabla L{\text{fp32}}$
Loss Scaling
The core trick: multiply the loss by a large constant $S$ (e.g., $2^{16} = 65536$) before the backward pass. This scales all gradients up by $S$, pushing small gradients into float16's representable range. After backward, divide by $S$ to recover true gradients.
Dynamic loss scaling: Start with $S = 2^{16}$ (or larger for float16). If no Inf/NaN in gradients for $N$ consecutive steps, increase $S$. If Inf/NaN appears, skip the update and decrease $S$.
bfloat16: The Game-Changer
bfloat16 (Google Brain, 2018) takes a different approach: keep the 8-bit exponent of float32 but truncate the mantissa to 7 bits.
Advantages: - Same dynamic range as float32 → no overflow, no loss scaling needed - Conversion to/from float32 is a simple truncation (drop the lower 16 mantissa bits) - Training in pure bfloat16 works without a float32 master copy (though one is still recommended for optimizer state)
Disadvantage: Lower precision (2.2 decimal digits) than float16 (3.3 digits). But for neural network training, dynamic range matters more than precision — the noise from bfloat16's coarser mantissa acts as benign regularization.
Memory and Speed Impact
| Precision | Activation Memory | Weight Memory | Bandwidth | Throughput (A100) |
|---|---|---|---|---|
| float32 | 1× | 1× | 1× | 1× (base) |
| float16 | 0.5× | 0.5× | 0.5× | ~2× (tensor cores) |
| bfloat16 | 0.5× | 0.5× | 0.5× | ~2× (tensor cores) |
For a 1B parameter model: - float32 weights: 4 GB - float16/bfloat16 weights: 2 GB - Plus optimizer states (Adam: 2× moment buffers + master weights): 12 GB (fp32) vs 6–8 GB (mixed)
This 2× memory saving enables training models 2× larger on the same hardware.
Key Terms
- Gradients
- Weight updates
Worked Examples
Example 1: Float16 Underflow
A gradient $g = 3 \times 10^{-5}$ with learning rate $\eta = 0.001$. What happens in float16?
Solution: Weight update: $\Delta w = 0.001 \cdot (3 \times 10^{-5}) = 3 \times 10^{-8}$.
Smallest positive normal float16: $2^{-14} \approx 6.1 \times 10^{-5}$. Subnormal minimum: $2^{-24} \approx 5.96 \times 10^{-8}$.
$3 \times 10^{-8}$ is BELOW the float16 subnormal minimum → underflows to zero. The weight update is lost.
With loss scaling $S = 2^{17} = 131072$: Scaled gradient: $S \cdot g = 131072 \cdot (3 \times 10^{-5}) \approx 3.93$ — well within float16 range. Update in fp32 after unscaling: $\eta \cdot (3.93 / S) = 0.001 \cdot 3 \times 10^{-5} = 3 \times 10^{-8}$ — small but correctly applied to fp32 master weight.
Click for answer
Without scaling: update underflows to zero in float16. With scaling $S = 2^{17}$: gradient is representable in float16, and the fp32 master weight correctly receives the small update.Example 2: Float16 Overflow
Compute $\operatorname{softmax}([1000, 0, -1000])$ in float16. What happens?
Solution: $e^{1000}$ overflows float16 (max ~65504). $e^{12} \approx 162754$ already exceeds float16 max.
Even with the log-sum-exp trick ($e^{1000-1000}=1$, $e^{0-1000}=0$, $e^{-1000-1000}=0$): 1000 is representable in float16 (range up to 65504), so the shifted exponents are $0, -1000, -2000$. $e^{-1000}$ is ~0 in any precision. Softmax works fine WITH the trick.
Without the trick: $e^{-1000} \approx 0$ (subnormal or zero), $e^{1000} = \infty$ → softmax is [NaN, 0, 0].
Click for answer
Without log-sum-exp: $e^{1000}$ overflows float16 → softmax produces NaN. With the trick (subtract max=1000): all exponents are ≤ 0, no overflow, correct result $(1, 0, 0)$.Example 3: Memory Savings
A 7B-parameter transformer is trained with Adam in pure float32 vs mixed precision (fp16 forward/backward, fp32 master weights + optimizer states). Compare memory.
Solution: Adam stores: weights, first moment ($m$), second moment ($v$) — 3× the parameter count.
Float32: Weights: $7 \times 10^9 \times 4 = 28$ GB. Adam states ($m$, $v$): $2 \times 28 = 56$ GB. Gradients: 28 GB. Total: 112 GB.
Mixed precision (fp16 forward/backward, fp32 master): fp32 master weights: 28 GB. fp16 weights copy (forward): $7 \times 10^9 \times 2 = 14$ GB. fp32 optimizer states ($m$, $v$): 56 GB. fp16 gradients: 14 GB. Total: $28 + 14 + 56 + 14 = 112$ GB — same! But the forward activations are half size, saving significant memory there.
The main savings come from activations (half the memory) and from fp16 communication in model parallelism.
Pure bfloat16 (no fp32 master): bfloat16 weights: 14 GB. fp32 optimizer states: 56 GB. Total: 70 GB (excluding activations) — saves 42 GB.
Click for answer
Float32 Adam: ~112 GB. Mixed precision with fp32 master: similar weight memory but ~2× activation savings. bfloat16 with fp32 optimizer states: ~70 GB. The 2× memory reduction for activations enables 2× larger batch sizes or model dimensions.Quiz
Q1: What does the concept of Gradients primarily refer to in this subject?
A) The definition and application of Gradients B) A visual representation of Gradients C) A computational error related to Gradients D) A historical anecdote about Gradients
Correct: A)
- If you chose A: Gradients is defined as: the definition and application of gradients. The other options describe different aspects that are not the primary focus. Correct!
- If you chose B: This is incorrect. Gradients is defined as: the definition and application of gradients. The other options describe different aspects that are not the primary focus.
- If you chose C: This is incorrect. Gradients is defined as: the definition and application of gradients. The other options describe different aspects that are not the primary focus.
- If you chose D: This is incorrect. Gradients is defined as: the definition and application of gradients. The other options describe different aspects that are not the primary focus.
Q2: Which of the following is the key formula discussed in this subject?
A) A simplified version of \sim 10^{\pm 38}... B) The inverse operation of the formula in question C) \sim 10^{\pm 38} D) An unrelated formula from a different topic
Correct: C)
- If you chose A: This is incorrect. The formula \sim 10^{\pm 38} is central to this subject. The other options are either simplified versions or unrelated.
- If you chose B: This is incorrect. The formula \sim 10^{\pm 38} is central to this subject. The other options are either simplified versions or unrelated.
- If you chose C: The formula \sim 10^{\pm 38} is central to this subject. The other options are either simplified versions or unrelated. Correct!
- If you chose D: This is incorrect. The formula \sim 10^{\pm 38} is central to this subject. The other options are either simplified versions or unrelated.
Q3: What is the primary purpose of Weight updates?
A) It is used to weight updates in mathematical analysis B) It replaces all other methods in this domain C) It is primarily a historical notation system D) It is used only in advanced research contexts
Correct: A)
- If you chose A: Weight updates serves the purpose described in the correct answer. The other options misrepresent its role. Correct!
- If you chose B: This is incorrect. Weight updates serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose C: This is incorrect. Weight updates serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose D: This is incorrect. Weight updates serves the purpose described in the correct answer. The other options misrepresent its role.
Q4: Which statement about Floating-Point Formats Compared is TRUE?
A) Floating-Point Formats Compared is not related to this subject B) Floating-Point Formats Compared is an advanced topic beyond this subject's scope C) Floating-Point Formats Compared is a fundamental concept covered in this subject D) Floating-Point Formats Compared is mentioned only as a historical footnote
Correct: C)
- If you chose A: This is incorrect. Floating-Point Formats Compared is a fundamental concept covered in this subject. This subject covers Floating-Point Formats Compared as part of its core content.
- If you chose B: This is incorrect. Floating-Point Formats Compared is a fundamental concept covered in this subject. This subject covers Floating-Point Formats Compared as part of its core content.
- If you chose C: Floating-Point Formats Compared is a fundamental concept covered in this subject. This subject covers Floating-Point Formats Compared as part of its core content. Correct!
- If you chose D: This is incorrect. Floating-Point Formats Compared is a fundamental concept covered in this subject. This subject covers Floating-Point Formats Compared as part of its core content.
Q5: Based on the worked examples in this subject, what is the correct result?
A) The inverse of the correct answer B) 128$ representable values. C) A different result from a common mistake D) An unrelated numerical value
Correct: B)
- If you chose A: This is incorrect. The worked examples show that the result is 128$ representable values.. The other options represent common errors.
- If you chose B: The worked examples show that the result is 128$ representable values.. The other options represent common errors. Correct!
- If you chose C: This is incorrect. The worked examples show that the result is 128$ representable values.. The other options represent common errors.
- If you chose D: This is incorrect. The worked examples show that the result is 128$ representable values.. The other options represent common errors.
Q6: How are Floating-Point Formats Compared and Why Float16 Alone Fails For Training related?
A) Floating-Point Formats Compared and Why Float16 Alone Fails For Training are closely related concepts B) Floating-Point Formats Compared is a special case of Why Float16 Alone Fails For Training C) Floating-Point Formats Compared and Why Float16 Alone Fails For Training are completely unrelated topics D) Floating-Point Formats Compared is the inverse of Why Float16 Alone Fails For Training
Correct: A)
- If you chose A: Both Floating-Point Formats Compared and Why Float16 Alone Fails For Training are covered in this subject as interconnected topics. Correct!
- If you chose B: This is incorrect. Both Floating-Point Formats Compared and Why Float16 Alone Fails For Training are covered in this subject as interconnected topics.
- If you chose C: This is incorrect. Both Floating-Point Formats Compared and Why Float16 Alone Fails For Training are covered in this subject as interconnected topics.
- If you chose D: This is incorrect. Both Floating-Point Formats Compared and Why Float16 Alone Fails For Training are covered in this subject as interconnected topics.
Q7: What is a common pitfall when working with Mixed-Precision Training (Micikevicius Et Al., 2018)?
A) Mixed-Precision Training (Micikevicius Et Al., 2018) has no common misconceptions B) Mixed-Precision Training (Micikevicius Et Al., 2018) is always computed the same way in all contexts C) A common mistake is confusing Mixed-Precision Training (Micikevicius Et Al., 2018) with a similar concept D) The main error with Mixed-Precision Training (Micikevicius Et Al., 2018) is using it when it is not needed
Correct: C)
- If you chose A: This is incorrect. Students often confuse Mixed-Precision Training (Micikevicius Et Al., 2018) with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose B: This is incorrect. Students often confuse Mixed-Precision Training (Micikevicius Et Al., 2018) with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose C: Students often confuse Mixed-Precision Training (Micikevicius Et Al., 2018) with similar-sounding or related concepts. Pay attention to the precise definitions. Correct!
- If you chose D: This is incorrect. Students often confuse Mixed-Precision Training (Micikevicius Et Al., 2018) with similar-sounding or related concepts. Pay attention to the precise definitions.
Q8: When should you apply Loss Scaling?
A) Loss Scaling is not practically useful B) Avoid Loss Scaling unless explicitly instructed C) Apply Loss Scaling to solve problems in this subject's domain D) Use Loss Scaling only in pure mathematics contexts
Correct: C)
- If you chose A: This is incorrect. Loss Scaling is a practical tool used throughout this subject to solve relevant problems.
- If you chose B: This is incorrect. Loss Scaling is a practical tool used throughout this subject to solve relevant problems.
- If you chose C: Loss Scaling is a practical tool used throughout this subject to solve relevant problems. Correct!
- If you chose D: This is incorrect. Loss Scaling is a practical tool used throughout this subject to solve relevant problems.
Practice Problems
-
Compute the number of representable values between 1.0 and 2.0 in float16 vs bfloat16.
Click for answer
Between $2^0 = 1$ and $2^1 = 2$, the mantissa provides $2^{10} = 1024$ representable values in float16. In bfloat16: $2^7 = 128$ representable values. float32: $2^{23} = 8,388,608$ values. bfloat16 has 8× fewer values than float16 in the [1,2) range — much coarser precision. But for neural net training, gradients and activations are rarely computed to 3+ digits of accuracy anyway. -
A gradient's magnitude is $g = 10^{-4}$. What's the minimum loss scale $S$ needed to make $S \cdot g$ representable as a float16 normal number?
Click for answer
Smallest normal float16: $2^{-14} \approx 6.1 \times 10^{-5}$. Need $S \cdot 10^{-4} \geq 6.1 \times 10^{-5} \implies S \geq 0.61$. So $S=1$ is sufficient — this gradient is already representable in float16's normal range. Loss scaling is only needed for gradients below ~$6 \times 10^{-5}$. For training large models, gradients commonly reach $10^{-7}$ to $10^{-8}$ — requiring $S \sim 10^3$ to $10^4$. -
Explain why bfloat16 has the same exponent range as float32.
Click for answer
Both use 8 exponent bits (bias 127 for float32, same for bfloat16). The exponent encoding is identical — bfloat16 simply drops the lower 16 mantissa bits of float32. This means: (1) bfloat16 and float32 have identical overflow/underflow thresholds; (2) conversion between them is a truncation, not a complex rounding operation; (3) no loss scaling is needed. -
Dynamic loss scaling increases $S$ when no Inf/NaN is detected, decreases it otherwise. Why not just pick a very large fixed $S$?
Click for answer
A fixed large $S$ would cause overflow for large gradients. If $S = 2^{20} \approx 10^6$ and a gradient is $g = 0.1$, then $S \cdot g = 10^5$ — still within float16 max (65504)? Actually $10^5 > 65504$, so it WOULD overflow. Dynamic scaling finds the sweet spot: large enough to prevent underflow of small gradients, not so large that it overflows large gradients. As training progresses and gradients typically shrink, $S$ can be increased. -
Why do optimizer states (Adam's $m$ and $v$) typically remain in float32 even when weights are in float16 or bfloat16?
Click for answer
Optimizer states accumulate information over thousands of steps. $m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t$ requires precise accumulation — float16's 3.3 digits of precision would lose the small $(1-\beta_1)g_t$ term when added to the much larger $\beta_1 m_{t-1}$. Example: if $m_{t-1} \approx 1$ and $g_t \approx 10^{-4}$, then $(1-0.9) \cdot 10^{-4} = 10^{-5}$. In float16, $1 + 10^{-5} = 1$ (the update is lost). Float32's 7.2 digits can represent this correctly. bfloat16 would be even worse (2.2 digits). Optimizer states stay in float32 for stability.
Summary
Key takeaways:
- float16: 5 exponent + 10 mantissa bits — limited range (~65504 max), needs loss scaling
- bfloat16: 8 exponent + 7 mantissa bits — float32 range, no loss scaling needed, coarser precision
- Mixed-precision: fp32 master weights + fp16 compute + loss scaling = float32 accuracy at fp16 speed
- Loss scaling multiplies loss by $S$ to push small gradients into float16's representable range
- Dynamic loss scaling adapts $S$ based on gradient overflow detection
- Memory savings: 2× for activations and model weights; optimizer states typically stay fp32
- bfloat16 is preferred for large-scale training (GPT-3, PaLM, LLaMA all use bfloat16)
Pitfalls
- Assuming float16 works for training without loss scaling: Float16's limited range ($\sim 6 \times 10^{-5}$ to $65504$) means gradients routinely underflow and activations can overflow. Simply casting everything to float16 and running
.backward()will silently produce zero gradients for most parameters. Loss scaling is mandatory with float16 — or switch to bfloat16 which has the same exponent range as float32. - Storing optimizer states in float16 or bfloat16: Adam's $m$ and $v$ accumulate gradient information over thousands of steps. In float16, $m + \eta \cdot g$ loses the update when $|\eta g| / |m| < 2^{-10} \approx 0.001$. Bfloat16 is even worse (7 mantissa bits). Always use float32 for optimizer states, even when weights and activations are in lower precision.
- Forgetting that bfloat16 has lower precision than float16: Bfloat16 has only 7 mantissa bits (~2.2 decimal digits) vs float16's 10 bits (~3.3 digits). While the dynamic range advantage eliminates loss scaling needs, the coarser precision means weight updates below ~0.4% of weight magnitude are lost. This is usually benign for training (acts as implicit regularization), but for precise computations like batch norm statistics or loss computation, float32 should still be used.
- Using a fixed loss scale that's too large: A fixed $S = 2^{20}$ may prevent underflow for small gradients but will overflow large gradients (e.g., $g=0.1$ gives $S \cdot g = 1.05 \times 10^5 > 65504$). Dynamic loss scaling that adapts $S$ based on gradient statistics is essential. Start conservative ($S = 2^{16}$) and let the algorithm increase it over time.
- Not verifying that mixed-precision gradients match float32 gradients: After converting a training loop to mixed precision, always compare the resulting gradients against a float32 baseline on the first few batches. Numerical differences beyond ~1% relative error may indicate underflow or overflow that loss scaling missed. This is especially important for models with unusual loss landscapes or very small learning rates.
Next Steps
Next up: 15-07-gpu-computation-model.md — SIMD/SIMT execution, GPU memory hierarchy, thread blocks and warps, and why matrix multiply is so fast on GPUs.