Skip to content

Commit 623d6a3

Browse files
brandonwillardtwiecki
authored andcommitted
Make sure new size values are int64
Closes #4652.
1 parent 26a5787 commit 623d6a3

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc3/aesaraf.py

+4
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ def change_rv_size(
156156
size = rv_node.op._infer_shape(size, dist_params)
157157
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)
158158

159+
# Make sure the new size is int64 so that it doesn't unnecessarily pick
160+
# up a `Cast` in some cases
161+
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
162+
159163
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)
160164
rv_var = new_rv_node.outputs[-1]
161165
rv_var.name = name

pymc3/tests/test_aesaraf.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pytest
2424
import scipy.sparse as sps
2525

26-
from aesara.graph.basic import Variable, ancestors
26+
from aesara.graph.basic import Constant, Variable, ancestors
2727
from aesara.tensor.random.basic import normal, uniform
2828
from aesara.tensor.random.op import RandomVariable
2929
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -67,6 +67,15 @@ def test_change_rv_size():
6767
assert rv_newer.ndim == 3
6868
assert rv_newer.eval().shape == (4, 3, 2)
6969

70+
# Make sure we avoid introducing a `Cast` by converting the new size before
71+
# constructing the new `RandomVariable`
72+
rv = normal(0, 1)
73+
new_size = np.array([4, 3], dtype="int32")
74+
rv_newer = change_rv_size(rv, new_size=new_size, expand=False)
75+
assert rv_newer.ndim == 2
76+
assert isinstance(rv_newer.owner.inputs[1], Constant)
77+
assert rv_newer.eval().shape == (4, 3)
78+
7079

7180
class TestBroadcasting:
7281
def test_make_shared_replacements(self):

0 commit comments

Comments
 (0)