Math graphic
📐 Concept diagram

17-08 — Multi-Head Attention

Phase: 17 — Deep Learning Architectures (Math) Subject: 17-08 Prerequisites: 17-07 (Scaled Dot-Product Attention), 17-06 (Attention Mechanism), 8-4 (Matrix multiplication — comfort with projections) Next subject: 17-09 — Transformer Architecture


Learning Objectives

By the end of this subject, you will be able to:

  1. Derive the full multi-head attention equation from single-head scaled dot-product attention
  2. Explain the projection matrices W^Q, W^K, W^V, W^O and their shapes
  3. Quantify the computational cost: O(n²d + nd²) and explain when each term dominates
  4. Prove that multi-head attention with h heads does NOT increase total FLOPs vs single-head (for same total dimension)
  5. Explain why multiple heads are necessary — the "representation subspace" argument

Core Content

1. The Limitation of Single-Head Attention

Single-head attention computes:

Attention(Q,K,V) = softmax(QKᵀ/√d_k)V

This produces ONE attention distribution per query position. But sequences have multiple types of relationships: - Syntactic dependencies (verb-subject agreement) - Semantic relationships (pronoun-antecedent) - Positional relationships (nearby vs faraway tokens) - Content-based groupings (noun phrases)

A single attention distribution can't simultaneously model all of these. Multi-head attention addresses this by running attention multiple times in parallel with different learned projections.

2. Multi-Head Attention Equation

Instead of one attention function on d_model-dimensional Q, K, V, we:

  1. Project Q, K, V h times with different learned weight matrices
  2. Apply scaled dot-product attention to each projected set
  3. Concatenate the h outputs
  4. Project back to d_model

Formally:

head_i = Attention(Q_i, K_i, V_i) = softmax(Q_i K_iᵀ / √d_k) V_i

where:

Q_i = Q · W_i^Q (∈ ℝ^(n_q × d_k)) K_i = K · W_i^K (∈ ℝ^(n_k × d_k)) V_i = V · W_i^V (∈ ℝ^(n_k × d_v))

And the final output:

MultiHead(Q, K, V) = Concat(head₁, ..., head_h) · W^O

Where: - W_i^Q ∈ ℝ^(d_model × d_k) - W_i^K ∈ ℝ^(d_model × d_k) - W_i^V ∈ ℝ^(d_model × d_v) - W^O ∈ ℝ^(h·d_v × d_model)

⚠️ THIS IS CRITICAL — Each head projects Q, K, V into a different subspace before computing attention. This allows different heads to specialize in different types of relationships. The concatenation + output projection merges these diverse perspectives into a single output.

3. Standard Dimensions

In the original Transformer and most implementations:

So for d_model = 512, h = 8: each head works with d_k = d_v = 64.

The total output dimension before W^O is h·d_v = 8·64 = 512 = d_model.

4. Computational Cost Analysis

Let's compute FLOPs for multi-head attention with sequence length n, model dimension d, h heads, d_k = d_v = d/h.

Projections (Q, K, V for all heads): Each: n × d @ d × (h·d_k) = n × d @ d × d = O(nd²) 3 projections: 3·O(nd²) + W^O: O(nd²) Total projections: O(nd²)

Attention per head: QKᵀ: n × d_k @ d_k × n = O(n²·d_k) = O(n²·d/h) For h heads: h · O(n²·d/h) = O(n²d)

Value-weighted sum per head: A·V: n × n @ n × d_v = O(n²·d_v) = O(n²·d/h) For h heads: h · O(n²·d/h) = O(n²d)

Total: O(n²d + nd²)

This is the same asymptotic cost as single-head attention with dimension d! Multi-head doesn't add computational overhead — it simply reorganizes the computation.

When does each term dominate? - n < d (short sequences, large models): nd² dominates — projections are the bottleneck - n > d (long sequences): n²d dominates — attention matrix is the bottleneck

5. Why Multiple Heads? The Subspace Argument

Consider a single-head attention with d_model = 512. The attention pattern is a single n×n matrix, and the value aggregation averages over all 512 dimensions with the same weights.

