Math graphic
πŸ“ Concept diagram

17-10 β€” The Transformer Block (Detailed)

Phase: 17 β€” Deep Learning Architectures (Math) Subject: 17-10 Prerequisites: 17-09 (Transformer Architecture), 17-08 (Multi-Head Attention), 17-05 (Residual Connections), 16-10 (LayerNorm/RMSNorm), 16-05 (Backpropagation) Next subject: 18-01 β€” Tokenization Mathematics


Learning Objectives

By the end of this subject, you will be able to:

  1. Derive the complete forward pass of one Transformer block as a system of matrix equations
  2. Trace the gradient through the attention mechanism, residual connections, and FFN
  3. Mathematically prove why pre-norm is more stable than post-norm for deep Transformers
  4. Compute the gradient norm scaling through repeated blocks under both norm placements
  5. Explain how RMSNorm differs from LayerNorm and why modern LLMs prefer it

Core Content

1. The Transformer Block as a System of Equations

Let x ∈ ℝ^(nΓ—d) be the input to a block (after positional encoding + previous blocks). A modern pre-norm Transformer block computes:

Step 1: Attention sub-block

x₁ = RMSNorm(x) a = MultiHead(x₁, x₁, x₁) [self-attention] xβ‚‚ = x + a [residual]

Step 2: FFN sub-block

x₃ = RMSNorm(xβ‚‚) f = FFN(x₃) = Wβ‚‚ Β· Οƒ(W₁ Β· x₃ + b₁) + bβ‚‚ xβ‚„ = xβ‚‚ + f [residual]

⚠️ THIS IS CRITICAL β€” Note the order: NORM β†’ SUBLAYER β†’ ADD. This is pre-norm. The normalization is applied BEFORE each sub-layer, and the residual connection adds the PRE-normalized input. This means the residual stream carries the "raw" signal, and normalization only affects what the sub-layer sees.

2. Post-Norm vs. Pre-Norm: The Critical Difference

Post-norm (original Transformer):

xβ‚‚ = LayerNorm(x + Sublayer(x))

The normalization is applied AFTER the residual addition. The residual stream passes through LayerNorm at every block.

Pre-norm (modern Transformers, GPT, Llama, etc.):

xβ‚‚ = x + Sublayer(Norm(x))

The normalization is applied BEFORE the sub-layer. The residual stream accumulates WITHOUT normalization gating.

3. Why Pre-Norm Wins: The Gradient Analysis

Let's analyze gradient propagation through L consecutive blocks.

Post-norm (simplified):

Block β„“: x_{β„“+1} = LN(x_β„“ + F_β„“(x_β„“))

The Jacobian involves LN derivatives:

βˆ‚x_{β„“+1}/βˆ‚x_β„“ = βˆ‚LN/βˆ‚(Β·) Β· (I + βˆ‚F_β„“/βˆ‚x_β„“)

The LayerNorm derivative is: βˆ‚LN(x)/βˆ‚x = (Ξ³/Οƒ)(I βˆ’ 11α΅€/d βˆ’ (xΜ‚xΜ‚α΅€)/d) where xΜ‚ is the normalized vector. This has eigenvalues bounded by Ξ³/Οƒ.

If the scale Ξ³ is small (typical at initialization) or Οƒ is large, the LN derivative attenuates gradients. After L blocks:

||βˆ‚x_L/βˆ‚xβ‚€|| ≀ (||βˆ‚LN|| Β· (1 + ||J_F||))^L

Each block's gradient is multiplied by ||βˆ‚LN||, which can be βˆ‚x_{β„“+1}/βˆ‚x_β„“ = I + βˆ‚F_β„“/βˆ‚(LN(x_β„“)) Β· βˆ‚LN/βˆ‚x_β„“

The identity term I is UNTOUCHED. No matter what LN does inside F_β„“, the residual path has gradient multiplier I.

After L blocks:

