Math graphic
📐 Concept diagram

19-05 — Knowledge Distillation

Phase: 19 — Advanced LLM Mathematics Subject: 19-05 Prerequisites: 19-04 (Pruning), 16-04 (Loss Functions — CCE, KL), 13-04 (KL Divergence — deep), 13-05 (Cross Entropy), 16-03 (Softmax — temperature), 14-02 (Gradient Descent) Next subject: 19-06 — Speculative Decoding


Learning Objectives

By the end of this subject, you will be able to:

  1. Derive the full knowledge distillation loss L = (1−α)·L_CE + α·T²·D_KL(p_t||p_s) and explain the role of each term, including the T² gradient scaling factor
  2. Prove that in the high-temperature limit (T → ∞), distillation reduces to matching the teacher's logit covariance structure, and in the zero-temperature limit it reduces to standard supervised training
  3. Formulate and compare at least three distillation variants: logit-based (Hinton), feature-based (FitNet), and relation-based (RKD), stating the loss function for each
  4. Analyze why distilling from a larger model can outperform training from scratch on the same data, using the "dark knowledge" and label smoothing perspectives
  5. Design a distillation strategy for an LLM: compute the computational budget trade-offs between teacher inference cost, student training cost, and final deployment savings

Core Content

1. The Knowledge Distillation Framework

1.1 Problem Statement

We have: - A pre-trained teacher model f_t with parameters θ_t (large, expensive, high-accuracy) - A student model f_s with parameters θ_s (small, cheap, to be trained) - Training data: {(x_i, y_i)} where y_i is the ground-truth label (hard targets)

The goal: train the student to match the teacher's behavior, not just the ground truth.

Intuition: The teacher's output distribution p_t = softmax(z_t) contains more information than the one-hot label y. The relative probabilities assigned to INCORRECT classes encode the teacher's learned similarity structure ("dark knowledge").

For example, if classifying a car image, the teacher might output:

$p_t = [0.7 (car), 0.2 (truck), 0.08 (bus), 0.01 (bicycle), 0.01 (pedestrian)]
$

The one-hot label only says "car." The teacher additionally reveals that "truck" is a much better wrong answer than "bicycle" — structural knowledge about class relationships.

1.2 Temperature-Scaled Softmax

⚠️ THIS IS CRITICAL — Temperature is the core mechanism that controls how much "dark knowledge" is revealed.

p^T(z) = softmax(z/T) = e^{z_i/T} / Σ_j e^{z_j/T}

where T is the temperature parameter.

Behavior by T: - T = 1: Standard softmax. Teacher is confident; incorrect class probabilities are near zero. - T > 1: "Softer" distribution. Flattens the probability mass, giving more weight to incorrect classes. Reveals the teacher's relative preferences among wrong answers. - T → ∞: p^T(z) → uniform. All probabilities approach 1/K (K = number of classes). - T → 0: p^T(z) → argmax. Approaches one-hot at the maximum logit.

Gradient scaling with T:

For a single logit z_i:

$∂p^T_i / ∂z_i = (1/T) · p^T_i · (1 − p^T_i)
$

The T in the denominator means that gradients are SCALED DOWN by T. The T² factor in the distillation loss (see below) compensates for this.

1.3 The Distillation Loss (Hinton et al., 2015)

$L_KD = (1−α) · L_CE(y, p_s) + α · T² · D_KL(p_t^T || p_s^T)
$

where: - L_CE(y, p_s) = −log p_s(y) — standard cross-entropy with hard labels - D_KL(p_t^T || p_s^T) = Σ_j p_t^T(j) · log(p_t^T(j)/p_s^T(j)) — KL divergence between teacher and student soft targets - p_t^T = softmax(z_t/T), p_s^T = softmax(z_s/T) - α ∈ [0, 1] — weighting between hard and soft targets - T² — gradient scaling factor (see derivation below)

1.4 Why the T² Factor?

The gradient of D_KL w.r.t. student logit z_{s,i}:

