Math graphic
📐 Concept diagram

16-09 — Batch Normalization

Phase: 16 — Neural Network Mathematics Subject: 16-09 Prerequisites: 16-05 (Backpropagation), 16-06 (Gradient Flow), Phase 6 (partial derivatives), Phase 11 (expectation and variance) Next subject: 16-10 — Other Normalization Methods


Learning Objectives

By the end of this subject, you will be able to:

  1. Write the batch normalization forward pass equations and explain each component
  2. Derive the backpropagation through batch normalization, including gradients through mean and variance
  3. Explain the distinction between training and inference modes and why running averages are used
  4. Analyze why batch normalization enables higher learning rates and provides implicit regularization
  5. Understand the "internal covariate shift" motivation and its modern reinterpretation

Core Content

1. The Problem Batch Normalization Solves

During training, the distribution of each layer's inputs constantly changes as previous layers' weights update. This is internal covariate shift — the network must continuously adapt to changing input distributions, which slows training and requires careful hyperparameter tuning.

⚠️ THIS IS CRITICAL — Batch Normalization (Ioffe & Szegedy, 2015) is one of the most impactful innovations in deep learning. It enables training deeper networks with higher learning rates, reduces sensitivity to initialization, and acts as a regularizer. It's a standard component of almost all modern CNN architectures.

Modern understanding: The original "internal covariate shift" motivation has been debated. The current view is that BN primarily works by smoothing the optimization landscape — making the loss surface more well-conditioned, with fewer pathological curvatures.

2. Batch Normalization: Forward Pass

For a mini-batch of m examples, consider the activations x = {x₁, ..., x_m} at a particular neuron:

Step 1 — Compute batch statistics:

μ_B = (1/m) Σᵢ xᵢ (batch mean) σ²_B = (1/m) Σᵢ (xᵢ − μ_B)² (batch variance)

Step 2 — Normalize:

x̂ᵢ = (xᵢ − μ_B) / √(σ²_B + ε)

Where ε ≈ 10⁻⁵ prevents division by zero.

Step 3 — Scale and shift (learnable parameters):

yᵢ = γ·x̂ᵢ + β

Where γ (scale) and β (shift) are learnable parameters, initialized to γ = 1, β = 0.

Why γ and β? Normalization alone would constrain the layer's representational power (e.g., forcing zero mean and unit variance could make sigmoid outputs nearly linear). The learnable γ and β restore full expressiveness — if the optimal representation needs a different mean or variance, the network can learn it.

3. Training vs Inference Mode

During TRAINING, μ_B and σ²_B are computed from the current mini-batch. During INFERENCE, we need deterministic behavior — we can't depend on batch statistics (and the batch might be size 1 for online prediction).

Solution — Running averages:

During training, maintain exponential moving averages:

μ_run ← momentum · μ_run + (1 − momentum) · μ_B σ²_run ← momentum · σ²_run + (1 − momentum) · σ²_B

Typical momentum: 0.9 or 0.99.

At inference, use μ_run and σ²_run:

y = γ · (x − μ_run)/√(σ²_run + ε) + β

Important: For the variance, some implementations track σ²_run (biased estimate, dividing by m) while others track the unbiased estimate (dividing by m−1). Most frameworks use the biased estimate during training (matching the training behavior) but the unbiased running average for inference. Be aware of this subtlety.

4. Backpropagation Through Batch Normalization

Backprop through BN is non-trivial because μ_B and σ²_B depend on ALL examples in the batch, making the gradient computation slightly involved.

Given ∂L/∂yᵢ for each example i, we need ∂L/∂xᵢ, ∂L/∂γ, and ∂L/∂β.

Let's derive ∂L/∂xᵢ:

The computation graph: xᵢ → μ_B, σ²_B → x̂ᵢ → yᵢ → L. xᵢ affects L through three paths: directly through x̂ᵢ, and indirectly through μ_B and σ²_B.

Step 1 — Gradients through the affine transform: ∂L/∂x̂ᵢ = ∂L/∂yᵢ · ∂yᵢ/∂x̂ᵢ = ∂L/∂yᵢ · γ

