Skip to content

Commit 225add5

Browse files
committed
Implement rewrite to stabilize log(softmax[idx])
1 parent df4183d commit 225add5

File tree

2 files changed

+172
-14
lines changed

2 files changed

+172
-14
lines changed

Diff for: pytensor/tensor/rewriting/special.py

+115-13
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,51 @@
1-
from pytensor import scalar as aes
1+
from numpy.core.numeric import normalize_axis_index # type: ignore
2+
3+
from pytensor import Variable
24
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
3-
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4-
from pytensor.tensor.math import Sum, exp
5+
from pytensor.tensor.basic import expand_dims
6+
from pytensor.tensor.elemwise import DimShuffle
7+
from pytensor.tensor.extra_ops import squeeze
8+
from pytensor.tensor.math import Sum, exp, log, logsumexp
59
from pytensor.tensor.math import sum as at_sum
610
from pytensor.tensor.math import true_div
7-
from pytensor.tensor.rewriting.basic import register_specialize
8-
from pytensor.tensor.rewriting.math import local_mul_canonizer
11+
from pytensor.tensor.rewriting.basic import register_specialize, register_stabilize
12+
from pytensor.tensor.rewriting.math import local_log_sum_exp, local_mul_canonizer
13+
from pytensor.tensor.rewriting.subtensor import is_full_slice
914
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
10-
from pytensor.tensor.subtensor import AdvancedIncSubtensor
15+
from pytensor.tensor.subtensor import (
16+
AdvancedIncSubtensor,
17+
AdvancedSubtensor,
18+
AdvancedSubtensor1,
19+
Subtensor,
20+
indices_from_subtensor,
21+
is_basic_idx,
22+
)
1123
from pytensor.tensor.type import (
1224
values_eq_approx_remove_inf,
1325
values_eq_approx_remove_nan,
1426
)
27+
from pytensor.tensor.type_other import NoneTypeT
28+
29+
30+
subtensor_ops = (
31+
Subtensor,
32+
AdvancedSubtensor,
33+
AdvancedSubtensor1,
34+
)
1535

1636

1737
# This is not registered in stabilize, as it cause some crossentropy
1838
# optimization to not be inserted.
1939
@register_specialize("stabilize", "fast_compile")
20-
@node_rewriter([Elemwise])
40+
@node_rewriter([log])
2141
def local_logsoftmax(fgraph, node):
2242
"""
2343
Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
2444
2545
Note: only forward pass is affected
2646
"""
27-
if (
28-
isinstance(node.op, Elemwise)
29-
and isinstance(node.op.scalar_op, aes.Log)
30-
and len(node.inputs) == 1
31-
and node.inputs[0].owner is not None
32-
and isinstance(node.inputs[0].owner.op, Softmax)
47+
if node.inputs[0].owner is not None and isinstance(
48+
node.inputs[0].owner.op, Softmax
3349
):
3450
inVars = node.inputs[0].owner.inputs[0]
3551
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis)
@@ -39,6 +55,92 @@ def local_logsoftmax(fgraph, node):
3955
return [ret]
4056

4157

58+
@register_stabilize
59+
@node_rewriter([log])
60+
def local_log_subtensor_softmax(fgraph, node):
61+
"""Replace log(softmax(x, axis)[idx]) -> x[idx] - logsumexp(x, axis).
62+
63+
This can only be done when indexing happens over axis dims.
64+
There can be non-indexed axis dims, but not non-axis indexed dims.
65+
"""
66+
[subtensor_var] = node.inputs
67+
subtensor_node = subtensor_var.owner
68+
69+
if subtensor_node is not None and isinstance(subtensor_node.op, subtensor_ops):
70+
softmax_var, *idxs = subtensor_node.inputs
71+
softmax_node = softmax_var.owner
72+
if softmax_node is not None and isinstance(softmax_node.op, Softmax):
73+
if isinstance(subtensor_node.op, Subtensor):
74+
idxs = indices_from_subtensor(idxs, subtensor_node.op.idx_list)
75+
76+
# TODO: support expand_dims
77+
if any(
78+
(isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT))
79+
for idx in idxs
80+
):
81+
return None
82+
83+
[x] = softmax_node.inputs
84+
axis = softmax_node.op.axis
85+
if axis is not None:
86+
axis = normalize_axis_index(axis, ndim=x.type.ndim)
87+
88+
indexed_dims = [
89+
dim for dim, idx in enumerate(idxs) if not is_full_slice(idx)
90+
]
91+
92+
# We can only apply the rewrite when the softmax is applied across all indexed dims
93+
if axis is not None and {axis} != set(indexed_dims):
94+
return None
95+
96+
dims_to_expand = ()
97+
dims_to_drop = ()
98+
if isinstance(subtensor_node.op, Subtensor):
99+
dims_to_drop = tuple(
100+
dim for dim, idx in enumerate(idxs) if getattr(idx, "ndim", -1) == 0
101+
)
102+
if isinstance(subtensor_node.op, (AdvancedSubtensor, AdvancedSubtensor1)):
103+
adv_dims_idxs = tuple(
104+
(dim, idx) for dim, idx in enumerate(idxs) if not is_basic_idx(idx)
105+
)
106+
adv_dims = tuple(dim for dim, idx in adv_dims_idxs)
107+
adv_idxs = tuple(idx for dim, idx in adv_dims_idxs)
108+
109+
# Boolean indexing not supported
110+
if any(idx.dtype == "bool" for idx in adv_idxs):
111+
return None
112+
113+
# Non-contiguous advanced indexing not supported
114+
if tuple(range(adv_dims[0], adv_dims[-1] + 1)) != adv_dims:
115+
return None
116+
117+
ndim_adv_idx = max(idx.ndim for idx in adv_idxs)
118+
n_new_dims = ndim_adv_idx - len(adv_idxs)
119+
# Advanced indexing introduces new dims
120+
if n_new_dims > 0:
121+
dims_to_expand = tuple(range(adv_dims[0], adv_dims[0] + n_new_dims))
122+
# It reduces number of dims
123+
elif n_new_dims < 0:
124+
dims_to_drop = tuple(
125+
range(adv_dims[0], adv_dims[0] + abs(n_new_dims))
126+
)
127+
128+
# Rewrite stable form of logsumexp immediately
129+
[x_logsumexp] = local_log_sum_exp.transform(
130+
None, logsumexp(x, axis=axis, keepdims=True).owner
131+
)
132+
133+
assert not (dims_to_drop and dims_to_expand)
134+
if dims_to_expand:
135+
x_logsumexp = expand_dims(x_logsumexp, dims_to_expand)
136+
elif dims_to_drop:
137+
x_logsumexp = squeeze(x_logsumexp, axis=dims_to_drop)
138+
ret = x[tuple(idxs)] - x_logsumexp
139+
140+
copy_stack_trace(node.outputs, ret)
141+
return [ret]
142+
143+
42144
# This is not registered in stabilize, as it cause some crossentropy
43145
# optimization to not be inserted.
44146
@register_specialize("stabilize", "fast_compile")