With h=8 heads of d_k=64 each: - Head 1 might learn positional patterns (attend to nearby tokens) - Head 2 might learn syntactic patterns (attend to verb for subject) - Head 3 might learn semantic patterns (attend to related entities) - etc.

Each head operates in a different 64-dimensional subspace of the full 512-dimensional space. The attention distribution can be DIFFERENT in each subspace — enabling the model to simultaneously track multiple types of relationships.

Mathematically: The joint attention distribution across all heads is a richer object than a single n×n matrix. It's an (h × n × n) tensor, though the attention is computed independently per head.

6. Parameter Count

For each head i: - W_i^Q: d_model × d_k = d × (d/h) = d²/h params - W_i^K: d²/h params - W_i^V: d²/h params

For h heads: 3h·(d²/h) = 3d² parameters

Plus W^O: h·d_v × d_model = d × d = d² parameters

Total: 4d² parameters (vs. 4d² for single-head with same total dimension — multi-head is parameter-neutral!)

Actually, single-head with dim d would have W^Q, W^K, W^V each d×d and W^O d×d: also 4d². Multi-head repartitions the SAME parameter budget.

7. Efficient Implementation

In practice, the per-head projections are implemented as one big matrix multiply:

Q_all = Q · W_all^Q where W_all^Q ∈ ℝ^(d × d)

Then the result is reshaped from (n, d) to (n, h, d_k) and transposed to (h, n, d_k) for batched attention computation:

# Concatenated view
Q = X @ W_Q          # (n, d)
Q = Q.view(n, h, d_k).transpose(0, 1)  # (h, n, d_k)

# Then batched attention
scores = Q @ K.transpose(-2, -1) / sqrt(d_k)  # (h, n, n)
attn = softmax(scores, dim=-1)
out = attn @ V  # (h, n, d_v)
out = out.transpose(0, 1).reshape(n, d)  # (n, d)
out = out @ W_O  # (n, d)

8. Attention Dropout

During training, dropout can be applied to the attention weights A (after softmax):

A_dropout = dropout(A, p)

This randomly zeros out some attention connections, forcing the model to not rely on any single attention path — a form of regularization specific to attention.



Key Terms

Worked Examples

Example 1: Dimensions Through Multi-Head

Problem: d_model=768, h=12 heads. Sequence length n=128. Compute the shapes of all intermediate tensors.

Solution: d_k = d_v = 768/12 = 64

Input: Q, K, V each ∈ ℝ^(128×768)

After per-head projection: Q_i ∈ ℝ^(128×64), K_i ∈ ℝ^(128×64), V_i ∈ ℝ^(128×64) for each head

Scores per head: S_i = Q_i·K_iᵀ/√64 ∈ ℝ^(128×128)

Attention weights: A_i = softmax(S_i) ∈ ℝ^(128×128)

Head output: head_i = A_i·V_i ∈ ℝ^(128×64)

Concatenated: Concat(head₁,...,head₁₂) ∈ ℝ^(128×768)

Final: out = Concat·W^O ∈ ℝ^(128×768)

Example 2: Parameter Count

Problem: Compute the number of parameters in the multi-head attention for d_model=512, h=8.

Solution: Per head: W_i^Q, W_i^K, W_i^V each 512×64 = 32,768 params 8 heads × 3 matrices × 32,768 = 786,432

W^O: (8×64) × 512 = 512 × 512 = 262,144

Total: 786,432 + 262,144 = 1,048,576 ≈ 1M parameters

Verification: 4d² = 4·512² = 4·262,144 = 1,048,576 ✓

Example 3: FLOPs Analysis

Problem: For n=2048, d=4096, h=32, compute approximate FLOPs and determine which term dominates.

Solution: Projections: O(nd²) = 2048 · 4096² ≈ 2048 · 16.8M ≈ 34.4B FLOPs Attention matrices: O(n²d) = 2048² · 4096 ≈ 4.2M · 4096 ≈ 17.2B FLOPs Value-weighted sum: O(n²d) ≈ 17.2B FLOPs

Total: ≈ 68.8B FLOPs. The projection term (34.4B) dominates slightly because d > n. For typical LLM inference (n≈2048, d≈4096), both terms are comparable.


Quiz