∂L/∂γ = Σᵢ ∂L/∂yᵢ · ∂yᵢ/∂γ = Σᵢ ∂L/∂yᵢ · x̂ᵢ

∂L/∂β = Σᵢ ∂L/∂yᵢ · 1 = Σᵢ ∂L/∂yᵢ

Step 2 — Gradients through normalization:

∂L/∂σ²_B = Σᵢ ∂L/∂x̂ᵢ · ∂x̂ᵢ/∂σ²_B = Σᵢ ∂L/∂x̂ᵢ · (−1/2)(xᵢ − μ_B)(σ²_B + ε)^(−3/2)

∂L/∂μ_B = Σᵢ ∂L/∂x̂ᵢ · ∂x̂ᵢ/∂μ_B + ∂L/∂σ²_B · ∂σ²_B/∂μ_B = Σᵢ ∂L/∂x̂ᵢ · (−1/√(σ²_B + ε)) + ∂L/∂σ²_B · (−2/m)Σⱼ(xⱼ − μ_B)

The second term simplifies since Σⱼ(xⱼ − μ_B) = 0, so: ∂L/∂μ_B = Σᵢ ∂L/∂x̂ᵢ · (−1/√(σ²_B + ε))

Step 3 — Full gradient ∂L/∂xᵢ:

∂L/∂xᵢ = ∂L/∂x̂ᵢ · (1/√(σ²_B + ε)) + ∂L/∂σ²_B · 2(xᵢ − μ_B)/m + ∂L/∂μ_B · (1/m)

Combining:

∂L/∂xᵢ = (γ / √(σ²_B + ε)) · [∂L/∂yᵢ − (1/m)(Σⱼ ∂L/∂yⱼ) − x̂ᵢ·(1/m)(Σⱼ ∂L/∂yⱼ · x̂ⱼ)]

Or more compactly (the form commonly implemented):

∂L/∂xᵢ = (1 / (m·√(σ²_B + ε))) · [m·∂L/∂x̂ᵢ − Σⱼ ∂L/∂x̂ⱼ − x̂ᵢ·Σⱼ ∂L/∂x̂ⱼ·x̂ⱼ]

Where ∂L/∂x̂ᵢ = γ·∂L/∂yᵢ.

Key insight: The gradient through BN involves terms that couple all batch elements together (the sums over j). This is what makes the computation slightly more expensive but also contributes to BN's regularizing effect.

5. Why Batch Normalization Works

1. Smoother optimization landscape: BN reduces the Lipschitz constants of the loss and gradients. The loss changes more smoothly as parameters change — fewer sharp ravines and cliffs. This allows larger learning rates without divergence.

2. Reduced sensitivity to initialization: Even with poor initial weights, BN ensures activations have controlled mean and variance. The network can recover from bad initialization more easily.

3. Implicit regularization: The mini-batch statistics introduce noise (μ_B and σ²_B are noisy estimates of the true mean/variance). This noise acts as a regularizer — similar to dropout but through normalization rather than zeroing. Networks with BN often need LESS explicit regularization.

4. Mitigates vanishing/exploding gradients: By controlling activation scale, BN prevents activations from growing or shrinking across layers, preserving gradient magnitude.

5. Allows higher learning rates: The smooth loss landscape means gradient descent can take larger steps without diverging. BN + high learning rate often trains faster than unnormalized networks with any learning rate.

6. Where to Place Batch Normalization

Traditionally: Conv/FC → BN → Activation

But this is debated. Some prefer: Conv/FC → Activation → BN (though less common).

For residual networks: Conv → BN → ReLU → Conv → BN → Addition → ReLU (pre-activation).



Key Terms

Worked Examples

Example 1: Forward Pass Computation

Problem: A mini-batch of 4 activations for a single neuron: x = [2, 4, 6, 8]. Compute the BN output with γ = 1, β = 0, ε = 0. (Use ε = 0 for simplicity in hand calculation.)

Solution:

μ_B = (2+4+6+8)/4 = 5 σ²_B = ((2−5)² + (4−5)² + (6−5)² + (8−5)²)/4 = (9+1+1+9)/4 = 5