βˆ‚x_L/βˆ‚xβ‚€ β†’ I + (terms involving F derivatives)

The identity term ensures that: ||βˆ‚x_L/βˆ‚xβ‚€|| β‰₯ 1 β€” gradients can't entirely vanish.

Empirical confirmation: Post-norm Transformers are typically limited to ~12-24 layers. Pre-norm Transformers can scale to 100+ layers (GPT-3: 96 layers, Llama 70B: 80 layers).

4. Gradient Through Attention (Pre-Norm Block)

Let's trace the full gradient. Loss L β†’ xβ‚„. We need βˆ‚L/βˆ‚x and βˆ‚L/βˆ‚(attention params).

Backward through the attention sub-block:

xβ‚‚ = x + MultiHead(Norm(x), Norm(x), Norm(x))

Given βˆ‚L/βˆ‚xβ‚‚:

βˆ‚L/βˆ‚x = βˆ‚L/βˆ‚xβ‚‚ (identity path from residual)

Plus gradient through the attention path:

βˆ‚L/βˆ‚x += βˆ‚L/βˆ‚xβ‚‚ Β· βˆ‚Attn/βˆ‚(Norm(x)) Β· βˆ‚Norm/βˆ‚x

The identity term guarantees βˆ‚L/βˆ‚x doesn't vanish. The attention path contributes additional gradient that depends on the attention computation itself.

Gradient through attention parameters:

βˆ‚L/βˆ‚W_i^Q = βˆ‚L/βˆ‚xβ‚‚ Β· (βˆ‚Attn/βˆ‚Q_i) Β· (βˆ‚Q_i/βˆ‚W_i^Q)

Where βˆ‚Attn/βˆ‚Q_i involves the softmax gradient from 17-06 β€” potentially small if the attention is saturated, but the FFN gradient provides a complementary signal.

5. Gradient Through FFN (Pre-Norm Block)

xβ‚„ = xβ‚‚ + FFN(Norm(xβ‚‚))

Backward:

βˆ‚L/βˆ‚xβ‚‚ = βˆ‚L/βˆ‚xβ‚„ (identity) βˆ‚L/βˆ‚xβ‚‚ += βˆ‚L/βˆ‚xβ‚„ Β· βˆ‚FFN/βˆ‚(Norm(xβ‚‚)) Β· βˆ‚Norm/βˆ‚xβ‚‚

The FFN gradient decomposes as:

