Math graphic
📐 Concept diagram

18-09 — KV Caching

Phase: 18 — Large Language Model Mathematics Subject: 18-09 Prerequisites: 18-08 (Inference Mathematics), 18-05 (Decoder-Only Architecture), 17-07 (Scaled Dot-Product Attention), 09-09 (Numerical Linear Algebra — memory analysis) Next subject: 18-10 — Transformer Variants (Math Focus)


Learning Objectives

By the end of this subject, you will be able to:

  1. Derive the KV cache mechanism from the causal attention equations and prove correctness
  2. Compute the exact memory footprint of the KV cache as a function of batch size, sequence length, layers, heads, and head dimension
  3. Analyze the compute-to-memory ratio and explain why LLM inference is memory-bandwidth-bound
  4. Compare multi-query attention (MQA) and grouped-query attention (GQA) in terms of KV cache reduction
  5. Derive the computational complexity reduction: O(L·n²) → O(L·n) per new token with KV caching

Core Content

1. The Problem: Naive Autoregressive Inference

Without caching, generating token t requires computing attention over ALL previous t tokens from scratch:

At step t (generating token t):

For each layer ℓ: 1. Compute q_t = W_Q · x_t (query for the new token) 2. Compute K_t = W_K · X_{0:t} (keys for ALL tokens 0..t) 3. Compute V_t = W_V · X_{0:t} (values for ALL tokens 0..t) 4. Compute attention: A = softmax(q_t · K_t^T / √d_k) · V_t

This recomputes K and V for tokens 0..t−1 at EVERY step. The total work for generating T tokens:

Total K,V computations = Σ_{t=1}^{T} t = T(T+1)/2 = O(T²) per layer

With L layers and d model dimension: total FLOPs ≈ L · d² · T² (dominated by the K,V projections repeated for all prefixes).

⚠️ THIS IS CRITICAL — Without caching, generating T tokens costs O(T²) where T can be thousands. KV caching reduces this to O(T).

2. The KV Cache: Core Idea

Observation: The key and value vectors for previous tokens DON'T CHANGE when we add new tokens. The causal mask ensures that token i's attention only depends on tokens ≤ i, and those tokens' K and V representations are fixed once computed.

Solution: Store (cache) the K and V vectors for all previously processed tokens. At each new step, compute only the K and V for the new token, append to the cache, and compute attention using the full cache.

Cache structure: For a model with: - L layers - H heads per layer - d_head dimensions per head - Sequence length n (growing)

The KV cache stores:

K_cache ∈ ℝ^(L × H × n × d_head) (keys) V_cache ∈ ℝ^(L × H × n × d_head) (values)

(Times 2 for K and V, times batch size B for batched inference.)

3. KV Cache Algorithm

Initialization (prefill): Process the prompt of length P in one forward pass. Store K and V for all P tokens in the cache.

For each layer ℓ, head h: K_cache[ℓ, h, 0:P] = K_{ℓ,h}[0:P] V_cache[ℓ, h, 0:P] = V_{ℓ,h}[0:P]

Generation (each new token):

For token at position t = P, P+1, ..., T−1:

  1. Compute x_t (the embedding for the new token, plus previous layer output)
  2. For each layer ℓ: a. Compute q_{t,ℓ,h} = W_Q^{ℓ,h} · x_t (query, one token) b. Compute k_{t,ℓ,h} = W_K^{ℓ,h} · x_t (key, one token) c. Compute v_{t,ℓ,h} = W_V^{ℓ,h} · x_t (value, one token) d. Append to cache: K_cache[ℓ, h, t] = k_{t,ℓ,h}, V_cache[ℓ, h, t] = v_{t,ℓ,h} e. Compute attention: score = q_{t,ℓ,h} · K_cache[ℓ, h, 0:t+1]^T / √d_head f. Attention output = softmax(score) · V_cache[ℓ, h, 0:t+1] g. Combine heads and continue to next sub-layer

Key in Step 2e: The query q_t (shape: 1 × d_head) is multiplied against K_cache[ℓ, h, 0:t+1] (shape: (t+1) × d_head). This single dot product replaces the full O(t²) attention computation.

4. Memory Footprint Analysis

Per token: Each token stores H K-vectors and H V-vectors per layer, each of dimension d_head.