$∂D_KL(p_t^T || p_s^T) / ∂z_{s,i} = (1/T) · (p_s^T(i) − p_t^T(i))
$

Without T², the gradient scales as 1/T. For large T (e.g., T=10), this would make the soft-target gradient 10× smaller than the hard-target gradient.

T² correction:

$T² · ∂D_KL / ∂z_{s,i} = T · (p_s^T(i) − p_t^T(i))
$

Now for large T, we can Taylor expand the softmax:

$p^T(i) ≈ 1/K + z_i/(T·K) − (1/K²)·(Σ z_j)/T + O(1/T²)
$

The gradient becomes:

$∂L/∂z_{s,i} ≈ T · (z_{s,i}/K − z_{t,i}/K) / T = (z_{s,i} − z_{t,i}) / K
$

which is independent of T! The T² factor makes the gradient scale invariant to temperature, ensuring the soft-target loss remains meaningful at all T.

1.5 The Distillation Gradient in Full

Combining both terms:

$∂L_KD/∂z_{s,i} = (1−α) · (p_s(i) − 𝟙[y=i]) + α · T · (p_s^T(i) − p_t^T(i))
$

At T=1 and α=1: ∂L/∂z_{s,i} = p_s(i) − p_t(i). The student learns to exactly match the teacher's output distribution.

2. Temperature Analysis: What Does T Control?

2.1 Low Temperature (T ≈ 1)

Soft targets are close to hard targets (teacher is confident). The student learns the same decision boundaries as the teacher. Distillation behaves similarly to standard training with label smoothing.

2.2 Medium Temperature (T ≈ 2-5)

Sweet spot. The teacher reveals similarity structure among incorrect classes without flattening so much that all classes look identical. The student learns: - Primary signal: what the correct class is (from the peak) - Secondary signal: which incorrect classes are "close" to the correct one

2.3 High Temperature Limit (T → ∞)

$p^T(i) ≈ 1/K + (z_i − z̄)/(T·K) + O(1/T²)
$

The KL divergence (with T² scaling) approaches:

$T² · D_KL(p_t^T || p_s^T) → (1/2K) · ||z_t − z_t̄ − z_s + z_s̄||²
$

In the high-T limit, distillation reduces to matching the CENTERED logits of teacher and student. This is equivalent to matching the teacher's logit COVARIANCE structure — which classes are grouped together in logit space.

2.4 Zero Temperature Limit (T → 0)

p_t^T → one_hot(argmax(z_t))

Distillation with T → 0 is equivalent to training on the teacher's hard predictions. If the teacher is accurate, this is similar to training on ground truth labels, but with the teacher's mistakes embedded as training noise.

3. Distillation Variants

3.1 Logit-Based Distillation (Hinton, 2015)

The original. Loss compares only the final output distributions:

$L = α · T² · D_KL(p_t^T || p_s^T) + (1−α) · L_CE(y, p_s)
$

Advantages: Simple, architecture-agnostic. Disadvantages: Ignores intermediate representations.

3.2 Feature-Based Distillation (FitNet)

Instead of matching final outputs, match INTERMEDIATE representations. For a chosen layer pair (teacher layer l_t, student layer l_s):

$L_feat = ||f_l_t(x) − W_proj · f_l_s(x)||²
$

where W_proj is a learned projection (since teacher and student may have different hidden dimensions).

The total loss includes both feature matching and output matching:

$L = L_KD + β · L_feat
$

Why it helps: Intermediate features encode compositional knowledge. A small student may not have capacity to learn the right intermediate representations purely from output supervision. Feature distillation provides "hints" about how to decompose the problem.

3.3 Relation-Based Distillation (RKD)

Rather than matching individual outputs, match the RELATIONSHIPS between examples or classes.

Distance-wise: For a batch of N examples, match pairwise distances:

$L_RKD_dist = Σ_{i≠j} (ψ_D(t_i, t_j) − ψ_D(s_i, s_j))²
$

where t_i and s_i are the teacher and student representations of example i, and ψ_D is a distance measure (Euclidean, cosine).

