@@ -492,6 +492,121 @@ def get_coordinate_circle(x):
492
492
return x_t
493
493
494
494
495
+ def reduce_lazytensor (a , func , axis = None , nx = None , batch_size = 100 ):
496
+ """ Reduce a LazyTensor along an axis with function fun using batches.
497
+
498
+ When axis=None, reduce the LazyTensor to a scalar as a sum of fun over
499
+ batches taken along dim.
500
+
501
+ .. warning::
502
+ This function works for tensor of any order but the reduction can be done
503
+ only along the first two axis (or global). Also, in order to work, it requires that the slice of size `batch_size` along the axis to reduce (or axis 0 if `axis=None`) is can be computed and fits in memory.
504
+
505
+
506
+ Parameters
507
+ ----------
508
+ a : LazyTensor
509
+ LazyTensor to reduce
510
+ func : callable
511
+ Function to apply to the LazyTensor
512
+ axis : int, optional
513
+ Axis along which to reduce the LazyTensor. If None, reduce the
514
+ LazyTensor to a scalar as a sum of fun over batches taken along axis 0.
515
+ If 0 or 1 reduce the LazyTensor to a vector/matrix as a sum of fun over
516
+ batches taken along axis.
517
+ nx : Backend, optional
518
+ Backend to use for the reduction
519
+ batch_size : int, optional
520
+ Size of the batches to use for the reduction (default=100)
521
+
522
+ Returns
523
+ -------
524
+ res : array-like
525
+ Result of the reduction
526
+
527
+ """
528
+
529
+ if nx is None :
530
+ nx = get_backend (a [0 ])
531
+
532
+ if axis is None :
533
+ res = 0.0
534
+ for i in range (0 , a .shape [0 ], batch_size ):
535
+ res += func (a [i :i + batch_size ])
536
+ return res
537
+ elif axis == 0 :
538
+ res = nx .zeros (a .shape [1 :], type_as = a [0 ])
539
+ if nx .__name__ in ["jax" , "tf" ]:
540
+ lst = []
541
+ for j in range (0 , a .shape [1 ], batch_size ):
542
+ lst .append (func (a [:, j :j + batch_size ], 0 ))
543
+ return nx .concatenate (lst , axis = 0 )
544
+ else :
545
+ for j in range (0 , a .shape [1 ], batch_size ):
546
+ res [j :j + batch_size ] = func (a [:, j :j + batch_size ], axis = 0 )
547
+ return res
548
+ elif axis == 1 :
549
+ if len (a .shape ) == 2 :
550
+ shape = (a .shape [0 ])
551
+ else :
552
+ shape = (a .shape [0 ], * a .shape [2 :])
553
+ res = nx .zeros (shape , type_as = a [0 ])
554
+ if nx .__name__ in ["jax" , "tf" ]:
555
+ lst = []
556
+ for i in range (0 , a .shape [0 ], batch_size ):
557
+ lst .append (func (a [i :i + batch_size ], 1 ))
558
+ return nx .concatenate (lst , axis = 0 )
559
+ else :
560
+ for i in range (0 , a .shape [0 ], batch_size ):
561
+ res [i :i + batch_size ] = func (a [i :i + batch_size ], axis = 1 )
562
+ return res
563
+
564
+ else :
565
+ raise (NotImplementedError ("Only axis=None, 0 or 1 is implemented for now." ))
566
+
567
+
568
+ def get_lowrank_lazytensor (Q , R , d = None , nx = None ):
569
+ """ Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T
570
+
571
+ Parameters
572
+ ----------
573
+ Q : ndarray, shape (n, r)
574
+ First factor of the lowrank tensor
575
+ R : ndarray, shape (m, r)
576
+ Second factor of the lowrank tensor
577
+ d : ndarray, shape (r,), optional
578
+ Diagonal of the lowrank tensor
579
+ nx : Backend, optional
580
+ Backend to use for the reduction
581
+
582
+ Returns
583
+ -------
584
+ T : LazyTensor
585
+ Lowrank tensor T=Q@R^T or T=Q@diag(d)@R^T
586
+ """
587
+
588
+ if nx is None :
589
+ nx = get_backend (Q , R , d )
590
+
591
+ shape = (Q .shape [0 ], R .shape [0 ])
592
+
593
+ if d is None :
594
+
595
+ def func (i , j , Q , R ):
596
+ return nx .dot (Q [i ], R [j ].T )
597
+
598
+ T = LazyTensor (shape , func , Q = Q , R = R )
599
+
600
+ else :
601
+
602
+ def func (i , j , Q , R , d ):
603
+ return nx .dot (Q [i ] * d [None , :], R [j ].T )
604
+
605
+ T = LazyTensor (shape , func , Q = Q , R = R , d = d )
606
+
607
+ return T
608
+
609
+
495
610
def get_parameter_pair (parameter ):
496
611
r"""Extract a pair of parameters from a given parameter
497
612
Used in unbalanced OT and COOT solvers
@@ -761,7 +876,76 @@ class UndefinedParameter(Exception):
761
876
762
877
763
878
class OTResult :
764
- def __init__ (self , potentials = None , value = None , value_linear = None , value_quad = None , plan = None , log = None , backend = None , sparse_plan = None , lazy_plan = None , status = None ):
879
+ """ Base class for OT results.
880
+
881
+ Parameters
882
+ ----------
883
+
884
+ potentials : tuple of array-like, shape (`n1`, `n2`)
885
+ Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
886
+ This pair of arrays has the same shape, numerical type
887
+ and properties as the input weights "a" and "b".
888
+ value : float, array-like
889
+ Full transport cost, including possible regularization terms and
890
+ quadratic term for Gromov Wasserstein solutions.
891
+ value_linear : float, array-like
892
+ The linear part of the transport cost, i.e. the product between the
893
+ transport plan and the cost.
894
+ value_quad : float, array-like
895
+ The quadratic part of the transport cost for Gromov-Wasserstein
896
+ solutions.
897
+ plan : array-like, shape (`n1`, `n2`)
898
+ Transport plan, encoded as a dense array.
899
+ log : dict
900
+ Dictionary containing potential information about the solver.
901
+ backend : Backend
902
+ Backend used to compute the results.
903
+ sparse_plan : array-like, shape (`n1`, `n2`)
904
+ Transport plan, encoded as a sparse array.
905
+ lazy_plan : LazyTensor
906
+ Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
907
+ status : int or str
908
+ Status of the solver.
909
+ batch_size : int
910
+ Batch size used to compute the results/marginals for LazyTensor.
911
+
912
+ Attributes
913
+ ----------
914
+
915
+ potentials : tuple of array-like, shape (`n1`, `n2`)
916
+ Dual potentials, i.e. Lagrange multipliers for the marginal constraints.
917
+ This pair of arrays has the same shape, numerical type
918
+ and properties as the input weights "a" and "b".
919
+ potential_a : array-like, shape (`n1`,)
920
+ First dual potential, associated to the "source" measure "a".
921
+ potential_b : array-like, shape (`n2`,)
922
+ Second dual potential, associated to the "target" measure "b".
923
+ value : float, array-like
924
+ Full transport cost, including possible regularization terms and
925
+ quadratic term for Gromov Wasserstein solutions.
926
+ value_linear : float, array-like
927
+ The linear part of the transport cost, i.e. the product between the
928
+ transport plan and the cost.
929
+ value_quad : float, array-like
930
+ The quadratic part of the transport cost for Gromov-Wasserstein
931
+ solutions.
932
+ plan : array-like, shape (`n1`, `n2`)
933
+ Transport plan, encoded as a dense array.
934
+ sparse_plan : array-like, shape (`n1`, `n2`)
935
+ Transport plan, encoded as a sparse array.
936
+ lazy_plan : LazyTensor
937
+ Transport plan, encoded as a symbolic POT or KeOps LazyTensor.
938
+ marginals : tuple of array-like, shape (`n1`,), (`n2`,)
939
+ Marginals of the transport plan: should be very close to "a" and "b"
940
+ for balanced OT.
941
+ marginal_a : array-like, shape (`n1`,)
942
+ Marginal of the transport plan for the "source" measure "a".
943
+ marginal_b : array-like, shape (`n2`,)
944
+ Marginal of the transport plan for the "target" measure "b".
945
+
946
+ """
947
+
948
+ def __init__ (self , potentials = None , value = None , value_linear = None , value_quad = None , plan = None , log = None , backend = None , sparse_plan = None , lazy_plan = None , status = None , batch_size = 100 ):
765
949
766
950
self ._potentials = potentials
767
951
self ._value = value
@@ -773,6 +957,7 @@ def __init__(self, potentials=None, value=None, value_linear=None, value_quad=No
773
957
self ._lazy_plan = lazy_plan
774
958
self ._backend = backend if backend is not None else NumpyBackend ()
775
959
self ._status = status
960
+ self ._batch_size = batch_size
776
961
777
962
# I assume that other solvers may return directly
778
963
# some primal objects?
@@ -793,7 +978,8 @@ def __repr__(self):
793
978
s += 'value_linear={},' .format (self ._value_linear )
794
979
if self ._plan is not None :
795
980
s += 'plan={}(shape={}),' .format (self ._plan .__class__ .__name__ , self ._plan .shape )
796
-
981
+ if self ._lazy_plan is not None :
982
+ s += 'lazy_plan={}(shape={}),' .format (self ._lazy_plan .__class__ .__name__ , self ._lazy_plan .shape )
797
983
if s [- 1 ] != '(' :
798
984
s = s [:- 1 ] + ')'
799
985
else :
@@ -853,7 +1039,10 @@ def sparse_plan(self):
853
1039
@property
854
1040
def lazy_plan (self ):
855
1041
"""Transport plan, encoded as a symbolic KeOps LazyTensor."""
856
- raise NotImplementedError ()
1042
+ if self ._lazy_plan is not None :
1043
+ return self ._lazy_plan
1044
+ else :
1045
+ raise NotImplementedError ()
857
1046
858
1047
# Loss values --------------------------------
859
1048
@@ -897,6 +1086,11 @@ def marginal_a(self):
897
1086
"""First marginal of the transport plan, with the same shape as "a"."""
898
1087
if self ._plan is not None :
899
1088
return self ._backend .sum (self ._plan , 1 )
1089
+ elif self ._lazy_plan is not None :
1090
+ lp = self ._lazy_plan
1091
+ bs = self ._batch_size
1092
+ nx = self ._backend
1093
+ return reduce_lazytensor (lp , nx .sum , axis = 1 , nx = nx , batch_size = bs )
900
1094
else :
901
1095
raise NotImplementedError ()
902
1096
@@ -905,6 +1099,11 @@ def marginal_b(self):
905
1099
"""Second marginal of the transport plan, with the same shape as "b"."""
906
1100
if self ._plan is not None :
907
1101
return self ._backend .sum (self ._plan , 0 )
1102
+ elif self ._lazy_plan is not None :
1103
+ lp = self ._lazy_plan
1104
+ bs = self ._batch_size
1105
+ nx = self ._backend
1106
+ return reduce_lazytensor (lp , nx .sum , axis = 0 , nx = nx , batch_size = bs )
908
1107
else :
909
1108
raise NotImplementedError ()
910
1109
@@ -968,3 +1167,70 @@ def citation(self):
968
1167
url = {http://jmlr.org/papers/v22/20-451.html}
969
1168
}
970
1169
"""
1170
+
1171
+
1172
+ class LazyTensor (object ):
1173
+ """ A lazy tensor is a tensor that is not stored in memory. Instead, it is
1174
+ defined by a function that computes its values on the fly from slices.
1175
+
1176
+ Parameters
1177
+ ----------
1178
+
1179
+ shape : tuple
1180
+ shape of the tensor
1181
+ getitem : callable
1182
+ function that computes the values of the indices/slices and tensors
1183
+ as arguments
1184
+
1185
+ kwargs : dict
1186
+ named arguments for the function, those names will be used as attributed
1187
+ of the LazyTensor object
1188
+
1189
+ Examples
1190
+ --------
1191
+ >>> import numpy as np
1192
+ >>> v = np.arange(5)
1193
+ >>> def getitem(i,j, v):
1194
+ ... return v[i,None]+v[None,j]
1195
+ >>> T = LazyTensor((5,5),getitem, v=v)
1196
+ >>> T[1,2]
1197
+ array([3])
1198
+ >>> T[1,:]
1199
+ array([[1, 2, 3, 4, 5]])
1200
+ >>> T[:]
1201
+ array([[0, 1, 2, 3, 4],
1202
+ [1, 2, 3, 4, 5],
1203
+ [2, 3, 4, 5, 6],
1204
+ [3, 4, 5, 6, 7],
1205
+ [4, 5, 6, 7, 8]])
1206
+
1207
+ """
1208
+
1209
+ def __init__ (self , shape , getitem , ** kwargs ):
1210
+
1211
+ self ._getitem = getitem
1212
+ self .shape = shape
1213
+ self .ndim = len (shape )
1214
+ self .kwargs = kwargs
1215
+
1216
+ # set attributes for named arguments/arrays
1217
+ for key , value in kwargs .items ():
1218
+ setattr (self , key , value )
1219
+
1220
+ def __getitem__ (self , key ):
1221
+ k = []
1222
+ if isinstance (key , int ) or isinstance (key , slice ):
1223
+ k .append (key )
1224
+ for i in range (self .ndim - 1 ):
1225
+ k .append (slice (None ))
1226
+ elif isinstance (key , tuple ):
1227
+ k = list (key )
1228
+ for i in range (self .ndim - len (key )):
1229
+ k .append (slice (None ))
1230
+ else :
1231
+ raise NotImplementedError ("Only integer, slice, and tuple indexing is supported" )
1232
+
1233
+ return self ._getitem (* k , ** self .kwargs )
1234
+
1235
+ def __repr__ (self ):
1236
+ return "LazyTensor(shape={},attributes=({}))" .format (self .shape , ',' .join (self .kwargs .keys ()))
0 commit comments