17-07 — Scaled Dot-Product Attention
Phase: 17 — Deep Learning Architectures (Math) Subject: 17-07 Prerequisites: 17-06 (Attention Mechanism), 16-03 (Softmax Function), 11-10 (Information Theory — entropy concepts helpful) Next subject: 17-08 — Multi-Head Attention
Learning Objectives
By the end of this subject, you will be able to:
- Derive the variance argument for why scaling by √d_k is mathematically necessary
- Compute the full scaled dot-product attention forward pass: Attention(Q,K,V) = softmax(QKᵀ/√d_k)V
- Explain the shape conventions for Q, K, V in both self-attention and cross-attention
- Derive and apply causal (autoregressive) masking with the mask matrix
- Analyze the softmax saturation problem when d_k is large (without scaling) and prove scaling fixes it
Core Content
1. The Problem with Unscaled Dot-Product Attention
Recall from 17-06: dot-product attention computes scores as S = QKᵀ. When the key dimension d_k is large, the dot products grow in magnitude.
The variance argument:
Assume the components of q and k are independent random variables with mean 0 and variance 1 (reasonable at initialization).
qᵀk = Σ_{i=1}^{d_k} q_i · k_i
Each term q_i·k_i has mean 0 and variance E[q_i²]·E[k_i²] = 1·1 = 1 (by independence).
By the Central Limit Theorem, qᵀk ~ N(0, d_k) — the variance is d_k!
So the raw scores S have entries with magnitude proportional to √d_k. Large d_k → very large magnitude scores → softmax enters saturation region.
2. Why Softmax Saturation is Bad
Consider softmax(s) where some s_i are very large and others very small:
softmax(s)_i = e^{s_i} / Σ e^{s_j}
If max(s) − min(s) is large (say 10+): - The largest softmax output is ≈ 1.0 - All other outputs are ≈ 0.0
This means: 1. Vanishing gradients: For the near-zero outputs, ∂softmax_i/∂s_i ≈ softmax_i ≈ 0 — no learning signal for those positions 2. Hard attention: The model effectively does argmax, losing the benefits of soft attention 3. Training instability: Small changes in keys can flip which position gets attention, causing erratic gradients
⚠️ THIS IS CRITICAL — Without scaling, attention would be near one-hot in high dimensions. The model would attend exclusively to one position, and gradients for all other positions would vanish. The √d_k scaling keeps the scores in a regime where the softmax produces a meaningful distribution.
3. The Solution: Scale by √d_k
Attention(Q, K, V) = softmax(QKᵀ / √d_k) V
After scaling: (qᵀk)/√d_k has variance 1 (dividing by √d_k cancels the d_k growth).
The scaled scores have mean 0, variance 1 — they stay in the "linear" regime of the softmax, producing a balanced attention distribution.
Geometric intuition: qᵀk = ||q||·||k||·cos(θ). If ||q|| ≈ ||k|| ≈ √d_k (which is typical at initialization for standard normal components), then |qᵀk| ≈ d_k·|cos(θ)|. Scaling by √d_k gives approximately √d_k·|cos(θ)|, which still grows with d_k! But in practice, with learned projections (as in Multi-Head Attention in 17-08), the norms don't grow as √d_k after training begins, so scaling by √d_k works well.
4. Full Forward Pass (Matrix Operations)
Given: - Q ∈ ℝ^(n_q × d_k) — n_q query vectors - K ∈ ℝ^(n_k × d_k) — n_k key vectors - V ∈ ℝ^(n_k × d_v) — n_k value vectors
Step 1: Compute scaled scores
S = QKᵀ / √d_k ∈ ℝ^(n_q × n_k)
Step 2: (Optional) Apply mask
S_masked = S + M where M[i,j] = −∞ for positions to mask, 0 otherwise
Step 3: Softmax (row-wise)
A[i,:] = softmax(S_masked[i,:]) ∈ ℝ^(n_q × n_k)
Step 4: Weighted sum
O = AV ∈ ℝ^(n_q × d_v)
5. Causal (Autoregressive) Masking
For autoregressive models (like GPT), we must prevent position i from attending to positions j > i (the future). We use a causal mask:
M_causal[i,j] = { 0 if j ≤ i { −∞ if j > i
After adding to S: S_masked[i,j] = −∞ for j > i, so softmax(−∞) = 0. Position i can only attend to positions ≤ i.
The mask matrix (for n=4):
$[ 0 −∞ −∞ −∞ ] [ 0 0 −∞ −∞ ] [ 0 0 0 −∞ ] [ 0 0 0 0 ] $
Implementation note: In practice, we use a very negative number like −1e9 instead of −∞. The softmax of a sufficiently negative number is ~0 (numerically indistinguishable from true 0).
6. Padding Mask
For variable-length sequences in a batch, we pad shorter sequences to the max length. Padding tokens should be ignored:
M_padding[i,j] = { 0 if j is not padding { −∞ if j is padding
In practice, the final mask is M = M_causal + M_padding, applied to the scores before softmax.
7. Shape Conventions
Throughout Transformers, the convention is:
| Symbol | Shape | Meaning |
|---|---|---|
| n_q | scalar | Number of query positions |
| n_k | scalar | Number of key positions |
| d_k | scalar | Key/query dimension |
| d_v | scalar | Value dimension |
| d_model | scalar | Model hidden dimension |
For self-attention: n_q = n_k = n (sequence length), d_k = d_v = d_model / h (per head)
8. Gradient Check (Numerical Stability)
Recall from 17-06 the attention gradient. With scaling:
S = QKᵀ/√d_k → ∂L/∂Q = (∂L/∂S · K) / √d_k, ∂L/∂K = ((∂L/∂S)ᵀ · Q) / √d_k
The 1/√d_k factor propagates through the backward pass cleanly, reducing gradient magnitudes proportionally — this actually helps with training stability.
9. Why Not Scale by d_k Instead?
Scale by d_k (instead of √d_k) would make the variance 1/d_k — scores would shrink as d_k grows. The softmax would approach uniform (all 1/n_k), losing the ability to focus attention.
Scale by 1 (no scaling): variance d_k, leading to near-one-hot softmax.
Scale by √d_k: variance 1, keeping softmax in the "interesting" regime. Goldilocks principle!
Key Terms
- 17 07 Scaled Dot Product Attention
- A) n_q × n_k
- C) O(n²)
- Causal (Autoregressive) Masking
- End-of-Subject Quiz
- Example 1: Variance Growth and Scaling
- Example 2: Causal Attention Computation
- Example 3: Effect of Scale on Softmax Distribution
- Full Forward Pass (Matrix Operations)
- Gradient Check (Numerical Stability)
- Padding Mask
- Problem 1
Worked Examples
Example 1: Variance Growth and Scaling
Problem: For d_k = 64, assume q_i, k_i ~ N(0,1) i.i.d. What is the variance of qᵀk? After scaling by √64? After scaling by 64?
Solution: Var(qᵀk) = d_k·Var(q_i·k_i) = 64·1 = 64. Std = 8.
After √d_k scaling: Var(qᵀk/8) = 64/64 = 1. Std = 1. After d_k scaling: Var(qᵀk/64) = 64/4096 = 1/64. Std = 1/8.
The √d_k scaling puts scores in the ideal variance-1 regime.
Example 2: Causal Attention Computation
Problem: Q = K = V = [[1,0],[2,0],[3,0]] — 3 positions, d=2. Compute causal scaled dot-product attention with d_k=2.
Solution: S = QKᵀ/√2 = (1/√2)·[[1,2,3],[2,4,6],[3,6,9]] ≈ [[0.707, 1.414, 2.121],[1.414, 2.828, 4.243],[2.121, 4.243, 6.364]]
Apply causal mask (upper triangle → −∞): [[0.707, −∞, −∞], [1.414, 2.828, −∞], [2.121, 4.243, 6.364]]
Row 0 softmax: only position 0 → A[0] = [1, 0, 0] Row 1 softmax: exp(1.414) vs exp(2.828) → A[1] = [e¹·⁴¹⁴/(e¹·⁴¹⁴+e²·⁸²⁸), e²·⁸²⁸/(e¹·⁴¹⁴+e²·⁸²⁸), 0] = [4.11/(4.11+16.92), 16.92/21.03, 0] = [0.196, 0.804, 0]
Row 2: three values → A[2] = [0.063, 0.262, 0.675]
Output = A·V = A·[[1,0],[2,0],[3,0]]
O[0] = [1,0] O[1] = 0.196·[1,0] + 0.804·[2,0] = [1.804, 0] O[2] = 0.063·[1,0] + 0.262·[2,0] + 0.675·[3,0] = [2.612, 0]
Example 3: Effect of Scale on Softmax Distribution
Problem: For scores s = [0, d_k, 2·d_k] with d_k = 64. Compute softmax with no scaling, √d_k scaling, and d_k scaling.
Solution:
No scaling: s = [0, 64, 128] softmax = [e⁰, e⁶⁴, e¹²⁸] / (e⁰+e⁶⁴+e¹²⁸) ≈ [0, 0, 1] — saturated
√d_k = 8: s = [0, 8, 16] softmax ≈ [e⁰, e⁸, e¹⁶] / (1+2981+8.89e6) = [1.12e-7, 0.0003, 0.9997] — still peaked
d_k = 64: s = [0, 1, 2] softmax = [e⁰, e¹, e²] / (1+2.718+7.389) = [0.09, 0.245, 0.665] — nicely distributed!
The √d_k scaling achieves the best balance between the saturated one-hot regime (no scaling) and the overly-uniform regime (d_k scaling).
Quiz
Q1: What does the concept of Causal (Autoregressive) Masking primarily refer to in this subject?
A) A computational error related to Causal (Autoregressive) Masking B) A visual representation of Causal (Autoregressive) Masking C) The definition and application of Causal (Autoregressive) Masking D) A historical anecdote about Causal (Autoregressive) Masking
Correct: C)
- If you chose A: This is incorrect. Causal (Autoregressive) Masking is defined as: the definition and application of causal (autoregressive) masking. The other options describe different aspects that are not the primary focus.
- If you chose B: This is incorrect. Causal (Autoregressive) Masking is defined as: the definition and application of causal (autoregressive) masking. The other options describe different aspects that are not the primary focus.
- If you chose C: Causal (Autoregressive) Masking is defined as: the definition and application of causal (autoregressive) masking. The other options describe different aspects that are not the primary focus. Correct!
- If you chose D: This is incorrect. Causal (Autoregressive) Masking is defined as: the definition and application of causal (autoregressive) masking. The other options describe different aspects that are not the primary focus.
Q2: What is the primary purpose of End-of-Subject Quiz?
A) It replaces all other methods in this domain B) It is used only in advanced research contexts C) It is used to end-of-subject quiz in mathematical analysis D) It is primarily a historical notation system
Correct: C)
- If you chose A: This is incorrect. End-of-Subject Quiz serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose B: This is incorrect. End-of-Subject Quiz serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose C: End-of-Subject Quiz serves the purpose described in the correct answer. The other options misrepresent its role. Correct!
- If you chose D: This is incorrect. End-of-Subject Quiz serves the purpose described in the correct answer. The other options misrepresent its role.
Q3: Which statement about Full Forward Pass (Matrix Operations) is TRUE?
A) Full Forward Pass (Matrix Operations) is mentioned only as a historical footnote B) Full Forward Pass (Matrix Operations) is an advanced topic beyond this subject's scope C) Full Forward Pass (Matrix Operations) is a fundamental concept covered in this subject D) Full Forward Pass (Matrix Operations) is not related to this subject
Correct: C)
- If you chose A: This is incorrect. Full Forward Pass (Matrix Operations) is a fundamental concept covered in this subject. This subject covers Full Forward Pass (Matrix Operations) as part of its core content.
- If you chose B: This is incorrect. Full Forward Pass (Matrix Operations) is a fundamental concept covered in this subject. This subject covers Full Forward Pass (Matrix Operations) as part of its core content.
- If you chose C: Full Forward Pass (Matrix Operations) is a fundamental concept covered in this subject. This subject covers Full Forward Pass (Matrix Operations) as part of its core content. Correct!
- If you chose D: This is incorrect. Full Forward Pass (Matrix Operations) is a fundamental concept covered in this subject. This subject covers Full Forward Pass (Matrix Operations) as part of its core content.
Q4: Based on the worked examples in this subject, what is the correct result?
A) softmax(QKᵀ/√d_k)V — scaling is essential B) The inverse of the correct answer C) A different result from a common mistake D) An unrelated numerical value
Correct: A)
- If you chose A: The worked examples show that the result is softmax(QKᵀ/√d_k)V — scaling is essential. The other options represent common errors. Correct!
- If you chose B: This is incorrect. The worked examples show that the result is softmax(QKᵀ/√d_k)V — scaling is essential. The other options represent common errors.
- If you chose C: This is incorrect. The worked examples show that the result is softmax(QKᵀ/√d_k)V — scaling is essential. The other options represent common errors.
- If you chose D: This is incorrect. The worked examples show that the result is softmax(QKᵀ/√d_k)V — scaling is essential. The other options represent common errors.
Q5: How are Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) related?
A) Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are completely unrelated topics B) Full Forward Pass (Matrix Operations) is the inverse of Gradient Check (Numerical Stability) C) Full Forward Pass (Matrix Operations) is a special case of Gradient Check (Numerical Stability) D) Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are closely related concepts
Correct: D)
- If you chose A: This is incorrect. Both Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are covered in this subject as interconnected topics.
- If you chose B: This is incorrect. Both Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are covered in this subject as interconnected topics.
- If you chose C: This is incorrect. Both Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are covered in this subject as interconnected topics.
- If you chose D: Both Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are covered in this subject as interconnected topics. Correct!
Q6: What is a common pitfall when working with Padding Mask?
A) Padding Mask has no common misconceptions B) A common mistake is confusing Padding Mask with a similar concept C) The main error with Padding Mask is using it when it is not needed D) Padding Mask is always computed the same way in all contexts
Correct: B)
- If you chose A: This is incorrect. Students often confuse Padding Mask with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose B: Students often confuse Padding Mask with similar-sounding or related concepts. Pay attention to the precise definitions. Correct!
- If you chose C: This is incorrect. Students often confuse Padding Mask with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose D: This is incorrect. Students often confuse Padding Mask with similar-sounding or related concepts. Pay attention to the precise definitions.
Q7: When should you apply The Problem With Unscaled Dot-Product Attention?
A) Use The Problem With Unscaled Dot-Product Attention only in pure mathematics contexts B) The Problem With Unscaled Dot-Product Attention is not practically useful C) Avoid The Problem With Unscaled Dot-Product Attention unless explicitly instructed D) Apply The Problem With Unscaled Dot-Product Attention to solve problems in this subject's domain
Correct: D)
- If you chose A: This is incorrect. The Problem With Unscaled Dot-Product Attention is a practical tool used throughout this subject to solve relevant problems.
- If you chose B: This is incorrect. The Problem With Unscaled Dot-Product Attention is a practical tool used throughout this subject to solve relevant problems.
- If you chose C: This is incorrect. The Problem With Unscaled Dot-Product Attention is a practical tool used throughout this subject to solve relevant problems.
- If you chose D: The Problem With Unscaled Dot-Product Attention is a practical tool used throughout this subject to solve relevant problems. Correct!
Practice Problems
Problem 1
Why does qᵀk have variance proportional to d_k when q and k components are i.i.d. with unit variance?
Answer
qᵀk = Σ q_i·k_i. Since q_i ⟂ k_i, Var(q_i·k_i) = Var(q_i)·Var(k_i) = 1. Sum of d_k independent terms → variance d_k. This is a direct consequence of the variance of a sum of independent random variables.Problem 2
What is the shape of S = QKᵀ/√d_k if Q ∈ ℝ^(8×64), K ∈ ℝ^(12×64)?
Answer
S ∈ ℝ^(8×12). 8 queries × 12 keys. The scaling doesn't change the shape.Problem 3
Write the causal mask matrix for n=3. What happens to masked positions in the softmax?
Answer
M = [[0, −∞, −∞],[0, 0, −∞],[0, 0, 0]] (lower triangular with zeros). For masked positions, e^{−∞} = 0, so softmax = 0. The remaining probabilities are redistributed among the unmasked positions (they sum to 1).Problem 4
If d_k = 256 and scores without scaling have std ≈ 16, what will the scaled scores' std be? What about scaling by d_k?
Answer
With √d_k = 16 scaling: std ≈ 1. With d_k = 256 scaling: std ≈ 1/16 = 0.0625. The √d_k scaling gives a standard normal range where softmax is "interesting."Problem 5
Explain why softmax saturation leads to vanishing gradients in attention.
Answer
When softmax saturates (some outputs ≈ 1, others ≈ 0), the Jacobian ∂A_i/∂S_j has entries ≈ 0. The gradient ∂L/∂S passes through this Jacobian, giving near-zero signals to positions that didn't "win" the softmax. Those key-value pairs stop receiving learning signals, even though they might be useful with different parameter values.Summary
- Scaled dot-product attention: Attention(Q,K,V) = softmax(QKᵀ/√d_k)V — scaling is essential
- Without scaling, qᵀk has variance d_k, causing softmax saturation into near-one-hot distributions
- Scaling by √d_k normalizes the variance to ≈ 1, keeping softmax in the balanced regime
- Causal masking (upper triangular −∞) prevents attending to future positions for autoregressive models
- Gradient flow is critically dependent on the softmax not saturating — scaling ensures meaningful gradients for all positions
Pitfalls
- Omitting the √d_k scaling factor. This is the single most common attention implementation bug. Without scaling, qᵀk has variance proportional to d_k, the softmax saturates to near-one-hot, and gradients vanish for all non-winning positions. The model may still train but converges slowly and generalizes poorly. Always divide by √d_k before softmax.
- Dividing by d_k instead of √d_k. Over-scaling pushes the softmax toward uniform (all entries ≈ 1/n_k), making attention unable to focus on specific positions. The model defaults to simple averaging, losing the benefit of content-based lookup. The √d_k is the mathematically correct factor — it normalizes the variance of qᵀk to 1 under i.i.d. assumptions.
- Using -1 or 0 instead of -∞ (or -1e9) for mask values. The mask is added BEFORE softmax, and softmax(-1) ≈ 0.212, not 0. Masked positions will still have non-trivial probability mass, distorting the attention distribution. Always use a sufficiently negative value (typically -1e9 for float32, -65504 for float16) so that exp(masked_score) ≈ 0 numerically.
- Forgetting to apply the causal mask during training. Without masking, the model "cheats" by attending to future tokens during training. The training loss will be artificially low but the model will be completely broken at inference time (when it can't see the future). Always verify that your mask is upper-triangular with -∞ for j > i.
- Not considering numerical stability with mixed masking. When combining causal and padding masks, naive addition can produce NaN (e.g., -∞ + -∞ in some frameworks). Use a combined mask construction or replace the additive mask with a score replacement approach that's numerically stable.
Next Steps
Continue to 17-08 — Multi-Head Attention to learn how splitting attention into multiple heads enables the model to attend to different representation subspaces simultaneously.