Angle-wise: For triplets of examples, match angles:

$L_RKD_angle = Σ_{i≠j≠k} (ψ_A(t_i, t_j, t_k) − ψ_A(s_i, s_j, s_k))²
$

where ψ_A computes the angle formed by three points.

Why it helps: The student doesn't need to reproduce exact teacher features — just the RELATIVE structure. This is more forgiving of capacity differences and can transfer knowledge even when the student can't fully replicate the teacher's representation.

3.4 Attention Transfer

For transformer models, match attention patterns:

$L_att = (1/H) · Σ_h ||A_t^h − A_s^h||²_F
$

where A^h is the attention matrix for head h. This transfers the teacher's "what to attend to" strategy.

3.5 Hidden-State Distillation (TinyBERT, DistilBERT)

Match hidden states at multiple layers:

$L_hidden = Σ_l ||h_t^{l_t(l)} − W_l · h_s^{l_s(l)}||²
$

with strategically selected layer mappings. The student is often half the depth of the teacher, and every other teacher layer is mapped to the student.

4. Why Distillation Works: Theoretical Perspectives

4.1 The "Dark Knowledge" Argument

The teacher's softmax provides a richer training signal. For C classes, the ground truth gives log₂ C bits of information per example. The teacher's full distribution (with relative probabilities) provides up to C times more information — the lesson is not just "this is a car" but also "a car is more like a truck than a bicycle."

4.2 Label Smoothing Connection

Training with hard labels + teacher soft targets ≈ training with smooth labels where the smoothing distribution is the teacher's confusion pattern (not uniform).

Standard label smoothing: y_smooth = (1−ε)·y_hard + ε·(1/C) Distillation smoothing: y_smooth ≈ (1−ε)·y_hard + ε·p_t

The difference: distillation smooths toward TEACHER-SPECIFIC class relationships, not uniform. This encodes domain-specific prior knowledge.

4.3 Variance Reduction

Teacher predictions p_t are lower-variance than individual training labels (the teacher averages over its own training). Training the student on teacher outputs averages out individual training example noise — a form of denoising that produces better generalization.

4.4 Optimization Landscape Smoothing

The student's loss landscape with soft targets is smoother than with hard targets (fewer sharp local minima). The soft target encourages the student to stay in basins that are good for generalization, not just sharp minima that achieve low hard-target loss.

5. Distillation for LLMs

5.1 Challenges

5.2 LLM Distillation Loss

For autoregressive language modeling at position t:

$L_KD^t = (1−α) · L_CE(y_t, p_s(·|x_{<t})) + α · T² · D_KL(p_t^T(·|x_{<t}) || p_s^T(·|x_{<t}))
$

The total loss averages over all positions.

5.3 Practical Approaches

DistilGPT-2 / DistilBERT style: - Student has half the layers of teacher - Initialize student layers from teacher layers (every other layer) - Train with KL + cosine embedding loss + hard label loss - 40% parameter reduction with 95% performance retention

Sequence-level distillation: - Generate completions from the teacher on a large prompt corpus - Train the student directly on (prompt, teacher_completion) pairs as supervised data - Simpler than step-by-step KL but requires running the teacher for full generations

On-policy distillation: - Student generates its own completions, teacher scores them - Student trained to maximize teacher-assigned score on its own outputs - More expensive (requires student rollouts) but matches the student's actual distribution



Pitfalls

⚠️ Pitfall 1: Using hard labels instead of soft labels. Knowledge distillation minimizes KL(p_teacher || p_student), where p_teacher is the full softmax distribution. Using only the teacher's argmax (hard label) throws away the rich "dark knowledge" in the teacher's probability distribution — that token A has 40% and token B has 35% tells the student about plausible alternatives.

⚠️ Pitfall 2: Setting the distillation temperature too low. Low T makes the teacher's distribution nearly one-hot, losing the inter-class similarity information. The optimal T is usually 2-10 for LLM distillation — high enough to expose the teacher's uncertainty structure.

