18-10 — Transformer Variants (Math Focus)
Phase: 18 — Large Language Model Mathematics Subject: 18-10 Prerequisites: 18-09 (KV Caching), 18-04 (RoPE), 17-07 (Scaled Dot-Product Attention), 17-10 (Transformer Block Detailed), 14-03 (SGD — sparsity concept) Next subject: 19-01 — Mixture of Experts (MoE) — Deep
Learning Objectives
By the end of this subject, you will be able to:
- Derive the computational complexity of sliding window attention and compare it to full attention
- Formulate FlashAttention's tiling strategy and prove it computes exact attention (not approximate) with O(n) memory
- Derive the Mixture of Experts (MoE) routing function g(x) = softmax(top-k(W_g · x)) and compute its gate Jacobian
- Compare the parameter count, FLOPs per token, and memory footprint of dense vs. MoE transformers
- Analyze the load balancing loss in MoE and explain why it's necessary for training stability
Core Content
1. Sparse Attention Patterns
Full causal attention has O(n²) cost for sequence length n. Sparse attention restricts which positions can attend to which, reducing this to O(n·w) or O(n·k) where w/k ≪ n.
Sliding Window Attention
Each position attends only to tokens within a window of size w (looking backward):
A_{ij} ≠ 0 only if max(0, i−w) ≤ j ≤ i
Complexity: O(n·w) per attention layer instead of O(n²).
Memory: KV cache can be truncated to only store the last w tokens. Each layer caches only w key-value pairs per head.
Mathematical formulation: Let M_w be the sliding window mask:
M_w[i, j] = 0 if max(0, i−w) ≤ j ≤ i, −∞ otherwise
Attention: A = softmax(QKᵀ/√d_k + M_w) · V
For Mistral 7B: w = 4096 with 32K context — distant tokens use only sliding window attention while global tokens or special mechanisms handle long-range dependencies.
Global Tokens
Some designated "global" tokens attend to and are attended by ALL positions. These carry summary information across the full sequence.
For global token set G: A_{i,g} and A_{g,i} are unmasked for g ∈ G
Complexity: O(n·(w + |G|)) ≈ O(n·w) when |G| ≪ w.
Random Sparse Attention (BigBird)
Each position attends to: its local window + random subset of previous positions + global tokens. The random connections ensure that information can propagate across the full sequence in O(log n) hops (like an expander graph).
2. FlashAttention — Exact Attention in O(n) Memory
⚠️ THIS IS CRITICAL — Standard attention materializes the full n×n attention matrix in GPU HBM, requiring O(n²) memory. FlashAttention computes the SAME exact output using O(n) memory by tiling and recomputing.
The core problem: The attention matrix S = QKᵀ ∈ ℝ^(n×n) is too large for long sequences. For n = 32K and fp16: n² = 1B elements = 2 GB, and we need it per head, per layer.
FlashAttention's insight: Softmax can be computed incrementally using online statistics.
Online Softmax Algorithm
Standard softmax requires two passes: one for max and exp-sum, one for division. Online softmax does it in one pass:
Algorithm (for a vector):
Initialize: m = −∞, ℓ = 0, output = zeros
For each block of values v: 1. m_new = max(m, max(v)) 2. ℓ = ℓ · exp(m − m_new) + Σ exp(v − m_new) 3. output = output · exp(m − m_new) 4. output += softmax_block(v, m_new) 5. m = m_new
Applying to attention: The attention row i is online softmax over q_i · K^T:
o_i = Σ_{j} softmax(q_i · k_j / √d) · v_j
FlashAttention tiles over both the query dimension (outer loop) and key/value dimension (inner loop):
Tiling strategy: 1. Split Q into blocks Q₁, Q₂, ..., Q_{T_r} 2. Split K, V into blocks (K₁,V₁), (K₂,V₂), ..., (K_{T_c}) 3. For each Q block, iterate over K,V blocks: a. Load Q_block, K_block, V_block from HBM into SRAM (on-chip) b. Compute S_block = Q_block · K_blockᵀ / √d c. Compute online softmax update using running m and ℓ d. Accumulate output: O_block = diag(ℓ_new)^(−1) · (diag(ℓ)·e^(m−m_new)·O_block + e^(S_block−m_new)·V_block) 4. Write O_block back to HBM
The output is EXACT (not approximate): FlashAttention produces exactly the same result as standard attention. The only difference is that the n×n matrix S is never fully materialized in HBM — it exists only in small tiles in SRAM.
Memory complexity: Standard = O(n²) HBM. FlashAttention = O(n) HBM (only stores Q,K,V,O, not S).
3. Mixture of Experts (MoE) — Routing Mathematics
MoE transformers replace the dense FFN with multiple "expert" FFNs, using a router to select a subset per token.
Router/Gate Function
Given an input token representation x ∈ ℝ^d:
z = W_g · x ∈ ℝ^E (logits for E experts) g = softmax(top-k(z)) ∈ ℝ^E (sparse gate scores)
where top-k retains only the k largest values (k ≪ E, typically k=1 or k=2). For the non-top-k entries, g_j = 0.
Sparse MoE Forward Pass
The FFN output for token x is a weighted sum of the selected experts:
y = Σ_{j∈T} g_j · FFN_j(x)
where T = {indices of top-k values in z} and FFN_j is the j-th expert FFN.
For k=2 (typical):
y = g_{e₁} · FFN_{e₁}(x) + g_{e₂} · FFN_{e₂}(x)
Parameter count: If each expert has 8d² parameters (standard FFN: d × 4d then 4d × d), a model with E experts has:
Params_MoE ≈ L · (4d² + E · 8d²)
where 4d² accounts for attention projections and E · 8d² for all expert FFNs. But per token, only k=2 experts are active, so FLOPs per token is:
FLOPs_per_token ≈ L · (4d² + k · 8d²) = L · 4d² · (1 + 2k)
So the model has many parameters but constant FLOPs per token — sparse activation.
The Jacobian of the Router
The top-k operation is non-differentiable (hard selection). The gradient flows through the gate:
For the selected experts (j ∈ T):
∂y/∂z_j = FFN_j(x) · ∂g_j/∂z_j where ∂g_j/∂z_j = g_j(1 − g_j)
For non-selected experts (j ∉ T):
g_j = 0, so ∂g_j/∂z_j = 0 (gradient blocked)
The non-selected experts still receive NO gradient for this token. They only learn from tokens that route to them.
4. Load Balancing Loss
Without regularization, the router can collapse to always picking the same few experts. The load balancing loss encourages uniform expert utilization:
L_balance = α · E · Σ_{j=1}^{E} f_j · P_j
where: - f_j = fraction of tokens routed to expert j (empirical) - P_j = average gate probability for expert j (softmax averaged over batch) - α = balancing coefficient (typically 0.01)
Take the coefficient of variation approach:
L_balance = α · (std(f) / mean(f))² = α · CV(f)²
where f = [f_1, ..., f_E] is the vector of expert fractions.
When all experts are used equally: f_j = 1/E for all j, CV = 0, L_balance = 0.
5. Expert Capacity
To prevent individual experts from being overwhelmed:
capacity = (tokens_per_batch / E) · capacity_factor
where capacity_factor > 1 (typically 1.25). If more than capacity tokens are routed to an expert, the excess tokens are "dropped" (their contribution to that expert is skipped). The dropped token ratio should be small (<1%).
Pitfalls
⚠️ Pitfall 1: Thinking FlashAttention is approximate. It computes EXACT attention — same output as standard attention. The only difference is memory: the n×n attention matrix is never materialized in HBM. If you get different results, it's a bug, not an approximation.
⚠️ Pitfall 2: Counting ALL MoE parameters as active. A model with E=8 experts and k=2 has ~8× the FFN parameters of a dense model, but only ~2× the FLOPs per token. The total parameter count is misleading — inference memory needs all E experts, but compute only needs k.
⚠️ Pitfall 3: Forgetting the load balancing loss in training. Without it, the router collapses to always picking the same 1-2 experts. The other E-2 experts receive zero gradient and never learn. The balancing loss is NOT optional — models will NOT "figure it out" on their own.
Key Terms
- 18 10 Transformer Variants Math Focus
- Common Pitfalls
- Example 1: Sliding Window Attention Complexity
- Example 2: MoE Parameter and FLOPs Calculation
- Example 3: Load Balancing Loss
- Expert Capacity
- FFN output for token x
- FlashAttention — Exact Attention in O(n) Memory
- Load Balancing Loss
- Mixture of Experts (MoE) — Routing Mathematics
- Pitfall 2: Counting ALL MoE parameters as active.
- Problem 1
Worked Examples
Example 1: Sliding Window Attention Complexity
Problem: A Transformer with L=32 layers, d=4096, processes a sequence of n=16384 tokens. Compute the attention FLOPs per token for full attention vs. sliding window (w=4096). What is the speedup?
Solution:
Full attention at position t: Attention FLOPs ≈ 2 · H · d_head · t = 2 · d · t Average over last 1000 tokens near n: ~2 · d · n = 2 · 4096 · 16384 ≈ 1.34 × 10^8 FLOPs per layer per token
Sliding window (t > w): Attention FLOPs ≈ 2 · d · w = 2 · 4096 · 4096 ≈ 3.36 × 10^7 FLOPs per layer per token
Speedup: 1.34×10^8 / 3.36×10^7 ≈ 4× for long sequences.
But the bigger win is KV cache memory: sliding window only needs to store w K,V pairs. Cache size drops from n=16384 to w=4096 (4× reduction), and inference latency for long sequences drops proportionally.
Example 2: MoE Parameter and FLOPs Calculation
Problem: A MoE Transformer has L=24, d=4096, E=8 experts, k=2. Standard FFN uses 8d² parameters (d×4d + 4d×d). Compute total parameters and compare FLOPs per token to the dense equivalent (same config, one FFN).
Solution:
Per layer, dense: - Attention: 4d² (Q,K,V,O projections) = 4 · 4096² ≈ 67.1M - FFN: 8d² = 8 · 4096² ≈ 134.2M Total per layer: 201.3M Total: 24 · 201.3M ≈ 4.83B params
Per layer, MoE: - Attention: 4d² = 67.1M - Router: E · d = 8 · 4096 = 32,768 (negligible) - Experts: E · 8d² = 8 · 134.2M = 1,073.6M Total per layer: ~1,140.7M Total: 24 · 1140.7M ≈ 27.4B params
FLOPs per token: Dense: attention (4d²) + FFN (8d²) = 12d² ≈ 12 · 16.8M = 201M FLOPs per layer MoE: attention (4d²) + k · FFN (2 · 8d²) = 20d² ≈ 20 · 16.8M = 336M FLOPs per layer
Summary: MoE has 27.4B parameters (5.7× dense) but only 336M FLOPs/token (1.67× dense). The MoE achieves vastly more parameters for modestly more compute by activating only a subset of experts per token.
Example 3: Load Balancing Loss
Problem: A batch of 256 tokens is routed over E=4 experts. The routing counts are [100, 80, 50, 26] and average softmax probabilities are [0.35, 0.30, 0.22, 0.13]. Compute L_balance with α=0.01.
Solution:
f = [100, 80, 50, 26] / 256 = [0.3906, 0.3125, 0.1953, 0.1016] P = [0.35, 0.30, 0.22, 0.13]
L_balance = 0.01 · 4 · (0.3906·0.35 + 0.3125·0.30 + 0.1953·0.22 + 0.1016·0.13) = 0.04 · (0.1367 + 0.0938 + 0.0430 + 0.0132) = 0.04 · 0.2867 = 0.01147
Interpretation: Expert 0 is overloaded (39% of tokens). This contributes most to the loss. The loss will encourage the router to distribute tokens more evenly. Without this loss, the router might route everything to expert 0.
Quiz
Q1: What does the concept of Common Pitfalls primarily refer to in this subject?
A) A historical anecdote about Common Pitfalls B) The definition and application of Common Pitfalls C) A computational error related to Common Pitfalls D) A visual representation of Common Pitfalls
Correct: B)
- If you chose A: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
- If you chose B: Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus. Correct!
- If you chose C: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
- If you chose D: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
Q2: What is the primary purpose of Expert Capacity?
A) It is used to expert capacity in mathematical analysis B) It is primarily a historical notation system C) It is used only in advanced research contexts D) It replaces all other methods in this domain
Correct: A)
- If you chose A: Expert Capacity serves the purpose described in the correct answer. The other options misrepresent its role. Correct!
- If you chose B: This is incorrect. Expert Capacity serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose C: This is incorrect. Expert Capacity serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose D: This is incorrect. Expert Capacity serves the purpose described in the correct answer. The other options misrepresent its role.
Q3: Which statement about FFN output for token x is TRUE?
A) FFN output for token x is not related to this subject B) FFN output for token x is a fundamental concept covered in this subject C) FFN output for token x is mentioned only as a historical footnote D) FFN output for token x is an advanced topic beyond this subject's scope
Correct: B)
- If you chose A: This is incorrect. FFN output for token x is a fundamental concept covered in this subject. This subject covers FFN output for token x as part of its core content.
- If you chose B: FFN output for token x is a fundamental concept covered in this subject. This subject covers FFN output for token x as part of its core content. Correct!
- If you chose C: This is incorrect. FFN output for token x is a fundamental concept covered in this subject. This subject covers FFN output for token x as part of its core content.
- If you chose D: This is incorrect. FFN output for token x is a fundamental concept covered in this subject. This subject covers FFN output for token x as part of its core content.
Q4: Based on the worked examples in this subject, what is the correct result?
A) An unrelated numerical value B) O(n) HBM (only stores Q,K,V,O, not S). C) A different result from a common mistake D) The inverse of the correct answer
Correct: B)
- If you chose A: This is incorrect. The worked examples show that the result is O(n) HBM (only stores Q,K,V,O, not S).. The other options represent common errors.
- If you chose B: The worked examples show that the result is O(n) HBM (only stores Q,K,V,O, not S).. The other options represent common errors. Correct!
- If you chose C: This is incorrect. The worked examples show that the result is O(n) HBM (only stores Q,K,V,O, not S).. The other options represent common errors.
- If you chose D: This is incorrect. The worked examples show that the result is O(n) HBM (only stores Q,K,V,O, not S).. The other options represent common errors.
Q5: How are FFN output for token x and Load Balancing Loss related?
A) FFN output for token x is a special case of Load Balancing Loss B) FFN output for token x and Load Balancing Loss are closely related concepts C) FFN output for token x and Load Balancing Loss are completely unrelated topics D) FFN output for token x is the inverse of Load Balancing Loss
Correct: B)
- If you chose A: This is incorrect. Both FFN output for token x and Load Balancing Loss are covered in this subject as interconnected topics.
- If you chose B: Both FFN output for token x and Load Balancing Loss are covered in this subject as interconnected topics. Correct!
- If you chose C: This is incorrect. Both FFN output for token x and Load Balancing Loss are covered in this subject as interconnected topics.
- If you chose D: This is incorrect. Both FFN output for token x and Load Balancing Loss are covered in this subject as interconnected topics.
Q6: What is a common pitfall when working with Sparse Attention Patterns?
A) The main error with Sparse Attention Patterns is using it when it is not needed B) Sparse Attention Patterns has no common misconceptions C) A common mistake is confusing Sparse Attention Patterns with a similar concept D) Sparse Attention Patterns is always computed the same way in all contexts
Correct: C)
- If you chose A: This is incorrect. Students often confuse Sparse Attention Patterns with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose B: This is incorrect. Students often confuse Sparse Attention Patterns with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose C: Students often confuse Sparse Attention Patterns with similar-sounding or related concepts. Pay attention to the precise definitions. Correct!
- If you chose D: This is incorrect. Students often confuse Sparse Attention Patterns with similar-sounding or related concepts. Pay attention to the precise definitions.
Q7: When should you apply Flashattention — Exact Attention In O(N) Memory?
A) Avoid Flashattention — Exact Attention In O(N) Memory unless explicitly instructed B) Apply Flashattention — Exact Attention In O(N) Memory to solve problems in this subject's domain C) Use Flashattention — Exact Attention In O(N) Memory only in pure mathematics contexts D) Flashattention — Exact Attention In O(N) Memory is not practically useful
Correct: B)
- If you chose A: This is incorrect. Flashattention — Exact Attention In O(N) Memory is a practical tool used throughout this subject to solve relevant problems.
- If you chose B: Flashattention — Exact Attention In O(N) Memory is a practical tool used throughout this subject to solve relevant problems. Correct!
- If you chose C: This is incorrect. Flashattention — Exact Attention In O(N) Memory is a practical tool used throughout this subject to solve relevant problems.
- If you chose D: This is incorrect. Flashattention — Exact Attention In O(N) Memory is a practical tool used throughout this subject to solve relevant problems.
Practice Problems
Problem 1
For sliding window attention with window size w = 2048 and sequence length n = 8192, what fraction of the full attention matrix is non-zero (not masked)?
Answer
Full attention: n² = 8192² ≈ 67.1M entries Sliding window (causal, lower triangle with window): Non-zero entries per row i: min(i+1, w) for i ≥ w: w entries For i < w: i+1 entries Total non-zero ≈ w·(n−w) + w²/2 for the first w rows ≈ 2048·(8192−2048) + 2048²/2 ≈ 2048·6144 + 2,097,152 ≈ 12,582,912 + 2,097,152 ≈ 14.68M Fraction = 14.68M / 67.1M ≈ 0.219 ≈ 22% The attention matrix is ~78% sparse.Problem 2
FlashAttention processes attention in tiles. If Q and K,V are each split into tiles of size B_r × d and B_c × d respectively, how many tiles are loaded from HBM for one attention computation with n=4096, B_r=128, B_c=128?
Answer
Number of Q tiles = ⌈n / B_r⌉ = ⌈4096/128⌉ = 32 Number of K,V tiles = ⌈n / B_c⌉ = ⌈4096/128⌉ = 32 Each Q tile iterates over all K,V tiles: 32 × 32 = 1024 tile loads. Each tile is B_r × d (Q) or B_c × d (K,V), so total data loaded: = 32 · B_r · d + 32 · 32 · B_c · d · 2 (Q once, K+V per inner loop) = 32·128·d + 32·32·128·d·2 = 4096·d + 262,144·d = 266,240·d For d=128: ~34M elements ≈ 68 MB. Standard attention would materialize S: n² = 16.8M elements ≈ 34 MB (fp16). FlashAttention actually loads MORE total data, but it fits in SRAM. The key is avoiding HBM writes/reads of the large S matrix.Problem 3
An MoE layer has E=16 experts, k=2 active per token. If 1024 tokens are processed, what is the expected number of tokens per expert under uniform routing? If expert 3 receives 200 tokens and capacity_factor=1.25, how many tokens are dropped?
Answer
Expected per expert: (tokens · k) / E = (1024 · 2) / 16 = 2048/16 = 128 tokens Capacity = expected · capacity_factor = 128 · 1.25 = 160 tokens Expert 3 receives 200 tokens > 160 capacity. Dropped = 200 − 160 = 40 tokens (20%). This high drop rate indicates expert 3 is overloaded — the load balancing loss should address this. Dropping tokens means those tokens don't get an FFN contribution from that expert, potentially degrading quality.Problem 4
Compare the total inference-time memory for a dense model (7B params) vs. an MoE model (7B active params, 56B total params, E=8) for a KV cache of 4 GB. Assume fp16.
Answer
Dense model: Weights: 7B · 2 bytes = 14 GB KV cache: 4 GB Total: 18 GB MoE model (7B active, 56B total): ALL weights must be in memory (any expert might be called): 56B · 2 bytes = 112 GB KV cache: 4 GB Total: 116 GB This is the key challenge with MoE inference: while FLOPs are low, ALL expert weights must reside in GPU memory (or be swapped in from CPU, adding latency). A 56B MoE model requires 8× the memory of a 7B dense model, despite similar inference speed.Problem 5
Explain why the top-k operation in MoE routers creates a gradient bottleneck for non-selected experts.
Answer
For a token routed to experts e₁ and e₂ (k=2), only these two experts receive the token and produce output. The other E−2 experts are never invoked. Since their gate scores g_j = 0 (hard zero from top-k) and the output y doesn't depend on them, ∂L/∂z_j = 0 for j ∉ {e₁,e₂}. These experts receive zero gradient from this token. They only learn from tokens that ARE routed to them. This creates a "rich get richer" dynamic — frequently used experts get more gradients and improve, making them even more likely to be selected. The load balancing loss counteracts this.Summary
- Sparse attention (sliding window, global tokens, random) reduces attention from O(n²) to O(n·w) by restricting each position's attention span
- FlashAttention uses tiling and online softmax to compute EXACT attention with O(n) memory instead of O(n²), by never materializing the full attention matrix in HBM
- MoE replaces dense FFN with E expert FFNs; a router g(x) = softmax(top-k(W_g·x)) selects k ≪ E experts per token, giving high parameter count at near-constant FLOPs
- Load balancing loss (α·E·Σ f_j·P_j or α·CV(f)²) prevents router collapse to a few experts during training
- MoE memory is dominated by total expert weights (E× larger than dense), making inference memory a challenge despite low FLOPs
Next Steps
Continue to Phase 19 with 19-01 — Mixture of Experts (MoE) — Deep for an exhaustive treatment of MoE routing, expert capacity, and advanced load balancing strategies.