1
- from pytensor import scalar as aes
1
+ from numpy .core .numeric import normalize_axis_index # type: ignore
2
+
3
+ from pytensor import Variable
2
4
from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
3
- from pytensor .tensor .elemwise import DimShuffle , Elemwise
4
- from pytensor .tensor .math import Sum , exp
5
+ from pytensor .tensor .basic import expand_dims
6
+ from pytensor .tensor .elemwise import DimShuffle
7
+ from pytensor .tensor .extra_ops import squeeze
8
+ from pytensor .tensor .math import Sum , exp , log , logsumexp
5
9
from pytensor .tensor .math import sum as at_sum
6
10
from pytensor .tensor .math import true_div
7
- from pytensor .tensor .rewriting .basic import register_specialize
8
- from pytensor .tensor .rewriting .math import local_mul_canonizer
11
+ from pytensor .tensor .rewriting .basic import register_specialize , register_stabilize
12
+ from pytensor .tensor .rewriting .math import local_log_sum_exp , local_mul_canonizer
13
+ from pytensor .tensor .rewriting .subtensor import is_full_slice
9
14
from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
10
- from pytensor .tensor .subtensor import AdvancedIncSubtensor
15
+ from pytensor .tensor .subtensor import (
16
+ AdvancedIncSubtensor ,
17
+ AdvancedSubtensor ,
18
+ AdvancedSubtensor1 ,
19
+ Subtensor ,
20
+ indices_from_subtensor ,
21
+ is_basic_idx ,
22
+ )
11
23
from pytensor .tensor .type import (
12
24
values_eq_approx_remove_inf ,
13
25
values_eq_approx_remove_nan ,
14
26
)
27
+ from pytensor .tensor .type_other import NoneTypeT
28
+
29
+
30
+ subtensor_ops = (
31
+ Subtensor ,
32
+ AdvancedSubtensor ,
33
+ AdvancedSubtensor1 ,
34
+ )
15
35
16
36
17
37
# This is not registered in stabilize, as it cause some crossentropy
18
38
# optimization to not be inserted.
19
39
@register_specialize ("stabilize" , "fast_compile" )
20
- @node_rewriter ([Elemwise ])
40
+ @node_rewriter ([log ])
21
41
def local_logsoftmax (fgraph , node ):
22
42
"""
23
43
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
24
44
25
45
Note: only forward pass is affected
26
46
"""
27
- if (
28
- isinstance (node .op , Elemwise )
29
- and isinstance (node .op .scalar_op , aes .Log )
30
- and len (node .inputs ) == 1
31
- and node .inputs [0 ].owner is not None
32
- and isinstance (node .inputs [0 ].owner .op , Softmax )
47
+ if node .inputs [0 ].owner is not None and isinstance (
48
+ node .inputs [0 ].owner .op , Softmax
33
49
):
34
50
inVars = node .inputs [0 ].owner .inputs [0 ]
35
51
new_op = LogSoftmax (axis = node .inputs [0 ].owner .op .axis )
@@ -39,6 +55,92 @@ def local_logsoftmax(fgraph, node):
39
55
return [ret ]
40
56
41
57
58
+ @register_stabilize
59
+ @node_rewriter ([log ])
60
+ def local_log_subtensor_softmax (fgraph , node ):
61
+ """Replace log(softmax(x, axis)[idx]) -> x[idx] - logsumexp(x, axis).
62
+
63
+ This can only be done when indexing happens over axis dims.
64
+ There can be non-indexed axis dims, but not non-axis indexed dims.
65
+ """
66
+ [subtensor_var ] = node .inputs
67
+ subtensor_node = subtensor_var .owner
68
+
69
+ if subtensor_node is not None and isinstance (subtensor_node .op , subtensor_ops ):
70
+ softmax_var , * idxs = subtensor_node .inputs
71
+ softmax_node = softmax_var .owner
72
+ if softmax_node is not None and isinstance (softmax_node .op , Softmax ):
73
+ if isinstance (subtensor_node .op , Subtensor ):
74
+ idxs = indices_from_subtensor (idxs , subtensor_node .op .idx_list )
75
+
76
+ # TODO: support expand_dims
77
+ if any (
78
+ (isinstance (idx , Variable ) and isinstance (idx .type , NoneTypeT ))
79
+ for idx in idxs
80
+ ):
81
+ return None
82
+
83
+ [x ] = softmax_node .inputs
84
+ axis = softmax_node .op .axis
85
+ if axis is not None :
86
+ axis = normalize_axis_index (axis , ndim = x .type .ndim )
87
+
88
+ indexed_dims = [
89
+ dim for dim , idx in enumerate (idxs ) if not is_full_slice (idx )
90
+ ]
91
+
92
+ # We can only apply the rewrite when the softmax is applied across all indexed dims
93
+ if axis is not None and {axis } != set (indexed_dims ):
94
+ return None
95
+
96
+ dims_to_expand = ()
97
+ dims_to_drop = ()
98
+ if isinstance (subtensor_node .op , Subtensor ):
99
+ dims_to_drop = tuple (
100
+ dim for dim , idx in enumerate (idxs ) if getattr (idx , "ndim" , - 1 ) == 0
101
+ )
102
+ if isinstance (subtensor_node .op , (AdvancedSubtensor , AdvancedSubtensor1 )):
103
+ adv_dims_idxs = tuple (
104
+ (dim , idx ) for dim , idx in enumerate (idxs ) if not is_basic_idx (idx )
105
+ )
106
+ adv_dims = tuple (dim for dim , idx in adv_dims_idxs )
107
+ adv_idxs = tuple (idx for dim , idx in adv_dims_idxs )
108
+
109
+ # Boolean indexing not supported
110
+ if any (idx .dtype == "bool" for idx in adv_idxs ):
111
+ return None
112
+
113
+ # Non-contiguous advanced indexing not supported
114
+ if tuple (range (adv_dims [0 ], adv_dims [- 1 ] + 1 )) != adv_dims :
115
+ return None
116
+
117
+ ndim_adv_idx = max (idx .ndim for idx in adv_idxs )
118
+ n_new_dims = ndim_adv_idx - len (adv_idxs )
119
+ # Advanced indexing introduces new dims
120
+ if n_new_dims > 0 :
121
+ dims_to_expand = tuple (range (adv_dims [0 ], adv_dims [0 ] + n_new_dims ))
122
+ # It reduces number of dims
123
+ elif n_new_dims < 0 :
124
+ dims_to_drop = tuple (
125
+ range (adv_dims [0 ], adv_dims [0 ] + abs (n_new_dims ))
126
+ )
127
+
128
+ # Rewrite stable form of logsumexp immediately
129
+ [x_logsumexp ] = local_log_sum_exp .transform (
130
+ None , logsumexp (x , axis = axis , keepdims = True ).owner
131
+ )
132
+
133
+ assert not (dims_to_drop and dims_to_expand )
134
+ if dims_to_expand :
135
+ x_logsumexp = expand_dims (x_logsumexp , dims_to_expand )
136
+ elif dims_to_drop :
137
+ x_logsumexp = squeeze (x_logsumexp , axis = dims_to_drop )
138
+ ret = x [tuple (idxs )] - x_logsumexp
139
+
140
+ copy_stack_trace (node .outputs , ret )
141
+ return [ret ]
142
+
143
+
42
144
# This is not registered in stabilize, as it cause some crossentropy
43
145
# optimization to not be inserted.
44
146
@register_specialize ("stabilize" , "fast_compile" )
0 commit comments