⚠️ Pitfall 3: Distilling only on the final output layer. The teacher's intermediate hidden states contain valuable information. Layer-to-layer distillation (matching student hidden states to teacher hidden states) significantly improves results, especially for deep models.


Key Terms

Worked Examples

Example 1: Temperature Effect on Softmax

Teacher logits: z_t = [5.0, 2.0, 1.0, 0.5, 0.1] for 5 classes. Compute the softmax at T=1, T=2, T=5, T=20.

T=1 (standard):

$e^{z/T} = [e^5, e^2, e^1, e^{0.5}, e^{0.1}] = [148.4, 7.39, 2.72, 1.65, 1.11]
Σ = 161.2
p = [0.920, 0.046, 0.017, 0.010, 0.007]
$

The teacher is extremely confident — 92% on class 1, other classes negligible.

T=2:

$e^{z/2} = [e^{2.5}, e^1, e^{0.5}, e^{0.25}, e^{0.05}] = [12.18, 2.718, 1.649, 1.284, 1.051]
Σ = 18.88
p = [0.645, 0.144, 0.087, 0.068, 0.056]
$

More spread. Class 2 gets 14.4%, revealing that class 2 is the "best wrong answer."

T=5:

$e^{z/5} = [e^1, e^{0.4}, e^{0.2}, e^{0.1}, e^{0.02}] = [2.718, 1.492, 1.221, 1.105, 1.020]
Σ = 7.556
p = [0.360, 0.197, 0.162, 0.146, 0.135]
$

Almost uniform. The rank order is preserved but probabilities differ only subtly. The "dark knowledge" is maximally revealed.

T=20:

$p = [0.228, 0.200, 0.194, 0.190, 0.188]
$

Very close to uniform (1/5 = 0.2). Nearly all class-distinction information is lost.

Optimal T for this case: T=2-3 provides the best balance — enough spread to reveal structure without washing out the signal.

Example 2: KL Divergence Gradient

For classes {cat, dog, bird}, teacher (T=2) gives p_t = [0.6, 0.3, 0.1]. Student gives p_s = [0.5, 0.2, 0.3]. Compute the KL gradient.

KL divergence:

$D_KL(p_t || p_s) = 0.6·log(0.6/0.5) + 0.3·log(0.3/0.2) + 0.1·log(0.1/0.3)
                 = 0.6·log(1.2) + 0.3·log(1.5) + 0.1·log(0.333)
                 = 0.6·0.182 + 0.3·0.405 + 0.1·(−1.099)
                 = 0.109 + 0.122 − 0.110
                 = 0.121
$

Gradient w.r.t. student logits (at T=1, no T² scaling):

$∂D_KL/∂z_s = p_s − p_t = [0.5−0.6, 0.2−0.3, 0.3−0.1] = [−0.1, −0.1, +0.2]
$

Interpretation: to reduce KL, the student should INCREASE cat and dog probabilities (they're too low relative to teacher) and DECREASE bird probability (too high). The gradient pushes in exactly these directions.

With T² scaling (T=2): Multiply by T² = 4: gradient = [−0.4, −0.4, +0.8], compensating for the 1/T softmax gradient scaling.

Example 3: Distillation vs Training from Scratch

A student model achieves test accuracy 72% when trained from scratch on N=10K labeled examples. A teacher model achieves 85% on the same task. The teacher's soft predictions on the same 10K examples are used to distill the student.

Given: The student trained with distillation (α=0.7, T=3) achieves 78% test accuracy — 6 percentage points above training from scratch on the SAME data.

Analysis: Why does distillation help?

  1. Information content: Each example provides ~1.5 bits of information as a hard label (log₂ 3 for 3 classes?). Actually for C classes: log₂ C bits. The teacher's full softmax provides up to C−1 real numbers of additional information per example about class relationships.

  2. The teacher's "extra" knowledge: The teacher was trained on MORE data (say 100K examples). Its soft labels on the 10K student examples ENCODE knowledge from the other 90K examples. The student indirectly benefits from the larger dataset through the teacher.

  3. Variance reduction: The teacher averages over its own training noise. Its predictions are smoother and more consistent than the raw labels.

The 6% improvement from distillation is the value of the teacher's additional knowledge transferred to the student.



Quiz

Q1: What does the concept of Common Pitfalls primarily refer to in this subject?

A) The definition and application of Common Pitfalls B) A computational error related to Common Pitfalls C) A visual representation of Common Pitfalls D) A historical anecdote about Common Pitfalls

