Skip to content

Commit d47dac0

Browse files
pibietaricardoV94OriolAbril
authored
Remove samples and keep_size from sample_posterior_predictive (#6029)
* removed unused imports Co-authored-by: vitaliset <[email protected]> * Update pymc/tests/test_sampling.py added ricardo's suggestion 2 Co-authored-by: Ricardo Vieira <[email protected]> * edited `test_mixture.py` and `test_sampling.py` * try fixing some tests * try fixing tests * fix one more test Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Oriol (ZBook) <[email protected]>
1 parent 52a1b3c commit d47dac0

File tree

6 files changed

+85
-185
lines changed

6 files changed

+85
-185
lines changed

pymc/sampling.py

+13-49
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from pymc.backends.base import BaseTrace, MultiTrace
6262
from pymc.backends.ndarray import NDArray
6363
from pymc.blocking import DictToArrayBijection
64-
from pymc.exceptions import IncorrectArgumentsError, SamplingError
64+
from pymc.exceptions import SamplingError
6565
from pymc.initial_point import (
6666
PointType,
6767
StartDict,
@@ -1769,10 +1769,8 @@ def expand(node):
17691769

17701770
def sample_posterior_predictive(
17711771
trace,
1772-
samples: Optional[int] = None,
17731772
model: Optional[Model] = None,
17741773
var_names: Optional[List[str]] = None,
1775-
keep_size: Optional[bool] = None,
17761774
random_seed: RandomState = None,
17771775
progressbar: bool = True,
17781776
return_inferencedata: bool = True,
@@ -1788,25 +1786,11 @@ def sample_posterior_predictive(
17881786
trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace
17891787
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
17901788
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
1791-
samples : int
1792-
Number of posterior predictive samples to generate. Defaults to one posterior predictive
1793-
sample per posterior sample, that is, the number of draws times the number of chains.
1794-
1795-
It is not recommended to modify this value; when modified, some chains may not be
1796-
represented in the posterior predictive sample. Instead, in cases when generating
1797-
posterior predictive samples is too expensive to do it once per posterior sample,
1798-
the recommended approach is to thin the ``trace`` argument
1799-
before passing it to ``sample_posterior_predictive``. In such cases it
1800-
might be advisable to set ``extend_inferencedata`` to ``False`` and extend
1801-
the inferencedata manually afterwards.
18021789
model : Model (optional if in ``with`` context)
18031790
Model to be used to generate the posterior predictive samples. It will
18041791
generally be the model used to generate the ``trace``, but it doesn't need to be.
18051792
var_names : Iterable[str]
18061793
Names of variables for which to compute the posterior predictive samples.
1807-
keep_size : bool, default True
1808-
Force posterior predictive sample to have the same shape as posterior and sample stats
1809-
data: ``(nchains, ndraws, ...)``. Overrides samples parameter.
18101794
random_seed : int, RandomState or Generator, optional
18111795
Seed for the random number generator.
18121796
progressbar : bool
@@ -1882,38 +1866,18 @@ def sample_posterior_predictive(
18821866
else:
18831867
raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.")
18841868

1885-
if keep_size is None:
1886-
# This will allow users to set return_inferencedata=False and
1887-
# automatically get the old behaviour instead of needing to
1888-
# set both return_inferencedata and keep_size to False
1889-
keep_size = return_inferencedata
1890-
1891-
if keep_size and samples is not None:
1892-
raise IncorrectArgumentsError(
1893-
"Should not specify both keep_size and samples arguments. "
1894-
"See the docstring of the samples argument for more details."
1869+
if isinstance(_trace, MultiTrace):
1870+
samples = sum(len(v) for v in _trace._straces.values())
1871+
elif isinstance(_trace, list):
1872+
# this is a list of points
1873+
samples = len(_trace)
1874+
else:
1875+
raise TypeError(
1876+
"Do not know how to compute number of samples for trace argument of type %s"
1877+
% type(_trace)
18951878
)
18961879

1897-
if samples is None:
1898-
if isinstance(_trace, MultiTrace):
1899-
samples = sum(len(v) for v in _trace._straces.values())
1900-
elif isinstance(_trace, list):
1901-
# this is a list of points
1902-
samples = len(_trace)
1903-
else:
1904-
raise TypeError(
1905-
"Do not know how to compute number of samples for trace argument of type %s"
1906-
% type(_trace)
1907-
)
1908-
19091880
assert samples is not None
1910-
if samples < len_trace * nchain:
1911-
warnings.warn(
1912-
"samples parameter is smaller than nchains times ndraws, some draws "
1913-
"and/or chains may not be represented in the returned posterior "
1914-
"predictive sample",
1915-
stacklevel=2,
1916-
)
19171881

19181882
model = modelcontext(model)
19191883

@@ -2001,9 +1965,9 @@ def sample_posterior_predictive(
20011965
pass
20021966

20031967
ppc_trace = ppc_trace_t.trace_dict
2004-
if keep_size:
2005-
for k, ary in ppc_trace.items():
2006-
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
1968+
1969+
for k, ary in ppc_trace.items():
1970+
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
20071971

20081972
if not return_inferencedata:
20091973
return ppc_trace

pymc/tests/backends/test_arviz.py

+3-23
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def get_predictions_inference_data(
8989
with data.model:
9090
prior = pm.sample_prior_predictive(return_inferencedata=False)
9191
posterior_predictive = pm.sample_posterior_predictive(
92-
data.obj, keep_size=True, return_inferencedata=False
92+
data.obj, return_inferencedata=False
9393
)
9494

9595
idata = to_inference_data(
@@ -111,7 +111,7 @@ def make_predictions_inference_data(
111111
) -> Tuple[InferenceData, Dict[str, np.ndarray]]:
112112
with data.model:
113113
posterior_predictive = pm.sample_posterior_predictive(
114-
data.obj, keep_size=True, return_inferencedata=False
114+
data.obj, return_inferencedata=False
115115
)
116116
idata = predictions_to_inference_data(
117117
posterior_predictive,
@@ -190,7 +190,7 @@ def test_predictions_to_idata_new(self, data, eight_schools_params):
190190
def test_posterior_predictive_keep_size(self, data, chains, draws, eight_schools_params):
191191
with data.model:
192192
posterior_predictive = pm.sample_posterior_predictive(
193-
data.obj, keep_size=True, return_inferencedata=False
193+
data.obj, return_inferencedata=False
194194
)
195195
inference_data = to_inference_data(
196196
trace=data.obj,
@@ -204,26 +204,6 @@ def test_posterior_predictive_keep_size(self, data, chains, draws, eight_schools
204204
[obs_s == s for obs_s, s in zip(shape, (chains, draws, eight_schools_params["J"]))]
205205
)
206206

207-
def test_posterior_predictive_warning(self, data, eight_schools_params, caplog):
208-
with data.model:
209-
with warnings.catch_warnings():
210-
warnings.filterwarnings(
211-
"ignore", ".*smaller than nchains times ndraws.*", UserWarning
212-
)
213-
posterior_predictive = pm.sample_posterior_predictive(
214-
data.obj, 370, return_inferencedata=False, keep_size=False
215-
)
216-
with pytest.warns(UserWarning, match="shape of variables"):
217-
inference_data = to_inference_data(
218-
trace=data.obj,
219-
posterior_predictive=posterior_predictive,
220-
coords={"school": np.arange(eight_schools_params["J"])},
221-
dims={"theta": ["school"], "eta": ["school"]},
222-
)
223-
224-
shape = inference_data.posterior_predictive.obs.shape
225-
assert np.all([obs_s == s for obs_s, s in zip(shape, (1, 370, eight_schools_params["J"]))])
226-
227207
def test_posterior_predictive_thinned(self, data):
228208
with data.model:
229209
draws = 20

pymc/tests/distributions/test_mixture.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -546,18 +546,18 @@ def test_single_poisson_predictive_sampling_shape(self):
546546
with model:
547547
prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False)
548548
ppc = sample_posterior_predictive(
549-
[self.get_inital_point(model)], samples=n_samples, return_inferencedata=False
549+
n_samples * [self.get_inital_point(model)], return_inferencedata=False
550550
)
551551

552552
assert prior["like0"].shape == (n_samples, 20)
553553
assert prior["like1"].shape == (n_samples, 20)
554554
assert prior["like2"].shape == (n_samples, 20)
555555
assert prior["like3"].shape == (n_samples, 20)
556556

557-
assert ppc["like0"].shape == (n_samples, 20)
558-
assert ppc["like1"].shape == (n_samples, 20)
559-
assert ppc["like2"].shape == (n_samples, 20)
560-
assert ppc["like3"].shape == (n_samples, 20)
557+
assert ppc["like0"].shape == (1, n_samples, 20)
558+
assert ppc["like1"].shape == (1, n_samples, 20)
559+
assert ppc["like2"].shape == (1, n_samples, 20)
560+
assert ppc["like3"].shape == (1, n_samples, 20)
561561

562562
def test_list_mvnormals_predictive_sampling_shape(self):
563563
N = 100 # number of data points
@@ -592,9 +592,16 @@ def test_list_mvnormals_predictive_sampling_shape(self):
592592
with model:
593593
prior = sample_prior_predictive(samples=n_samples, return_inferencedata=False)
594594
ppc = sample_posterior_predictive(
595-
[self.get_inital_point(model)], samples=n_samples, return_inferencedata=False
595+
n_samples * [self.get_inital_point(model)], return_inferencedata=False
596596
)
597-
assert ppc["x_obs"].shape == (n_samples,) + X.shape
597+
assert (
598+
ppc["x_obs"].shape
599+
== (
600+
1,
601+
n_samples,
602+
)
603+
+ X.shape
604+
)
598605
assert prior["x_obs"].shape == (n_samples,) + X.shape
599606
assert prior["mu0"].shape == (n_samples, D)
600607
assert prior["chol_cov_0"].shape == (n_samples, D * (D + 1) // 2)

pymc/tests/distributions/test_simulator.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ def test_one_gaussian(self):
8080
with self.SMABC_test:
8181
trace = pm.sample_smc(draws=1000, chains=1, return_inferencedata=False)
8282
pr_p = pm.sample_prior_predictive(1000, return_inferencedata=False)
83-
po_p = pm.sample_posterior_predictive(
84-
trace, keep_size=False, return_inferencedata=False
85-
)
83+
po_p = pm.sample_posterior_predictive(trace, return_inferencedata=False)
8684

8785
assert abs(self.data.mean() - trace["a"].mean()) < 0.05
8886
assert abs(self.data.std() - trace["b"].mean()) < 0.05
@@ -91,7 +89,7 @@ def test_one_gaussian(self):
9189
assert abs(0 - pr_p["s"].mean()) < 0.15
9290
assert abs(1.4 - pr_p["s"].std()) < 0.10
9391

94-
assert po_p["s"].shape == (1000, 1000)
92+
assert po_p["s"].shape == (1, 1000, 1000)
9593
assert abs(self.data.mean() - po_p["s"].mean()) < 0.10
9694
assert abs(self.data.std() - po_p["s"].std()) < 0.10
9795

pymc/tests/test_model.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1263,13 +1263,18 @@ def test_interval_missing_observations():
12631263

12641264
# Make sure that the observed values are newly generated samples and that
12651265
# the observed and deterministic matche
1266-
pp_trace = pm.sample_posterior_predictive(
1267-
trace, return_inferencedata=False, keep_size=False
1266+
pp_idata = pm.sample_posterior_predictive(trace)
1267+
pp_trace = pp_idata.posterior_predictive.stack(sample=["chain", "draw"]).transpose(
1268+
"sample", ...
12681269
)
12691270
assert np.all(np.var(pp_trace["theta1"], 0) > 0.0)
12701271
assert np.all(np.var(pp_trace["theta2"], 0) > 0.0)
1271-
assert np.mean(pp_trace["theta1"][:, ~obs1.mask] - pp_trace["theta1_observed"]) == 0.0
1272-
assert np.mean(pp_trace["theta2"][:, ~obs2.mask] - pp_trace["theta2_observed"]) == 0.0
1272+
assert np.isclose(
1273+
np.mean(pp_trace["theta1"][:, ~obs1.mask] - pp_trace["theta1_observed"]), 0
1274+
)
1275+
assert np.isclose(
1276+
np.mean(pp_trace["theta2"][:, ~obs2.mask] - pp_trace["theta2_observed"]), 0
1277+
)
12731278

12741279

12751280
def test_double_counting():

0 commit comments

Comments
 (0)