Math graphic
📐 Concept diagram

17-07 — Scaled Dot-Product Attention

Phase: 17 — Deep Learning Architectures (Math) Subject: 17-07 Prerequisites: 17-06 (Attention Mechanism), 16-03 (Softmax Function), 11-10 (Information Theory — entropy concepts helpful) Next subject: 17-08 — Multi-Head Attention


Learning Objectives

By the end of this subject, you will be able to:

  1. Derive the variance argument for why scaling by √d_k is mathematically necessary
  2. Compute the full scaled dot-product attention forward pass: Attention(Q,K,V) = softmax(QKᵀ/√d_k)V
  3. Explain the shape conventions for Q, K, V in both self-attention and cross-attention
  4. Derive and apply causal (autoregressive) masking with the mask matrix
  5. Analyze the softmax saturation problem when d_k is large (without scaling) and prove scaling fixes it

Core Content

1. The Problem with Unscaled Dot-Product Attention

Recall from 17-06: dot-product attention computes scores as S = QKᵀ. When the key dimension d_k is large, the dot products grow in magnitude.

The variance argument:

Assume the components of q and k are independent random variables with mean 0 and variance 1 (reasonable at initialization).

qᵀk = Σ_{i=1}^{d_k} q_i · k_i

Each term q_i·k_i has mean 0 and variance E[q_i²]·E[k_i²] = 1·1 = 1 (by independence).

By the Central Limit Theorem, qᵀk ~ N(0, d_k) — the variance is d_k!

So the raw scores S have entries with magnitude proportional to √d_k. Large d_k → very large magnitude scores → softmax enters saturation region.

2. Why Softmax Saturation is Bad

Consider softmax(s) where some s_i are very large and others very small:

softmax(s)_i = e^{s_i} / Σ e^{s_j}

If max(s) − min(s) is large (say 10+): - The largest softmax output is ≈ 1.0 - All other outputs are ≈ 0.0

This means: 1. Vanishing gradients: For the near-zero outputs, ∂softmax_i/∂s_i ≈ softmax_i ≈ 0 — no learning signal for those positions 2. Hard attention: The model effectively does argmax, losing the benefits of soft attention 3. Training instability: Small changes in keys can flip which position gets attention, causing erratic gradients

⚠️ THIS IS CRITICAL — Without scaling, attention would be near one-hot in high dimensions. The model would attend exclusively to one position, and gradients for all other positions would vanish. The √d_k scaling keeps the scores in a regime where the softmax produces a meaningful distribution.

3. The Solution: Scale by √d_k

Attention(Q, K, V) = softmax(QKᵀ / √d_k) V

After scaling: (qᵀk)/√d_k has variance 1 (dividing by √d_k cancels the d_k growth).

The scaled scores have mean 0, variance 1 — they stay in the "linear" regime of the softmax, producing a balanced attention distribution.

Geometric intuition: qᵀk = ||q||·||k||·cos(θ). If ||q|| ≈ ||k|| ≈ √d_k (which is typical at initialization for standard normal components), then |qᵀk| ≈ d_k·|cos(θ)|. Scaling by √d_k gives approximately √d_k·|cos(θ)|, which still grows with d_k! But in practice, with learned projections (as in Multi-Head Attention in 17-08), the norms don't grow as √d_k after training begins, so scaling by √d_k works well.

4. Full Forward Pass (Matrix Operations)

Given: - Q ∈ ℝ^(n_q × d_k) — n_q query vectors - K ∈ ℝ^(n_k × d_k) — n_k key vectors - V ∈ ℝ^(n_k × d_v) — n_k value vectors

Step 1: Compute scaled scores

S = QKᵀ / √d_k ∈ ℝ^(n_q × n_k)

Step 2: (Optional) Apply mask

S_masked = S + M where M[i,j] = −∞ for positions to mask, 0 otherwise

Step 3: Softmax (row-wise)

A[i,:] = softmax(S_masked[i,:]) ∈ ℝ^(n_q × n_k)

Step 4: Weighted sum

O = AV ∈ ℝ^(n_q × d_v)

5. Causal (Autoregressive) Masking

For autoregressive models (like GPT), we must prevent position i from attending to positions j > i (the future). We use a causal mask:

M_causal[i,j] = { 0 if j ≤ i { −∞ if j > i

After adding to S: S_masked[i,j] = −∞ for j > i, so softmax(−∞) = 0. Position i can only attend to positions ≤ i.

The mask matrix (for n=4):

$[  0  −∞  −∞  −∞ ]
[  0   0  −∞  −∞ ]
[  0   0   0  −∞ ]
[  0   0   0   0 ]
$

Implementation note: In practice, we use a very negative number like −1e9 instead of −∞. The softmax of a sufficiently negative number is ~0 (numerically indistinguishable from true 0).

6. Padding Mask

For variable-length sequences in a batch, we pad shorter sequences to the max length. Padding tokens should be ignored:

M_padding[i,j] = { 0 if j is not padding { −∞ if j is padding

In practice, the final mask is M = M_causal + M_padding, applied to the scores before softmax.

7. Shape Conventions

Throughout Transformers, the convention is:

Symbol Shape Meaning
n_q scalar Number of query positions
n_k scalar Number of key positions
d_k scalar Key/query dimension
d_v scalar Value dimension
d_model scalar Model hidden dimension

For self-attention: n_q = n_k = n (sequence length), d_k = d_v = d_model / h (per head)

8. Gradient Check (Numerical Stability)

Recall from 17-06 the attention gradient. With scaling:

S = QKᵀ/√d_k → ∂L/∂Q = (∂L/∂S · K) / √d_k, ∂L/∂K = ((∂L/∂S)ᵀ · Q) / √d_k

The 1/√d_k factor propagates through the backward pass cleanly, reducing gradient magnitudes proportionally — this actually helps with training stability.

9. Why Not Scale by d_k Instead?

Scale by d_k (instead of √d_k) would make the variance 1/d_k — scores would shrink as d_k grows. The softmax would approach uniform (all 1/n_k), losing the ability to focus attention.

Scale by 1 (no scaling): variance d_k, leading to near-one-hot softmax.

Scale by √d_k: variance 1, keeping softmax in the "interesting" regime. Goldilocks principle!



Key Terms

Worked Examples

Example 1: Variance Growth and Scaling

Problem: For d_k = 64, assume q_i, k_i ~ N(0,1) i.i.d. What is the variance of qᵀk? After scaling by √64? After scaling by 64?

Solution: Var(qᵀk) = d_k·Var(q_i·k_i) = 64·1 = 64. Std = 8.

After √d_k scaling: Var(qᵀk/8) = 64/64 = 1. Std = 1. After d_k scaling: Var(qᵀk/64) = 64/4096 = 1/64. Std = 1/8.

The √d_k scaling puts scores in the ideal variance-1 regime.

Example 2: Causal Attention Computation

Problem: Q = K = V = [[1,0],[2,0],[3,0]] — 3 positions, d=2. Compute causal scaled dot-product attention with d_k=2.

Solution: S = QKᵀ/√2 = (1/√2)·[[1,2,3],[2,4,6],[3,6,9]] ≈ [[0.707, 1.414, 2.121],[1.414, 2.828, 4.243],[2.121, 4.243, 6.364]]

Apply causal mask (upper triangle → −∞): [[0.707, −∞, −∞], [1.414, 2.828, −∞], [2.121, 4.243, 6.364]]

Row 0 softmax: only position 0 → A[0] = [1, 0, 0] Row 1 softmax: exp(1.414) vs exp(2.828) → A[1] = [e¹·⁴¹⁴/(e¹·⁴¹⁴+e²·⁸²⁸), e²·⁸²⁸/(e¹·⁴¹⁴+e²·⁸²⁸), 0] = [4.11/(4.11+16.92), 16.92/21.03, 0] = [0.196, 0.804, 0]

Row 2: three values → A[2] = [0.063, 0.262, 0.675]

Output = A·V = A·[[1,0],[2,0],[3,0]]

O[0] = [1,0] O[1] = 0.196·[1,0] + 0.804·[2,0] = [1.804, 0] O[2] = 0.063·[1,0] + 0.262·[2,0] + 0.675·[3,0] = [2.612, 0]

Example 3: Effect of Scale on Softmax Distribution

Problem: For scores s = [0, d_k, 2·d_k] with d_k = 64. Compute softmax with no scaling, √d_k scaling, and d_k scaling.

Solution:

No scaling: s = [0, 64, 128] softmax = [e⁰, e⁶⁴, e¹²⁸] / (e⁰+e⁶⁴+e¹²⁸) ≈ [0, 0, 1] — saturated

√d_k = 8: s = [0, 8, 16] softmax ≈ [e⁰, e⁸, e¹⁶] / (1+2981+8.89e6) = [1.12e-7, 0.0003, 0.9997] — still peaked

d_k = 64: s = [0, 1, 2] softmax = [e⁰, e¹, e²] / (1+2.718+7.389) = [0.09, 0.245, 0.665] — nicely distributed!

The √d_k scaling achieves the best balance between the saturated one-hot regime (no scaling) and the overly-uniform regime (d_k scaling).


Quiz

Q1: What does the concept of Causal (Autoregressive) Masking primarily refer to in this subject?

A) A computational error related to Causal (Autoregressive) Masking B) A visual representation of Causal (Autoregressive) Masking C) The definition and application of Causal (Autoregressive) Masking D) A historical anecdote about Causal (Autoregressive) Masking

