diff --git a/README.md b/README.md index f8880a166..3e69af448 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,10 @@ POT provides the following generic OT solvers: Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. -* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] +* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation [73] and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [82] * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html) [44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] @@ -367,7 +368,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. @@ -449,5 +450,4 @@ Artificial Intelligence. [81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). - -``` +[82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2024). [Slicing Unbalanced Optimal Transport](https://openreview.net/forum?id=AjJTg5M0r8). Transactions on Machine Learning Research. diff --git a/RELEASES.md b/RELEASES.md index cdabf416d..da282b274 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -9,6 +9,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) - Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) +- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) +- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) #### Closed issues diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index ade4bbb0c..b7bcc420b 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -9,6 +9,7 @@ """ # Author: Hicham Janati +# Clément Bonet # # License: MIT License @@ -19,6 +20,8 @@ import ot import ot.plot from ot.datasets import make_1D_gauss as gauss +import torch +import cvxpy as cp ############################################################################## # Generate data @@ -41,7 +44,6 @@ # loss matrix M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) -M /= M.max() ############################################################################## @@ -61,26 +63,178 @@ ot.plot.plot1D_mat(a, b, M, "Cost matrix M") +############################################################################## +# Solve Unbalanced OT with MM Unbalanced +# ----------------------------------- + +# %% MM Unbalanced + +alpha = 1.0 # Unbalanced KL relaxation parameter + +Gs = ot.unbalanced.mm_unbalanced(a, b, M / M.max(), alpha, verbose=False) + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source") +pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", Gs.sum()) + + +############################################################################## +# Solve 1D UOT with Frank-Wolfe +# ----------------------------- + +# %% 1D UOT with FW + + +alpha = M.max() # Unbalanced KL relaxation parameter + +a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d( + x, x, alpha, u_weights=a, v_weights=b, p=2 +) + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", a_reweighted.sum()) + + +############################################################################## +# Solve 1D UOT with Frank-Wolfe (backprop mode) +# ----------------------------- + + +# %% 1D UOT with FW (backprop mode) + + +alpha = M.max() # Unbalanced KL relaxation parameter + +a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d( + torch.tensor(x, dtype=torch.float64), + torch.tensor(x, dtype=torch.float64), + alpha, + u_weights=torch.tensor(a, dtype=torch.float64), + v_weights=torch.tensor(b, dtype=torch.float64), + p=2, + mode="backprop", +) + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", a_reweighted.sum()) + + +############################################################################## +# Solve 1D USOT with Frank-Wolfe with UOT (TO CHECK) +# ----------------------------- + +# %% TEST USOT + + +alpha = M.max() # Unbalanced KL relaxation parameter + +a_reweighted, b_reweighted, loss = ot.unbalanced.unbalanced_sliced_ot( + torch.tensor(x.reshape((n, 1)), dtype=torch.float64), + torch.tensor(x.reshape((n, 1)), dtype=torch.float64), + alpha, + torch.tensor(a, dtype=torch.float64), + torch.tensor(b, dtype=torch.float64), + mode="backprop", + p=2, +) + + +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, a_reweighted.numpy(), "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted.numpy(), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", a_reweighted.sum()) + + +############################################################################## +# Solve Unbalanced OT with cvxpy +# ------------------------------ + +# %% UOT with cvxpy + + +# (https://colab.research.google.com/github/gpeyre/ot4ml/blob/main/python/5-unbalanced.ipynb) + +alpha = M.max() # Unbalanced KL relaxation parameter +n, m = a.shape[0], b.shape[0] + +P = cp.Variable((n, m)) + +u = np.ones((n, 1)) +v = np.ones((m, 1)) +q = cp.sum(cp.kl_div(cp.matmul(P, v), a[:, None])) +r = cp.sum(cp.kl_div(cp.matmul(P.T, u), b[:, None])) + +constr = [0 <= P] +# uncomment to perform balanced OT +# constr = [0 <= P, cp.matmul(P,u)==a[:,None], cp.matmul(P.T,v)==b[:,None]] + +objective = cp.Minimize(cp.sum(cp.multiply(P, M)) + alpha * q + alpha * r) + +prob = cp.Problem(objective, constr) +result = prob.solve() + +G = P.value + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, G.sum(1), "b", alpha=0.5, label="Transported source") +pl.fill(x, G.sum(0), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", Gs.sum()) + + ############################################################################## # Solve Unbalanced Sinkhorn # ------------------------- +# %% Sinkhorn UOT + # Sinkhorn epsilon = 0.1 # entropy parameter alpha = 1.0 # Unbalanced KL relaxation parameter -Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M / M.max(), epsilon, alpha, verbose=True) pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn") - pl.show() - -# %% -# plot the transported mass -# ------------------------- - pl.figure(4, figsize=(6.4, 3)) pl.plot(x, a, "b", label="Source distribution") pl.plot(x, b, "r", label="Target distribution") @@ -88,3 +242,6 @@ pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", Gs.sum()) diff --git a/examples/unbalanced-partial/plot_UOT_sliced.py b/examples/unbalanced-partial/plot_UOT_sliced.py new file mode 100644 index 000000000..d5937a71d --- /dev/null +++ b/examples/unbalanced-partial/plot_UOT_sliced.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Sliced Unbalanced optimal transport +=================================== + +This example illustrates the behavior of Sliced UOT versus +Unbalanced Sliced OT. + +The first one removes outliers on each slice while the second one +removes outliers of the original marginals. +""" + +# Author: Clément Bonet +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import torch +import matplotlib.pyplot as plt +import matplotlib as mpl + +from sklearn.neighbors import KernelDensity + +############################################################################## +# Generate data +# ------------- + + +# %% parameters + +get_rot = lambda theta: np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] +) + + +# regular distribution of Gaussians around a circle +def make_blobs_reg(n_samples, n_blobs, scale=0.5): + per_blob = int(n_samples / n_blobs) + result = np.random.randn(per_blob, 2) * scale + 5 + theta = (2 * np.pi) / (n_blobs) + for r in range(1, n_blobs): + new_blob = (np.random.randn(per_blob, 2) * scale + 5).dot(get_rot(theta * r)) + result = np.vstack((result, new_blob)) + return result + + +def make_blobs_random(n_samples, n_blobs, scale=0.5, offset=3): + per_blob = int(n_samples / n_blobs) + result = np.random.randn(per_blob, 2) * scale + np.random.randn(1, 2) * offset + for r in range(1, n_blobs): + new_blob = np.random.randn(per_blob, 2) * scale + np.random.randn(1, 2) * offset + result = np.vstack((result, new_blob)) + return result + + +def make_spiral(n_samples, noise=0.5): + n = np.sqrt(np.random.rand(n_samples, 1)) * 780 * (2 * np.pi) / 360 + d1x = -np.cos(n) * n + np.random.rand(n_samples, 1) * noise + d1y = np.sin(n) * n + np.random.rand(n_samples, 1) * noise + return np.array(np.hstack((d1x, d1y))) + + +n_samples = 500 +expe = "outlier" + +np.random.seed(42) + +nb_outliers = 200 +Xs = make_blobs_random(n_samples=n_samples, scale=0.2, n_blobs=1, offset=0) - 0.5 +Xs_outlier = make_blobs_random( + n_samples=nb_outliers, scale=0.05, n_blobs=1, offset=0 +) - [2, 0.5] + +Xs = np.vstack((Xs, Xs_outlier)) +Xt = make_blobs_random(n_samples=n_samples, scale=0.2, n_blobs=1, offset=0) + 1.5 +y = np.hstack(([0] * (n_samples + nb_outliers), [1] * n_samples)) +X = np.vstack((Xs, Xt)) + + +Xs_torch = torch.from_numpy(Xs).type(torch.float) +Xt_torch = torch.from_numpy(Xt).type(torch.float) + +p = 2 +num_proj = 180 + +a = torch.ones(Xs.shape[0], dtype=torch.float) +b = torch.ones(Xt.shape[0], dtype=torch.float) + +# construct projections +thetas = np.linspace(0, np.pi, num_proj) +dir = np.array([(np.cos(theta), np.sin(theta)) for theta in thetas]) +dir_torch = torch.from_numpy(dir).type(torch.float) + +Xps = (Xs_torch @ dir_torch.T).T # shape (n_projs, n) +Xpt = (Xt_torch @ dir_torch.T).T + +############################################################################## +# Compute SUOT and USOT +# ------------- + +# %% + +rho1_SUOT = 1 +rho2_SUOT = 1 +_, log = ot.unbalanced.sliced_unbalanced_ot( + Xs_torch, + Xt_torch, + (rho1_SUOT, rho2_SUOT), + a, + b, + num_proj, + p, + numItermax=10, + projections=dir_torch.T, + mode="backprop", + log=True, +) +A_SUOT, B_SUOT = log["a_reweighted"].T, log["b_reweighted"].T + + +rho1_USOT = 1 +rho2_USOT = 1 +A_USOT, B_USOT, _ = ot.unbalanced_sliced_ot( + Xs_torch, + Xt_torch, + (rho1_USOT, rho2_USOT), + a, + b, + num_proj, + p, + numItermax=10, + projections=dir_torch.T, + mode="backprop", +) + + +############################################################################## +# Utils plot +# ---------- + +# %% + + +def kde_sklearn(x, x_grid, weights=None, bandwidth=0.2, **kwargs): + """Kernel Density Estimation with Scikit-learn""" + kde_skl = KernelDensity(bandwidth=bandwidth, **kwargs) + if weights is not None: + kde_skl.fit(x[:, np.newaxis], sample_weight=weights) + else: + kde_skl.fit(x[:, np.newaxis]) + # score_samples() returns the log-likelihood of the samples + log_pdf = kde_skl.score_samples(x_grid[:, np.newaxis]) + return np.exp(log_pdf) + + +def plot_slices( + col, nb_slices, x_grid, Xps, Xpt, Xps_weights, Xpt_weights, method, rho1, rho2 +): + for i in range(nb_slices): + ax = plt.subplot2grid((nb_slices, 3), (i, col)) + if len(Xps_weights.shape) > 1: # SUOT + weights_src = Xps_weights[i * offset_degree, :].cpu().numpy() + weights_tgt = Xpt_weights[i * offset_degree, :].cpu().numpy() + else: # USOT + weights_src = Xps_weights.cpu().numpy() + weights_tgt = Xpt_weights.cpu().numpy() + + samples_src = Xps[i * offset_degree, :].cpu().numpy() + samples_tgt = Xpt[i * offset_degree, :].cpu().numpy() + + pdf_source = kde_sklearn(samples_src, x_grid, weights=weights_src, bandwidth=bw) + pdf_target = kde_sklearn(samples_tgt, x_grid, weights=weights_tgt, bandwidth=bw) + pdf_source_without_w = kde_sklearn(samples_src, x_grid, bandwidth=bw) + pdf_target_without_w = kde_sklearn(samples_tgt, x_grid, bandwidth=bw) + + ax.plot(x_grid, pdf_source, color=c2, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_source_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_source, ec=c2, fc=c2, alpha=0.3) + + ax.plot(x_grid, pdf_target, color=c1, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_target_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_target, ec=c2, fc=c1, alpha=0.3) + + ax.set_xlim(xlim_min, xlim_max) + + if col == 1: + ax.set_ylabel( + r"$\theta=${}$^o$".format(i * offset_degree), + color=colors[i], + fontsize=13, + ) + + ax.set_yticks([]) + ax.set_xticks([]) + + ax.set_xlabel( + r"{} $\rho_1={}$ $\rho_2={}$".format(method, rho1, rho2), fontsize=13 + ) + + +############################################################################## +# Plot reweighted distributions on several slices +# ------------- +# We plot the reweighted distributions on several slices. We see that for SUOT, +# the mode of outliers is kept of some slices (e.g. for :math:`\theta=120°`) while USOT +# is able to get rid of the outlier mode. + +# %% + +c1 = np.array(mpl.colors.to_rgb("red")) +c2 = np.array(mpl.colors.to_rgb("blue")) + +# define plotting grid +xlim_min = -3 +xlim_max = 3 +x_grid = np.linspace(xlim_min, xlim_max, 200) +bw = 0.05 + +# visu parameters +nb_slices = 3 # 4 +offset_degree = int(180 / nb_slices) + +delta_degree = np.pi / nb_slices +colors = plt.cm.Reds(np.linspace(0.3, 1, nb_slices)) + +X1 = np.array([-4, 0]) +X2 = np.array([4, 0]) + + +fig = plt.figure(figsize=(9, 3)) + +ax1 = plt.subplot2grid((nb_slices, 3), (0, 0), rowspan=nb_slices) + +for i in range(nb_slices): + R = get_rot(delta_degree * (-i)) + X1_r = X1.dot(R) + X2_r = X2.dot(R) + if i == 0: + ax1.plot( + [X1_r[0], X2_r[0]], + [X1_r[1], X2_r[1]], + color=colors[i], + alpha=0.8, + zorder=0, + label="Directions", + ) + else: + ax1.plot( + [X1_r[0], X2_r[0]], [X1_r[1], X2_r[1]], color=colors[i], alpha=0.8, zorder=0 + ) + +ax1.scatter(Xs[:, 0], Xs[:, 1], zorder=1, color=c2, label="Source data") +ax1.scatter(Xt[:, 0], Xt[:, 1], zorder=1, color=c1, label="Target data") +ax1.set_xlim([-3, 3]) +ax1.set_ylim([-3, 3]) +ax1.set_yticks([]) +ax1.set_xticks([]) +# ax1.legend(loc='best',fontsize=13) +ax1.set_xlabel("Original distributions", fontsize=13) + + +fig.subplots_adjust(hspace=0) +fig.subplots_adjust(wspace=0.15) + +plot_slices( + 1, nb_slices, x_grid, Xps, Xpt, A_SUOT, B_SUOT, "SUOT", rho1_SUOT, rho2_SUOT +) +plot_slices( + 2, nb_slices, x_grid, Xps, Xpt, A_USOT, B_USOT, "USOT", rho1_USOT, rho2_USOT +) + +plt.show() diff --git a/ignore-words.txt b/ignore-words.txt index 00c1f5edb..573400137 100644 --- a/ignore-words.txt +++ b/ignore-words.txt @@ -6,4 +6,5 @@ wass ccompiler ist lik -ges \ No newline at end of file +ges +mapp diff --git a/ot/__init__.py b/ot/__init__.py index 26f428aa1..ffb073285 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -44,6 +44,8 @@ emd2_lazy, emd_1d, emd2_1d, + emd_1d_dual, + emd_1d_dual_backprop, wasserstein_1d, binary_search_circle, wasserstein_circle, @@ -51,7 +53,12 @@ linear_circular_ot, ) from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 +from .unbalanced import ( + sinkhorn_unbalanced, + barycenter_unbalanced, + sinkhorn_unbalanced2, + unbalanced_sliced_ot, +) from .da import sinkhorn_lpl1_mm from .sliced import ( sliced_wasserstein_distance, @@ -96,6 +103,8 @@ "toq", "gromov", "emd2_1d", + "emd_1d_dual", + "emd_1d_dual_backprop", "wasserstein_1d", "backend", "gaussian", @@ -110,6 +119,7 @@ "sinkhorn_unbalanced2", "sliced_wasserstein_distance", "sliced_wasserstein_sphere", + "unbalanced_sliced_ot", "linear_sliced_wasserstein_sphere", "gromov_wasserstein", "gromov_wasserstein2", diff --git a/ot/backend.py b/ot/backend.py index 6b03f5cd1..0a4b20953 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -588,7 +588,7 @@ def flip(self, a, axis=None): """ raise NotImplementedError() - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): """ Limits the values in a tensor. @@ -1145,6 +1145,22 @@ def slogdet(self, a): """ raise NotImplementedError() + def index_select(self, input, axis, index): + r""" + Returns a new tensor which indexes the input tensor along dimension dim using the entries in index. + + See: https://docs.pytorch.org/docs/stable/generated/torch.index_select.html + """ + raise NotImplementedError() + + def nonzero(self, input, as_tuple=False): + r""" + Returns a tensor containing the indices of all non-zero elements of input. + + See: https://docs.pytorch.org/docs/stable/generated/torch.nonzero.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1286,7 +1302,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return np.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return np.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -1528,6 +1544,16 @@ def det(self, a): def slogdet(self, a): return np.linalg.slogdet(a) + def index_select(self, input, axis, index): + return np.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return np.nonzero(input) + else: + L_tuple = np.nonzero(input) + return np.concatenate([t[None] for t in L_tuple], axis=0).T + _register_backend_implementation(NumpyBackend) @@ -1703,7 +1729,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return jnp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return jnp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -1946,6 +1972,16 @@ def det(self, x): def slogdet(self, a): return jnp.linalg.slogdet(a) + def index_select(self, input, axis, index): + return jnp.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return jnp.nonzero(input) + else: + L_tuple = jnp.nonzero(input) + return jnp.concatenate([t[None] for t in L_tuple], axis=0).T + if jax: # Only register jax backend if it is installed @@ -2200,7 +2236,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return torch.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return torch.clamp(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -2540,6 +2576,12 @@ def det(self, x): def slogdet(self, a): return torch.linalg.slogdet(a) + def index_select(self, input, axis, index): + return torch.index_select(input, axis, index) + + def nonzero(self, input, as_tuple=False): + return torch.nonzero(input, as_tuple=as_tuple) + if torch: # Only register torch backend if it is installed @@ -2701,7 +2743,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return cp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return cp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -2963,6 +3005,16 @@ def det(self, x): def slogdet(self, a): return cp.linalg.slogdet(a) + def index_select(self, input, axis, index): + return cp.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return cp.nonzero(input) + else: + L_tuple = cp.nonzero(input) + return cp.concatenate([t[None] for t in L_tuple], axis=0).T + if cp: # Only register cp backend if it is installed @@ -3135,7 +3187,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return tnp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return tnp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -3423,6 +3475,16 @@ def det(self, x): def slogdet(self, a): return tf.linalg.slogdet(a) + def index_select(self, input, axis, index): + return tf.gather(input, index, axis=axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return tf.where(input) + else: + indices = tf.where(input) + return tf.reshape(indices, (-1, indices.shape[-1])) + if tf: # Only register tensorflow backend if it is installed diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 8e88d63c8..0d8a640e4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,11 @@ emd_1d, emd2_1d, wasserstein_1d, + emd_1d_dual, + emd_1d_dual_backprop, +) + +from .solver_circle import ( binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -43,6 +48,8 @@ "emd_1d", "emd2_1d", "wasserstein_1d", + "emd_1d_dual", + "emd_1d_dual_backprop", "generalized_free_support_barycenter", "binary_search_circle", "wasserstein_circle", diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index ec06298bc..1f376b707 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -44,31 +44,46 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): Parameters ---------- - alpha0 : (ns,) numpy.ndarray, float64 + alpha0 : (ns, ...) numpy.ndarray, float64 Source dual potential - beta0 : (nt,) numpy.ndarray, float64 + beta0 : (nt, ...) numpy.ndarray, float64 Target dual potential - a : (ns,) numpy.ndarray, float64 + a : (ns, ...) numpy.ndarray, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt, ...) numpy.ndarray, float64 Target histogram (uniform weight if empty list) Returns ------- - alpha : (ns,) numpy.ndarray, float64 + alpha : (ns, ...) numpy.ndarray, float64 Source centered dual potential - beta : (nt,) numpy.ndarray, float64 + beta : (nt, ...) numpy.ndarray, float64 Target centered dual potential """ + if a is not None and b is not None: + nx = get_backend(alpha0, beta0, a, b) + else: + nx = get_backend(alpha0, beta0) + + n = alpha0.shape[0] + m = beta0.shape[0] + # if no weights are provided, use uniform if a is None: - a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + a = nx.full(alpha0.shape, 1.0 / n, type_as=alpha0) + elif a.ndim != alpha0.ndim: + a = nx.repeat(a[..., None], alpha0.shape[-1], -1) + if b is None: - b = np.ones(beta0.shape[0]) / beta0.shape[0] + b = nx.full(beta0.shape, 1.0 / m, type_as=beta0) + elif b.ndim != beta0.ndim: + b = nx.repeat(b[..., None], beta0.shape[-1], -1) # compute constant that balances the weighted sums of the duals - c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + ips = nx.sum(b * beta0, axis=0) - nx.sum(a * alpha0, axis=0) + denom = nx.sum(a, axis=0) + nx.sum(b, axis=0) + c = ips / denom # update duals alpha = alpha0 + c diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 49e0c9c41..4326b525d 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -5,6 +5,7 @@ # Author: Remi Flamary # Author: Nicolas Courty +# Author: Clément Bonet # # License: MIT License @@ -14,9 +15,10 @@ from .emd_wrap import emd_1d_sorted from ..backend import get_backend from ..utils import list_to_array +from ._network_simplex import center_ot_dual -def quantile_function(qs, cws, xs): +def quantile_function(qs, cws, xs, return_index=False): r"""Computes the quantile function of an empirical distribution Parameters @@ -27,6 +29,7 @@ def quantile_function(qs, cws, xs): cumulative weights of the 1D empirical distribution, if batched, must be similar to xs xs: array-like, shape (n, ...) locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + return_index: bool Returns ------- @@ -43,8 +46,14 @@ def quantile_function(qs, cws, xs): else: cws = cws.T qs = qs.T - idx = nx.searchsorted(cws, qs).T - return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + # idx = nx.searchsorted(cws, qs).T + # return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + idx = nx.clip(nx.searchsorted(cws, qs).T, 0, n - 1) + if return_index: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0), idx + else: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) def wasserstein_1d( @@ -399,316 +408,48 @@ def emd2_1d( return cost -def roll_cols(M, shifts): +def emd_1d_dual( + u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True +): r""" - Utils functions which allow to shift the order of each row of a 2d matrix - - Parameters - ---------- - M : ndarray, shape (nr, nc) - Matrix to shift - shifts: int or ndarray, shape (nr,) - - Returns - ------- - Shifted array - - Examples - -------- - >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) - >>> roll_cols(M, 2) - array([[2, 3, 1], - [5, 6, 4], - [8, 9, 7]]) - >>> roll_cols(M, np.array([[1],[2],[1]])) - array([[3, 1, 2], - [5, 6, 4], - [9, 7, 8]]) - - References - ---------- - https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch - """ - nx = get_backend(M) - - n_rows, n_cols = M.shape - - arange1 = nx.tile( - nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1) - ) - arange2 = (arange1 - shifts) % n_cols - - return nx.take_along_axis(M, arange2, 1) - - -def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): - r"""Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) - - Parameters - ---------- - theta: array-like, shape (n_batch, n) - Cuts on the circle - u_values: array-like, shape (n_batch, n) - locations of the first empirical distribution - v_values: array-like, shape (n_batch, n) - locations of the second empirical distribution - u_cdf: array-like, shape (n_batch, n) - cdf of the first empirical distribution - v_cdf: array-like, shape (n_batch, n) - cdf of the second empirical distribution - p: float, optional = 2 - Power p used for computing the Wasserstein distance - - Returns - ------- - dCp: array-like, shape (n_batch, 1) - The batched right derivative - dCm: array-like, shape (n_batch, 1) - The batched left derivative - - References - --------- - .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - """ - nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) - - v_values = nx.copy(v_values) - - n = u_values.shape[-1] - m_batch, m = v_values.shape - - v_cdf_theta = v_cdf - (theta - nx.floor(theta)) - - mask_p = v_cdf_theta >= 0 - mask_n = v_cdf_theta < 0 - - v_values[mask_n] += nx.floor(theta)[mask_n] + 1 - v_values[mask_p] += nx.floor(theta)[mask_p] - - if nx.any(mask_n) and nx.any(mask_p): - v_cdf_theta[mask_n] += 1 - - v_cdf_theta2 = nx.copy(v_cdf_theta) - v_cdf_theta2[mask_n] = np.inf - shift = -nx.argmin(v_cdf_theta2, axis=-1) - - v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) - v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) - v_values = nx.concatenate( - [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 - ) - - if nx.__name__ == "torch": - # this is to ensure the best performance for torch searchsorted - # and avoid a warning related to non-contiguous arrays - u_cdf = u_cdf.contiguous() - v_cdf_theta = v_cdf_theta.contiguous() - - # quantiles of F_u evaluated in F_v^\theta - u_index = nx.searchsorted(u_cdf, v_cdf_theta) - u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) - - # Deal with 1 - u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) - u_valuesm = nx.concatenate( - [u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1 - ) - - if nx.__name__ == "torch": - # this is to ensure the best performance for torch searchsorted - # and avoid a warning related to non-contiguous arrays - u_cdfm = u_cdfm.contiguous() - v_cdf_theta = v_cdf_theta.contiguous() - - u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") - u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) - - dCp = nx.sum( - nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) - - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), - axis=-1, - ) + Computes the 1 dimensional OT loss between two (batched) empirical + distributions - dCm = nx.sum( - nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) - - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), - axis=-1, - ) + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq - return dCp.reshape(-1, 1), dCm.reshape(-1, 1) + and returns the dual potentials and the loss, i.e. such that + .. math: + OT_{loss}(u,v) = \int f(x)\mathrm{d}u(x) + \int g(y)\mathrm{d}v(y). -def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): - r"""Computes the the cost (Equation (6.2) of [1]) + We do so by solving the dual problem using a parallel North-West corner rule. Parameters ---------- - theta: array-like, shape (n_batch, n) - Cuts on the circle - u_values: array-like, shape (n_batch, n) + u_values: array-like, shape (n, ...) locations of the first empirical distribution - v_values: array-like, shape (n_batch, n) + v_values: array-like, shape (m, ...) locations of the second empirical distribution - u_cdf: array-like, shape (n_batch, n) - cdf of the first empirical distribution - v_cdf: array-like, shape (n_batch, n) - cdf of the second empirical distribution - p: float, optional = 2 - Power p used for computing the Wasserstein distance - - Returns - ------- - ot_cost: array-like, shape (n_batch,) - OT cost evaluated at theta - - References - --------- - .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - """ - nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) - - v_values = nx.copy(v_values) - - m_batch, m = v_values.shape - n_batch, n = u_values.shape - - v_cdf_theta = v_cdf - (theta - nx.floor(theta)) - - mask_p = v_cdf_theta >= 0 - mask_n = v_cdf_theta < 0 - - v_values[mask_n] += nx.floor(theta)[mask_n] + 1 - v_values[mask_p] += nx.floor(theta)[mask_p] - - if nx.any(mask_n) and nx.any(mask_p): - v_cdf_theta[mask_n] += 1 - - # Put negative values at the end - v_cdf_theta2 = nx.copy(v_cdf_theta) - v_cdf_theta2[mask_n] = np.inf - shift = -nx.argmin(v_cdf_theta2, axis=-1) - - v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) - v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) - v_values = nx.concatenate( - [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 - ) - - # Compute absciss - cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) - cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) - - delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] - - if nx.__name__ == "torch": - # this is to ensure the best performance for torch searchsorted - # and avoid a warning related to non-contiguous arrays - u_cdf = u_cdf.contiguous() - v_cdf_theta = v_cdf_theta.contiguous() - cdf_axis = cdf_axis.contiguous() - - # Compute icdf - u_index = nx.searchsorted(u_cdf, cdf_axis) - u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) - - v_values = nx.concatenate( - [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 - ) - v_index = nx.searchsorted(v_cdf_theta, cdf_axis) - v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) - - if p == 1: - ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) - else: - ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) - - return ot_cost - - -def binary_search_circle( - u_values, - v_values, - u_weights=None, - v_weights=None, - p=1, - Lm=10, - Lp=10, - tm=-1, - tp=1, - eps=1e-6, - require_sort=True, - log=False, -): - r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates - using e.g. the atan2 function. - - .. math:: - W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q - - where: - - - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` - - For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with - - .. math:: - u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} - - using e.g. ot.utils.get_coordinate_circle(x) - - The function runs on backend but tensorflow and jax are not supported. - - Parameters - ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...) - samples in the target domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - p : float, optional (default=1) - Power p used for computing the Wasserstein distance - Lm : int, optional - Lower bound dC - Lp : int, optional - Upper bound dC - tm: float, optional - Lower bound theta - tp: float, optional - Upper bound theta - eps: float, optional - Stopping condition + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 1 require_sort: bool, optional - If True, sort the values. - log: bool, optional - If True, returns also the optimal theta + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True Returns ------- + f: array-like shape (n, ...) + First dual potential + g: array-like shape (m, ...) + Second dual potential loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation - log: dict, optional - log dictionary returned only if log==True in parameters - - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> binary_search_circle(u.T, v.T, p=1) - array([0.1]) - - References - ---------- - .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + the batched EMD """ - assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) - if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: @@ -717,30 +458,18 @@ def binary_search_circle( n = u_values.shape[0] m = v_values.shape[0] - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - if len(v_values.shape) == 1: - v_values = nx.reshape(v_values, (m, 1)) - - if u_values.shape[1] != v_values.shape[1]: - raise ValueError( - "u and v must have the same number of batches {} and {} respectively given".format( - u_values.shape[1], v_values.shape[1] - ) - ) - - u_values = u_values % 1 - v_values = v_values % 1 - + # Init weights or broadcast if necessary if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + # Sort w.r.t. support if not already done if require_sort: u_sorter = nx.argsort(u_values, 0) u_values = nx.take_along_axis(u_values, u_sorter, 0) @@ -751,496 +480,177 @@ def binary_search_circle( u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - u_cdf = nx.cumsum(u_weights, 0).T - v_cdf = nx.cumsum(v_weights, 0).T - - u_values = u_values.T - v_values = v_values.T - - L = max(Lm, Lp) - - tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) - tm = nx.tile(tm, (1, m)) - tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) - tp = nx.tile(tp, (1, m)) - tc = (tm + tp) / 2 - - done = nx.zeros((u_values.shape[0], m)) - - cpt = 0 - while nx.any(1 - done): - cpt += 1 - - dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) - done = ((dCp * dCm) <= 0) * 1 - - mask = ((tp - tm) < eps / L) * (1 - done) - - if nx.any(mask): - # can probably be improved by computing only relevant values - dCptp, dCmtp = derivative_cost_on_circle( - tp, u_values, v_values, u_cdf, v_cdf, p - ) - dCptm, dCmtm = derivative_cost_on_circle( - tm, u_values, v_values, u_cdf, v_cdf, p - ) - Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape( - -1, 1 - ) - Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape( - -1, 1 - ) - - # Avoid warning raised when dCptm - dCmtp == 0, for which - # tc is not updated as mask_end is False, - # see Issue #738 - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) - tc[mask_end > 0] = ( - (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) - )[mask_end > 0] - done[nx.prod(mask, axis=-1) > 0] = 1 - elif nx.any(1 - done): - tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] - tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] - tc[((1 - mask) * (1 - done)) > 0] = ( - tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0] - ) / 2 - - w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) + # eps trick to have strictly increasing cdf and avoid zero mass issues + eps = 1e-12 + u_cdf = nx.cumsum(u_weights + eps, 0) - eps + v_cdf = nx.cumsum(v_weights + eps, 0) - eps - if log: - return w, {"optimal_theta": tc[:, 0]} - return w + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf), 0), 0) + u_icdf, u_index = quantile_function(cdf_axis, u_cdf, u_values, return_index=True) + v_icdf, v_index = quantile_function(cdf_axis, v_cdf, v_values, return_index=True) -def wasserstein1_circle( - u_values, v_values, u_weights=None, v_weights=None, require_sort=True -): - r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates - using e.g. the atan2 function. - The function runs on backend but tensorflow and jax are not supported. + diff_dist = nx.power(nx.abs(u_icdf - v_icdf), p) + cdf_axis = nx.zero_pad( + cdf_axis, pad_width=[(1, 0)] + (cdf_axis.ndim - 1) * [(0, 0)] + ) - .. math:: - W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + # parallel North-West corner rule + mask_u = u_index[1:, ...] - u_index[:-1, ...] + mask_u = nx.zero_pad(mask_u, pad_width=[(1, 0)] + (mask_u.ndim - 1) * [(0, 0)]) + mask_v = v_index[1:, ...] - v_index[:-1, ...] + mask_v = nx.zero_pad( + mask_v, pad_width=[(1, 0)] + (mask_v.ndim - 1) * [(0, 0)], value=1 + ) - Parameters - ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...) - samples in the target domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - require_sort: bool, optional - If True, sort the values. + c1 = nx.where((mask_u[:-1, ...] + mask_u[1:, ...]) > 1, -1, 0) + c1 = nx.cumsum(c1 * diff_dist[:-1, ...], axis=0) + c1 = nx.zero_pad(c1, pad_width=[(1, 0)] + (c1.ndim - 1) * [(0, 0)]) - Returns - ------- - loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation + c2 = nx.where((mask_v[:-1, ...] + mask_v[1:, ...]) > 1, -1, 0) + c2 = nx.cumsum(c2 * diff_dist[:-1, ...], axis=0) + c2 = nx.zero_pad(c2, pad_width=[(1, 0)] + (c2.ndim - 1) * [(0, 0)]) - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> wasserstein1_circle(u.T, v.T) - array([0.1]) + masked_u_dist = mask_u * diff_dist + masked_v_dist = mask_v * diff_dist - References - ---------- - .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. - .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ - """ - - if u_weights is not None and v_weights is not None: - nx = get_backend(u_values, v_values, u_weights, v_weights) - else: - nx = get_backend(u_values, v_values) + T = nx.cumsum(masked_u_dist - masked_v_dist, axis=0) + c1 - c2 - n = u_values.shape[0] - m = v_values.shape[0] + tmp = nx.copy(mask_u > 0) # avoid in-place problem + tmp[0, ...] = 1 + # f = nx.reshape(T[tmp], u_values.shape) # work only with one axis + f = nx.reshape( + nx.index_select( + nx.reshape(T.T, (-1,)), + 0, + # nx.reshape(tmp.T, (-1,)).nonzero().squeeze() + nx.nonzero(nx.reshape(tmp.T, (-1,))).squeeze(), + ), + u_values.T.shape, + ).T + f[0, ...] = 0 - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - if len(v_values.shape) == 1: - v_values = nx.reshape(v_values, (m, 1)) + # Complementary slackness + C = nx.power(nx.abs(u_values[:, None] - v_values[None]), p) - f[:, None] + g = nx.min(C, axis=0) - if u_values.shape[1] != v_values.shape[1]: - raise ValueError( - "u and v must have the same number of batchs {} and {} respectively given".format( - u_values.shape[1], v_values.shape[1] - ) - ) - - u_values = u_values % 1 - v_values = v_values % 1 - - if u_weights is None: - u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) - elif u_weights.ndim != u_values.ndim: - u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - if v_weights is None: - v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) - elif v_weights.ndim != v_values.ndim: - v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + loss = nx.sum(f * u_weights, axis=0) + nx.sum(g * v_weights, axis=0) + # unsort potentials if require_sort: - u_sorter = nx.argsort(u_values, 0) - u_values = nx.take_along_axis(u_values, u_sorter, 0) - - v_sorter = nx.argsort(v_values, 0) - v_values = nx.take_along_axis(v_values, v_sorter, 0) - - u_weights = nx.take_along_axis(u_weights, u_sorter, 0) - v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - - # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ - values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) - - cdf_diff = nx.cumsum( - nx.take_along_axis( - nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0 - ), - 0, - ) - cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) - - values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) - delta = values_sorted[1:, ...] - values_sorted[:-1, ...] - weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) + u_rev_sorter = nx.argsort(u_sorter, 0) + f = nx.take_along_axis(f, u_rev_sorter, 0) - sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 - sum_weights[sum_weights < 0] = np.inf - inds = nx.argmin(sum_weights, axis=0) + v_rev_sorter = nx.argsort(v_sorter, 0) + g = nx.take_along_axis(g, v_rev_sorter, 0) - levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) + f, g = center_ot_dual(f, g, u_weights, v_weights) - return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) + return f, g, loss -def wasserstein_circle( - u_values, - v_values, - u_weights=None, - v_weights=None, - p=1, - Lm=10, - Lp=10, - tm=-1, - tp=1, - eps=1e-6, - require_sort=True, +def emd_1d_dual_backprop( + u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True ): - r"""Computes the Wasserstein distance on the circle using either :ref:`[45] ` for p=1 or - the binary search algorithm proposed in :ref:`[44] ` otherwise. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates - using e.g. the atan2 function. - - General loss returned: - - .. math:: - OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q - - For p=1, [45] - - .. math:: - W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + r""" + Computes the 1 dimensional OT loss between two (batched) empirical + distributions - For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq - .. math:: - u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + and returns the dual potentials and the loss, i.e. such that - using e.g. ot.utils.get_coordinate_circle(x) + .. math: + OT_{loss}(u,v) = \int f(x)\mathrm{d}u(x) + \int g(y)\mathrm{d}v(y). - The function runs on backend but tensorflow and jax are not supported. + We do so by backpropagating through the `wasserstein_1d` function. Thus, the function + only works in torch and jax. Parameters ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...) - samples in the target domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - p : float, optional (default=1) - Power p used for computing the Wasserstein distance - Lm : int, optional - Lower bound dC. For p>1. - Lp : int, optional - Upper bound dC. For p>1. - tm: float, optional - Lower bound theta. For p>1. - tp: float, optional - Upper bound theta. For p>1. - eps: float, optional - Stopping condition. For p>1. + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 1 require_sort: bool, optional - If True, sort the values. - - Returns - ------- - loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation - - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> wasserstein_circle(u.T, v.T) - array([0.1]) - - - .. _references-wasserstein-circle: - References - ---------- - .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. - .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - """ - assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) - - return binary_search_circle( - u_values, - v_values, - u_weights, - v_weights, - p=p, - Lm=Lm, - Lp=Lp, - tm=tm, - tp=tp, - eps=eps, - require_sort=require_sort, - ) - - -def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): - r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates - using e.g. the atan2 function. - - .. math:: - W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} - - where: - - - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` - - For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with - - .. math:: - u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, - - using e.g. ot.utils.get_coordinate_circle(x). - - Parameters - ---------- - u_values : ndarray, shape (n, ...) - Samples - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True Returns ------- + f: array-like shape (n, ...) + First dual potential + g: array-like shape (m, ...) + Second dual potential loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation - - Examples - -------- - >>> x0 = np.array([[0], [0.2], [0.4]]) - >>> semidiscrete_wasserstein2_unif_circle(x0) - array([0.02111111]) - - References - ---------- - .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. - """ - - if u_weights is not None: - nx = get_backend(u_values, u_weights) - else: - nx = get_backend(u_values) - - n = u_values.shape[0] - - u_values = u_values % 1 - - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - - if u_weights is None: - u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) - elif u_weights.ndim != u_values.ndim: - u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - - u_values = nx.sort(u_values, 0) - u_cdf = nx.cumsum(u_weights, 0) - u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) - - cpt1 = nx.sum(u_weights * u_values**2, axis=0) - u_mean = nx.sum(u_weights * u_values, axis=0) - - ns = 1 - u_weights - 2 * u_cdf[:-1] - cpt2 = nx.sum(u_values * u_weights * ns, axis=0) - - return cpt1 - u_mean**2 + cpt2 + 1 / 12 - - -def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True): - r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference - :math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`. - - For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[78] `) - - .. math`` - \hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x. - - Parameters - ---------- - x : ndary, shape (m,) - Points in [0,1[ where to evaluate the embedding - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - - Returns - ------- - embedding: ndarray of shape (m, ...) - Embedding evaluated at :math:`x` - - .. _references-lcot: - References - ---------- - .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. + the batched EMD """ - if u_weights is not None: - nx = get_backend(u_values, u_weights) + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) else: - nx = get_backend(u_values) - - n = u_values.shape[0] - u_values = u_values % 1 - - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - - if u_weights is None: - u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) - elif u_weights.ndim != u_values.ndim: - u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - - if require_sort: - u_sorter = nx.argsort(u_values, 0) - u_values = nx.take_along_axis(u_values, u_sorter, 0) - u_weights = nx.take_along_axis(u_weights, u_sorter, 0) - - u_cdf = nx.cumsum(u_weights, 0) - u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) - - q_s = ( - x[:, None] - nx.sum(u_values * u_weights, axis=0)[None] + 0.5 - ) # shape (m, ...) - - u_quantiles = quantile_function(q_s % 1, u_cdf, u_values) - - return (u_quantiles - x[:, None]) % 1 - - -def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None): - r"""Computes the Linear Circular Optimal Transport distance from :ref:`[78] ` using :math:`\eta=\mathrm{Unif}(S^1)` - as reference measure. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates - using e.g. the atan2 function. - - General loss returned: - - .. math:: - \mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t - - where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`, - and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`. - - Parameters - ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...), optional - samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - - Returns - ------- - loss: float/array-like, shape (...) - Batched cost associated to the linear optimal transportation - - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> linear_circular_ot(u.T, v.T) - array([0.0127]) - + nx = get_backend(u_values, v_values) - .. _references-lcot: - References - ---------- - .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. - """ - if u_weights is not None: - nx = get_backend(u_values, u_weights) - else: - nx = get_backend(u_values) + assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" n = u_values.shape[0] - u_values = u_values % 1 - - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) + m = v_values.shape[0] + # Init weights or broadcast if necessary if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1] + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) - emb_u = linear_circular_embedding(unif_s1, u_values, u_weights) + if nx.__name__ == "torch": + u_weights_diff = nx.copy(u_weights) + v_weights_diff = nx.copy(v_weights) + + u_weights_diff.requires_grad_(True) + v_weights_diff.requires_grad_(True) + + cost_output = wasserstein_1d( + u_values, + v_values, + u_weights_diff, + v_weights_diff, + p=p, + require_sort=require_sort, + ) + loss = cost_output.sum() + loss.backward() + + f, g = center_ot_dual( + u_weights_diff.grad.detach(), + v_weights_diff.grad.detach(), + u_weights, + v_weights, + ) - if v_values is None: - dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u)) - return nx.mean(dist_u**2, axis=0) - else: - m = v_values.shape[0] - if len(v_values.shape) == 1: - v_values = nx.reshape(v_values, (m, 1)) + return f, g, cost_output.detach() # value can not be backward anymore + elif nx.__name__ == "jax": + import jax - if u_values.shape[1] != v_values.shape[1]: - raise ValueError( - "u and v must have the same number of batchs {} and {} respectively given".format( - u_values.shape[1], v_values.shape[1] - ) - ) + def ot_1d(a, b): + return wasserstein_1d( + u_values, v_values, a, b, p=p, require_sort=require_sort + ).sum() - emb_v = linear_circular_embedding(unif_s1, v_values, v_weights) + f, g = jax.grad(ot_1d, argnums=[0, 1])(u_weights, v_weights) + cost_output = wasserstein_1d( + u_values, v_values, u_weights, v_weights, p=p, require_sort=require_sort + ) - dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v)) - return nx.mean(dist_uv**2, axis=0) + f, g = center_ot_dual(f, g, u_weights, v_weights) + return f, g, cost_output diff --git a/ot/lp/solver_circle.py b/ot/lp/solver_circle.py new file mode 100644 index 000000000..8fcdef49e --- /dev/null +++ b/ot/lp/solver_circle.py @@ -0,0 +1,861 @@ +# -*- coding: utf-8 -*- +""" +Exact solvers for the 1D Wasserstein distance using cvxopt +""" + +# Author: Clément Bonet +# +# License: MIT License + +import numpy as np +import warnings + +from ..backend import get_backend +from .solver_1d import quantile_function + + +def roll_cols(M, shifts): + r""" + Utils functions which allow to shift the order of each row of a 2d matrix + + Parameters + ---------- + M : ndarray, shape (nr, nc) + Matrix to shift + shifts: int or ndarray, shape (nr,) + + Returns + ------- + Shifted array + + Examples + -------- + >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) + >>> roll_cols(M, 2) + array([[2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + >>> roll_cols(M, np.array([[1],[2],[1]])) + array([[3, 1, 2], + [5, 6, 4], + [9, 7, 8]]) + + References + ---------- + https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch + """ + nx = get_backend(M) + + n_rows, n_cols = M.shape + + arange1 = nx.tile( + nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1) + ) + arange2 = (arange1 - shifts) % n_cols + + return nx.take_along_axis(M, arange2, 1) + + +def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): + r"""Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + dCp: array-like, shape (n_batch, 1) + The batched right derivative + dCm: array-like, shape (n_batch, 1) + The batched left derivative + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + n = u_values.shape[-1] + m_batch, m = v_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = -nx.argmin(v_cdf_theta2, axis=-1) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) + + if nx.__name__ == "torch": + # this is to ensure the best performance for torch searchsorted + # and avoid a warning related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + # quantiles of F_u evaluated in F_v^\theta + u_index = nx.searchsorted(u_cdf, v_cdf_theta) + u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) + + # Deal with 1 + u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) + u_valuesm = nx.concatenate( + [u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1 + ) + + if nx.__name__ == "torch": + # this is to ensure the best performance for torch searchsorted + # and avoid a warning related to non-contiguous arrays + u_cdfm = u_cdfm.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") + u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) + + dCp = nx.sum( + nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), + axis=-1, + ) + + dCm = nx.sum( + nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), + axis=-1, + ) + + return dCp.reshape(-1, 1), dCm.reshape(-1, 1) + + +def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): + r"""Computes the the cost (Equation (6.2) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + ot_cost: array-like, shape (n_batch,) + OT cost evaluated at theta + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + m_batch, m = v_values.shape + n_batch, n = u_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + # Put negative values at the end + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = -nx.argmin(v_cdf_theta2, axis=-1) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) + + # Compute absciss + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) + cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) + + delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] + + if nx.__name__ == "torch": + # this is to ensure the best performance for torch searchsorted + # and avoid a warning related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + cdf_axis = cdf_axis.contiguous() + + # Compute icdf + u_index = nx.searchsorted(u_cdf, cdf_axis) + u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) + + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) + v_index = nx.searchsorted(v_cdf_theta, cdf_axis) + v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) + + if p == 1: + ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) + else: + ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) + + return ot_cost + + +def binary_search_circle( + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + Lm=10, + Lp=10, + tm=-1, + tp=1, + eps=1e-6, + require_sort=True, + log=False, +): + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + where: + + - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow and jax are not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC + Lp : int, optional + Upper bound dC + tm: float, optional + Lower bound theta + tp: float, optional + Upper bound theta + eps: float, optional + Stopping condition + require_sort: bool, optional + If True, sort the values. + log: bool, optional + If True, returns also the optimal theta + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + log: dict, optional + log dictionary returned only if log==True in parameters + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> binary_search_circle(u.T, v.T, p=1) + array([0.1]) + + References + ---------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batches {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0).T + v_cdf = nx.cumsum(v_weights, 0).T + + u_values = u_values.T + v_values = v_values.T + + L = max(Lm, Lp) + + tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tm = nx.tile(tm, (1, m)) + tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tp = nx.tile(tp, (1, m)) + tc = (tm + tp) / 2 + + done = nx.zeros((u_values.shape[0], m)) + + cpt = 0 + while nx.any(1 - done): + cpt += 1 + + dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + done = ((dCp * dCm) <= 0) * 1 + + mask = ((tp - tm) < eps / L) * (1 - done) + + if nx.any(mask): + # can probably be improved by computing only relevant values + dCptp, dCmtp = derivative_cost_on_circle( + tp, u_values, v_values, u_cdf, v_cdf, p + ) + dCptm, dCmtm = derivative_cost_on_circle( + tm, u_values, v_values, u_cdf, v_cdf, p + ) + Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape( + -1, 1 + ) + Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape( + -1, 1 + ) + + # Avoid warning raised when dCptm - dCmtp == 0, for which + # tc is not updated as mask_end is False, + # see Issue #738 + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) + tc[mask_end > 0] = ( + (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) + )[mask_end > 0] + done[nx.prod(mask, axis=-1) > 0] = 1 + elif nx.any(1 - done): + tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] + tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] + tc[((1 - mask) * (1 - done)) > 0] = ( + tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0] + ) / 2 + + w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) + + if log: + return w, {"optimal_theta": tc[:, 0]} + return w + + +def wasserstein1_circle( + u_values, v_values, u_weights=None, v_weights=None, require_sort=True +): + r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. + The function runs on backend but tensorflow and jax are not supported. + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein1_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + """ + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) + + cdf_diff = nx.cumsum( + nx.take_along_axis( + nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0 + ), + 0, + ) + cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) + + values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) + delta = values_sorted[1:, ...] - values_sorted[:-1, ...] + weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) + + sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 + sum_weights[sum_weights < 0] = np.inf + inds = nx.argmin(sum_weights, axis=0) + + levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) + + return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) + + +def wasserstein_circle( + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + Lm=10, + Lp=10, + tm=-1, + tp=1, + eps=1e-6, + require_sort=True, +): + r"""Computes the Wasserstein distance on the circle using either :ref:`[45] ` for p=1 or + the binary search algorithm proposed in :ref:`[44] ` otherwise. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + For p=1, [45] + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow and jax are not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC. For p>1. + Lp : int, optional + Upper bound dC. For p>1. + tm: float, optional + Lower bound theta. For p>1. + tp: float, optional + Upper bound theta. For p>1. + eps: float, optional + Stopping condition. For p>1. + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein_circle(u.T, v.T) + array([0.1]) + + + .. _references-wasserstein-circle: + References + ---------- + .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + return binary_search_circle( + u_values, + v_values, + u_weights, + v_weights, + p=p, + Lm=Lm, + Lp=Lp, + tm=tm, + tp=tp, + eps=eps, + require_sort=require_sort, + ) + + +def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): + r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} + + where: + + - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, + + using e.g. ot.utils.get_coordinate_circle(x). + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + Samples + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + + Examples + -------- + >>> x0 = np.array([[0], [0.2], [0.4]]) + >>> semidiscrete_wasserstein2_unif_circle(x0) + array([0.02111111]) + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + u_values = nx.sort(u_values, 0) + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + cpt1 = nx.sum(u_weights * u_values**2, axis=0) + u_mean = nx.sum(u_weights * u_values, axis=0) + + ns = 1 - u_weights - 2 * u_cdf[:-1] + cpt2 = nx.sum(u_values * u_weights * ns, axis=0) + + return cpt1 - u_mean**2 + cpt2 + 1 / 12 + + +def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True): + r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference + :math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`. + + For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[78] `) + + .. math`` + \hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x. + + Parameters + ---------- + x : ndary, shape (m,) + Points in [0,1[ where to evaluate the embedding + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + embedding: ndarray of shape (m, ...) + Embedding evaluated at :math:`x` + + .. _references-lcot: + References + ---------- + .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. + """ + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + q_s = ( + x[:, None] - nx.sum(u_values * u_weights, axis=0)[None] + 0.5 + ) # shape (m, ...) + + u_quantiles = quantile_function(q_s % 1, u_cdf, u_values) + + return (u_quantiles - x[:, None]) % 1 + + +def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None): + r"""Computes the Linear Circular Optimal Transport distance from :ref:`[78] ` using :math:`\eta=\mathrm{Unif}(S^1)` + as reference measure. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + \mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t + + where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`, + and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...), optional + samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the linear optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> linear_circular_ot(u.T, v.T) + array([0.0127]) + + + .. _references-lcot: + References + ---------- + .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. + """ + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1] + + emb_u = linear_circular_embedding(unif_s1, u_values, u_weights) + + if v_values is None: + dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u)) + return nx.mean(dist_u**2, axis=0) + else: + m = v_values.shape[0] + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) + + emb_v = linear_circular_embedding(unif_s1, v_values, v_weights) + + dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v)) + return nx.mean(dist_uv**2, axis=0) diff --git a/ot/sliced.py b/ot/sliced.py index 81d0bd4a3..636432c2d 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -6,6 +6,7 @@ # Author: Adrien Corenflos # Nicolas Courty # Rémi Flamary +# Clément Bonet # # License: MIT License diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py index 771452954..b7a526182 100644 --- a/ot/unbalanced/__init__.py +++ b/ot/unbalanced/__init__.py @@ -24,6 +24,10 @@ from ._lbfgs import lbfgsb_unbalanced, lbfgsb_unbalanced2 +from ._solver_1d import uot_1d + +from ._sliced import sliced_unbalanced_ot, unbalanced_sliced_ot + __all__ = [ "sinkhorn_knopp_unbalanced", "sinkhorn_unbalanced", @@ -38,4 +42,7 @@ "_get_loss_unbalanced", "lbfgsb_unbalanced", "lbfgsb_unbalanced2", + "uot_1d", + "sliced_unbalanced_ot", + "unbalanced_sliced_ot", ] diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py new file mode 100644 index 000000000..987d0ddf4 --- /dev/null +++ b/ot/unbalanced/_sliced.py @@ -0,0 +1,358 @@ +# -*- coding: utf-8 -*- +""" +Sliced Unbalanced OT solvers +""" + +# Author: Clément Bonet +# +# License: MIT License + +from ..backend import get_backend +from ..utils import get_parameter_pair, list_to_array +from ..sliced import get_random_projections +from ._solver_1d import rescale_potentials, uot_1d +from ..lp.solver_1d import emd_1d_dual, emd_1d_dual_backprop, wasserstein_1d + + +def sliced_unbalanced_ot( + X_s, + X_t, + reg_m, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + numItermax=10, + mode="backprop", + log=False, +): + r""" + Compute SUOT + + TODO + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional, by default =2 + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + numItermax: int, optional + mode: str, optional + "icdf" for inverse CDF, "backprop" for backpropagation mode. + Default is "icdf". + log: bool, optional + if True, returns the projections used and their associated UOTs and reweighted marginals. + + Returns + ------- + loss: float/array-like, shape (...) + SUOT + + References + ---------- + [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). + Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research + """ + assert mode in ["backprop", "icdf"] + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + else: + n_projections = projections.shape[1] + + X_s_projections = nx.dot(X_s, projections) # shape (n, n_projs) + X_t_projections = nx.dot(X_t, projections) + + a_reweighted, b_reweighted, projected_uot = uot_1d( + X_s_projections, + X_t_projections, + reg_m, + a, + b, + p, + require_sort=True, + mode=mode, + numItermax=numItermax, + ) + + res = nx.mean(projected_uot) ** (1.0 / p) + + if log: + dico = { + "projection": projections, + "projected_uots": projected_uot, + "a_reweighted": a_reweighted, + "b_reweighted": b_reweighted, + } + return res, dico + + return res + + +def unbalanced_sliced_ot( + X_s, + X_t, + reg_m, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + numItermax=10, + mode="backprop", + stochastic_proj=False, + log=False, +): + r""" + Compute USOT + + TODO + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional, by default =2 + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + numItermax: int, optional + mode: str, optional + "icdf" for inverse CDF, "backprop" for backpropagation mode. + Default is "icdf". + stochastic_proj: bool, default False + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + a_reweighted: array-like shape (n, ...) + First marginal reweighted + b_reweighted: array-like shape (m, ...) + Second marginal reweighted + loss: float/array-like, shape (...) + USOT + + References + ---------- + [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). + Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research + """ + assert mode in ["backprop", "icdf"] + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None and not stochastic_proj: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + else: + n_projections = projections.shape[1] + + if not stochastic_proj: + X_s_projections = nx.dot(X_s, projections).T # shape (n_projs, n) + X_t_projections = nx.dot(X_t, projections).T + + X_s_sorter = nx.argsort(X_s_projections, -1) + X_s_rev_sorter = nx.argsort(X_s_sorter, -1) + X_s_sorted = nx.take_along_axis(X_s_projections, X_s_sorter, -1) + + X_t_sorter = nx.argsort(X_t_projections, -1) + X_t_rev_sorter = nx.argsort(X_t_sorter, -1) + X_t_sorted = nx.take_along_axis(X_t_projections, X_t_sorter, -1) + + # Initialize potentials - WARNING: They correspond to non-sorted samples + f = nx.zeros(a.shape, type_as=a) + g = nx.zeros(b.shape, type_as=b) + + for i in range(numItermax): + # Output FW descent direction + # translate potentials + transl = rescale_potentials(f, g, a, b, reg_m1, reg_m2, nx) + + f = f + transl + g = g - transl + + # If stochastic version then sample new directions and re-sort data + # TODO: add functions to sample and project + if stochastic_proj: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) + + X_s_sorter = nx.argsort(X_s_projections, -1) + X_s_rev_sorter = nx.argsort(X_s_sorter, -1) + X_s_sorted = nx.take_along_axis(X_s_projections, X_s_sorter, -1) + + X_t_sorter = nx.argsort(X_t_projections, -1) + X_t_rev_sorter = nx.argsort(X_t_sorter, -1) + X_t_sorted = nx.take_along_axis(X_t_projections, X_t_sorter, -1) + + # update measures + a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] + b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] + + full_mass = nx.sum(a_reweighted, axis=1) + + # normalize the weights for compatibility with wasserstein_1d + a_reweighted = a_reweighted / nx.sum(a_reweighted, axis=1, keepdims=True) + b_reweighted = b_reweighted / nx.sum(b_reweighted, axis=1, keepdims=True) + + # solve for new potentials + if mode == "icdf": + fd, gd, loss = emd_1d_dual( + X_s_sorted.T, + X_t_sorted.T, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + fd, gd = fd.T, gd.T + + elif mode == "backprop": + fd, gd, loss = emd_1d_dual_backprop( + X_s_sorted.T, + X_t_sorted.T, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + fd, gd = fd.T, gd.T + + # default step for FW + t = 2.0 / (2.0 + i) + + f = f + t * (nx.mean(nx.take_along_axis(fd, X_s_rev_sorter, 1), axis=0) - f) + g = g + t * (nx.mean(nx.take_along_axis(gd, X_t_rev_sorter, 1), axis=0) - g) + + ot_loss = wasserstein_1d( + X_s_sorted, + X_t_sorted, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + sot_loss = nx.mean(ot_loss * full_mass) + + a_reweighted, b_reweighted = a * nx.exp(-f / reg_m1), b * nx.exp(-g / reg_m2) + + uot_loss = ( + sot_loss + + reg_m1 * nx.kl_div(a_reweighted, a, mass=True) + + reg_m2 * nx.kl_div(b_reweighted, b, mass=True) + ) + + if log: + return a_reweighted, b_reweighted, uot_loss, {"projections": projections} + + return a_reweighted, b_reweighted, uot_loss diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py new file mode 100644 index 000000000..ea88920cd --- /dev/null +++ b/ot/unbalanced/_solver_1d.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- +""" +1D Unbalanced OT solvers +""" + +# Author: Clément Bonet +# +# License: MIT License + +from ..backend import get_backend +from ..utils import get_parameter_pair +from ..lp.solver_1d import emd_1d_dual, emd_1d_dual_backprop + + +def rescale_potentials(f, g, a, b, rho1, rho2, nx): + r""" + Find the optimal :math: `\lambda` in the translation invariant dual of UOT + with KL regularization and returns it, see Proposition 2 in :ref:`[73] `. + + Parameters + ---------- + f: array-like, shape (n, ...) + first dual potential + g: array-like, shape (m, ...) + second dual potential + a: array-like, shape (n, ...) + weights of the first empirical distribution + b: array-like, shape (m, ...) + weights of the second empirical distribution + rho1: float + Marginal relaxation term for the first marginal + rho2: float + Marginal relaxation term for the second marginal + nx: module + backend module + + Returns + ------- + transl: array-like, shape (...) + optimal translation + + .. _references-uot: + References + ---------- + .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + """ + if rho1 == float("inf") and rho2 == float("inf"): + return nx.zeros(shape=nx.sum(f, axis=0).shape, type_as=f) + + elif rho1 == float("inf"): + tau = rho2 + denom = nx.logsumexp(-g / rho2 + nx.log(b), axis=0) + num = nx.log(nx.sum(a, axis=0)) + + elif rho2 == float("inf"): + tau = rho1 + num = nx.logsumexp(-f / rho1 + nx.log(a), axis=0) + denom = nx.log(nx.sum(b, axis=0)) + + else: + tau = (rho1 * rho2) / (rho1 + rho2) + num = nx.logsumexp(-f / rho1 + nx.log(a), axis=0) + denom = nx.logsumexp(-g / rho2 + nx.log(b), axis=0) + + transl = tau * (num - denom) + + return transl + + +def uot_1d( + u_values, + v_values, + reg_m, + u_weights=None, + v_weights=None, + p=2, + require_sort=True, + numItermax=10, + mode="icdf", + returnCost="linear", + log=False, +): + r""" + Solves the 1D unbalanced OT problem with KL regularization. + The function implements the Frank-Wolfe algorithm to solve the dual problem, + as proposed in :ref:`[73] `. + + The unbalanced OT problem reads + .. math: + \mathrm{UOT}(\mu,\nu) = \min_{\gamma \in \mathcal{M}_{+}(\mathbb{R}\times\mathbb{R})} W_2^2(\pi^1_\#\gamma,\pi^2_\#\gamma) + \mathrm{reg_{m}}_1 \mathrm{KL}(\pi^1_\#\gamma|\mu) + \mathrm{reg_{m}}_2 \mathrm{KL}(\pi^2_\#\gamma|\nu). + + The mode "backprop" should be preferred, but is available only with backends supporting automatic differentiation (torch and jax) + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as inxut arrays `(a, b)`. + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 2 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + numItermax: int, optional + mode: str, optional + "icdf" for inverse CDF, "backprop" for backpropagation mode. + Default is "icdf". + returnCost: string, optional (default = "linear") + If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. + If `returnCost` = "total", then return the total unbalanced OT loss. + log: bool, optional + + Returns + ------- + u_reweighted: array-like shape (n, ...) + First marginal reweighted + v_reweighted: array-like shape (m, ...) + Second marginal reweighted + loss: float/array-like, shape (...) + The batched 1D UOT + + .. _references-uot: + References + --------- + .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + """ + assert mode in ["backprop", "icdf"] + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + n = u_values.shape[0] + m = v_values.shape[0] + + # Init weights or broadcast if necessary + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + # Sort w.r.t. support if not already done + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_rev_sorter = nx.argsort(u_sorter, 0) + u_values_sorted = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_rev_sorter = nx.argsort(v_sorter, 0) + v_values_sorted = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights_sorted = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights_sorted = nx.take_along_axis(v_weights, v_sorter, 0) + + f = nx.zeros(u_weights.shape, type_as=u_weights) + fd = nx.zeros(u_weights.shape, type_as=u_weights) + g = nx.zeros(v_weights.shape, type_as=v_weights) + gd = nx.zeros(v_weights.shape, type_as=v_weights) + + for i in range(numItermax): + transl = rescale_potentials( + f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx + ) + + f = f + transl[None] + g = g - transl[None] + + if reg_m1 != float("inf"): + u_reweighted = u_weights_sorted * nx.exp(-f / reg_m1) + else: + u_reweighted = u_weights_sorted + + if reg_m2 != float("inf"): + v_reweighted = v_weights_sorted * nx.exp(-g / reg_m2) + else: + v_reweighted = v_weights_sorted + + full_mass = nx.sum(u_reweighted, axis=0) + + # Normalize weights + u_rescaled = u_reweighted / nx.sum(u_reweighted, axis=0, keepdims=True) + v_rescaled = v_reweighted / nx.sum(v_reweighted, axis=0, keepdims=True) + + # print(i, fd) + + if mode == "icdf": + fd, gd, loss = emd_1d_dual( + u_values_sorted, + v_values_sorted, + u_weights=u_rescaled, + v_weights=v_rescaled, + p=p, + require_sort=False, + ) + elif mode == "backprop": + fd, gd, loss = emd_1d_dual_backprop( + u_values_sorted, + v_values_sorted, + u_weights=u_rescaled, + v_weights=v_rescaled, + p=p, + require_sort=False, + ) + + t = 2.0 / (2.0 + i) + f = f + t * (fd - f) + g = g + t * (gd - g) + + if require_sort: + f = nx.take_along_axis(f, u_rev_sorter, 0) + g = nx.take_along_axis(g, v_rev_sorter, 0) + u_reweighted = nx.take_along_axis(u_reweighted, u_rev_sorter, 0) + v_reweighted = nx.take_along_axis(v_reweighted, v_rev_sorter, 0) + + # rescale OT loss + linear_loss = loss * full_mass + + if reg_m1 == float("inf") and reg_m2 == float("inf"): + uot_loss = linear_loss + elif reg_m1 == float("inf"): + uot_loss = linear_loss + reg_m2 * nx.kl_div(v_reweighted, v_weights, mass=True) + elif reg_m2 == float("inf"): + uot_loss = linear_loss + reg_m1 * nx.kl_div(u_reweighted, u_weights, mass=True) + else: + uot_loss = ( + linear_loss + + reg_m1 * nx.kl_div(u_reweighted, u_weights, mass=True) + + reg_m2 * nx.kl_div(v_reweighted, v_weights, mass=True) + ) + + if returnCost == "linear": + out_loss = linear_loss + elif returnCost == "total": + out_loss = uot_loss + + if log: + dico = {"f": f, "g": g, "total_cost": uot_loss, "linear_cost": linear_loss} + return u_reweighted, v_reweighted, out_loss, dico + return u_reweighted, v_reweighted, out_loss diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index c2f377469..7762c7d35 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -2,6 +2,7 @@ # Author: Adrien Corenflos # Nicolas Courty +# Clément Bonet # # License: MIT License @@ -94,7 +95,7 @@ def test_wasserstein_1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) @@ -178,7 +179,7 @@ def test_emd1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) @@ -218,17 +219,13 @@ def test_emd1d_device_tf(): assert nx.dtype_device(emd)[1].startswith("GPU") -def test_wasserstein_1d_circle(): - # test binary_search_circle and wasserstein_circle give similar results as emd +def test_emd1d_dual_with_weights(): + # test emd1d_dual gives similar results as emd n = 20 m = 30 rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) + u = rng.randn(n, 1) + v = rng.randn(m, 1) w_u = rng.uniform(0.0, 1.0, n) w_u = w_u / w_u.sum() @@ -236,207 +233,88 @@ def test_wasserstein_1d_circle(): w_v = rng.uniform(0.0, 1.0, m) w_v = w_v / w_v.sum() - M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) - - wass1 = ot.emd2(w_u, w_v, M1) + M = ot.dist(u, v, metric="sqeuclidean") - wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) - w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] - M2 = M1**2 - wass2 = ot.emd2(w_u, w_v, M2) - wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) - w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + f, g, wass1d = ot.emd_1d_dual(u, v, w_u, w_v, p=2) # check loss is similar - np.testing.assert_allclose(wass1, wass1_bsc) - np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) - np.testing.assert_allclose(wass2, wass2_bsc) - np.testing.assert_allclose(wass2, w2_circle) + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, np.sum(f[:, 0] * w_u) + np.sum(g[:, 0] * w_v)) @pytest.skip_backend("tf") -def test_wasserstein1d_circle_devices(nx): +@pytest.skip_backend("jax") +def test_emd1d_dual_batch(nx): rng = np.random.RandomState(0) - n = 10 - x = np.linspace(0, 1, n) + n = 100 + x = np.linspace(0, 5, n) rho_u = np.abs(rng.randn(n)) rho_u /= rho_u.sum() rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) - - w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) - w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) - - nx.assert_same_dtype_device(xb, w1) - nx.assert_same_dtype_device(xb, w2_bsc) - - -def test_wasserstein_1d_unif_circle(): - # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle - n = 20 - m = 1000 - - rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) - - # w_u = rng.uniform(0., 1., n) - # w_u = w_u / w_u.sum() - - w_u = ot.utils.unif(n) - w_v = ot.utils.unif(m) - - M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) - wass2 = ot.emd2(w_u, w_v, M1**2) - - wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) - wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) - # check loss is similar - np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2) - np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2) + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + f, g, res = ot.emd_1d_dual(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) -def test_wasserstein1d_unif_circle_devices(nx): +def test_emd1d_dual_backprop_batch(nx): rng = np.random.RandomState(0) - n = 10 - x = np.linspace(0, 1, n) + n = 100 + x = np.linspace(0, 5, n) rho_u = np.abs(rng.randn(n)) rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) - - w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) - - nx.assert_same_dtype_device(xb, w2) - - -def test_binary_search_circle_log(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) - - wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) - optimal_thetas = log["optimal_theta"] - - assert optimal_thetas.shape[0] == 1 - + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) -def test_wasserstein_circle_bad_shape(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand(n, 2) - v = rng.rand(m, 1) + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) - with pytest.raises(ValueError): - _ = ot.wasserstein_circle(u, v, p=2) + if nx.__name__ in ["torch", "jax"]: + f, g, res = ot.emd_1d_dual_backprop(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) - with pytest.raises(ValueError): - _ = ot.wasserstein_circle(u, v, p=1) + cost_dual = nx.sum(f * rho_ub[:, None], axis=0) + nx.sum( + g * rho_vb[:, None], axis=0 + ) + np.testing.assert_allclose(cost_dual, res) + else: + np.testing.assert_raises( + AssertionError, ot.emd_1d_dual_backprop, Xb, Xb, rho_ub, rho_vb, p=2 + ) @pytest.skip_backend("tf") -def test_linear_circular_ot_devices(nx): +def test_emd1d_dual_type_devices(nx): rng = np.random.RandomState(0) n = 10 - x = np.linspace(0, 1, n) + x = np.linspace(0, 5, n) rho_u = np.abs(rng.randn(n)) rho_u /= rho_u.sum() rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) - - lcot = ot.linear_circular_ot(xb, xb, rho_ub, rho_vb) - - nx.assert_same_dtype_device(xb, lcot) - - -def test_linear_circular_ot_bad_shape(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand(n, 2) - v = rng.rand(m, 1) - - with pytest.raises(ValueError): - _ = ot.linear_circular_ot(u, v) - - -def test_linear_circular_ot_same_dist(): - n = 20 - rng = np.random.RandomState(0) - u = rng.rand(n) - - lcot = ot.linear_circular_ot(u, u) - np.testing.assert_almost_equal(lcot, 0.0) - - -def test_linear_circular_ot_different_dist(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand(n) - v = rng.rand(m) - - lcot = ot.linear_circular_ot(u, v) - assert lcot > 0.0 - - -def test_linear_circular_embedding_shape(): - n = 20 - rng = np.random.RandomState(0) - u = rng.rand(n, 2) - - ts = np.linspace(0, 1, 101)[:-1] - - emb = ot.lp.solver_1d.linear_circular_embedding(ts, u) - assert emb.shape == (100, 2) - - emb = ot.lp.solver_1d.linear_circular_embedding(ts, u[:, 0]) - assert emb.shape == (100, 1) - - -def test_linear_circular_ot_unif_circle(): - n = 20 - m = 1000 - - rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) - - lcot = ot.linear_circular_ot(u, v) - lcot_unif = ot.linear_circular_ot(u) - - # check loss is similar - np.testing.assert_allclose(lcot, lcot_unif, atol=1e-2) + f, g, res = ot.emd_1d_dual(xb, xb, rho_ub, rho_vb, p=1) + nx.assert_same_dtype_device(xb, res) + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) + + if nx.__name__ == "torch" or nx.__name__ == "jax": + f, g, res = ot.emd_1d_dual_backprop(xb, xb, rho_ub, rho_vb, p=1) + nx.assert_same_dtype_device(xb, res) + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) diff --git a/test/test_backend.py b/test/test_backend.py index efd696ef0..50e52eb73 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -139,6 +139,7 @@ def test_empty_backend(): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) v = rnd.randn(3) + inds = rnd.randint(10) nx = ot.backend.Backend() @@ -321,6 +322,10 @@ def test_empty_backend(): nx.slogdet(M) with pytest.raises(NotImplementedError): nx.unsqueeze(M, 0) + with pytest.raises(NotImplementedError): + nx.index_select(M, 0, inds) + with pytest.raises(NotImplementedError): + nx.nonzero(M) def test_func_backends(nx): @@ -753,6 +758,14 @@ def test_func_backends(nx): lst_b.append(np.array([s, logabsd])) lst_name.append("slogdet") + vec = nx.index_select(vb, 0, nx.from_numpy(np.array([0, 1]))) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("index_select") + + vec = nx.nonzero(Mb) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("nonzero") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_circle_solver.py b/test/test_circle_solver.py new file mode 100644 index 000000000..35097b1c0 --- /dev/null +++ b/test/test_circle_solver.py @@ -0,0 +1,234 @@ +"""Tests for module Circle Wasserstein solver""" + +# Author: Clément Bonet +# +# License: MIT License + +import numpy as np +import pytest + +import ot + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + w_u = rng.uniform(0.0, 1.0, n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0.0, 1.0, m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + # print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 1000 + + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + # print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) + + +@pytest.skip_backend("tf") +def test_linear_circular_ot_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + lcot = ot.linear_circular_ot(xb, xb, rho_ub, rho_vb) + + nx.assert_same_dtype_device(xb, lcot) + + +def test_linear_circular_ot_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.linear_circular_ot(u, v) + + +def test_linear_circular_ot_same_dist(): + n = 20 + rng = np.random.RandomState(0) + u = rng.rand(n) + + lcot = ot.linear_circular_ot(u, u) + np.testing.assert_almost_equal(lcot, 0.0) + + +def test_linear_circular_ot_different_dist(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n) + v = rng.rand(m) + + lcot = ot.linear_circular_ot(u, v) + assert lcot > 0.0 + + +def test_linear_circular_embedding_shape(): + n = 20 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + + ts = np.linspace(0, 1, 101)[:-1] + + emb = ot.lp.solver_circle.linear_circular_embedding(ts, u) + assert emb.shape == (100, 2) + + emb = ot.lp.solver_circle.linear_circular_embedding(ts, u[:, 0]) + assert emb.shape == (100, 1) + + +def test_linear_circular_ot_unif_circle(): + n = 20 + m = 1000 + + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + lcot = ot.linear_circular_ot(u, v) + lcot_unif = ot.linear_circular_ot(u) + + # check loss is similar + np.testing.assert_allclose(lcot, lcot_unif, atol=1e-2) diff --git a/test/unbalanced/test_1d_solver.py b/test/unbalanced/test_1d_solver.py new file mode 100644 index 000000000..ecc6826cc --- /dev/null +++ b/test/unbalanced/test_1d_solver.py @@ -0,0 +1,419 @@ +"""Tests for module 1D Unbalanced OT""" + +# Author: Clément Bonet +# +# License: MIT License + +import itertools +import numpy as np +import ot +import pytest + + +@pytest.skip_backend("tf") +def test_uot_1d(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = 1.0 + + M = ot.dist(xs, xt) + a, b, M = nx.from_numpy(a_np, b_np, M) + xs, xt = nx.from_numpy(xs, xt) + + G, log = ot.unbalanced.mm_unbalanced(a, b, M, reg_m, div="kl", log=True) + loss_mm = log["cost"] + + if nx.__name__ != "jax": + f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="icdf", p=2) + np.testing.assert_allclose(loss_1d, loss_mm, atol=1e-2) + np.testing.assert_allclose(G.sum(0), g[:, 0], atol=1e-2) + np.testing.assert_allclose(G.sum(1), f[:, 0], atol=1e-2) + + if nx.__name__ in ["jax", "torch"]: + f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="backprop", p=2) + np.testing.assert_allclose(loss_1d, loss_mm, atol=1e-2) + np.testing.assert_allclose(G.sum(0), g[:, 0], atol=1e-2) + np.testing.assert_allclose(G.sum(1), f[:, 0], atol=1e-2) + + +@pytest.skip_backend("tf") +def test_uot_1d_convergence(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + xs, xt = nx.from_numpy(xs, xt) + + reg_m = 1000 + + # wass1d = ot.wasserstein_1d(xs, xt, p=2) + G_1d, log = ot.emd_1d(xs, xt, metric="sqeuclidean", log=True) + wass1d = log["cost"] + u_w1d, v_w1d = nx.sum(G_1d, 1), nx.sum(G_1d, 0) + + if nx.__name__ != "jax": + u, v, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="icdf", p=2) + np.testing.assert_allclose(loss_1d, wass1d, atol=1e-2) + np.testing.assert_allclose(v_w1d, v[:, 0], atol=1e-2) + np.testing.assert_allclose(u_w1d, u[:, 0], atol=1e-2) + + if nx.__name__ in ["jax", "torch"]: + u, v, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="backprop", p=2) + np.testing.assert_allclose(loss_1d, wass1d, atol=1e-2) + np.testing.assert_allclose(v_w1d, v[:, 0], atol=1e-2) + np.testing.assert_allclose(u_w1d, u[:, 0], atol=1e-2) + + +@pytest.skip_backend("tf") +@pytest.skip_backend("jax") +def test_uot_1d_inf_reg_m_icdf(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = float("inf") + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + f_w1d, g_w1d, wass1d = ot.emd_1d_dual(xs, xt, a, b, p=2) + u, v, loss_1d, log = ot.unbalanced.uot_1d( + xs, xt, reg_m, a, b, mode="icdf", p=2, log=True + ) + + print("ICDF", loss_1d) + + # Check right loss + np.testing.assert_allclose(loss_1d, wass1d) + + # Check right marginals + np.testing.assert_allclose(a, u[:, 0]) + np.testing.assert_allclose(b, v[:, 0]) + + # Check potentials + np.testing.assert_allclose(f_w1d, log["f"]) + np.testing.assert_allclose(g_w1d, log["g"]) + + +def test_uot_1d_inf_reg_m_backprop(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = float("inf") + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + if nx.__name__ in ["jax", "torch"]: + f_w1d, g_w1d, wass1d = ot.emd_1d_dual_backprop(xs, xt, a, b, p=2) + u, v, loss_1d, log = ot.unbalanced.uot_1d( + xs, xt, reg_m, a, b, mode="backprop", p=2, log=True + ) + + print("Backprop", loss_1d) + + # Check right loss + np.testing.assert_allclose(loss_1d, wass1d) + + # Check right marginals + np.testing.assert_allclose(a, u[:, 0]) + np.testing.assert_allclose(b, v[:, 0]) + + # Check potentials + np.testing.assert_allclose(f_w1d, log["f"]) + np.testing.assert_allclose(g_w1d, log["g"]) + + +@pytest.skip_backend("tf") +@pytest.skip_backend("jax") +def test_semi_uot_1d_icdf(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = (float("inf"), 1.0) + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + u, v, loss_1d, log = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="icdf", p=2, log=True) + + # Check right marginals + np.testing.assert_allclose(a, u[:, 0]) + np.testing.assert_allclose(v[:, 0].sum(), 1) + + +def test_semi_uot_1d_backprop(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = (float("inf"), 1.0) + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + if nx.__name__ in ["jax", "torch"]: + u, v, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, mode="backprop", p=2) + + # Check right marginals + np.testing.assert_allclose(a, u[:, 0]) + np.testing.assert_allclose(v[:, 0].sum(), 1) + + +@pytest.skip_backend("tf") +@pytest.skip_backend("jax") +@pytest.mark.parametrize( + "reg_m", + itertools.product( + [1, float("inf")], + ), +) +def test_unbalanced_relaxation_parameters_icdf(nx, reg_m): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + a, b, x = nx.from_numpy(a, b, x) + + reg_m = reg_m[0] + + # options for reg_m + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx_reg_m = reg_m * nx.ones(1) + + list_options = [ + nx_reg_m, + full_tuple_reg_m, + tuple_reg_m, + full_list_reg_m, + list_reg_m, + ] + + u, v, loss = ot.unbalanced.uot_1d( + x, x, reg_m, u_weights=a, v_weights=b, p=2, mode="icdf" + ) + + for opt in list_options: + u, v, loss_opt = ot.unbalanced.uot_1d( + x, x, opt, u_weights=a, v_weights=b, p=2, mode="icdf" + ) + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) + + +@pytest.mark.parametrize( + "reg_m", + itertools.product( + [1, float("inf")], + ), +) +def test_unbalanced_relaxation_parameters_backprop(nx, reg_m): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + a, b, x = nx.from_numpy(a, b, x) + + reg_m = reg_m[0] + + # options for reg_m + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx_reg_m = reg_m * nx.ones(1) + + list_options = [ + nx_reg_m, + full_tuple_reg_m, + tuple_reg_m, + full_list_reg_m, + list_reg_m, + ] + + if nx.__name__ in ["jax", "torch"]: + u, v, loss = ot.unbalanced.uot_1d( + x, x, reg_m, u_weights=a, v_weights=b, p=2, mode="backprop" + ) + + for opt in list_options: + u, v, loss_opt = ot.unbalanced.uot_1d( + x, x, opt, u_weights=a, v_weights=b, p=2, mode="backprop" + ) + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) + + +@pytest.skip_backend("tf") +@pytest.skip_backend("jax") +@pytest.mark.parametrize( + "reg_m1, reg_m2", + itertools.product( + [1, float("inf")], + [1, float("inf")], + ), +) +def test_unbalanced_relaxation_parameters_pair_icdf(nx, reg_m1, reg_m2): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + a, b, x = nx.from_numpy(a, b, x) + + # options for reg_m + full_list_reg_m = [reg_m1, reg_m2] + full_tuple_reg_m = (reg_m1, reg_m2) + list_options = [full_tuple_reg_m, full_list_reg_m] + + _, _, loss = ot.unbalanced.uot_1d( + x, x, (reg_m1, reg_m2), u_weights=a, v_weights=b, p=2, mode="icdf" + ) + + for opt in list_options: + _, _, loss_opt = ot.unbalanced.uot_1d( + x, x, opt, u_weights=a, v_weights=b, p=2, mode="icdf" + ) + + np.testing.assert_allclose(nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) + + +@pytest.mark.parametrize( + "reg_m1, reg_m2", + itertools.product( + [1, float("inf")], + [1, float("inf")], + ), +) +def test_unbalanced_relaxation_parameters_pair_backprop(nx, reg_m1, reg_m2): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + a, b, x = nx.from_numpy(a, b, x) + + # options for reg_m + full_list_reg_m = [reg_m1, reg_m2] + full_tuple_reg_m = (reg_m1, reg_m2) + list_options = [full_tuple_reg_m, full_list_reg_m] + + if nx.__name__ in ["jax", "torch"]: + _, _, loss = ot.unbalanced.uot_1d( + x, x, (reg_m1, reg_m2), u_weights=a, v_weights=b, p=2, mode="backprop" + ) + + for opt in list_options: + _, _, loss_opt = ot.unbalanced.uot_1d( + x, x, opt, u_weights=a, v_weights=b, p=2, mode="backprop" + ) + + np.testing.assert_allclose( + nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05 + ) + + +@pytest.skip_backend("tf") +@pytest.skip_backend("jax") +def test_uot_1d_type_devices_icdf(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + reg_m = 1.0 + + for tp in nx.__type_list__: + # print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + f, g, _ = ot.unbalanced.uot_1d(xb, xb, reg_m, rho_ub, rho_vb, p=2, mode="icdf") + + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) + + +@pytest.skip_backend("tf") +@pytest.skip_backend("numpy") +@pytest.skip_backend("cupy") +def test_uot_1d_type_devices_backprop(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + reg_m = 1.0 + + for tp in nx.__type_list__: + # print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + f, g, _ = ot.unbalanced.uot_1d( + xb, xb, reg_m, rho_ub, rho_vb, p=2, mode="backprop" + ) + + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) diff --git a/test/unbalanced/test_sliced.py b/test/unbalanced/test_sliced.py new file mode 100644 index 000000000..bdd917f19 --- /dev/null +++ b/test/unbalanced/test_sliced.py @@ -0,0 +1,10 @@ +"""Tests for module sliced Unbalanced OT""" + +# Author: Clément Bonet +# +# License: MIT License + +import itertools +import numpy as np +import ot +import pytest