βˆ‚FFN/βˆ‚x = Wβ‚‚ Β· diag(Οƒ'(W₁x + b₁)) Β· W₁

For ReLU: diag entries are 0 or 1. For GELU: entries are between ~0 and ~1.1.

Key insight: The FFN gradient can vanish if too many neurons are inactive (dying ReLU), but the residual identity path always provides a baseline gradient.

6. Effect on Training Dynamics

Post-norm: The output of each block is ALWAYS normalized β†’ all positions have the same scale β†’ gradient magnitude controlled but potentially attenuated.

Pre-norm: The residual stream accumulates β†’ later layers can have larger activation norms β†’ larger gradients β†’ faster learning. This is beneficial early in training but can make later layers have larger updates than early layers.

Modern practice: Pre-norm + proper initialization ensures stable training. The identity gradient path acts as a "gradient amplifier" that prevents any layer from being starved of learning signal.

7. RMSNorm vs LayerNorm

Modern LLMs (Llama, Mistral, Gemma) use RMSNorm instead of LayerNorm:

LayerNorm:

ΞΌ = (1/d)Ξ£x_i σ² = (1/d)Ξ£(x_i βˆ’ ΞΌ)Β² LN(x) = Ξ³ βŠ™ (x βˆ’ ΞΌ)/Οƒ + Ξ²

RMSNorm:

rms = √((1/d)Ξ£x_iΒ²) RMSNorm(x) = Ξ³ βŠ™ x / rms (no mean subtraction, no Ξ² shift)

Why RMSNorm? 1. Fewer operations (no mean computation) 2. Slightly faster (one less pass over data) 3. Empirically equivalent or better for Transformers 4. The mean subtraction might remove useful information about activation magnitudes

Gradient of RMSNorm:

βˆ‚RMSNorm(x)/βˆ‚x = (Ξ³/rms) Β· (I βˆ’ xΜ‚xΜ‚α΅€/d)

Where xΜ‚ = x/rms. The structure is similar to LayerNorm without the 11α΅€ term.

8. Activation Checkpointing (Gradient Checkpointing)

Since the attention matrix O(nΒ²) dominates memory, Transformers use activation checkpointing: store only block inputs, recompute intermediate activations during backward pass.

Forward: store only xβ‚€, x₁, ..., x_L (block inputs, O(LΒ·nΒ·d) memory) Backward: recompute each block's forward pass to get intermediate activations, then backprop through the block.

This trades compute (2Γ— forward passes) for memory (O(nΒ²) attention matrices not stored).

9. Full Block Parameter Update

For one training step, the parameter gradients for a pre-norm block are:

Attention parameters:

βˆ‚L/βˆ‚W_i^{Q,K,V} = accumulated via attention backward (17-08) βˆ‚L/βˆ‚W^O = accumulated via output projection backward

FFN parameters:

βˆ‚L/βˆ‚W₁ = (βˆ‚L/βˆ‚f_post) βŠ™ Οƒ'(Β·) Β· x₃ᡀ βˆ‚L/βˆ‚Wβ‚‚ = (βˆ‚L/βˆ‚xβ‚„) Β· f_preα΅€

RMSNorm parameters:

βˆ‚L/βˆ‚Ξ³β‚ = Ξ£ over positions: βˆ‚L/βˆ‚xβ‚‚ Β· (βˆ‚xβ‚‚/βˆ‚Ξ³β‚) βˆ‚L/βˆ‚Ξ³β‚‚ = Ξ£ over positions: βˆ‚L/βˆ‚xβ‚„ Β· (βˆ‚xβ‚„/βˆ‚Ξ³β‚‚)



Key Terms

Worked Examples

Example 1: Pre-Norm Forward Pass (Single Block, Simplified)

Problem: Input x = [[1,2],[3,4]] (n=2, d=2). RMSNorm with Ξ³=[1,1]. FFN: W₁=[[1,0],[0,1]], Wβ‚‚=[[1,0],[0,1]], no biases, ReLU activation. Skip attention (a=0). Compute xβ‚„.

Solution: RMSNorm(x): Row 0: rms = √((1+4)/2) = √2.5 β‰ˆ 1.581. xΜ‚β‚€ = [1/1.581, 2/1.581] = [0.632, 1.265] Row 1: rms = √((9+16)/2) = √12.5 β‰ˆ 3.536. x̂₁ = [3/3.536, 4/3.536] = [0.848, 1.131]

With Ξ³=[1,1]: norm output x₃ = [[0.632, 1.265],[0.848, 1.131]]

FFN (since W₁=Wβ‚‚=I): f = ReLU(x₃) = [[0.632, 1.265],[0.848, 1.131]]

xβ‚„ = xβ‚‚ + f = x + f (a=0) = [[1+0.632, 2+1.265],[3+0.848, 4+1.131]] = [[1.632, 3.265],[3.848, 5.131]]

Example 2: Gradient Identity Path Verification

Problem: For a pre-norm block: xβ‚‚ = x + F(Norm(x)). Show that if βˆ‚F/βˆ‚x = 0 (F doesn't depend on x), the gradient βˆ‚xβ‚‚/βˆ‚x = I.

Solution: βˆ‚xβ‚‚/βˆ‚x = I + βˆ‚F(Norm(x))/βˆ‚x = I + βˆ‚F/βˆ‚(Norm(x)) Β· βˆ‚Norm/βˆ‚x = I + 0 = I

The identity gradient is always present, regardless of what F does. If F learned to be the zero function, gradients still flow perfectly through the block.

Example 3: Post-Norm Gradient Degradation

Problem: A post-norm block: xβ‚‚ = LN(x + F(x)). Assume LayerNorm acts as L2 normalization (simplified: divides by ||Β·||). If ||x + F(x)|| grows by factor 2 each block, how does the gradient scale after 10 blocks?

Solution: After 1 block: βˆ‚x₁/βˆ‚xβ‚€ β‰ˆ (1/||xβ‚€+F(xβ‚€)||)Β·(I + J_F)

If norm grows 2Γ— per block: after block β„“, ||x_β„“|| β‰ˆ 2^β„“Β·||xβ‚€||.

The LN derivative scales as 1/||x|| β‰ˆ 2^{-β„“}. So: ||βˆ‚x_10/βˆ‚xβ‚€|| β‰ˆ ∏_{β„“=1}^{10} (1/2^β„“) Β· (1+||J||) β‰ˆ 2^{-55} β‰ˆ 2.8Γ—10^{-17}

Catastrophic gradient vanishing! Even with residual connections, the post-norm placement destroys gradients in deep networks.

The √d_k scaling doesn't fix this β€” it's a separate normalization problem specific to post-norm placement.


Quiz

Q1: What is the key mathematical advantage of pre-norm over post-norm in Transformer blocks?

A) Pre-norm is faster to compute B) Pre-norm places normalization BEFORE the sub-layer, so the residual skip's gradient is I, never attenuated by normalization C) Pre-norm uses fewer parameters D) Pre-norm eliminates the need for residual connections

