Skip to content

Commit 53dde7a

Browse files
authored
[MRG] Add LazyTensor for large scale OT (#544)
* add LazyTensor * test marginals for OTResult * test marginals for OTResult * debug tensorflow and impelment reduce function for Lazytensors * cleanup tensorflow stuff and debut backnd detection in reduce function * comment agramfor * remove trailing space * all comments alex * add simple low rank lazytensor creator and add it to log for actored OT * better tests corverage * update doc function * pep8
1 parent 6a29551 commit 53dde7a

File tree

5 files changed

+441
-4
lines changed

5 files changed

+441
-4
lines changed

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
+ New LP solvers from scipy used by default for LP barycenter (PR #537)
1313
+ Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543)
1414
+ Upgraded unbalanced OT solvers for more flexibility (PR #539)
15+
+ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544)
1516

1617
#### Closed issues
1718
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/factored.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# License: MIT License
88

99
from .backend import get_backend
10-
from .utils import dist
10+
from .utils import dist, get_lowrank_lazytensor
1111
from .lp import emd
1212
from .bregman import sinkhorn
1313

@@ -139,6 +139,7 @@ def solve_ot(X1, X2, w1, w2):
139139
'vb': logb['v'],
140140
'costa': loga['cost'],
141141
'costb': logb['cost'],
142+
'lazy_plan': get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx),
142143
}
143144
return Ga, Gb, X, log_dic
144145

ot/utils.py

+269-3
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,121 @@ def get_coordinate_circle(x):
492492
return x_t
493493

494494

