Skip to content

Commit 254c574

Browse files
lucianopaztwiecki
authored andcommitted
Fix for #3354. draw_values now adds the theano graph descendants of TensorConstant or SharedVariables to the named relationship nodes stack, only if these descendants are ObservedRV or MultiObservedRV instances.
* Fix for 3354 * Fixed float32 precision error * Added inline comment explaining why we must add observed RVs to the stack
1 parent 13d0ed6 commit 254c574

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

RELEASE-NOTES.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310).
1212
- The `Wald`, `Kumaraswamy`, `LogNormal`, `Pareto`, `Cauchy`, `HalfCauchy`, `Weibull` and `ExGaussian` distributions `random` method used a hidden `_random` function that was written with scalars in mind. This could potentially lead to artificial correlations between random draws. Added shape guards and broadcasting of the distribution samples to prevent this (Similar to issue #3310).
1313
- Added a fix to allow the imputation of single missing values of observed data, which previously would fail (Fix issue #3122).
14+
- Fix for #3354. `draw_values` now adds the theano graph descendants of `TensorConstant` or `SharedVariables` to the named relationship nodes stack, only if these descendants are `ObservedRV` or `MultiObservedRV` instances.
1415

1516
### Deprecations
1617

pymc3/distributions/distribution.py

+8
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,14 @@ def draw_values(params, point=None, size=None):
366366
# ('Constants not allowed in param list', ...)` for
367367
# TensorConstant, and a `TypeError: Cannot use a shared
368368
# variable (...) as explicit input` for SharedVariable.
369+
# ObservedRV and MultiObservedRV instances are ViewOPs
370+
# of TensorConstants or SharedVariables, we must add them
371+
# to the stack or risk evaluating deterministics with the
372+
# wrong values (issue #3354)
373+
stack.extend([node for node in named_nodes_parents[next_]
374+
if isinstance(node, (ObservedRV,
375+
MultiObservedRV))
376+
and (node, size) not in drawn])
369377
continue
370378
else:
371379
# If the node does not have a givens value, try to draw it.

pymc3/tests/test_sampling.py

+31
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,37 @@ def test_model_not_drawable_prior(self):
302302
samples = pm.sample_posterior_predictive(trace, 50)
303303
assert samples['foo'].shape == (50, 200)
304304

305+
def test_deterministic_of_observed(self):
306+
meas_in_1 = pm.theanof.floatX(2 + 4 * np.random.randn(100))
307+
meas_in_2 = pm.theanof.floatX(5 + 4 * np.random.randn(100))
308+
with pm.Model() as model:
309+
mu_in_1 = pm.Normal('mu_in_1', 0, 1)
310+
sigma_in_1 = pm.HalfNormal('sd_in_1', 1)
311+
mu_in_2 = pm.Normal('mu_in_2', 0, 1)
312+
sigma_in_2 = pm.HalfNormal('sd__in_2', 1)
313+
314+
in_1 = pm.Normal('in_1', mu_in_1, sigma_in_1, observed=meas_in_1)
315+
in_2 = pm.Normal('in_2', mu_in_2, sigma_in_2, observed=meas_in_2)
316+
out_diff = in_1 + in_2
317+
pm.Deterministic('out', out_diff)
318+
319+
trace = pm.sample(100)
320+
ppc_trace = pm.trace_to_dataframe(
321+
trace,
322+
varnames=[n for n in trace.varnames
323+
if n != 'out']
324+
).to_dict('records')
325+
ppc = pm.sample_posterior_predictive(model=model,
326+
trace=ppc_trace,
327+
samples=len(ppc_trace),
328+
vars=(model.deterministics +
329+
model.basic_RVs))
330+
331+
rtol = 1e-5 if theano.config.floatX == 'float64' else 1e-3
332+
assert np.allclose(ppc['in_1'] + ppc['in_2'],
333+
ppc['out'],
334+
rtol=rtol)
335+
305336

306337
class TestSamplePPCW(SeededTest):
307338
def test_sample_posterior_predictive_w(self):

0 commit comments

Comments
 (0)