Math graphic
📐 Concept diagram

15-06 — Mixed-Precision Training

Phase: Numerical Methods for ML | Subject: 15-06 Prerequisites: 15-01-floating-point-arithmetic.md, 15-04-backpropagation-implementation.md Next subject: 15-07-gpu-computation-model.md


Learning Objectives

By the end of this subject, you will be able to:

  1. Compare float32, float16, and bfloat16 formats and their ML tradeoffs
  2. Explain why float16 training requires loss scaling
  3. Describe the master-weight-copy strategy for mixed-precision training
  4. Understand why bfloat16 eliminates the need for loss scaling
  5. Quantify the memory and speed benefits of mixed precision

Core Content

Floating-Point Formats Compared

Format Total Bits Sign Exponent Mantissa Range Precision
float32 32 1 8 23 $\sim 10^{\pm 38}$ $\sim 7.2$ decimal digits
float16 16 1 5 10 $\sim 10^{\pm 4.5}$ $\sim 3.3$ decimal digits
bfloat16 16 1 8 7 $\sim 10^{\pm 38}$ $\sim 2.2$ decimal digits

⚠️ CRITICAL — float16 vs bfloat16: bfloat16 has the same range as float32 (8 exponent bits) but low precision (7 mantissa bits). float16 has low range (5 exponent bits, max ~65504) and moderate precision (10 mantissa bits). bfloat16 = same exponent range as float32, truncated mantissa. This matters enormously for training.

Why float16 Alone Fails for Training

Training neural networks involves:

  1. Gradients that span many orders of magnitude — from ~$10^{-7}$ (later layers, small learning rates) to ~$10^3$ (early layers in large models)
  2. Weight updates $\Delta w = -\eta \cdot g$ that can be $10^{-4}$ or smaller

With float16 (max value ~65504): - Large activations overflow - Gradients $< 6 \times 10^{-5}$ underflow to zero (smallest normal float16)

Even without overflow, float16's 3.3 decimal digits mean: - Addition $65504 + 1 = 65504$ (the 1 is lost in rounding) - Weight updates smaller than $\sim 0.001\%$ of the weight magnitude are discarded

This is the loss-of-update problem: small gradients produce no weight change, and training stalls.

Mixed-Precision Training (Micikevicius et al., 2018)

The solution: keep a master copy of weights in float32 while performing forward/backward computation in float16.

Algorithm: 1. Maintain float32 master weights $\mathbf{W}{\text{master}}$ 2. Cast to float16: $\mathbf{W}{\text{fp16}} \leftarrow \mathbf{W}{\text{master}}$ 3. Forward pass in float16 4. Compute loss in float32 (higher precision for the scalar) 5. Backward pass in float16 6. Multiply gradients by loss scale $S$ (to prevent underflow) 7. Cast gradients to float32, divide by $S$ 8. Update master weights: $\mathbf{W}{\text{master}} \leftarrow \mathbf{W}{\text{master}} - \eta \cdot \nabla L{\text{fp32}}$

Loss Scaling

The core trick: multiply the loss by a large constant $S$ (e.g., $2^{16} = 65536$) before the backward pass. This scales all gradients up by $S$, pushing small gradients into float16's representable range. After backward, divide by $S$ to recover true gradients.

Dynamic loss scaling: Start with $S = 2^{16}$ (or larger for float16). If no Inf/NaN in gradients for $N$ consecutive steps, increase $S$. If Inf/NaN appears, skip the update and decrease $S$.

bfloat16: The Game-Changer

bfloat16 (Google Brain, 2018) takes a different approach: keep the 8-bit exponent of float32 but truncate the mantissa to 7 bits.

Advantages: - Same dynamic range as float32 → no overflow, no loss scaling needed - Conversion to/from float32 is a simple truncation (drop the lower 16 mantissa bits) - Training in pure bfloat16 works without a float32 master copy (though one is still recommended for optimizer state)

Disadvantage: Lower precision (2.2 decimal digits) than float16 (3.3 digits). But for neural network training, dynamic range matters more than precision — the noise from bfloat16's coarser mantissa acts as benign regularization.

Memory and Speed Impact

Precision Activation Memory Weight Memory Bandwidth Throughput (A100)
float32 1× (base)
float16 0.5× 0.5× 0.5× ~2× (tensor cores)
bfloat16 0.5× 0.5× 0.5× ~2× (tensor cores)

For a 1B parameter model: - float32 weights: 4 GB - float16/bfloat16 weights: 2 GB - Plus optimizer states (Adam: 2× moment buffers + master weights): 12 GB (fp32) vs 6–8 GB (mixed)

This 2× memory saving enables training models 2× larger on the same hardware.



