Math graphic
📐 Concept diagram

25-03 — Grokking

Phase: 25 — Frontiers & Active Research Areas Subject: 25-03 Prerequisites: Phase 14 (Optimization), Phase 18–19 (LLM Math), 25-01 Next subject: 25-04 — Double Descent


Learning Objectives

By the end of this subject, you will be able to:

  1. Define grokking and distinguish it from standard overfitting
  2. Explain the role of weight decay in inducing the grokking phase transition
  3. Analyze measures of grokking including test accuracy spike timing
  4. Understand grokking as a phase transition in the loss landscape
  5. Relate grokking to the broader phenomena of delayed generalization and simplicity bias

Core Content

What Is Grokking?

Grokking (Power et al., 2022) is a phenomenon where a neural network suddenly transitions from memorization (perfect training accuracy, chance test accuracy) to generalization (perfect test accuracy), long after overfitting would normally be considered complete.

The canonical experimental setup: - Task: Modular arithmetic (e.g., compute $a + b \pmod{p}$ for prime $p$) - Data: A small fraction of all possible input-output pairs (e.g., 50% of $p^2$ pairs) - Architecture: A small transformer (1-2 layers) - Training: Full-batch gradient descent with weight decay

The signature grokking curve:

Phase Training Accuracy Test Accuracy What's happening
Memorization Rises to 100% Stays at chance (~1/p) Network memorizes training examples
Plateau Remains at 100% Remains at chance Network is "comfortable" — loss is low on train set
Grokking Remains at 100% Suddenly jumps to ~100% Network discovers the general algorithm

