Math graphic
📐 Concept diagram

20-05 — Instruction Tuning (SFT)

Phase: 20 — Training & Fine-tuning Mathematics Subject: 20-05 Prerequisites: 18-08 (Inference Mathematics), 17-04 (Language Modeling Loss), 18-05 (Decoder-Only Architecture), 20-01 (Learning Rate Schedules), 09-03 (SVD — for data quality concepts) Next subject: 20-06 — RLHF Mathematics


Learning Objectives

By the end of this subject, you will be able to:

  1. Formulate the SFT loss function with prompt masking and prove that masking is equivalent to zero-weighting the prompt tokens in the cross-entropy sum
  2. Derive why SFT on instruction-response pairs learns P(response | instruction) rather than P(instruction, response)
  3. Analyze the gradient contribution ratio between prompt and response tokens, and prove that without masking, optimization is dominated by prompt modeling
  4. Compute the effective number of training tokens in SFT given a dataset of prompts and responses
  5. Explain the data quality scaling laws for SFT — why 1K high-quality examples can outperform 50K mediocre ones

Core Content

1. What is Instruction Tuning?

After pre-training (next-token prediction on vast text corpora), an LLM can complete text but doesn't reliably follow instructions. Instruction tuning (also called Supervised Fine-Tuning, SFT) trains the model on (instruction, response) pairs, teaching it to generate helpful responses when given instructions.

Key mathematical difference from pre-training: In pre-training, the model learns P(token_t | token_{<t}) for all tokens. In SFT, we specifically want P(response | instruction), NOT P(instruction). We must NOT train the model to generate the instruction — that would be circular (the model already receives the instruction as input).


2. The SFT Loss Function

For a dataset of N examples, each with instruction x and response y:

$L_SFT(θ) = −(1/N) Σ_{i=1}^N Σ_{t=1}^{|y_i|} log P_θ(y_{i,t} | x_i, y_{i,<t})
$

where: - θ = model parameters - x_i = instruction tokens (prompt) - y_i = response tokens - y_{i,t} = the t-th token of the i-th response - P_θ(y_t | x, y_{<t}) = softmax(logits)_{y_t} — probability the model assigns to the correct next token

Critical detail — the prompt is in the CONTEXT but NOT in the LOSS:

The model processes the full sequence [x, y]. The forward pass produces logits for EVERY position. But we only compute the loss over positions corresponding to response tokens y. The instruction tokens x are included in the context (so the model can attend to them), but their prediction loss is MASKED.

Masking implementation:

For position t:
    loss_t = {
        −log P_θ(token_t | tokens_{<t})    if position t is in the response
        0                                    if position t is in the instruction
    }

⚠️ THIS IS CRITICAL — Without prompt masking, the model would spend ~50% of its capacity learning to reproduce the instruction, defeating the purpose of instruction tuning. This is the #1 implementation bug in SFT.


3. Mathematical Justification for Prompt Masking

Without masking, the loss is:

$L_unmasked = −E_{(x,y)} [log P_θ(x, y)]
           = −E[log P_θ(x)] − E[log P_θ(y|x)]
$

The first term −E[log P_θ(x)] trains the model on the instruction distribution — useless, since we're always going to PROVIDE the instruction. The second term is what we actually want.

With masking, the loss is:

$L_masked = −E[log P_θ(y|x)]
$

This trains the model ONLY on P(response | instruction), which is exactly the conditional distribution needed for instruction following.

Gradient perspective: Consider the gradient contribution from each token. Without masking:

$∇L = (1/(|x|+|y|)) Σ_{t in [x,y]} ∇(−log P_θ(token_t))
$

The gradient is an average over |x|+|y| tokens. If instructions and responses have similar lengths, roughly half the gradient signal comes from instruction tokens — wasted. With masking:

$∇L_masked = (1/|y|) Σ_{t in y} ∇(−log P_θ(token_t))
$

Now 100% of the gradient signal goes toward learning the response distribution.


