Skip to content

Commit 9aa9d06

Browse files
committed
Extend log_softmax rewrite and run it in stabilize
1 parent 3fe07f3 commit 9aa9d06

File tree

2 files changed

+85
-52
lines changed

2 files changed

+85
-52
lines changed

Diff for: pytensor/tensor/rewriting/special.py

+55-29
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,75 @@
1-
from pytensor import scalar as aes
1+
from numpy.core.numeric import normalize_axis_index # type: ignore
2+
23
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
3-
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4-
from pytensor.tensor.math import Sum, exp
4+
from pytensor.tensor.elemwise import DimShuffle
5+
from pytensor.tensor.math import Sum, exp, log
56
from pytensor.tensor.math import sum as at_sum
67
from pytensor.tensor.math import true_div
7-
from pytensor.tensor.rewriting.basic import register_specialize
8+
from pytensor.tensor.rewriting.basic import register_stabilize
89
from pytensor.tensor.rewriting.math import local_mul_canonizer
9-
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
10-
from pytensor.tensor.subtensor import AdvancedIncSubtensor
10+
from pytensor.tensor.special import Softmax, SoftmaxGrad, log_softmax
11+
from pytensor.tensor.subtensor import (
12+
AdvancedIncSubtensor,
13+
AdvancedSubtensor,
14+
AdvancedSubtensor1,
15+
Subtensor,
16+
)
1117
from pytensor.tensor.type import (
1218
values_eq_approx_remove_inf,
1319
values_eq_approx_remove_nan,
1420
)
1521

1622

17-
# This is not registered in stabilize, as it cause some crossentropy
18-
# optimization to not be inserted.
19-
@register_specialize("stabilize", "fast_compile")
20-
@node_rewriter([Elemwise])
23+
subtensor_ops = (
24+
Subtensor,
25+
AdvancedSubtensor,
26+
AdvancedSubtensor1,
27+
)
28+
29+
30+
@register_stabilize
31+
@node_rewriter([log])
2132
def local_logsoftmax(fgraph, node):
2233
"""
2334
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
2435
36+
This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
37+
2538
Note: only forward pass is affected
2639
"""
27-
if (
28-
isinstance(node.op, Elemwise)
29-
and isinstance(node.op.scalar_op, aes.Log)
30-
and len(node.inputs) == 1
31-
and node.inputs[0].owner is not None
32-
and isinstance(node.inputs[0].owner.op, Softmax)
33-
):
34-
inVars = node.inputs[0].owner.inputs[0]
35-
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
36-
ret = new_op(inVars)
37-
ret.tag.values_eq_approx = values_eq_approx_remove_inf
38-
copy_stack_trace([node.inputs[0], node.outputs[0]], ret)
39-
return [ret]
40+
def find_softmax_under_lifteable_ops(inp_node, ops_to_lift):
41+
if inp_node is None:
42+
return
43+
44+
if isinstance(inp_node.op, Softmax):
45+
return inp_node
46+
47+
if isinstance(inp_node.op, subtensor_ops):
48+
ops_to_lift.append((inp_node.op, inp_node.inputs[1:]))
49+
return find_softmax_under_lifteable_ops(inp_node.inputs[0].owner, ops_to_lift)
50+
51+
if isinstance(inp_node.op, DimShuffle):
52+
ops_to_lift.append((inp_node.op, ()))
53+
return find_softmax_under_lifteable_ops(inp_node.inputs[0].owner, ops_to_lift)
54+
55+
ops_to_lift = []
56+
softmax_node = find_softmax_under_lifteable_ops(node.inputs[0].owner, ops_to_lift)
57+
58+
if softmax_node is None:
59+
return
60+
61+
ret = log_softmax(softmax_node.inputs[0], axis=softmax_node.op.axis)
62+
ret.tag.values_eq_approx = values_eq_approx_remove_inf
63+
64+
# Lift ops that used to be between log and softmax
65+
for op_to_lift, parameters in reversed(ops_to_lift):
66+
ret = op_to_lift(ret, *parameters)
67+
68+
copy_stack_trace(node.outputs, ret)
69+
return [ret]
4070

4171