Correct: A)

Q2: Which of the following is the key formula discussed in this subject?

A) A simplified version of 10/1M =... B) An unrelated formula from a different topic C) 10/1M = D) The inverse operation of the formula in question

Correct: C)

Q3: What is the primary purpose of Distillation Variants?

A) It is primarily a historical notation system B) It replaces all other methods in this domain C) It is used to distillation variants in mathematical analysis D) It is used only in advanced research contexts

Correct: C)

Q4: Which statement about Distillation for LLMs is TRUE?

A) Distillation for LLMs is an advanced topic beyond this subject's scope B) Distillation for LLMs is mentioned only as a historical footnote C) Distillation for LLMs is not related to this subject D) Distillation for LLMs is a fundamental concept covered in this subject

Correct: D)

Q5: Based on the worked examples in this subject, what is the correct result?

A) number of classes). B) A different result from a common mistake C) An unrelated numerical value D) The inverse of the correct answer

Correct: A)

Q6: How are Distillation for LLMs and KL divergence related?

A) Distillation for LLMs is the inverse of KL divergence B) Distillation for LLMs and KL divergence are closely related concepts C) Distillation for LLMs is a special case of KL divergence D) Distillation for LLMs and KL divergence are completely unrelated topics

Correct: B)

Q7: What is a common pitfall when working with The Knowledge Distillation Framework?

A) A common mistake is confusing The Knowledge Distillation Framework with a similar concept B) The main error with The Knowledge Distillation Framework is using it when it is not needed C) The Knowledge Distillation Framework is always computed the same way in all contexts D) The Knowledge Distillation Framework has no common misconceptions

Correct: A)

Q8: When should you apply Temperature Analysis: What Does T Control??

A) Temperature Analysis: What Does T Control? is not practically useful B) Apply Temperature Analysis: What Does T Control? to solve problems in this subject's domain C) Avoid Temperature Analysis: What Does T Control? unless explicitly instructed D) Use Temperature Analysis: What Does T Control? only in pure mathematics contexts

Correct: B)

Practice Problems

Problem 1

Derive the gradient of the T²-scaled KL divergence term with respect to a student logit z_{s,i}. Show that it equals T·(p_s^T(i) − p_t^T(i)).

Answer Let L = T² · D_KL(p_t^T || p_s^T) = T² · Σ_j p_t^T(j) · log(p_t^T(j) / p_s^T(j)). The only term depending on z_{s,i} is log(1/p_s^T(i)):
$L = T² · Σ_j p_t^T(j) · [log(p_t^T(j)) − log(p_s^T(j))]
$
∂L/∂z_{s,i} = −T² · Σ_j p_t^T(j) · ∂log(p_s^T(j))/∂z_{s,i} Now ∂log(p_s^T(j))/∂z_{s,i}: For j=i: ∂/∂z_{s,i} log(p_s^T(i)) = (1/p_s^T(i)) · ∂p_s^T(i)/∂z_{s,i} ∂p_s^T(i)/∂z_{s,i} = (1/T) · p_s^T(i) · (1 − p_s^T(i)) [softmax derivative with temp] So: ∂log(p_s^T(i))/∂z_{s,i} = (1/T) · (1 − p_s^T(i)) For j≠i: ∂p_s^T(j)/∂z_{s,i} = −(1/T) · p_s^T(j) · p_s^T(i) So: ∂log(p_s^T(j))/∂z_{s,i} = (1/p_s^T(j)) · (−1/T) · p_s^T(j) · p_s^T(i) = −(1/T) · p_s^T(i) Plugging in: ∂L/∂z_{s,i} = −T² · [p_t^T(i) · (1/T)·(1−p_s^T(i)) + Σ_{j≠i} p_t^T(j) · (−1/T)·p_s^T(i)] = −T · [p_t^T(i)·(1−p_s^T(i)) − p_s^T(i)· Σ_{j≠i} p_t^T(j)] = −T · [p_t^T(i) − p_t^T(i)·p_s^T(i) − p_s^T(i)·(1−p_t^T(i))] = −T · [p_t^T(i) − p_t^T(i)·p_s^T(i) − p_s^T(i) + p_s^T(i)·p_t^T(i)] = −T · [p_t^T(i) − p_s^T(i)] = T · (p_s^T(i) − p_t^T(i)) ∎