495+
def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100):
496+
""" Reduce a LazyTensor along an axis with function fun using batches.
497+
498+
When axis=None, reduce the LazyTensor to a scalar as a sum of fun over
499+
batches taken along dim.
500+
501+
.. warning::
502+
This function works for tensor of any order but the reduction can be done
503+
only along the first two axis (or global). Also, in order to work, it requires that the slice of size `batch_size` along the axis to reduce (or axis 0 if `axis=None`) is can be computed and fits in memory.
504+
505+
506+
Parameters
507+
----------
508+
a : LazyTensor
509+
LazyTensor to reduce
510+
func : callable
511+
Function to apply to the LazyTensor
512+
axis : int, optional
513+
Axis along which to reduce the LazyTensor. If None, reduce the
514+
LazyTensor to a scalar as a sum of fun over batches taken along axis 0.
515+
If 0 or 1 reduce the LazyTensor to a vector/matrix as a sum of fun over
516+
batches taken along axis.
517+
nx : Backend, optional
518+
Backend to use for the reduction
519+
batch_size : int, optional
520+
Size of the batches to use for the reduction (default=100)
521+
522+
Returns
523+
-------
524+
res : array-like
525+
Result of the reduction
526+
527+
"""
528+
529+
if nx is None:
530+
nx = get_backend(a[0])
531+
532+
if axis is None:
533+
res = 0.0
534+
for i in range(0, a.shape[0], batch_size):
535+
res += func(a[i:i + batch_size])
536+
return res
537+
elif axis == 0:
538+
res = nx.zeros(a.shape[1:], type_as=a[0])
539+
if nx.__name__ in ["jax", "tf"]:
540+
lst = []
541+
for j in range(0, a.shape[1], batch_size):
542+
lst.append(func(a[:, j:j + batch_size], 0))
543+
return nx.concatenate(lst, axis=0)
544+
else:
545+
for j in range(0, a.shape[1], batch_size):
546+
res[j:j + batch_size] = func(a[:, j:j + batch_size], axis=0)
547+
return res
548+
elif axis == 1:
549+
if len(a.shape) == 2:
550+
shape = (a.shape[0])
551+
else:
552+
shape = (a.shape[0], *a.shape[2:])
553+
res = nx.zeros(shape, type_as=a[0])
554+
if nx.__name__ in ["jax", "tf"]:
555+
lst = []
556+
for i in range(0, a.shape[0], batch_size):
557+
lst.append(func(a[i:i + batch_size], 1))
558+
return nx.concatenate(lst, axis=0)
559+
else:
560+
for i in range(0, a.shape[0], batch_size):
561+
res[i:i + batch_size] = func(a[i:i + batch_size], axis=1)
562+
return res
563+
564+
else:
565+
raise (NotImplementedError("Only axis=None, 0 or 1 is implemented for now."))
566+
567+
568+
def get_lowrank_lazytensor(Q, R, d=None, nx=None):
569+
""" Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T
570+
571+
Parameters
572+
----------
573+
Q : ndarray, shape (n, r)
574+
First factor of the lowrank tensor
575+
R : ndarray, shape (m, r)
576+
Second factor of the lowrank tensor
577+
d : ndarray, shape (r,), optional
578+
Diagonal of the lowrank tensor
579+
nx : Backend, optional
580+
Backend to use for the reduction
581+
582+
Returns
583+
-------
584+
T : LazyTensor
585+
Lowrank tensor T=Q@R^T or T=Q@diag(d)@R^T
586+
"""
587+
588+
if nx is None:
589+
nx = get_backend(Q, R, d)
590+
591+
shape = (Q.shape[0], R.shape[0])
592+
593+
if d is None:
594+
595+
def func(i, j, Q, R):
596+
return nx.dot(Q[i], R[j].T)
597+
598+
T = LazyTensor(shape, func, Q=Q, R=R)
599+
600+
else:
601+
602+
def func(i, j, Q, R, d):
603+
return nx.dot(Q[i] * d[None, :], R[j].T)
604+
605+
T = LazyTensor(shape, func, Q=Q, R=R, d=d)
606+
607+
return T
608+
609+
495610
def get_parameter_pair(parameter):
496611
r"""Extract a pair of parameters from a given parameter
497612
Used in unbalanced OT and COOT solvers
@@ -761,7 +876,76 @@ class UndefinedParameter(Exception):
761876

762877

763878
class OTResult:
764-
def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None):
879+
""" Base class for OT results.
880+
881+
Parameters
882+
----------
883+
884+
potentials : tuple of array-like, shape (`n1`, `n2`)
885+
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
886+
This pair of arrays has the same shape, numerical type
887+
and properties as the input weights "a" and "b".
888+
value : float, array-like
889+
Full transport cost, including possible regularization terms and
890+
quadratic term for Gromov Wasserstein solutions.
891+
value_linear : float, array-like
892+
The linear part of the transport cost, i.e. the product between the
893+
transport plan and the cost.
894+
value_quad : float, array-like
895+
The quadratic part of the transport cost for Gromov-Wasserstein
896+
solutions.
897+
plan : array-like, shape (`n1`, `n2`)
898+
Transport plan, encoded as a dense array.
899+
log : dict
900+
Dictionary containing potential information about the solver.
901+
backend : Backend
902+
Backend used to compute the results.
903+
sparse_plan : array-like, shape (`n1`, `n2`)
904+
Transport plan, encoded as a sparse array.
905+
lazy_plan : LazyTensor
906+
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
907+
status : int or str
908+
Status of the solver.
909+
batch_size : int
910+
Batch size used to compute the results/marginals for LazyTensor.
911+
912+
Attributes
913+
----------
914+
915+
potentials : tuple of array-like, shape (`n1`, `n2`)
916+
Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
917+
This pair of arrays has the same shape, numerical type
918+
and properties as the input weights "a" and "b".
919+
potential_a : array-like, shape (`n1`,)
920+
First dual potential, associated to the "source" measure "a".
921+
potential_b : array-like, shape (`n2`,)
922+
Second dual potential, associated to the "target" measure "b".
923+
value : float, array-like
924+
Full transport cost, including possible regularization terms and
925+
quadratic term for Gromov Wasserstein solutions.
926+
value_linear : float, array-like
927+
The linear part of the transport cost, i.e. the product between the
928+
transport plan and the cost.
929+
value_quad : float, array-like
930+
The quadratic part of the transport cost for Gromov-Wasserstein
931+
solutions.
932+
plan : array-like, shape (`n1`, `n2`)
933+
Transport plan, encoded as a dense array.
934+
sparse_plan : array-like, shape (`n1`, `n2`)
935+
Transport plan, encoded as a sparse array.
936+
lazy_plan : LazyTensor
937+
Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
938+
marginals : tuple of array-like, shape (`n1`,), (`n2`,)
939+
Marginals of the transport plan: should be very close to "a" and "b"
940+
for balanced OT.
941+
marginal_a : array-like, shape (`n1`,)
942+
Marginal of the transport plan for the "source" measure "a".
943+
marginal_b : array-like, shape (`n2`,)
944+
Marginal of the transport plan for the "target" measure "b".
945+
946+
"""
947+
948+
def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None, batch_size=100):
765949

766950
self._potentials = potentials
767951
self._value = value
@@ -773,6 +957,7 @@ def __init__(self, potentials=None, value=None, value_linear=None, value_quad=No
773957
self._lazy_plan = lazy_plan
774958
self._backend = backend if backend is not None else NumpyBackend()
775959
self._status = status
960+
self._batch_size = batch_size
776961

777962
# I assume that other solvers may return directly
778963
# some primal objects?
@@ -793,7 +978,8 @@ def __repr__(self):
793978
s += 'value_linear={},'.format(self._value_linear)
794979
if self._plan is not None:
795980
s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape)
796-
981+
if self._lazy_plan is not None:
982+
s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape)
797983
if s[-1] != '(':
798984
s = s[:-1] + ')'
799985
else:
@@ -853,7 +1039,10 @@ def sparse_plan(self):
8531039
@property
8541040
def lazy_plan(self):
8551041
"""Transport plan, encoded as a symbolic KeOps LazyTensor."""
856-
raise NotImplementedError()
1042+
if self._lazy_plan is not None:
1043+
return self._lazy_plan
1044+
else:
1045+
raise NotImplementedError()
8571046

8581047
# Loss values --------------------------------
8591048

@@ -897,6 +1086,11 @@ def marginal_a(self):
8971086
"""First marginal of the transport plan, with the same shape as "a"."""
8981087
if self._plan is not None:
8991088
return self._backend.sum(self._plan, 1)
1089+
elif self._lazy_plan is not None:
1090+
lp = self._lazy_plan
1091+
bs = self._batch_size
1092+
nx = self._backend
1093+
return reduce_lazytensor(lp, nx.sum, axis=1, nx=nx, batch_size=bs)
9001094
else:
9011095
raise NotImplementedError()
9021096

@@ -905,6 +1099,11 @@ def marginal_b(self):
9051099
"""Second marginal of the transport plan, with the same shape as "b"."""
9061100
if self._plan is not None:
9071101
return self._backend.sum(self._plan, 0)
1102+
elif self._lazy_plan is not None:
1103+
lp = self._lazy_plan
1104+
bs = self._batch_size
1105+
nx = self._backend
1106+
return reduce_lazytensor(lp, nx.sum, axis=0, nx=nx, batch_size=bs)
9081107
else:
9091108
raise NotImplementedError()
9101109

@@ -968,3 +1167,70 @@ def citation(self):
9681167
url = {http://jmlr.org/papers/v22/20-451.html}
9691168
}
9701169
"""
1170+
1171+
1172+
class LazyTensor(object):
1173+
""" A lazy tensor is a tensor that is not stored in memory. Instead, it is
1174+
defined by a function that computes its values on the fly from slices.
1175+
1176+
Parameters
1177+
----------
1178+
1179+
shape : tuple
1180+
shape of the tensor
1181+
getitem : callable
1182+
function that computes the values of the indices/slices and tensors
1183+
as arguments
1184+
1185+
kwargs : dict
1186+
named arguments for the function, those names will be used as attributed
1187+
of the LazyTensor object
1188+
1189+
Examples
1190+
--------
1191+
>>> import numpy as np
1192+
>>> v = np.arange(5)
1193+
>>> def getitem(i,j, v):
1194+
... return v[i,None]+v[None,j]
1195+
>>> T = LazyTensor((5,5),getitem, v=v)
1196+
>>> T[1,2]
1197+
array([3])
1198+
>>> T[1,:]
1199+
array([[1, 2, 3, 4, 5]])
1200+
>>> T[:]
1201+
array([[0, 1, 2, 3, 4],
1202+
[1, 2, 3, 4, 5],
1203+
[2, 3, 4, 5, 6],
1204+
[3, 4, 5, 6, 7],
1205+
[4, 5, 6, 7, 8]])
1206+
1207+
"""
1208+
1209+
def __init__(self, shape, getitem, **kwargs):
1210+
1211+
self._getitem = getitem
1212+
self.shape = shape
1213+
self.ndim = len(shape)
1214+
self.kwargs = kwargs
1215+
1216+
# set attributes for named arguments/arrays
1217+
for key, value in kwargs.items():
1218+
setattr(self, key, value)
1219+
1220+
def __getitem__(self, key):
1221+
k = []
1222+
if isinstance(key, int) or isinstance(key, slice):
1223+
k.append(key)
1224+
for i in range(self.ndim - 1):
1225+
k.append(slice(None))
1226+
elif isinstance(key, tuple):
1227+
k = list(key)
1228+
for i in range(self.ndim - len(key)):
1229+
k.append(slice(None))
1230+
else:
1231+
raise NotImplementedError("Only integer, slice, and tuple indexing is supported")
1232+
1233+
return self._getitem(*k, **self.kwargs)
1234+
1235+
def __repr__(self):
1236+
return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys()))

test/test_factored.py

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_factored_ot():
2828
# check constraints
2929
np.testing.assert_allclose(u, Ga.sum(1))
3030
np.testing.assert_allclose(u, Gb.sum(0))
31+
np.testing.assert_allclose(1, log['lazy_plan'][:].sum())
3132

3233

3334
def test_factored_ot_backends(nx):

0 commit comments

Comments
 (0)