bytes_per_token = 2 · L · H · d_head · bytes_per_element

Where bytes_per_element = 2 for fp16, 1 for int8 quantization.

For Llama 2 7B: L=32, H=32, d_head=128, fp16: bytes_per_token = 2 · 32 · 32 · 128 · 2 = 524,288 bytes = 512 KB per token

For a sequence of 4096 tokens (one batch element): memory = 4096 · 512 KB = 2 GB per sequence

For batch size 8 at seq_len 4096: 16 GB — this is why GPU memory is the bottleneck.

Scaling analysis: KV cache memory grows linearly with sequence length:

Memory(n) = 2 · B · L · H · d_head · n · bytes_per_elem

For typical models, at n = 4096 with B=1: 1-4 GB. At n = 128K (long context): 32-128 GB. This is the primary limiter on context length.

5. Compute-vs-Memory Analysis

Compute per token (KV cached): - Attention: O(H · d_head · n) ≈ O(d · n) per layer per token - FFN + projections: O(d²) per layer per token - Total: L · O(d² + d·n) per token

For d ≫ n (short sequences): compute-bound (d² dominates) For n ≫ d (long sequences): attention-bound (d·n approaches d²)

Memory bandwidth: For each token generated, we must read the ENTIRE KV cache from GPU HBM (high-bandwidth memory) and write one new entry.

At n = 4096, d = 4096, fp16: KV cache read per token per layer: ~H · n · d_head · 2 bytes = 32·4096·128·2 = 33.6 MB per layer With L=32: ~1.07 GB read per token

If GPU HBM bandwidth is 2 TB/s, reading 1 GB takes ~0.5 ms. The actual compute for that token's attention is ~0.02 ms. Inference is memory-bandwidth-bound: ~95%+ of time is spent waiting for KV cache data transfer.

6. Multi-Query Attention (MQA) — KV Cache Compression

Standard multi-head attention: each head has its own K and V projections, so we cache H separate K,V per layer.

MQA (Shazeer, 2019): All heads share a SINGLE K and V projection. Only Q has multiple heads.

q_{h} = W_Q^h · x (h = 1..H, separate Q projections) k = W_K · x (shared K) v = W_V · x (shared V)

KV cache reduction: Factor of H. For H=32, cache size drops to 1/32.

Tradeoff: Slight quality reduction. The model has less representational capacity for attention patterns since all heads work with the same keys and values.

7. Grouped-Query Attention (GQA) — Middle Ground

GQA (Ainslie et al., 2023, used in Llama 2 70B): Heads are divided into G groups. Each group shares K and V.

For each group g ∈ {1..G}: k_g = W_K^g · x (shared K for heads in group g) v_g = W_V^g · x (shared V for heads in group g)

KV cache reduction: Factor of H/G. For H=64, G=8: 8× reduction.

Llama 2 settings: 7B/13B use H=G (full MHA, no sharing). 70B uses G=8 with H=64, giving 8 K,V heads.



Pitfalls

⚠️ Pitfall 1: Forgetting that the KV cache grows during decode. Each new token adds one entry per layer per head. A 4096-token sequence produces a 2GB cache for a 7B model. Memory planning MUST account for this growth — it's not a fixed allocation.

⚠️ Pitfall 2: Confusing prefill with decode. Prefill processes the whole prompt in parallel and POPULATES the cache. Decode reads FROM the cache and appends one entry. Prefill is compute-bound (O(n²) attention); decode is memory-bandwidth-bound (reading the growing cache).

⚠️ Pitfall 3: Assuming MQA/GQA are free. While they reduce KV cache, they also reduce the attention's representational capacity. A model with GQA (G=8, H=64) can express fewer distinct attention patterns than full MHA. The quality-memory tradeoff is real.


Key Terms

Worked Examples

Example 1: KV Cache Memory for Llama 2 7B

Problem: Compute the KV cache memory in GB for Llama 2 7B (L=32, H=32, d_head=128, fp16) at sequence lengths n = 2048, 4096, 8192, 32768. Single batch.

Solution:

bytes_per_token = 2 · L · H · d_head · 2 (fp16) = 2 · 32 · 32 · 128 · 2 = 524,288 bytes = 0.5 MB/token

