19-02 — Attention Variants
Phase: 19 — Advanced LLM Mathematics Subject: 19-02 Prerequisites: 19-01 (MoE Deep), 18-10 (Transformer Variants), 18-09 (KV Caching), 17-08 (Multi-Head Attention), 17-07 (Scaled Dot-Product Attention) Next subject: 19-03 — Quantization Mathematics
Learning Objectives
By the end of this subject, you will be able to:
- Derive the KV cache size reduction from Multi-Head (MHA) to Multi-Query (MQA) to Grouped-Query (GQA) and compute exact memory savings for a given configuration (h heads, g groups, n layers, d_head)
- Prove that Multi-Query Attention is a strict capacity reduction of Multi-Head Attention and characterize what representational power is lost
- Trace the FlashAttention-2 algorithmic improvements over FlashAttention-1 (reduced non-matmul FLOPs, parallelized over sequence length) and explain why the causal variant is cheaper
- Formulate ring attention as a block-wise communication scheme and derive the communication volume per device for distributed long-context attention
- Compare the latency-memory tradeoffs of MQA vs GQA vs MHA under batch inference with growing KV cache sizes
Core Content
1. From Multi-Head to Multi-Query: The KV Projection Spectrum
1.1 Multi-Head Attention (MHA) — Baseline
Standard MHA with h heads, each of dimension d_head:
Given input X ∈ ℝ^{n×d}, for each head i ∈ {1,...,h}:
$Q_i = X · W_i^Q ∈ ℝ^{n×d_head}
K_i = X · W_i^K ∈ ℝ^{n×d_head}
V_i = X · W_i^V ∈ ℝ^{n×d_head}
head_i = softmax(Q_i · K_i^T / √d_head) · V_i
MultiHead(X) = Concat(head_1, ..., head_h) · W^O
$
where W_i^Q, W_i^K, W_i^V ∈ ℝ^{d×d_head} and W^O ∈ ℝ^{hd_head×d}.
Parameter count for projections:
$Params_Q = h · d · d_head = d · d Params_K = h · d · d_head = d · d Params_V = h · d · d_head = d · d Params_O = d · d Total = 4d² $
Since h·d_head = d (canonical setup).
KV cache size per layer:
$KV_cache_per_layer = 2 · h · n · d_head (K and V, each size h × n × d_head)
= 2 · n · d (since h·d_head = d)
$
For fp16: 2 · n · d · 2 bytes per layer. For a 32-layer model with d=4096 and n=32K: 32 · 2 · 32768 · 4096 · 2 ≈ 17.2 GB KV cache.
1.2 Multi-Query Attention (MQA) — Minimal KV
MQA uses a SINGLE shared K and V projection across all heads, while keeping separate Q projections:
$K_shared = X · W^K ∈ ℝ^{n×d_head} (one projection!)
V_shared = X · W^V ∈ ℝ^{n×d_head}
For each head i:
Q_i = X · W_i^Q ∈ ℝ^{n×d_head}
head_i = softmax(Q_i · K_shared^T / √d_head) · V_shared
$
Key insight: Every head computes attention using the same K and V, but different Q. The heads still see different "queries," so they can attend to different positions, but the "key" and "value" spaces are shared.
⚠️ THIS IS CRITICAL — MQA reduces KV cache by a factor of h (number of heads) compared to MHA: from 2·h·n·d_head to 2·1·n·d_head. For h=32, this is a 32× reduction. The tradeoff is representational capacity: all heads share the same key/value representations, which can hurt performance on tasks requiring diverse attention patterns. GQA with g groups gives a middle ground — h/g reduction with better quality retention.
Parameter count:
Params_Q = d · d (same as MHA)
Params_K = d · d_head (h× reduction)
Params_V = d · d_head (h× reduction)
Params_O = d · d
Total = 3d² + 2d·d_head = d²(3 + 2/h)
For h=32: ~3.06d² vs 4d² for MHA — about 23.5% parameter reduction.
KV cache size:
$KV_cache_MQA = 2 · n · d_head = 2 · n · d / h $
For the same 32-layer, d=4096, n=32K, h=32 model: 32 · 2 · 32768 · 64 · 2 ≈ 268 MB. That's a 64× reduction from MHA!
1.3 Grouped-Query Attention (GQA) — The Sweet Spot
GQA offers a continuum between MHA and MQA. Heads are partitioned into g groups; heads within a group share K and V projections:
$Number of KV heads: g (1 ≤ g ≤ h)
Heads per group: h/g (integer)
For group j ∈ {1,...,g}:
K_j = X · W_j^K ∈ ℝ^{n×d_head}
V_j = X · W_j^V ∈ ℝ^{n×d_head}
For head i in group j:
Q_i = X · W_i^Q
head_i = softmax(Q_i · K_j^T / √d_head) · V_j
$
Parameter count:
$Params_K = g · d · d_head = g · d · d/h = d·d · g/h Params_V = d·d · g/h $
For Llama 3 70B: h=64, g=8, so g/h = 1/8. KV parameters are 1/8 of MHA.
KV cache size:
$KV_cache_GQA = 2 · g · n · d_head = 2 · n · d · g/h $
Ratio vs MHA: g/h. For g=8, h=64: 8× smaller KV cache.
Why GQA > MQA in practice: MQA (g=1) can underperform MHA on complex tasks because all heads must share the same key/value representation — they can't learn specialized key/value projections. GQA with g=4-8 provides most of the memory savings with near-MHA quality.
2. FlashAttention-2: Algorithmic Improvements
2.1 FlashAttention-1 Recap
FlashAttention-1 introduced tiling and recomputation to avoid materializing the n×n attention matrix S = QK^T in HBM:
- Outer loop: iterate over Q blocks
- Inner loop: iterate over K,V blocks
- Online softmax: maintain running m (max) and ℓ (sum) to compute correct softmax incrementally
Limitation: The algorithm is sequential in the outer loop (Q blocks), which underutilizes GPU parallelism. Also, the causal mask variant does unnecessary computation on masked entries.
2.2 FlashAttention-2 Key Changes
Change 1: Switch loop order — parallelize over Q blocks
FA2 puts the Q block loop on the parallel dimension (thread blocks), eliminating the sequential bottleneck:
// Each thread block handles one Q block in parallel
For thread_block t (handles Q_block_t):
Load Q_block_t into SRAM
For each K,V block:
Load K_block, V_block
Compute S = Q_block_t · K_block^T / √d
Apply softmax incrementally
Accumulate output
Write O_block_t to HBM
This exposes parallelism across the sequence length dimension.
Change 2: Reduce non-matmul FLOPs
FA1 performed rescaling at every inner loop step (expensive divisions and exponentials). FA2 restructures the computation to minimize these:
Instead of rescaling at every inner step, FA2 postpones the division until the final write:
$O_accumulator += exp(S - m_new) · V ℓ_accumulator += exp(S - m_new) // ... at the end: O_final = O_accumulator / ℓ_accumulator $
Change 3: Causal attention optimization
For causal masking (lower triangular), roughly half the S entries are −∞. FA2 skips these entirely by only iterating K,V blocks up to the current Q position:
$For Q_block at position t:
For K,V blocks only up to position t:
Compute attention
// Skip blocks beyond t (they'd be entirely masked)
$
This gives ~2× speedup for causal attention over the full bidirectional FA2.
2.3 FlashAttention-3
FA3 targets Hopper architecture (H100) features: asynchronous execution (warp specialization) and TMA (Tensor Memory Accelerator) for hardware-accelerated data movement. Key innovation: overlap computation of one tile with data loading of the next tile (software pipelining).
3. Ring Attention: Distributed Long-Context
3.1 The Problem
Standard attention requires the full K and V on each device. For D devices with sequence length n, the naive approach is: - Each device stores all K,V (total O(n) per device) — possible if n fits in memory - But for very long contexts (n > 1M), even per-device O(n) is too large
3.2 Ring Attention Algorithm
Devices are arranged in a logical ring. The sequence is split into D chunks, one per device:
$Device i initially holds: Q_i, K_i, V_i (chunk of length n/D)
Goal: compute O_i = Attention(Q_i, K_{full}, V_{full})
$
Algorithm:
For step s = 0 to D−1: 1. Compute partial attention using local K,V: O_i += Attention(Q_i, K_{(i−s)mod D}, V_{(i−s)mod D}) (with online softmax) 2. Send local K,V to the next device in the ring; receive K,V from the previous device 3. Update position offsets for causal masking
After D steps, O_i contains the full attention output.
Communication cost:
$per_device_send = D · (n/D) · d_head · g = n · d_head · g (total K,V sent)
= n · d · g/h
total_communication = D · n · d · g/h
$
Each K,V chunk travels around the ring once (D sends of size n/D per device). The total communication across all devices is D · n · d · g/h. Ring attention keeps memory per device at O(n/D) while computing full attention.
Memory per device:
Ring Attention: O(n/D · d) for K,V storage
Naive: O(n · d) for K,V storage
Speedup ceiling: The computation is still O(n²) per device (each device computes D partial attentions of size (n/D)², so D · (n/D)² = n²/D parallelised). Communication overhead grows as D · n, so beyond some D, communication dominates compute.
4. Memory Analysis: MHA vs MQA vs GQA at Scale
4.1 Comparative KV Cache Table
For an L-layer model with batch size B, sequence length n, hidden dimension d, h heads, g KV groups:
| Variant | KV cache per layer (bytes, fp16) | for L=32, d=4096, n=32K, h=32 |
|---|---|---|
| MHA (g=h) | 2 · h · n · d_head · 2 = 4nd | 4 · 32768 · 4096 = 537 MB/layer → 17.2 GB |
| GQA (g=8) | 2 · g · n · d_head · 2 = 4nd · g/h | 17.2 · 8/32 = 4.3 GB |
| GQA (g=4) | 4nd · g/h | 2.15 GB |
| MQA (g=1) | 4nd/h | 17.2/32 = 537.5 MB |
Key insight: MHA memory is dominated by KV cache for long sequences. GQA provides a linear trade-off: memory ∝ g/h.
When KV cache exceeds model weights, inference becomes memory-bandwidth bound — each token generation requires moving the entire KV cache (but only the relevant row). MQA/GQA directly attack this bottleneck.
Pitfalls
⚠️ Pitfall 1: Computing KV cache with the wrong formula. The correct formula is 2 (K+V) × L × (K,V heads) × n × d_head × bytes_per_element. Using H (attention heads) instead of the number of K,V heads (which is H/G for GQA) will overestimate cache size by the grouping factor.
⚠️ Pitfall 2: Assuming FlashAttention-2 is "just faster FlashAttention-1." FlashAttention-2 fundamentally restructures the loop: the outer loop is over sequence length (parallelized), and it reduces non-matmul FLOPs by ~50%. The causal variant is even faster because it only computes the lower triangle.
⚠️ Pitfall 3: Forgetting that ring attention has a communication cost. Ring attention enables arbitrarily long contexts by distributing the sequence across devices, but each device must communicate K,V blocks to its neighbors. For D devices, this adds O(n·d/D) communication per layer — acceptable for very long sequences where O(n²) attention would be impossible.
Key Terms
- 19 02 Attention Variants
- 64× reduction
- Change 2: Reduce non-matmul FLOPs
- Change 3: Causal attention optimization
- Common Pitfalls
- Does not fit!
- Example 1: GQA Configuration Tradeoff
- Example 2: Ring Attention Step-by-Step
- Example 3: Comparing Attention FLOPs
- FlashAttention-2: Algorithmic Improvements
- From Multi-Head to Multi-Query: The KV Projection Spectrum
- Memory Analysis: MHA vs MQA vs GQA at Scale
Worked Examples
Example 1: GQA Configuration Tradeoff
A model has d=3072, h=24, d_head=d/h=128. Compute the KV projection parameters and KV cache for n=8192 tokens for: (a) MHA, (b) GQA with g=6, (c) MQA. Express in fp16 bytes.
Solution:
(a) MHA (g=h=24): - K params: h · d · d_head = 24 · 3072 · 128 = 9,437,184 - V params: same = 9,437,184 - KV cache: 2 · h · n · d_head = 2 · 24 · 8192 · 128 = 50,331,648 elements - In fp16: 50,331,648 · 2 = 100,663,296 bytes ≈ 96 MB per layer
(b) GQA (g=6): - K params: g · d · d_head = 6 · 3072 · 128 = 2,359,296 - V params: same = 2,359,296 - KV cache: 2 · g · n · d_head = 2 · 6 · 8192 · 128 = 12,582,912 elements - In fp16: 12,582,912 · 2 ≈ 24 MB per layer - Ratio vs MHA: g/h = 6/24 = 0.25 (4× reduction)
(c) MQA (g=1): - K params: 1 · 3072 · 128 = 393,216 - V params: 393,216 - KV cache: 2 · 1 · 8192 · 128 = 2,097,152 elements - In fp16: ≈ 4 MB per layer - Ratio vs MHA: 1/24 ≈ 0.042 (24× reduction)
Example 2: Ring Attention Step-by-Step
3 devices (D=3), sequence split into 3 chunks of length m each. Total n = 3m.
$Initial state: Device 0: Q_0, K_0, V_0 Device 1: Q_1, K_1, V_1 Device 2: Q_2, K_2, V_2 $
Step 0: Each device computes attention with its own K,V. - Device 0: O_0 = Attn(Q_0, K_0, V_0) — causal mask means only positions within chunk 0 - Device 1: O_1 = Attn(Q_1, K_1, V_1) — attends to chunk 1 positions - Device 2: O_2 = Attn(Q_2, K_2, V_2)
Step 1: Send K,V clockwise. Now: - Device 0 has K_2, V_2 (from device 2) - Device 1 has K_0, V_0 - Device 2 has K_1, V_1
Compute with received K,V using online softmax accumulation: - Device 0: O_0 += Attn(Q_0, K_2, V_2) — causal: chunk 2 is FUTURE for chunk 0? No, ring attention wraps — for causal, chunk 2 is future so mask out - Device 1: O_1 += Attn(Q_1, K_0, V_0) — chunk 0 is PAST for chunk 1, attention valid - Device 2: O_2 += Attn(Q_2, K_1, V_1)
Step 2: Send K,V clockwise again: - Device 0 has K_1, V_1 - Device 1 has K_2, V_2 - Device 2 has K_0, V_0
- Device 0: O_0 += Attn(Q_0, K_1, V_1) — chunk 1 is PAST? For causal, chunk 0 attends to chunks... actually for causal mask, chunk 0 attends to nothing before it. Let me reconsider.
Correct causal ring attention: Each device is responsible for a contiguous segment of the sequence. Device i holds positions [i·m, (i+1)·m). For causal masking, Q at position p can only attend to K,V at positions ≤ p.
Device 0 (positions 0...m−1): can only attend to itself (chunk 0). Devices 1 and 2 send their K,V but they're all masked out. Device 0 finishes after step 0.
Device 1 (positions m...2m−1): attends to chunk 0 then chunk 1. Needs K_0,V_0 (step 1) then K_1,V_1 (step 2).
Device 2: attends to chunk 0, chunk 1, chunk 2. Needs all three.
Communication: Each K,V chunk makes D−1 hops at most (in causal case, only backward hops matter for non-causal it's full D hops).
Example 3: Comparing Attention FLOPs
For a single attention layer with n=4096, d=3072, h=24, g=6. Compute FLOPs for: (a) computing Q, K, V projections, (b) computing attention scores QK^T, (c) applying attention to V, (d) the output projection. Compare MHA vs GQA.
Solution:
d_head = 3072/24 = 128.
(a) Projections: - Q: n · d · d_head · h = n · d · d = 4096 · 3072² = 4096 · 9,437,184 ≈ 3.87×10¹⁰
MHA: - K: same as Q = 3.87×10¹⁰ - V: same = 3.87×10¹⁰
GQA: - K: n · d · d_head · g = 4096 · 3072 · 128 · 6 = 4096 · 3072 · 768 = 9.66×10⁹ - V: same = 9.66×10⁹
Projection savings (GQA vs MHA): K and V are g/h = 0.25× of MHA.
(b) Attention scores QK^T: Same for both! Each head computes Q_i · K_j^T ∈ ℝ^{n×n} regardless of whether K is shared.
FLOPs = h · n · n · d_head = 24 · 4096² · 128 = 24 · 16,777,216 · 128 ≈ 5.15×10¹⁰
(c) Attention × V: Same for both. h · n · n · d_head = 5.15×10¹⁰
(d) Output projection: Same for both. n · d · d = 3.87×10¹⁰
Total FLOPs ratio: GQA saves only on K,V projections. The attention scores dominate for long sequences (∝ n²). So GQA's FLOP savings are modest for long sequences — its primary benefit is KV cache memory, not compute.
Quiz
Q1: What does the concept of Common Pitfalls primarily refer to in this subject?
A) A computational error related to Common Pitfalls B) A visual representation of Common Pitfalls C) A historical anecdote about Common Pitfalls D) The definition and application of Common Pitfalls
Correct: D)
- If you chose A: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
- If you chose B: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
- If you chose C: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
- If you chose D: Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus. Correct!
Q2: What is the primary purpose of From Multi-Head To Multi-Query: The Kv Projection Spectrum?
A) It is primarily a historical notation system B) It replaces all other methods in this domain C) It is used to from multi-head to multi-query: the kv projection spectrum in mathematical analysis D) It is used only in advanced research contexts
Correct: C)
- If you chose A: This is incorrect. From Multi-Head To Multi-Query: The Kv Projection Spectrum serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose B: This is incorrect. From Multi-Head To Multi-Query: The Kv Projection Spectrum serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose C: From Multi-Head To Multi-Query: The Kv Projection Spectrum serves the purpose described in the correct answer. The other options misrepresent its role. Correct!
- If you chose D: This is incorrect. From Multi-Head To Multi-Query: The Kv Projection Spectrum serves the purpose described in the correct answer. The other options misrepresent its role.
Q3: Which statement about Flashattention-2: Algorithmic Improvements is TRUE?
A) Flashattention-2: Algorithmic Improvements is mentioned only as a historical footnote B) Flashattention-2: Algorithmic Improvements is an advanced topic beyond this subject's scope C) Flashattention-2: Algorithmic Improvements is not related to this subject D) Flashattention-2: Algorithmic Improvements is a fundamental concept covered in this subject
Correct: D)
- If you chose A: This is incorrect. Flashattention-2: Algorithmic Improvements is a fundamental concept covered in this subject. This subject covers Flashattention-2: Algorithmic Improvements as part of its core content.
- If you chose B: This is incorrect. Flashattention-2: Algorithmic Improvements is a fundamental concept covered in this subject. This subject covers Flashattention-2: Algorithmic Improvements as part of its core content.
- If you chose C: This is incorrect. Flashattention-2: Algorithmic Improvements is a fundamental concept covered in this subject. This subject covers Flashattention-2: Algorithmic Improvements as part of its core content.
- If you chose D: Flashattention-2: Algorithmic Improvements is a fundamental concept covered in this subject. This subject covers Flashattention-2: Algorithmic Improvements as part of its core content. Correct!
Q4: Based on the worked examples in this subject, what is the correct result?
A) d² B) A different result from a common mistake C) The inverse of the correct answer D) An unrelated numerical value
Correct: A)
- If you chose A: The worked examples show that the result is d². The other options represent common errors. Correct!
- If you chose B: This is incorrect. The worked examples show that the result is d². The other options represent common errors.
- If you chose C: This is incorrect. The worked examples show that the result is d². The other options represent common errors.
- If you chose D: This is incorrect. The worked examples show that the result is d². The other options represent common errors.
Q5: How are Flashattention-2: Algorithmic Improvements and Ring Attention: Distributed Long-Context related?
A) Flashattention-2: Algorithmic Improvements and Ring Attention: Distributed Long-Context are closely related concepts B) Flashattention-2: Algorithmic Improvements is the inverse of Ring Attention: Distributed Long-Context C) Flashattention-2: Algorithmic Improvements is a special case of Ring Attention: Distributed Long-Context D) Flashattention-2: Algorithmic Improvements and Ring Attention: Distributed Long-Context are completely unrelated topics
Correct: A)
- If you chose A: Both Flashattention-2: Algorithmic Improvements and Ring Attention: Distributed Long-Context are covered in this subject as interconnected topics. Correct!
- If you chose B: This is incorrect. Both Flashattention-2: Algorithmic Improvements and Ring Attention: Distributed Long-Context are covered in this subject as interconnected topics.
- If you chose C: This is incorrect. Both Flashattention-2: Algorithmic Improvements and Ring Attention: Distributed Long-Context are covered in this subject as interconnected topics.
- If you chose D: This is incorrect. Both Flashattention-2: Algorithmic Improvements and Ring Attention: Distributed Long-Context are covered in this subject as interconnected topics.
Q6: What is a common pitfall when working with Memory Analysis: Mha Vs Mqa Vs Gqa At Scale?
A) A common mistake is confusing Memory Analysis: Mha Vs Mqa Vs Gqa At Scale with a similar concept B) The main error with Memory Analysis: Mha Vs Mqa Vs Gqa At Scale is using it when it is not needed C) Memory Analysis: Mha Vs Mqa Vs Gqa At Scale is always computed the same way in all contexts D) Memory Analysis: Mha Vs Mqa Vs Gqa At Scale has no common misconceptions
Correct: A)
- If you chose A: Students often confuse Memory Analysis: Mha Vs Mqa Vs Gqa At Scale with similar-sounding or related concepts. Pay attention to the precise definitions. Correct!
- If you chose B: This is incorrect. Students often confuse Memory Analysis: Mha Vs Mqa Vs Gqa At Scale with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose C: This is incorrect. Students often confuse Memory Analysis: Mha Vs Mqa Vs Gqa At Scale with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose D: This is incorrect. Students often confuse Memory Analysis: Mha Vs Mqa Vs Gqa At Scale with similar-sounding or related concepts. Pay attention to the precise definitions.
Q7: When should you apply Example 1: Gqa Configuration Tradeoff?
A) Apply Example 1: Gqa Configuration Tradeoff to solve problems in this subject's domain B) Avoid Example 1: Gqa Configuration Tradeoff unless explicitly instructed C) Use Example 1: Gqa Configuration Tradeoff only in pure mathematics contexts D) Example 1: Gqa Configuration Tradeoff is not practically useful
Correct: A)
- If you chose A: Example 1: Gqa Configuration Tradeoff is a practical tool used throughout this subject to solve relevant problems. Correct!
- If you chose B: This is incorrect. Example 1: Gqa Configuration Tradeoff is a practical tool used throughout this subject to solve relevant problems.
- If you chose C: This is incorrect. Example 1: Gqa Configuration Tradeoff is a practical tool used throughout this subject to solve relevant problems.
- If you chose D: This is incorrect. Example 1: Gqa Configuration Tradeoff is a practical tool used throughout this subject to solve relevant problems.
Practice Problems
Problem 1
A model with d=4096, h=32, L=40 layers, batch size 64, sequence length 8192. Compute the total KV cache memory in GB (fp16) for MHA, GQA(g=8), and MQA.
Answer
d_head = 4096/32 = 128. MHA (g=32): per layer = 2 · h · n · d_head · 2 bytes = 4 · 32 · 8192 · 128 · 2 = tricky — let's use formula 4nd = 4 · 8192 · 4096 = 134,217,728 bytes = 128 MB per layer. 40 layers: 5.0 GB. With batch 64: 5.0GB (KV is per-batch-item if different sequences, but batched same-sequence shares cache? Typically KV cache is per batch item.) Actually, KV cache is PER BATCH ITEM (different sequences have different K,V). So: MHA per batch item per layer: 4nd bytes = 4 · 8192 · 4096 = 134.2 MB. Total = 40 · 64 · 134.2 MB = 343.6 GB. GQA (g=8): per batch item per layer: 4nd · g/h = 134.2 MB · 8/32 = 33.55 MB. Total = 40 · 64 · 33.55 MB = 85.9 GB. MQA (g=1): per batch item per layer: 134.2 MB / 32 = 4.19 MB. Total = 40 · 64 · 4.19 MB = 10.7 GB. This illustrates why MQA/GQA matters at scale: MHA requires 344 GB of KV cache, which won't fit on even an H100 (80 GB). GQA at 86 GB might barely fit across 2 GPUs. MQA at 11 GB is practical.Problem 2
Prove that the K,V projection parameter count for GQA is exactly g/h times that of MHA. Then show what happens when g=1 and g=h.
Answer
MHA K projection params: h · d · d_head = h · d · (d/h) = d² GQA K projection params: g · d · d_head = g · d · (d/h) = d² · g/h Ratio: (d² · g/h) / d² = g/h ∎ When g=1 (MQA): ratio = 1/h → massive savings. When g=h (MHA): ratio = 1 → same as MHA. Same analysis applies to V projections. Total KV projection params: MHA = 2d², GQA = 2d² · g/h.Problem 3
In FlashAttention-2, why does switching the loop order (parallelizing Q blocks instead of making Q the outer sequential loop) improve GPU utilization?
Answer
GPUs have thousands of cores that can execute in parallel via thread blocks. In FA1, the outer loop iterates sequentially over Q blocks — at any moment, only one Q block is being processed. The inner loop (over K,V blocks) runs within a single thread block. FA2 moves Q blocks to the parallel dimension: each thread block handles ONE Q block, and all thread blocks (many Q blocks) run concurrently. The inner K,V loop within each thread block remains sequential, but many such loops run simultaneously. This increases occupancy (more thread blocks active simultaneously), better utilizing GPU compute units. Additionally, since each thread block processes exactly one Q block from start to finish, there's no synchronization needed between thread blocks — they're embarrassingly parallel over the Q dimension. This eliminates the sequential bottleneck of FA1's outer loop.Problem 4
For ring attention with D=8 devices, n=128K tokens, d=4096, h=32, compute: (a) tokens per device, (b) memory per device for KV storage in fp16, (c) total communication volume per device in bytes.
Answer
d_head = 4096/32 = 128. (a) Tokens per device: n/D = 128K/8 = 16,384 tokens. (b) Memory per device for K (fp16): n/D · d · 2 = 16384 · 4096 · 2 = 134,217,728 bytes = 128 MB. For V: same = 128 MB. Total KV: 256 MB per device. Compare naive (store all K,V on each device): n · d · 2 = 128K · 4096 · 2 = 1 GB per K or V = 2 GB total. Ring attention saves 8× memory. With GQA (g=8): ring attention memory = 256 MB · 8/32 = 64 MB. Naive = 2 GB · 8/32 = 500 MB. (c) Communication per device: each device sends its K and V chunk D−1 times around the ring. Each send is n/D · d_head · h elements (for MHA, KV size = 2 · n/D · d in elements). Per send: n/D · d · 2 bytes (K) + n/D · d · 2 bytes (V) = 2 · n/D · d · 2 bytes = 4 · 16384 · 4096 = 268,435,456 bytes ≈ 256 MB. Total sends per device: D−1 = 7 (each chunk visits each device once). Total communication = 7 · 256 MB = 1.75 GB per device. For the full ring: D · 1.75 GB = 14 GB total communication across the network. This is why ring attention is practical: per-device memory drops 8× while per-device communication is linear in D (7× the stored data), not quadratic.Problem 5
A model uses GQA with g=4, h=32, d=6656, L=60 layers, n=4096, B=128. The available GPU has 80 GB memory. Model weights (fp16) are 40 GB. Compute whether the KV cache fits in the remaining memory for prefill and for each subsequent decode step.
Answer
d_head = 6656/32 = 208. KV cache per batch item per layer: 2 · g · n · d_head = 2 · 4 · 4096 · 208 = 6,815,744 elements. In fp16: 6,815,744 · 2 = 13,631,488 bytes ≈ 13.0 MB. Total KV cache: L · B · 13.0 MB = 60 · 128 · 13.0 MB = 99,840 MB ≈ 97.5 GB. Model weights: 40 GB. Remaining memory: 80 − 40 = 40 GB. KV cache (97.5 GB) > remaining memory (40 GB). **Does not fit!** The KV cache alone requires 97.5 GB but only 40 GB is available. Solutions: - Reduce batch size: B ≤ 40GB / (60 · 13.0 MB) = 40GB / 780MB ≈ 52. - Use MQA (g=1): cache drops to 97.5/4 = 24.4 GB → fits with margin. - Use PagedAttention / vLLM-style memory management to share KV blocks. - Reduce sequence length. This calculation shows why GQA with g=4 is often insufficient for high-throughput serving and why further optimizations (MQA, paged KV cache, quantization of KV cache) are needed.Summary
- Multi-Query Attention (MQA, g=1) shares K,V across all heads — h× KV cache reduction and (2/h)d² parameter savings — but can degrade quality on complex tasks
- Grouped-Query Attention (GQA) interpolates between MHA and MQA with g KV groups (1 ≤ g ≤ h) — KV cache scales as g/h while preserving near-MHA quality at g=4-8
- FlashAttention-2 improves over FA1 by parallelizing over Q blocks, reducing non-matmul FLOPs, and optimizing causal masking — providing ~2× speedup for causal attention with O(n) memory
- Ring attention distributes long-context attention across D devices in a ring topology — each device sends O(n/D) data D−1 times, reducing per-device memory O(n/D) while computing exact attention
- The dominant bottleneck shifts for different configurations: long sequences → attention FLOPs (O(n²)), large batch → KV cache memory (O(B·L·n·g/h)), large model → weight loading bandwidth
Next Steps
Continue to 19-03 — Quantization Mathematics to learn about uniform affine quantization, GPTQ, AWQ, and NF4 — the mathematics behind compressing model weights to 8-bit, 4-bit, and beyond.