24-02 — Natural Gradient Descent
Phase: 24 — Information Geometry & Advanced Theory Subject: 24-02 Prerequisites: 24-01 Fisher Information, Phase 14 (Optimization — gradient descent, second-order methods) Next subject: 24-03 — Neural Tangent Kernel (NTK)
Learning Objectives
By the end of this subject, you will be able to:
- Derive the natural gradient update rule from the KL divergence constraint
- Explain why natural gradient is invariant to reparameterization (unlike standard gradient)
- Compute natural gradients for simple models (Bernoulli, Gaussian, categorical)
- Understand the connection to Newton's method and how natural gradient differs
- Describe practical approximations including K-FAC and their trade-offs
Core Content
Why Standard Gradient Descent Has a Problem
Standard gradient descent updates parameters in the direction of steepest descent in Euclidean space:
$$\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}(\theta_t)$$
But parameter space has no intrinsic geometry. Consider reparameterizing a Bernoulli distribution:
- Parameterization A: $\theta \in (0, 1)$ — the probability itself
- Parameterization B: $\phi = \log\frac{\theta}{1-\theta}$ — the log-odds
The same model, same data, but gradient descent behaves completely differently because $d\theta \neq d\phi$. The Euclidean distance $|\theta_1 - \theta_2|$ is not a meaningful measure of how different two distributions are.
⚠️ CRITICAL: In standard gradient descent, the update $\Delta\theta = -\eta\nabla\mathcal{L}$ is NOT invariant to reparameterization. If you change from $\theta$ to $\phi = g(\theta)$, the new update $\Delta\phi = -\eta\nabla_\phi\mathcal{L}$ does NOT correspond to the same distribution change as the original update. This is a fundamental flaw when optimizing over probability distributions.
The Natural Gradient
The natural gradient fixes this by measuring distance in distribution space, not parameter space. The key insight: use KL divergence as the distance metric:
$$D_{KL}(p_\theta \;|\; p_{\theta + d\theta}) \approx \frac{1}{2} d\theta^T I(\theta) d\theta$$
The Fisher information matrix $I(\theta)$ defines the local geometry. The natural gradient finds the direction that produces the steepest descent in loss per unit of KL divergence:
$$\tilde{\nabla}\mathcal{L}(\theta) = I(\theta)^{-1} \nabla_\theta \mathcal{L}(\theta)$$
The natural gradient update rule:
$$\boxed{\theta_{t+1} = \theta_t - \eta \, I(\theta_t)^{-1} \, \nabla_\theta \mathcal{L}(\theta_t)}$$
Formal derivation: We want to minimize $\mathcal{L}(\theta + d\theta)$ subject to $D_{KL}(p_\theta | p_{\theta+d\theta}) \leq \epsilon$ (a trust region). To first order:
$$\mathcal{L}(\theta + d\theta) \approx \mathcal{L}(\theta) + \nabla\mathcal{L}^T d\theta$$
Using the Fisher metric constraint $\frac{1}{2} d\theta^T I(\theta) d\theta \leq \epsilon$, the Lagrangian is:
$$\mathcal{L} = \nabla\mathcal{L}^T d\theta + \lambda\left(\frac{1}{2} d\theta^T I d\theta - \epsilon\right)$$
Setting the derivative to zero: $\nabla\mathcal{L} + \lambda I d\theta = 0 \implies d\theta = -\frac{1}{\lambda} I^{-1} \nabla\mathcal{L}$, which gives the natural gradient direction.
Reparameterization Invariance
This is the killer feature. Under any smooth invertible reparameterization $\phi = g(\theta)$:
- The standard gradient transforms as: $\nabla_\phi \mathcal{L} = J^{-T} \nabla_\theta \mathcal{L}$ where $J = \frac{\partial\phi}{\partial\theta}$
- The Fisher matrix transforms as: $I(\phi) = J^{-T} I(\theta) J^{-1}$
Therefore the natural gradient transforms as:
$$\tilde{\nabla}\phi \mathcal{L} = I(\phi)^{-1} \nabla\phi \mathcal{L} = (J^{-T} I J^{-1})^{-1} J^{-T} \nabla_\theta \mathcal{L} = J \, I^{-1} \nabla_\theta \mathcal{L} = J \, \tilde{\nabla}_\theta \mathcal{L}$$
This is exactly how a contravariant vector should transform. The parameter update $d\phi = J d\theta$ corresponds to the same step in distribution space regardless of parametrization:
$$d\phi^T I(\phi) d\phi = (J d\theta)^T (J^{-T} I J^{-1}) (J d\theta) = d\theta^T I(\theta) d\theta$$
⚠️ CRITICAL: Natural gradient is to standard gradient what covariant differentiation is to partial differentiation. The Fisher matrix plays the role of the metric tensor in Riemannian geometry. This is why this subject is called information geometry.
Natural Gradient vs Newton's Method
Both use preconditioning matrices, but they're fundamentally different:
| Aspect | Newton's Method | Natural Gradient |
|---|---|---|
| Preconditioner | $H = \nabla_\theta^2 \mathcal{L}(\theta)$ (Hessian of loss) | $I(\theta)$ (Fisher of model) |
| What it optimizes | Second-order Taylor of loss | First-order in distribution space |
| Invariance | Not reparam-invariant | Fully reparam-invariant |
| PSD guarantee | No — Hessian can be indefinite | Yes — Fisher is always PSD |
| Cost | $O(d^2)$ memory, $O(d^3)$ inversion | $O(d^2)$ memory, $O(d^3)$ inversion |
| Where from | Optimization theory | Information geometry |
The natural gradient is arguably more principled for probabilistic models because it respects the geometry of the problem. Newton's method uses the loss landscape directly, which depends on parameterization.
Computing Natural Gradients for Specific Models
Bernoulli Log-Likelihood with Cross-Entropy Loss
Model: $p(y=1 \mid \theta) = \sigma(\theta) = \frac{1}{1+e^{-\theta}}$. Loss: $\mathcal{L}(\theta) = -y\log\sigma(\theta) - (1-y)\log(1-\sigma(\theta))$.
Standard gradient: $\nabla_\theta\mathcal{L} = \sigma(\theta) - y$.
Fisher: $I(\theta) = \sigma(\theta)(1-\sigma(\theta))$.
Natural gradient:
$$\tilde{\nabla}_\theta\mathcal{L} = \frac{\sigma(\theta) - y}{\sigma(\theta)(1-\sigma(\theta))}$$
This is exactly the error in log-odds space. Natural gradient descent on logistic regression is equivalent to standard gradient descent on the log-odds parameters — it automatically "undoes" the sigmoid nonlinearity.
Mean-Field Gaussian Variational Inference
For a Gaussian variational posterior $q(z \mid \mu, \sigma^2) = \mathcal{N}(z \mid \mu, \text{diag}(\sigma^2))$, the Fisher matrix is block diagonal:
$$I(\mu, \sigma) = \begin{pmatrix} \text{diag}(1/\sigma_i^2) & 0 \ 0 & \text{diag}(2/\sigma_i^2) \end{pmatrix}$$
The natural gradient simplifies dramatically. For $\mu$:
$$\tilde{\nabla}{\mu_i}\mathcal{L} = \sigma_i^2 \cdot \nabla{\mu_i}\mathcal{L}$$
Natural gradient automatically scales the update by the variance — parameters with high uncertainty get larger updates. This is a form of automatic learning rate adaptation driven by the model's own uncertainty.
Categorical Softmax Classifier
For $K$-class softmax with logits $\mathbf{z}$, Fisher is:
$$I(\mathbf{z}) = \text{diag}(\boldsymbol{\pi}) - \boldsymbol{\pi}\boldsymbol{\pi}^T$$
This $(K \times K)$ matrix has rank $K-1$. The natural gradient requires the pseudoinverse or damping:
$$\tilde{\nabla}\mathbf{z}\mathcal{L} = (I(\mathbf{z}) + \lambda I)^{-1} \nabla\mathbf{z}\mathcal{L}$$
With damping $\lambda$, this is computationally tractable for moderate $K$.
Practical Approximations
The elephant in the room: $I(\theta)$ is $d \times d$, and $d$ can be millions in deep learning. Exact inversion is impossible. We need approximations.
Diagonal Approximation
Simplest approach: $I(\theta) \approx \text{diag}(I_{11}, I_{22}, \ldots, I_{dd})$.
Natural gradient becomes per-parameter scaling: $\Delta\theta_i = -\eta \cdot \frac{\nabla_i\mathcal{L}}{I_{ii}}$. Easy to compute but ignores correlations between parameters. Works surprisingly well for some problems.
K-FAC (Kronecker-Factored Approximate Curvature)
Key insight: in feedforward networks, the Fisher matrix for layer $\ell$ approximately factorizes as:
$$I_\ell \approx A_{\ell-1} \otimes G_\ell$$
where $A_{\ell-1} = \mathbb{E}[\mathbf{a}{\ell-1} \mathbf{a}{\ell-1}^T]$ (activation covariance) and $G_\ell = \mathbb{E}[\mathbf{g}\ell \mathbf{g}\ell^T]$ (gradient covariance w.r.t. pre-activations).
Since $(A \otimes G)^{-1} = A^{-1} \otimes G^{-1}$, we only need to invert two much smaller matrices (e.g., $1000 \times 1000$ instead of $10^6 \times 10^6$). This reduces the inversion cost from $O(d^3)$ to $O(d^{3/2})$ per layer.
K-FAC is the most successful practical natural gradient method for deep learning, used in large-scale applications including reinforcement learning and Bayesian neural networks.
Empirical Fisher
Replace the expectation over $p(x \mid \theta)$ with the empirical distribution:
$$\hat{I}(\theta) = \frac{1}{B}\sum_{i=1}^B s(\theta, x_i) s(\theta, x_i)^T$$
⚠️ CRITICAL — Common Pitfall: The empirical Fisher $\hat{I}$ based on the training data is NOT the same as the true Fisher $I$. The empirical Fisher using training labels $y_i$ (not samples from the model) can give misleading curvature estimates. When possible, compute the Fisher using samples from the model's own predictive distribution, not the data distribution.
Damping and Trust Regions
In practice, $I(\theta)$ can be ill-conditioned or singular. We add damping:
$$\Delta\theta = -\eta (I(\theta) + \lambda I)^{-1} \nabla\mathcal{L}$$
Damping has a beautiful interpretation: it interpolates between natural gradient ($\lambda \to 0$) and standard gradient ($\lambda \to \infty$). In trust-region terms, $\lambda$ controls how far you trust the quadratic approximation of KL divergence.
Adaptive damping (Levenberg-Marquardt style): Reduce $\lambda$ if the step improves the loss, increase it if the step is too aggressive. This is standard practice in second-order optimization and applies directly to natural gradient.
Connection to TRPO and PPO
Natural gradient is the theoretical foundation of Trust Region Policy Optimization (TRPO) in reinforcement learning. TRPO constrains the KL divergence between old and new policies:
$$\max_\theta \mathbb{E}\left[\frac{\pi_\theta(a \mid s)}{\pi_{\theta_{\text{old}}}(a \mid s)} A(s, a)\right] \quad \text{s.t.} \quad D_{KL}(\pi_{\theta_{\text{old}}} |\pi_\theta) \leq \delta$$
This is exactly a natural gradient step with the Fisher matrix of the policy. PPO simplifies this by using a clipped surrogate objective instead of the hard KL constraint, but the motivation is the same: respect the geometry of policy space.
Key Terms
- Damping
- K-FAC
- Natural gradient
- Riemannian metric
Worked Examples
Example 1: Natural Gradient for a Single-Parameter Bernoulli
Given observed data $y=1$ (single observation), loss $\mathcal{L}(\theta) = -\log\sigma(\theta)$. Compute and compare the standard and natural gradient updates from $\theta_0 = 0$ with $\eta = 1$.
Solution:
At $\theta = 0$: $\sigma(0) = 0.5$, $\sigma(0)(1-\sigma(0)) = 0.25$.
Standard gradient: $\nabla\mathcal{L} = \sigma(0) - 1 = -0.5$. Update: $\Delta\theta_{\text{std}} = -1 \cdot (-0.5) = 0.5$. New $\theta = 0.5$, $\sigma(0.5) \approx 0.622$.
Fisher: $I(0) = 0.25$. Natural gradient: $\tilde{\nabla}\mathcal{L} = -0.5 / 0.25 = -2$. Update: $\Delta\theta_{\text{NG}} = -1 \cdot (-2) = 2$. New $\theta = 2$, $\sigma(2) \approx 0.881$.
Natural gradient makes a much larger step, getting much closer to the optimum. This makes sense: at $\theta=0$, the Fisher metric says the parameter is highly uncertain (wide distribution), so we should move aggressively.
Now reparameterize as $\phi = \log\frac{\theta}{1-\theta}$ (log-odds). At $\theta=0$, $\phi=0$. Standard gradient: $\nabla_\phi\mathcal{L} = \sigma(\phi) - 1 = -0.5$. Update: $\Delta\phi = 0.5$. Natural gradient in $\phi$: same $d = 2$. Under the inverse sigmoid, $\phi = 2$ corresponds to $\theta = \sigma(2) \approx 0.881$ — same answer! Standard gradient in $\phi$ gives $\phi = 0.5$ → $\theta = \sigma(0.5) = 0.622$, same as standard gradient in $\theta$.
Click for answer
Standard gradient: $\theta$ goes from 0 to 0.5 ($\sigma=0.622$). Natural gradient: $\theta$ goes from 0 to 2 ($\sigma=0.881$). Natural gradient gives the same result regardless of parameterization (log-odds or probability), while standard gradient does not. The natural gradient overcomes the sigmoid's compression.Example 2: Natural Gradient for Multivariate Gaussian
For a bivariate Gaussian $\mathcal{N}(\boldsymbol{\mu}, \Sigma)$ with known $\Sigma = \begin{pmatrix} 2 & 1 \ 1 & 2 \end{pmatrix}$ and loss $\mathcal{L}(\boldsymbol{\mu}) = \frac{1}{2}|\boldsymbol{\mu} - \mathbf{y}|^2$ with $\mathbf{y} = (1, 2)^T$, compute the natural gradient direction from $\boldsymbol{\mu}_0 = (0, 0)^T$.
Solution:
Fisher: $I = \Sigma^{-1} = \frac{1}{3}\begin{pmatrix} 2 & -1 \ -1 & 2 \end{pmatrix}$.
Standard gradient: $\nabla\mathcal{L} = \boldsymbol{\mu} - \mathbf{y}$. At $\boldsymbol{\mu}_0$: $(-1, -2)^T$.
Natural gradient: $\tilde{\nabla}\mathcal{L} = I^{-1} \nabla\mathcal{L} = \Sigma \nabla\mathcal{L} = \begin{pmatrix}2 & 1 \ 1 & 2\end{pmatrix} \begin{pmatrix}-1 \ -2\end{pmatrix} = \begin{pmatrix}-4 \ -5\end{pmatrix}$.
Standard descent points toward $(-1, -2)$ (towards $(1, 2)$ in Euclidean space). Natural gradient points toward $(-4, -5)$ — it "stretches" the direction to account for the correlation. The natural gradient knows that moving along the negatively correlated direction requires a larger parameter change to achieve the same distributional change.
Click for answer
Standard gradient direction: $(-1, -2)^T$. Natural gradient direction: $(-4, -5)^T$. The natural gradient is $\Sigma \nabla\mathcal{L}$ because $I^{-1} = \Sigma$. It accounts for the correlation structure: highly correlated dimensions require larger parameter changes for equivalent effects on the distribution.Example 3: K-FAC Block Diagonal Structure
Consider a 2-layer network: $\mathbf{h} = \text{ReLU}(W_1 \mathbf{x})$ (input $\mathbf{x} \in \mathbb{R}^{100}$, hidden $\mathbf{h} \in \mathbb{R}^{50}$), output softmax with $W_2 \in \mathbb{R}^{10 \times 50}$. Compute the dimensions of the Fisher matrix and its K-FAC approximation for $W_2$.
Solution:
$W_2$ has shape $10 \times 50$, so $W_2$ flattened is $d = 500$ parameters. The full Fisher for $W_2$ is $500 \times 500$, requiring $250,000$ entries.
K-FAC factorizes it as $I_{W_2} \approx A_1 \otimes G_2$: - $A_1 = \mathbb{E}[\mathbf{h} \mathbf{h}^T]$: $50 \times 50$ = 2,500 entries - $G_2 = \mathbb{E}[\mathbf{g}_2 \mathbf{g}_2^T]$: $10 \times 10$ = 100 entries
The inverse: $I_{W_2}^{-1} \approx A_1^{-1} \otimes G_2^{-1}$, requiring only a $50 \times 50$ and $10 \times 10$ inversion. The $O(d^3) = O(500^3) = 125 \times 10^6$ flops become $O(50^3 + 10^3) = 126,000$ flops — a factor of ~1000 improvement.
For the full network, each layer gets its own Kronecker factorization, making the preconditioner block-diagonal across layers.
Click for answer
Full Fisher: $500 \times 500$ ($250,000$ entries, $125 \times 10^6$ flops to invert). K-FAC: $A_1$ ($50 \times 50$) and $G_2$ ($10 \times 10$), totaling ~2,600 entries and ~126,000 flops to invert. The savings come from the Kronecker structure $(A \otimes G)^{-1} = A^{-1} \otimes G^{-1}$.Practice Problems
Problem 1: Show that the natural gradient update for minimizing $\mathcal{L}(\theta) = D_{KL}(p^ | p_\theta)$ (forward KL to a fixed target $p^$) is equivalent to one step of the EM algorithm in the exponential family.
Click for answer
In exponential family form $p_\theta(x) = h(x) \exp(\theta^T T(x) - A(\theta))$: $\nabla_\theta \log p_\theta = T(x) - \nabla_\theta A(\theta) = T(x) - \mathbb{E}_{p_\theta}[T(x)]$. So $I(\theta) = \text{Cov}_{p_\theta}[T(x)] = \nabla_\theta^2 A(\theta)$ (the Hessian of the log-partition function is the Fisher). For $\mathcal{L}(\theta) = D_{KL}(p^* \| p_\theta) = \mathbb{E}_{p^*}[\log p^* - \log p_\theta] = \text{const} - \mathbb{E}_{p^*}[T(x)]^T \theta + A(\theta)$: $\nabla_\theta\mathcal{L} = -\mathbb{E}_{p^*}[T(x)] + \nabla A(\theta) = -\mathbb{E}_{p^*}[T(x)] + \mathbb{E}_{p_\theta}[T(x)]$. Natural gradient: $\tilde{\nabla}\mathcal{L} = I(\theta)^{-1} \nabla\mathcal{L} = (\nabla^2 A)^{-1} (\mathbb{E}_{p_\theta}[T] - \mathbb{E}_{p^*}[T])$. With unit step size: $\Delta\theta = -(\nabla^2 A)^{-1} (\mathbb{E}_{p_\theta}[T] - \mathbb{E}_{p^*}[T])$. This matches the EM update in exponential families: the new natural parameters satisfy $\mathbb{E}_{p_{\theta_{\text{new}}}}[T(x)] = \mathbb{E}_{p^*}[T(x)]$, which is exactly the M-step of EM when $p^*$ is the posterior over latents.Problem 2: For the loss $\mathcal{L}(\theta) = \frac{1}{2}(\theta - \mu)^2$ (ridge regression on a single parameter) with Gaussian likelihood $\mathcal{N}(x \mid \theta, 1)$, derive the natural gradient update and explain why it looks different from standard gradient descent.
Click for answer
Model: $p(x \mid \theta) = \mathcal{N}(x \mid \theta, 1)$. Fisher: $I(\theta) = 1$ (constant — see 24-01). Loss gradient: $\nabla\mathcal{L} = \theta - \mu$. Natural gradient: $\tilde{\nabla}\mathcal{L} = I^{-1} \nabla\mathcal{L} = \theta - \mu$. They're identical! Natural gradient IS standard gradient in this case. This happens because the Gaussian-with-fixed-variance model has a flat Fisher metric — the statistical manifold is Euclidean. There's no difference between natural and standard gradient when $I(\theta)$ is the identity matrix. This shows that natural gradient's benefits appear only when the Fisher metric is non-trivial.Problem 3: Consider minimizing $\mathcal{L}(\theta) = -\log p(y \mid \theta)$ (negative log-likelihood). Show that the natural gradient is $\tilde{\nabla}\mathcal{L} = I(\theta)^{-1} (\mathbb{E}_{p(\cdot \mid \theta)}[s] - s(y))$, and that at the MLE (where $\nabla\mathcal{L} = 0$), the natural gradient also vanishes.
Click for answer
$\mathcal{L}(\theta) = -\log p(y \mid \theta)$. Gradient: $\nabla\mathcal{L} = -s(\theta, y)$ where $s$ is the score. Since $\mathbb{E}_{p(x \mid \theta)}[s(\theta, x)] = 0$, we can write: $\nabla\mathcal{L} = \mathbb{E}_{p}[s] - s(\theta, y)$ (trivially, since the expectation is zero). Natural gradient: $\tilde{\nabla}\mathcal{L} = I^{-1}(\mathbb{E}_p[s] - s(y))$. At the true MLE: $\nabla\mathcal{L} = -s(\hat{\theta}, y) = 0$ at the optimum (score vanishes at MLE). Since natural gradient is $I^{-1}\nabla\mathcal{L}$, it also vanishes. Natural gradient descent converges to the MLE just like standard gradient descent, but follows a different (more efficient) path. This is important: natural gradient doesn't change the fixed points, only the trajectory.Problem 4: Prove that for a linear model $y = \mathbf{w}^T \mathbf{x} + \epsilon$ with $\epsilon \sim \mathcal{N}(0, 1)$, the Fisher matrix for $\mathbf{w}$ is exactly the data covariance matrix $\mathbb{E}[\mathbf{x}\mathbf{x}^T]$. What does this imply about the natural gradient for linear regression?
Click for answer
$p(y \mid \mathbf{x}, \mathbf{w}) = \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{1}{2}(y - \mathbf{w}^T\mathbf{x})^2\right)$. Score: $s = \nabla_{\mathbf{w}} \log p = (y - \mathbf{w}^T\mathbf{x})\mathbf{x}$. $I(\mathbf{w}) = \mathbb{E}[s s^T] = \mathbb{E}[(y - \mathbf{w}^T\mathbf{x})^2 \mathbf{x}\mathbf{x}^T] = \mathbb{E}[\epsilon^2 \mathbf{x}\mathbf{x}^T] = \mathbb{E}[\mathbf{x}\mathbf{x}^T]$. Since $\mathbb{E}[\epsilon^2] = 1$ and $\epsilon \perp \mathbf{x}$, the Fisher is the data covariance $\Sigma_{\mathbf{x}}$. Implication: For least squares loss $\mathcal{L} = \frac{1}{2}(y - \mathbf{w}^T\mathbf{x})^2$: $\nabla_{\mathbf{w}}\mathcal{L} = -(y - \mathbf{w}^T\mathbf{x})\mathbf{x}$, $\tilde{\nabla}_{\mathbf{w}}\mathcal{L} = I^{-1}\nabla\mathcal{L} = -\Sigma_{\mathbf{x}}^{-1}(y - \mathbf{w}^T\mathbf{x})\mathbf{x}$. This is the **generalized least squares** update. Natural gradient effectively whitens the input — if $\mathbf{x}$ has highly correlated features, natural gradient corrects for this, while standard gradient descent struggles with ill-conditioned covariances.Problem 5: A common trick in natural gradient implementations is to use the empirical Fisher from the training batch: $\hat{I} = \frac{1}{B}\sum_i s(y_i) s(y_i)^T$ where $y_i$ are true labels. Explain why this is problematic for a classifier, and what would be the correct thing to do.
Click for answer
**Why it's problematic:** The true Fisher $I(\theta)$ expects $s$ to be computed under the model's distribution: $\mathbb{E}_{x \sim p(x \mid \theta)}[s s^T]$. When you use training labels $y_i$, you're approximating $\mathbb{E}_{x \sim p_{\text{data}}}[s s^T]$, which is NOT the Fisher. For a classifier predicting $p(y \mid \mathbf{x})$, the true Fisher uses samples $\tilde{y} \sim p(y \mid \mathbf{x}, \theta)$ (model samples), not the ground truth labels $y_i$. The empirical Fisher with ground truth labels gives an estimate of $I(\theta)$ that is only valid when the model is already perfectly fit — which it isn't during training. **The correct approach:** Sample from the model's predictive distribution (e.g., $\tilde{y} \sim \text{Categorical}(\boldsymbol{\pi}(\mathbf{x}))$) and compute the Fisher from those samples. This is unbiased for $I(\theta)$. In practice, many implementations ignore this and use labels anyway — it works because the bias is often benign, but it's technically incorrect. An even better approach: use the *observed Fisher* $-\nabla_\theta^2 \log p(y \mid \theta)$ for the current batch, which is a valid approximation that uses the actual labels and converges to $I(\theta)$ as the model fits.Summary
Key takeaways:
- Natural gradient preconditions the gradient with the inverse Fisher information matrix: $\tilde{\nabla}\mathcal{L} = I(\theta)^{-1} \nabla\mathcal{L}$
- It is fully invariant to reparameterization — the update is the same in probability space regardless of coordinate system
- The Fisher matrix acts as a Riemannian metric measuring KL divergence, giving the only geometrically correct direction of steepest descent
- K-FAC approximates the Fisher via Kronecker factorization, reducing inversion from $O(d^3)$ to $O(d^{3/2})$ per layer
- Natural gradient is the theoretical basis of TRPO/PPO in RL — trust regions in policy space are KL divergence constraints
- Damping $(I + \lambda I)^{-1}$ is essential in practice; it smoothly interpolates between natural and standard gradient descent
Quiz
Question 1: Why is standard gradient descent not invariant to reparameterization?
A. Because the gradient doesn't exist at some parameter values B. Because the Euclidean inner product $\langle\cdot,\cdot\rangle$ doesn't transform correctly under coordinate changes C. Because the loss function changes under reparameterization D. Because learning rates must be adjusted
Correct Answer: B. Because the Euclidean inner product $\langle\cdot,\cdot\rangle$ doesn't transform correctly under coordinate changes
Explanation: Under reparameterization $\phi = g(\theta)$, the gradient transforms as $\nabla_\phi\mathcal{L} = J^{-T}\nabla_\theta\mathcal{L}$. The standard update $\Delta\phi = -\eta\nabla_\phi\mathcal{L}$ is a different point in distribution space than $\Delta\theta = -\eta\nabla_\theta\mathcal{L}$ because the Euclidean metric $|\Delta\theta|^2$ doesn't correspond to a distributional distance. Natural gradient fixes this by using the Fisher metric $d\theta^T I d\theta$ which transforms correctly.
Question 2: The natural gradient update $\Delta\theta = -\eta I(\theta)^{-1} \nabla\mathcal{L}$ can be derived by:
A. Taking a second-order Taylor expansion of the loss B. Minimizing a first-order approximation of the loss subject to a KL divergence constraint C. Computing the Newton step for the KL divergence D. Using the chain rule on the Fisher matrix
Correct Answer: B. Minimizing a first-order approximation of the loss subject to a KL divergence constraint
Explanation: The natural gradient solves $\min_{d\theta} \nabla\mathcal{L}^T d\theta$ subject to $\frac{1}{2}d\theta^T I d\theta \leq \epsilon$. This is a constrained linear (first-order) problem. Option A describes Newton's method (which uses the Hessian, not the Fisher). Option C doesn't make sense — the Newton step would be on the loss, not KL. Option D is not how the derivation works.
Question 3: Which of the following is TRUE about the relationship between natural gradient and Newton's method?
A. They are identical when the loss is the negative log-likelihood B. Natural gradient always converges in fewer iterations than Newton C. Newton uses the Hessian of the loss; natural gradient uses the Fisher of the model D. Natural gradient is a special case of Newton's method with a diagonal approximation
Correct Answer: C. Newton uses the Hessian of the loss; natural gradient uses the Fisher of the model
Explanation: They use fundamentally different matrices. Newton: $H = \nabla_\theta^2\mathcal{L}$. Natural gradient: $I = -\mathbb{E}[\nabla_\theta^2 \log p]$. They coincide ONLY when the model is well-specified and the loss is the negative log-likelihood — in that case, the Hessian and Fisher are asymptotically equal (Bartlett's identity). In general they differ, and natural gradient has the advantage of always being PSD.
Question 4: For a diagonal Gaussian variational posterior $\mathcal{N}(\mu, \sigma^2)$, the natural gradient update for $\mu$ scales the standard gradient by $\sigma^2$. What does this mean in practice?
A. Parameters with higher uncertainty get smaller updates (more conservative) B. Parameters with higher uncertainty get larger updates (more aggressive) C. Natural gradient ignores mean parameters entirely D. The scaling has no practical effect
Correct Answer: B. Parameters with higher uncertainty get larger updates (more aggressive)
Explanation: $\tilde{\nabla}\mu\mathcal{L} = \sigma^2 \nabla\mu\mathcal{L}$. When $\sigma^2$ is large (high uncertainty about $\mu$), the Fisher information is small ($1/\sigma^2$), so $I^{-1} = \sigma^2$ is large — the update is amplified. Intuitively, if you're very uncertain about a parameter, you should take bigger steps to explore the space. As you become confident ($\sigma^2 \to 0$), the step size shrinks, naturally annealing the learning rate.
Question 5: In K-FAC, the Fisher matrix for a layer is approximated as $A \otimes G$. What are $A$ and $G$?
A. $A$ is the weight matrix and $G$ is the gradient B. $A$ is the activation covariance and $G$ is the pre-activation gradient covariance C. $A$ is the attention matrix and $G$ is the gain D. $A$ is the Fisher of the previous layer and $G$ is the Fisher of the current layer
Correct Answer: B. $A$ is the activation covariance and $G$ is the pre-activation gradient covariance
Explanation: K-FAC exploits the structure of feedforward layers $W\mathbf{a}$. The Fisher for $W$ factorizes into the outer product of $\mathbf{a}\mathbf{a}^T$ (input-side covariance) and $\mathbf{g}\mathbf{g}^T$ (output-side gradient covariance), under the assumption that activations and backpropagated gradients are independently distributed. This Kronecker structure is key to making natural gradient tractable in deep networks.
Question 6: Why is damping $(I + \lambda I)^{-1}$ necessary for practical natural gradient descent?
A. To ensure the Fisher matrix is positive definite B. To prevent the step size from being too large when $I$ is ill-conditioned or singular C. To make the Fisher matrix symmetric D. To reduce the memory cost of storing the Fisher matrix
Correct Answer: B. To prevent the step size from being too large when $I$ is ill-conditioned or singular
Explanation: The Fisher can be ill-conditioned (eigenvalues near zero) or even singular, causing $I^{-1}$ to blow up. Damping controls this: $(I + \lambda I)^{-1}$ has eigenvalues $1/(\lambda_i + \lambda) \leq 1/\lambda$, bounding the step. As $\lambda \to \infty$, the update becomes $\approx \frac{1}{\lambda} \nabla\mathcal{L}$ — standard gradient descent. As $\lambda \to 0$, we recover pure natural gradient. It's a continuous trust-region mechanism.
Pitfalls
-
Confusing natural gradient with Newton's method: Both use matrix preconditioners, but natural gradient uses $I(\theta)$ (Fisher of the model) while Newton uses $H$ (Hessian of the loss). They coincide only asymptotically when the model is well-specified and the loss is negative log-likelihood. The natural gradient has the crucial advantage of always being PSD, while the Hessian can be indefinite.
-
Forgetting damping in practice: The Fisher matrix $I(\theta)$ is often ill-conditioned or singular. Without damping $(I + \lambda I)^{-1}$, the natural gradient step can explode. Many implementation failures stem from setting $\lambda$ too small or omitting it entirely. Start with $\lambda \approx 1$ and adapt using Levenberg-Marquardt heuristics: decrease if the step improves the loss, increase if it worsens it.
-
Using the wrong distribution for empirical Fisher: Computing $\hat{I} = \frac{1}{B}\sum s(y_i)s(y_i)^T$ with training labels $y_i$ estimates the Fisher under the data distribution, not the model distribution. The correct approach is to sample from the model's predictive distribution (e.g., $\tilde{y} \sim \text{Categorical}(\boldsymbol{\pi}(\mathbf{x}))$) and compute the Fisher from those samples. In practice, many implementations use labels anyway — it often works, but it's technically incorrect and can cause problems early in training.
-
Expecting natural gradient to always outperform standard gradient: Natural gradient provides faster convergence per step but each step is much more expensive ($O(d^2)$ to $O(d^3)$ vs. $O(d)$). The wall-clock advantage depends on the problem and the quality of the approximation. For many deep learning tasks, well-tuned Adam or SGD with momentum matches or exceeds natural gradient in practice. Natural gradient shines most when parameterization matters (e.g., variational inference, reinforcement learning with policy gradients).
Next Steps
Next up: 24-03 — Neural Tangent Kernel (NTK) — where you'll discover that infinitely wide neural networks behave exactly like kernel machines, with the natural gradient and NTK sharing deep mathematical connections through the Fisher geometry.