Notebook A — Flow Matching to a Three-Dirac Mixture¶
Stochastic interpolant. Couple a Gaussian base $X_0\sim\mathcal N(0,I_d)$ with a discrete target $X_1=\sum_{k=1}^3 w_k\,\delta_{\mu_k}$ through
$$ X_t = a(t)\,X_0 + b(t)\,X_1, \qquad t\in[0,1], $$
with a $C^1$ schedule $(a,b)$ obeying $a(0)=1,\,b(0)=0,\,a(1)=0,\,b(1)=1$. The marginal law $\rho_t$ of $X_t$ is carried from $\rho_0=\mathcal N(0,I)$ to $\rho_1=\sum_k w_k\delta_{\mu_k}$ by the marginal velocity field
$$ v(x,t)=\mathbb E\!\left[\dot a\,X_0+\dot b\,X_1 \mid X_t=x\right], $$
which solves the continuity equation $\partial_t\rho_t+\operatorname{div}(\rho_t v)=0$. Integrating $\dot x=v(x,t)$ from a base sample produces a sample of $\rho_t$ at every time.
This notebook does four things, with the maths derived in text between the plots:
- writes the marginal and the velocity in closed form (the mixture is conditionally Gaussian);
- cross-checks the closed-form velocity against the autodiff score of $\log\rho_t$, and shows where they disagree (near $t=1$);
- integrates the flow and colours each trajectory by the atom it reaches, under three schedules (linear, variance-preserving, cosine);
- contrasts the flow with the semi-discrete optimal-transport map (straight lines into Laguerre cells) — making visible that flow matching is not OT.
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
# Make the local fmg package importable regardless of the launch directory.
def _find(name):
p = Path.cwd()
for _ in range(6):
if (p / name).exists():
return p
p = p.parent
return Path.cwd()
PY_DIR = _find("fmg")
sys.path.insert(0, str(PY_DIR))
REPO = PY_DIR.parent
FIG_DIR = REPO / "paper" / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)
import fmg
from fmg import schedules, mixture, ode, ot_semidiscrete, plotting
from fmg.seeding import set_seed
SEED = set_seed()
plt = plotting.setup_matplotlib()
print("fmg", fmg.__version__, "| seed", SEED, "| figures ->", FIG_DIR)
fmg 0.1.0 | seed 20260617 | figures -> /Users/eserie/galaxies/flow-matching-gaussians/.worktrees/task-20260617-32e5/paper/figures
1. The target, the schedules, and the base samples¶
We place three atoms in the plane ($d=2$) with unequal weights, and contrast the three schedules studied throughout the repository:
| schedule | $a(t)$ | $b(t)$ | character |
|---|---|---|---|
| linear | $1-t$ | $t$ | straight sample paths (rectified flow) |
| variance-preserving | $\cos\frac{\pi t}{2}$ | $\sin\frac{\pi t}{2}$ | unit $a^2+b^2$, no mid-path pinch |
| cosine | $\tfrac12(1+\cos\pi t)$ | $\tfrac12(1-\cos\pi t)$ | slow start/end, fast middle |
mus = np.array([[2.4, 0.0], [-1.4, 2.0], [-1.4, -2.0]]) # three atoms
w = np.array([0.45, 0.30, 0.25]) # weights (sum to 1)
SCHEDULES = [schedules.linear(), schedules.variance_preserving(), schedules.cosine()]
for s in SCHEDULES:
s.check_boundary()
print("atoms\n", mus, "\nweights", w)
print("schedules:", [s.name for s in SCHEDULES], "— boundary conditions OK")
N = 600
x0 = np.random.randn(N, 2) # base samples ~ N(0, I)
atoms [[ 2.4 0. ] [-1.4 2. ] [-1.4 -2. ]] weights [0.45 0.3 0.25] schedules: ['linear', 'variance-preserving', 'cosine'] — boundary conditions OK
2. Closed-form marginal and velocity¶
Conditioned on $X_1=\mu_k$ we have $X_t=a X_0+b\mu_k\sim\mathcal N(b\mu_k,\,a^2 I)$. Averaging over the three atoms gives a closed-form Gaussian mixture
$$ \rho_t(x)=\sum_{k}w_k\,\mathcal N(x;\,b\mu_k,\,a^2 I). $$
Because each component is Gaussian, the conditional expectation defining $v$ is explicit. Writing the posterior weights (responsibilities)
$$ \gamma_k(x,t)=\frac{w_k\,\mathcal N(x;b\mu_k,a^2I)}{\sum_j w_j\,\mathcal N(x;b\mu_j,a^2I)}, $$
a short computation (conditional mean of a Gaussian, then averaging) yields
$$ \boxed{\;v(x,t)=\frac{\dot a}{a}\,x+\Bigl(\dot b-\frac{\dot a\,b}{a}\Bigr)\sum_k\gamma_k(x,t)\,\mu_k\;} $$
The term $\sum_k\gamma_k\mu_k=:m(x,t)$ is the posterior mean of the target atom.
Numerical note. As $t\to1$, $a\to0$ and the Gaussian densities underflow, so $\gamma_k$ would compute as $0/0=$ NaN. We evaluate it in log-space, $\ell_k=\log w_k-\|x-b\mu_k\|^2/(2a^2)$ (the shared $(2\pi a^2)^{-d/2}$ normaliser cancels), followed by a softmax. As $a\to0$ this correctly tends to the hard nearest-atom assignment.
# Marginal density rho_t on a grid, linear schedule, at three times.
sched = schedules.linear()
gx = np.linspace(-5, 5, 240)
GX, GY = np.meshgrid(gx, gx)
grid = np.column_stack([GX.ravel(), GY.ravel()])
fig, axes = plt.subplots(1, 3, figsize=(12, 3.8))
for ax, t in zip(axes, [0.15, 0.5, 0.9]):
logrho = mixture.log_density(grid, t, sched, mus, w).reshape(GX.shape)
ax.contourf(GX, GY, np.exp(logrho), levels=30, cmap="magma")
ax.scatter(*mus.T, c=plotting.ATOM_COLORS, s=60, edgecolor="white", zorder=5)
ax.set_title(f"$\\rho_t$ (linear, $t={t}$)")
ax.set_aspect("equal"); ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
fig.suptitle("Closed-form marginal: a Gaussian blob splitting into three atoms")
plotting.savefig_pdf(fig, FIG_DIR / "mixture_marginal_density.pdf")
plt.show()
print("gamma finite at t=0.999:", np.isfinite(mixture.posterior_gamma(x0, 0.999, sched, mus, w)).all())
gamma finite at t=0.999: True
3. Cross-check: closed form vs. autodiff of $\log\rho_t$¶
An independent check of the velocity formula. The Gaussian-mixture identity
$$ \nabla_x\log\rho_t(x)=\frac{b\,m(x,t)-x}{a^2} $$
lets us recover the posterior mean $m$ from the autodiff score of $\log\rho_t$, and then re-assemble $v$. The two routes — analytic posterior vs. differentiated density — must agree.
They agree to machine precision for $t$ away from $1$. Near $t=1$ the velocity is a removable $\infty-\infty$ singularity (both $\dot a/a$ and $\dot b-\dot a b/a$ blow up like $1/(1-t)$ and cancel only along the true solution), so finite-precision arithmetic loses 3–4 digits. We therefore report the discrepancy as a curve $\mathrm{error}(t)$, not a single pass/fail scalar.
ts = np.linspace(0.05, 0.97, 40)
fig, ax = plt.subplots(figsize=(7, 4))
for s in SCHEDULES:
_, err = mixture.velocity_error_curve(x0[:200], ts, s, mus, w)
ax.semilogy(ts, err + 1e-18, label=s.name)
ax.set_xlabel("$t$"); ax.set_ylabel(r"$\max_x\,|v_{\rm closed}-v_{\rm autodiff}|$")
ax.set_title("Closed-form vs. autodiff velocity — agreement degrades only near $t=1$")
ax.legend()
plotting.savefig_pdf(fig, FIG_DIR / "mixture_velocity_crosscheck.pdf")
plt.show()
4. Integrating the flow — trajectories coloured by their atom¶
We integrate $\dot x=v(x,t)$ from the base samples with an adaptive RK45
solver (scipy.integrate.solve_ivp, rtol=1e-8), stopping at $t=1-\varepsilon$
($\varepsilon=10^{-3}$) to stay clear of the singularity. Each trajectory is
coloured by the atom nearest its endpoint — the three colours partition the base
Gaussian into the three basins.
def vel_fn(x, t, s):
return mixture.velocity(x, t, s, mus, w)
def integrate_and_colour(s):
tt, traj = ode.integrate_samples(x0, s, vel_fn, eps=1e-3)
end = traj[-1]
labels = np.argmin(np.linalg.norm(end[:, None, :] - mus[None, :, :], axis=-1), axis=1)
return tt, traj, labels
def plot_traj(ax, traj, labels, title):
for i in range(traj.shape[1]):
ax.plot(traj[:, i, 0], traj[:, i, 1], lw=0.4, alpha=0.35,
color=plotting.ATOM_COLORS[labels[i]])
ax.scatter(traj[0, :, 0], traj[0, :, 1], s=3, c="0.4", alpha=0.5)
ax.scatter(*mus.T, c=plotting.ATOM_COLORS, s=90, edgecolor="white", zorder=5)
ax.set_title(title); ax.set_aspect("equal"); ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
results = {s.name: integrate_and_colour(s) for s in SCHEDULES}
# One detailed figure for the linear schedule.
fig, ax = plt.subplots(figsize=(5.5, 5.5))
_, traj, labels = results["linear"]
plot_traj(ax, traj, labels, "Flow-matching trajectories (linear schedule)")
plotting.savefig_pdf(fig, FIG_DIR / "mixture_trajectories_linear.pdf")
plt.show()
# Side-by-side over the three schedules.
fig, axes = plt.subplots(1, 3, figsize=(13.5, 4.6))
for ax, s in zip(axes, SCHEDULES):
_, traj, labels = results[s.name]
plot_traj(ax, traj, labels, s.name)
fig.suptitle("Same endpoints, different paths — the schedule bends the trajectories")
plotting.savefig_pdf(fig, FIG_DIR / "mixture_trajectories_schedules.pdf")
plt.show()
frac = {n: np.bincount(r[2], minlength=3) / len(r[2]) for n, r in results.items()}
print("empirical basin fractions vs. target weights", np.round(w, 3))
for n, f in frac.items():
print(f" {n:20s} {np.round(f, 3)}")
empirical basin fractions vs. target weights [0.45 0.3 0.25] linear [0.428 0.317 0.255] variance-preserving [0.428 0.317 0.255] cosine [0.428 0.317 0.255]
5. Contrast with optimal transport (semi-discrete / Laguerre)¶
Flow matching uses the independent coupling of $X_0$ and $X_1$. The optimal-transport map for the quadratic cost uses the optimal coupling instead. From a continuous Gaussian to a few atoms, the OT (Brenier) map is piecewise constant: it sends each Laguerre cell
$$ L_k=\bigl\{x:\tfrac12\|x-\mu_k\|^2-\psi_k\le\tfrac12\|x-\mu_j\|^2-\psi_j\ \forall j\bigr\} $$
to the atom $\mu_k$, with potentials $\psi$ fixed by the mass constraints $\mathbb P(X_0\in L_k)=w_k$. Each point then moves in a straight line to its atom. We solve for $\psi$ via POT's exact-LP dual (cross-checked against a damped-Newton solve on the dual), draw the Laguerre cells, and overlay the straight OT trajectories on the curved flow-matching ones. They are visibly different transports of the same two marginals.
sdot = ot_semidiscrete.SemiDiscreteOT(mus, w).fit(n_samples=60000, seed=SEED)
print("semi-discrete OT fit residual (max |mass_k - w_k|):", f"{sdot.fit_residual:.2e}")
print("Laguerre potentials psi:", np.round(sdot.psi, 4))
cells = sdot.assign(grid).reshape(GX.shape)
ts_ot = np.linspace(0, 1, 60)
ot_traj = sdot.trajectory(x0, ts_ot) # (T, N, 2) straight lines
ot_labels = sdot.assign(x0)
fig, axes = plt.subplots(1, 2, figsize=(12, 5.6))
# Left: Laguerre cells + straight OT trajectories.
ax = axes[0]
ax.contourf(GX, GY, cells, levels=[-0.5, 0.5, 1.5, 2.5],
colors=[plotting.ATOM_COLORS[k] for k in range(3)], alpha=0.18)
for i in range(0, N, 2):
ax.plot(ot_traj[:, i, 0], ot_traj[:, i, 1], lw=0.4, alpha=0.35,
color=plotting.ATOM_COLORS[ot_labels[i]])
ax.scatter(*mus.T, c=plotting.ATOM_COLORS, s=90, edgecolor="white", zorder=5)
ax.set_title("Optimal transport: straight lines into Laguerre cells")
ax.set_aspect("equal"); ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
# Right: flow-matching (linear) for comparison.
_, traj, labels = results["linear"]
plot_traj(axes[1], traj, labels, "Flow matching (linear): curved, independent coupling")
fig.suptitle("Flow matching $\\neq$ optimal transport (same marginals, different coupling)")
plotting.savefig_pdf(fig, FIG_DIR / "mixture_ot_contrast.pdf")
plt.show()
semi-discrete OT fit residual (max |mass_k - w_k|): 3.72e-03 Laguerre potentials psi: [ 0. -0.8065 -1.2405]
6. Takeaways¶
- The mixture marginal and velocity are closed-form; the velocity check against autodiff passes to machine precision except near $t=1$, where the removable singularity costs a few digits — handled by integrating to $t=1-\varepsilon$ and computing responsibilities in log-space.
- Every admissible schedule carries the base Gaussian to the same three atoms with the same basin weights; the schedule only bends the paths.
- The flow-matching transport is not the optimal-transport map: OT moves mass
in straight lines into Laguerre cells, flow matching follows curved paths from
the independent coupling. The Gaussian companion notebook
(
notebook_b_ellipses) makes this gap exact and characterises when it closes.