61
61
from pymc .backends .base import BaseTrace , MultiTrace
62
62
from pymc .backends .ndarray import NDArray
63
63
from pymc .blocking import DictToArrayBijection
64
- from pymc .exceptions import IncorrectArgumentsError , SamplingError
64
+ from pymc .exceptions import SamplingError
65
65
from pymc .initial_point import (
66
66
PointType ,
67
67
StartDict ,
@@ -1769,10 +1769,8 @@ def expand(node):
1769
1769
1770
1770
def sample_posterior_predictive (
1771
1771
trace ,
1772
- samples : Optional [int ] = None ,
1773
1772
model : Optional [Model ] = None ,
1774
1773
var_names : Optional [List [str ]] = None ,
1775
- keep_size : Optional [bool ] = None ,
1776
1774
random_seed : RandomState = None ,
1777
1775
progressbar : bool = True ,
1778
1776
return_inferencedata : bool = True ,
@@ -1788,25 +1786,11 @@ def sample_posterior_predictive(
1788
1786
trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace
1789
1787
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
1790
1788
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
1791
- samples : int
1792
- Number of posterior predictive samples to generate. Defaults to one posterior predictive
1793
- sample per posterior sample, that is, the number of draws times the number of chains.
1794
-
1795
- It is not recommended to modify this value; when modified, some chains may not be
1796
- represented in the posterior predictive sample. Instead, in cases when generating
1797
- posterior predictive samples is too expensive to do it once per posterior sample,
1798
- the recommended approach is to thin the ``trace`` argument
1799
- before passing it to ``sample_posterior_predictive``. In such cases it
1800
- might be advisable to set ``extend_inferencedata`` to ``False`` and extend
1801
- the inferencedata manually afterwards.
1802
1789
model : Model (optional if in ``with`` context)
1803
1790
Model to be used to generate the posterior predictive samples. It will
1804
1791
generally be the model used to generate the ``trace``, but it doesn't need to be.
1805
1792
var_names : Iterable[str]
1806
1793
Names of variables for which to compute the posterior predictive samples.
1807
- keep_size : bool, default True
1808
- Force posterior predictive sample to have the same shape as posterior and sample stats
1809
- data: ``(nchains, ndraws, ...)``. Overrides samples parameter.
1810
1794
random_seed : int, RandomState or Generator, optional
1811
1795
Seed for the random number generator.
1812
1796
progressbar : bool
@@ -1882,38 +1866,18 @@ def sample_posterior_predictive(
1882
1866
else :
1883
1867
raise TypeError (f"Unsupported type for `trace` argument: { type (trace )} ." )
1884
1868
1885
- if keep_size is None :
1886
- # This will allow users to set return_inferencedata=False and
1887
- # automatically get the old behaviour instead of needing to
1888
- # set both return_inferencedata and keep_size to False
1889
- keep_size = return_inferencedata
1890
-
1891
- if keep_size and samples is not None :
1892
- raise IncorrectArgumentsError (
1893
- "Should not specify both keep_size and samples arguments. "
1894
- "See the docstring of the samples argument for more details."
1869
+ if isinstance (_trace , MultiTrace ):
1870
+ samples = sum (len (v ) for v in _trace ._straces .values ())
1871
+ elif isinstance (_trace , list ):
1872
+ # this is a list of points
1873
+ samples = len (_trace )
1874
+ else :
1875
+ raise TypeError (
1876
+ "Do not know how to compute number of samples for trace argument of type %s"
1877
+ % type (_trace )
1895
1878
)
1896
1879
1897
- if samples is None :
1898
- if isinstance (_trace , MultiTrace ):
1899
- samples = sum (len (v ) for v in _trace ._straces .values ())
1900
- elif isinstance (_trace , list ):
1901
- # this is a list of points
1902
- samples = len (_trace )
1903
- else :
1904
- raise TypeError (
1905
- "Do not know how to compute number of samples for trace argument of type %s"
1906
- % type (_trace )
1907
- )
1908
-
1909
1880
assert samples is not None
1910
- if samples < len_trace * nchain :
1911
- warnings .warn (
1912
- "samples parameter is smaller than nchains times ndraws, some draws "
1913
- "and/or chains may not be represented in the returned posterior "
1914
- "predictive sample" ,
1915
- stacklevel = 2 ,
1916
- )
1917
1881
1918
1882
model = modelcontext (model )
1919
1883
@@ -2001,9 +1965,9 @@ def sample_posterior_predictive(
2001
1965
pass
2002
1966
2003
1967
ppc_trace = ppc_trace_t .trace_dict
2004
- if keep_size :
2005
- for k , ary in ppc_trace .items ():
2006
- ppc_trace [k ] = ary .reshape ((nchain , len_trace , * ary .shape [1 :]))
1968
+
1969
+ for k , ary in ppc_trace .items ():
1970
+ ppc_trace [k ] = ary .reshape ((nchain , len_trace , * ary .shape [1 :]))
2007
1971
2008
1972
if not return_inferencedata :
2009
1973
return ppc_trace
0 commit comments