Skip to content

Commit b3be5a6

Browse files
committed
gaussian init
1 parent 53dde7a commit b3be5a6

File tree

6 files changed

+109
-7
lines changed

6 files changed

+109
-7
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
334334

335335
[59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017.
336336

337+
[60] Thornton, James, and Marco Cuturi. [Rethinking initialization of the sinkhorn algorithm](https://arxiv.org/pdf/2206.07630.pdf). International Conference on Artificial Intelligence and Statistics. PMLR, 2023.

ot/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d,
4141
binary_search_circle, wasserstein_circle,
4242
semidiscrete_wasserstein2_unif_circle)
43-
from .bregman import sinkhorn, sinkhorn2, barycenter
43+
from .bregman import (sinkhorn, sinkhorn2, barycenter, empirical_sinkhorn, empirical_sinkhorn2, empirical_sinkhorn_divergence)
4444
from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced,
4545
sinkhorn_unbalanced2)
4646
from .da import sinkhorn_lpl1_mm
@@ -61,6 +61,7 @@
6161
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
6262
'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian',
6363
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
64+
'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn_divergence',
6465
'sinkhorn_unbalanced', 'barycenter_unbalanced',
6566
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere',
6667
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein',

ot/bregman.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ot.utils import dist, list_to_array, unif
2424

2525
from .backend import get_backend
26+
from .gaussian import dual_gaussian_init
2627

2728

2829
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
@@ -541,6 +542,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
541542
log['niter'] = ii
542543
log['u'] = u
543544
log['v'] = v
545+
log['warmstart'] = (nx.log(u), nx.log(v))
544546

545547
if n_hists: # return only loss
546548
res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -697,6 +699,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
697699
'log_v': nx.stack(lst_v, 1), }
698700
log['u'] = nx.exp(log['log_u'])
699701
log['v'] = nx.exp(log['log_v'])
702+
log['warmstart'] = (log['log_u'], log['log_v'])
700703
return res, log
701704
else:
702705
return res
@@ -2999,15 +3002,23 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
29993002
if b is None:
30003003
b = nx.from_numpy(unif(nt), type_as=X_s)
30013004

3005+
if warmstart is None:
3006+
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
3007+
elif warmstart == 'gaussian':
3008+
# init only g since f is the first updated
3009+
f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None])
3010+
g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None])
3011+
elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2:
3012+
f, g = warmstart
3013+
else:
3014+
raise ValueError(
3015+
"warmstart must be None, 'gaussian' or a tuple of two arrays")
3016+
30023017
if isLazy:
30033018
if log:
30043019
dict_log = {"err": []}
30053020

30063021
log_a, log_b = nx.log(a), nx.log(b)
3007-
if warmstart is None:
3008-
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
3009-
else:
3010-
f, g = warmstart
30113022

30123023
if isinstance(batchSize, int):
30133024
bs, bt = batchSize, batchSize
@@ -3075,6 +3086,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
30753086
if log:
30763087
dict_log["u"] = f
30773088
dict_log["v"] = g
3089+
dict_log["warmstart"] = (f, g)
30783090
return (f, g, dict_log)
30793091
else:
30803092
return (f, g)
@@ -3083,11 +3095,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
30833095
M = dist(X_s, X_t, metric=metric)
30843096
if log:
30853097
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
3086-
verbose=verbose, log=True, warmstart=warmstart, **kwargs)
3098+
verbose=verbose, log=True, warmstart=(f, g), **kwargs)
30873099
return pi, log
30883100
else:
30893101
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
3090-
verbose=verbose, log=False, warmstart=warmstart, **kwargs)
3102+
verbose=verbose, log=False, warmstart=(f, g), **kwargs)
30913103
return pi
30923104

30933105

@@ -3201,6 +3213,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
32013213
if b is None:
32023214
b = nx.from_numpy(unif(nt), type_as=X_s)
32033215

3216+
if warmstart is None:
3217+
warmstart = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
3218+
elif warmstart == 'gaussian':
3219+
# init only g since f is the first updated
3220+
f = dual_gaussian_init(X_s, X_t, a[:, None], b[:, None])
3221+
g = dual_gaussian_init(X_t, X_s, b[:, None], a[:, None])
3222+
warmstart = (f, g)
3223+
elif (isinstance(warmstart, tuple) or isinstance(warmstart, list)) and len(warmstart) == 2:
3224+
warmstart = warmstart
3225+
else:
3226+
raise ValueError(
3227+
"warmstart must be None, 'gaussian' or a tuple of two arrays")
3228+
32043229
if isLazy:
32053230
if log:
32063231
f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,

