diff --git a/.github/workflows/build_doc.yml b/.github/workflows/build_doc.yml index 6575510c5..ff5f41176 100644 --- a/.github/workflows/build_doc.yml +++ b/.github/workflows/build_doc.yml @@ -14,6 +14,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: true # Standard drop-in approach that should work for most people. - name: Free Disk Space (Ubuntu) uses: insightsengineering/disk-space-reclaimer@v1 diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 28f5edbb3..9db704b22 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -28,6 +28,8 @@ jobs: - name: Checking Out Repository uses: actions/checkout@v4 + with: + submodules: true # Install Python & Packages - uses: actions/setup-python@v4 with: @@ -43,6 +45,9 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + submodules: true + - name: Set up Python uses: actions/setup-python@v5 with: @@ -70,6 +75,9 @@ jobs: android: true dotnet: true - uses: actions/checkout@v4 + with: + submodules: true + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -95,6 +103,9 @@ jobs: if: "!contains(github.event.head_commit.message, 'no ci')" steps: - uses: actions/checkout@v4 + with: + submodules: true + - name: Set up Python uses: actions/setup-python@v5 with: @@ -124,6 +135,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -152,6 +165,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 216ac6dbc..28c4c6785 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -19,6 +19,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: true - name: Set up Python 3.10 uses: actions/setup-python@v5 with: @@ -54,6 +56,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: true - name: Set up Python 3.10 uses: actions/setup-python@v5 with: diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml index 61cda61d1..4fb0c377f 100644 --- a/.github/workflows/build_wheels_weekly.yml +++ b/.github/workflows/build_wheels_weekly.yml @@ -18,6 +18,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: true - name: Set up Python 3.11 uses: actions/setup-python@v5 with: diff --git a/.gitignore b/.gitignore index 8ca0349c4..f8c7e0198 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ docs/modules/ # Cython output ot/lp/emd_wrap.cpp ot/partial/partial_cython.cpp +ot/bsp/bsp_wrap.cpp # Distribution / packaging .Python diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000..cd8bb2486 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "deps/eigen"] + path = deps/eigen + url = https://gitlab.com/libeigen/eigen.git diff --git a/RELEASES.md b/RELEASES.md index cdabf416d..fd9921d8b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -9,7 +9,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) - Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) - +- Add "BSP-OT: Sparse transport plans between discrete measures in loglinear time" (PR #768) + #### Closed issues - Fix NumPy 2.x compatibility in Brenier potential bounds (PR #788) diff --git a/deps/eigen b/deps/eigen new file mode 160000 index 000000000..d71c30c47 --- /dev/null +++ b/deps/eigen @@ -0,0 +1 @@ +Subproject commit d71c30c47858effcbd39967097a2d99ee48db464 diff --git a/ot/bsp/BSP-OT_header_only.h b/ot/bsp/BSP-OT_header_only.h new file mode 100644 index 000000000..03625ae06 --- /dev/null +++ b/ot/bsp/BSP-OT_header_only.h @@ -0,0 +1,2842 @@ + +// Amalgamation-specific define +#ifndef BSP_OT_HEADER_ONLY +#define BSP_OT_HEADER_ONLY +#endif + + +#pragma once +#include +#include +#include +#include + +// #include +#include + +namespace BSPOT { +// namespace fs = std::filesystem; + +// using scalar = double; +using scalar = float; +using scalars = std::vector; + +using vec = Eigen::Vector3; +using vec2 = Eigen::Vector2; +using mat2 = Eigen::Matrix2; +using mat = Eigen::Matrix3; +using mat4 = Eigen::Matrix4; +using vec4 = Eigen::Vector4; + +using Mat = Eigen::Matrix; +using Diag = Eigen::DiagonalMatrix; +using vecs = std::vector; +using vec2s = std::vector; + +using triplet = Eigen::Triplet; + +using ints = std::vector; +using Vec = Eigen::Vector; +using Vecs = std::vector; + +using smat = Eigen::SparseMatrix; + +using Index = long; +using grid_Index = std::pair; + + + + +inline ints rangeVec(int a,int b) { + ints rslt(b-a); + std::iota(rslt.begin(),rslt.end(),a); + return rslt; +} + +inline ints rangeVec(int i) { + return rangeVec(0,i); +} + +inline auto range(int i) { + return rangeVec(i); + // return std::views::iota(0,i); +} +inline auto range(int a,int b) { + return rangeVec(a,b); + // return std::views::iota(a,b); +} + + +template +using twins = std::pair; + +inline smat Identity(int V) { + smat I(V,V); + I.setIdentity(); + return I; + } + + + +template +bool Smin(T& a,T b) { + if (b < a){ + a = b; + return true; + } + return false; +} + +template +bool Smax(T& a,T b) { + if (a < b){ + a = b; + return true; + } + return false; +} + +} + + +// begin --- BSPOT.h --- + +#ifndef BSPOT_H +#define BSPOT_H + +namespace BSPOT { + +template +using Points = Eigen::Matrix; +template +using Vector = Eigen::Vector; + +using cost_function = std::function; +template +using geometric_cost = std::function; + +template +using CovType = Eigen::Matrix; + +template +struct Moments { + Vector mean; + CovType Cov; +}; + + +template +Vector Mean(const Points& X) { + return X.rowwise().mean(); +} + +template +CovType Covariance(const Points& X) { + Vector mean = X.rowwise().mean(); + Points centered = X.colwise() - mean; + CovType rslt = centered * centered.adjoint() / double(X.cols()); + return rslt; +} + +template +CovType Covariance(const Points& X,const Points& Y) { + Vector meanX = Mean(X); + Points centeredX = X.colwise() - meanX; + Vector meanY = Mean(Y); + Points centeredY = Y.colwise() - meanY; + CovType rslt = centeredX * centeredY.adjoint() / double(X.cols()); + return rslt; +} + + +template +CovType sqrt(const CovType &A) { + Eigen::SelfAdjointEigenSolver> root(A); + return root.operatorSqrt(); +} + +template +CovType W2GaussianTransportMap(const CovType& A,const CovType& B){ + Eigen::SelfAdjointEigenSolver> sasA(A); + CovType root_A = sasA.operatorSqrt(); + CovType inv_root_A = sasA.operatorInverseSqrt(); + CovType C = root_A * B * root_A; + C = sqrt(C); + C = inv_root_A*C*inv_root_A; + return C; +} + + +} + +#endif // BSPOT_H + + +// end --- BSPOT.h --- + + + + +// end --- types.h --- + + +// begin --- sliced.cpp --- + + + +// begin --- sliced.h --- + +#ifndef SLICED_H +#define SLICED_H + + + +namespace BSPOT { + + + +} + +#endif // SLICED_H + + +// end --- sliced.h --- + + + + + +// end --- sliced.cpp --- + + + +// begin --- coupling.cpp --- + + + +// begin --- coupling.h --- + +#ifndef COUPLING_H +#define COUPLING_H + +namespace BSPOT { + +using Coupling = Eigen::SparseMatrix; + +scalar EvalCoupling(const Coupling& pi,const cost_function& cost); + +template +Points CouplingToGrad(const Coupling& pi,const Points& A,const Points& B) { + Points Grad = Points::Zero(A.rows(),A.cols()); + for (int k = 0;k; + +inline Vec Mass(const Atoms& A) { + Vec M(A.size()); + for (auto i : range(A.size())) + M[i] = A[i].mass; + return M; +} + +inline Atoms FromMass(const Vec& x) { + Atoms rslt(x.size()); + for (auto i : range(x.size())) { + rslt[i].mass = x[i]; + rslt[i].id = i; + } + return rslt; +} + +inline Atoms UniformMass(int n) { + Atoms rslt(n); + for (auto i : range(n)) { + rslt[i].mass = 1./n; + rslt[i].id = i; + } + return rslt; +} + +} + + +#endif // COUPLING_H + + +// end --- coupling.h --- + + + + +BSPOT::scalar BSPOT::EvalCoupling(const Coupling &pi, const cost_function &cost) { + scalar W = 0; + for (int k = 0;k + +namespace BSPOT { + +inline void NormalizeDyn(Points<-1> &X, scalar dilat = 1) +{ + Vector<-1> min = X.rowwise().minCoeff(); + Vector<-1> max = X.rowwise().maxCoeff(); + Vector<-1> scale = max - min; + double f = dilat/scale.maxCoeff(); + Vector<-1> c = (min+max)*0.5; + X.colwise() -= c; + X *= f; +} + + +template +void Normalize(Points &X, Vector offset = Vector::Zero(dim), scalar dilat = 1) +{ + if (dim == -1) { + offset = Vector::Zero(X.rows()); + } + Vector min = X.rowwise().minCoeff(); + Vector max = X.rowwise().maxCoeff(); + Vector scale = max - min; + double f = dilat/scale.maxCoeff(); + Vector c = (min+max)*0.5; + X.colwise() -= c; + X *= f; + X.colwise() += offset; +} + + +template +Points concat(const Points& X,const Points& Y) { + Points rslt(X.rows(),X.cols() + Y.cols()); + rslt << X,Y; + return rslt; +} + +template +Points pad(const Points& X,int target) { + int n = X.cols(); + Points rslt(dim,target); + for (auto i : range(target)) + rslt.col(i) = X.col(rand()%n); + return rslt; +} + + +template +Points trunc(const Points& X,int target) { + static thread_local std::random_device rd; + static thread_local std::mt19937 g(rd()); + ints I = rangeVec(X.cols()); + ::std::shuffle(I.begin(),I.end(),g); + Points rslt(X.rows(),target); + for (auto i : range(target)) + rslt.col(i) = X.col(I[i]); + return rslt; +} + +template +inline Points ForceToSize(const Points& X,int target) { + if (X.size() == target) + return X; + if (X.size() < target) + return pad(X,target); + return trunc(X,target); +} + +} + + +#endif // CLOUDUTILS_H + + +// end --- cloudutils.h --- + + +#include + + +// end --- cloudutils.cpp --- + + + +// begin --- BijectiveMatching.cpp --- + + + +// begin --- BijectiveMatching.h --- + +#ifndef BIJECTIVEMATCHING_H +#define BIJECTIVEMATCHING_H + +// begin --- data_structures.h --- + +#ifndef DATA_STRUCTURES_H +#define DATA_STRUCTURES_H +#include +#include +#include +#include +#include + +namespace BSPOT { + +class UnionFind { +private: + std::vector parent, rank,componentSize; +public: + UnionFind(int n); + + int find(int u); + + void unite(int u, int v); + + std::vector> getConnectedComponents(int n); +} ; + + +class StampedPriorityQueue { +private: + + struct stamped_element { + scalar priority; + int id; + int timestamp; + bool operator<(const stamped_element& other) const { + return priority < other.priority; + } + }; + std::priority_queue queue; + std::map timestamp; + +public: + void insert(int key, scalar priority); + + std::pair pop(); + + bool empty() const; +}; + +struct Edge { + int i, j; + scalar w; +}; + + +class TreeGraph { +public: + std::vector> adj; // Adjacency list with unordered maps + + TreeGraph(int n) : adj(n) {} // Constructor initializes adjacency list with 'n' vertices + + void addEdge(int u, int v, scalar w) { + adj[u][v] = w; + adj[v][u] = w; + } + + void changeWeight(int u, int v, scalar w) { + if (u >= adj.size() || v >= adj.size()) return; // Out of bounds check + + auto it = adj[u].find(v); + if (it != adj[u].end()) { + it->second = w; + adj[v][u] = w; // Update the reverse edge as well + } + } + + void removeEdge(int u, int v) { + if (u >= adj.size() || v >= adj.size()) return; + + adj[u].erase(v); + adj[v].erase(u); + } + + std::vector findPath(int start, int end) { + if (start >= adj.size() || end >= adj.size()) return {}; // Out of bounds check + + std::unordered_map parent; // Maps node -> (parent edge) + std::queue q; + q.push(start); + parent[start] = {-1, -1, 0}; // Root has no parent edge + + bool found = false; + + // BFS traversal + while (!q.empty()) { + int node = q.front(); + q.pop(); + + if (node == end) { + found = true; + break; // Stop early when we reach the target + } + + for (const auto& [neighbor, weight] : adj[node]) { + if (parent.find(neighbor) == parent.end()) { // Not visited + parent[neighbor] = {node, neighbor, weight}; + q.push(neighbor); + } + } + } + + if (!found) return {}; // No path found + + // Reconstruct the path from end to start + std::vector path; + int current = end; + while (parent[current].i != -1) { // -1 means root node + path.push_back(parent[current]); + current = parent[current].i; + } + + std::reverse(path.begin(), path.end()); // Reverse to get correct order + return path; + } +}; + +} + + +#endif // DATA_STRUCTURES_H + + +// end --- data_structures.h --- + + + +// begin --- sampling.h --- + +#pragma once + +// begin --- types.h --- + +#include + +namespace BSPOT { + +struct PCG32 +{ + PCG32( ) : x(), key() { seed(0x853c49e6748fea9b, c); } + PCG32( const uint64_t s, const uint64_t ss= c ) : x(), key() { seed(s, ss); } + + void seed( const uint64_t s, const uint64_t ss= c ) + { + key= (ss << 1) | 1; + + x= key + s; + sample(); + } + + unsigned sample( ) + { + // f(x), fonction de transition + uint64_t xx= x; + x= a*x + key; + + // g(x), fonction résultat + uint32_t tmp= ((xx >> 18u) ^ xx) >> 27u; + uint32_t r= xx >> 59u; + return (tmp >> r) | (tmp << ((~r + 1u) & 31)); + } + + // c++ interface + unsigned operator() ( ) { return sample(); } + static constexpr unsigned min( ) { return 0; } + static constexpr unsigned max( ) { return ~unsigned(0); } + typedef unsigned result_type; + + static constexpr uint64_t a= 0x5851f42d4c957f2d; + static constexpr uint64_t c= 0xda3e39cb94b95bdb; + + uint64_t x; + uint64_t key; +}; + + + +inline Vecs sampleUnitGaussian(int N,int dim) { + //static std::mt19937 gen; + + static std::random_device hwseed; + static PCG32 gen( hwseed(), hwseed() ); + static std::normal_distribution dist{0.0,1.0}; + Vecs X(N,Vec(dim)); + for (auto& x : X){ + for (int i = 0;i dist{0.0,1.0}; + //static thread_local std::mt19937 gen; + static thread_local std::random_device rd; + static thread_local std::mt19937 rng(rd()); + Vec X(dim); + for (int i = 0;i dist{0.0,1.0}; + Mat X(dim,n); + for (auto i : range(dim)) + for (auto j : range(n)) + X(i,j) = dist(gen); + return X; +} + +inline Mat sampleUnitSphere(int n,int dim) { + static std::mt19937 gen; + static std::normal_distribution dist{0.0,1.0}; + Mat X(dim,n); + for (auto i : range(n)){ + for (auto j : range(dim)) + X(j,i) = dist(gen); + X.col(i).normalize(); + } + return X; +} + +inline Mat sampleUnitSquare(int n,int dim) { + static std::mt19937 gen; + static std::normal_distribution dist{0.0,1.0}; + Mat X(dim,n); + for (auto i : range(n)){ + for (auto j : range(dim)) + X(j,i) = dist(gen); + X.col(i) /= X.col(i).lpNorm(); + } + return X; +} + + +template +size_t WeightedRandomChoice(const T& weights) { + // Random number generator + static std::random_device rd; + static std::mt19937 gen(rd()); + + // Create a discrete distribution based on the weights + std::discrete_distribution<> dist(weights.begin(), weights.end()); + + // Draw an index based on weights + return dist(gen); +} + +inline Vecs fibonacci_sphere(int n) +{ + static double goldenRatio = (1 + std::sqrt(5.))/2.; + Vecs FS(n); + for (int i = 0;i gaussian_dist; + static std::uniform_real_distribution uniform_dist; + // Génère un point gaussien aléatoire + Vec point(d); + for (int i = 0; i < d; ++i) { + point[i] = gaussian_dist(gen); + } + + // Normalisation pour obtenir un point sur la sphère + point.normalize(); + + // Distance aléatoire à l'intérieur de la boule avec distribution uniforme + double radius = std::pow(uniform_dist(gen), 1.0 / d); + + return point * radius; +} + +// Fonction principale pour échantillonner N points dans la boule unité de dimension d +inline Vecs sample_unit_ball(int N, int d,double r = 1,Vec offset = Vec()) { + Vecs samples(N); + if (!offset.size()) + offset = Vec::Zero(d); + + for (int i = 0; i < N; ++i) + samples[i] = sample_point_in_unit_ball(d)*r + offset; + + return samples; +} + +inline Vecs sampleGaussian(int dim,int N,const Vec& mean,const Mat& Cov) { + Vecs X = sampleUnitGaussian(N,dim); + for (auto& x : X) + x = Cov*x + mean; + return X; +} + +inline Mat sampleUnitBall(int N,int d) { + static std::mt19937 gen; + static std::normal_distribution gaussian_dist; + static std::uniform_real_distribution uniform_dist; + + Mat X(d,N); + for (auto i : range(N)){ + Vec point(d); + for (int j = 0; j < d; ++j) { + point[j] = gaussian_dist(gen); + } + + // Normalisation pour obtenir un point sur la sphère + point.normalize(); + + // Distance aléatoire à l'intérieur de la boule avec distribution uniforme + double radius = std::pow(uniform_dist(gen), 1.0 / d); + + X.col(i) = point * radius; + } + return X; +} + + +template +inline Points sampleUnitBall(int N,int dim = D) { + static std::mt19937 gen; + static std::normal_distribution gaussian_dist; + static std::uniform_real_distribution uniform_dist; + + Points X(dim, N); + for (auto i : range(N)){ + Vector point(dim); + for (int j = 0; j < dim; ++j) + point[j] = gaussian_dist(gen); + + // Normalisation pour obtenir un point sur la sphère + point.normalize(); + + // Distance aléatoire à l'intérieur de la boule avec distribution uniforme + double radius = std::pow(uniform_dist(gen), 1.0 / dim); + + X.col(i) = point * radius; + } + return X; +} + +template +inline Vector sampleUnitGaussian(int dim = D) { + static thread_local std::random_device rd; + static thread_local std::mt19937 gen(rd()); + std::normal_distribution gaussian_dist(0,1); + Vector point(dim); + for (int j = 0; j < dim; ++j) + point[j] = gaussian_dist(gen); + return point; +} + +inline int randint(int a,int b) { + static thread_local std::random_device rd; + static thread_local std::mt19937 gen(rd()); + std::uniform_int_distribution dist(a,b); + return dist(gen); +} + + +} + + +// end --- sampling.h --- + + +#include + +namespace BSPOT { + + +class BijectiveMatching +{ +public: + using TransportPlan = ints; + + BijectiveMatching(); + BijectiveMatching(const TransportPlan& T) : plan(T),inverse_plan(getInverse(T)) {} + BijectiveMatching(const Eigen::Vector& T); + + scalar evalMatching(const cost_function& cost) const; + + template + scalar evalMatchingL2(const Points& A,const Points& B) const { + return (A - B(Eigen::all,plan)).squaredNorm()/A.cols(); + } + + const TransportPlan& getPlan() const; + + size_t operator[](size_t i) const; + size_t operator()(size_t i) const; + size_t size() const; + operator TransportPlan() const; + + BijectiveMatching inverseMatching(); + + BijectiveMatching inverseMatching() const; + bool checkBijectivity() const; + + BijectiveMatching operator()(const BijectiveMatching& other) const; + + template + std::vector operator()(const std::vector& X); + + const TransportPlan& getInversePlan(); + + bool operator==(const BijectiveMatching& other) const { + return plan == other.plan; + } + + + static inline bool swapIfUpgrade(ints &T, ints &TI, const ints &TP, int a, const cost_function &cost) { + int b = T[a]; + int bp = TP[a]; + int ap = TI[bp]; + if (a == ap || b == bp) + return false; + scalar old_cost = cost(a,b) + cost(ap,bp); + scalar new_cost = cost(a,bp) + cost(ap,b); + if (new_cost < old_cost) { + T[a] = bp; + T[ap] = b; + TI[bp] = a; + TI[b] = ap; + return true; + } + return false; + } + +protected: + + BijectiveMatching(const TransportPlan& T,const TransportPlan& TI); + TransportPlan plan,inverse_plan; + + static TransportPlan getInverse(const TransportPlan& T); +}; + +BijectiveMatching Merge(const BijectiveMatching &T, const BijectiveMatching &TP, const cost_function &cost,bool verbose = false); + +BijectiveMatching MergePlans(const std::vector &plans,const cost_function &cost,BijectiveMatching T = BijectiveMatching(),bool cycle = true); +BijectiveMatching MergePlansNoPar(const std::vector &plans,const cost_function &cost,BijectiveMatching T = BijectiveMatching(),bool cycle = true); + +bool swapIfUpgradeK(ints &T, ints &TI, const ints &TP, int a,int k, const cost_function &cost); + +inline ints rankPlans(const std::vector& plans,const cost_function& cost) { + std::vector> scores(plans.size()); + for (auto i : range(plans.size())) { + scores[i].first = plans[i].evalMatching(cost); + scores[i].second = i; + } + std::sort(scores.begin(),scores.end(),[](const auto& a,const auto& b) { + return a.first < b.first; + }); + ints rslt(scores.size()); + for (auto i : range(scores.size())) + rslt[i] = scores[i].second; + return rslt; +} + + +inline bool checkBijection(const ints& T,const ints& TI) { + ints I(T.size(),-1); + for (auto i : range(T.size())) + I[T[i]] = i; + bool ok = true; + for (auto i : range(T.size())) + if (I[i] == -1){ + std::cerr << "not bijection" << std::endl;; + ok = false; + } + for (auto i : range(T.size())) + if (TI[T[i]] != i){ + ok = false; + } + return ok; +} + +inline void checkBijection(const ints& T) { + ints I(T.size(),-1); + for (auto i : range(T.size())) + I[T[i]] = i; + for (auto i : range(T.size())) + if (I[i] == -1) + std::cerr << "not bijection" << std::endl;; +} + +BijectiveMatching load_plan(std::string path); + +inline void out_plan(std::string out,const BijectiveMatching& T) { + std::ofstream file(out); + for (auto i : range(T.size())) + file << T[i] << "\n"; + file.close(); +} + + +} + +#endif // BIJECTIVEMATCHING_H + + +// end --- BijectiveMatching.h --- + + + +namespace BSPOT { + +BijectiveMatching::BijectiveMatching(){} + +BijectiveMatching::BijectiveMatching(const Eigen::Vector &T) { + plan.resize(T.size()); + for (auto i : range(T.size())) + plan[i] = T[i]; + inverse_plan = getInverse(plan); +} + +scalar BijectiveMatching::evalMatching(const cost_function &cost) const { + scalar c = 0; + if (plan.empty()) { + std::cerr << "tried to eval cost on empty plan!" << std::endl;; + return 0; + } + + for (auto i : range(plan.size())) + c += cost(i,plan.at(i)); + return c/plan.size(); +} + +const BijectiveMatching::TransportPlan &BijectiveMatching::getPlan() const {return plan;} + +size_t BijectiveMatching::operator[](size_t i) const {return plan.at(i);} + +size_t BijectiveMatching::operator()(size_t i) const {return plan.at(i);} + +size_t BijectiveMatching::size() const {return plan.size();} + +BijectiveMatching::operator TransportPlan() const { + return plan; +} + +BijectiveMatching BijectiveMatching::inverseMatching() { + if (inverse_plan.empty()) + inverse_plan = getInversePlan(); + return BijectiveMatching(inverse_plan,plan); +} + +BijectiveMatching BijectiveMatching::inverseMatching() const { + if (inverse_plan.empty()) + return BijectiveMatching(getInverse(plan),plan); + return BijectiveMatching(inverse_plan,plan); +} + +bool BijectiveMatching::checkBijectivity() const +{ + auto I = getInverse(plan); + for (auto i : I) + if (i == -1) + return false; + return true; +} + +BijectiveMatching BijectiveMatching::operator()(const BijectiveMatching &other) const { + TransportPlan rslt(other.size()); + for (auto i : range(other.size())) + rslt[i] = plan[other[i]]; + return rslt; +} + +BijectiveMatching::BijectiveMatching(const TransportPlan &T, const TransportPlan &TI) : plan(T),inverse_plan(TI) {} + +const BijectiveMatching::TransportPlan &BijectiveMatching::getInversePlan() { + if (inverse_plan.empty()) + inverse_plan = getInverse(plan); + return inverse_plan; +} + +BijectiveMatching::TransportPlan BijectiveMatching::getInverse(const TransportPlan &T) { + TransportPlan TI(T.size(),-1); + for (auto i : range(T.size())){ + TI[T[i]] = i; + } + return TI; +} + +template +std::vector BijectiveMatching::operator()(const std::vector &X) { + std::vector rslt(X.size()); + for (auto i : range(X.size())) + rslt[plan[i]] = X[i]; + return rslt; +} + + +BijectiveMatching Merge(const BijectiveMatching &T, const BijectiveMatching &TP, const cost_function &cost, bool verbose) { + if (T.size() == 0) + return TP; + int N = T.size(); + + UnionFind UF(N*2); + for (auto i : range(N)) { + UF.unite(i,T[i]+N); + UF.unite(i,TP[i]+N); + } + + std::unordered_map components; + for (auto i = 0;i connected_components(components.size()); + int i = 0; + for (auto& [p,cc] : components) + connected_components[i++] = cc; + + + for (int k = 0;k &plans, const cost_function &cost, BijectiveMatching T,bool cycle) { + int s = 0; + auto I = true ? rankPlans(plans,cost) : rangeVec(plans.size()); + if (T.size() == 0) { + T = plans[I[0]]; + s = 1; + } + int N = plans[0].size(); + + auto C = evalMappings(T,cost); + + ints rslt = T; + ints rsltI = T.inverseMatching(); + + ints sig(N); + + scalar avg_cc_size = 0; + + for (auto k : range(s,plans.size())) { + ints Tp = plans[I[k]]; + ints Tpi = plans[I[k]].inverseMatching(); + auto Cp = evalMappings(Tp,cost); + + for (auto i : range(N)) + sig[i] = Tpi[rslt[i]]; + + // profiler.start(); + + std::vector CCs; + + if (cycle) { + ints visited(N,-1); + int c = 0; + for (auto i : range(N)) { + if (visited[i] != -1) + continue; + int j = i; + int i0 = i; + if (sig[j] == i) + continue; + + ints CC; + scalar costT = 0; + scalar costTP = 0; + + while (visited[j] == -1) { + CC.push_back(j); + costT += C[j]; + costTP += Cp[j]; + visited[j] = c; + j = sig[j]; + } + + if (costTP < costT) { + j = i0; + do { + std::swap(Tp[j],rslt[j]); + std::swap(C[j],Cp[j]); + j = sig[j]; + } while (j != i0); + j = i0; + do { + rsltI[rslt[j]] = j; + j = sig[j]; + } while (j != i0); + } + + c++; + CCs.push_back(CC); + avg_cc_size += CC.size(); + } + } else { + CCs.push_back(rangeVec(N)); + } + // profiler.tick("cycle"); + // for (auto a : range(N)) +// spdlog::info("nb cycles {} avg size {}",CCs.size(),avg_cc_size / CCs.size() ); +// #pragma omp parallel for +#pragma omp parallel + { +#pragma omp single + { + for (int i = 0; i < CCs.size(); ++i) { +#pragma omp task firstprivate(i) + { + for (auto a : CCs[i]){ + // swapIfUpgradeK(rslt,rsltI,Tp,a,3,cost); + int b = rslt[a]; + int bp = Tp[a]; + int ap = rsltI[bp]; + if (a == ap || b == bp) + continue; + scalar old_cost = C[a] + C[ap]; + scalar cabp = Cp[a]; + if (cabp > old_cost) + continue; + scalar capb = cost(ap,b); + if (cabp + capb < old_cost) { + rslt[a] = bp; + rslt[ap] = b; + rsltI[bp] = a; + rsltI[b] = ap; + C[a] = cabp; + C[ap] = capb; + } + } + } + } + } + } + // for (const auto& cc : CCs) + // { + // std::cout << "cc size " << cc.size() << std::endl; + // } + // profiler.tick("greedy"); + } + // profiler.profile(false); + return rslt; +} + +BijectiveMatching MergePlansNoPar(const std::vector &plans, const cost_function &cost, BijectiveMatching T,bool cycle) { + int s = 0; + auto I = true ? rankPlans(plans,cost) : rangeVec(plans.size()); + if (T.size() == 0) { + T = plans[I[0]]; + s = 1; + } + int N = plans[0].size(); + + auto C = evalMappings(T,cost); + + ints rslt = T; + ints rsltI = T.inverseMatching(); + + ints sig(N); + + for (auto k : range(s,plans.size())) { + ints Tp = plans[I[k]]; + ints Tpi = plans[I[k]].inverseMatching(); + auto Cp = evalMappings(Tp,cost); + + for (auto i : range(N)) + sig[i] = Tpi[rslt[i]]; + + // profiler.start(); + + std::vector CCs; + + if (cycle) { + ints visited(N,-1); + int c = 0; + for (auto i : range(N)) { + if (visited[i] != -1) + continue; + int j = i; + int i0 = i; + if (sig[j] == i) + continue; + + ints CC; + scalar costT = 0; + scalar costTP = 0; + + while (visited[j] == -1) { + CC.push_back(j); + costT += C[j]; + costTP += Cp[j]; + visited[j] = c; + j = sig[j]; + } + + if (costTP < costT) { + j = i0; + do { + std::swap(Tp[j],rslt[j]); + std::swap(C[j],Cp[j]); + j = sig[j]; + } while (j != i0); + j = i0; + do { + rsltI[rslt[j]] = j; + j = sig[j]; + } while (j != i0); + } + + c++; + CCs.push_back(CC); + } + } else { + CCs.push_back(rangeVec(N)); + } + for (int i = 0; i < CCs.size(); ++i) { + { + for (auto a : CCs[i]){ + // swapIfUpgradeK(rslt,rsltI,Tp,a,3,cost); + int b = rslt[a]; + int bp = Tp[a]; + int ap = rsltI[bp]; + if (a == ap || b == bp) + continue; + scalar old_cost = C[a] + C[ap]; + scalar cabp = Cp[a]; + if (cabp > old_cost) + continue; + scalar capb = cost(ap,b); + if (cabp + capb < old_cost) { + rslt[a] = bp; + rslt[ap] = b; + rsltI[bp] = a; + rsltI[b] = ap; + C[a] = cabp; + C[ap] = capb; + } + } + } + } + } + return rslt; +} + +BijectiveMatching load_plan(std::string path) { + std::ifstream file(path); + ints plan; + while (file) { + int i; + file >> i; + plan.push_back(i); + } + //remove last element + plan.pop_back(); + return plan; +} + + +template +inline std::vector> getPermutations(std::vector C) { + std::vector> rslt; + do + { + rslt.push_back(C); + } + while (std::next_permutation(C.begin(), C.end())); + return rslt; +} + + +bool swapIfUpgradeK(ints &plan, ints &inverse_plan, const ints &T, int a, int k, const cost_function &cost) +{ + if (k == 2) { + return BijectiveMatching::swapIfUpgrade(plan,inverse_plan,T,a,cost); + } + scalar s = 0; + std::set A,TA; + A.insert(a); + TA.insert(plan[a]); + auto i = a; + for (auto k : range(k-1)) { + auto j = T[i]; + i = inverse_plan[j]; + A.insert(i); + TA.insert(j); + } + if (TA.size() != A.size() || TA.size() == 1) + return BijectiveMatching::swapIfUpgrade(plan,inverse_plan,T,a,cost); + ints TAvec(TA.begin(),TA.end()); + ints Avec(A.begin(),A.end()); + auto Sig = getPermutations(TAvec); + ints best; + scalar score = 1e8; + + scalar curr = 0; + for (auto i : range(A.size())) + curr += cost(Avec[i],plan[Avec[i]]); + + for (const auto& sig : Sig) { + scalar c = 0; + for (auto i : range(sig.size())) + c += cost(Avec[i],sig[i]); + if (Smin(score,c)) + best = sig; + } + if (score > curr) + return false; + for (auto i : range(best.size())){ + plan[Avec[i]] = best[i]; + inverse_plan[best[i]] = Avec[i]; + } + return true; +} + +} + + +// end --- BijectiveMatching.cpp --- + + + +// begin --- data_structures.cpp --- + + + + + +BSPOT::UnionFind::UnionFind(int n) { + parent.resize(n); + rank.resize(n, 0); + componentSize.resize(n, 1); // Initialize each component size to 1 + for (int i = 0; i < n; ++i) parent[i] = i; +} + +int BSPOT::UnionFind::find(int u) { + if (parent[u] != u) { + parent[u] = find(parent[u]); // Path compression + } + return parent[u]; +} + + +void BSPOT::UnionFind::unite(int x, int y) { + int rootX = find(x), rootY = find(y); + if (rootX != rootY) { + if (rank[rootX] > rank[rootY]) { + parent[rootY] = rootX; + componentSize[rootX] += componentSize[rootY]; + } else if (rank[rootX] < rank[rootY]) { + parent[rootX] = rootY; + componentSize[rootY] += componentSize[rootX]; + } else { + parent[rootY] = rootX; + componentSize[rootX] += componentSize[rootY]; + rank[rootX]++; + } + } +} + +std::vector> BSPOT::UnionFind::getConnectedComponents(int n) { + std::unordered_map rootIndex; // Maps root -> index in components + std::vector> components; + + // **Step 1: Determine component sizes and allocate memory** + for (int i = 0; i < n; i++) { + int root = find(i); + if (rootIndex.find(root) == rootIndex.end()) { + rootIndex[root] = components.size(); + components.emplace_back(); + components.back().reserve(componentSize[root]); // Preallocate! + } + } + + // **Step 2: Populate components without push_back overhead** + for (int i = 0; i < n; i++) { + int root = find(i); + components[rootIndex[root]].push_back(i); + } + + return components; +} + +void BSPOT::StampedPriorityQueue::insert(int key, scalar priority) { + int ts = 0; + if (timestamp.find(key) == timestamp.end()) + ts = timestamp[key]+1; + timestamp[key] = ts; + queue.push(stamped_element{priority, key, ts}); +} + +std::pair BSPOT::StampedPriorityQueue::pop() { + if (queue.empty()) + return {-1, 0}; + stamped_element e = queue.top(); + queue.pop(); + while (timestamp[e.id] != e.timestamp) { + if (queue.empty()) + return {-1, 0}; + e = queue.top(); + queue.pop(); + } + return {e.id, e.priority}; +} + +bool BSPOT::StampedPriorityQueue::empty() const { + return queue.empty(); +} + + +// end --- data_structures.cpp --- + + + +// begin --- InjectiveMatching.cpp --- + + + +// begin --- InjectiveMatching.h --- + +#ifndef INJECTIVEMATCHING_H +#define INJECTIVEMATCHING_H + +namespace BSPOT { + +class InjectiveMatching +{ +public: + using TransportPlan = ints; + using InverseTransportPlan = ints; + + int image_domain_size = -1; + + InjectiveMatching(int m); + InjectiveMatching(); + InjectiveMatching(const TransportPlan& T,int m); + + scalar evalMatching(const cost_function& cost) const; + + const TransportPlan& getPlan() const; + + size_t operator[](size_t i) const; + size_t operator()(size_t i) const; + size_t size() const; + operator TransportPlan() const; + + InverseTransportPlan inversePlan(); + InverseTransportPlan inversePlan() const; + + static bool swapIfUpgrade(ints& T,ints& TI,const ints& TP,int a,const cost_function& cost); + + static InjectiveMatching Merge(const InjectiveMatching& T1,const InjectiveMatching& T2,const cost_function& cost); + + InverseTransportPlan getInverse() const; + + +protected: + InjectiveMatching(const TransportPlan& T,const TransportPlan& TI); + TransportPlan plan; + InverseTransportPlan inverse_plan; + + const TransportPlan& getInversePlan(); + +}; + + +Vec evalMappings(const InjectiveMatching& T,const cost_function& cost); + +InjectiveMatching MergePlans(const std::vector& plans,const cost_function& cost,InjectiveMatching T = InjectiveMatching()); + + +} +#endif // INJECTIVEMATCHING_H + + +// end --- InjectiveMatching.h --- + + + + +BSPOT::InjectiveMatching::InjectiveMatching(int m) : image_domain_size(m) {} + +BSPOT::InjectiveMatching::InjectiveMatching() {} + +BSPOT::InjectiveMatching::InjectiveMatching(const TransportPlan &T, int m) : image_domain_size(m),plan(T) { + +} + +BSPOT::scalar BSPOT::InjectiveMatching::evalMatching(const cost_function &cost) const { + scalar c = 0; + for (auto i : range(plan.size())) + c += cost(i,plan[i])/plan.size(); + return c; +} + +const BSPOT::InjectiveMatching::TransportPlan &BSPOT::InjectiveMatching::getPlan() const {return plan;} + +size_t BSPOT::InjectiveMatching::operator[](size_t i) const {return plan[i];} + +size_t BSPOT::InjectiveMatching::operator()(size_t i) const {return plan[i];} + +size_t BSPOT::InjectiveMatching::size() const {return plan.size();} + +BSPOT::InjectiveMatching::operator TransportPlan() const {return plan;} + +bool BSPOT::InjectiveMatching::swapIfUpgrade(ints &T, ints &TI, const ints &TP, int a, const cost_function &cost) { + int b = T[a]; + int bp = TP[a]; + int ap = TI[bp]; + if (a == ap || b == bp) + return false; + if (a == ap || b == bp) + return false; + if (ap != -1) { + if (cost(ap,b) + cost(a,bp) < cost(a,b) + cost(ap,bp) ){ + T[a] = bp; + T[ap] = b; + TI[bp] = a; + TI[b] = ap; + return true; + } + } + else { + if (cost(a,bp) < cost(a,b)) { + T[a] = bp; + TI[b] = -1; + TI[bp] = a; + return true; + } + } + return false; +} + +BSPOT::InjectiveMatching::InverseTransportPlan BSPOT::InjectiveMatching::inversePlan() { + if (inverse_plan.empty()) + inverse_plan = getInverse(); + return inverse_plan; +} + +BSPOT::InjectiveMatching::InverseTransportPlan BSPOT::InjectiveMatching::inversePlan() const { + if (inverse_plan.empty()) + std::cerr << "inverse plan not computed" << std::endl;; + return inverse_plan; +} + +const BSPOT::InjectiveMatching::TransportPlan &BSPOT::InjectiveMatching::getInversePlan() { + inverse_plan = getInverse(); + return inverse_plan; +} + + + +BSPOT::InjectiveMatching::InverseTransportPlan BSPOT::InjectiveMatching::getInverse() const { + if (image_domain_size == -1) { + return {}; + } + InverseTransportPlan rslt(image_domain_size,-1); + for (auto i : range(plan.size())) + rslt[plan[i]] = i; + return rslt; +} + +bool checkValid(const BSPOT::ints &T,const BSPOT::ints& TI) { + int M = TI.size(); + std::set image; + for (auto i : BSPOT::range(T.size())) { + if (T[i] == -1) + return false; + image.insert(T[i]); + } + if (image.size() != T.size()){ + std::cerr << "not injective" << std::endl;; + return false; + } + for (auto i : BSPOT::range(T.size())) + if (TI[T[i]] != i){ + std::cerr << "wrong inverse" << std::endl;; + return false; + } + for (auto i : BSPOT::range(M)){ + if (TI[i] != -1 && image.find(i) != image.end()){ + std::cerr << "wrong inverse" << std::endl;; + return false; + } + } + return true; +} + + +BSPOT::InjectiveMatching BSPOT::InjectiveMatching::Merge(const InjectiveMatching &T, const InjectiveMatching &TP, const cost_function &cost) +{ + if (T.size() == 0) + return TP; + int N = T.size(); + int M = T.image_domain_size; + + UnionFind UF(N + M); + for (auto i : range(N)) { + UF.unite(i,T[i]+N); + UF.unite(i,TP[i]+N); + } + + std::map components; + for (auto i = 0;i connected_components(components.size()); + int i = 0; + for (auto& [p,cc] : components) + connected_components[i++] = cc; + + +#pragma omp parallel for + for (int k = 0;k &plans, const cost_function &cost, BSPOT::InjectiveMatching T) { + int s = 0; + if (T.size() == 0) { + T = plans[0]; + s = 1; + } + int N = plans[0].size(); + + auto C = evalMappings(T,cost); + + ints rslt = T; + ints rsltI = T.getInverse(); + + for (auto k : range(s,plans.size())) { + + auto Cp = evalMappings(plans[k],cost); + + const auto& Tp = plans[k]; + for (auto a : range(N)) + { + int b = rslt[a]; + int bp = Tp[a]; + int ap = rsltI[bp]; + if (a == ap || b == bp) + continue; + if (ap != -1) { + scalar old_cost = C[a] + C[ap]; + scalar cabp = Cp[a]; + if (cabp > old_cost) + continue; + scalar capb = cost(ap,b); + if (cabp + capb < old_cost) { + rslt[a] = bp; + rslt[ap] = b; + rsltI[bp] = a; + rsltI[b] = ap; + C[a] = cabp; + C[ap] = capb; + } + } else { + scalar old_cost = C[a]; + scalar cabp = cost(a,bp); + if (cabp < old_cost) { + rslt[a] = bp; + rsltI[b] = -1; + rsltI[bp] = a; + } + } + } + } + return InjectiveMatching(rslt,plans[0].image_domain_size); +} + + +// end --- InjectiveMatching.cpp --- + + + +// begin --- PartialBSPMatching.h --- + +#ifndef PARTIALBSPMATCHING_H +#define PARTIALBSPMATCHING_H + +namespace BSPOT { + +template +class PartialBSPMatching { +public: + using TransportPlan = ints; + + using Pts = Points; + const Pts& A; + const Pts& B; + +protected: + int dim; + cost_function cost; + + struct dot_id { + scalar dot; + int id; + bool operator<(const dot_id& other) const { + return dot < other.dot; + } + }; + + using ids = std::vector; + + + int partition(ids &atoms, int beg, int end, int idx) { + scalar d = atoms[idx].dot; + int idmin = beg; + int idmax = end-1; + while (idmin < idmax) { + while (idmin < end && atoms[idmin].dot < d){ + idmin++; + } + while (idmax >= beg && atoms[idmax].dot > d) + idmax--; + if (idmin >= idmax) + break; + if (idmin < idmax) + std::swap(atoms[idmin],atoms[idmax]); + } + return idmin; + } + + + Vector getSlice(ids &idA,ids &idB, int b, int e) { + return sampleUnitGaussian(dim); + } + + void computeDots(ids& idA,ids& idB,int begA,int endA,int begB,int endB,const Vector& d) { + for (auto i : range(begA,endA)) + idA[i].dot = d.dot(A.col(idA[i].id)); + for (auto i : range(begB,endB)) + idB[i].dot = d.dot(B.col(idB[i].id)); + } + + bool random_pivot = true; + Mat sliceBasis; + bool hasSliceBasis = false; + + int best_choice(int a,ids& idB,int b,int e) { + if (e - b == 0) { + std::cerr << "error gap null" << std::endl;; + } + int best = 0; + scalar score = 1e8; + for (auto i : range(b,e)) { + scalar s = cost(a,idB[i].id); + if (s < score) { + best = i; + score = s; + } + } + return best; + } + + void partialBSPOT(ints& plan,ids &idA, ids &idB, int begA, int endA,int begB,int endB,int height = 0) { + auto gap = (endA-begA); + if (gap == 1){ + int a = idA[begA].id; + plan[a] = idB[best_choice(a,idB,begB,endB)].id; + return; + } + const Vector d = hasSliceBasis ? sliceBasis.col(height % dim) : sampleUnitGaussian(dim); + + computeDots(idA,idB,begA,endA,begB,endB,d); + + int pivotA = random_pivot ? randint(begA+1,endA-1) : begA + (endA-begA)/2; + std::nth_element(idA.begin()+begA,idA.begin() + pivotA,idA.begin() + endA); + + if (endB - begB == gap) { + int pivotB = begB + pivotA - begA; + std::nth_element(idB.begin()+begB,idB.begin() + pivotB,idB.begin() + endB); + partialBSPOT(plan,idA,idB,begA,pivotA,begB,pivotB,height + 1); + partialBSPOT(plan,idA,idB,pivotA,endA,pivotB,endB,height + 1); + return; + } + + + int nb_left = pivotA - begA; + int nb_right = endA - pivotA; + + std::nth_element(idB.begin()+ begB,idB.begin() + begB + nb_left,idB.begin() + endB); + std::nth_element(idB.begin() + begB + nb_left,idB.begin() + endB - nb_right,idB.begin() + endB); + // std::sort(idB.begin() + begB,idB.begin() + endB); + + int pivotB = best_choice(idA[pivotA].id,idB,begB + nb_left,endB - nb_right); + pivotB = partition(idB,begB + nb_left,endB - nb_right,pivotB); + + partialBSPOT(plan,idA,idB,begA,pivotA,begB,pivotB,height+1); + partialBSPOT(plan,idA,idB,pivotA,endA,pivotB,endB,height+1); + } + +public: + + PartialBSPMatching(const Pts& A_,const Pts& B_,const cost_function& c) : A(A_),B(B_),cost(c) { + if (A.cols() > B.cols()) { + std::cerr << "B must be the larger cloud" << std::endl;; + } + dim = A.rows(); + if (D != -1 && dim != D) { + std::cerr << "dynamic dimension is different from static one !" << std::endl;; + } + } + + InjectiveMatching computePartialMatching(const Eigen::Matrix& M,bool rp = false){ + sliceBasis = M; + hasSliceBasis = true; + return computePartialMatching(rp); + } + + + InjectiveMatching computePartialMatching(bool random_pivot = true){ + ids idA(A.cols()),idB(B.cols()); + for (auto i : range(A.cols())) + idA[i].id = i; + for (auto i : range(B.cols())) + idB[i].id = i; + + this->random_pivot = random_pivot; + ints plan = TransportPlan(A.cols(),-1); + partialBSPOT(plan,idA,idB,0,A.cols(),0,B.cols()); + std::set image; + for (auto i : range(A.cols())) { + if (plan[i] == -1){ + std::cout << "unassigned" << i << std::endl; + } + else + image.insert(plan[i]); + } + if (image.size() != A.cols()) + std::cerr << "not injective" << std::endl;; + return InjectiveMatching(plan,B.cols()); + } + +}; +} + +#endif // PARTIALBSPMATCHING_H + + +// end --- PartialBSPMatching.h --- + + + + + + +// begin --- PointCloudIO.h --- + +#ifndef POINTCLOUDIO_H +#define POINTCLOUDIO_H +#include +#include + +namespace BSPOT { + +/* +template +inline Points ReadPointCloud(std::filesystem::path path) { + std::ifstream infile(path); + + if (!infile.is_open()) { + std::cerr << "Error opening file: " << path.filename() << std::endl; + throw std::runtime_error("File not found"); + } + + std::vector data; // Store all values in a single contiguous array + int rows = 0; + std::string line; + int dim = D; + + // First pass: Read the file and store numbers in a vector + while (std::getline(infile, line)) { + std::istringstream iss(line); + double num; + int current_cols = 0; + + while (iss >> num) { + data.push_back(num); + ++current_cols; + } + + + if (dim == -1) + dim = current_cols; + if (current_cols != dim) { + throw std::runtime_error("Inconsistent dimensions in point cloud file or static dim != point cloud dim"); + } + ++rows; + } + + // Second pass: Copy the data into an Eigen matrix + // where each col is a point + Points pointCloud(dim, rows); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < dim; ++j) { + pointCloud(j, i) = data[i * dim + j]; + } + } + return pointCloud; +} + + +template +void WritePointCloud(std::filesystem::path path,const Points& pts) { + // each row is a point + std::ofstream outfile(path); + if (!outfile.is_open()) { + std::cerr << "Error opening file: " << path.filename() << std::endl; + return; + } + + for (int i = 0; i < pts.cols(); ++i) { + for (int j = 0; j < pts.rows(); ++j) { + outfile << pts(j, i); + if (j < pts.rows() - 1) { + outfile << " "; + } + } + outfile << "\n"; + } +} +*/ + +} + +#endif // POINTCLOUDIO_H + + +// end --- PointCloudIO.h --- + + + +// begin --- BijectiveBSPMatching.h --- + +#ifndef BIJECTIVEBSPMATCHING_H +#define BIJECTIVEBSPMATCHING_H + +namespace BSPOT { + +template +class BijectiveBSPMatching { +public: + using TransportPlan = ints; + + using Pts = BSPOT::Points; + const Pts& A; + const Pts& B; + int dim; + +protected: + + struct dot_id { + scalar dot; + int id; + bool operator<(const dot_id& other) const { + return dot < other.dot; + } + }; + + using ids = std::vector; + struct SliceView { + const ids& id; + int b,e; + + int operator[](int i) const {return id[b + i].id;} + + int size() const {return e - b;} + + SliceView(const ids& id,int b,int e) : id(id),b(b),e(e){} + }; + + + static Moments computeMoments(const Pts& mat,const ids& I,int b,int e) { + SliceView view(I,b,e); + thread_local static Pts sub; + sub = mat(Eigen::all,view); + Vector mean = sub.rowwise().mean(); + + CovType rslt = CovType::Zero(mat.rows(),mat.rows()); + for (auto i : range(sub.cols())){ + Vector c = sub.col(i) - mean; + rslt += c*c.transpose()/scalar(e-b); + } + // Pts centered = sub.colwise() - mean; + // CovType rslt = centered * centered.adjoint() / double(e-b); + return {mean,rslt}; + } + + + + Vector getSlice(ids &idA,ids &idB, int b, int e) { + return sampleUnitGaussian(dim); + } + + void BSP(ids& idA,ids& idB,int beg,int end,int pivot,const Vector& d) { + + for (auto i : range(beg,end)) { + idA[i].dot = d.dot(A.col(idA[i].id));// + sampleUnitGaussian<1>()(0)*0e-3; + idB[i].dot = d.dot(B.col(idB[i].id));// + sampleUnitGaussian<1>()(0)*0e-3; + } + std::nth_element(idA.begin() + beg,idA.begin() + pivot,idA.begin() + end); + std::nth_element(idB.begin() + beg,idB.begin() + pivot,idB.begin() + end); + } + + + bool random_pivot = true; + + std::pair,Moments> decomposeMoments(const Pts& X,const Moments& M, const ids& id, int beg, int end,int pivot) { + scalar alpha = scalar(pivot - beg)/scalar(end - beg); + scalar beta = 1 - alpha; + + auto [ML,CL] = computeMoments(X,id,beg,pivot); + + Vector MR = (M.mean - alpha*ML)/beta; + CovType DL = (M.mean - ML)*(M.mean - ML).transpose(); + CovType DR = (M.mean - MR)*(M.mean - MR).transpose(); + CovType CR = CovType(M.Cov - alpha*(CL + DL))/beta - DR; + + return {{ML,CL},{MR,CR}}; + } + + bool init_mode = false; + + Vector DrawEigenVector(const CovType >) { + Eigen::SelfAdjointEigenSolver> solver(GT); + return solver.eigenvectors().col(randint(0,dim-1)); + } + + + Vector gaussianSlice(const Moments& MA,const Moments& MB) { + CovType GT = W2GaussianTransportMap(MA.Cov,MB.Cov); + return DrawEigenVector(GT); + } + + + void gaussianPartialBSPOT(ids &idA, ids &idB, int beg, int end, const Moments& MA,const Moments& MB) { + auto gap = (end-beg); + if (gap == 0){ + std::cerr << "end - beg == 0" << std::endl;; + return; + } + if (gap == 1) + return; + if (gap < 50) { + // random_pivot = true; + // partialBSPOT(idA,idB,beg,end); + partialOrthogonalBSPOT(idA,idB,beg,end,sampleUnitGaussian(dim)); + // random_pivot = false; + return; + } + + const Vector d = gaussianSlice(MA,MB); + + + // int pivot = randint(beg + gap/4,beg + gap*3/4); + int pivot = random_pivot ? randint(beg+1,end-1) : beg + (end-beg)/2; + + // for (auto i : range(beg,end)) { + // idA[i].dot = d.dot(A.col(idA[i].id)); + // idB[i].dot = d.dot(B.col(idB[i].id)); + // } + // std::nth_element(idA.begin() + beg,idA.begin() + pivot,idA.begin() + end); + // std::nth_element(idB.begin() + beg,idB.begin() + pivot,idB.begin() + end); + BSP(idA,idB,beg,end,pivot,d); + + auto SMA = decomposeMoments(A,MA,idA,beg,end,pivot); + auto SMB = decomposeMoments(B,MB,idB,beg,end,pivot); + + gaussianPartialBSPOT(idA,idB,beg,pivot,SMA.first,SMB.first); + gaussianPartialBSPOT(idA,idB,pivot,end,SMA.second,SMB.second); + } + + Mat sliceBasis; + bool hasSliceBasis = false; + + void partialBSPOT(ids &idA, ids &idB, int beg, int end,int height = 0) { + auto gap = (end-beg); + if (gap == 0){ + std::cerr << "end - beg == 0" << std::endl;; + } + if (gap == 1){ + return; + } + int pivot = random_pivot ? randint(beg+1,end-1) : beg + (end-beg)/2; + const Vector d = hasSliceBasis ? sliceBasis.col(height % dim) : getSlice(idA,idB,beg,end); + BSP(idA,idB,beg,end,pivot,d); + partialBSPOT(idA,idB,beg,pivot,height+1); + partialBSPOT(idA,idB,pivot,end,height+1); + } + + /* + void selectBSPOT(std::map& T,ids &idA, ids &idB, int beg, int end,std::set targets,int height = 0) { + auto gap = (end-beg); + if (gap == 0){ + std::cerr << "end - beg == 0" << std::endl;; + } + if (gap == 1){ + if (!targets.contains(idA[beg].id)) + std::cerr << "target not found" << std::endl;; + T[idA[beg].id] = idB[beg].id; + return; + } + int pivot = random_pivot ? randint(beg+1,end-1) : beg + (end-beg)/2; + const Vector d = hasSliceBasis ? sliceBasis.col(height % dim) : getSlice(idA,idB,beg,end); + BSP(idA,idB,beg,end,pivot,d); + std::set L,R; + for (auto i : range(beg,pivot)) + if (targets.contains(idA[i].id)) + L.insert(idA[i].id); + for (auto i : range(pivot,end)) + if (targets.contains(idA[i].id)) + R.insert(idA[i].id); + if (L.size()) + selectBSPOT(T,idA,idB,beg,pivot,L,height+1); + if (R.size()) + selectBSPOT(T,idA,idB,pivot,end,R,height+1); + } + */ + + + + void partialOrthogonalBSPOT(ids &idA, ids &idB, int beg, int end,Vector prev_slice) { + auto gap = (end-beg); + if (gap == 0){ + std::cerr << "end - beg == 0" << std::endl;; + //return; + } + if (gap == 1){ + return; + } + int pivot = random_pivot ? randint(beg+1,end-1) : beg + (end-beg)/2; + Vector d = getSlice(idA,idB,beg,end); + d -= d.dot(prev_slice)*prev_slice/prev_slice.squaredNorm(); + d.normalized(); + BSP(idA,idB,beg,end,pivot,d); + partialOrthogonalBSPOT(idA,idB,beg,pivot,d); + partialOrthogonalBSPOT(idA,idB,pivot,end,d); + } + + + +public: + + BijectiveBSPMatching(const Pts& A_,const Pts& B_) : A(A_),B(B_) { + dim = A.rows(); + if (D != -1 && dim != D) { + std::cerr << "dynamic dimension is different from static one !" << std::endl;; + } + } + + std::map quickselectTransport(const std::set& targets,const Mat& _sliceBasis) { + sliceBasis = _sliceBasis; + hasSliceBasis = true; + return quickselectTransport(targets); + } + + /* + std::map quickselectTransport(const std::set& targets) { + ids idA(A.cols()),idB(B.cols()); + for (auto i : range(A.cols())) { + idA[i].id = i; + idB[i].id = i; + } + std::map T; + selectBSPOT(T,idA,idB,0,A.cols(),targets); + return T; + } + */ + + + BijectiveMatching computeMatching(bool random_pivot = true){ + ids idA(A.cols()),idB(B.cols()); + for (auto i : range(A.cols())) { + idA[i].id = i; + idB[i].id = i; + } + + this->random_pivot = random_pivot; + partialBSPOT(idA,idB,0,A.cols()); + + ints plan = TransportPlan(A.cols()); + for (int i = 0;irandom_pivot = random_pivot_; + partialBSPOT(idA,idB,0,A.cols()); + + ints plan = TransportPlan(A.cols()); + for (int i = 0;i meanA = A.rowwise().mean(); + Vector meanB = B.rowwise().mean(); + Moments MA = {meanA,Covariance(A)}; + Moments MB = {meanB,Covariance(B)}; + + gaussianPartialBSPOT(idA,idB,0,A.cols(),MA,MB); + + + ints plan = TransportPlan(A.cols()); + for (int i = 0;i computeGaussianMatchingOrders(){ + ids idA(A.cols()),idB(B.cols()); + for (auto i : range(A.cols())) { + idA[i].id = i; + idB[i].id = i; + } + + random_pivot = false; + + Vector meanA = A.rowwise().mean(); + Vector meanB = B.rowwise().mean(); + Moments MA = {meanA,Covariance(A)}; + Moments MB = {meanB,Covariance(B)}; + + + + // partialBSPOT(idA,idB,0,A.cols()); + // partialOrthogonalBSPOT(idA,idB,0,A.cols(),sampleUnitGaussian(dim)); + gaussianPartialBSPOT(idA,idB,0,A.cols(),MA,MB); + ints OA(A.cols()),OB(A.cols()); + for (int i = 0;i + +namespace BSPOT { + +template +class GeneralBSPMatching { +public: +protected: + using Pts = Points; + + int dim; + + const Pts& A; + const Pts& B; + + Atoms mu,nu; + Atoms src_mu; + Atoms src_nu; + + struct CDFSplit { + int id; + scalar rho; + }; + + std::vector triplets; + + struct atom_split { + int id = -1; + scalar mass_left,mass_right; + }; + + Pts Grad; + scalar W = 0; + bool random_pivot = true; + Coupling coupling; + + struct SliceView { + const Atoms& id; + int b,e; + + int operator[](int i) const {return id[b + i].id;} + + int size() const {return e - b;} + }; + + +public: + + GeneralBSPMatching(const Pts& A_,const Atoms& MU,const Pts& B_,const Atoms& NU) : src_mu(MU),src_nu(NU),A(A_),B(B_) { + dim = A.rows(); + if (D != -1 && dim != D) { + std::cerr << "dynamic dimension is different from static one !" << std::endl;; + } + mu.resize(MU.size()); + nu.resize(NU.size()); + Grad = Pts::Zero(dim,MU.size()); + coupling = Coupling(mu.size(),nu.size()); + } + + GeneralBSPMatching(const Pts& A_,const Pts& B_,bool random_pivot = true) : A(A_),B(B_) { + dim = A.rows(); + if (D != -1 && dim != D) { + std::cerr << "dynamic dimension is different from static one !" << std::endl;; + } + Grad = Pts::Zero(dim,A.cols()); + coupling = Coupling(A.cols(),B.cols()); + } + +protected: + + + CDFSplit partition(Atoms &atoms, int beg, int end, int idx) { + scalar d = atoms[idx].dot; + int idmin = beg; + int idmax = end-1; + scalar sum_min = 0; + while (idmin < idmax) { + while (idmin < end && atoms[idmin].dot < d){ + sum_min += atoms[idmin].mass; + idmin++; + } + while (idmax >= beg && atoms[idmax].dot > d) + idmax--; + if (idmin >= idmax) + break; + if (idmin < idmax) + std::swap(atoms[idmin],atoms[idmax]); + } + return {idmin,sum_min}; + } + + CDFSplit quickCDF(Atoms &atoms, int beg, int end, scalar rho, scalar sum) { + if (end - beg == 1) + return {beg,sum}; + int idx = getRandomPivot(beg,end-1); + auto [p,sum_min] = partition(atoms,beg,end,idx); + if (sum_min >= rho){ + return quickCDF(atoms,beg,p,rho,sum); + } + else + return quickCDF(atoms,p,end,rho - sum_min,sum + sum_min); + } + + CDFSplit quickCDF(Atoms &atoms, int beg, int end, scalar rho) { + return quickCDF(atoms,beg,end,rho,0); + } + + int dotMedian(const Atoms &atoms, int a, int b, int c) { + const auto& da = atoms[a].dot; + const auto& db = atoms[b].dot; + const auto& dc = atoms[c].dot; + if ((da >= db && da <= dc) || (da >= dc && da <= db)) return a; + if ((db >= da && db <= dc) || (db >= dc && db <= da)) return b; + return c; + } + + CDFSplit partitionCDF(Atoms &atoms, int beg, int end) { + if (end - beg == 2) { + if (atoms[beg].dot > atoms[beg+1].dot) + std::swap(atoms[beg],atoms[beg+1]); + return {beg+1,atoms[beg].mass}; + } + int rand_piv = getRandomPivot(beg+1,end-2); + int piv = dotMedian(atoms,rand_piv,beg,end-1); + //spdlog::info("start partition b{} p{} e{}",beg,piv,end); + return partition(atoms,beg,end,piv); + } + + atom_split splitCDF(Atoms &atoms, int beg, int end, scalar rho) { + auto selected = quickCDF(atoms,beg,end,rho); + scalar mass_left = rho - selected.rho; + scalar mass_right = atoms[selected.id].mass - mass_left; + + return {selected.id,mass_left,mass_right}; + } + + void computeDots(Atoms &atoms, const Pts &X, int beg, int end, const Vector &d) { + for (auto i : range(beg,end)) + atoms[i].dot = X.col(atoms[i].id).dot(d) + i*1e-8; + } + + CovType slice_basis; + bool slice_basis_computed = false; + + Vector getSlice(const Atoms &m, int begA, int endA, const Atoms &n, int begB, int endB,int h) const + { + if (slice_basis_computed) + return slice_basis.col(h % dim); + if (endA - begA < 50 || endB - begB < 50) + return sampleUnitGaussian(dim); + return sampleUnitGaussian(dim); + CovType CovA = Cov(A,m,begA,endA); + CovType CovB = Cov(B,n,begB,endB); + CovType T = W2GaussianTransportMap(CovA,CovB); + Eigen::SelfAdjointEigenSolver solver(T); + return solver.eigenvectors().col(getRandomPivot(0,T.cols()-1)); + } + + int getRandomPivot(int beg, int end) const { + if (beg == end) + return beg; + if (end < beg) + std::cerr << "invalid pivot range" << std::endl;; + static thread_local std::random_device rd; + static thread_local std::mt19937 rng(rd()); + std::uniform_int_distribution gen(beg, end); // uniform, unbiased + return gen(rng); + } + + bool checkMassLeak(int begA, int endA, int begB, int endB) const { + scalar sumA = 0,sumB = 0; + for (auto i : range(begA,endA)) + sumA += mu[i].mass; + for (auto i : range(begB,endB)) + sumB += nu[i].mass; + if (std::abs(sumA - sumB) > 1e-8){ + return true; + } + return false; + } + + void partialBSPOT(int begA, int endA, int begB, int endB,int height = 0) { + int gapA = endA - begA; + int gapB = endB - begB; + + if (gapA == 0 || gapB == 0){ + std::cerr << "null gap" << std::endl;; + return; + } + + // checkMassLeak(begA,endA,begB,endB); + + + if (gapA == 1) { + for (auto i : range(begB,endB)) { + if (nu[i].mass < 1e-12) + continue; + Grad.col(mu[begA].id) += (B.col(nu[i].id) - A.col(mu[begA].id))*nu[i].mass; + triplet t = {mu[begA].id,nu[i].id,nu[i].mass}; + triplets.push_back(t); + } + return; + } + if (gapB == 1) { + for (auto i : range(begA,endA)) { + if (mu[i].mass < 1e-12) + continue; + Grad.col(mu[i].id) += (B.col(nu[begB].id) - A.col(mu[i].id))*mu[i].mass; + triplet t = {mu[i].id,nu[begB].id,mu[i].mass}; + triplets.push_back(t); + } + return; + } + const Vector d = getSlice(mu,begA,endA,nu,begB,endB,height); + + computeDots(mu,A,begA,endA,d); + computeDots(nu,B,begB,endB,d); + + CDFSplit CDFS; + if (random_pivot) { + CDFS = partitionCDF(mu,begA,endA); + } + else { + scalar sumA = 0; + for (auto i : range(begA,endA)) + sumA += mu[i].mass; + CDFS = quickCDF(mu,begA,endA,0.5*sumA); + if (CDFS.id == begA) { + CDFS.rho = mu[CDFS.id].mass; + CDFS.id++; + } + } + int p = CDFS.id; + scalar rho = CDFS.rho; + auto split = splitCDF(nu,begB,endB,rho); + int splitted_atom = nu[split.id].id; + + nu[split.id].mass = split.mass_left; + partialBSPOT(begA,p,begB,split.id+1,height + 1); + + nu[split.id].id = splitted_atom; + nu[split.id].mass = split.mass_right; + partialBSPOT(p,endA,split.id,endB,height + 1); + } + + void init() { + for (auto i : range(src_mu.size())) + mu[i] = src_mu[i]; + for (auto i : range(src_nu.size())) + nu[i] = src_nu[i]; + Grad = Pts::Zero(dim,A.cols()); + triplets.clear(); + coupling.setZero(); + } + + void setMeasures(const Atoms &mu_, const Atoms &nu_) + { + src_mu = mu_; + src_nu = nu_; + mu.resize(mu_.size()); + nu.resize(nu_.size()); + } + + Moments computeMoments(const Pts& X,const Atoms& id,int b,int e) const { + Vec masses(e-b); + scalar S = 0; + for (auto i : range(b,e)){ + masses(i) = id[i].mass; + S += id[i].mass; + } + Eigen::DiagonalMatrix M = (masses/S).asDiagonal(); + SliceView view(id,b,e); + Pts sub = X(Eigen::all,view); + Pts wsub = sub*M; + Vector mean = wsub.rowwise().sum(); + Pts centered = sub.colwise() - mean; + CovType rslt = (centered*M) * centered.adjoint() / double(e-b); + return {mean,rslt}; + + } + + Vector getMean(const Pts &X, const Atoms &id, int b, int e) const + { + Vector m = Vector::Zero(dim); + scalar s = 0; + for (auto i : range(b,e)) { + m += X.col(id[i].id)*id[i].mass; + s += id[i].mass; + } + return m/s; + } + + CovType Cov(const Pts &X, const Atoms &atoms, int b, int e) const + { + Vector m = getMean(X,atoms,b,e); + CovType Cov = CovType::Zero(dim,dim); + scalar s = 0; + for (auto i : range(b,e)) { + Vector x = X.col(atoms[i].id) - m; + Cov.noalias() += x*x.transpose()*atoms[i].mass; + s += atoms[i].mass; + } + return Cov/s; + } + +public: + + const Coupling &computeCoupling(bool rp = true){ + init(); + random_pivot = rp; + if (checkMassLeak(0,src_mu.size(),0,src_nu.size())) { + std::cerr << "cannot compute plan to unbalanced marginals" << std::endl;; + } + partialBSPOT(0,src_mu.size(),0,src_nu.size()); + coupling.setFromTriplets(triplets.begin(),triplets.end()); + //coupling.makeCompressed(); + return coupling; + } + + const Coupling &computeOrthogonalCoupling(const CovType& slice_basis = CovType::Identity(D,D)){ + this->slice_basis = slice_basis; + slice_basis_computed = true; + return computeCoupling(false); + } + + + const Pts &computeTransportGradient(bool random_pivot = true){ + init(); + this->random_pivot = random_pivot; + partialBSPOT(0,src_mu.size(),0,src_nu.size()); + for (auto i : range(src_mu.size())) + Grad.col(i) /= src_mu[i].mass; + return Grad; + } + + const Pts &computeOrthogonalTransportGradient(const CovType& slice_basis = CovType::Identity(D,D),bool rp = false){ + this->slice_basis = slice_basis; + slice_basis_computed = true; + return computeTransportGradient(rp); + } +}; + +} + +#endif // GENERALBSPMATCHING_H + + +// end --- GeneralBSPMatching.h --- + + + +// begin --- BSPOTWrapper.h --- + +#ifndef BSPOTWRAPPER_H +#define BSPOTWRAPPER_H + +namespace BSPOT { + +/* +BijectiveMatching MergePlans(const std::vector& plans,const cost_function& cost,BijectiveMatching T = BijectiveMatching()) { + std::vector> scores(plans.size()); +#pragma omp parallel for + for (int i = 0;i +BijectiveMatching computeGaussianBSPOT(const Points& A,const Points& B,int nb_plans,const cost_function& cost,BijectiveMatching T = BijectiveMatching()) { + std::vector plans(nb_plans); + BijectiveBSPMatching BSP(A,B); +#pragma omp parallel for + for (int i = 0;i +BijectiveMatching computeBijectiveBSPOT(const Points& A,const Points& B,int nb_plans,const cost_function& cost,BijectiveMatching T = BijectiveMatching()) { + std::vector plans(nb_plans); + BijectiveBSPMatching BSP(A,B); + int d = A.rows(); +#pragma omp parallel for + for (int i = 0;i +BijectiveMatching computeBijectiveOrthogonalBSPOT(const Points& A,const Points& B,int nb_plans,const cost_function& cost,BijectiveMatching T = BijectiveMatching()) { + std::vector plans(nb_plans); + BijectiveBSPMatching BSP(A,B); + int d = A.rows(); +#pragma omp parallel for + for (int i = 0;i +Coupling computeBSPOTCoupling(const Points& A,const Atoms& mu,const Points& B,const Atoms& nu) { + GeneralBSPMatching BSP(A,mu,B,nu); + return BSP.computeCoupling(); +} + +template +Points computeBSPOTGradient(const Points& A,const Atoms& mu,const Points& B,const Atoms& nu,int nb_plans) { + Points Grad = Points::Zero(A.rows(),A.cols()); + int d = A.rows(); +#pragma omp parallel for + for (int i = 0;i Grad_i = BSP.computeTransportGradient(); + #pragma omp critical + { + Grad += Grad_i/nb_plans; + } + } + return Grad; +} + + +template +InjectiveMatching computePartialBSPOT(const Points& A,const Points& B,int nb_plans,const cost_function& cost,InjectiveMatching T = InjectiveMatching()) { + std::vector plans(nb_plans); + PartialBSPMatching BSP(A,B,cost); +#pragma omp parallel for + for (int i = 0;i +InjectiveMatching computePartialOrthogonalBSPOT(const Points& A,const Points& B,int nb_plans,const cost_function& cost,InjectiveMatching T = InjectiveMatching()) { + std::vector plans(nb_plans); + PartialBSPMatching BSP(A,B,cost); +#pragma omp parallel for + for (int i = 0;i Q = sampleUnitGaussianMat(dim,dim).fullPivHouseholderQr().matrixQ(); + plans[i] = BSP.computePartialMatching(Q,false); + } + return MergePlans(plans,cost,T); + // InjectiveMatching plan = T; + // for (int i = 0;i +# +# License: MIT License + +import numpy as np +cimport numpy as np + +cimport cython +cimport libc.math as math +from libc.stdint cimport uint64_t + + +cdef extern from "bsp_wrapper.h": + double BSPOT_wrap(int n, int d, double *X, double *Y, uint64_t nb_plans, int *plans, int *plan,const char* cost_name,int* initial_plan) + double MergeBijections(int n, int d, double *X, double *Y, uint64_t nb_plans, int *plans, int *plan,const char* cost_name) + + +@cython.boundscheck(False) +@cython.wraparound(False) +def bsp_solve(np.ndarray[double, ndim=2, mode="c"] X, np.ndarray[double, ndim=2, mode="c"] Y, int n_plans=64,str cost_name="sqnorm",np.ndarray[int,ndim=1,mode="c"] initial_plan = None): + """ + + Builds nb_plans BSP Matchings and merges them in a single bijection. + + cost,plan,plans = bsp_solve(X,Y,n_plans) + + where : + + - X and Y are the input point clouds + - n_plans is the number of BSP Matchings used to compute the final bijection + + Returns the transport cost of the final bijection, the final bijection, and the intermediary ones + + """ + cdef int n = X.shape[0] + cdef int d = X.shape[1] + cdef np.ndarray[int, ndim=2, mode="c"] plans = np.zeros((n, n_plans), dtype=np.int32) + cdef np.ndarray[int, ndim=1, mode="c"] plan = np.zeros(n, dtype=np.int32) + + cdef double cost + + cdef bytes cost_bytes = cost_name.encode("utf-8") + cdef const char* cost_c = cost_bytes + + if initial_plan is None: + cost = BSPOT_wrap(n, d, X.data, Y.data, n_plans, plans.data, plan.data,cost_c, NULL) + else: + cost = BSPOT_wrap(n, d, X.data, Y.data, n_plans, plans.data, plan.data,cost_c, initial_plan.data) + + # add + + return cost, plan, plans + + +@cython.boundscheck(False) +@cython.wraparound(False) +def merge_bijections(np.ndarray[double, ndim=2, mode="c"] X, np.ndarray[double, ndim=2, mode="c"] Y, np.ndarray[int, ndim=2, mode="c"] plans,str cost = "sqnorm"): + """ + Merges transport bijections + + where : + + - X and Y are the input point clouds + - plans input bijections + - metric name, by default "sqnorm" + + Returns the merged bijection and its transport cost. + """ + + cdef int n = X.shape[0] + cdef int d = X.shape[1] + cdef int k = plans.shape[1] + cdef np.ndarray[int, ndim=1, mode="c"] plan = np.zeros(n, dtype=np.int32) + + cdef double cost_val + + cdef bytes cost_bytes = cost.encode("utf-8") + cdef const char* cost_c = cost_bytes + + # add merging code here + + cost_val = MergeBijections(n, d, X.data, Y.data, k, plans.data, plan.data,cost_c) + + + return cost_val,plan + + + \ No newline at end of file diff --git a/ot/bsp/bsp_wrapper.cpp b/ot/bsp/bsp_wrapper.cpp new file mode 100644 index 000000000..98f4fcc22 --- /dev/null +++ b/ot/bsp/bsp_wrapper.cpp @@ -0,0 +1,157 @@ +#include "bsp_wrapper.h" +#include "BSP-OT_header_only.h" + +template +BSPOT::Points UnLinearize(double* data,int n,int d) { + return Eigen::Map>(data, d, n).template cast(); +} + +template +std::function makeCost(const BSPOT::Points& A,const BSPOT::Points& B,std::string cost){ + if (cost == "sqnorm") { + return [&](int i,int j) { + return (A.col(i) - B.col(j)).squaredNorm(); + }; + } + if (cost == "norm") { + return [&](int i,int j) { + return (A.col(i) - B.col(j)).norm(); + }; + } + return [&](int i,int j) { + return (A.col(i) - B.col(j)).squaredNorm(); + }; +} + + +template +std::vector computeBSPMatchings_dim(const BSPOT::Points& A,const BSPOT::Points& B,int nb_plans, bool gaussian = true){ + using namespace BSPOT; + + if (A.rows() > 64) + gaussian = false; + + std::vector plans(nb_plans); + BijectiveBSPMatching BSP(A,B); +#pragma omp parallel for + for (auto& plan : plans) { + if (gaussian) + plan = BSP.computeGaussianMatching(); + else + plan = BSP.computeMatching(true); + } + return plans; +} + + +BSPOT::BijectiveMatching MergeBijections(const std::vector& matchings,const std::function& cost,const BSPOT::BijectiveMatching& T0 = BSPOT::BijectiveMatching()) { + using namespace BSPOT; + return MergePlansNoPar(matchings,cost,T0); +} + + +template +double BSPOT_wrap_dim(int n, int d, double *X, double *Y, uint64_t nb_plans,std::vector& plans, BSPOT::BijectiveMatching& plan,std::string cost_name,const BSPOT::BijectiveMatching& T0) { + auto A = UnLinearize(X,n,d); + auto B = UnLinearize(Y,n,d); + plans = computeBSPMatchings_dim(A,B,nb_plans); + auto cost_func = makeCost(A,B,cost_name); + plan = MergeBijections(plans,cost_func,T0); + + return plan.evalMatching(cost_func); +} + +double BSPOT_wrap(int n, int d, double *X, double *Y, uint64_t nb_plans, int *plans_ptr, int *final_plan_ptr,const char* cn,int* initial_plan) { + using namespace BSPOT; + + std::string cost_name(cn); + + BijectiveMatching T0; + + if (initial_plan){ + ints t0(n); + std::copy(initial_plan,initial_plan+n,t0.begin()); + T0 = BijectiveMatching(t0); + } + + + std::vector plans; + BijectiveMatching plan; + scalar cost; + + switch (d) + { + case 2: {cost = BSPOT_wrap_dim<2>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 3: {cost = BSPOT_wrap_dim<3>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 4: {cost = BSPOT_wrap_dim<4>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 5: {cost = BSPOT_wrap_dim<5>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 6: {cost = BSPOT_wrap_dim<6>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 7: {cost = BSPOT_wrap_dim<7>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 8: {cost = BSPOT_wrap_dim<8>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 9: {cost = BSPOT_wrap_dim<9>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + case 10: {cost = BSPOT_wrap_dim<10>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + default: {cost = BSPOT_wrap_dim<-1>(n,d,X,Y,nb_plans,plans,plan,cost_name,T0);break;} + } + + std::copy(plan.getPlan().begin(), plan.getPlan().end(), final_plan_ptr); + + int* dst = plans_ptr; + for (const auto& p : plans) + { + std::copy(p.getPlan().begin(), p.getPlan().end(), dst); + dst += p.size(); // == M + } + + return cost; +} + + +template +double MergeBijections_dim(int n, int d, double *X, double *Y, uint64_t nb_plans,const std::vector& plans, BSPOT::BijectiveMatching& plan,std::string cost_name) { + auto A = UnLinearize(X,n,d); + auto B = UnLinearize(Y,n,d); + auto cost_func = makeCost(A,B,cost_name); + plan = MergeBijections(plans,cost_func); + + return plan.evalMatching(cost_func); +} + + +double MergeBijections(int n, int d, double *X, double *Y, uint64_t nb_plans, int *plans_ptr, int *final_plan_ptr,const char* cn) { + using namespace BSPOT; + + std::string cost_name(cn); + + std::vector plans(nb_plans); + + for (std::size_t i = 0; i < nb_plans; ++i) + { + std::vector bij(n); + std::copy( + plans_ptr + i * n, + plans_ptr + (i + 1) * n, + bij.begin() + ); + plans[i] = BijectiveMatching(bij); + } + + BijectiveMatching plan; + scalar cost; + + switch (d) + { + case 2: {cost = MergeBijections_dim<2>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 3: {cost = MergeBijections_dim<3>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 4: {cost = MergeBijections_dim<4>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 5: {cost = MergeBijections_dim<5>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 6: {cost = MergeBijections_dim<6>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 7: {cost = MergeBijections_dim<7>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 8: {cost = MergeBijections_dim<8>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 9: {cost = MergeBijections_dim<9>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + case 10: {cost = MergeBijections_dim<10>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + default: {cost = MergeBijections_dim<-1>(n,d,X,Y,nb_plans,plans,plan,cost_name);break;} + } + + std::copy(plan.getPlan().begin(), plan.getPlan().end(), final_plan_ptr); + return cost; +} diff --git a/ot/bsp/bsp_wrapper.h b/ot/bsp/bsp_wrapper.h new file mode 100644 index 000000000..853c7afb5 --- /dev/null +++ b/ot/bsp/bsp_wrapper.h @@ -0,0 +1,12 @@ +#ifndef BSP_WRAPPER_H +#define BSP_WRAPPER_H + +#include +#include + + + +double BSPOT_wrap(int n, int d, double *X, double *Y, uint64_t nb_plans, int *plans, int *plan,const char* cost,int* initial_plan); +double MergeBijections(int n, int d, double *X, double *Y, uint64_t nb_plans, int *plans, int *plan,const char* cost); + +#endif diff --git a/setup.py b/setup.py index acbe5aed9..74b6b5ef2 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ link_args += flags if sys.platform.startswith("darwin"): - compile_args.append("-stdlib=libc++") + compile_args.append("-std=c++17") # Need for ot/bsp sdk_path = subprocess.check_output(["xcrun", "--show-sdk-path"]) os.environ["CFLAGS"] = '-isysroot "{}"'.format(sdk_path.rstrip().decode("utf-8")) @@ -85,6 +85,21 @@ extra_compile_args=compile_args, language="c++", ), + Extension( + name="ot.bsp.bsp_wrap", + sources=[ + "ot/bsp/bsp_wrap.pyx", + "ot/bsp/bsp_wrapper.cpp", + ], # cython/c++ src files + language="c++", + include_dirs=[ + numpy.get_include(), + os.path.join(ROOT, "deps/eigen"), + os.path.join(ROOT, "ot/lp"), + ], + extra_compile_args=compile_args, + extra_link_args=link_args, + ), ] ), platforms=["linux", "macosx", "windows"],