Answer & Explanation **B** β€” Pre-norm: βˆ‚(x + Sublayer(Norm(x)))/βˆ‚x = I + term. The I is never multiplied by Norm's derivative. Post-norm passes the residual through Norm, whose derivative < 1 causes catastrophic gradient decay.

Q2: How does RMSNorm differ from LayerNorm?

A) RMSNorm uses batch statistics B) RMSNorm removes mean subtraction and only normalizes by the root mean square C) RMSNorm adds an extra residual connection D) RMSNorm uses a different activation function

Answer & Explanation **B** β€” RMSNorm(x) = Ξ³ Β· x/rms(x) where rms = √(mean(xΒ²)). LayerNorm additionally subtracts the mean: LN(x) = Ξ³ Β· (x βˆ’ ΞΌ)/Οƒ + Ξ². Modern LLMs (Llama, Mistral, Gemma) use RMSNorm.

Q3: In the backward pass of a pre-norm block, how does gradient flow from output xβ‚„ to input x?

A) Only through the FFN path B) Through I (identity, from both residuals) + paths through attention and FFN C) Only through attention D) It does not flow at all

Answer & Explanation **B** β€” xβ‚‚ = x + Attn(Norm(x)) gives βˆ‚xβ‚‚/βˆ‚x = I + βˆ‚Attn/βˆ‚x. xβ‚„ = xβ‚‚ + FFN(Norm(xβ‚‚)) gives βˆ‚xβ‚„/βˆ‚xβ‚‚ = I + βˆ‚FFN/βˆ‚xβ‚‚. The I terms guarantee baseline gradient flow.

Q4: What is the purpose of activation checkpointing?

A) To make training faster B) To reduce memory by recomputing activations during backward instead of storing them C) To check for NaN gradients D) To reduce the number of parameters

Answer & Explanation **B** β€” Attention matrices are O(nΒ²) per layer, dominating memory. Checkpointing stores only block inputs and recomputes attention during backward, trading 2Γ— forward compute for much lower peak memory.

Q5: Why can post-norm not scale to 100+ layers while pre-norm can?