Q1: What does the concept of Multi-head attention primarily refer to in this subject?

A) The definition and application of Multi-head attention B) A visual representation of Multi-head attention C) A historical anecdote about Multi-head attention D) A computational error related to Multi-head attention

Correct: A)

Q2: What is the primary purpose of Project?

A) It replaces all other methods in this domain B) It is primarily a historical notation system C) It is used only in advanced research contexts D) It is used to project in mathematical analysis

Correct: D)

Q3: Which statement about Apply is TRUE?

A) Apply is a fundamental concept covered in this subject B) Apply is mentioned only as a historical footnote C) Apply is an advanced topic beyond this subject's scope D) Apply is not related to this subject

Correct: A)

Q4: Based on the worked examples in this subject, what is the correct result?

A) An unrelated numerical value B) The inverse of the correct answer C) ```python D) A different result from a common mistake

Correct: C)

Q5: How are Apply and Concatenate related?

A) Apply is the inverse of Concatenate B) Apply and Concatenate are closely related concepts C) Apply is a special case of Concatenate D) Apply and Concatenate are completely unrelated topics

Correct: B)

Q6: What is a common pitfall when working with The Limitation Of Single-Head Attention?

A) The Limitation Of Single-Head Attention is always computed the same way in all contexts B) The Limitation Of Single-Head Attention has no common misconceptions C) The main error with The Limitation Of Single-Head Attention is using it when it is not needed D) A common mistake is confusing The Limitation Of Single-Head Attention with a similar concept

Correct: D)

Q7: When should you apply Multi-Head Attention Equation?

A) Avoid Multi-Head Attention Equation unless explicitly instructed B) Use Multi-Head Attention Equation only in pure mathematics contexts C) Multi-Head Attention Equation is not practically useful D) Apply Multi-Head Attention Equation to solve problems in this subject's domain

Correct: D)

Practice Problems

Problem 1

What is d_k if d_model=1024 and h=16? Verify that h·d_k = d_model.

Answer d_k = 1024/16 = 64. Verification: 16·64 = 1024 = d_model ✓.

Problem 2

Why doesn't multi-head attention increase the total parameter count compared to single-head attention with the same total dimension?

Answer Single-head with d_model: W^Q, W^K, W^V each d×d = d², plus W^O d×d = d² → 4d² params. Multi-head: h copies of (d×d_k), where d_k = d/h. Total for Q projections: h·d·(d/h) = d². Same for K and V. Plus W^O = d². Total: 4d². The computation is the same; only the organization differs.

Problem 3

Explain why different heads can learn different attention patterns. What prevents all heads from converging to the same pattern?

Answer Each head has its OWN projection matrices W_i^Q, W_i^K, W_i^V, so each head projects into a different subspace. The attention is computed independently per head. While nothing fundamentally prevents convergence, the random initialization and the fact that different subspaces capture different aspects of the data mean heads naturally specialize. In practice, some redundancy is observed — not all heads are unique.

Problem 4

How many attention weight matrices (A_i) are produced by a multi-head attention layer with h=8, batch_size=32, n=128?

Answer 8 heads × 32 batch items = 256 attention matrices, each 128×128. Total memory: 256 · 128² · 4 bytes (float32) ≈ 16.8 MB for attention weights alone.

Problem 5

In multi-head cross-attention, where do Q, K, and V come from?

Answer Q comes from the decoder's hidden states (or previous layer output). K and V come from the encoder's output. Each head projects these separately, allowing different heads to attend to different aspects of the encoded source sequence.

Summary

  1. Multi-head attention runs h parallel attention operations with learned projections: head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
  2. Output = Concat(head₁,...,head_h)·W^O merges the h per-head outputs into d_model dimensions
  3. Computational cost is O(n²d + nd²) — asymptotically identical to single-head with same total dimension
  4. Different heads learn to attend to different representation subspaces, enabling the model to jointly track multiple relationship types
  5. Parameter count is exactly 4d² regardless of number of heads — the parameter budget is repartitioned, not increased

Pitfalls



Next Steps

Continue to 17-09 — Transformer Architecture to see how multi-head attention, feedforward networks, and residual connections are combined into the full Transformer.