x̂ = [(2−5)/√5, (4−5)/√5, (6−5)/√5, (8−5)/√5] = [−3/2.236, −1/2.236, 1/2.236, 3/2.236] ≈ [−1.342, −0.447, 0.447, 1.342]

Check: Mean of x̂ ≈ (−1.342−0.447+0.447+1.342)/4 = 0 ✓ Variance of x̂ ≈ (1.801+0.200+0.200+1.801)/4 = 1.0005 ≈ 1 ✓

Answer: Normalized outputs are approximately [−1.34, −0.45, 0.45, 1.34], with mean 0 and variance 1.

Example 2: Effect of γ and β

Problem: Starting from Example 1's normalized outputs, apply γ = 2 and β = 1. What are the final BN outputs and their mean/variance?

Solution:

y = γ·x̂ + β = [2·(−1.342)+1, 2·(−0.447)+1, 2·0.447+1, 2·1.342+1] = [−1.684, 0.106, 1.894, 3.684]

Mean: (−1.684+0.106+1.894+3.684)/4 = 4/4 = 1 = β ✓ Variance: approximately 2² = 4 times original variance ≈ 4·1 = 4 ✓

γ controls the standard deviation (multiplies by |γ|). β controls the mean (adds β). The network can learn any mean and variance for the output distribution.

Example 3: Gradient Through BN (Simplified)

Problem: For a batch of 2 examples: x₁ = 1, x₂ = 3, with γ = 1, β = 0, and ∂L/∂y₁ = 0.5, ∂L/∂y₂ = −0.5, compute ∂L/∂x₁ and ∂L/∂x₂. Use ε = 0.

Solution:

μ_B = (1+3)/2 = 2 σ²_B = ((1−2)²+(3−2)²)/2 = (1+1)/2 = 1 x̂₁ = (1−2)/1 = −1, x̂₂ = (3−2)/1 = 1

∂L/∂x̂₁ = γ·∂L/∂y₁ = 0.5 ∂L/∂x̂₂ = γ·∂L/∂y₂ = −0.5

∂L/∂σ²_B = Σᵢ ∂L/∂x̂ᵢ · (−1/2)(xᵢ−μ_B)·(σ²_B)^(−3/2) = 0.5·(−1/2)(1−2)·1 + (−0.5)·(−1/2)(3−2)·1 = 0.5·(0.5) + (−0.5)·(−0.5) = 0.25 + 0.25 = 0.5

∂L/∂μ_B = Σᵢ ∂L/∂x̂ᵢ · (−1/√1) = 0.5·(−1) + (−0.5)·(−1) = −0.5 + 0.5 = 0

Now ∂L/∂x₁: = ∂L/∂x̂₁·(1/√1) + ∂L/∂σ²_B·2(x₁−μ_B)/2 + ∂L/∂μ_B·(1/2) = 0.5·1 + 0.5·2(−1)/2 + 0·0.5 = 0.5 + 0.5·(−1) = 0.5 − 0.5 = 0

Similarly ∂L/∂x₂ = −0.5 + 0.5·1 = 0.

Interesting result: Both gradients are zero! This makes intuitive sense: with γ=1, β=0, the BN output is just (x−2)/1. In this case, ∂L/∂y₁ = 0.5 and ∂L/∂y₂ = −0.5 sum to 0 and have no component along x̂. The mean shift and variance scaling gradients happen to cancel the direct gradient.

Practice Problems

(Answers are below. Try each problem before checking.)

Problem 1: Show that for any mini-batch, the normalized outputs x̂ᵢ satisfy: (a) Σᵢ x̂ᵢ = 0, (b) (1/m)Σᵢ x̂ᵢ² = 1.

Problem 2: Derive ∂L/∂γ and ∂L/∂β from the BN forward equations.

Problem 3: In inference mode, BN uses running averages μ_run and σ²_run. Derive the update rule for μ_run using momentum α and explain why high momentum (close to 1) is preferred.

Problem 4: A network layer with 256 neurons uses BN. There are 256 γ parameters and 256 β parameters. During inference, what is the total number of operations per example for this BN layer? Compare with a simple linear layer of 256→256.

