From d1e868ff079b606921e8ca00a6a49cbb203f26c6 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 22 Apr 2021 09:55:10 +0300 Subject: [PATCH 1/2] Add a type guard for `intX` (#4569) * add type guard for inX * fix test for pandas * fix posterior test, ints passed for float data Closes #4279 --- pymc/aesaraf.py | 3 +++ pymc/tests/test_aesaraf.py | 5 +++-- pymc/tests/test_data.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 0fdb53acfd..bcb1bfa00b 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -444,6 +444,9 @@ def intX(X): """ Convert a aesara tensor or numpy array to aesara.tensor.int32 type. """ + # check value is already int, do nothing in this case + if (hasattr(X, "dtype") and "int" in str(X.dtype)) or isinstance(X, int): + return X intX = _conversion_map[aesara.config.floatX] try: return X.astype(intX) diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index f627d932fa..05a9137397 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -244,7 +244,7 @@ def test_convert_observed_data(input_dtype): assert isinstance(aesara_output, Variable) npt.assert_allclose(aesara_output.eval(), aesara_graph_input.eval()) intX = pm.aesaraf._conversion_map[aesara.config.floatX] - if dense_input.dtype == intX or dense_input.dtype == aesara.config.floatX: + if "int" in str(dense_input.dtype) or dense_input.dtype == aesara.config.floatX: assert aesara_output.owner is None # func should not have added new nodes assert aesara_output.name == input_name else: @@ -254,7 +254,8 @@ def test_convert_observed_data(input_dtype): if "float" in input_dtype: assert aesara_output.dtype == aesara.config.floatX else: - assert aesara_output.dtype == intX + # only cast floats, leave ints as is + assert aesara_output.dtype == input_dtype # Check function behavior with generator data generator_output = func(square_generator) diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index 52b18705ba..4536786e88 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -87,7 +87,7 @@ def test_sample_posterior_predictive_after_set_data(self): ) # Predict on new data. with model: - x_test = [5, 6, 9] + x_test = [5.0, 6.0, 9.0] pm.set_data(new_data={"x": x_test}) y_test = pm.sample_posterior_predictive(trace) @@ -111,7 +111,7 @@ def test_sample_posterior_predictive_after_set_data_with_coords(self): ) # Predict on new data. with model: - x_test = [5, 6] + x_test = [5.0, 6.0] pm.set_data(new_data={"x": x_test}, coords={"obs_id": ["a", "b"]}) pm.sample_posterior_predictive(idata, extend_inferencedata=True, predictions=True) From 7c86ed7bbdb3f064dbc92e4a11ce9e1e41b134bb Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 20 Nov 2022 20:07:58 +0100 Subject: [PATCH 2/2] Make discrete moments robust against int-overflows --- pymc/distributions/discrete.py | 8 +++++++- pymc/tests/distributions/util.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index fc57448584..a6db6b70c8 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -878,7 +878,11 @@ def dist(cls, N, k, n, *args, **kwargs): return super().dist([good, bad, n], *args, **kwargs) def moment(rv, size, good, bad, n): - N, k = good + bad, good + # Cast to float because the intX can be int8 + # which could trigger an integer overflow below. + n = floatX(n) + k = floatX(good) + N = k + floatX(bad) mode = at.floor((n + 1) * (k + 1) / (N + 2)) if not rv_size_is_none(size): mode = at.full(size, mode) @@ -1014,6 +1018,8 @@ def dist(cls, lower, upper, *args, **kwargs): return super().dist([lower, upper], **kwargs) def moment(rv, size, lower, upper): + upper = floatX(upper) + lower = floatX(lower) mode = at.maximum(at.floor((upper + lower) / 2.0), lower) if not rv_size_is_none(size): mode = at.full(size, mode) diff --git a/pymc/tests/distributions/util.py b/pymc/tests/distributions/util.py index 0a501da4e8..a6adba6097 100644 --- a/pymc/tests/distributions/util.py +++ b/pymc/tests/distributions/util.py @@ -579,7 +579,7 @@ def assert_moment_is_expected(model, expected, check_finite_logp=True): assert moment.shape == expected.shape assert expected.shape == random_draw.shape - assert np.allclose(moment, expected) + np.testing.assert_allclose(moment, expected, atol=1e-10) if check_finite_logp: logp_moment = (