4. SFT Training Dynamics

4.1 Loss Curve Interpretation

Unlike pre-training where loss decreases slowly over trillions of tokens, SFT loss drops rapidly (within hundreds to thousands of steps). This is because: 1. The model already knows language — it's learning a specific mapping, not language itself 2. SFT datasets are small (typically 1K–100K examples) 3. The model is initialized from a strong pre-trained checkpoint

4.2 Learning Rate

SFT typically uses a much lower learning rate than pre-training:

$η_SFT ≈ η_pretrain / 10 to η_pretrain / 100
$

Too high: catastrophic forgetting — the model overwrites its pre-trained knowledge. Too low: insufficient adaptation.

4.3 Epochs

Unlike pre-training (typically 1 epoch or less), SFT can benefit from multiple epochs:

Typical: 1–3 epochs for small datasets, 1 epoch for large datasets

Multiple epochs increase the risk of overfitting to the specific phrasing of the training examples.


5. Data Quality vs Quantity

A key finding from instruction tuning research (Zhou et al., 2023 — LIMA; Chung et al., 2022 — FLAN):

LIMA (Less Is More for Alignment): A 65B model fine-tuned on only 1,000 carefully curated examples performed competitively with models trained on 50K+ examples. The quality of examples (diversity, clarity, correctness) matters far more than quantity in SFT.

Mathematical intuition: The model's pre-training already provides a strong prior over language. SFT only needs to "nudge" the distribution toward instruction-following behavior. A few high-quality gradient steps in the right direction suffice. Adding noisy or contradictory examples creates competing gradient signals that dilute the desired behavior.

Diversity matters: Even with 1K examples, coverage of diverse instruction types (reasoning, creative, factual, coding, etc.) is more important than having 10K examples of a single type.


6. Loss Computation — Full Derivation

For a transformer with vocabulary size V, hidden dimension d:

  1. Forward pass: Process tokens [x, y] = [t_1, t_2, ..., t_L]
  2. Final hidden states: h_1, h_2, ..., h_L ∈ ℝ^d
  3. Logits: For each position i: z_i = W_lm_head · h_i ∈ ℝ^V
  4. Probabilities: p_i = softmax(z_i) ∈ ℝ^V (probability distribution over vocabulary)
  5. Loss per response position:
$ℓ_i = −log(p_i[y_i])
    = −log(exp(z_i[y_i]) / Σ_{j=1}^V exp(z_i[j]))
    = −z_i[y_i] + log Σ_{j=1}^V exp(z_i[j])
$
  1. Total loss (masked):
$L = (1/|y|) Σ_{i: position i is in response} ℓ_i
$

Gradient w.r.t. logits at response position i:

∂ℓ_i / ∂z_i[k] = p_i[k] − δ_{k, y_i}

This is the standard cross-entropy gradient: the model's predicted probability minus the one-hot target. We want to INCREASE p_i[y_i] and DECREASE all other p_i[k].

