21-02 — Variational Inference
Phase: 21 — Probability & Statistics for ML (Advanced) Subject: 21-02 Prerequisites: 21-01 (Bayesian Inference), 13-04 (KL Divergence), 14-06 (Convex Sets and Functions — for ELBO optimization), 10-05 (Continuous Distributions) Next subject: 21-03 — Markov Chain Monte Carlo (MCMC)
Learning Objectives
By the end of this subject, you will be able to:
- Derive the Evidence Lower Bound (ELBO) from the KL divergence between approximate and true posteriors
- Formulate variational inference as an optimization problem: find q*(z) that minimizes KL(q || p) within a tractable family
- Apply mean-field variational inference — derive the coordinate ascent update q_j*(z_j) ∝ exp(E_{−j}[log p(z, x)])
- Compute the reparameterization trick gradient for a Gaussian variational distribution and explain why it reduces variance vs the score-function estimator
- Connect variational inference to variational autoencoders (VAEs) — derive the VAE objective as ELBO maximization
Core Content
1. The Problem: Intractable Posteriors
Bayesian inference requires computing the posterior:
$P(z | x) = P(x | z) P(z) / P(x) $
where P(x) = ∫ P(x | z) P(z) dz.
For all but the simplest models, this integral is intractable. In high dimensions, quadrature fails (curse of dimensionality). Monte Carlo methods (21-03) work but can be slow. Variational inference turns integration into OPTIMIZATION — find the best approximation to the posterior from a tractable family.
⚠️ THIS IS CRITICAL — Variational inference is the workhorse of scalable Bayesian methods. Every VAE, every Bayesian neural network trained with variational methods, rests on the ELBO.
2. The KL Divergence and the ELBO
Let q(z) be our approximation to the true posterior p(z | x). We measure the quality of the approximation using KL divergence:
$KL(q(z) || p(z | x)) = E_{z~q}[log q(z) − log p(z | x)]
= E_q[log q(z) − log(p(x, z)/p(x))]
= E_q[log q(z) − log p(x, z)] + log p(x)
$
Rearrange to isolate log p(x):
$log p(x) = KL(q || p) + E_q[log p(x, z) − log q(z)]
= KL(q || p) + ELBO(q)
$
The Evidence Lower Bound (ELBO):
$ELBO(q) = E_{z~q}[log p(x, z)] − E_{z~q}[log q(z)]
= E_q[log p(x | z)] + E_q[log p(z)] − E_q[log q(z)]
= E_q[log p(x | z)] − KL(q(z) || p(z))
$
Key properties: - Since KL ≥ 0, ELBO(q) ≤ log p(x) — the ELBO is a lower bound on the log marginal likelihood - Maximizing ELBO is equivalent to minimizing KL(q || p) — as the bound tightens, q approaches the true posterior - The gap log p(x) − ELBO(q) = KL(q || p) is the approximation error
Interpretation of the ELBO's two terms: 1. E_q[log p(x | z)] — expected log-likelihood (reconstruction quality). Encourages q to put mass on z that explain x well. 2. −KL(q(z) || p(z)) — negative KL from prior. Encourages q to stay close to the prior — a regularizer preventing overfitting.
3. Mean-Field Variational Inference
The simplest variational family: assume all latent variables are independent:
$q(z) = ∏_{j=1}^m q_j(z_j)
$
Each z_j has its own variational distribution, and they factorize completely. This is the mean-field assumption.
Coordinate Ascent Variational Inference (CAVI):
Optimize each q_j while holding the others fixed. The optimal q_j* satisfies:
$log q_j*(z_j) = E_{z_{−j}~q_{−j}}[log p(z, x)] + const
$
where E_{z_{−j}} means expectation over all latent variables EXCEPT z_j.
Derivation: Take the functional derivative of ELBO with respect to q_j, set to zero, and solve. The result: each q_j is proportional to the exponentiated expected log-joint, where the expectation is under the CURRENT estimates of all other q's.
Algorithm:
$Initialize all q_j arbitrarily
Repeat until convergence:
For j = 1 to m:
q_j(z_j) ∝ exp(E_{−j}[log p(z, x)])
$
Since the ELBO is convex with respect to each q_j (holding others fixed), each step increases the ELBO. The algorithm converges to a local optimum (though not necessarily global, since the overall problem is non-convex).
4. The Reparameterization Trick
For gradient-based optimization (needed in deep learning), we need ∇_φ ELBO where φ parameterizes q_φ(z). The naive score-function (REINFORCE) estimator has high variance:
$∇_φ ELBO = E_q[∇_φ log q_φ(z) · (log p(x,z) − log q_φ(z))] $
The reparameterization trick (Kingma & Welling, 2014) provides a lower-variance alternative. If we can express:
$z = g_φ(ε) where ε ~ p(ε) [a fixed noise distribution] $
then:
$∇_φ E_{q_φ}[f(z)] = E_{p(ε)}[∇_φ f(g_φ(ε))]
$
Gaussian example: For q_φ(z) = N(z | μ, σ²) with φ = {μ, σ}:
$z = μ + σ · ε where ε ~ N(0, 1) $
Then:
$∇_μ ELBO = E_ε[∇_z(log p(x,z) − log q(z)) · 1] ∇_σ ELBO = E_ε[∇_z(log p(x,z) − log q(z)) · ε] $
The gradient flows THROUGH the sampling operation, dramatically reducing variance compared to REINFORCE (typically 10-100× lower variance).
⚠️ THIS IS CRITICAL — The reparameterization trick is what made VAEs practical. Without it, training VAEs with REINFORCE produces unusably noisy gradients.
5. Variational Autoencoders (VAEs)
A VAE is variational inference applied to a latent-variable generative model:
$p(x, z) = p(z) · p_θ(x | z) [generative model] q_φ(z | x) [inference network / encoder] $
The ELBO for a VAE:
$ELBO(θ, φ) = E_{z~q_φ(z|x)}[log p_θ(x | z)] − KL(q_φ(z | x) || p(z))
$
Architecture: - Encoder (φ): Neural network mapping x → μ_φ(x), σ_φ(x) for the variational Gaussian q_φ(z|x) - Decoder (θ): Neural network mapping z → parameters of p_θ(x|z) (e.g., Bernoulli for binary data, Gaussian for continuous) - Prior: Typically p(z) = N(0, I)
Training: Gradient descent on −ELBO (the "VAE loss"):
$L_VAE = −E_{ε~N(0,I)}[log p_θ(x | μ + σ·ε)] + KL(N(μ,σ²) || N(0,I))
$
The KL between two Gaussians has a closed form:
$KL(N(μ,σ²) || N(0,1)) = ½(μ² + σ² − log σ² − 1) $
For vector z: sum over dimensions.
6. Amortized Inference
In mean-field VI, we optimize separate variational parameters for EACH data point — expensive for large datasets. Amortized inference (used in VAEs) learns a function mapping x → q_φ(z|x) parameters. The cost of inference is amortized across data points — one forward pass through the encoder replaces per-datapoint optimization.
Mathematical formulation:
$q_φ(z | x_i) = N(z | μ_φ(x_i), σ²_φ(x_i)) $
Instead of optimizing {μ_i, σ_i} for each x_i, we optimize the shared parameters φ of the neural network. Inference becomes a prediction problem: given x, predict the posterior's parameters.
7. The β-VAE
A variant that weights the KL term:
$L_β-VAE = E_q[log p(x|z)] − β · KL(q(z|x) || p(z)) $
- β = 1: Standard VAE
- β > 1: Stronger pressure for the latent representation to be factorial and match the prior — encourages DISENTANGLED representations
- β < 1: Weaker prior constraint — better reconstruction but less structured latent space
The β-VAE trades off reconstruction fidelity against latent structure. High β produces interpretable, factorized latent dimensions where each dimension controls an independent generative factor.
Worked Examples
Example 1: Mean-Field VI for a Gaussian Mixture
Problem: Data x is drawn from a mixture of two Gaussians: p(x|z) = N(μ_z, 1) where z ∈ {0, 1} with prior p(z=1) = π. Derive the mean-field variational update for q(z).
Solution:
The joint: p(x, z) = p(z) · p(x|z)
Mean-field: q(z, μ₀, μ₁) = q(z) · q(μ₀) · q(μ₁). For q(z):
$log q*(z) = E_{μ}[log p(x, z, μ)] + const
= E_{μ}[log p(z) + log p(x | z, μ_z)] + const
$
For z=1:
$log q*(z=1) = log π + E_{μ₁}[−(x−μ₁)²/2] + const
$
For z=0:
$log q*(z=0) = log(1−π) + E_{μ₀}[−(x−μ₀)²/2] + const
$
Let r = q(z=1). Then:
$log(r/(1−r)) = log(π/(1−π)) + E_{μ₀}[−(x−μ₀)²/2] − E_{μ₁}[−(x−μ₁)²/2]
r = σ(log(π/(1−π)) + ½[E_{μ₀}[(x−μ₀)²] − E_{μ₁}[(x−μ₁)²]])
$
where σ is the sigmoid. This is an intuitive "soft" assignment: points closer to μ₁'s expected location get higher probability of z=1.
Example 2: Reparameterization Gradient Calculation
Problem: For q_φ(z) = N(μ, σ²) with φ = {μ, σ}, compute the reparameterized gradient of E_q[z²] with respect to μ and σ.
Solution:
$z = μ + σ·ε, ε ~ N(0, 1) f(z) = z² ∇_μ E[f] = E_ε[∇_μ f(μ+σ·ε)] = E_ε[2(μ+σ·ε)·1] = 2μ + 2σ·E[ε] = 2μ ∇_σ E[f] = E_ε[∇_σ f(μ+σ·ε)] = E_ε[2(μ+σ·ε)·ε] = E_ε[2με + 2σε²] = 2μ·0 + 2σ·1 = 2σ $
Check: E[z²] = μ² + σ². ∂/∂μ = 2μ ✓. ∂/∂σ = 2σ ✓.
Contrast with score-function estimator:
$∇_μ E[z²] = E_q[z² · ∇_μ log q(z)] = E_q[z² · (z−μ)/σ²] $
This requires sampling and has variance scaling with the magnitude of z², which can be large.
Example 3: VAE ELBO Decomposition
Problem: A VAE with latent dimension d=2, prior N(0,I), encoder output μ=[0.5, −0.3], σ²=[0.1, 0.2], and reconstruction log-likelihood of −2.0 (per data point). Compute the ELBO.
Solution:
$KL term = ½ Σ_{j=1}^2 (μ_j² + σ_j² − log σ_j² − 1)
= ½[(0.25 + 0.1 − log(0.1) − 1) + (0.09 + 0.2 − log(0.2) − 1)]
= ½[(0.25 + 0.1 − (−2.303) − 1) + (0.09 + 0.2 − (−1.609) − 1)]
= ½[(0.25 + 0.1 + 2.303 − 1) + (0.09 + 0.2 + 1.609 − 1)]
= ½[1.653 + 0.899]
= ½[2.552]
= 1.276
ELBO = E_q[log p(x|z)] − KL
= −2.0 − 1.276
= −3.276
$
A tighter bound (ELBO closer to log p(x)) would result from either better reconstruction (less negative E_q[log p(x|z)]) or a KL term closer to zero. The KL of 1.276 indicates moderate divergence from the prior.
Quiz
Q1: What does the concept of Amortized inference primarily refer to in this subject?
A) The definition and application of Amortized inference B) A visual representation of Amortized inference C) A computational error related to Amortized inference D) A historical anecdote about Amortized inference
Correct: A)
- If you chose A: Amortized inference is defined as: the definition and application of amortized inference. The other options describe different aspects that are not the primary focus. Correct!
- If you chose B: This is incorrect. Amortized inference is defined as: the definition and application of amortized inference. The other options describe different aspects that are not the primary focus.
- If you chose C: This is incorrect. Amortized inference is defined as: the definition and application of amortized inference. The other options describe different aspects that are not the primary focus.
- If you chose D: This is incorrect. Amortized inference is defined as: the definition and application of amortized inference. The other options describe different aspects that are not the primary focus.
Q2: What is the primary purpose of Speed?
A) It replaces all other methods in this domain B) It is used only in advanced research contexts C) It is primarily a historical notation system D) It is used to speed in mathematical analysis
Correct: D)
- If you chose A: This is incorrect. Speed serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose B: This is incorrect. Speed serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose C: This is incorrect. Speed serves the purpose described in the correct answer. The other options misrepresent its role.
- If you chose D: Speed serves the purpose described in the correct answer. The other options misrepresent its role. Correct!
Q3: Which statement about Accuracy is TRUE?
A) Accuracy is mentioned only as a historical footnote B) Accuracy is not related to this subject C) Accuracy is a fundamental concept covered in this subject D) Accuracy is an advanced topic beyond this subject's scope
Correct: C)
- If you chose A: This is incorrect. Accuracy is a fundamental concept covered in this subject. This subject covers Accuracy as part of its core content.
- If you chose B: This is incorrect. Accuracy is a fundamental concept covered in this subject. This subject covers Accuracy as part of its core content.
- If you chose C: Accuracy is a fundamental concept covered in this subject. This subject covers Accuracy as part of its core content. Correct!
- If you chose D: This is incorrect. Accuracy is a fundamental concept covered in this subject. This subject covers Accuracy as part of its core content.
Q4: Based on the worked examples in this subject, what is the correct result?
A) A different result from a common mistake B) The inverse of the correct answer C) An unrelated numerical value D) log E[X] ≥ E[log X]]
Correct: D)
- If you chose A: This is incorrect. The worked examples show that the result is log E[X] ≥ E[log X]]. The other options represent common errors.
- If you chose B: This is incorrect. The worked examples show that the result is log E[X] ≥ E[log X]]. The other options represent common errors.
- If you chose C: This is incorrect. The worked examples show that the result is log E[X] ≥ E[log X]]. The other options represent common errors.
- If you chose D: The worked examples show that the result is log E[X] ≥ E[log X]]. The other options represent common errors. Correct!
Q5: How are Accuracy and Scalability related?
A) Accuracy is a special case of Scalability B) Accuracy and Scalability are closely related concepts C) Accuracy and Scalability are completely unrelated topics D) Accuracy is the inverse of Scalability
Correct: B)
- If you chose A: This is incorrect. Both Accuracy and Scalability are covered in this subject as interconnected topics.
- If you chose B: Both Accuracy and Scalability are covered in this subject as interconnected topics. Correct!
- If you chose C: This is incorrect. Both Accuracy and Scalability are covered in this subject as interconnected topics.
- If you chose D: This is incorrect. Both Accuracy and Scalability are covered in this subject as interconnected topics.
Q6: What is a common pitfall when working with Applicability?
A) Applicability is always computed the same way in all contexts B) A common mistake is confusing Applicability with a similar concept C) The main error with Applicability is using it when it is not needed D) Applicability has no common misconceptions
Correct: B)
- If you chose A: This is incorrect. Students often confuse Applicability with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose B: Students often confuse Applicability with similar-sounding or related concepts. Pay attention to the precise definitions. Correct!
- If you chose C: This is incorrect. Students often confuse Applicability with similar-sounding or related concepts. Pay attention to the precise definitions.
- If you chose D: This is incorrect. Students often confuse Applicability with similar-sounding or related concepts. Pay attention to the precise definitions.
Q7: When should you apply The ELBO?
A) The ELBO is not practically useful B) Use The ELBO only in pure mathematics contexts C) Avoid The ELBO unless explicitly instructed D) Apply The ELBO to solve problems in this subject's domain
Correct: D)
- If you chose A: This is incorrect. The ELBO is a practical tool used throughout this subject to solve relevant problems.
- If you chose B: This is incorrect. The ELBO is a practical tool used throughout this subject to solve relevant problems.
- If you chose C: This is incorrect. The ELBO is a practical tool used throughout this subject to solve relevant problems.
- If you chose D: The ELBO is a practical tool used throughout this subject to solve relevant problems. Correct!
Practice Problems
Problem 1
Derive the ELBO starting from log p(x) = log ∫ p(x, z) dz. Use Jensen's inequality with the variational distribution q(z).
Answer
$log p(x) = log ∫ p(x, z) dz
= log ∫ q(z) · p(x, z)/q(z) dz
≥ ∫ q(z) · log(p(x, z)/q(z)) dz [Jensen: log E[X] ≥ E[log X]]
= E_q[log p(x, z)] − E_q[log q(z)]
= ELBO(q)
$
Jensen's inequality applies because log is concave. The gap is exactly KL(q || p(z|x)):
ELBO = log p(x) − KL(q(z) || p(z|x))
So log p(x) ≥ ELBO, with equality iff q(z) = p(z|x) (zero KL divergence).
Problem 2
For the mean-field approximation q(z₁, z₂) = q₁(z₁)q₂(z₂), derive the optimal update for q₁. Show that each update monotonically increases the ELBO.
Answer
ELBO = E_{q₁q₂}[log p(z₁,z₂,x)] − E_{q₁}[log q₁] − E_{q₂}[log q₂] Take functional derivative w.r.t. q₁ (with constraint ∫ q₁ = 1):$δ/δq₁ [∫ q₁(z₁)(∫ q₂(z₂)log p(z₁,z₂,x)dz₂)dz₁ − ∫ q₁(z₁)log q₁(z₁)dz₁ + λ(∫ q₁−1)] = 0 $Yields: E_{q₂}[log p(z₁, z₂, x)] − log q₁(z₁) − 1 + λ = 0 So: log q₁*(z₁) = E_{q₂}[log p(z₁, z₂, x)] + const Each CAVI update maximizes the ELBO with respect to q_j holding others fixed — so the ELBO is non-decreasing across iterations and converges to a local optimum.
Problem 3
Show that the score-function (REINFORCE) gradient estimator for ∇φ E{q_φ}[f(z)] is unbiased but can have infinite variance. Use f(z) = z², q_φ = N(μ, 1) as an example.
Answer
Score-function estimator:$g_SF = f(z) · ∇_φ log q_φ(z), z ~ q_φ
E[g_SF] = ∫ q_φ(z)·f(z)·∇_φ log q_φ(z) dz
= ∫ f(z)·∇_φ q_φ(z) dz
= ∇_φ ∫ f(z)·q_φ(z) dz
= ∇_φ E_{q_φ}[f(z)]
$
Unbiased ✓. But variance:
For q_φ = N(μ, 1), log q = −½(z−μ)² + const, ∇_μ log q = z−μ.
g_SF = z² · (z−μ)
E[g_SF²] = E[z⁴(z−μ)²]. For z ~ N(μ, 1), this involves 6th moments — can be very large when |μ| is large. Contrast with reparameterization: g_RP = 2z · 1 = 2z, variance = 4 — constant, independent of μ.
This is why reparameterization gradients dominate in practice: 10-100× lower variance means faster, more stable training.
Problem 4
A VAE is trained on MNIST. The encoder outputs μ and log σ². Explain why we parameterize log σ² instead of σ² directly, and what happens if σ → 0 during training.
Answer
**Why log σ²:** σ must be positive. If we output log σ², any real-valued network output maps to σ² = exp(log σ²) > 0 automatically. If we output σ directly, we'd need a softplus or ReLU constraint, which can have gradient issues near zero. **If σ → 0:** The KL term −log σ² → ∞, heavily penalizing this. But if the reconstruction term benefits enough (deterministic encoding), the optimizer may push σ → 0. This is "posterior collapse" — the latent code becomes deterministic, carrying no information about x, and the decoder ignores z. The model reduces to a standard autoencoder. Mitigations: KL annealing (gradually increase KL weight), free bits (clamp KL to a minimum), or stronger decoder (pixelCNN) so reconstruction doesn't dominate.Problem 5
Compare variational inference to MCMC (21-03) along four axes: speed, accuracy guarantees, scalability, and applicability.
Answer
| Axis | Variational Inference | MCMC | |------|----------------------|------| | **Speed** | Fast — optimization, deterministic convergence | Slow — sampling, needs many iterations to mix | | **Accuracy** | Biased — approximates posterior with a simpler family. Even at convergence, KL > 0 | Asymptotically exact — converges to true posterior as samples → ∞ | | **Scalability** | Excellent — stochastic VI scales to millions of data points (VAEs, Bayes NN) | Poor — each sample requires full data pass or complex subsampling | | **Applicability** | Requires tractable ELBO and reparameterizable distributions | Applies to any model where you can evaluate the unnormalized posterior | **When to use which:** - VI: Large datasets, need fast approximate answers, deep learning integration - MCMC: Small to medium datasets, need accurate uncertainty quantification, complex models without convenient variational families Modern practice often combines both: use VI for initialization, then refine with MCMC; or use normalizing flows (which bridge both paradigms).Summary
- Variational inference turns integration into optimization — find q*(z) minimizing KL(q || p(z|x)) from a tractable family, which is equivalent to maximizing the ELBO
- The ELBO = E_q[log p(x|z)] − KL(q(z) || p(z)) is a lower bound on log p(x); maximizing it simultaneously improves the approximation and the model
- Mean-field VI assumes factorized q(z) = ∏ q_j(z_j) and uses coordinate ascent — each q_j ∝ exp(E_{−j}[log p(z,x)])
- The reparameterization trick (z = g_φ(ε)) enables low-variance gradient estimates through sampling, making VAEs and Bayesian neural networks trainable
- VAEs apply amortized VI with neural networks — the encoder predicts posterior parameters, the decoder reconstructs, and the ELBO is optimized end-to-end
Pitfalls
- Assuming variational inference gives exact posteriors. VI minimizes KL(q||p) within a restricted family — even at convergence, q ≠ p. The approximation error (the KL gap) can be large, especially when the true posterior has strong correlations that the variational family cannot capture (e.g., mean-field on a highly correlated posterior). VI provides a fast approximation, not an exact answer. Complement with MCMC when accuracy matters.
- Using mean-field VI when posterior correlations are critical. The mean-field assumption q(z) = ∏ q_j(z_j) cannot represent dependencies between latent variables. In a model where two parameters are strongly correlated in the posterior (e.g., slope and intercept in regression), mean-field VI will underestimate uncertainty and produce an overconfident approximation. Use structured variational families (e.g., multivariate Gaussian with full covariance) when correlations matter.
- Applying the reparameterization trick to discrete latent variables. The reparameterization trick requires z = g_φ(ε) to be differentiable. Discrete variables (e.g., Bernoulli, Categorical) don't have a differentiable reparameterization. Using the score-function estimator on discrete latents produces high-variance gradients. Use continuous relaxations (Gumbel-Softmax, Concrete distribution) or specialized discrete VI methods (REINFORCE with control variates, RELAX).
- Ignoring posterior collapse during VAE training. When the decoder is too powerful relative to the encoder, the model learns to ignore z — q_φ(z|x) collapses to the prior p(z) for all x, and the decoder memorizes the data directly. Monitor the KL term during training: if it rapidly decays to near zero, the model is collapsing. Mitigations include KL annealing (gradually increasing KL weight from 0 to 1), free bits (clamping KL to a minimum), or weakening the decoder.
- Comparing ELBO values across different models. The ELBO = log p(x) − KL(q||p) depends on both the model quality AND the tightness of the bound. A higher ELBO for model A vs model B could mean model A is better OR that model A's variational approximation is tighter. Only compare ELBOs within the same variational family and model class. For model comparison, use more expensive but reliable methods like importance-weighted bounds (IWAE) or thermodynamic integration.
Key Terms
| Term | Definition |
|---|---|
| ELBO | Evidence Lower Bound — E_q[log p(x,z)] − E_q[log q(z)] ≤ log p(x); maximizing it minimizes KL(q |
| Mean-field | Variational family with fully factorized q(z) = ∏ q_j(z_j) — simplifies optimization at the cost of ignoring posterior correlations |
| CAVI | Coordinate Ascent VI — iteratively update each q_j in closed form using E_{−j}[log p(z,x)] |
| Reparameterization trick | Express z = g_φ(ε) with ε ~ fixed noise; gradients flow through g_φ, giving low-variance estimates |
| Score-function estimator | REINFORCE: ∇_φ E_q[f] = E_q[f(z)·∇_φ log q_φ(z)] — unbiased but high-variance |
| Amortized inference | Learn a function mapping x → q_φ(z |
| VAE | Variational Autoencoder — encoder + decoder trained jointly via ELBO maximization |
| β-VAE | VAE with weighted KL term β·KL(q |
| Posterior collapse | Latent variable z becomes independent of x — decoder ignores z, model reduces to standard autoencoder |
Next Steps
Continue to 21-03 — Markov Chain Monte Carlo (MCMC) to learn the sampling-based alternative to variational inference — how to draw samples from intractable distributions using Markov chains.