18-09 — KV Caching
Phase: 18 — Large Language Model Mathematics Subject: 18-09 Prerequisites: 18-08 (Inference Mathematics), 18-05 (Decoder-Only Architecture), 17-07 (Scaled Dot-Product Attention), 09-09 (Numerical Linear Algebra — memory analysis) Next subject: 18-10 — Transformer Variants (Math Focus)
Learning Objectives
By the end of this subject, you will be able to:
- Derive the KV cache mechanism from the causal attention equations and prove correctness
- Compute the exact memory footprint of the KV cache as a function of batch size, sequence length, layers, heads, and head dimension
- Analyze the compute-to-memory ratio and explain why LLM inference is memory-bandwidth-bound
- Compare multi-query attention (MQA) and grouped-query attention (GQA) in terms of KV cache reduction
- Derive the computational complexity reduction: O(L·n²) → O(L·n) per new token with KV caching
Core Content
1. The Problem: Naive Autoregressive Inference
Without caching, generating token t requires computing attention over ALL previous t tokens from scratch:
At step t (generating token t):
For each layer ℓ: 1. Compute q_t = W_Q · x_t (query for the new token) 2. Compute K_t = W_K · X_{0:t} (keys for ALL tokens 0..t) 3. Compute V_t = W_V · X_{0:t} (values for ALL tokens 0..t) 4. Compute attention: A = softmax(q_t · K_t^T / √d_k) · V_t
This recomputes K and V for tokens 0..t−1 at EVERY step. The total work for generating T tokens:
Total K,V computations = Σ_{t=1}^{T} t = T(T+1)/2 = O(T²) per layer
With L layers and d model dimension: total FLOPs ≈ L · d² · T² (dominated by the K,V projections repeated for all prefixes).
⚠️ THIS IS CRITICAL — Without caching, generating T tokens costs O(T²) where T can be thousands. KV caching reduces this to O(T).
2. The KV Cache: Core Idea
Observation: The key and value vectors for previous tokens DON'T CHANGE when we add new tokens. The causal mask ensures that token i's attention only depends on tokens ≤ i, and those tokens' K and V representations are fixed once computed.
Solution: Store (cache) the K and V vectors for all previously processed tokens. At each new step, compute only the K and V for the new token, append to the cache, and compute attention using the full cache.
Cache structure: For a model with: - L layers - H heads per layer - d_head dimensions per head - Sequence length n (growing)
The KV cache stores:
K_cache ∈ ℝ^(L × H × n × d_head) (keys) V_cache ∈ ℝ^(L × H × n × d_head) (values)
(Times 2 for K and V, times batch size B for batched inference.)
3. KV Cache Algorithm
Initialization (prefill): Process the prompt of length P in one forward pass. Store K and V for all P tokens in the cache.
For each layer ℓ, head h: K_cache[ℓ, h, 0:P] = K_{ℓ,h}[0:P] V_cache[ℓ, h, 0:P] = V_{ℓ,h}[0:P]
Generation (each new token):
For token at position t = P, P+1, ..., T−1:
- Compute x_t (the embedding for the new token, plus previous layer output)
- For each layer ℓ: a. Compute q_{t,ℓ,h} = W_Q^{ℓ,h} · x_t (query, one token) b. Compute k_{t,ℓ,h} = W_K^{ℓ,h} · x_t (key, one token) c. Compute v_{t,ℓ,h} = W_V^{ℓ,h} · x_t (value, one token) d. Append to cache: K_cache[ℓ, h, t] = k_{t,ℓ,h}, V_cache[ℓ, h, t] = v_{t,ℓ,h} e. Compute attention: score = q_{t,ℓ,h} · K_cache[ℓ, h, 0:t+1]^T / √d_head f. Attention output = softmax(score) · V_cache[ℓ, h, 0:t+1] g. Combine heads and continue to next sub-layer
Key in Step 2e: The query q_t (shape: 1 × d_head) is multiplied against K_cache[ℓ, h, 0:t+1] (shape: (t+1) × d_head). This single dot product replaces the full O(t²) attention computation.
4. Memory Footprint Analysis
Per token: Each token stores H K-vectors and H V-vectors per layer, each of dimension d_head.
bytes_per_token = 2 · L · H · d_head · bytes_per_element
Where bytes_per_element = 2 for fp16, 1 for int8 quantization.
For Llama 2 7B: L=32, H=32, d_head=128, fp16: bytes_per_token = 2 · 32 · 32 · 128 · 2 = 524,288 bytes = 512 KB per token
For a sequence of 4096 tokens (one batch element): memory = 4096 · 512 KB = 2 GB per sequence
For batch size 8 at seq_len 4096: 16 GB — this is why GPU memory is the bottleneck.
Scaling analysis: KV cache memory grows linearly with sequence length:
Memory(n) = 2 · B · L · H · d_head · n · bytes_per_elem
For typical models, at n = 4096 with B=1: 1-4 GB. At n = 128K (long context): 32-128 GB. This is the primary limiter on context length.
5. Compute-vs-Memory Analysis
Compute per token (KV cached): - Attention: O(H · d_head · n) ≈ O(d · n) per layer per token - FFN + projections: O(d²) per layer per token - Total: L · O(d² + d·n) per token
For d ≫ n (short sequences): compute-bound (d² dominates) For n ≫ d (long sequences): attention-bound (d·n approaches d²)
Memory bandwidth: For each token generated, we must read the ENTIRE KV cache from GPU HBM (high-bandwidth memory) and write one new entry.
At n = 4096, d = 4096, fp16: KV cache read per token per layer: ~H · n · d_head · 2 bytes = 32·4096·128·2 = 33.6 MB per layer With L=32: ~1.07 GB read per token
If GPU HBM bandwidth is 2 TB/s, reading 1 GB takes ~0.5 ms. The actual compute for that token's attention is ~0.02 ms. Inference is memory-bandwidth-bound: ~95%+ of time is spent waiting for KV cache data transfer.
6. Multi-Query Attention (MQA) — KV Cache Compression
Standard multi-head attention: each head has its own K and V projections, so we cache H separate K,V per layer.
MQA (Shazeer, 2019): All heads share a SINGLE K and V projection. Only Q has multiple heads.
q_{h} = W_Q^h · x (h = 1..H, separate Q projections) k = W_K · x (shared K) v = W_V · x (shared V)
KV cache reduction: Factor of H. For H=32, cache size drops to 1/32.
Tradeoff: Slight quality reduction. The model has less representational capacity for attention patterns since all heads work with the same keys and values.
7. Grouped-Query Attention (GQA) — Middle Ground
GQA (Ainslie et al., 2023, used in Llama 2 70B): Heads are divided into G groups. Each group shares K and V.
For each group g ∈ {1..G}: k_g = W_K^g · x (shared K for heads in group g) v_g = W_V^g · x (shared V for heads in group g)
KV cache reduction: Factor of H/G. For H=64, G=8: 8× reduction.
Llama 2 settings: 7B/13B use H=G (full MHA, no sharing). 70B uses G=8 with H=64, giving 8 K,V heads.
Pitfalls
⚠️ Pitfall 1: Forgetting that the KV cache grows during decode. Each new token adds one entry per layer per head. A 4096-token sequence produces a 2GB cache for a 7B model. Memory planning MUST account for this growth — it's not a fixed allocation.
⚠️ Pitfall 2: Confusing prefill with decode. Prefill processes the whole prompt in parallel and POPULATES the cache. Decode reads FROM the cache and appends one entry. Prefill is compute-bound (O(n²) attention); decode is memory-bandwidth-bound (reading the growing cache).
⚠️ Pitfall 3: Assuming MQA/GQA are free. While they reduce KV cache, they also reduce the attention's representational capacity. A model with GQA (G=8, H=64) can express fewer distinct attention patterns than full MHA. The quality-memory tradeoff is real.
Key Terms
- 18 09 Kv Caching
- Common Pitfalls
- Compute-vs-Memory Analysis
- Example 1: KV Cache Memory for Llama 2 7B
- Example 2: MQA vs MHA Cache Comparison
- Example 3: Compute Complexity With and Without Cache
- Grouped-Query Attention (GQA) — Middle Ground
- KV Cache Algorithm
- KV cache
- K_cache
- Memory Footprint Analysis
- Multi-Query Attention (MQA) — KV Cache Compression
Worked Examples
Example 1: KV Cache Memory for Llama 2 7B
Problem: Compute the KV cache memory in GB for Llama 2 7B (L=32, H=32, d_head=128, fp16) at sequence lengths n = 2048, 4096, 8192, 32768. Single batch.
Solution:
bytes_per_token = 2 · L · H · d_head · 2 (fp16) = 2 · 32 · 32 · 128 · 2 = 524,288 bytes = 0.5 MB/token
| n | Memory |
|---|---|
| 2048 | 2048 · 0.5 MB = 1.0 GB |
| 4096 | 4096 · 0.5 MB = 2.0 GB |
| 8192 | 8192 · 0.5 MB = 4.0 GB |
| 32768 | 32768 · 0.5 MB = 16.0 GB |
For B=8 at n=4096: 16 GB. With model weights ~14 GB (7B params × 2 bytes fp16), total ≈ 30 GB — fits in an A100 (40/80GB) but tight.
Example 2: MQA vs MHA Cache Comparison
Problem: A model has L=40, H=64, d_head=128, fp16. Compare KV cache size at n=4096 for standard MHA, GQA with G=8, and MQA.
Solution:
Standard MHA: bytes_per_token = 2 · 40 · 64 · 128 · 2 = 1,310,720 bytes = 1.25 MB/token Memory(n=4096) = 4096 · 1.25 MB = 5.0 GB
GQA (G=8): H/G = 8 K,V heads bytes_per_token = 2 · 40 · 8 · 128 · 2 = 163,840 bytes = 0.156 MB/token Memory(n=4096) = 4096 · 0.156 MB = 0.625 GB (8× reduction)
MQA (G=1): 1 K,V head bytes_per_token = 2 · 40 · 1 · 128 · 2 = 20,480 bytes ≈ 0.020 MB/token Memory(n=4096) = 4096 · 0.020 MB = 0.078 GB (64× reduction)
Example 3: Compute Complexity With and Without Cache
Problem: For L=32, d=4096, seq_len growing from 100 (prompt) to 1100 (1000 generated tokens), compute total attention FLOPs with and without KV caching.
Solution:
Attention FLOPs per token at position t (per head, simplified): ~2 · t · d_head (one dot product with t keys, times d_head multiply-adds)
With H heads: ~2 · H · t · d_head = 2 · d · t (since H·d_head = d)
Without caching (per layer, per step): At step t, compute ALL pairwise dot products: O(t²·d). For the entire generation: O(T³·d).
Simplified without caching total: ~Σ_{t=100}^{1100} 2·d·t² ≈ 2·4096·Σ t² Σ t² from 100 to 1100 = (1100³ − 100³)/3 ≈ (1.331×10^9 − 10^6)/3 ≈ 4.44×10^8 FLOPs ≈ 2·4096·4.44×10^8 ≈ 3.64×10^12 per layer With L=32: ≈ 1.16×10^14 FLOPs
With caching (per layer, per step): At step t, compute q·K^T over t keys: O(t·d). For the entire generation: ~Σ_{t=100}^{1100} 2·d·t ≈ 2·4096·Σ t Σ t from 100 to 1100 = (100+1100)·1001/2 ≈ 600,600 FLOPs ≈ 2·4096·600,600 ≈ 4.92×10^9 per layer With L=32: ≈ 1.57×10^11 FLOPs
Reduction: 1.16×10^14 / 1.57×10^11 ≈ 740× reduction in attention FLOPs.
Quiz
Q1: What does the concept of Common Pitfalls primarily refer to in this subject?
A) The definition and application of Common Pitfalls B) A visual representation of Common Pitfalls C) A computational error related to Common Pitfalls D) A historical anecdote about Common Pitfalls
Correct: A)
- If you chose A: Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus. Correct!
- If you chose B: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
- If you chose C: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
- If you chose D: This is incorrect. Common Pitfalls is defined as: the definition and application of common pitfalls. The other options describe different aspects that are not the primary focus.
Q2: What is the primary purpose of Compute-vs-Memory Analysis?
A) It is primarily a historical notation system B) It replaces all other methods in this domain C) It is used to compute-vs-memory analysis in mathematical analysis D) It is used only in advanced research contexts
Correct: C)
- If you chose A: This is incorrect. Compute-vs-Memory Analysis serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose B: This is incorrect. Compute-vs-Memory Analysis serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose C: Compute-vs-Memory Analysis serves the purpose described in the correct answer. The other options misrepresent its role. Correct!
- If you chose D: This is incorrect. Compute-vs-Memory Analysis serves the purpose described in the correct answer. The other options misrepresent its role.
Q3: Which statement about KV Cache Algorithm is TRUE?
A) KV Cache Algorithm is mentioned only as a historical footnote B) KV Cache Algorithm is not related to this subject C) KV Cache Algorithm is a fundamental concept covered in this subject D) KV Cache Algorithm is an advanced topic beyond this subject's scope
Correct: C)
- If you chose A: This is incorrect. KV Cache Algorithm is a fundamental concept covered in this subject. This subject covers KV Cache Algorithm as part of its core content.
- If you chose B: This is incorrect. KV Cache Algorithm is a fundamental concept covered in this subject. This subject covers KV Cache Algorithm as part of its core content.
- If you chose C: KV Cache Algorithm is a fundamental concept covered in this subject. This subject covers KV Cache Algorithm as part of its core content. Correct!
- If you chose D: This is incorrect. KV Cache Algorithm is a fundamental concept covered in this subject. This subject covers KV Cache Algorithm as part of its core content.
Q4: Based on the worked examples in this subject, what is the correct result?
A) ~128K tokens in the same memory. B) The inverse of the correct answer C) A different result from a common mistake D) An unrelated numerical value
Correct: A)
- If you chose A: The worked examples show that the result is ~128K tokens in the same memory.. The other options represent common errors. Correct!
- If you chose B: This is incorrect. The worked examples show that the result is ~128K tokens in the same memory.. The other options represent common errors.
- If you chose C: This is incorrect. The worked examples show that the result is ~128K tokens in the same memory.. The other options represent common errors.
- If you chose D: This is incorrect. The worked examples show that the result is ~128K tokens in the same memory.. The other options represent common errors.
Q5: How are KV Cache Algorithm and KV cache related?
A) KV Cache Algorithm is a special case of KV cache B) KV Cache Algorithm is the inverse of KV cache C) KV Cache Algorithm and KV cache are completely unrelated topics D) KV Cache Algorithm and KV cache are closely related concepts
Correct: D)
- If you chose A: This is incorrect. Both KV Cache Algorithm and KV cache are covered in this subject as interconnected topics.
- If you chose B: This is incorrect. Both KV Cache Algorithm and KV cache are covered in this subject as interconnected topics.
- If you chose C: This is incorrect. Both KV Cache Algorithm and KV cache are covered in this subject as interconnected topics.
- If you chose D: Both KV Cache Algorithm and KV cache are covered in this subject as interconnected topics. Correct!
Q6: What is a common pitfall when working with Memory Footprint Analysis?
A) Memory Footprint Analysis has no common misconceptions B) A common mistake is confusing Memory Footprint Analysis with a similar concept C) Memory Footprint Analysis is always computed the same way in all contexts D) The main error with Memory Footprint Analysis is using it when it is not needed
Correct: B)
- If you chose A: This is incorrect. Students often confuse Memory Footprint Analysis with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose B: Students often confuse Memory Footprint Analysis with similar-sounding or related concepts. Pay attention to the precise definitions. Correct!
- If you chose C: This is incorrect. Students often confuse Memory Footprint Analysis with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose D: This is incorrect. Students often confuse Memory Footprint Analysis with similar-sounding or related concepts. Pay attention to the precise definitions.
Q7: When should you apply The Problem: Naive Autoregressive Inference?
A) Avoid The Problem: Naive Autoregressive Inference unless explicitly instructed B) The Problem: Naive Autoregressive Inference is not practically useful C) Use The Problem: Naive Autoregressive Inference only in pure mathematics contexts D) Apply The Problem: Naive Autoregressive Inference to solve problems in this subject's domain
Correct: D)
- If you chose A: This is incorrect. The Problem: Naive Autoregressive Inference is a practical tool used throughout this subject to solve relevant problems.
- If you chose B: This is incorrect. The Problem: Naive Autoregressive Inference is a practical tool used throughout this subject to solve relevant problems.
- If you chose C: This is incorrect. The Problem: Naive Autoregressive Inference is a practical tool used throughout this subject to solve relevant problems.
- If you chose D: The Problem: Naive Autoregressive Inference is a practical tool used throughout this subject to solve relevant problems. Correct!
Practice Problems
Problem 1
For a model with L=24, H=16, d_head=64, compute the KV cache size per token in KB (fp16). How much memory for 2048 tokens?
Answer
bytes_per_token = 2 · 24 · 16 · 64 · 2 = 98,304 bytes = 96 KB/token Memory(2048) = 2048 · 96 KB = 196,608 KB ≈ 192 MBProblem 2
With GQA, if H=32 and G=4, how many distinct K,V projections exist per layer? By what factor is the KV cache reduced vs. standard MHA?
Answer
Distinct K,V projections = G = 4 (each shared by H/G = 8 heads). Reduction factor = H/G = 32/4 = 8×. The Q projections remain at 32 (one per head); only K and V are grouped.Problem 3
A GPU has 80 GB HBM and 2 TB/s bandwidth. The KV cache is 4 GB. Approximately what fraction of the GPU's memory bandwidth is consumed per token if the attention computation requires reading the full cache? Assume ~4 ms spent on other operations.
Answer
Time to read 4 GB at 2 TB/s = 4/2000 = 0.002 seconds = 2 ms. If total per-token time is 6 ms (2 ms cache read + 4 ms other), bandwidth fraction = 2/6 = 33%. But in practice, the cache read dominates. With 0.02 ms actual compute: 2/(2+0.02+4) ≈ 33%. The real bottleneck is that reading the entire cache is mandatory; the compute can overlap somewhat, but memory access is the limiting factor for long sequences.Problem 4
Explain why the prefill phase (processing the prompt) does NOT benefit from KV caching but the decode phase does.
Answer
During prefill, the model sees the entire prompt at once. All K and V are computed in parallel (because the prompt is known). There's nothing to "cache" from previous steps because there are no previous steps — this is the first and only forward pass over the prompt. The KV cache is POPULATED during prefill (we store all K,V for future use) but doesn't help prefill itself. During decode, we reuse the cached K,V from prefill and previous decode steps, avoiding recomputation.Problem 5
For a model with L=80, H=64, d_head=128 (like Llama 70B), what is the maximum sequence length that fits in 40 GB of KV cache memory with batch size 1 and fp16?
Answer
bytes_per_token = 2 · 80 · 64 · 128 · 2 = 2,621,440 bytes ≈ 2.5 MB/token max_n = 40 GB / 2.5 MB = 40,000 MB / 2.5 MB = 16,000 tokens This is why Llama 2 70B (which uses GQA with G=8, reducing cache to 1/8) has 8× the context capacity with GQA: ~128K tokens in the same memory.Summary
- KV caching stores the key and value vectors for all processed tokens, avoiding O(n²) recomputation of K and V at each generation step
- Cache memory: 2 · L · H · d_head · n · bytes_per_element per sequence; grows linearly with n and is the primary memory bottleneck for long contexts
- With caching, per-token compute is O(L·d² + L·d·n) — but inference is memory-bandwidth-bound because the entire cache must be read each step
- MQA (1 shared K,V) reduces KV cache by factor H; GQA (G shared K,V groups) reduces by factor H/G, balancing quality and memory
- Prefill populates the cache (one parallel forward pass over the prompt); decode reads from cache (sequential, O(1) per token after cache)
Next Steps
Continue to 18-10 — Transformer Variants (Math Focus) for a mathematical treatment of sparse attention, FlashAttention, and Mixture of Experts.