Problem 2

For a distillation setup with T=4 and α=0.5, the teacher gives p_t^T = [0.4, 0.35, 0.25] and the student (before update) gives p_s^T = [0.5, 0.3, 0.2]. The hard label is class 0. Compute the total gradient for each student logit (combined hard + soft loss).

Answer Hard target gradient: ∂L_CE/∂z_{s,i} = p_s(i) − 𝟙[i=0]. At T=1, p_s = softmax(z_s). We don't have the raw logits, but we can compute relative gradients. If we assume z_s produced p_s^T at T=4, then p_s at T=1 is different. For simplicity, let's assume p_s(T=1) ≈ p_s^T (this is only valid if the logits are small, but for gradient calculation we need the right p_s). Actually, the hard loss uses T=1 softmax: p_s = softmax(z_s). The soft loss uses T=4: p_s^T = softmax(z_s/4). We need z_s. From p_s^T, we can recover z_s up to an additive constant: z_s/4 ≈ log(p_s^T). z_s ≈ 4·log(p_s^T) plus constant. z_s ∝ [4·log(0.5), 4·log(0.3), 4·log(0.2)] = [−2.77, −4.82, −6.44] (+ constant C) p_s = softmax(z_s) = e^{z_s}/Σe^{z_s}. With constant canceled: p_s = [e^{−2.77}, e^{−4.82}, e^{−6.44}] / Σ = [0.063, 0.008, 0.002] / 0.073 = [0.86, 0.11, 0.03] Hard gradient (1−α=0.5): 0.5 · (p_s − onehot(0)) = 0.5 · [0.86−1, 0.11−0, 0.03−0] = [−0.07, 0.055, 0.015] Soft gradient (α·T=0.5·4=2): 2 · (p_s^T − p_t^T) = 2 · [0.5−0.4, 0.3−0.35, 0.2−0.25] = [0.2, −0.1, −0.1] Total gradient: [−0.07+0.2, 0.055−0.1, 0.015−0.1] = [0.13, −0.045, −0.085] The soft loss pushes class 0 up (student was overconfident on class 0? No, p_s^T=0.5 > p_t^T=0.4, so the soft loss DECREASES class 0... wait: p_s^T−p_t^T = 0.5−0.4 = +0.1. But T·(p_s−p_t) = 2·(+0.1) = +0.2 pushes class 0 UP. The KL gradient moves class 0 UP because increasing p_s(0) would reduce the gap? Let me reconsider.) Actually: ∂D_KL/∂z_{s,i} ∝ p_s^T(i) − p_t^T(i). If p_s^T(i) > p_t^T(i), the gradient is POSITIVE — it increases z_{s,i}, making p_s^T(i) even LARGER. That seems wrong... The KL divergence is D_KL(p_t||p_s) = Σ p_t log(p_t/p_s). If p_s > p_t, then log(p_t/p_s) is negative, so we want to INCREASE p_s (making the ratio closer to 1). Since p_s is already larger, increasing it further makes the ratio EVEN FURTHER from 1. This is contradictory. Wait — I need to check the derivative sign. L = Σ_j p_t(j)·(log p_t(j) − log p_s(j)). ∂L/∂z_{s,i} = Σ_j p_t(j)·(−∂log p_s(j)/∂z_{s,i}) = −(1/T)·[p_t(i)(1−p_s(i)) − Σ_{j≠i} p_t(j)·p_s(i)]... I already derived that ∂L/∂z = (p_s−p_t)/T (without T² scaling). With T²: T·(p_s−p_t). If p_s > p_t, derivative is positive → increasing z increases L? That can't be right for minimizing. Let me re-derive carefully. L = T²·KL(p_t||p_s) = T²·Σ p_t·log(p_t/p_s). ∂/∂z_{s,i}: only p_s depends on z_{s,i}. ∂log(1/p_s(j))/∂z_{s,i}. For j=i: ∂/∂z_{s,i}(−log p_s(i)) = −(1/p_s(i))·(1/T)·p_s(i)·(1−p_s(i)) = −(1/T)·(1−p_s(i)) For j≠i: ∂/∂z_{s,i}(−log p_s(j)) = −(1/p_s(j))·(−1/T)·p_s(j)·p_s(i) = (1/T)·p_s(i) ∂L/∂z_{s,i} = T²·[p_t(i)·(−1/T)·(1−p_s(i)) + Σ_{j≠i} p_t(j)·(1/T)·p_s(i)] = T·[−p_t(i)·(1−p_s(i)) + p_s(i)·(1−p_t(i))] = T·[−p_t(i) + p_t(i)p_s(i) + p_s(i) − p_s(i)p_t(i)] = T·[p_s(i) − p_t(i)] So the gradient IS T·(p_s−p_t). And if p_s > p_t, the gradient is positive. To MINIMIZE L, we do z_s ← z_s − η·∂L/∂z_s. So we move in direction −(p_s−p_t) = p_t−p_s. If p_s > p_t, we decrease p_s. CORRECT. The gradient was right all along — I confused the sign of the update rule. The NEGATIVE gradient is applied: p_s > p_t → positive gradient → step in negative direction → decrease p_s. Good. So in the example: soft gradient = [0.2, −0.1, −0.1], update direction = [−0.2, +0.1, +0.1] (decrease class 0, increase classes 1 and 2 to match teacher).