Key Terms

Worked Examples

Example 1: Float16 Underflow

A gradient $g = 3 \times 10^{-5}$ with learning rate $\eta = 0.001$. What happens in float16?

Solution: Weight update: $\Delta w = 0.001 \cdot (3 \times 10^{-5}) = 3 \times 10^{-8}$.

Smallest positive normal float16: $2^{-14} \approx 6.1 \times 10^{-5}$. Subnormal minimum: $2^{-24} \approx 5.96 \times 10^{-8}$.

$3 \times 10^{-8}$ is BELOW the float16 subnormal minimum → underflows to zero. The weight update is lost.

With loss scaling $S = 2^{17} = 131072$: Scaled gradient: $S \cdot g = 131072 \cdot (3 \times 10^{-5}) \approx 3.93$ — well within float16 range. Update in fp32 after unscaling: $\eta \cdot (3.93 / S) = 0.001 \cdot 3 \times 10^{-5} = 3 \times 10^{-8}$ — small but correctly applied to fp32 master weight.

Click for answer Without scaling: update underflows to zero in float16. With scaling $S = 2^{17}$: gradient is representable in float16, and the fp32 master weight correctly receives the small update.

Example 2: Float16 Overflow

Compute $\operatorname{softmax}([1000, 0, -1000])$ in float16. What happens?

Solution: $e^{1000}$ overflows float16 (max ~65504). $e^{12} \approx 162754$ already exceeds float16 max.

Even with the log-sum-exp trick ($e^{1000-1000}=1$, $e^{0-1000}=0$, $e^{-1000-1000}=0$): 1000 is representable in float16 (range up to 65504), so the shifted exponents are $0, -1000, -2000$. $e^{-1000}$ is ~0 in any precision. Softmax works fine WITH the trick.

Without the trick: $e^{-1000} \approx 0$ (subnormal or zero), $e^{1000} = \infty$ → softmax is [NaN, 0, 0].

Click for answer Without log-sum-exp: $e^{1000}$ overflows float16 → softmax produces NaN. With the trick (subtract max=1000): all exponents are ≤ 0, no overflow, correct result $(1, 0, 0)$.

Example 3: Memory Savings

A 7B-parameter transformer is trained with Adam in pure float32 vs mixed precision (fp16 forward/backward, fp32 master weights + optimizer states). Compare memory.

Solution: Adam stores: weights, first moment ($m$), second moment ($v$) — 3× the parameter count.

Float32: Weights: $7 \times 10^9 \times 4 = 28$ GB. Adam states ($m$, $v$): $2 \times 28 = 56$ GB. Gradients: 28 GB. Total: 112 GB.

Mixed precision (fp16 forward/backward, fp32 master): fp32 master weights: 28 GB. fp16 weights copy (forward): $7 \times 10^9 \times 2 = 14$ GB. fp32 optimizer states ($m$, $v$): 56 GB. fp16 gradients: 14 GB. Total: $28 + 14 + 56 + 14 = 112$ GB — same! But the forward activations are half size, saving significant memory there.

The main savings come from activations (half the memory) and from fp16 communication in model parallelism.

Pure bfloat16 (no fp32 master): bfloat16 weights: 14 GB. fp32 optimizer states: 56 GB. Total: 70 GB (excluding activations) — saves 42 GB.

Click for answer Float32 Adam: ~112 GB. Mixed precision with fp32 master: similar weight memory but ~2× activation savings. bfloat16 with fp32 optimizer states: ~70 GB. The 2× memory reduction for activations enables 2× larger batch sizes or model dimensions.


Quiz

Q1: What does the concept of Gradients primarily refer to in this subject?

A) The definition and application of Gradients B) A visual representation of Gradients C) A computational error related to Gradients D) A historical anecdote about Gradients

Correct: A)

Q2: Which of the following is the key formula discussed in this subject?

A) A simplified version of \sim 10^{\pm 38}... B) The inverse operation of the formula in question C) \sim 10^{\pm 38} D) An unrelated formula from a different topic

Correct: C)

Q3: What is the primary purpose of Weight updates?

A) It is used to weight updates in mathematical analysis B) It replaces all other methods in this domain C) It is primarily a historical notation system D) It is used only in advanced research contexts

Correct: A)

Q4: Which statement about Floating-Point Formats Compared is TRUE?

A) Floating-Point Formats Compared is not related to this subject B) Floating-Point Formats Compared is an advanced topic beyond this subject's scope C) Floating-Point Formats Compared is a fundamental concept covered in this subject D) Floating-Point Formats Compared is mentioned only as a historical footnote

Correct: C)

Q5: Based on the worked examples in this subject, what is the correct result?

