17-08 — Multi-Head Attention
Phase: 17 — Deep Learning Architectures (Math) Subject: 17-08 Prerequisites: 17-07 (Scaled Dot-Product Attention), 17-06 (Attention Mechanism), 8-4 (Matrix multiplication — comfort with projections) Next subject: 17-09 — Transformer Architecture
Learning Objectives
By the end of this subject, you will be able to:
- Derive the full multi-head attention equation from single-head scaled dot-product attention
- Explain the projection matrices W^Q, W^K, W^V, W^O and their shapes
- Quantify the computational cost: O(n²d + nd²) and explain when each term dominates
- Prove that multi-head attention with h heads does NOT increase total FLOPs vs single-head (for same total dimension)
- Explain why multiple heads are necessary — the "representation subspace" argument
Core Content
1. The Limitation of Single-Head Attention
Single-head attention computes:
Attention(Q,K,V) = softmax(QKᵀ/√d_k)V
This produces ONE attention distribution per query position. But sequences have multiple types of relationships: - Syntactic dependencies (verb-subject agreement) - Semantic relationships (pronoun-antecedent) - Positional relationships (nearby vs faraway tokens) - Content-based groupings (noun phrases)
A single attention distribution can't simultaneously model all of these. Multi-head attention addresses this by running attention multiple times in parallel with different learned projections.
2. Multi-Head Attention Equation
Instead of one attention function on d_model-dimensional Q, K, V, we:
- Project Q, K, V h times with different learned weight matrices
- Apply scaled dot-product attention to each projected set
- Concatenate the h outputs
- Project back to d_model
Formally:
head_i = Attention(Q_i, K_i, V_i) = softmax(Q_i K_iᵀ / √d_k) V_i
where:
Q_i = Q · W_i^Q (∈ ℝ^(n_q × d_k)) K_i = K · W_i^K (∈ ℝ^(n_k × d_k)) V_i = V · W_i^V (∈ ℝ^(n_k × d_v))
And the final output:
MultiHead(Q, K, V) = Concat(head₁, ..., head_h) · W^O
Where: - W_i^Q ∈ ℝ^(d_model × d_k) - W_i^K ∈ ℝ^(d_model × d_k) - W_i^V ∈ ℝ^(d_model × d_v) - W^O ∈ ℝ^(h·d_v × d_model)
⚠️ THIS IS CRITICAL — Each head projects Q, K, V into a different subspace before computing attention. This allows different heads to specialize in different types of relationships. The concatenation + output projection merges these diverse perspectives into a single output.
3. Standard Dimensions
In the original Transformer and most implementations:
- d_model = 512 (base) or 768 (BERT base) or 4096 (GPT-3)
- h = 8 (number of heads)
- d_k = d_v = d_model / h
So for d_model = 512, h = 8: each head works with d_k = d_v = 64.
The total output dimension before W^O is h·d_v = 8·64 = 512 = d_model.
4. Computational Cost Analysis
Let's compute FLOPs for multi-head attention with sequence length n, model dimension d, h heads, d_k = d_v = d/h.
Projections (Q, K, V for all heads): Each: n × d @ d × (h·d_k) = n × d @ d × d = O(nd²) 3 projections: 3·O(nd²) + W^O: O(nd²) Total projections: O(nd²)
Attention per head: QKᵀ: n × d_k @ d_k × n = O(n²·d_k) = O(n²·d/h) For h heads: h · O(n²·d/h) = O(n²d)
Value-weighted sum per head: A·V: n × n @ n × d_v = O(n²·d_v) = O(n²·d/h) For h heads: h · O(n²·d/h) = O(n²d)
Total: O(n²d + nd²)
This is the same asymptotic cost as single-head attention with dimension d! Multi-head doesn't add computational overhead — it simply reorganizes the computation.
When does each term dominate? - n < d (short sequences, large models): nd² dominates — projections are the bottleneck - n > d (long sequences): n²d dominates — attention matrix is the bottleneck
5. Why Multiple Heads? The Subspace Argument
Consider a single-head attention with d_model = 512. The attention pattern is a single n×n matrix, and the value aggregation averages over all 512 dimensions with the same weights.
With h=8 heads of d_k=64 each: - Head 1 might learn positional patterns (attend to nearby tokens) - Head 2 might learn syntactic patterns (attend to verb for subject) - Head 3 might learn semantic patterns (attend to related entities) - etc.
Each head operates in a different 64-dimensional subspace of the full 512-dimensional space. The attention distribution can be DIFFERENT in each subspace — enabling the model to simultaneously track multiple types of relationships.
Mathematically: The joint attention distribution across all heads is a richer object than a single n×n matrix. It's an (h × n × n) tensor, though the attention is computed independently per head.
6. Parameter Count
For each head i: - W_i^Q: d_model × d_k = d × (d/h) = d²/h params - W_i^K: d²/h params - W_i^V: d²/h params
For h heads: 3h·(d²/h) = 3d² parameters
Plus W^O: h·d_v × d_model = d × d = d² parameters
Total: 4d² parameters (vs. 4d² for single-head with same total dimension — multi-head is parameter-neutral!)
Actually, single-head with dim d would have W^Q, W^K, W^V each d×d and W^O d×d: also 4d². Multi-head repartitions the SAME parameter budget.
7. Efficient Implementation
In practice, the per-head projections are implemented as one big matrix multiply:
Q_all = Q · W_all^Q where W_all^Q ∈ ℝ^(d × d)
Then the result is reshaped from (n, d) to (n, h, d_k) and transposed to (h, n, d_k) for batched attention computation:
# Concatenated view
Q = X @ W_Q # (n, d)
Q = Q.view(n, h, d_k).transpose(0, 1) # (h, n, d_k)
# Then batched attention
scores = Q @ K.transpose(-2, -1) / sqrt(d_k) # (h, n, n)
attn = softmax(scores, dim=-1)
out = attn @ V # (h, n, d_v)
out = out.transpose(0, 1).reshape(n, d) # (n, d)
out = out @ W_O # (n, d)
8. Attention Dropout
During training, dropout can be applied to the attention weights A (after softmax):
A_dropout = dropout(A, p)
This randomly zeros out some attention connections, forcing the model to not rely on any single attention path — a form of regularization specific to attention.
Key Terms
- Apply
- Concatenate
- Multi-head attention
- Project
Worked Examples
Example 1: Dimensions Through Multi-Head
Problem: d_model=768, h=12 heads. Sequence length n=128. Compute the shapes of all intermediate tensors.
Solution: d_k = d_v = 768/12 = 64
Input: Q, K, V each ∈ ℝ^(128×768)
After per-head projection: Q_i ∈ ℝ^(128×64), K_i ∈ ℝ^(128×64), V_i ∈ ℝ^(128×64) for each head
Scores per head: S_i = Q_i·K_iᵀ/√64 ∈ ℝ^(128×128)
Attention weights: A_i = softmax(S_i) ∈ ℝ^(128×128)
Head output: head_i = A_i·V_i ∈ ℝ^(128×64)
Concatenated: Concat(head₁,...,head₁₂) ∈ ℝ^(128×768)
Final: out = Concat·W^O ∈ ℝ^(128×768)
Example 2: Parameter Count
Problem: Compute the number of parameters in the multi-head attention for d_model=512, h=8.
Solution: Per head: W_i^Q, W_i^K, W_i^V each 512×64 = 32,768 params 8 heads × 3 matrices × 32,768 = 786,432
W^O: (8×64) × 512 = 512 × 512 = 262,144
Total: 786,432 + 262,144 = 1,048,576 ≈ 1M parameters
Verification: 4d² = 4·512² = 4·262,144 = 1,048,576 ✓
Example 3: FLOPs Analysis
Problem: For n=2048, d=4096, h=32, compute approximate FLOPs and determine which term dominates.
Solution: Projections: O(nd²) = 2048 · 4096² ≈ 2048 · 16.8M ≈ 34.4B FLOPs Attention matrices: O(n²d) = 2048² · 4096 ≈ 4.2M · 4096 ≈ 17.2B FLOPs Value-weighted sum: O(n²d) ≈ 17.2B FLOPs
Total: ≈ 68.8B FLOPs. The projection term (34.4B) dominates slightly because d > n. For typical LLM inference (n≈2048, d≈4096), both terms are comparable.
Quiz
Q1: What does the concept of Multi-head attention primarily refer to in this subject?
A) The definition and application of Multi-head attention B) A visual representation of Multi-head attention C) A historical anecdote about Multi-head attention D) A computational error related to Multi-head attention
Correct: A)
- If you chose A: Multi-head attention is defined as: the definition and application of multi-head attention. The other options describe different aspects that are not the primary focus. Correct!
- If you chose B: This is incorrect. Multi-head attention is defined as: the definition and application of multi-head attention. The other options describe different aspects that are not the primary focus.
- If you chose C: This is incorrect. Multi-head attention is defined as: the definition and application of multi-head attention. The other options describe different aspects that are not the primary focus.
- If you chose D: This is incorrect. Multi-head attention is defined as: the definition and application of multi-head attention. The other options describe different aspects that are not the primary focus.
Q2: What is the primary purpose of Project?
A) It replaces all other methods in this domain B) It is primarily a historical notation system C) It is used only in advanced research contexts D) It is used to project in mathematical analysis
Correct: D)
- If you chose A: This is incorrect. Project serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose B: This is incorrect. Project serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose C: This is incorrect. Project serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose D: Project serves the purpose described in the correct answer. The other options misrepresent its role. Correct!
Q3: Which statement about Apply is TRUE?
A) Apply is a fundamental concept covered in this subject B) Apply is mentioned only as a historical footnote C) Apply is an advanced topic beyond this subject's scope D) Apply is not related to this subject
Correct: A)
- If you chose A: Apply is a fundamental concept covered in this subject. This subject covers Apply as part of its core content. Correct!
- If you chose B: This is incorrect. Apply is a fundamental concept covered in this subject. This subject covers Apply as part of its core content.
- If you chose C: This is incorrect. Apply is a fundamental concept covered in this subject. This subject covers Apply as part of its core content.
- If you chose D: This is incorrect. Apply is a fundamental concept covered in this subject. This subject covers Apply as part of its core content.
Q4: Based on the worked examples in this subject, what is the correct result?
A) An unrelated numerical value B) The inverse of the correct answer C) ```python D) A different result from a common mistake
Correct: C)
- If you chose A: This is incorrect. The worked examples show that the result is ```python. The other options represent common errors.
- If you chose B: This is incorrect. The worked examples show that the result is ```python. The other options represent common errors.
- If you chose C: The worked examples show that the result is ```python. The other options represent common errors. Correct!
- If you chose D: This is incorrect. The worked examples show that the result is ```python. The other options represent common errors.
Q5: How are Apply and Concatenate related?
A) Apply is the inverse of Concatenate B) Apply and Concatenate are closely related concepts C) Apply is a special case of Concatenate D) Apply and Concatenate are completely unrelated topics
Correct: B)
- If you chose A: This is incorrect. Both Apply and Concatenate are covered in this subject as interconnected topics.
- If you chose B: Both Apply and Concatenate are covered in this subject as interconnected topics. Correct!
- If you chose C: This is incorrect. Both Apply and Concatenate are covered in this subject as interconnected topics.
- If you chose D: This is incorrect. Both Apply and Concatenate are covered in this subject as interconnected topics.
Q6: What is a common pitfall when working with The Limitation Of Single-Head Attention?
A) The Limitation Of Single-Head Attention is always computed the same way in all contexts B) The Limitation Of Single-Head Attention has no common misconceptions C) The main error with The Limitation Of Single-Head Attention is using it when it is not needed D) A common mistake is confusing The Limitation Of Single-Head Attention with a similar concept
Correct: D)
- If you chose A: This is incorrect. Students often confuse The Limitation Of Single-Head Attention with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose B: This is incorrect. Students often confuse The Limitation Of Single-Head Attention with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose C: This is incorrect. Students often confuse The Limitation Of Single-Head Attention with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose D: Students often confuse The Limitation Of Single-Head Attention with similar-sounding or related concepts. Pay attention to the precise definitions. Correct!
Q7: When should you apply Multi-Head Attention Equation?
A) Avoid Multi-Head Attention Equation unless explicitly instructed B) Use Multi-Head Attention Equation only in pure mathematics contexts C) Multi-Head Attention Equation is not practically useful D) Apply Multi-Head Attention Equation to solve problems in this subject's domain
Correct: D)
- If you chose A: This is incorrect. Multi-Head Attention Equation is a practical tool used throughout this subject to solve relevant problems.
- If you chose B: This is incorrect. Multi-Head Attention Equation is a practical tool used throughout this subject to solve relevant problems.
- If you chose C: This is incorrect. Multi-Head Attention Equation is a practical tool used throughout this subject to solve relevant problems.
- If you chose D: Multi-Head Attention Equation is a practical tool used throughout this subject to solve relevant problems. Correct!
Practice Problems
Problem 1
What is d_k if d_model=1024 and h=16? Verify that h·d_k = d_model.
Answer
d_k = 1024/16 = 64. Verification: 16·64 = 1024 = d_model ✓.Problem 2
Why doesn't multi-head attention increase the total parameter count compared to single-head attention with the same total dimension?
Answer
Single-head with d_model: W^Q, W^K, W^V each d×d = d², plus W^O d×d = d² → 4d² params. Multi-head: h copies of (d×d_k), where d_k = d/h. Total for Q projections: h·d·(d/h) = d². Same for K and V. Plus W^O = d². Total: 4d². The computation is the same; only the organization differs.Problem 3
Explain why different heads can learn different attention patterns. What prevents all heads from converging to the same pattern?
Answer
Each head has its OWN projection matrices W_i^Q, W_i^K, W_i^V, so each head projects into a different subspace. The attention is computed independently per head. While nothing fundamentally prevents convergence, the random initialization and the fact that different subspaces capture different aspects of the data mean heads naturally specialize. In practice, some redundancy is observed — not all heads are unique.Problem 4
How many attention weight matrices (A_i) are produced by a multi-head attention layer with h=8, batch_size=32, n=128?
Answer
8 heads × 32 batch items = 256 attention matrices, each 128×128. Total memory: 256 · 128² · 4 bytes (float32) ≈ 16.8 MB for attention weights alone.Problem 5
In multi-head cross-attention, where do Q, K, and V come from?
Answer
Q comes from the decoder's hidden states (or previous layer output). K and V come from the encoder's output. Each head projects these separately, allowing different heads to attend to different aspects of the encoded source sequence.Summary
- Multi-head attention runs h parallel attention operations with learned projections: head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
- Output = Concat(head₁,...,head_h)·W^O merges the h per-head outputs into d_model dimensions
- Computational cost is O(n²d + nd²) — asymptotically identical to single-head with same total dimension
- Different heads learn to attend to different representation subspaces, enabling the model to jointly track multiple relationship types
- Parameter count is exactly 4d² regardless of number of heads — the parameter budget is repartitioned, not increased
Pitfalls
- Thinking multi-head attention costs h× more compute than single-head attention. Each head operates on a reduced dimension d_k = d/h, so the per-head QKᵀ cost is O(n²d/h). Summed over h heads: O(n²d) — exactly the same asymptotic cost as single-head with dimension d. Multi-head repartitions the same FLOP budget across independent subspaces.
- Forgetting the W^O output projection. The output of each head is concatenated into an (n × h·d_v) tensor, but this hasn't yet been mapped back to d_model. Without W^O, the output dimension is wrong and the residual connection (in Transformers) breaks. Always apply the final linear projection after concatenation.
- Using d_k ≠ d_model/h. The concatenation of all h heads must produce dimension d_model for the output projection. If d_k does not divide d_model evenly, or if you use a non-standard per-head dimension, the shapes won't align. Standard practice: ensure h·d_k = d_model.
- Getting the tensor reshape/transpose order wrong. The standard pattern is: (batch, seq, d_model) → project to (batch, seq, d_model) → reshape to (batch, seq, h, d_k) → transpose to (batch, h, seq, d_k). Getting the axis order wrong (e.g., (h, batch, seq, d_k)) produces silently incorrect attention — the model trains but each "head" computes nonsense because the sequence dimension is misaligned.
- Assuming every head learns a unique, interpretable pattern. In practice, many heads are redundant or degenerate — some learn nearly identical patterns, some attend uniformly to all positions, some attend exclusively to [CLS] or [SEP] tokens. This is normal. Pruning redundant heads is an active research area; don't expect all h heads to be maximally diverse.
Next Steps
Continue to 17-09 — Transformer Architecture to see how multi-head attention, feedforward networks, and residual connections are combined into the full Transformer.