<aside> 📃

Many natural dynamic processes -such as in vivo cellular differentiation or disease progression- can only be observed through the lens of static sample snapshots. While challenging, reconstructing their temporal evolution to decipher underlying dynamic properties is of major interest to scientific research. Existing approaches enable data transport along a temporal axis but are poorly scalable in high dimension and require restrictive assumptions to be met. To address these issues, we propose Multi-Marginal temporal Schrödinger Bridge Matching (MMtSBM) for video generation from unpaired data, extending the theoretical guarantees and empirical efficiency of https://arxiv.org/abs/2303.16852 by deriving the Iterative Markovian Fitting algorithm to multiple marginals in a novel factorized fashion. Experiments show that MMtSBM retains theoretical properties on toy examples, achieves state-of-the-art performance on real world datasets such as transcriptomic trajectory inference in 100 dimensions, and for the first time recovers couplings and dynamics in very high dimensional image settings. Our work establishes multi-marginal Schrödinger bridges as a practical and principled approach for recovering hidden dynamics from static data.

</aside>

https://github.com/ICLRMMtDSBM/MMDSBM_ILCR

Videos of experiments

Exact OT between Gaussian Mixtures

In this 2D experiment akin to https://arxiv.org/abs/2209.03003, we used $N=3$ mixtures of two standard Gaussian as marginals. In this configuration the optimal transport between each pair of marginals is known exactly: it is a pure translation of each Gaussian components inside the mixtures.

Epoch $0$ (only noisy flow matching).

Epoch $0$ (only noisy flow matching).

Epoch $5$ (after MMtSBM training).

Epoch $5$ (after MMtSBM training).

Above videos: True marginal times $(t_0, t_1, t_2)=(0, 1, 2)$. The order of the $3$ true marginals is: $t_0=$ dark blue; $t_1=$ red; $t_2=$ light blue. Generated samples are in green. In the background is the quiver plot of the learned score network.

<aside> 🔬

After only the warm-up phase, we can see that the learned transport maps mix the Gaussian components of the mixtures, resulting in intersecting trajectories as can be seen in the left video. However, after the SB learning phase of MMtSBM, we can see in the right video that the learned trajectories do not intersect each other anymore and that MMtSBM yields the exact static optimal transport map: pure translations between Gaussians.

</aside>

Usual Schrödinger Bridge metrics

To quantitatively verify that MMtSBM recovers the correct multi-marginal SB in terms of both 1) static coupling and 2) energy minimization, we extended the now classical "Moons" and "8Gaussians" experiments found in https://arxiv.org/abs/2302.00482 and https://arxiv.org/abs/2303.16852 to our temporal multi-marginal setting. Choosing $N=4$, we considered ( $\mathcal{N}$$\to$ Moons $\to$ $\mathcal{N}$ $\to$ Moons ), and ( $\mathcal{N}$$\to$ 8Gaussians $\to\mathcal{N}\to$ 8Gaussians ). To assess 1) we report the $\mathcal{W}_2$ **distance of generations vs test set data at target marginal time(s), averaging along the $N-1=3$ target times for MMtSBM and comparing this to the single bridge setting. To assess 2) we report the full path energy $\mathbb{E}\left[\int{0}^{T} \| v(t, \mathbf{Z}_t) \|^2 \, dt \right]$ where $Z_t$ is the process simulated along the ODE drift.

“8Gaussians” experiment: $\mathcal{N}$$\to$ 8Gaussians $\to\mathcal{N}\to$ 8Gaussians.

“8Gaussians” experiment: $\mathcal{N}$$\to$ 8Gaussians $\to\mathcal{N}\to$ 8Gaussians.

"Moons" experiment: $\mathcal{N}$$\to$ 8Gaussians $\to\mathcal{N}\to$ 8Gaussians.

"Moons" experiment: $\mathcal{N}$$\to$ 8Gaussians $\to\mathcal{N}\to$ 8Gaussians.

Comparison in terms of static coupling ("$\mathcal{W}_2$") and energy minimization ("Path Energy"). The rows marked ``$\times 3$'' correspond to the hypothetical case where the energy of a single bridge is simply tripled, and are included as an ideal baseline for comparison with our actual multi-bridge setting. All metrics apart from ours are from Diffusion Schrödinger Bridge Matching.

Comparison in terms of static coupling ("$\mathcal{W}_2$") and energy minimization ("Path Energy"). The rows marked ``$\times 3$'' correspond to the hypothetical case where the energy of a single bridge is simply tripled, and are included as an ideal baseline for comparison with our actual multi-bridge setting. All metrics apart from ours are from Diffusion Schrödinger Bridge Matching.

<aside> 🔬

We observe that despite a much more complex time-varying true transport map to be learned, MMtSBM achieves almost as low $\mathcal{W}_2$ distances than the simple single-bridge setting (3% to 4%), and that our full path energy is within 13% to 6% of the ideal extrapolation of the single bridge result. This validates that MMtSBM manages to approach the true SB in practice.

</aside>

State-of-the-art 100D transcriptomic results

We evaluated our method on the the TrajectoryNet benchmark https://arxiv.org/abs/2002.04461 which uses real single-cell RNA-seq embryoid body differentiation data from https://www.nature.com/articles/s41587-019-0336-3. We project RNA counts to their first 100 principal components for each of the $N=5$ marginals. We report in the tables below the Maximum Mean Discrepency (MMD) and Sliced Wasserstein Distance (SWD).

Video of an example 2D dynamic on the first 2 principal components.

Video of an example 2D dynamic on the first 2 principal components.

Results on test set of embryoid body RNA-seq data ($d=100$). Left table: per-marginal metrics. Right table: average over all target marginals. Results from Deep Momentum Multi-Marginal Schrödinger Bridge. Best value in bold.

Results on test set of embryoid body RNA-seq data ($d=100$). Left table: per-marginal metrics. Right table: average over all target marginals. Results from Deep Momentum Multi-Marginal Schrödinger Bridge. Best value in bold.

<aside> 🔬

Our method consistently outperforms baselines on all marginals, reducing the average MMD from 0.03 to 0.02 and the SWD from 0.20 to 0.13. This demonstrates that enforcing all marginal constraints simultaneously as we do yields sharper and more consistent interpolations across developmental stages, setting a new state of the art on this benchmark.

</aside>