Problem 3

In feature-based distillation, the student's intermediate features have dimension d_s=256 and the teacher's have d_t=768. A learned projection W_proj ∈ ℝ^{768×256} maps student features to teacher space. If the student processes N=100K training examples with batch size 64, compute the total FLOPs for the projection operation over one epoch. Compare to the cost if we instead used a student with d_s=768 (no projection needed).

Answer Projection FLOPs per example: 2 · d_t · d_s = 2 · 768 · 256 = 393,216 FLOPs. Per epoch: 100K · 393,216 ≈ 3.93×10¹⁰ FLOPs ≈ 0.04 TFLOP (a single forward pass through even a small transformer is ~1 GFLOP, so the projection is negligible). If we instead increased d_s from 256 to 768: - Additional parameters in every layer: roughly proportional to d_s². - Per layer: attention = 4·768² − 4·256² = 4·(589,824 − 65,536) = 2,097,152 extra params. - FFN (standard 4× expansion): 2·768·3072 − 2·256·1024 = 4,718,592 − 524,288 = 4,194,304 extra. Per layer extra: ~6.3M params. For 12 layers: ~75M extra parameters. At fp16, that's 150 MB extra memory. The projection approach costs ~400K FLOPs per example and O(1) extra memory. Increasing dimension costs ~150 MB and proportionally more FLOPs per forward pass. The projection is clearly the more efficient approach for matching teacher feature dimensions.

Problem 4

For LLM distillation, the probability distribution is over |V|=50,000 tokens at each position. Computing the full KL divergence requires evaluating the student's softmax over all 50K tokens, which is expensive. Propose and analyze two methods to approximate the KL divergence efficiently.

