Skip to content

Commit 43bd711

Browse files
Make transform objects stateless
1 parent 6a4cdd6 commit 43bd711

File tree

8 files changed

+342
-319
lines changed

8 files changed

+342
-319
lines changed

pymc3/backends/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, name, model=None, vars=None, test_point=None):
6868
if transform:
6969
# We need to create and add an un-transformed version of
7070
# each transformed variable
71-
untrans_var = transform.backward(var)
71+
untrans_var = transform.backward(v, var)
7272
untrans_var.name = v.name
7373
vars.append(untrans_var)
7474
vars.append(var)

pymc3/distributions/__init__.py

+56-37
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from functools import singledispatch
1517
from itertools import chain
1618
from typing import Generator, List, Optional, Tuple, Union
@@ -20,7 +22,7 @@
2022

2123
from aesara import config
2224
from aesara.graph.basic import Variable, ancestors, clone_replace
23-
from aesara.graph.op import compute_test_value
25+
from aesara.graph.op import Op, compute_test_value
2426
from aesara.tensor.random.op import Observed, RandomVariable
2527
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
2628
from aesara.tensor.var import TensorVariable
@@ -33,7 +35,7 @@
3335

3436

3537
@singledispatch
36-
def logp_transform(op, inputs):
38+
def logp_transform(op: Op):
3739
return None
3840

3941