⚠️ CRITICAL: Grokking is NOT standard overfitting-then-recovery. In standard overfitting, test accuracy degrades as training continues. In grokking, test accuracy stays at chance level (the network truly hasn't learned the pattern) and then jumps to perfect — a discrete phase transition.

The Mathematics of the Grokking Setup

For modular addition $a + b \equiv c \pmod{p}$:

The memorization solution: the network learns an arbitrary lookup table mapping training pairs to correct outputs. Training loss → 0, test accuracy ≈ 1/p (chance).

The generalizing solution: the network discovers the mathematical structure — modular arithmetic is a group $(a, b) \mapsto (a + b) \bmod p$, represented internally via trigonometric features (discrete Fourier transform over the group $\mathbb{Z}_p$).

Weight Decay as the Driver

The crucial insight: grokking requires weight decay. Without weight decay, the network stays in the memorization solution indefinitely. With weight decay, there's continuous pressure to shrink weights, which eventually forces the network to discover the more parameter-efficient generalizing solution.

Weight decay adds the regularization term:

$$\mathcal{L}{\text{total}} = \mathcal{L}{\text{CE}}(\theta) + \frac{\gamma}{2}|\theta|_2^2$$

where $\gamma$ is the weight decay coefficient. This penalizes large weights, making the memorization solution (which requires precise, large weights for each training example) eventually more expensive than the generalizing solution (which uses compact trigonometric structure).

Key dynamic: The memorization solution has low training loss but high weight norm (many large parameters). The generalizing solution has zero training loss and low weight norm. Weight decay continuously reduces weight norm, and at a critical point, the network "falls into" the generalizing basin.

Grokking as a Phase Transition

Grokking can be understood as a first-order phase transition in the loss landscape. Consider the "effective loss":

$$\mathcal{L}{\text{eff}}(\theta) = \mathcal{L}{\text{CE}}(\theta) + \frac{\gamma}{2}|\theta|_2^2$$

Initially, the network is in the memorization basin — a local minimum of the training loss (with zero CE loss) but high weight norm. The generalizing solution is a different basin with zero CE loss and low weight norm.

As training proceeds with weight decay, the effective norm $|\theta|$ decreases. When it crosses a critical threshold $|\theta|_c$, the memorization basin disappears (or becomes higher-energy than the barrier), and the network transitions to the generalizing basin.

This can be modeled as a bifurcation in the gradient flow:

$$\dot{\theta} = -\nabla_\theta \mathcal{L}_{\text{total}}$$

At the bifurcation point, the memorization fixed point loses stability and the trajectory flows to the generalizing fixed point.

Measures of Grokking

Quantitative measures of the grokking phenomenon:

  1. Grokking time $t_{\text{grok}}$: The training step at which test accuracy first exceeds some threshold (e.g., 90%)
  2. Memorization time $t_{\text{mem}}$: The step at which training accuracy first reaches 100%
  3. Grokking gap: $\Delta t = t_{\text{grok}} - t_{\text{mem}}$ — the delay between memorization and generalization
  4. Sharpness of transition: $\Delta t / t_{\text{grok}}$ or the slope of the accuracy curve at $t_{\text{grok}}$

The grokking gap depends on: - Weight decay $\gamma$: Larger $\gamma$ → faster grokking (more pressure to shrink weights) - Training set size: Larger training set → slower grokking (memorization is harder, so the gap may close) - Model size: Larger models → typically longer grokking gaps (more capacity to memorize) - Data distribution: Certain fraction of data (around 40-60%) maximizes the grokking gap

Theoretical Explanations

Several theoretical frameworks explain grokking:

1. Slingshot mechanism (Thilak et al., 2022): The loss Hessian $\nabla^2 \mathcal{L}_{\text{CE}}$ has two clusters of eigenvalues. During memorization, the optimizer follows the large-eigenvalue directions (memorizing specific examples). The small-eigenvalue directions correspond to the generalizing solution but have near-zero gradient until weight decay pushes the parameters into a regime where those directions become visible.

2. Simplicity bias: Neural networks have an inductive bias toward "simple" functions (low Kolmogorov complexity, low Fourier frequency). The generalizing solution (modular arithmetic implemented via trigonometric identities) is simpler in some sense than the memorization solution (an arbitrary lookup table). Weight decay amplifies this simplicity bias.

3. Loss landscape geometry: The memorization basin is wide in some directions (allowing gradient descent to find it quickly) but has high curvature in others (making it susceptible to weight decay). The generalizing basin is narrower but has lower overall weight norm. The transition occurs when weight decay "squeezes" the network out of the memorization basin.

Beyond Modular Arithmetic

Grokking has been observed in: - Polynomial regression: Networks grok higher-degree polynomials after fitting lower-degree ones - Group operations: Any finite group operation $(g, h) \mapsto gh$ in permutation groups, dihedral groups - MNIST with few labels: Semi-supervised grokking where test accuracy spikes after many epochs - Algorithmic tasks: Learning to execute simple programs, sorting, graph algorithms



Key Terms

Worked Examples

Example 1: Computing the Test Accuracy Spike

Problem: A transformer is trained on modular addition $\bmod 97$ with 50% of data. The test accuracy log shows: - Step 0–10K: ~1% (chance) - Step 10K–100K: ~1% (plateau) - Step 101K: 57% - Step 102K: 98% - Step 103K: 100%

Compute $t_{\text{mem}}$, $t_{\text{grok}}$, and the grokking gap. Assume $t_{\text{mem}} = 7,000$ steps. What's the sharpness?

Solution: - $t_{\text{mem}} = 7,000$ - $t_{\text{grok}} = 101,000$ (first step above 90%) - Grokking gap: $\Delta t = 101,000 - 7,000 = 94,000$ steps - Sharpness: test accuracy goes from ~1% to 98% in ~2,000 steps → very sharp transition (about 2% of $t_{\text{grok}}$)

The network spent 94K steps at chance-level test accuracy while maintaining perfect training accuracy — this is the hallmark grokking signature.

Click for answer $t_{\text{mem}} = 7,000$, $t_{\text{grok}} = 101,000$, gap = 94,000 steps. The transition is extremely sharp (2K steps for 1% → 98%), consistent with a first-order phase transition. The network "suddenly" generalizes after a long memorization plateau.

Example 2: Weight Decay Scaling

Problem: An experiment varies weight decay $\gamma$ while keeping all other hyperparameters fixed. Results: - $\gamma = 0$: No grokking (test accuracy stays at chance after 500K steps) - $\gamma = 0.01$: Grokking at $t \approx 200K$ - $\gamma = 0.1$: Grokking at $t \approx 50K$ - $\gamma = 1.0$: Grokking at $t \approx 10K$, but final test accuracy only 85%

Explain the pattern. Why does $\gamma=1.0$ fail to reach 100%?

Solution: - $\gamma=0$: Without weight decay, no pressure to leave the memorization basin. The network stays there indefinitely. - $\gamma$ increasing from 0.01 to 0.1: Stronger weight decay → faster norm reduction → earlier transition. $\Delta t \propto 1/\gamma$ (approximately). - $\gamma=1.0$: Weight decay is too strong. The optimization is dominated by the regularization term, preventing the network from reaching any low-training-loss solution. The network is stuck in a high-loss regime where neither memorization nor generalization is achieved.

The optimal $\gamma$ balances: enough regularization to induce grokking, but not so much that it prevents learning altogether.

Click for answer $\gamma$ acts as a clock speed for grokking — larger values accelerate the transition. But excessive weight decay ($\gamma=1.0$) prevents the network from fitting even the training data, so it never reaches either solution. The relationship $\Delta t \propto 1/\gamma$ holds in the moderate-$\gamma$ regime.

Example 3: Training Set Size and Grokking

Problem: On modular addition mod 97, you vary the training fraction: - 10% training data: train acc reaches 100% at step 5K, test stays at 1% permanently - 30% training data: train acc 100% at step 8K, grokking at 150K - 50% training data: train acc 100% at step 10K, grokking at 100K - 80% training data: train acc 100% at step 15K, test acc slowly rises to 95% by step 20K (no sharp spike)

Explain why grokking disappears at both extremes.

Solution: - 10% data: Only 97 × 0.10 = ~94 training examples. The memorization solution is very cheap (few examples to memorize), so weight decay pressure is insufficient to force a transition. The generalizing solution may not even be reachable from this limited data. - 30-50% data: The "sweet spot" — enough data that memorization is costly (many large weights needed), but not so much that the generalizing solution is obvious. The contrast between memorization cost and generalization cost is maximized. - 80% data: With 776 training examples out of 970, the training distribution is dense enough that simple interpolation approaches the generalizing solution. The network implicitly generalizes during training without a sharp phase transition.

Click for answer Grokking requires an *intermediate* amount of training data — enough to make memorization expensive, but not so much that the task is trivial. This is sometimes called the "Goldilocks zone" for grokking. At low data, memorization is cheap; at high data, generalization happens naturally.

Practice Problems

Problem 1: A neural network trained on addition mod 61 shows training accuracy reaching 100% at step 3,000 and test accuracy jumping from 1.6% to 98% between steps 90,000 and 93,000. Compute the grokking gap, and discuss whether this is "clean" grokking.

Answer (click to expand) Grokking gap: 90,000 - 3,000 = 87,000 steps. The test accuracy jump is extremely sharp (3,000 steps), and the gap is large (29× the memorization time). This is a textbook example of clean grokking — a long plateau followed by a sharp phase transition.

Problem 2: You train a model without weight decay and observe no grokking after 1M steps. You then take the checkpoint at step 1M and resume training with weight decay. Would you expect grokking? Why or why not?

Answer (click to expand) Yes, you would likely observe grokking — possibly faster than training from scratch. The 1M-step checkpoint has found a memorization solution with some (potentially large) weight norm. Adding weight decay now creates the pressure to shrink weights, and since the solution is already at a memorization minimum, the network can transition to the generalizing basin quickly. This is essentially "inducing grokking on demand" by toggling weight decay. Experimentally, this has been confirmed: networks trained without weight decay, then fine-tuned with weight decay, do grok.

Problem 3: How does the grokking phenomenon relate to the "lottery ticket hypothesis" and the idea of sparse subnetworks?

Answer (click to expand) The connection is that the generalizing solution in grokking can be seen as a sparse subnetwork (a "winning ticket") within the larger network. During memorization, the network uses many parameters to store the lookup table. The generalizing solution uses a compact, structured set of weights (the trigonometric implementation of modular arithmetic). Weight decay acts as implicit pruning, driving the network toward the sparse, structured subnetwork. This connects grokking to the observation that neural networks contain sparse subnetworks that achieve the same performance when trained in isolation.

Problem 4: You're designing a grokking experiment. For a task with $N = p^2 = 3721$ total examples (mod 61), what training fraction would you choose to maximize the grokking signal? Justify your answer.

Answer (click to expand) Based on empirical studies, training fractions of 40-50% maximize the grokking gap. For $p=61$, this means 1,488–1,860 training examples. This fraction provides enough examples that memorization is expensive (requiring large weights), but leaves enough held-out data that the test set clearly demonstrates the phase transition. Going below 30% or above 70% reduces or eliminates the sharp transition. The precise optimal fraction depends on $p$ and the architecture, but 40-50% is a robust starting point.

Problem 5: Propose a mathematical criterion to automatically detect the grokking transition point $t_{\text{grok}}$ from a loss/accuracy log, without human inspection.

Answer (click to expand) A robust automatic detector: 1. Compute test accuracy $a_t$ for each step $t$ 2. Smooth with a moving average (window ~100 steps) 3. Compute the discrete derivative $\Delta a_t = a_{t+1} - a_t$ 4. Define $t_{\text{grok}}$ as $\arg\max_t \Delta a_t$ (the point of maximum increase) 5. Validate: check that $a_{t_{\text{grok}}-k} \approx 1/p$ (baseline) and $a_{t_{\text{grok}}+k} \gg 1/p$ (post-grok) Alternatively, fit a sigmoid $a(t) = c + \frac{d-c}{1+e^{-k(t-t_0)}}$ and identify $t_0$ as $t_{\text{grok}}$. The transition sharpness $k$ quantifies how abrupt the grok is.

Summary


Quiz

Question 1: What distinguishes grokking from standard overfitting?

A. In grokking, training accuracy decreases while test accuracy increases B. In grokking, test accuracy stays at chance level during the memorization phase and then jumps sharply C. Grokking happens at the start of training, before any memorization D. Grokking only occurs in the absence of weight decay

Correct Answer: B

Explanation - **If you chose A:** In grokking, training accuracy stays at 100% — it doesn't decrease. - **If you chose B:** Correct. The defining signature: chance-level test accuracy during the memorization phase, then a sudden jump. - **If you chose C:** Grokking happens *after* memorization, not before. - **If you chose D:** Grokking *requires* weight decay; it doesn't happen without it.

Question 2: What role does weight decay play in grokking?

A. It prevents the network from memorizing the training data B. It creates continuous pressure to reduce weight norm, eventually forcing a transition from memorization to generalization C. It increases training accuracy faster D. It has no effect on grokking

Correct Answer: B

Explanation - **If you chose A:** Weight decay doesn't prevent memorization — the network still reaches 100% training accuracy. - **If you chose B:** Correct. Weight decay penalizes large weights, making the memorization solution increasingly expensive until the generalizing solution becomes favorable. - **If you chose C:** Weight decay typically slows down training, not accelerates it. - **If you chose D:** Weight decay is *essential* for grokking — without it, no grokking occurs.

Question 3: In the grokking phase transition analogy, what is the "order parameter"?

A. Training loss B. Test accuracy C. Weight decay coefficient D. Model size

Correct Answer: B

Explanation - **If you chose A:** Training loss is always near zero during the plateau — it doesn't indicate the transition. - **If you chose B:** Correct. Test accuracy is the order parameter — it jumps discontinuously at the transition point. - **If you chose C:** Weight decay is the control parameter that drives the transition, not the order parameter. - **If you chose D:** Model size is fixed during a single training run.

Question 4: Why is modular arithmetic used as the canonical grokking task?

A. It's the only task where grokking occurs B. It has a clear mathematical structure (group theory) that makes the generalizing solution well-defined C. It requires the largest possible models D. It has no test set

Correct Answer: B

Explanation - **If you chose A:** Grokking has been observed in many other tasks. - **If you chose B:** Correct. Modular arithmetic's group structure (Fourier basis, trigonometric identities) provides a clean mathematical description of the generalizing solution, making it ideal for study. - **If you chose C:** Grokking typically uses small transformers (1-2 layers, small embedding dimension). - **If you chose D:** The task has a well-defined test set (the held-out pairs).

Question 5: If you observe no grokking after 500K training steps, which change is MOST likely to induce it?

A. Remove weight decay B. Increase model size dramatically C. Add moderate weight decay if none was used D. Train on 100% of the data

Correct Answer: C

Explanation - **If you chose A:** Removing weight decay eliminates the driving force for grokking. - **If you chose B:** Larger models tend to have longer grokking gaps or no grokking at all. - **If you chose C:** Correct. If no weight decay was used, adding it is the most reliable way to induce grokking. - **If you chose D:** Training on 100% data makes the task a standard supervised learning problem — no grokking signal.

Question 6: What does the "simplicity bias" explanation of grokking propose?

A. Networks always prefer complex solutions B. Networks have an inductive bias toward simpler functions, and weight decay amplifies this to favor the generalizing solution over memorization C. Simpler solutions are always found first in training D. Grokking is caused by numerical precision issues, not learning dynamics

Correct Answer: B

Explanation - **If you chose A:** The opposite — networks are biased toward simplicity. - **If you chose B:** Correct. The generalizing solution (trigonometric implementation) has lower effective complexity than the memorization solution (lookup table), and weight decay amplifies this preference. - **If you chose C:** During grokking, the complex solution (memorization) is found first — only later does simplicity win. - **If you chose D:** Grokking is a real learning phenomenon, not a numerical artifact.

Pitfalls

  1. Confusing grokking with standard test error recovery: Grokking requires the test accuracy to stay at chance level, not just dip temporarily
  2. Training without weight decay: Without weight decay, grokking simply will not occur — this is the most common experimental error
  3. Using too little or too much training data: Extremes suppress the grokking signal; 40-50% is the sweet spot for modular arithmetic
  4. Misinterpreting the plateau: The plateau is not "the network isn't learning" — it's learned a perfect memorization solution; it's waiting for weight decay to force a transition

Pitfalls

  1. Assuming any delayed generalization is grokking: True grokking requires test accuracy to stay at chance level during the memorization plateau, then jump sharply. If test accuracy gradually improves after some initial dip, that's epoch-wise double descent (25-04), not grokking. The distinction matters because they have different mechanisms: grokking is driven by weight decay forcing a structural phase transition; epoch-wise double descent is driven by implicit regularization in SGD.

  2. Training without logging intermediate checkpoints: The grokking transition can be extremely sharp — test accuracy can jump from 1% to 98% in a few hundred steps out of 100K+. Without frequent checkpointing (every 100-1000 steps), you'll miss the transition entirely and only see "before" and "after" — losing the most scientifically interesting data. Log test accuracy at high temporal resolution.

  3. Using the wrong data fraction: The grokking gap is maximized at 40-50% training data for modular arithmetic. Below 30%, memorization is too cheap; above 70%, the generalizing solution is obvious from interpolation. If you're not seeing grokking, the training fraction is the first hyperparameter to check after weight decay.

  4. Confusing grokking with beneficial overfitting: In grokking, the test accuracy stays at chance — the model has no generalization ability during the plateau. In beneficial overfitting, the model generalizes somewhat (above chance) even while overfitting noise. The chance-level plateau is the critical diagnostic: if test accuracy is above $1/p$, the model has already started to generalize, and you're seeing a different phenomenon.


Next Steps

25-04 — Double Descent — another counterintuitive phenomenon in modern deep learning, where bigger models don't just avoid overfitting, they actively get better past a critical size threshold.