1
- from pytensor import scalar as aes
1
+ from numpy .core .numeric import normalize_axis_index # type: ignore
2
+
2
3
from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
3
- from pytensor .tensor .elemwise import DimShuffle , Elemwise
4
- from pytensor .tensor .math import Sum , exp
4
+ from pytensor .tensor .elemwise import DimShuffle
5
+ from pytensor .tensor .math import Sum , exp , log
5
6
from pytensor .tensor .math import sum as at_sum
6
7
from pytensor .tensor .math import true_div
7
- from pytensor .tensor .rewriting .basic import register_specialize
8
+ from pytensor .tensor .rewriting .basic import register_stabilize
8
9
from pytensor .tensor .rewriting .math import local_mul_canonizer
9
- from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
10
- from pytensor .tensor .subtensor import AdvancedIncSubtensor
10
+ from pytensor .tensor .special import Softmax , SoftmaxGrad , log_softmax
11
+ from pytensor .tensor .subtensor import (
12
+ AdvancedIncSubtensor ,
13
+ AdvancedSubtensor ,
14
+ AdvancedSubtensor1 ,
15
+ Subtensor ,
16
+ )
11
17
from pytensor .tensor .type import (
12
18
values_eq_approx_remove_inf ,
13
19
values_eq_approx_remove_nan ,
14
20
)
15
21
16
22
17
- # This is not registered in stabilize, as it cause some crossentropy
18
- # optimization to not be inserted.
19
- @register_specialize ("stabilize" , "fast_compile" )
20
- @node_rewriter ([Elemwise ])
23
+ subtensor_ops = (
24
+ Subtensor ,
25
+ AdvancedSubtensor ,
26
+ AdvancedSubtensor1 ,
27
+ )
28
+
29
+
30
+ @register_stabilize
31
+ @node_rewriter ([log ])
21
32
def local_logsoftmax (fgraph , node ):
22
33
"""
23
34
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
24
35
36
+ This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
37
+
25
38
Note: only forward pass is affected
26
39
"""
27
- if (
28
- isinstance (node .op , Elemwise )
29
- and isinstance (node .op .scalar_op , aes .Log )
30
- and len (node .inputs ) == 1
31
- and node .inputs [0 ].owner is not None
32
- and isinstance (node .inputs [0 ].owner .op , Softmax )
33
- ):
34
- inVars = node .inputs [0 ].owner .inputs [0 ]
35
- new_op = LogSoftmax (axis = node .inputs [0 ].owner .op .axis )
36
- ret = new_op (inVars )
37
- ret .tag .values_eq_approx = values_eq_approx_remove_inf
38
- copy_stack_trace ([node .inputs [0 ], node .outputs [0 ]], ret )
39
- return [ret ]
40
+ def find_softmax_under_lifteable_ops (inp_node , ops_to_lift ):
41
+ if inp_node is None :
42
+ return
43
+
44
+ if isinstance (inp_node .op , Softmax ):
45
+ return inp_node
46
+
47
+ if isinstance (inp_node .op , subtensor_ops ):
48
+ ops_to_lift .append ((inp_node .op , inp_node .inputs [1 :]))
49
+ return find_softmax_under_lifteable_ops (inp_node .inputs [0 ].owner , ops_to_lift )
50
+
51
+ if isinstance (inp_node .op , DimShuffle ):
52
+ ops_to_lift .append ((inp_node .op , ()))
53
+ return find_softmax_under_lifteable_ops (inp_node .inputs [0 ].owner , ops_to_lift )
54
+
55
+ ops_to_lift = []
56
+ softmax_node = find_softmax_under_lifteable_ops (node .inputs [0 ].owner , ops_to_lift )
57
+
58
+ if softmax_node is None :
59
+ return
60
+
61
+ ret = log_softmax (softmax_node .inputs [0 ], axis = softmax_node .op .axis )
62
+ ret .tag .values_eq_approx = values_eq_approx_remove_inf
63
+
64
+ # Lift ops that used to be between log and softmax
65
+ for op_to_lift , parameters in reversed (ops_to_lift ):
66
+ ret = op_to_lift (ret , * parameters )
67
+
68
+ copy_stack_trace (node .outputs , ret )
69
+ return [ret ]
40
70
41
71
42
- # This is not registered in stabilize, as it cause some crossentropy
43
- # optimization to not be inserted.
44
- @register_specialize ("stabilize" , "fast_compile" )
72
+ @register_stabilize
45
73
@node_rewriter ([SoftmaxGrad ])
46
74
def local_logsoftmax_grad (fgraph , node ):
47
75
"""
@@ -50,9 +78,7 @@ def local_logsoftmax_grad(fgraph, node):
50
78
Note: only grad is affected
51
79
"""
52
80
if (
53
- isinstance (node .op , SoftmaxGrad )
54
- and len (node .inputs ) == 2
55
- and node .inputs [0 ].owner is not None
81
+ node .inputs [0 ].owner is not None
56
82
and node .inputs [0 ].owner .op == true_div
57
83
and len (node .inputs [0 ].owner .inputs ) >= 2
58
84
and node .inputs [0 ].owner .inputs [1 ].owner is not None
0 commit comments