n Memory
2048 2048 · 0.5 MB = 1.0 GB
4096 4096 · 0.5 MB = 2.0 GB
8192 8192 · 0.5 MB = 4.0 GB
32768 32768 · 0.5 MB = 16.0 GB

For B=8 at n=4096: 16 GB. With model weights ~14 GB (7B params × 2 bytes fp16), total ≈ 30 GB — fits in an A100 (40/80GB) but tight.

Example 2: MQA vs MHA Cache Comparison

Problem: A model has L=40, H=64, d_head=128, fp16. Compare KV cache size at n=4096 for standard MHA, GQA with G=8, and MQA.

Solution:

Standard MHA: bytes_per_token = 2 · 40 · 64 · 128 · 2 = 1,310,720 bytes = 1.25 MB/token Memory(n=4096) = 4096 · 1.25 MB = 5.0 GB

GQA (G=8): H/G = 8 K,V heads bytes_per_token = 2 · 40 · 8 · 128 · 2 = 163,840 bytes = 0.156 MB/token Memory(n=4096) = 4096 · 0.156 MB = 0.625 GB (8× reduction)

MQA (G=1): 1 K,V head bytes_per_token = 2 · 40 · 1 · 128 · 2 = 20,480 bytes ≈ 0.020 MB/token Memory(n=4096) = 4096 · 0.020 MB = 0.078 GB (64× reduction)

Example 3: Compute Complexity With and Without Cache

Problem: For L=32, d=4096, seq_len growing from 100 (prompt) to 1100 (1000 generated tokens), compute total attention FLOPs with and without KV caching.

Solution:

Attention FLOPs per token at position t (per head, simplified): ~2 · t · d_head (one dot product with t keys, times d_head multiply-adds)

With H heads: ~2 · H · t · d_head = 2 · d · t (since H·d_head = d)

Without caching (per layer, per step): At step t, compute ALL pairwise dot products: O(t²·d). For the entire generation: O(T³·d).

Simplified without caching total: ~Σ_{t=100}^{1100} 2·d·t² ≈ 2·4096·Σ t² Σ t² from 100 to 1100 = (1100³ − 100³)/3 ≈ (1.331×10^9 − 10^6)/3 ≈ 4.44×10^8 FLOPs ≈ 2·4096·4.44×10^8 ≈ 3.64×10^12 per layer With L=32: ≈ 1.16×10^14 FLOPs

With caching (per layer, per step): At step t, compute q·K^T over t keys: O(t·d). For the entire generation: ~Σ_{t=100}^{1100} 2·d·t ≈ 2·4096·Σ t Σ t from 100 to 1100 = (100+1100)·1001/2 ≈ 600,600 FLOPs ≈ 2·4096·600,600 ≈ 4.92×10^9 per layer With L=32: ≈ 1.57×10^11 FLOPs

Reduction: 1.16×10^14 / 1.57×10^11 ≈ 740× reduction in attention FLOPs.



Quiz

Q1: What does the concept of Common Pitfalls primarily refer to in this subject?

A) The definition and application of Common Pitfalls B) A visual representation of Common Pitfalls C) A computational error related to Common Pitfalls D) A historical anecdote about Common Pitfalls

Correct: A)

Q2: What is the primary purpose of Compute-vs-Memory Analysis?

A) It is primarily a historical notation system B) It replaces all other methods in this domain C) It is used to compute-vs-memory analysis in mathematical analysis D) It is used only in advanced research contexts

Correct: C)

Q3: Which statement about KV Cache Algorithm is TRUE?

A) KV Cache Algorithm is mentioned only as a historical footnote B) KV Cache Algorithm is not related to this subject C) KV Cache Algorithm is a fundamental concept covered in this subject D) KV Cache Algorithm is an advanced topic beyond this subject's scope

Correct: C)

Q4: Based on the worked examples in this subject, what is the correct result?

A) ~128K tokens in the same memory. B) The inverse of the correct answer C) A different result from a common mistake D) An unrelated numerical value

Correct: A)

Q5: How are KV Cache Algorithm and KV cache related?

A) KV Cache Algorithm is a special case of KV cache B) KV Cache Algorithm is the inverse of KV cache C) KV Cache Algorithm and KV cache are completely unrelated topics D) KV Cache Algorithm and KV cache are closely related concepts