Problem 5: Prove that if we apply BN right AFTER a linear layer (before activation), the bias term b in the linear layer becomes redundant. (Hint: show that any effect of b can be absorbed by β.)

Answers (click to expand) **Problem 1:** (a) Σᵢ x̂ᵢ = Σᵢ (xᵢ−μ_B)/σ_B = (1/σ_B)[Σᵢ xᵢ − m·μ_B] = (1/σ_B)[m·μ_B − m·μ_B] = 0 ✓ (b) (1/m)Σᵢ x̂ᵢ² = (1/m)Σᵢ (xᵢ−μ_B)²/σ²_B = (1/σ²_B)·(1/m)Σᵢ (xᵢ−μ_B)² = σ²_B/σ²_B = 1 ✓ **Problem 2:** ∂L/∂γ = Σᵢ ∂L/∂yᵢ · ∂yᵢ/∂γ = Σᵢ ∂L/∂yᵢ · x̂ᵢ ∂L/∂β = Σᵢ ∂L/∂yᵢ · ∂yᵢ/∂β = Σᵢ ∂L/∂yᵢ · 1 = Σᵢ ∂L/∂yᵢ The β gradient is simply the sum of incoming gradients. The γ gradient weights each example's gradient by how extreme its normalized activation was. **Problem 3:** μ_run ← α·μ_run + (1−α)·μ_B With high α (e.g., 0.99), μ_run is a slowly-moving estimate of the population mean. This is an exponential moving average: μ_run ≈ E[μ_B] over recent batches. High α is preferred because: - More stable estimate (less sensitive to individual batches) - During training, the data distribution shifts (due to changing parameters); recent batches better reflect the current distribution, so some recency bias is desirable - α = 0.9 or 0.99 are typical (effectively averaging over ~10 or ~100 batches) **Problem 4:** BN inference per neuron: subtract μ_run, divide by √(σ²_run+ε), multiply by γ, add β = 4 operations per value × 256 values = ~1024 operations Linear layer 256→256: matrix multiply 256·256 = 65,536 multiply-adds ≈ 131,000 operations. BN is computationally CHEAP compared to the linear layer (~1% of the cost). The training cost is higher (computing batch statistics and their gradients), but still small relative to the linear operations. **Problem 5:** With linear layer: z = Wx + b, then BN(z). BN computes (z − μ_B)/σ_B · γ + β. = ((Wx + b) − μ_B)/σ_B · γ + β = (Wx − (μ_B − b))/σ_B · γ + β The bias b simply shifts μ_B. But μ_B is ALREADY subtracted during normalization. Specifically, BN computes ẑ = (z − μ_z)/σ_z. For z = Wx + b: μ_z = μ_{Wx} + b (mean of Wx plus b) z − μ_z = (Wx + b) − (μ_{Wx} + b) = Wx − μ_{Wx} The bias CANCELS. Therefore, when BN follows a linear layer, the bias b has NO effect on the output. Its effect is subsumed by β. In practice: omit the bias from linear layers that are immediately followed by BN.

Summary

  1. Batch Normalization normalizes activations within a mini-batch: x̂ = (x−μ_B)/√(σ²_B+ε), then applies learnable scale γ and shift β.
  2. Training vs inference: During training, use batch statistics; during inference, use running averages (exponential moving averages from training).
  3. Backprop through BN is non-trivial because μ_B and σ²_B depend on all batch elements; gradients couple across the batch, providing implicit regularization.
  4. Benefits: Smooths optimization landscape, enables higher learning rates, reduces sensitivity to initialization, provides implicit regularization.
  5. Practical note: Biases in linear layers before BN are redundant (BN's β replaces them). Omit the bias.

Pitfalls


Quiz

Q1: What are the learnable parameters in a Batch Normalization layer?

A) μ and σ² B) γ and β (scale and shift) C) ε (epsilon) D) The momentum parameter