Diff for: tests/tensor/rewriting/test_special.py

+57-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
import scipy.special
34

45
import pytensor
56
from pytensor import shared
@@ -11,7 +12,8 @@
1112
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1213
from pytensor.tensor.math import add, exp, log, true_div
1314
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, softmax
14-
from pytensor.tensor.type import matrix
15+
from pytensor.tensor.subtensor import AdvancedSubtensor, group_indices
16+
from pytensor.tensor.type import matrix, tensor
1517
from tests import unittest_tools as utt
1618

1719

@@ -130,3 +132,57 @@ def f(inputs):
130132
return pytensor.grad(None, x, known_grads={y: inputs})
131133

132134
utt.verify_grad(f, [rng.random((3, 4))])
135+
136+
137+
def _is_non_contiguous_adv_indexing(index_var):
138+
if not isinstance(index_var.owner.op, AdvancedSubtensor):
139+
return False
140+
idx_groups = group_indices(index_var.owner.inputs[1:])
141+
return len(idx_groups) > 3 or (len(idx_groups) == 3 and not idx_groups[0][0])
142+
143+
144+
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
145+
@pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None), None, [0, 1, 1, -1]])
146+
@pytest.mark.parametrize("idx1", [0, slice(1, None), slice(None), None, [0, 1, 1, -1]])
147+
@pytest.mark.parametrize(
148+
"idx2", [0, slice(1, None), slice(None), None, [[0, 1, 1, -1], [-1, 1, 1, 0]]]
149+
)
150+
def test_log_subtensor_softmax(axis, idx0, idx1, idx2):
151+
logit_p = tensor("logit_p", shape=(4, 3, 5))
152+
p = softmax(logit_p, axis=axis)
153+
p_indexed = p[(idx0, idx1, idx2)]
154+
out = log(p_indexed)
155+
156+
# Don't waste time with C compilation
157+
with config.change_flags(cxx=""):
158+
fn = pytensor.function([logit_p], out)
159+
160+
rewrite_applies = True
161+
if _is_non_contiguous_adv_indexing(p_indexed):
162+
rewrite_applies = False
163+
else:
164+
if idx0 is None or idx1 is None or idx2 is None:
165+
# Not yet implemented!
166+
rewrite_applies = False
167+
elif axis is not None:
168+
indexed_dims = {
169+
dim for dim, idx in enumerate((idx0, idx1, idx2)) if idx != slice(None)
170+
}
171+
# If no indexed dims, the rewrite doesn't actually apply
172+
# but the log_softmax stabilization kicks-in and the output is also stable
173+
if indexed_dims:
174+
rewrite_applies = {axis} == indexed_dims
175+
176+
# assert any(isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes) != rewrite_applies
177+
178+
if not rewrite_applies:
179+
return
180+
181+
# This range would lead to underflow to -inf without the stabilization
182+
logit_ps = np.array([0.0, 1.0, 2.0, 3.0, 999.0])
183+
rng = np.random.default_rng(156)
184+
test_logit_p = rng.choice(logit_ps, size=(4, 3, 5))
185+
np.testing.assert_allclose(
186+
fn(logit_p=test_logit_p),
187+
scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1, idx2)],
188+
)

0 commit comments

Comments
 (0)