Skip to content

Commit fd3f730

Browse files
Allow unnamed (None) dims and undefined (None) coord values
Also refactor into properties to add docstrings and type annotations. And no longer allow InferenceData conversion without a Model on stack. Co-authored-by: Oriol Abril Pla <[email protected]>
1 parent 45cb4eb commit fd3f730

File tree

7 files changed

+200
-43
lines changed

7 files changed

+200
-43
lines changed

RELEASE-NOTES.md

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
### Maintenance
1414
- Remove float128 dtype support (see [#4514](https://github.com./pymc-devs/pymc3/pull/4514)).
1515
- Logp method of `Uniform` and `DiscreteUniform` no longer depends on `pymc3.distributions.dist_math.bound` for proper evaluation (see [#4541](https://github.com./pymc-devs/pymc3/pull/4541)).
16+
- `Model.RV_dims` and `Model.coords` are now read-only properties. To modify the `coords` dictionary use `Model.add_coord`. Also `dims` or coordinate values that are `None` will be auto-completed (see [#4625](https://github.com./pymc-devs/pymc3/pull/4625)).
17+
- The length of `dims` in the model is now tracked symbolically through `Model.dim_lengths` (see [#4625](https://github.com./pymc-devs/pymc3/pull/4625)).
1618
- ...
1719

1820
## PyMC3 3.11.2 (14 March 2021)

pymc3/backends/arviz.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,7 @@ def __init__(
162162
self.trace = trace
163163

164164
# this permits us to get the model from command-line argument or from with model:
165-
try:
166-
self.model = modelcontext(model)
167-
except TypeError:
168-
self.model = None
165+
self.model = modelcontext(model)
169166

170167
self.attrs = None
171168
if trace is not None:
@@ -223,10 +220,14 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
223220
self.coords = {} if coords is None else coords
224221
if hasattr(self.model, "coords"):
225222
self.coords = {**self.model.coords, **self.coords}
223+
self.coords = {key: value for key, value in self.coords.items() if value is not None}
226224

227225
self.dims = {} if dims is None else dims
228226
if hasattr(self.model, "RV_dims"):
229-
model_dims = {k: list(v) for k, v in self.model.RV_dims.items()}
227+
model_dims = {
228+
var_name: [dim for dim in dims if dim is not None]
229+
for var_name, dims in self.model.RV_dims.items()
230+
}
230231
self.dims = {**model_dims, **self.dims}
231232

232233
self.density_dist_obs = density_dist_obs

pymc3/data.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import urllib.request
2020

2121
from copy import copy
22-
from typing import Any, Dict, List
22+
from typing import Any, Dict, List, Sequence
2323

2424
import aesara
2525
import aesara.tensor as at
@@ -502,7 +502,7 @@ class Data:
502502
>>> for data_vals in observed_data:
503503
... with model:
504504
... # Switch out the observed dataset
505-
... pm.set_data({'data': data_vals})
505+
... model.set_data('data', data_vals)
506506
... traces.append(pm.sample())
507507
508508
To set the value of the data container variable, check out
@@ -543,6 +543,11 @@ def __new__(self, name, value, *, dims=None, export_index_as_coords=False):
543543

544544
if export_index_as_coords:
545545
model.add_coords(coords)
546+
elif dims:
547+
# Register new dimension lengths
548+
for d, dname in enumerate(dims):
549+
if not dname in model.dim_lengths:
550+
model.add_coord(dname, values=None, length=shared_object.shape[d])
546551

547552
# To draw the node for this variable in the graphviz Digraph we need
548553
# its shape.
@@ -562,7 +567,7 @@ def __new__(self, name, value, *, dims=None, export_index_as_coords=False):
562567
return shared_object
563568

564569
@staticmethod
565-
def set_coords(model, value, dims=None):
570+
def set_coords(model, value, dims=None) -> Dict[str, Sequence]:
566571
coords = {}
567572

568573
# If value is a df or a series, we interpret the index as coords:

pymc3/model.py

+170-30
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,20 @@
1818
import warnings
1919

2020
from sys import modules
21-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
21+
from typing import (
22+
TYPE_CHECKING,
23+
Any,
24+
Dict,
25+
List,
26+
Optional,
27+
Sequence,
28+
Tuple,
29+
Type,
30+
TypeVar,
31+
Union,
32+
)
2233

2334
import aesara
24-
import aesara.graph.basic
2535
import aesara.sparse as sparse
2636
import aesara.tensor as at
2737
import numpy as np
@@ -32,6 +42,7 @@
3242
from aesara.graph.basic import Constant, Variable, graph_inputs
3343
from aesara.graph.fg import FunctionGraph, MissingInputError
3444
from aesara.tensor.random.opt import local_subtensor_rv_lift
45+
from aesara.tensor.sharedvar import ScalarSharedVariable
3546
from aesara.tensor.var import TensorVariable
3647
from pandas import Series
3748

@@ -46,7 +57,7 @@
4657
from pymc3.blocking import DictToArrayBijection, RaveledVars
4758
from pymc3.data import GenTensorVariable, Minibatch
4859
from pymc3.distributions import logp_transform, logpt, logpt_sum
49-
from pymc3.exceptions import ImputationWarning, SamplingError
60+
from pymc3.exceptions import ImputationWarning, SamplingError, ShapeError
5061
from pymc3.math import flatten_list
5162
from pymc3.util import UNSET, WithMemoization, get_var_name, treedict, treelist
5263
from pymc3.vartypes import continuous_types, discrete_types, typefilter
@@ -606,8 +617,9 @@ def __new__(cls, *args, **kwargs):
606617

607618
def __init__(self, name="", model=None, aesara_config=None, coords=None, check_bounds=True):
608619
self.name = name
609-
self.coords = {}
610-
self.RV_dims = {}
620+
self._coords = {}
621+
self._RV_dims = {}
622+
self._dim_lengths = {}
611623
self.add_coords(coords)
612624
self.check_bounds = check_bounds
613625

@@ -826,6 +838,27 @@ def basic_RVs(self):
826838
"""
827839
return self.free_RVs + self.observed_RVs
828840

841+
@property
842+
def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
843+
"""Tuples of dimension names for specific model variables.
844+
845+
Entries in the tuples may be ``None``, if the RV dimension was not given a name.
846+
"""
847+
return self._RV_dims
848+
849+
@property
850+
def coords(self) -> Dict[str, Union[Sequence, None]]:
851+
"""Coordinate values for model dimensions."""
852+
return self._coords
853+
854+
@property
855+
def dim_lengths(self) -> Dict[str, Tuple[Variable, ...]]:
856+
"""The symbolic lengths of dimensions in the model.
857+
858+
The values are typically instances of ``TensorVariable`` or ``ScalarSharedVariable``.
859+
"""
860+
return self._dim_lengths
861+
829862
@property
830863
def unobserved_RVs(self):
831864
"""List of all random variables, including deterministic ones.
@@ -913,20 +946,138 @@ def shape_from_dims(self, dims):
913946
shape.extend(np.shape(self.coords[dim]))
914947
return tuple(shape)
915948

916-
def add_coords(self, coords):
949+
def add_coord(
950+
self,
951+
name: str,
952+
values: Optional[Sequence] = None,
953+
*,
954+
length: Optional[Variable] = None,
955+
):
956+
"""Registers a dimension coordinate with the model.
957+
958+
Parameters
959+
----------
960+
name : str
961+
Name of the dimension.
962+
Forbidden: {"chain", "draw"}
963+
values : optional, array-like
964+
Coordinate values or ``None`` (for auto-numbering).
965+
If ``None`` is passed, a ``length`` must be specified.
966+
length : optional, scalar
967+
A symbolic scalar of the dimensions length.
968+
Defaults to ``aesara.shared(len(values))``.
969+
"""
970+
if name in {"draw", "chain"}:
971+
raise ValueError(
972+
"Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs."
973+
)
974+
if values is None and length is None:
975+
raise ValueError(
976+
f"Either `values` or `length` must be specified for the '{name}' dimension."
977+
)
978+
if length is not None and not isinstance(length, Variable):
979+
raise ValueError(
980+
f"The `length` passed for the '{name}' coord must be an Aesara Variable or None."
981+
)
982+
if name in self.coords:
983+
if not values.equals(self.coords[name]):
984+
raise ValueError("Duplicate and incompatiple coordinate: %s." % name)
985+
else:
986+
self._coords[name] = values
987+
self._dim_lengths[name] = length or aesara.shared(len(values))
988+
989+
def add_coords(
990+
self,
991+
coords: Dict[str, Optional[Sequence]],
992+
*,
993+
lengths: Optional[Dict[str, Union[Variable, None]]] = None,
994+
):
995+
"""Vectorized version of ``Model.add_coord``."""
917996
if coords is None:
918997
return
998+
lengths = lengths or {}
919999

920-
for name in coords:
921-
if name in {"draw", "chain"}:
922-
raise ValueError(
923-
"Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs."
1000+
for name, values in coords.items():
1001+
self.add_coord(name, values, length=lengths.get(name, None))
1002+
1003+
def set_data(
1004+
self,
1005+
name: str,
1006+
values: Dict[str, Optional[Sequence]],
1007+
coords: Optional[Dict[str, Sequence]] = None,
1008+
):
1009+
"""Changes the values of a data variable in the model.
1010+
1011+
In contrast to pm.Data().set_value, this method can also
1012+
update the corresponding coordinates.
1013+
1014+
Parameters
1015+
----------
1016+
name : str
1017+
Name of a shared variable in the model.
1018+
values : array-like
1019+
New values for the shared variable.
1020+
coords : optional, dict
1021+
New coordinate values for dimensions of the shared variable.
1022+
Must be provided for all named dimensions that change in length.
1023+
"""
1024+
shared_object = self[name]
1025+
if not isinstance(shared_object, SharedVariable):
1026+
raise TypeError(
1027+
f"The variable `{name}` must be defined as `pymc3.Data` inside the model to allow updating. "
1028+
f"The current type is: {type(shared_object)}"
1029+
)
1030+
values = pandas_to_array(values)
1031+
dims = self.RV_dims.get(name, None) or ()
1032+
coords = coords or {}
1033+
1034+
if values.ndim != shared_object.ndim:
1035+
raise ValueError(
1036+
f"New values for '{name}' must have {shared_object.ndim} dimensions, just like the original."
1037+
)
1038+
1039+
for d, dname in enumerate(dims):
1040+
length_tensor = self.dim_lengths[dname]
1041+
old_length = length_tensor.eval()
1042+
new_length = values.shape[d]
1043+
original_coords = self.coords.get(dname, None)
1044+
new_coords = coords.get(dname, None)
1045+
1046+
length_changed = new_length != old_length
1047+
1048+
# Reject resizing if we already know that it would create shape problems.
1049+
# NOTE: If there are multiple pm.Data containers sharing this dim, but the user only
1050+
# changes the values for one of them, they will run into shape problems nonetheless.
1051+
if not isinstance(length_tensor, ScalarSharedVariable) and length_changed:
1052+
raise ShapeError(
1053+
f"Resizing dimension {dname} with values of length {new_length} would lead to incompatibilities, "
1054+
f"because the dimension was not initialized from a shared variable. "
1055+
f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, "
1056+
f"for example by a model variable.",
1057+
actual=new_length,
1058+
expected=old_length,
9241059
)
925-
if name in self.coords:
926-
if not coords[name].equals(self.coords[name]):
927-
raise ValueError("Duplicate and incompatiple coordinate: %s." % name)
928-
else:
929-
self.coords[name] = coords[name]
1060+
if original_coords is not None and length_changed:
1061+
if length_changed and new_coords is None:
1062+
raise ValueError(
1063+
f"The '{name}' variable already had {len(original_coords)} coord values defined for"
1064+
f"its {dname} dimension. With the new values this dimension changes to length "
1065+
f"{new_length}, so new coord values for the {dname} dimension are required."
1066+
)
1067+
if new_coords is not None:
1068+
# Update the registered coord values (also if they were None)
1069+
if len(new_coords) != new_length:
1070+
raise ShapeError(
1071+
f"Length of new coordinate values for dimension '{dname}' does not match the provided values.",
1072+
actual=len(new_coords),
1073+
expected=new_length,
1074+
)
1075+
self._coords[dname] = new_coords
1076+
if isinstance(length_tensor, ScalarSharedVariable) and new_length != old_length:
1077+
# Updating the shared variable resizes dependent nodes that use this dimension for their `size`.
1078+
length_tensor.set_value(new_length)
1079+
1080+
shared_object.set_value(values)
9301081

9311082
def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET):
9321083
"""Register an (un)observed random variable with the model.
@@ -1132,16 +1283,16 @@ def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVari
11321283

11331284
return value_var
11341285

1135-
def add_random_variable(self, var, dims=None):
1286+
def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None):
11361287
"""Add a random variable to the named variables of the model."""
11371288
if self.named_vars.tree_contains(var.name):
11381289
raise ValueError(f"Variable name {var.name} already exists.")
11391290

11401291
if dims is not None:
11411292
if isinstance(dims, str):
11421293
dims = (dims,)
1143-
assert all(dim in self.coords for dim in dims)
1144-
self.RV_dims[var.name] = dims
1294+
assert all(dim in self.coords or dim is None for dim in dims)
1295+
self._RV_dims[var.name] = dims
11451296

11461297
self.named_vars[var.name] = var
11471298
if not hasattr(self, self.name_of(var.name)):
@@ -1500,18 +1651,7 @@ def set_data(new_data, model=None):
15001651
model = modelcontext(model)
15011652

15021653
for variable_name, new_value in new_data.items():
1503-
if isinstance(model[variable_name], SharedVariable):
1504-
if isinstance(new_value, list):
1505-
new_value = np.array(new_value)
1506-
model[variable_name].set_value(pandas_to_array(new_value))
1507-
else:
1508-
message = (
1509-
"The variable `{}` must be defined as `pymc3."
1510-
"Data` inside the model to allow updating. The "
1511-
"current type is: "
1512-
"{}.".format(variable_name, type(model[variable_name]))
1513-
)
1514-
raise TypeError(message)
1654+
model.set_data(variable_name, new_value)
15151655

15161656

15171657
def fn(outs, mode=None, model=None, *args, **kwargs):

pymc3/tests/sampler_fixtures.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,15 @@ def setup_class(cls):
155155

156156
def test_neff(self):
157157
if hasattr(self, "min_n_eff"):
158-
idata = to_inference_data(self.trace[self.burn :])
158+
with self.model:
159+
idata = to_inference_data(self.trace[self.burn :])
159160
n_eff = az.ess(idata)
160161
for var in n_eff:
161162
npt.assert_array_less(self.min_n_eff, n_eff[var])
162163

163164
def test_Rhat(self):
164-
idata = to_inference_data(self.trace[self.burn :])
165+
with self.model:
166+
idata = to_inference_data(self.trace[self.burn :])
165167
rhat = az.rhat(idata)
166168
for var in rhat:
167169
npt.assert_allclose(rhat[var], 1, rtol=0.01)

pymc3/tests/test_data_container.py

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import pytest
1818

1919
from aesara import shared
20+
from aesara.tensor.sharedvar import ScalarSharedVariable
21+
from aesara.tensor.var import TensorVariable
2022

2123
import pymc3 as pm
2224

@@ -272,9 +274,15 @@ def test_explicit_coords(self):
272274

273275
assert "rows" in pmodel.coords
274276
assert pmodel.coords["rows"] == ["R1", "R2", "R3", "R4", "R5"]
277+
assert "rows" in pmodel.dim_lengths
278+
assert isinstance(pmodel.dim_lengths["rows"], ScalarSharedVariable)
279+
assert pmodel.dim_lengths["rows"].eval() == 5
275280
assert "columns" in pmodel.coords
276281
assert pmodel.coords["columns"] == ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]
277282
assert pmodel.RV_dims == {"observations": ("rows", "columns")}
283+
assert "columns" in pmodel.dim_lengths
284+
assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable)
285+
assert pmodel.dim_lengths["columns"].eval() == 7
278286

279287
def test_implicit_coords_series(self):
280288
ser_sales = pd.Series(

0 commit comments

Comments
 (0)