A) Post-norm uses more parameters B) In post-norm, the LayerNorm Jacobian multiplies gradient at every block: ||βˆ‚x_L/βˆ‚xβ‚€|| ≀ (||βˆ‚LN||)^L. Pre-norm's I term ensures ||βˆ‚x_L/βˆ‚xβ‚€|| β‰₯ 1. C) Post-norm requires more FLOPs D) Post-norm uses a different activation function

Answer & Explanation **B** β€” Post-norm: x_{β„“+1} = LN(x_β„“ + F(x_β„“)). βˆ‚LN/βˆ‚(Β·) has eigenvalues bounded by Ξ³/Οƒ. After L blocks, gradients can decay exponentially. Pre-norm's identity term is untouched by normalization, enabling GPT-3's 96 layers.

Practice Problems

Problem 1

Write the forward equations for a pre-norm Transformer block. Identify where the identity gradient path enters.

Answer x₁ = Norm(x); a = Attn(x₁); xβ‚‚ = x + a; x₃ = Norm(xβ‚‚); f = FFN(x₃); xβ‚„ = xβ‚‚ + f. Identity paths: x β†’ xβ‚‚ (first residual) and xβ‚‚ β†’ xβ‚„ (second residual). Both contribute I to the Jacobian.

Problem 2

Why does the identity term I in the pre-norm Jacobian not get attenuated by normalization?

Answer In pre-norm: x_{β„“+1} = x_β„“ + F(Norm(x_β„“)). The derivative is I + βˆ‚F(Norm(x_β„“))/βˆ‚x_β„“. I comes from the skip connection x_β„“, which is NOT passed through Norm. In post-norm: x_{β„“+1} = Norm(x_β„“ + F(x_β„“)), the entire sum goes through Norm.

Problem 3

How does RMSNorm differ from LayerNorm mathematically?

Answer LayerNorm: xΜ‚ = (x βˆ’ ΞΌ)/Οƒ, using both first (ΞΌ) and second (Οƒ) moments. RMSNorm: xΜ‚ = x/rms, using only the root mean square. Fewer operations, similar performance in Transformers.

Problem 4

Compute the gradient of RMSNorm(x) for x ∈ ℝ² with Ξ³ = 1.

Answer rms = √((x₁² + xβ‚‚Β²)/2). RMSNorm(x)_i = x_i/rms. βˆ‚RMSNorm_i/βˆ‚x_j = (1/rms)(Ξ΄_ij βˆ’ xΜ‚_iΒ·xΜ‚_j/2) where xΜ‚_i = x_i/rms. In matrix form: (1/rms)(I βˆ’ xΜ‚xΜ‚α΅€/2).

Problem 5

Why does activation checkpointing reduce memory at the cost of compute?

Answer Without checkpointing: all intermediate activations (including O(nΒ²) attention matrices from every layer) are stored. Checkpointing stores only block inputs (O(nΒ·d)). During backward, each block's forward is recomputed to obtain activations for its backward pass. This doubles forward compute but reduces memory from O(LΒ·nΒ²) to O(LΒ·nΒ·d + nΒ²).

Summary

  1. Modern Transformers use pre-norm: Norm β†’ Sublayer β†’ Add, giving the residual stream an un-normalized identity gradient path
  2. Post-norm (original Transformer: Add β†’ Norm) causes normalization to gate even the residual gradient, preventing scaling to deep networks
  3. The identity Jacobian term I in pre-norm ensures ||βˆ‚x_L/βˆ‚xβ‚€|| β‰₯ 1 β€” gradients never vanish completely through the residual
  4. RMSNorm (normalize by RMS only, no mean subtraction) replaces LayerNorm in modern LLMs for efficiency with equal performance
  5. Activation checkpointing trades compute for memory by storing only block inputs and recomputing intermediates during backward pass

Pitfalls


Next Steps

Continue to Phase 18 with 18-01 β€” Tokenization Mathematics to learn how text is converted into the numerical tokens that feed into the Transformer.