Skip to content

Commit b2e64cb

Browse files
authored
Fix dependence of Uniform logp on bound method (#4541)
* Fix logp of (Discrete) Uniform to not depend on bound * Add unittest * Remove redundant `all()` bound conditions in multivariate distributions and improve documentation of dist_math::bound * Add recommendation for check_bounds. * Add release note * Include Release Notes from 3.11.2
1 parent 82ccb3b commit b2e64cb

File tree

7 files changed

+75
-30
lines changed

7 files changed

+75
-30
lines changed

RELEASE-NOTES.md

+24-6
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,37 @@
11
# Release Notes
22

3-
## PyMC3 vNext (TBD)
3+
## PyMC3 vNext (4.0.0)
44
### Breaking Changes
55
- ⚠ Theano-PyMC has been replaced with Aesara, so all external references to `theano`, `tt`, and `pymc3.theanof` need to be replaced with `aesara`, `aet`, and `pymc3.aesaraf` (see [4471](https://github.com./pymc-devs/pymc3/pull/4471)).
66

77
### New Features
8-
+ `pm.math.cartesian` can now handle inputs that are themselves >1D (see [#4482](https://github.com./pymc-devs/pymc3/pull/4482)).
9-
+ The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
10-
+ ...
8+
- The `CAR` distribution has been added to allow for use of conditional autoregressions which often are used in spatial and network models.
9+
- ...
1110

1211
### Maintenance
13-
- The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util` (see [#4509](https://github.com./pymc-devs/pymc3/pull/4509)).
1412
- Remove float128 dtype support (see [#4514](https://github.com./pymc-devs/pymc3/pull/4514)).
13+
- Logp method of `Uniform` and `DiscreteUniform` no longer depends on `pymc3.distributions.dist_math.bound` for proper evaluation (see [#4541](https://github.com./pymc-devs/pymc3/pull/4541)).
14+
- ...
15+
16+
## PyMC3 3.11.2 (14 March 2021)
17+
18+
### New Features
19+
+ `pm.math.cartesian` can now handle inputs that are themselves >1D (see [#4482](https://github.com./pymc-devs/pymc3/pull/4482)).
20+
+ Statistics and plotting functions that were removed in `3.11.0` were brought back, albeit with deprecation warnings if an old naming scheme is used (see [#4536](https://github.com./pymc-devs/pymc3/pull/4536)). In order to future proof your code, rename these function calls:
21+
+ `pm.traceplot``pm.plot_trace`
22+
+ `pm.compareplot``pm.plot_compare` (here you might need to rename some columns in the input according to the [`arviz.plot_compare` documentation](https://arviz-devs.github.io/arviz/api/generated/arviz.plot_compare.html))
23+
+ `pm.autocorrplot``pm.plot_autocorr`
24+
+ `pm.forestplot``pm.plot_forest`
25+
+ `pm.kdeplot``pm.plot_kde`
26+
+ `pm.energyplot``pm.plot_energy`
27+
+ `pm.densityplot``pm.plot_density`
28+
+ `pm.pairplot``pm.plot_pair`
29+
30+
### Maintenance
31+
- ⚠ Our memoization mechanism wasn't robust against hash collisions ([#4506](https://github.com./pymc-devs/pymc3/issues/4506)), sometimes resulting in incorrect values in, for example, posterior predictives. The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util` (see [#4525](https://github.com./pymc-devs/pymc3/pull/4525)).
1532
- `pm.make_shared_replacements` now retains broadcasting information which fixes issues with Metropolis samplers (see [#4492](https://github.com./pymc-devs/pymc3/pull/4492)).
16-
+ ...
33+
34+
**Release manager** for 3.11.2: Michael Osthege ([@michaelosthege](https://github.com./michaelosthege))
1735

1836
## PyMC3 3.11.1 (12 February 2021)
1937

pymc3/distributions/continuous.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,11 @@ def logp(self, value):
268268
"""
269269
lower = self.lower
270270
upper = self.upper
271-
return bound(-aet.log(upper - lower), value >= lower, value <= upper)
271+
return bound(
272+
aet.fill(value, -aet.log(upper - lower)),
273+
value >= lower,
274+
value <= upper,
275+
)
272276

273277
def logcdf(self, value):
274278
"""

pymc3/distributions/discrete.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,11 @@ def logp(self, value):
12781278
"""
12791279
upper = self.upper
12801280
lower = self.lower
1281-
return bound(-aet.log(upper - lower + 1), lower <= value, value <= upper)
1281+
return bound(
1282+
aet.fill(value, -aet.log(upper - lower + 1)),
1283+
lower <= value,
1284+
value <= upper,
1285+
)
12821286

12831287
def logcdf(self, value):
12841288
"""

pymc3/distributions/dist_math.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,23 @@
4848
def bound(logp, *conditions, **kwargs):
4949
"""
5050
Bounds a log probability density with several conditions.
51+
When conditions are not met, the logp values are replaced by -inf.
52+
53+
Note that bound should not be used to enforce the logic of the logp under the normal
54+
support as it can be disabled by the user via check_bounds = False in pm.Model()
5155
5256
Parameters
5357
----------
5458
logp: float
5559
*conditions: booleans
5660
broadcast_conditions: bool (optional, default=True)
57-
If True, broadcasts logp to match the largest shape of the conditions.
58-
This is used e.g. in DiscreteUniform where logp is a scalar constant and the shape
59-
is specified via the conditions.
60-
If False, will return the same shape as logp.
61-
This is used e.g. in Multinomial where broadcasting can lead to differences in the logp.
61+
If True, conditions are broadcasted and applied element-wise to each value in logp.
62+
If False, conditions are collapsed via aet.all(). As a consequence the entire logp
63+
array is either replaced by -inf or unchanged.
64+
65+
Setting broadcasts_conditions to False is necessary for most (all?) multivariate
66+
distributions where the dimensions of the conditions do not unambigously match
67+
that of the logp.
6268
6369
Returns
6470
-------

pymc3/distributions/multivariate.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,9 @@ def logp(self, value):
527527
# only defined for sum(value) == 1
528528
return bound(
529529
aet.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(aet.sum(a, axis=-1)),
530-
aet.all(value >= 0),
531-
aet.all(value <= 1),
532-
aet.all(a > 0),
530+
value >= 0,
531+
value <= 1,
532+
a > 0,
533533
broadcast_conditions=False,
534534
)
535535

@@ -671,11 +671,11 @@ def logp(self, x):
671671

672672
return bound(
673673
factln(n) + aet.sum(-factln(x) + logpow(p, x), axis=-1, keepdims=True),
674-
aet.all(x >= 0),
675-
aet.all(aet.eq(aet.sum(x, axis=-1, keepdims=True), n)),
676-
aet.all(p <= 1),
677-
aet.all(aet.eq(aet.sum(p, axis=-1), 1)),
678-
aet.all(aet.ge(n, 0)),
674+
x >= 0,
675+
aet.eq(aet.sum(x, axis=-1, keepdims=True), n),
676+
p <= 1,
677+
aet.eq(aet.sum(p, axis=-1), 1),
678+
n >= 0,
679679
broadcast_conditions=False,
680680
)
681681

@@ -823,10 +823,10 @@ def logp(self, value):
823823
# and that each observation value_i sums to n_i.
824824
return bound(
825825
result,
826-
aet.all(aet.ge(value, 0)),
827-
aet.all(aet.gt(a, 0)),
828-
aet.all(aet.ge(n, 0)),
829-
aet.all(aet.eq(value.sum(axis=-1, keepdims=True), n)),
826+
value >= 0,
827+
a > 0,
828+
n >= 0,
829+
aet.eq(value.sum(axis=-1, keepdims=True), n),
830830
broadcast_conditions=False,
831831
)
832832

@@ -1575,8 +1575,8 @@ def logp(self, x):
15751575
result += (eta - 1.0) * aet.log(det(X))
15761576
return bound(
15771577
result,
1578-
aet.all(X <= 1),
1579-
aet.all(X >= -1),
1578+
X >= -1,
1579+
X <= 1,
15801580
matrix_pos_def(X),
15811581
eta > 0,
15821582
broadcast_conditions=False,
@@ -2204,9 +2204,10 @@ def logp(self, value):
22042204
logquad = (self.tau * delta * tau_dot_delta).sum(axis=-1)
22052205
return bound(
22062206
0.5 * (logtau + logdet - logquad),
2207-
aet.all(self.alpha <= 1),
2208-
aet.all(self.alpha >= -1),
2207+
self.alpha >= -1,
2208+
self.alpha <= 1,
22092209
self.tau > 0,
2210+
broadcast_conditions=False,
22102211
)
22112212

22122213
def random(self, point=None, size=None):

pymc3/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,8 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
817817
Ensure that input parameters to distributions are in a valid
818818
range. If your model is built in a way where you know your
819819
parameters can only take on valid values you can set this to
820-
False for increased speed.
820+
False for increased speed. This should not be used if your model
821+
contains discrete variables.
821822
822823
Examples
823824
--------

pymc3/tests/test_distributions.py

+11
Original file line numberDiff line numberDiff line change
@@ -2667,6 +2667,17 @@ def test_issue_3051(self, dims, dist_cls, kwargs):
26672667
assert actual_a.shape == (X.shape[0],)
26682668
pass
26692669

2670+
def test_issue_4499(self):
2671+
# Test for bug in Uniform and DiscreteUniform logp when setting check_bounds = False
2672+
# https://github.com./pymc-devs/pymc3/issues/4499
2673+
with pm.Model(check_bounds=False) as m:
2674+
x = pm.Uniform("x", 0, 2, shape=10, transform=None)
2675+
assert_almost_equal(m.logp_array(np.ones(10)), -np.log(2) * 10)
2676+
2677+
with pm.Model(check_bounds=False) as m:
2678+
x = pm.DiscreteUniform("x", 0, 1, shape=10)
2679+
assert_almost_equal(m.logp_array(np.ones(10)), -np.log(2) * 10)
2680+
26702681

26712682
def test_serialize_density_dist():
26722683
def func(x):

0 commit comments

Comments
 (0)