Skip to content

Add a type guard for intX (#4569) #6319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ def intX(X):
"""
Convert a aesara tensor or numpy array to aesara.tensor.int32 type.
"""
# check value is already int, do nothing in this case
if (hasattr(X, "dtype") and "int" in str(X.dtype)) or isinstance(X, int):
return X
intX = _conversion_map[aesara.config.floatX]
try:
return X.astype(intX)
Expand Down
8 changes: 7 additions & 1 deletion pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,11 @@ def dist(cls, N, k, n, *args, **kwargs):
return super().dist([good, bad, n], *args, **kwargs)

def moment(rv, size, good, bad, n):
N, k = good + bad, good
# Cast to float because the intX can be int8
# which could trigger an integer overflow below.
n = floatX(n)
k = floatX(good)
N = k + floatX(bad)
mode = at.floor((n + 1) * (k + 1) / (N + 2))
if not rv_size_is_none(size):
mode = at.full(size, mode)
Expand Down Expand Up @@ -1014,6 +1018,8 @@ def dist(cls, lower, upper, *args, **kwargs):
return super().dist([lower, upper], **kwargs)

def moment(rv, size, lower, upper):
upper = floatX(upper)
lower = floatX(lower)
mode = at.maximum(at.floor((upper + lower) / 2.0), lower)
if not rv_size_is_none(size):
mode = at.full(size, mode)
Expand Down
2 changes: 1 addition & 1 deletion pymc/tests/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True):

assert moment.shape == expected.shape
assert expected.shape == random_draw.shape
assert np.allclose(moment, expected)
np.testing.assert_allclose(moment, expected, atol=1e-10)

if check_finite_logp:
logp_moment = (
Expand Down
5 changes: 3 additions & 2 deletions pymc/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_convert_observed_data(input_dtype):
assert isinstance(aesara_output, Variable)
npt.assert_allclose(aesara_output.eval(), aesara_graph_input.eval())
intX = pm.aesaraf._conversion_map[aesara.config.floatX]
if dense_input.dtype == intX or dense_input.dtype == aesara.config.floatX:
if "int" in str(dense_input.dtype) or dense_input.dtype == aesara.config.floatX:
assert aesara_output.owner is None # func should not have added new nodes
assert aesara_output.name == input_name
else:
Expand All @@ -254,7 +254,8 @@ def test_convert_observed_data(input_dtype):
if "float" in input_dtype:
assert aesara_output.dtype == aesara.config.floatX
else:
assert aesara_output.dtype == intX
# only cast floats, leave ints as is
assert aesara_output.dtype == input_dtype

# Check function behavior with generator data
generator_output = func(square_generator)
Expand Down
4 changes: 2 additions & 2 deletions pymc/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_sample_posterior_predictive_after_set_data(self):
)
# Predict on new data.
with model:
x_test = [5, 6, 9]
x_test = [5.0, 6.0, 9.0]
pm.set_data(new_data={"x": x_test})
y_test = pm.sample_posterior_predictive(trace)

Expand All @@ -111,7 +111,7 @@ def test_sample_posterior_predictive_after_set_data_with_coords(self):
)
# Predict on new data.
with model:
x_test = [5, 6]
x_test = [5.0, 6.0]
pm.set_data(new_data={"x": x_test}, coords={"obs_id": ["a", "b"]})
pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True)

Expand Down