Correct: C)

Q2: What is the primary purpose of End-of-Subject Quiz?

A) It replaces all other methods in this domain B) It is used only in advanced research contexts C) It is used to end-of-subject quiz in mathematical analysis D) It is primarily a historical notation system

Correct: C)

Q3: Which statement about Full Forward Pass (Matrix Operations) is TRUE?

A) Full Forward Pass (Matrix Operations) is mentioned only as a historical footnote B) Full Forward Pass (Matrix Operations) is an advanced topic beyond this subject's scope C) Full Forward Pass (Matrix Operations) is a fundamental concept covered in this subject D) Full Forward Pass (Matrix Operations) is not related to this subject

Correct: C)

Q4: Based on the worked examples in this subject, what is the correct result?

A) softmax(QKᵀ/√d_k)V — scaling is essential B) The inverse of the correct answer C) A different result from a common mistake D) An unrelated numerical value

Correct: A)

Q5: How are Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) related?

A) Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are completely unrelated topics B) Full Forward Pass (Matrix Operations) is the inverse of Gradient Check (Numerical Stability) C) Full Forward Pass (Matrix Operations) is a special case of Gradient Check (Numerical Stability) D) Full Forward Pass (Matrix Operations) and Gradient Check (Numerical Stability) are closely related concepts

Correct: D)

Q6: What is a common pitfall when working with Padding Mask?

A) Padding Mask has no common misconceptions B) A common mistake is confusing Padding Mask with a similar concept C) The main error with Padding Mask is using it when it is not needed D) Padding Mask is always computed the same way in all contexts

Correct: B)

Q7: When should you apply The Problem With Unscaled Dot-Product Attention?

A) Use The Problem With Unscaled Dot-Product Attention only in pure mathematics contexts B) The Problem With Unscaled Dot-Product Attention is not practically useful C) Avoid The Problem With Unscaled Dot-Product Attention unless explicitly instructed D) Apply The Problem With Unscaled Dot-Product Attention to solve problems in this subject's domain

Correct: D)

Practice Problems

Problem 1

Why does qᵀk have variance proportional to d_k when q and k components are i.i.d. with unit variance?

Answer qᵀk = Σ q_i·k_i. Since q_i ⟂ k_i, Var(q_i·k_i) = Var(q_i)·Var(k_i) = 1. Sum of d_k independent terms → variance d_k. This is a direct consequence of the variance of a sum of independent random variables.

Problem 2

What is the shape of S = QKᵀ/√d_k if Q ∈ ℝ^(8×64), K ∈ ℝ^(12×64)?

Answer S ∈ ℝ^(8×12). 8 queries × 12 keys. The scaling doesn't change the shape.

Problem 3

Write the causal mask matrix for n=3. What happens to masked positions in the softmax?

Answer M = [[0, −∞, −∞],[0, 0, −∞],[0, 0, 0]] (lower triangular with zeros). For masked positions, e^{−∞} = 0, so softmax = 0. The remaining probabilities are redistributed among the unmasked positions (they sum to 1).

Problem 4

If d_k = 256 and scores without scaling have std ≈ 16, what will the scaled scores' std be? What about scaling by d_k?

Answer With √d_k = 16 scaling: std ≈ 1. With d_k = 256 scaling: std ≈ 1/16 = 0.0625. The √d_k scaling gives a standard normal range where softmax is "interesting."

Problem 5

Explain why softmax saturation leads to vanishing gradients in attention.

Answer When softmax saturates (some outputs ≈ 1, others ≈ 0), the Jacobian ∂A_i/∂S_j has entries ≈ 0. The gradient ∂L/∂S passes through this Jacobian, giving near-zero signals to positions that didn't "win" the softmax. Those key-value pairs stop receiving learning signals, even though they might be useful with different parameter values.

Summary

  1. Scaled dot-product attention: Attention(Q,K,V) = softmax(QKᵀ/√d_k)V — scaling is essential
  2. Without scaling, qᵀk has variance d_k, causing softmax saturation into near-one-hot distributions
  3. Scaling by √d_k normalizes the variance to ≈ 1, keeping softmax in the balanced regime
  4. Causal masking (upper triangular −∞) prevents attending to future positions for autoregressive models
  5. Gradient flow is critically dependent on the softmax not saturating — scaling ensures meaningful gradients for all positions

Pitfalls



Next Steps

Continue to 17-08 — Multi-Head Attention to learn how splitting attention into multiple heads enables the model to attend to different representation subspaces simultaneously.