@@ -141,7 +143,8 @@ def change_rv_size(
141143

142144
def rv_log_likelihood_args(
143145
rv_var: TensorVariable,
144-
transformed: Optional[bool] = True,
146+
*,
147+
return_observations: bool = True,
145148
) -> Tuple[TensorVariable, TensorVariable]:
146149
"""Get a `RandomVariable` and its corresponding log-likelihood `TensorVariable` value.
147150
@@ -151,8 +154,9 @@ def rv_log_likelihood_args(
151154
A variable corresponding to a `RandomVariable`, whether directly or
152155
indirectly (e.g. an observed variable that's the output of an
153156
`Observed` `Op`).
154-
transformed
155-
When ``True``, return the transformed value var.
157+
return_observations
158+
When ``True``, return the observed values in place of the log-likelihood
159+
value variable.
156160
157161
Returns
158162
=======
@@ -163,12 +167,14 @@ def rv_log_likelihood_args(
163167
"""
164168

165169
if rv_var.owner and isinstance(rv_var.owner.op, Observed):
166-
return tuple(rv_var.owner.inputs)
167-
elif hasattr(rv_var.tag, "value_var"):
168-
rv_value = rv_var.tag.value_var
169-
return rv_var, rv_value
170-
else:
171-
return rv_var, None
170+
rv_var, obs_var = rv_var.owner.inputs
171+
if return_observations:
172+
return rv_var, obs_var
173+
else:
174+
return rv_var, rv_log_likelihood_args(rv_var)[1]
175+
176+
rv_value = getattr(rv_var.tag, "value_var", None)
177+
return rv_var, rv_value
172178

173179

174180
def rv_ancestors(graphs: List[TensorVariable]) -> Generator[TensorVariable, None, None]:
@@ -217,7 +223,7 @@ def sample_to_measure_vars(
217223
if not (anc.owner and isinstance(anc.owner.op, RandomVariable)):
218224
continue
219225

220-
_, value_var = rv_log_likelihood_args(anc)
226+
_, value_var = rv_log_likelihood_args(anc, return_observations=False)
221227

222228
if value_var is not None:
223229
replace[anc] = value_var
@@ -233,8 +239,10 @@ def sample_to_measure_vars(
233239
def logpt(
234240
rv_var: TensorVariable,
235241
rv_value: Optional[TensorVariable] = None,
236-
jacobian: Optional[bool] = True,
237-
scaling: Optional[bool] = True,
242+
*,
243+
jacobian: bool = True,
244+
scaling: bool = True,
245+
transformed: bool = True,
238246
**kwargs,
239247
) -> TensorVariable:
240248
"""Create a measure-space (i.e. log-likelihood) graph for a random variable at a given point.
@@ -257,6 +265,8 @@ def logpt(
257265
Whether or not to include the Jacobian term.
258266
scaling
259267
A scaling term to apply to the generated log-likelihood graph.
268+
transformed
269+
Apply transforms.
260270
261271
"""
262272

@@ -282,22 +292,22 @@ def logpt(
282292

283293
raise NotImplementedError("Missing value support is incomplete")
284294

285-
# "Flatten" and sum an array of indexed RVs' log-likelihoods
286-
rv_var, missing_values = rv_node.inputs
287-
288-
missing_values = missing_values.data
289-
logp_var = aet.sum(
290-
[
291-
logpt(
292-
rv_var,
293-
)
294-
for idx, missing in zip(
295-
np.ndindex(missing_values.shape), missing_values.flatten()
296-
)
297-
if missing
298-
]
299-
)
300-
return logp_var
295+
# # "Flatten" and sum an array of indexed RVs' log-likelihoods
296+
# rv_var, missing_values = rv_node.inputs
297+
#
298+
# missing_values = missing_values.data
299+
# logp_var = aet.sum(
300+
# [
301+
# logpt(
302+
# rv_var,
303+
# )
304+
# for idx, missing in zip(
305+
# np.ndindex(missing_values.shape), missing_values.flatten()
306+
# )
307+
# if missing
308+
# ]
309+
# )
310+
# return logp_var
301311

302312
return aet.zeros_like(rv_var)
303313

@@ -312,15 +322,16 @@ def logpt(
312322
# If any of the measure vars are transformed measure-space variables
313323
# (signified by having a `transform` value in their tags), then we apply
314324
# the their transforms and add their Jacobians (when enabled)
315-
if transform:
316-
logp_var = _logp(rv_node.op, transform.backward(rv_value), *dist_params, **kwargs)
325+
if transform and transformed:
326+
logp_var = _logp(rv_node.op, transform.backward(rv_var, rv_value), *dist_params, **kwargs)
327+
317328
logp_var = transform_logp(
318329
logp_var,
319330
tuple(replacements.values()),
320331
)
321332

322333
if jacobian:
323-
transformed_jacobian = transform.jacobian_det(rv_value)
334+
transformed_jacobian = transform.jacobian_det(rv_var, rv_value)
324335
if transformed_jacobian:
325336
if logp_var.ndim > transformed_jacobian.ndim:
326337
logp_var = logp_var.sum(axis=-1)
@@ -345,11 +356,17 @@ def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> Te
345356
for measure_var in inputs:
346357

347358
transform = getattr(measure_var.tag, "transform", None)
359+
rv_var = getattr(measure_var.tag, "rv_var", None)
360+
361+
if transform is not None and rv_var is None:
362+
warnings.warn(
363+
f"A transform was found for {measure_var} but not a corresponding random variable"
364+
)
348365

349-
if transform is None:
366+
if transform is None or rv_var is None:
350367
continue
351368

352-
trans_rv_value = transform.backward(measure_var)
369+
trans_rv_value = transform.backward(rv_var, measure_var)
353370
trans_replacements[measure_var] = trans_rv_value
354371

355372
if trans_replacements:
@@ -359,7 +376,7 @@ def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> Te
359376

360377

361378
@singledispatch
362-
def _logp(op, value, *dist_params, **kwargs):
379+
def _logp(op: Op, value: TensorVariable, *dist_params, **kwargs):
363380
"""Create a log-likelihood graph.
364381
365382
This function dispatches on the type of `op`, which should be a subclass
@@ -370,7 +387,9 @@ def _logp(op, value, *dist_params, **kwargs):
370387
return aet.zeros_like(value)
371388

372389

373-
def logcdf(rv_var, rv_value, jacobian=True, **kwargs):
390+
def logcdf(
391+
rv_var: TensorVariable, rv_value: Optional[TensorVariable], jacobian: bool = True, **kwargs
392+
):
374393
"""Create a log-CDF graph."""
375394

376395
rv_var, _ = rv_log_likelihood_args(rv_var)

pymc3/distributions/continuous.py

+10-17
Original file line numberDiff line numberDiff line change
@@ -104,31 +104,24 @@ class BoundedContinuous(Continuous):
104104

105105

106106
@logp_transform.register(PositiveContinuous)
107-
def pos_cont_transform(op, rv_var):
107+
def pos_cont_transform(op):
108108
return transforms.log
109109

110110

111111
@logp_transform.register(UnitContinuous)
112-
def unit_cont_transform(op, rv_var):
112+
def unit_cont_transform(op):
113113
return transforms.logodds
114114

115115

116116
@logp_transform.register(BoundedContinuous)
117-
def bounded_cont_transform(op, rv_var):
118-
_, _, _, lower, upper = rv_var.owner.inputs
119-
lower = aet.as_tensor_variable(lower) if lower is not None else None
120-
upper = aet.as_tensor_variable(upper) if upper is not None else None
121-
122-
if lower is None and upper is None:
123-
transform = None
124-
elif lower is not None and upper is None:
125-
transform = transforms.lowerbound(lower)
126-
elif lower is None and upper is not None:
127-
transform = transforms.upperbound(upper)
128-
else:
129-
transform = transforms.interval(lower, upper)
130-
131-
return transform
117+
def bounded_cont_transform(op):
118+
def transform_params(rv_var):
119+
_, _, _, lower, upper = rv_var.owner.inputs
120+
lower = aet.as_tensor_variable(lower) if lower is not None else None
121+
upper = aet.as_tensor_variable(upper) if upper is not None else None
122+
return lower, upper
123+
124+
return transforms.interval(transform_params)
132125

133126

134127
def assert_negative_support(var, label, distname, value=-1e-6):

pymc3/distributions/multivariate.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):
125125

126126

127127
def quaddist_chol(delta, chol_mat):
128-
diag = aet.nlinalg.diag(chol_mat)
128+
diag = aet.diag(chol_mat)
129129
# Check if the covariance matrix is positive definite.
130130
ok = aet.all(diag > 0)
131131
# If not, replace the diagonal. We return -inf later, but
@@ -222,7 +222,7 @@ class MvNormal(Continuous):
222222
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
223223
mu = aet.as_tensor_variable(mu)
224224
cov = quaddist_matrix(cov, tau, chol, lower)
225-
return super().__init__([mu, cov], **kwargs)
225+
return super().dist([mu, cov], **kwargs)
226226

227227
def logp(value, mu, cov):
228228
"""
@@ -968,7 +968,11 @@ def __init__(self, eta, n, sd_dist, *args, **kwargs):
968968
if sd_dist.shape.ndim not in [0, 1]:
969969
raise ValueError("Invalid shape for sd_dist.")
970970

971-
transform = transforms.CholeskyCovPacked(n)
971+
def transform_params(rv_var):
972+
_, _, _, n, eta = rv_var.owner.inputs
973+
return np.arange(1, n + 1).cumsum() - 1
974+
975+
transform = transforms.CholeskyCovPacked(transform_params)
972976

973977
kwargs["shape"] = shape
974978
kwargs["transform"] = transform

0 commit comments

Comments
 (0)