24-06 โ Optimal Transport
Phase: 24 โ Information Geometry & Advanced Theory Subject: 24-06 Prerequisites: 24-05 (Disentanglement), 22-03 (GANs), 14 (Optimization Theory) Next subject: 25-01 โ Mechanistic Interpretability
Learning Objectives
By the end of this subject, you will be able to:
- Formulate the Monge and Kantorovich optimal transport problems and explain their relationship
- Define the Wasserstein distance and derive its key properties (metric axioms, dual formulation)
- Apply the Sinkhorn algorithm for entropic-regularised optimal transport
- Explain how optimal transport connects to WGAN, domain adaptation, and dataset comparison
- Compute exact Wasserstein distances for 1D distributions and discrete empirical distributions
Core Content
1. The Monge Problem (1781)
Gaspard Monge posed the original optimal transport question: given a pile of soil (source distribution) and a hole to fill (target distribution), what is the most efficient way to move the soil?
Mathematical formulation: Given two probability measures $\mu$ on $\mathcal{X}$ and $\nu$ on $\mathcal{Y}$, find a transport map $T: \mathcal{X} \to \mathcal{Y}$ that pushes $\mu$ forward to $\nu$ while minimising the transport cost:
$$\min_T \int_{\mathcal{X}} c(x, T(x)) \, d\mu(x) \quad \text{subject to} \quad T_#\mu = \nu$$
where $T_#\mu = \nu$ means $\nu(B) = \mu(T^{-1}(B))$ for all measurable $B$ (the push-forward condition) and $c(x, y)$ is the cost of moving one unit of mass from $x$ to $y$.
Limitations: The Monge problem requires a deterministic map $T$, which may not exist (e.g., splitting mass โ one source point mapped to two targets). It's also generally non-convex and difficult to solve.
2. The Kantorovich Relaxation (1942)
Leonid Kantorovich generalised Monge's formulation by allowing mass to split โ using a transport plan (joint distribution) rather than a deterministic map:
$$\min_{\gamma \in \Pi(\mu, \nu)} \int_{\mathcal{X} \times \mathcal{Y}} c(x, y) \, d\gamma(x, y)$$
where $\Pi(\mu, \nu)$ is the set of all joint distributions on $\mathcal{X} \times \mathcal{Y}$ with marginals $\mu$ and $\nu$:
$$\Pi(\mu, \nu) = {\gamma : \gamma(A \times \mathcal{Y}) = \mu(A), \;\gamma(\mathcal{X} \times B) = \nu(B), \;\forall A \subseteq \mathcal{X}, B \subseteq \mathcal{Y}}$$
โ ๏ธ CRITICAL: The Kantorovich formulation is a convex optimisation problem (convex objective, convex constraint set), making it computationally tractable unlike Monge's non-convex problem.
Interpretation: $\gamma(x, y)$ represents how much mass from source point $x$ is sent to target point $y$. The marginal constraints ensure all source mass is sent and all target mass is received.
Kantorovich-Rubinstein duality: For the special case $c(x, y) = |x - y|$ (1-Wasserstein), the dual formulation is:
$$W_1(\mu, \nu) = \sup_{|f|_L \leq 1} \left[ \int f \, d\mu - \int f \, d\nu \right]$$
where $|f|_L \leq 1$ means $f$ is 1-Lipschitz. This is the form used in WGAN!
3. The Wasserstein Distance
The $p$-Wasserstein distance between $\mu$ and $\nu$ is:
$$W_p(\mu, \nu) = \left( \inf_{\gamma \in \Pi(\mu, \nu)} \int |x - y|^p \, d\gamma(x, y) \right)^{1/p}$$
Properties: For $p \geq 1$, $W_p$ is a proper metric on the space of probability measures with finite $p$-th moments. It satisfies:
- Non-negativity: $W_p(\mu, \nu) \geq 0$, with equality iff $\mu = \nu$
- Symmetry: $W_p(\mu, \nu) = W_p(\nu, \mu)$
- Triangle inequality: $W_p(\mu, \nu) \leq W_p(\mu, \eta) + W_p(\eta, \nu)$
Earth Mover's Distance (EMD): $W_1$ specifically is called the Earth Mover's Distance โ the minimum amount of "work" (mass ร distance) to transform $\mu$ into $\nu$.
โ ๏ธ CRITICAL โ Why Wasserstein matters for ML: Unlike KL divergence and JS divergence, Wasserstein distance: - Is defined even when distributions have disjoint supports (KL is infinite, JS is $\log 2$) - Provides meaningful gradients everywhere, avoiding vanishing gradient problems in GAN training - Metrizes weak convergence, providing a more useful geometry on the space of distributions
4. Exact Solution for 1D
For distributions on $\mathbb{R}$, the Wasserstein distance has a closed form using the inverse CDF (quantile function):
$$W_p(\mu, \nu) = \left( \int_0^1 |F_\mu^{-1}(t) - F_\nu^{-1}(t)|^p \, dt \right)^{1/p}$$
For $p=2$, this is the $L^2$ distance between quantile functions.
Discrete 1D case: Given sorted points $x_1 \leq x_2 \leq \cdots \leq x_n$ (source) and $y_1 \leq y_2 \leq \cdots \leq y_n$ (target) with equal mass $1/n$ each:
$$W_p^p(\mu, \nu) = \frac{1}{n} \sum_{i=1}^n |x_i - y_i|^p$$
The optimal transport simply pairs the $i$-th smallest source with the $i$-th smallest target.
5. The Sinkhorn Algorithm
For general discrete OT with $n$ source points and $m$ target points, the cost is solving an $n \times m$ linear program โ $O(n^3 \log n)$ complexity, prohibitive for large $n$.
Entropic regularisation (Cuturi, 2013): Add an entropy term to smooth the problem:
$$\min_{\gamma \in \Pi(\mu, \nu)} \sum_{i,j} \gamma_{ij} c_{ij} - \varepsilon H(\gamma)$$
where $H(\gamma) = -\sum_{i,j} \gamma_{ij} \log \gamma_{ij}$ is the entropy of the transport plan, and $\varepsilon > 0$ controls regularisation strength.
The Sinkhorn solution: The optimal plan takes the form $\gamma_{ij} = u_i K_{ij} v_j$ where $K_{ij} = \exp(-c_{ij}/\varepsilon)$ (Gibbs kernel). The vectors $\mathbf{u}$ and $\mathbf{v}$ are found by iterative row/column normalisation:
Sinkhorn algorithm: 1. Initialise $v_j^{(0)} = 1$ for all $j$ 2. Repeat until convergence: - $u_i^{(t+1)} = \mu_i / \sum_j K_{ij} v_j^{(t)}$ (row normalisation) - $v_j^{(t+1)} = \nu_j / \sum_i K_{ij} u_i^{(t+1)}$ (column normalisation) 3. Reconstruct $\gamma_{ij} = u_i K_{ij} v_j$
Each iteration is matrix-vector products โ $O(n^2)$ per iteration, much faster than the full $O(n^3)$ LP. Convergence rate scales as $O(\varepsilon^{-1} \log n)$.
6. Applications of Optimal Transport in ML
WGAN (Wasserstein GAN): Arjovsky et al. (2017) replaced the JS divergence in GAN training with $W_1$, using the Kantorovich-Rubinstein dual:
$$W_1(P_r, P_g) = \sup_{|f|L \leq 1} \mathbb{E}{x \sim P_r}[f(x)] - \mathbb{E}_{x \sim P_g}[f(x)]$$
The discriminator (critic) $f$ is constrained to be 1-Lipschitz (enforced via weight clipping or gradient penalty). This provides meaningful gradients even when the real and generated distributions have disjoint supports.
Domain Adaptation: Use optimal transport to align source and target domain feature distributions โ compute the transport plan between source and target samples, then apply barycentric mapping to transform data.
Dataset Distance: Compare two datasets by treating each as an empirical distribution and computing $W_p$ between them. Used to measure domain shift, evaluate generative model quality, or compare training sets.
Key Terms
- Beyond GANs
- Monge
- Sinkhorn algorithm
- WGAN
- Wasserstein distance
Worked Examples
Example 1: 1D Wasserstein Distance
Problem: Compute $W_2$ between two 1D empirical distributions: - $\mu$: points at ${1, 2, 5}$ (uniform weights $1/3$) - $\nu$: points at ${0, 4, 7}$ (uniform weights $1/3$)
Solution:
Step 1 โ Sort both: $\mu$ already sorted: $[1, 2, 5]$. $\nu$: $[0, 4, 7]$.
Step 2 โ Apply the 1D formula:
$$W_2^2 = \frac{1}{3} \left[|1 - 0|^2 + |2 - 4|^2 + |5 - 7|^2\right] = \frac{1}{3}[1 + 4 + 4] = \frac{9}{3} = 3$$
$$W_2 = \sqrt{3} \approx 1.732$$
The optimal transport pairs $1 \to 0$, $2 \to 4$, $5 \to 7$ โ the monotone coupling.
Example 2: Discrete Kantorovich Problem
Problem: Two distributions on ${A, B, C}$: - $\mu = [0.5, 0.3, 0.2]$ - $\nu = [0.2, 0.4, 0.4]$ - Cost matrix $C$ (Euclidean on a line: positions 0, 1, 2): $c_{ij} = |i - j|$
Find the optimal transport plan $\gamma$.
Solution:
Cost matrix: $$C = \begin{bmatrix} 0 & 1 & 2 \ 1 & 0 & 1 \ 2 & 1 & 0 \end{bmatrix}$$
LP formulation (Kantorovich): minimise $\sum_{i,j} \gamma_{ij} c_{ij}$ subject to row sums = $\mu$, column sums = $\nu$, $\gamma \geq 0$.
Intuitive solution (mass to nearest neighbour): - $\mu_A = 0.5$: Send 0.2 to $\nu_A$ (cost 0), 0.3 to $\nu_B$ (cost 1 ร 0.3 = 0.3) - $\mu_B = 0.3$: Send 0.1 to $\nu_B$ (cost 0), 0.2 to $\nu_C$ (cost 1 ร 0.2 = 0.2) - $\mu_C = 0.2$: Send 0.2 to $\nu_C$ (cost 0)
$$\gamma = \begin{bmatrix} 0.2 & 0.3 & 0 \ 0 & 0.1 & 0.2 \ 0 & 0 & 0.2 \end{bmatrix}$$
Total cost = $0.2 \cdot 0 + 0.3 \cdot 1 + 0.1 \cdot 0 + 0.2 \cdot 1 + 0.2 \cdot 0 = 0.5$
$W_1(\mu, \nu) = 0.5$ โ
Example 3: Sinkhorn Iteration
Problem: Apply one Sinkhorn iteration to the same problem with $\varepsilon = 1$. Track $\mathbf{u}$ and $\mathbf{v}$.
Solution:
Step 1 โ Gibbs kernel $K_{ij} = \exp(-c_{ij}/\varepsilon)$:
$$K = \begin{bmatrix} 1.000 & 0.368 & 0.135 \ 0.368 & 1.000 & 0.368 \ 0.135 & 0.368 & 1.000 \end{bmatrix}$$
Step 2 โ Initialise $\mathbf{v}^{(0)} = [1, 1, 1]$.
Step 3 โ Row update: $\mathbf{u}^{(1)} = \mu \oslash K\mathbf{v}^{(0)}$
$K\mathbf{v}^{(0)} = [1.503, 1.736, 1.503]$
$\mathbf{u}^{(1)} = [0.5/1.503, 0.3/1.736, 0.2/1.503] = [0.333, 0.173, 0.133]$
Step 4 โ Column update: $\mathbf{v}^{(1)} = \nu \oslash K^T\mathbf{u}^{(1)}$
$K^T\mathbf{u}^{(1)} = [0.415, 0.345, 0.241]$
$\mathbf{v}^{(1)} = [0.2/0.415, 0.4/0.345, 0.4/0.241] = [0.482, 1.160, 1.659]$
After many iterations, $\gamma$ converges to the entropy-regularised plan, approximating the exact OT plan more smoothly (mass is more spread out).
Practice Problems
Problem 1: For two Gaussians $\mathcal{N}(m_1, \sigma^2)$ and $\mathcal{N}(m_2, \sigma^2)$ (same variance), prove $W_2^2 = (m_1 - m_2)^2$.
Problem 2: Express $W_2(\mu, \nu)$ for 1D in terms of the quantile functions. Then compute $W_2$ between $\text{Uniform}(0,1)$ and $\text{Uniform}(0.2, 0.8)$.
Problem 3: Show that $W_p(\mu, \nu)$ is a metric โ specifically, prove the triangle inequality using the gluing lemma (existence of a joint distribution with given pairwise marginals).
Problem 4: For the WGAN critic loss with gradient penalty, explain why weight clipping (original WGAN) fails and how gradient penalty (WGAN-GP) fixes it.
Problem 5: Two datasets $X = {1, 1, 2, 3}$ and $Y = {0, 2, 2, 4}$ with equal weights $1/4$. Compute $W_1$ between the empirical distributions.
Answers (click to expand)
**Problem 1:** For Gaussians with equal variance, the optimal transport map is $T(x) = x + (m_2 - m_1)$ (translation). Since $\|x - T(x)\| = |m_1 - m_2|$ for every $x$, the $W_2$ distance is exactly $|m_1 - m_2|$. Squaring gives $W_2^2 = (m_1 - m_2)^2$. Formally: $W_2^2 = \inf_{\gamma} \int (x-y)^2 d\gamma = \int (x - (x + m_2 - m_1))^2 d\mu = (m_1 - m_2)^2$. **Problem 2:** Quantile functions: $F_\mu^{-1}(t) = t$ for $\mu = U(0,1)$. $F_\nu^{-1}(t) = 0.2 + 0.6t$ for $\nu = U(0.2, 0.8)$. $$W_2^2 = \int_0^1 |t - (0.2 + 0.6t)|^2 dt = \int_0^1 (0.4t - 0.2)^2 dt = \int_0^1 (0.16t^2 - 0.16t + 0.04) dt = [0.053t^3 - 0.08t^2 + 0.04t]_0^1 = 0.053 - 0.08 + 0.04 = 0.013$$ $W_2 = \sqrt{0.013} \approx 0.115$. **Problem 3:** By the gluing lemma, given optimal plans $\gamma_{12}$ for $W_p(\mu_1, \mu_2)$ and $\gamma_{23}$ for $W_p(\mu_2, \mu_3)$, there exists a joint distribution $\gamma$ on $\mathcal{X}_1 \times \mathcal{X}_2 \times \mathcal{X}_3$ with marginals $\gamma_{12}$ and $\gamma_{23}$ on the appropriate pairs. Then: $$W_p(\mu_1, \mu_3) \leq \left(\int \|x-z\|^p d\gamma_{13}\right)^{1/p} \leq \left(\int (\|x-y\| + \|y-z\|)^p d\gamma\right)^{1/p} \leq W_p(\mu_1, \mu_2) + W_p(\mu_2, \mu_3)$$ by Minkowski's inequality applied to the $L^p$ norm of $(\|x-y\|, \|y-z\|)$. **Problem 4:** Weight clipping (enforcing $|w| \leq c$) constrains the critic to be Lipschitz but also limits capacity โ the critic can only produce simple functions, often resulting in vanishing gradients and poor generated samples. Gradient penalty adds $\lambda \mathbb{E}_{\hat{x}}[(\|\nabla_{\hat{x}} f(\hat{x})\|_2 - 1)^2]$ where $\hat{x}$ is sampled along lines between real and fake โ this softly enforces $\|\nabla f\| \approx 1$ everywhere, which is the optimal Lipschitz condition, while allowing full network capacity. **Problem 5:** Sort both: $X: [1, 1, 2, 3]$, $Y: [0, 2, 2, 4]$. For $W_1$ with equal weights $1/4$: $$W_1 = \frac{1}{4}(|1-0| + |1-2| + |2-2| + |3-4|) = \frac{1}{4}(1 + 1 + 0 + 1) = \frac{3}{4} = 0.75$$ The optimal plan pairs sorted order: $1 \to 0$, $1 \to 2$, $2 \to 2$, $3 \to 4$.Summary
- Monge asked for a deterministic map minimising transport cost โ Kantorovich relaxed this to a convex joint distribution problem, enabling practical optimisation.
- Wasserstein distance $W_p$ is a true metric on probability distributions that provides meaningful gradients even for disjoint-support distributions, unlike KL/JS.
- Sinkhorn algorithm solves entropic-regularised OT in $O(n^2)$ per iteration via iterative row/column normalisation, enabling OT at scale.
- WGAN uses the Kantorovich-Rubinstein dual of $W_1$ to train GANs stably, with gradient penalty enforcing the Lipschitz constraint smoothly.
- Beyond GANs, optimal transport enables domain adaptation and principled dataset comparison through distribution alignment.
Quiz
Question 1: What is the key difference between the Monge and Kantorovich formulations of optimal transport?
A. Monge minimises cost while Kantorovich maximises it B. Monge uses a deterministic map while Kantorovich allows mass splitting via a joint distribution C. Monge applies to continuous distributions while Kantorovich applies to discrete ones D. Monge uses Euclidean cost while Kantorovich uses any cost function
Correct Answer: B
Explanation: Monge's problem seeks a deterministic transport map T with T_#ฮผ = ฮฝ, which may not exist when mass needs to split across multiple targets. Kantorovich relaxed this to a joint distribution ฮณ โ ฮ (ฮผ, ฮฝ) where ฮณ(x, y) represents how much mass flows from x to y โ this makes the problem convex and always feasible.
Question 2: Why does the Wasserstein distance provide more useful gradients for GAN training than KL or JS divergence?
A. It is always smaller, preventing gradient explosion B. It is defined and varies smoothly even when distributions have disjoint supports C. It doesn't require a discriminator network D. It is computationally cheaper to estimate
Correct Answer: B
Explanation: When P_r and P_g have disjoint supports (common early in GAN training), JS divergence saturates at log 2 and provides zero gradient, while KL diverges to infinity. The Wasserstein distance W_p varies continuously with distribution shifts, giving the generator meaningful gradient signal throughout training.
Question 3: For two 1D empirical distributions with equal-weight sorted points xโ โค ... โค x_n and yโ โค ... โค y_n, the optimal W_p transport cost is:
A. (1/n) โแตข |x_i - y_{n-i+1}|^p B. max_i |x_i - y_i|^p C. (1/n) โแตข |x_i - y_i|^p D. |mean(x) - mean(y)|^p
Correct Answer: C
Explanation: In 1D with convex cost, the optimal transport plan is the monotone coupling โ pair the k-th order statistic of the source with the k-th of the target. This follows from the convexity of |x-y|^p and the rearrangement inequality. Anti-monotone pairing (option A) is optimal only for concave costs.
Question 4: The Sinkhorn algorithm solves entropy-regularised optimal transport via:
A. Gradient descent on the transport plan B. Iterative row and column normalisation of a Gibbs kernel matrix C. Singular value decomposition of the cost matrix D. Solving a system of linear equations
Correct Answer: B
Explanation: With entropic regularisation (ฮต > 0), the optimal plan has the form ฮณ_ij = u_i K_ij v_j where K_ij = exp(-c_ij/ฮต). The Sinkhorn algorithm finds u and v by alternating: u_i = ฮผ_i / โ_j K_ij v_j (row normalisation) and v_j = ฮฝ_j / โ_i K_ij u_i (column normalisation). Each iteration is O(nยฒ) โ far more efficient than the O(nยณ) exact LP.
Question 5: In WGAN-GP, the gradient penalty term ฮป E[(||โ_xฬ f(xฬ)||โ - 1)ยฒ] enforces:
A. That the critic output is bounded B. That the critic is 1-Lipschitz by targeting gradient norm = 1 at interpolated points C. That the generator gradients don't vanish D. That the Wasserstein distance is exactly 1
Correct Answer: B
Explanation: The Kantorovich-Rubinstein dual requires the critic f to be 1-Lipschitz. The gradient penalty softly enforces ||โf|| โ 1 at points interpolated between real and fake samples, which is the condition for an optimal dual potential โ the gradient norm should be exactly 1 on the transport paths, not merely โค 1.
Pitfalls
- $W_p$ vs KL: Don't use Wasserstein distances interchangeably with $f$-divergences. $W_p$ measures how far you have to move mass; KL measures relative information content. They capture fundamentally different notions of distribution difference.
- Sinkhorn bias: The entropic regularisation biases the transport plan toward higher entropy (more spread out mass). For exact OT, take the limit $\varepsilon \to 0$, but this requires many iterations.
- $W_p$ computation cost: Exact OT for discrete distributions with $n$ points costs $O(n^3)$ โ too slow for large datasets. Use Sinkhorn ($O(n^2)$) or sliced Wasserstein (1D projections, $O(n \log n)$) for practical applications.
Next Steps
We've reached the final subject of Phase 24. Move on to 25-01 โ Mechanistic Interpretability, the first of the frontier topics exploring what neural networks actually learn and how to reverse-engineer their internal representations.