ot/gaussian.py

+48
Original file line numberDiff line numberDiff line change
@@ -645,3 +645,51 @@ def empirical_gaussian_gromov_wasserstein_mapping(xs, xt, ws=None,
645645
return A, b, log
646646
else:
647647
return A, b
648+
649+
650+
def dual_gaussian_init(xs, xt, ws=None, wt=None, reg=1e-6):
651+
r""" Return the source dual potential gaussian initialization.
652+
653+
This function return the dual potential gaussian initialization that can be
654+
used to initialize the Sinkhorn algorithm. This initialization is based on
655+
the Monge mapping between the source and target distributions seen as two
656+
Gaussian distributions [60].
657+
658+
Parameters
659+
----------
660+
xs : array-like (ns,ds)
661+
samples in the source domain
662+
xt : array-like (nt,dt)
663+
samples in the target domain
664+
ws : array-like (ns,1), optional
665+
weights for the source samples
666+
wt : array-like (ns,1), optional
667+
weights for the target samples
668+
reg : float,optional
669+
regularization added to the diagonals of covariances (>0)
670+
671+
.. [60] Thornton, James, and Marco Cuturi. "Rethinking initialization of the
672+
sinkhorn algorithm." International Conference on Artificial Intelligence
673+
and Statistics. PMLR, 2023.
674+
"""
675+
676+
nx = get_backend(xs, xt)
677+
678+
if ws is None:
679+
ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
680+
681+
if wt is None:
682+
wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
683+
684+
# estimate mean and covariance
685+
mu_s = nx.dot(ws.T, xs) / nx.sum(ws)
686+
mu_t = nx.dot(wt.T, xt) / nx.sum(wt)
687+
688+
A, b = empirical_bures_wasserstein_mapping(xs, xt, ws=ws, wt=wt, reg=reg)
689+
690+
xsc = xs - mu_s
691+
692+
# compute the dual potential (see appendix D in [60])
693+
f = nx.sum(xs**2 - nx.dot(xsc, A) * xsc - mu_t * xs, 1)
694+
695+
return f

test/test_bregman.py

+8
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,9 @@ def test_empirical_sinkhorn(nx):
10411041
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
10421042
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
10431043

1044+
loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy(
1045+
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian'))
1046+
10441047
# check constraints
10451048
np.testing.assert_allclose(
10461049
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
@@ -1055,6 +1058,7 @@ def test_empirical_sinkhorn(nx):
10551058
np.testing.assert_allclose(
10561059
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
10571060
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
1061+
np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05)
10581062

10591063

10601064
def test_lazy_empirical_sinkhorn(nx):
@@ -1095,6 +1099,9 @@ def test_lazy_empirical_sinkhorn(nx):
10951099
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
10961100
loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
10971101

1102+
loss_emp_sinkhorn_gausss_warmstart = nx.to_numpy(
1103+
ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, warmstart='gaussian', isLazy=True))
1104+
10981105
# check constraints
10991106
np.testing.assert_allclose(
11001107
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
@@ -1109,6 +1116,7 @@ def test_lazy_empirical_sinkhorn(nx):
11091116
np.testing.assert_allclose(
11101117
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
11111118
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)
1119+
np.testing.assert_allclose(loss_emp_sinkhorn_gausss_warmstart, loss_sinkhorn, atol=1e-05)
11121120

11131121

11141122
def test_empirical_sinkhorn_divergence(nx):

test/test_gaussian.py

+19
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,22 @@ def test_gaussian_gromov_wasserstein_mapping(nx, d_target):
175175

176176
if d_target >= 2:
177177
np.testing.assert_allclose(Cs, Ctt)
178+
179+
180+
def test_gaussian_init(nx):
181+
ns = 50
182+
nt = 50
183+
184+
Xs, ys = make_data_classif('3gauss', ns)
185+
Xt, yt = make_data_classif('3gauss2', nt)
186+
187+
a_s = np.ones((ns, 1)) / ns
188+
a_t = np.ones((nt, 1)) / nt
189+
190+
Xsb, Xtb, a_sb, a_tb = nx.from_numpy(Xs, Xt, a_s, a_t)
191+
192+
f = ot.gaussian.dual_gaussian_init(Xsb, Xtb)
193+
194+
f2 = ot.gaussian.dual_gaussian_init(Xsb, Xtb, a_sb, a_tb)
195+
196+
np.testing.assert_allclose(nx.to_numpy(f), nx.to_numpy(f2))

0 commit comments

Comments
 (0)