11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import warnings
15
+
14
16
from functools import singledispatch
15
17
from itertools import chain
16
18
from typing import Generator , List , Optional , Tuple , Union
20
22
21
23
from aesara import config
22
24
from aesara .graph .basic import Variable , ancestors , clone_replace
23
- from aesara .graph .op import compute_test_value
25
+ from aesara .graph .op import Op , compute_test_value
24
26
from aesara .tensor .random .op import Observed , RandomVariable
25
27
from aesara .tensor .subtensor import AdvancedSubtensor , AdvancedSubtensor1 , Subtensor
26
28
from aesara .tensor .var import TensorVariable
33
35
34
36
35
37
@singledispatch
36
- def logp_transform (op , inputs ):
38
+ def logp_transform (op : Op ):
37
39
return None
38
40
39
41
@@ -141,7 +143,8 @@ def change_rv_size(
141
143
142
144
def rv_log_likelihood_args (
143
145
rv_var : TensorVariable ,
144
- transformed : Optional [bool ] = True ,
146
+ * ,
147
+ return_observations : bool = True ,
145
148
) -> Tuple [TensorVariable , TensorVariable ]:
146
149
"""Get a `RandomVariable` and its corresponding log-likelihood `TensorVariable` value.
147
150
@@ -151,8 +154,9 @@ def rv_log_likelihood_args(
151
154
A variable corresponding to a `RandomVariable`, whether directly or
152
155
indirectly (e.g. an observed variable that's the output of an
153
156
`Observed` `Op`).
154
- transformed
155
- When ``True``, return the transformed value var.
157
+ return_observations
158
+ When ``True``, return the observed values in place of the log-likelihood
159
+ value variable.
156
160
157
161
Returns
158
162
=======
@@ -163,12 +167,14 @@ def rv_log_likelihood_args(
163
167
"""
164
168
165
169
if rv_var .owner and isinstance (rv_var .owner .op , Observed ):
166
- return tuple (rv_var .owner .inputs )
167
- elif hasattr (rv_var .tag , "value_var" ):
168
- rv_value = rv_var .tag .value_var
169
- return rv_var , rv_value
170
- else :
171
- return rv_var , None
170
+ rv_var , obs_var = rv_var .owner .inputs
171
+ if return_observations :
172
+ return rv_var , obs_var
173
+ else :
174
+ return rv_var , rv_log_likelihood_args (rv_var )[1 ]
175
+
176
+ rv_value = getattr (rv_var .tag , "value_var" , None )
177
+ return rv_var , rv_value
172
178
173
179
174
180
def rv_ancestors (graphs : List [TensorVariable ]) -> Generator [TensorVariable , None , None ]:
@@ -217,7 +223,7 @@ def sample_to_measure_vars(
217
223
if not (anc .owner and isinstance (anc .owner .op , RandomVariable )):
218
224
continue
219
225
220
- _ , value_var = rv_log_likelihood_args (anc )
226
+ _ , value_var = rv_log_likelihood_args (anc , return_observations = False )
221
227
222
228
if value_var is not None :
223
229
replace [anc ] = value_var
@@ -233,8 +239,10 @@ def sample_to_measure_vars(
233
239
def logpt (
234
240
rv_var : TensorVariable ,
235
241
rv_value : Optional [TensorVariable ] = None ,
236
- jacobian : Optional [bool ] = True ,
237
- scaling : Optional [bool ] = True ,
242
+ * ,
243
+ jacobian : bool = True ,
244
+ scaling : bool = True ,
245
+ transformed : bool = True ,
238
246
** kwargs ,
239
247
) -> TensorVariable :
240
248
"""Create a measure-space (i.e. log-likelihood) graph for a random variable at a given point.
@@ -257,6 +265,8 @@ def logpt(
257
265
Whether or not to include the Jacobian term.
258
266
scaling
259
267
A scaling term to apply to the generated log-likelihood graph.
268
+ transformed
269
+ Apply transforms.
260
270
261
271
"""
262
272
@@ -282,22 +292,22 @@ def logpt(
282
292
283
293
raise NotImplementedError ("Missing value support is incomplete" )
284
294
285
- # "Flatten" and sum an array of indexed RVs' log-likelihoods
286
- rv_var , missing_values = rv_node .inputs
287
-
288
- missing_values = missing_values .data
289
- logp_var = aet .sum (
290
- [
291
- logpt (
292
- rv_var ,
293
- )
294
- for idx , missing in zip (
295
- np .ndindex (missing_values .shape ), missing_values .flatten ()
296
- )
297
- if missing
298
- ]
299
- )
300
- return logp_var
295
+ # # "Flatten" and sum an array of indexed RVs' log-likelihoods
296
+ # rv_var, missing_values = rv_node.inputs
297
+ #
298
+ # missing_values = missing_values.data
299
+ # logp_var = aet.sum(
300
+ # [
301
+ # logpt(
302
+ # rv_var,
303
+ # )
304
+ # for idx, missing in zip(
305
+ # np.ndindex(missing_values.shape), missing_values.flatten()
306
+ # )
307
+ # if missing
308
+ # ]
309
+ # )
310
+ # return logp_var
301
311
302
312
return aet .zeros_like (rv_var )
303
313
@@ -312,15 +322,16 @@ def logpt(
312
322
# If any of the measure vars are transformed measure-space variables
313
323
# (signified by having a `transform` value in their tags), then we apply
314
324
# the their transforms and add their Jacobians (when enabled)
315
- if transform :
316
- logp_var = _logp (rv_node .op , transform .backward (rv_value ), * dist_params , ** kwargs )
325
+ if transform and transformed :
326
+ logp_var = _logp (rv_node .op , transform .backward (rv_var , rv_value ), * dist_params , ** kwargs )
327
+
317
328
logp_var = transform_logp (
318
329
logp_var ,
319
330
tuple (replacements .values ()),
320
331
)
321
332
322
333
if jacobian :
323
- transformed_jacobian = transform .jacobian_det (rv_value )
334
+ transformed_jacobian = transform .jacobian_det (rv_var , rv_value )
324
335
if transformed_jacobian :
325
336
if logp_var .ndim > transformed_jacobian .ndim :
326
337
logp_var = logp_var .sum (axis = - 1 )
@@ -345,11 +356,17 @@ def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> Te
345
356
for measure_var in inputs :
346
357
347
358
transform = getattr (measure_var .tag , "transform" , None )
359
+ rv_var = getattr (measure_var .tag , "rv_var" , None )
360
+
361
+ if transform is not None and rv_var is None :
362
+ warnings .warn (
363
+ f"A transform was found for { measure_var } but not a corresponding random variable"
364
+ )
348
365
349
- if transform is None :
366
+ if transform is None or rv_var is None :
350
367
continue
351
368
352
- trans_rv_value = transform .backward (measure_var )
369
+ trans_rv_value = transform .backward (rv_var , measure_var )
353
370
trans_replacements [measure_var ] = trans_rv_value
354
371
355
372
if trans_replacements :
@@ -359,7 +376,7 @@ def transform_logp(logp_var: TensorVariable, inputs: List[TensorVariable]) -> Te
359
376
360
377
361
378
@singledispatch
362
- def _logp (op , value , * dist_params , ** kwargs ):
379
+ def _logp (op : Op , value : TensorVariable , * dist_params , ** kwargs ):
363
380
"""Create a log-likelihood graph.
364
381
365
382
This function dispatches on the type of `op`, which should be a subclass
@@ -370,7 +387,9 @@ def _logp(op, value, *dist_params, **kwargs):
370
387
return aet .zeros_like (value )
371
388
372
389
373
- def logcdf (rv_var , rv_value , jacobian = True , ** kwargs ):
390
+ def logcdf (
391
+ rv_var : TensorVariable , rv_value : Optional [TensorVariable ], jacobian : bool = True , ** kwargs
392
+ ):
374
393
"""Create a log-CDF graph."""
375
394
376
395
rv_var , _ = rv_log_likelihood_args (rv_var )
0 commit comments