42-
# This is not registered in stabilize, as it cause some crossentropy
43-
# optimization to not be inserted.
44-
@register_specialize("stabilize", "fast_compile")
72+
@register_stabilize
4573
@node_rewriter([SoftmaxGrad])
4674
def local_logsoftmax_grad(fgraph, node):
4775
"""
@@ -50,9 +78,7 @@ def local_logsoftmax_grad(fgraph, node):
5078
Note: only grad is affected
5179
"""
5280
if (
53-
isinstance(node.op, SoftmaxGrad)
54-
and len(node.inputs) == 2
55-
and node.inputs[0].owner is not None
81+
node.inputs[0].owner is not None
5682
and node.inputs[0].owner.op == true_div
5783
and len(node.inputs[0].owner.inputs) >= 2
5884
and node.inputs[0].owner.inputs[1].owner is not None

Diff for: tests/tensor/rewriting/test_special.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
import scipy.special
34

45
import pytensor
56
from pytensor import shared
@@ -11,7 +12,8 @@
1112
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1213
from pytensor.tensor.math import add, exp, log, true_div
1314
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax
14-
from pytensor.tensor.type import matrix
15+
from pytensor.tensor.subtensor import AdvancedSubtensor, group_indices
16+
from pytensor.tensor.type import matrix, tensor
1517
from tests import unittest_tools as utt
1618

1719

@@ -35,6 +37,33 @@ def test_local_logsoftmax_rewrite(self, axis):
3537
_fast_run_rewrites.rewrite(fgraph)
3638
assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax)
3739
assert check_stack_trace(fgraph, ops_to_check=LogSoftmax)
40+
assert check_stack_trace(fgraph, ops_to_check="all")
41+
42+
@pytest.mark.parametrize("axis", [None, 0, -1])
43+
@pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None)])
44+
@pytest.mark.parametrize("idx1", [None, [0, 1, 1, -1]])
45+
def test_logsoftmax_subtensor_dimshuffle(self, axis, idx0, idx1):
46+
"""Test that stabilization is introduced even when subtensor or dimshuffle operations
47+
are present between log and softmax.
48+
"""
49+
50+
logit_p = matrix("logit_p")
51+
p = softmax(logit_p, axis=axis)
52+
p_indexed = p[(idx0, idx1)]
53+
out = log(p_indexed)
54+
55+
# Don't waste time with C compilation
56+
with config.change_flags(cxx=""):
57+
fn = pytensor.function([logit_p], out)
58+
59+
assert not any(isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes)
60+
61+
# This range would lead to underflow to -inf without the stabilization
62+
test_logit_p = np.array([[-10., -10., 999.], [999., 990., -10.]])
63+
np.testing.assert_allclose(
64+
fn(logit_p=test_logit_p),
65+
scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1)],
66+
)
3867

3968
@pytest.mark.parametrize("axis", [None, 0, -1])
4069
def test_local_logsoftmax_grad_rewrite(self, axis):
@@ -91,28 +120,6 @@ def test_logsoftmax_grad_true_div_elemwise(self):
91120
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
92121

93122

94-
def test_log_softmax_stabilization():
95-
mode = pytensor.compile.mode.get_default_mode()
96-
mode = mode.including("local_log_softmax", "specialize")
97-
98-
x = matrix()
99-
y = softmax(x, axis=-1)
100-
z = log(y)
101-
102-
fgraph = FunctionGraph([x], [z])
103-
_fast_run_rewrites(fgraph)
104-
assert check_stack_trace(fgraph, ops_to_check="all")
105-
106-
# Check that the softmax has been rewritten
107-
for node in fgraph.toposort():
108-
assert not isinstance(node.op, Softmax)
109-
110-
# Call the function so debug mode can verify the rewritten version matches
111-
# the un-rewritten version
112-
f = pytensor.function([x], z, mode=mode)
113-
rng = np.random.default_rng(utt.fetch_seed())
114-
f(np.cast[config.floatX](rng.random((2, 3))))
115-
116123

117124
def test_softmax_graph():
118125
"""Make sure that sotfmax expressions are turned into

0 commit comments

Comments
 (0)