For masked (instruction) positions: gradient = 0 (loss doesn't depend on these positions).


7. Packing and Sequence Construction

SFT data is typically packed to maximize GPU utilization:

Without packing (one example per sequence): A sequence with 50 instruction tokens and 100 response tokens wastes 50/150 = 33% of compute on non-loss positions.

With packing (multiple examples per sequence):

$[INST_1, RESP_1, INST_2, RESP_2, ..., INST_K, RESP_K]
$

Each example's response is loss-masked to itself. The attention mask must be adjusted so INST_k cannot attend to RESP_{k+1} (causal, each example is independent). This is achieved by setting the attention mask to block cross-example attention while maintaining causal attention within each example.

Packing efficiency:

$Compute efficiency = Σ|y_i| / (Σ|x_i| + Σ|y_i|)
$

Higher is better. Long instructions reduce efficiency.


8. Chat Templates and Special Tokens

Modern SFT formats use structured chat templates (e.g., Llama's <|user|>, <|assistant|>). These tokens should be:

For Llama-style format:

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{response}<|eot_id|>

Typically, the template tokens (like <|eot_id|>) ARE included in the loss computation, as learning when to stop is part of instruction following. Only the user's instruction text is masked.


Worked Examples

Example 1: Computing SFT Loss

Problem: An instruction "What is 2+2?" (tokenized as [102, 374, 465, 220, 220, 30]) gets the response "4" (tokenized as [18]). The model's logits at the response position give probability 0.6 to token 18 and 0.4 distributed among other tokens. Compute the masked SFT loss for this example.

Solution:

Only the response position contributes to the loss. At the response position:

$p_18 = 0.6
ℓ = −log(0.6) = −(−0.5108) = 0.5108
$

Total loss for this example = 0.5108. The instruction's 6 tokens contribute zero to the loss.

If we had NOT masked: the instruction tokens would also contribute their cross-entropy, likely averaging 2-5 each, dwarfing the response contribution.


Example 2: Gradient Analysis

Problem: An example has 20 instruction tokens and 80 response tokens. Without masking, what fraction of the total gradient magnitude comes from response tokens? With masking? (Assume per-token loss gradients have similar magnitude.)

Solution:

Without masking: Response contributes 80/(20+80) = 80% of the gradient. Instruction contributes 20%. The model spends 20% of its capacity learning to reproduce instructions — wasteful.

With masking: Response contributes 100% of the gradient. All optimization effort goes toward P(response|instruction).

Gradient norm ratio: With masking, the gradient magnitude per response token is:

$||∇L_masked|| ∝ (1/80) · 80 = 1.0 (normalized)
$

Without masking:

||∇L_unmasked|| ∝ (1/100) · 80 = 0.8 (for response) + (1/100)·20 = 0.2 (for instruction)
Total ∝ 1.0

But only 0.8 is useful. Masking gives 1.0 useful gradient per example vs 0.8 — a 25% improvement.


Example 3: Packing Efficiency

Problem: You have 5 SFT examples with the following token counts:

Example Instruction tokens Response tokens
1 15 45
2 30 20
3 10 60
4 50 30
5 5 35

Compute the compute efficiency if packed into one sequence vs trained separately. (Assume sequence padding to max length for unpacked, and no padding for packed.)

Solution:

Separate sequences (unpacked, padded to max length = 80 per example): - Total tokens processed = 5 × 80 = 400 - Loss-masked tokens (responses) = 45+20+60+30+35 = 190 - Efficiency = 190/400 = 47.5%

Packed (single sequence, no padding): - Total tokens = (15+45+30+20+10+60+50+30+5+35) = 300 - Loss-masked tokens (responses) = 190 - Efficiency = 190/300 = 63.3%

Packing improves compute efficiency by ~16 percentage points.



Quiz

Q1: What does the concept of Without masking primarily refer to in this subject?

A) A visual representation of Without masking B) A computational error related to Without masking C) The definition and application of Without masking D) A historical anecdote about Without masking

Correct: C)

Q2: What is the primary purpose of With masking?

A) It is primarily a historical notation system B) It is used only in advanced research contexts C) It replaces all other methods in this domain D) It is used to with masking in mathematical analysis

Correct: D)

Q3: Which statement about Prompt masking is TRUE?

A) Prompt masking is mentioned only as a historical footnote B) Prompt masking is a fundamental concept covered in this subject C) Prompt masking is an advanced topic beyond this subject's scope D) Prompt masking is not related to this subject

Correct: B)

Q4: Based on the worked examples in this subject, what is the correct result?

A) 1,500,000 tokens. B) An unrelated numerical value C) The inverse of the correct answer D) A different result from a common mistake

Correct: A)

Q5: How are Prompt masking and SFT is computationally cheap related?

A) Prompt masking and SFT is computationally cheap are completely unrelated topics B) Prompt masking is the inverse of SFT is computationally cheap C) Prompt masking and SFT is computationally cheap are closely related concepts D) Prompt masking is a special case of SFT is computationally cheap