Answer **Method 1: Top-K KL (sparse KL)** Only compute the KL over the top-k tokens in the teacher's distribution. For k=256: - Teacher identifies its top-256 tokens (these cover >99% of probability mass for typical distributions) - Student only evaluates softmax for those 256 tokens + the ground-truth token if not in top-256 - KL computed over this subset, with the remaining probability mass assigned to an "other" catch-all class FLOPs: 256/50000 = 0.5% of full softmax. Bias: the truncated probability mass (~1%) is ignored or aggregated, introducing a small KL underestimation. **Method 2: Sampled Softmax / NCE** Treat the KL divergence between two categoricals as an importance-weighted sampled problem:
$KL ≈ Σ_{i~q} (p_t(i)/q(i)) · log(p_t(i)/p_s(i))
$
where q is a proposal distribution (e.g., the teacher's distribution itself, or a uniform+teacher mixture). Sample m=512 tokens according to q and compute the reweighted estimate. FLOPs: 512/50000 = 1% of full softmax. The estimate is unbiased but has variance. With the teacher as proposal (importance sampling), variance is low because p_t/q is close to 1 for sampled tokens. **Comparison:** Top-K is simpler and deterministic. Sampled softmax is unbiased but stochastic. In practice, both work well; top-K is more common in LLM distillation libraries. For both methods, the student still needs to compute the full embedding dot product to get logits for the selected tokens, but it avoids the full-softmax normalization.

Problem 5

A 70B-parameter teacher model achieves 70% on a benchmark. A 7B-parameter student trained from scratch on the same data achieves 58%. After distillation (teacher soft targets on the same data), the 7B student achieves 64%. The teacher costs $10 per 1M tokens for inference; the student costs $0.50 per 1M tokens. Training data is 1B tokens. (a) What's the total distillation cost (teacher inference + student training)? (b) How many inference tokens must the student process after deployment to break even vs using the teacher?

Answer (a) Teacher inference for distillation: 1B tokens × $10/1M = $10,000. Student training: 1B tokens (student trains on the same data, generating its own logits). Student training cost is separate from inference cost — let's say $1,000 for 1B tokens of training compute. Total distillation cost: $10,000 + $1,000 = $11,000. (b) Break-even: savings per inference token = $10 − $0.50 = $9.50 per 1M tokens. Break-even volume: $11,000 / ($9.50/1M) = 1,158M tokens ≈ 1.16B tokens. After serving 1.16B tokens, the distillation investment is recovered. At 1M tokens per query and 1000 queries/day: 1B/day → break-even in ~1 day. However, there's an accuracy gap: 70% (teacher) vs 64% (distilled student). If accuracy degradation costs more than the inference savings (e.g., in a revenue-generating application), the break-even calculation needs to account for quality-adjusted value. With accuracy-weighted value: Teacher provides 70% accuracy × value_per_correct. Student: 64%. If value is linear in accuracy, break-even tokens = $11,000 / ($9.50/1M − value_loss_per_token) — the denominator shrinks, requiring more tokens for break-even.

Summary

  1. Knowledge distillation transfers a teacher's "dark knowledge" via temperature-scaled soft targets: L = (1−α)·L_CE + α·T²·D_KL(p_t^T||p_s^T), where T² compensates for the 1/T gradient attenuation from softmax temperature scaling
  2. Temperature T controls the softness of teacher targets: low T ≈ label smoothing, medium T (2-5) optimally reveals similarity structure among incorrect classes, high T reduces to logit covariance matching
  3. Variants extend beyond output logits: feature-based distillation matches intermediate representations, relation-based distillation preserves pairwise/triplet relationships, attention transfer matches attention patterns
  4. Distillation outperforms training from scratch on the same data because soft targets encode richer information (class relationships) and the teacher encodes knowledge from its own (potentially larger) training set
  5. For LLM distillation, the primary challenges are the teacher's inference cost on training data and the massive vocabulary size requiring approximations (top-k KL, sampled softmax) for efficient gradient computation


Next Steps

Continue to 19-06 — Speculative Decoding to understand how a small draft model can accelerate large model inference through speculative execution with mathematical acceptance guarantees.