A) The inverse of the correct answer B) 128$ representable values. C) A different result from a common mistake D) An unrelated numerical value

Correct: B)

Q6: How are Floating-Point Formats Compared and Why Float16 Alone Fails For Training related?

A) Floating-Point Formats Compared and Why Float16 Alone Fails For Training are closely related concepts B) Floating-Point Formats Compared is a special case of Why Float16 Alone Fails For Training C) Floating-Point Formats Compared and Why Float16 Alone Fails For Training are completely unrelated topics D) Floating-Point Formats Compared is the inverse of Why Float16 Alone Fails For Training

Correct: A)

Q7: What is a common pitfall when working with Mixed-Precision Training (Micikevicius Et Al., 2018)?

A) Mixed-Precision Training (Micikevicius Et Al., 2018) has no common misconceptions B) Mixed-Precision Training (Micikevicius Et Al., 2018) is always computed the same way in all contexts C) A common mistake is confusing Mixed-Precision Training (Micikevicius Et Al., 2018) with a similar concept D) The main error with Mixed-Precision Training (Micikevicius Et Al., 2018) is using it when it is not needed

Correct: C)

Q8: When should you apply Loss Scaling?

A) Loss Scaling is not practically useful B) Avoid Loss Scaling unless explicitly instructed C) Apply Loss Scaling to solve problems in this subject's domain D) Use Loss Scaling only in pure mathematics contexts

Correct: C)

Practice Problems

  1. Compute the number of representable values between 1.0 and 2.0 in float16 vs bfloat16.

    Click for answer Between $2^0 = 1$ and $2^1 = 2$, the mantissa provides $2^{10} = 1024$ representable values in float16. In bfloat16: $2^7 = 128$ representable values. float32: $2^{23} = 8,388,608$ values. bfloat16 has 8× fewer values than float16 in the [1,2) range — much coarser precision. But for neural net training, gradients and activations are rarely computed to 3+ digits of accuracy anyway.

  2. A gradient's magnitude is $g = 10^{-4}$. What's the minimum loss scale $S$ needed to make $S \cdot g$ representable as a float16 normal number?

    Click for answer Smallest normal float16: $2^{-14} \approx 6.1 \times 10^{-5}$. Need $S \cdot 10^{-4} \geq 6.1 \times 10^{-5} \implies S \geq 0.61$. So $S=1$ is sufficient — this gradient is already representable in float16's normal range. Loss scaling is only needed for gradients below ~$6 \times 10^{-5}$. For training large models, gradients commonly reach $10^{-7}$ to $10^{-8}$ — requiring $S \sim 10^3$ to $10^4$.

  3. Explain why bfloat16 has the same exponent range as float32.

    Click for answer Both use 8 exponent bits (bias 127 for float32, same for bfloat16). The exponent encoding is identical — bfloat16 simply drops the lower 16 mantissa bits of float32. This means: (1) bfloat16 and float32 have identical overflow/underflow thresholds; (2) conversion between them is a truncation, not a complex rounding operation; (3) no loss scaling is needed.

  4. Dynamic loss scaling increases $S$ when no Inf/NaN is detected, decreases it otherwise. Why not just pick a very large fixed $S$?

    Click for answer A fixed large $S$ would cause overflow for large gradients. If $S = 2^{20} \approx 10^6$ and a gradient is $g = 0.1$, then $S \cdot g = 10^5$ — still within float16 max (65504)? Actually $10^5 > 65504$, so it WOULD overflow. Dynamic scaling finds the sweet spot: large enough to prevent underflow of small gradients, not so large that it overflows large gradients. As training progresses and gradients typically shrink, $S$ can be increased.

  5. Why do optimizer states (Adam's $m$ and $v$) typically remain in float32 even when weights are in float16 or bfloat16?

    Click for answer Optimizer states accumulate information over thousands of steps. $m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t$ requires precise accumulation — float16's 3.3 digits of precision would lose the small $(1-\beta_1)g_t$ term when added to the much larger $\beta_1 m_{t-1}$. Example: if $m_{t-1} \approx 1$ and $g_t \approx 10^{-4}$, then $(1-0.9) \cdot 10^{-4} = 10^{-5}$. In float16, $1 + 10^{-5} = 1$ (the update is lost). Float32's 7.2 digits can represent this correctly. bfloat16 would be even worse (2.2 digits). Optimizer states stay in float32 for stability.


Summary

Key takeaways:


Pitfalls



Next Steps

Next up: 15-07-gpu-computation-model.md — SIMD/SIMT execution, GPU memory hierarchy, thread blocks and warps, and why matrix multiply is so fast on GPUs.