Correct: C)

Q6: What is a common pitfall when working with Data quality dominates quantity?

A) Data quality dominates quantity has no common misconceptions B) Data quality dominates quantity is always computed the same way in all contexts C) The main error with Data quality dominates quantity is using it when it is not needed D) A common mistake is confusing Data quality dominates quantity with a similar concept

Correct: D)

Q7: When should you apply Low learning rates and few epochs?

A) Avoid Low learning rates and few epochs unless explicitly instructed B) Low learning rates and few epochs is not practically useful C) Apply Low learning rates and few epochs to solve problems in this subject's domain D) Use Low learning rates and few epochs only in pure mathematics contexts

Correct: C)

Practice Problems

Problem 1

Write the SFT loss for a single example with instruction x = [t₁, t₂] and response y = [t₃, t₄], showing which positions contribute.

Answer The full sequence is [t₁, t₂, t₃, t₄]. The model's forward pass produces logits for predicting t₂, t₃, t₄, t₅:
$L = −(1/2)[log P_θ(t₃ | t₁, t₂) + log P_θ(t₄ | t₁, t₂, t₃)]
$
Positions 1 and 2 (predicting t₂ and t₃ from previous context) are masked — they're part of the instruction. Only positions 3 and 4 contribute.

Problem 2

An SFT dataset has 10,000 examples. Average instruction length = 50 tokens, average response length = 100 tokens. The model trains at 8M tokens/sec. How long does 1 epoch take? What's the effective training tokens per second (only counting loss-contributing tokens)?

Answer **Total tokens per epoch:** 10,000 × (50 + 100) = 1,500,000 tokens. Time per epoch = 1.5M / 8M = 0.1875 seconds. Very fast — SFT is computationally cheap compared to pre-training. **Effective tokens/sec:** Only response tokens matter. Response tokens per epoch = 10,000 × 100 = 1M. Effective rate = 1M / 0.1875 = 5.33M tokens/sec (the remaining 2.67M tokens/sec are "wasted" on instruction processing). With better packing, this improves.

Problem 3

Prove that the SFT loss with mask M (M_i = 1 for response, 0 for instruction) is equivalent to unconstrained fine-tuning on a weighted dataset where each instruction token has weight 0 and each response token has weight 1.

