Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6f9b2df
1st try potentials OT 1d
clbonet Jul 28, 2025
cb5660d
emd1d_dual ok without batch
clbonet Jul 29, 2025
0a9d38b
batched emd1d_dual
clbonet Jul 29, 2025
0be92a7
1d potentials with backprop, 1d uot 1st try
clbonet Aug 6, 2025
cade9d5
up tests 1d solvers
clbonet Aug 6, 2025
b055044
file sliced uot
clbonet Aug 9, 2025
c4971a8
clip max cdf in wasserstein_1d
clbonet Aug 9, 2025
246d092
Example UOT 1d
clbonet Aug 9, 2025
b0c791c
normalize weights
clbonet Aug 10, 2025
f9dc43a
add suot
clbonet Aug 10, 2025
6fe05ea
add code example (to test)
clbonet Aug 10, 2025
c361c32
tests backend
clbonet Aug 10, 2025
c08655c
up code example 1D UOT
clbonet Aug 22, 2025
fd00304
resolve conflicts with master
clbonet Aug 27, 2025
26473e1
Examples UOT 1D
clbonet Aug 27, 2025
0ca65a6
fix output loss uot_1d
clbonet Aug 28, 2025
c6301b8
Example USOT vs SUOT
clbonet Sep 13, 2025
504c07a
Center dual potentials
clbonet Sep 14, 2025
812b4da
up tests
clbonet Sep 14, 2025
5fca694
merge main
clbonet Oct 5, 2025
801aa89
up citation
clbonet Oct 5, 2025
7be3794
Merge master
clbonet Jan 30, 2026
362d2aa
Merge branch 'master' into sliced_uot
clbonet Jan 30, 2026
ee19161
fix backend and skip tf in 1d_dual tests
clbonet Jan 30, 2026
2f60ba9
lint
clbonet Jan 30, 2026
cd176ae
Default p=2 for UOT 1D
clbonet Jan 31, 2026
311e106
Test UOT1D, refactorize W2 on circle
clbonet Feb 1, 2026
acb4059
Typo doc
clbonet Feb 1, 2026
4da4ad1
Typo test sum
clbonet Feb 1, 2026
3c79233
Skip test TF
clbonet Feb 1, 2026
6826cc7
update plot example
clbonet Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ POT provides the following generic OT solvers:
Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation [73] and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [82]
* [Wasserstein distance on the
circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html)
[44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
Expand Down Expand Up @@ -367,7 +368,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer

[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.

[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.

[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.

Expand Down Expand Up @@ -449,5 +450,4 @@ Artificial Intelligence.

[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS).


```
[82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2024). [Slicing Unbalanced Optimal Transport](https://openreview.net/forum?id=AjJTg5M0r8). Transactions on Machine Learning Research.
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)

#### Closed issues

Expand Down
173 changes: 165 additions & 8 deletions examples/unbalanced-partial/plot_UOT_1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

# Author: Hicham Janati <hicham.janati@inria.fr>
# Clément Bonet <clemebt.bonet.mapp@polytechnique.edu>
#
# License: MIT License

Expand All @@ -19,6 +20,8 @@
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss
import torch
import cvxpy as cp

##############################################################################
# Generate data
Expand All @@ -41,7 +44,6 @@

# loss matrix
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
M /= M.max()


##############################################################################
Expand All @@ -61,30 +63,185 @@
ot.plot.plot1D_mat(a, b, M, "Cost matrix M")


##############################################################################
# Solve Unbalanced OT with MM Unbalanced
# -----------------------------------

# %% MM Unbalanced

alpha = 1.0 # Unbalanced KL relaxation parameter

Gs = ot.unbalanced.mm_unbalanced(a, b, M / M.max(), alpha, verbose=False)

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source")
pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", Gs.sum())


##############################################################################
# Solve 1D UOT with Frank-Wolfe
# -----------------------------

# %% 1D UOT with FW


alpha = M.max() # Unbalanced KL relaxation parameter

a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d(
x, x, alpha, u_weights=a, v_weights=b, p=2
)

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source")
pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", a_reweighted.sum())


##############################################################################
# Solve 1D UOT with Frank-Wolfe (backprop mode)
# -----------------------------


# %% 1D UOT with FW (backprop mode)


alpha = M.max() # Unbalanced KL relaxation parameter

a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d(
torch.tensor(x, dtype=torch.float64),
torch.tensor(x, dtype=torch.float64),
alpha,
u_weights=torch.tensor(a, dtype=torch.float64),
v_weights=torch.tensor(b, dtype=torch.float64),
p=2,
mode="backprop",
)

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source")
pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", a_reweighted.sum())


##############################################################################
# Solve 1D USOT with Frank-Wolfe with UOT (TO CHECK)
# -----------------------------

# %% TEST USOT


alpha = M.max() # Unbalanced KL relaxation parameter

a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot(
torch.tensor(x.reshape((n, 1)), dtype=torch.float64),
torch.tensor(x.reshape((n, 1)), dtype=torch.float64),
alpha,
torch.tensor(a, dtype=torch.float64),
torch.tensor(b, dtype=torch.float64),
mode="backprop",
p=2,
)


# plot the transported mass
# -------------------------

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, a_reweighted.numpy(), "b", alpha=0.5, label="Transported source")
pl.fill(x, b_reweighted.numpy(), "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", a_reweighted.sum())


##############################################################################
# Solve Unbalanced OT with cvxpy
# ------------------------------

# %% UOT with cvxpy


# (https://colab.research.google.com/github/gpeyre/ot4ml/blob/main/python/5-unbalanced.ipynb)

alpha = M.max() # Unbalanced KL relaxation parameter
n, m = a.shape[0], b.shape[0]

P = cp.Variable((n, m))

u = np.ones((n, 1))
v = np.ones((m, 1))
q = cp.sum(cp.kl_div(cp.matmul(P, v), a[:, None]))
r = cp.sum(cp.kl_div(cp.matmul(P.T, u), b[:, None]))

constr = [0 <= P]
# uncomment to perform balanced OT
# constr = [0 <= P, cp.matmul(P,u)==a[:,None], cp.matmul(P.T,v)==b[:,None]]

objective = cp.Minimize(cp.sum(cp.multiply(P, M)) + alpha * q + alpha * r)

prob = cp.Problem(objective, constr)
result = prob.solve()

G = P.value

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, G.sum(1), "b", alpha=0.5, label="Transported source")
pl.fill(x, G.sum(0), "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", Gs.sum())


##############################################################################
# Solve Unbalanced Sinkhorn
# -------------------------

# %% Sinkhorn UOT

# Sinkhorn

epsilon = 0.1 # entropy parameter
alpha = 1.0 # Unbalanced KL relaxation parameter
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M / M.max(), epsilon, alpha, verbose=True)

pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn")

pl.show()


# %%
# plot the transported mass
# -------------------------

pl.figure(4, figsize=(6.4, 3))
pl.plot(x, a, "b", label="Source distribution")
pl.plot(x, b, "r", label="Target distribution")
pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source")
pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target")
pl.legend(loc="upper right")
pl.title("Distributions and transported mass for UOT")
pl.show()

print("Mass of reweighted marginals:", Gs.sum())
Loading
Loading