Answer and Explanations **Correct: B) γ and β (scale and shift)** γ and β are learned via gradient descent, initialized to 1 and 0 respectively. μ and σ² are computed from data (not learned). ε is a fixed hyperparameter (typically 10⁻⁵). Momentum is a fixed hyperparameter for the running averages. - A) μ and σ² are computed from batch statistics, not learned via gradient descent. - B) ✓ Correct. γ and β restore representational power lost by normalization. - C) ε is a fixed constant for numerical stability. - D) Momentum is a hyperparameter, not learned.

Q2: Why are running averages used during inference instead of batch statistics?

A) Batch statistics cannot be computed for a single example B) Running averages are more accurate C) Running averages are faster D) Batch statistics require backpropagation

Answer and Explanations **Correct: A) Batch statistics cannot be computed for a single example** At inference, we may process one example at a time. μ_B and σ²_B can't be meaningfully computed from a single example (σ²_B would be 0 or undefined). Running averages provide stable, pre-computed statistics from the training data. - A) ✓ Correct. Inference batches can be size 1, making batch statistics impossible. - B) Running averages approximate population statistics but aren't necessarily more accurate than large-batch statistics. - C) They're the same computational cost. - D) Backprop is not done during inference regardless.

Q3: Batch Normalization provides implicit regularization because:

A) It adds L2 penalty to activations B) The mini-batch statistics introduce noise, as each batch's mean and variance are noisy estimates C) It reduces the number of parameters D) It uses dropout internally

Answer and Explanations **Correct: B) The mini-batch statistics introduce noise, as each batch's mean and variance are noisy estimates** Each mini-batch's μ_B and σ²_B are noisy estimates of the population statistics. This noise in the normalization acts as a regularizer — similar to how dropout adds noise by zeroing neurons. Networks with BN often require less (or no) explicit dropout. - A) Incorrect. BN doesn't add any penalty term. - B) ✓ Correct. The stochasticity from batch statistics provides a regularizing effect. - C) Incorrect. BN adds parameters (γ and β). - D) Incorrect. BN is a separate mechanism from dropout.

Q4: In the backpropagation through BN, why does ∂L/∂xᵢ depend on all ∂L/∂yⱼ (for all j in the batch)?

A) Because of the γ scaling parameter B) Because μ_B and σ²_B are functions of all xⱼ in the batch C) Because the loss is computed over the whole batch D) It doesn't — ∂L/∂xᵢ only depends on ∂L/∂yᵢ

Answer and Explanations **Correct: B) Because μ_B and σ²_B are functions of all xⱼ in the batch** xᵢ appears not only directly in x̂ᵢ = (xᵢ−μ_B)/σ_B, but also in μ_B (via the mean) and σ²_B (via the variance), both of which aggregate over ALL batch elements. Thus ∂L/∂xᵢ receives contributions from every ∂L/∂yⱼ. - A) γ is element-wise, not cross-coupled. - B) ✓ Correct. The coupling through μ_B and σ²_B makes each ∂L/∂xᵢ depend on the entire batch. - C) While true, this isn't the mechanism — the coupling is through the normalization statistics. - D) Incorrect. The indirect paths through μ_B and σ²_B create cross-dependencies.

Q5: If a linear layer is immediately followed by Batch Normalization, why is the bias term redundant?

A) BN subtracts the mean, which cancels any constant offset from the bias B) BN multiplies by γ, which nullifies the bias C) The bias makes the network too deep D) The bias interferes with the running averages

Answer and Explanations **Correct: A) BN subtracts the mean, which cancels any constant offset from the bias** z = Wx + b. BN computes (z−μ_z)/σ_z. μ_z = μ_{Wx} + b. So (z−μ_z) = (Wx+b) − (μ_{Wx}+b) = Wx − μ_{Wx}. The bias b cancels out completely. Any desired shift can be achieved by BN's own β parameter. - A) ✓ Correct. The mean subtraction eliminates any constant bias term. - B) γ scales, doesn't nullify. - C) Irrelevant. - D) The bias doesn't affect the running averages beyond being canceled in (z−μ_z).

Next Steps

Move on to 16-10 — Other Normalization Methods to learn about Layer Normalization, Instance Normalization, Group Normalization, and RMS Normalization — alternatives to BN that work in different contexts, especially sequence models and transformers.