Correct: D)

Q6: What is a common pitfall when working with Memory Footprint Analysis?

A) Memory Footprint Analysis has no common misconceptions B) A common mistake is confusing Memory Footprint Analysis with a similar concept C) Memory Footprint Analysis is always computed the same way in all contexts D) The main error with Memory Footprint Analysis is using it when it is not needed

Correct: B)

Q7: When should you apply The Problem: Naive Autoregressive Inference?

A) Avoid The Problem: Naive Autoregressive Inference unless explicitly instructed B) The Problem: Naive Autoregressive Inference is not practically useful C) Use The Problem: Naive Autoregressive Inference only in pure mathematics contexts D) Apply The Problem: Naive Autoregressive Inference to solve problems in this subject's domain

Correct: D)

Practice Problems

Problem 1

For a model with L=24, H=16, d_head=64, compute the KV cache size per token in KB (fp16). How much memory for 2048 tokens?

Answer bytes_per_token = 2 · 24 · 16 · 64 · 2 = 98,304 bytes = 96 KB/token Memory(2048) = 2048 · 96 KB = 196,608 KB ≈ 192 MB

Problem 2

With GQA, if H=32 and G=4, how many distinct K,V projections exist per layer? By what factor is the KV cache reduced vs. standard MHA?

Answer Distinct K,V projections = G = 4 (each shared by H/G = 8 heads). Reduction factor = H/G = 32/4 = 8×. The Q projections remain at 32 (one per head); only K and V are grouped.

Problem 3

A GPU has 80 GB HBM and 2 TB/s bandwidth. The KV cache is 4 GB. Approximately what fraction of the GPU's memory bandwidth is consumed per token if the attention computation requires reading the full cache? Assume ~4 ms spent on other operations.

Answer Time to read 4 GB at 2 TB/s = 4/2000 = 0.002 seconds = 2 ms. If total per-token time is 6 ms (2 ms cache read + 4 ms other), bandwidth fraction = 2/6 = 33%. But in practice, the cache read dominates. With 0.02 ms actual compute: 2/(2+0.02+4) ≈ 33%. The real bottleneck is that reading the entire cache is mandatory; the compute can overlap somewhat, but memory access is the limiting factor for long sequences.

Problem 4

Explain why the prefill phase (processing the prompt) does NOT benefit from KV caching but the decode phase does.

Answer During prefill, the model sees the entire prompt at once. All K and V are computed in parallel (because the prompt is known). There's nothing to "cache" from previous steps because there are no previous steps — this is the first and only forward pass over the prompt. The KV cache is POPULATED during prefill (we store all K,V for future use) but doesn't help prefill itself. During decode, we reuse the cached K,V from prefill and previous decode steps, avoiding recomputation.

Problem 5

For a model with L=80, H=64, d_head=128 (like Llama 70B), what is the maximum sequence length that fits in 40 GB of KV cache memory with batch size 1 and fp16?

Answer bytes_per_token = 2 · 80 · 64 · 128 · 2 = 2,621,440 bytes ≈ 2.5 MB/token max_n = 40 GB / 2.5 MB = 40,000 MB / 2.5 MB = 16,000 tokens This is why Llama 2 70B (which uses GQA with G=8, reducing cache to 1/8) has 8× the context capacity with GQA: ~128K tokens in the same memory.

Summary

  1. KV caching stores the key and value vectors for all processed tokens, avoiding O(n²) recomputation of K and V at each generation step
  2. Cache memory: 2 · L · H · d_head · n · bytes_per_element per sequence; grows linearly with n and is the primary memory bottleneck for long contexts
  3. With caching, per-token compute is O(L·d² + L·d·n) — but inference is memory-bandwidth-bound because the entire cache must be read each step
  4. MQA (1 shared K,V) reduces KV cache by factor H; GQA (G shared K,V groups) reduces by factor H/G, balancing quality and memory
  5. Prefill populates the cache (one parallel forward pass over the prompt); decode reads from cache (sequential, O(1) per token after cache)


Next Steps

Continue to 18-10 — Transformer Variants (Math Focus) for a mathematical treatment of sparse attention, FlashAttention, and Mixture of Experts.