Answer The cross-entropy loss for position i is ℓ_i = −log P_θ(token_i | tokens_{ ### Problem 4 Why might SFT cause "catastrophic forgetting" at high learning rates? Express mathematically.
Answer Catastrophic forgetting occurs when the SFT gradient overshadows the pre-trained parameter values. If pre-training converged to θ_pretrain ≈ argmin L_pretrain, then SFT updates θ = θ_pretrain − η·∇L_SFT. If η is too large, the update moves θ far from θ_pretrain, "forgetting" pre-training knowledge. The SFT loss decreases but L_pretrain increases (model loses general language ability). A simple bound: if the Hessian of L_pretrain at θ_pretrain is H, then:
$L_pretrain(θ_pretrain + Δ) ≈ L_pretrain(θ_pretrain) + ½Δ^T H Δ
$
With Δ = −η·∇L_SFT, the pre-training loss increase is ≈ ½η²·∇L_SFT^T H ∇L_SFT. Using a small η keeps this term manageable.
### Problem 5 You have 100 SFT examples. You train for 3 epochs. A colleague says "that's only 300 gradient steps — how can that possibly be enough?" Calculate the information content: 100 examples × 3 epochs × average response tokens, and explain why it IS enough.
Answer Assuming average response length of 200 tokens:
$3 epochs × 100 examples × 200 tokens/example = 60,000 training tokens
$
This seems tiny compared to pre-training (trillions). But the model already knows language, facts, and reasoning from pre-training. SFT only needs to shift the model's behavior from "complete this text" to "respond to this instruction." The pre-trained model is near a good solution; a few gradient steps in the right direction suffice. Additionally, each SFT example carries high "information density" — it's a clear demonstration of desired behavior, unlike raw text which is mostly low-signal.
--- ## Summary 1. **SFT trains on (instruction, response) pairs** with the loss computed ONLY on response tokens — instruction tokens contribute to attention context but not to the loss 2. **Prompt masking** is mathematically essential: without it, the model wastes capacity learning P(instruction), and the gradient signal for learning P(response|instruction) is diluted 3. **SFT is computationally cheap** — datasets of 1K–100K examples train in minutes on a single GPU, compared to months for pre-training 4. **Data quality dominates quantity** — 1K diverse, high-quality examples can outperform 50K noisy ones because the pre-trained model only needs a "nudge" 5. **Low learning rates and few epochs** prevent catastrophic forgetting while allowing sufficient adaptation to the instruction-following distribution --- ## Pitfalls - **Forgetting to mask prompt tokens in the loss.** This is the #1 implementation bug in SFT. Without masking, the model wastes roughly half its gradient signal learning to reproduce the instruction text — exactly the wrong objective. The model should learn P(response | instruction), not P(instruction, response). Always set loss weights to 0 for instruction positions and 1 for response positions. - **Using pre-training-scale learning rates.** SFT requires learning rates ~1/10 to 1/100 of the pre-training LR. Using the pre-training LR causes catastrophic forgetting: large parameter updates overwrite the model's language understanding, factual knowledge, and reasoning abilities. The SFT loss decreases but the model loses general competence. - **Training for too many epochs on small datasets.** With only 1K-10K examples, 3+ epochs cause the model to memorize exact phrasing, punctuation, and formatting quirks of the training data. It becomes brittle — responding correctly to the training prompts but failing on paraphrased versions. For small, high-quality datasets, 1-2 epochs is usually optimal. - **Assuming more data always helps in SFT.** The LIMA paper demonstrated that 1K diverse, high-quality examples can outperform 50K noisy ones. Poor-quality examples (inconsistent formatting, contradictory instructions, hallucinations in responses) create competing gradient signals that dilute the desired behavior. Invest in data QUALITY (diversity, correctness, consistent formatting) before scaling quantity. - **Handling chat template tokens inconsistently.** Special tokens like `<|user|>`, `<|assistant|>`, and `<|eot_id|>` must be treated uniformly: if included in the context, decide whether their prediction should be loss-masked. Some frameworks mask only the instruction text and include template tokens in the loss; others mask all non-response tokens. Inconsistency (e.g., masking `<|assistant|>` sometimes but not others) creates confusing training signals. --- ## Key Terms | Term | Definition | |------|------------| | **Instruction tuning (SFT)** | Supervised fine-tuning on (instruction, response) pairs — teaches a pre-trained model to follow instructions | | **Prompt masking** | Computing loss ONLY on response tokens; instruction tokens contribute to attention context but not to the loss | | **Conditional distribution** | P(response | instruction) — what SFT learns; distinct from P(instruction, response) | | **Gradient dilution** | Without masking, instruction token gradients dilute the response-learning signal — masking ensures 100% useful gradient | | **Catastrophic forgetting** | Large SFT updates overwrite pre-trained knowledge; low LRs (~1/10 to 1/100 of pre-training) prevent this | | **Sequence packing** | Concatenating multiple examples into one sequence to reduce padding waste — requires attention masking | | **LIMA hypothesis** | 1K high-quality diverse examples can match 50K noisy ones — SFT only needs to "nudge" the pre-trained model | | **Chat template tokens** | Special tokens like `<|user|>`, `<|assistant|>` frame the instruction-response structure | --- ## Next Steps Continue to [20-06 — RLHF Mathematics](./20-06-rlhf-mathematics.md) to learn how reinforcement learning from human feedback extends SFT by optimizing for human preferences using a reward model and the PPO algorithm.