|
18 | 18 | import warnings
|
19 | 19 |
|
20 | 20 | from sys import modules
|
21 |
| -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union |
| 21 | +from typing import ( |
| 22 | + TYPE_CHECKING, |
| 23 | + Any, |
| 24 | + Dict, |
| 25 | + List, |
| 26 | + Optional, |
| 27 | + Sequence, |
| 28 | + Tuple, |
| 29 | + Type, |
| 30 | + TypeVar, |
| 31 | + Union, |
| 32 | +) |
22 | 33 |
|
23 | 34 | import aesara
|
24 |
| -import aesara.graph.basic |
25 | 35 | import aesara.sparse as sparse
|
26 | 36 | import aesara.tensor as at
|
27 | 37 | import numpy as np
|
|
32 | 42 | from aesara.graph.basic import Constant, Variable, graph_inputs
|
33 | 43 | from aesara.graph.fg import FunctionGraph, MissingInputError
|
34 | 44 | from aesara.tensor.random.opt import local_subtensor_rv_lift
|
| 45 | +from aesara.tensor.sharedvar import ScalarSharedVariable |
35 | 46 | from aesara.tensor.var import TensorVariable
|
36 | 47 | from pandas import Series
|
37 | 48 |
|
|
46 | 57 | from pymc3.blocking import DictToArrayBijection, RaveledVars
|
47 | 58 | from pymc3.data import GenTensorVariable, Minibatch
|
48 | 59 | from pymc3.distributions import logp_transform, logpt, logpt_sum
|
49 |
| -from pymc3.exceptions import ImputationWarning, SamplingError |
| 60 | +from pymc3.exceptions import ImputationWarning, SamplingError, ShapeError |
50 | 61 | from pymc3.math import flatten_list
|
51 | 62 | from pymc3.util import UNSET, WithMemoization, get_var_name, treedict, treelist
|
52 | 63 | from pymc3.vartypes import continuous_types, discrete_types, typefilter
|
@@ -606,8 +617,9 @@ def __new__(cls, *args, **kwargs):
|
606 | 617 |
|
607 | 618 | def __init__(self, name="", model=None, aesara_config=None, coords=None, check_bounds=True):
|
608 | 619 | self.name = name
|
609 |
| - self.coords = {} |
610 |
| - self.RV_dims = {} |
| 620 | + self._coords = {} |
| 621 | + self._RV_dims = {} |
| 622 | + self._dim_lengths = {} |
611 | 623 | self.add_coords(coords)
|
612 | 624 | self.check_bounds = check_bounds
|
613 | 625 |
|
@@ -826,6 +838,27 @@ def basic_RVs(self):
|
826 | 838 | """
|
827 | 839 | return self.free_RVs + self.observed_RVs
|
828 | 840 |
|
| 841 | + @property |
| 842 | + def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]: |
| 843 | + """Tuples of dimension names for specific model variables. |
| 844 | +
|
| 845 | + Entries in the tuples may be ``None``, if the RV dimension was not given a name. |
| 846 | + """ |
| 847 | + return self._RV_dims |
| 848 | + |
| 849 | + @property |
| 850 | + def coords(self) -> Dict[str, Union[Sequence, None]]: |
| 851 | + """Coordinate values for model dimensions.""" |
| 852 | + return self._coords |
| 853 | + |
| 854 | + @property |
| 855 | + def dim_lengths(self) -> Dict[str, Tuple[Variable, ...]]: |
| 856 | + """The symbolic lengths of dimensions in the model. |
| 857 | +
|
| 858 | + The values are typically instances of ``TensorVariable`` or ``ScalarSharedVariable``. |
| 859 | + """ |
| 860 | + return self._dim_lengths |
| 861 | + |
829 | 862 | @property
|
830 | 863 | def unobserved_RVs(self):
|
831 | 864 | """List of all random variables, including deterministic ones.
|
@@ -913,20 +946,138 @@ def shape_from_dims(self, dims):
|
913 | 946 | shape.extend(np.shape(self.coords[dim]))
|
914 | 947 | return tuple(shape)
|
915 | 948 |
|
916 |
| - def add_coords(self, coords): |
| 949 | + def add_coord( |
| 950 | + self, |
| 951 | + name: str, |
| 952 | + values: Optional[Sequence] = None, |
| 953 | + *, |
| 954 | + length: Optional[Variable] = None, |
| 955 | + ): |
| 956 | + """Registers a dimension coordinate with the model. |
| 957 | +
|
| 958 | + Parameters |
| 959 | + ---------- |
| 960 | + name : str |
| 961 | + Name of the dimension. |
| 962 | + Forbidden: {"chain", "draw"} |
| 963 | + values : optional, array-like |
| 964 | + Coordinate values or ``None`` (for auto-numbering). |
| 965 | + If ``None`` is passed, a ``length`` must be specified. |
| 966 | + length : optional, scalar |
| 967 | + A symbolic scalar of the dimensions length. |
| 968 | + Defaults to ``aesara.shared(len(values))``. |
| 969 | + """ |
| 970 | + if name in {"draw", "chain"}: |
| 971 | + raise ValueError( |
| 972 | + "Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs." |
| 973 | + ) |
| 974 | + if values is None and length is None: |
| 975 | + raise ValueError( |
| 976 | + f"Either `values` or `length` must be specified for the '{name}' dimension." |
| 977 | + ) |
| 978 | + if length is not None and not isinstance(length, Variable): |
| 979 | + raise ValueError( |
| 980 | + f"The `length` passed for the '{name}' coord must be an Aesara Variable or None." |
| 981 | + ) |
| 982 | + if name in self.coords: |
| 983 | + if not values.equals(self.coords[name]): |
| 984 | + raise ValueError("Duplicate and incompatiple coordinate: %s." % name) |
| 985 | + else: |
| 986 | + self._coords[name] = values |
| 987 | + self._dim_lengths[name] = length or aesara.shared(len(values)) |
| 988 | + |
| 989 | + def add_coords( |
| 990 | + self, |
| 991 | + coords: Dict[str, Optional[Sequence]], |
| 992 | + *, |
| 993 | + lengths: Optional[Dict[str, Union[Variable, None]]] = None, |
| 994 | + ): |
| 995 | + """Vectorized version of ``Model.add_coord``.""" |
917 | 996 | if coords is None:
|
918 | 997 | return
|
| 998 | + lengths = lengths or {} |
919 | 999 |
|
920 |
| - for name in coords: |
921 |
| - if name in {"draw", "chain"}: |
922 |
| - raise ValueError( |
923 |
| - "Dimensions can not be named `draw` or `chain`, as they are reserved for the sampler's outputs." |
| 1000 | + for name, values in coords.items(): |
| 1001 | + self.add_coord(name, values, length=lengths.get(name, None)) |
| 1002 | + |
| 1003 | + def set_data( |
| 1004 | + self, |
| 1005 | + name: str, |
| 1006 | + values: Dict[str, Optional[Sequence]], |
| 1007 | + coords: Optional[Dict[str, Sequence]] = None, |
| 1008 | + ): |
| 1009 | + """Changes the values of a data variable in the model. |
| 1010 | +
|
| 1011 | + In contrast to pm.Data().set_value, this method can also |
| 1012 | + update the corresponding coordinates. |
| 1013 | +
|
| 1014 | + Parameters |
| 1015 | + ---------- |
| 1016 | + name : str |
| 1017 | + Name of a shared variable in the model. |
| 1018 | + values : array-like |
| 1019 | + New values for the shared variable. |
| 1020 | + coords : optional, dict |
| 1021 | + New coordinate values for dimensions of the shared variable. |
| 1022 | + Must be provided for all named dimensions that change in length. |
| 1023 | + """ |
| 1024 | + shared_object = self[name] |
| 1025 | + if not isinstance(shared_object, SharedVariable): |
| 1026 | + raise TypeError( |
| 1027 | + f"The variable `{name}` must be defined as `pymc3.Data` inside the model to allow updating. " |
| 1028 | + f"The current type is: {type(shared_object)}" |
| 1029 | + ) |
| 1030 | + values = pandas_to_array(values) |
| 1031 | + dims = self.RV_dims.get(name, None) or () |
| 1032 | + coords = coords or {} |
| 1033 | + |
| 1034 | + if values.ndim != shared_object.ndim: |
| 1035 | + raise ValueError( |
| 1036 | + f"New values for '{name}' must have {shared_object.ndim} dimensions, just like the original." |
| 1037 | + ) |
| 1038 | + |
| 1039 | + for d, dname in enumerate(dims): |
| 1040 | + length_tensor = self.dim_lengths[dname] |
| 1041 | + old_length = length_tensor.eval() |
| 1042 | + new_length = values.shape[d] |
| 1043 | + original_coords = self.coords.get(dname, None) |
| 1044 | + new_coords = coords.get(dname, None) |
| 1045 | + |
| 1046 | + length_changed = new_length != old_length |
| 1047 | + |
| 1048 | + # Reject resizing if we already know that it would create shape problems. |
| 1049 | + # NOTE: If there are multiple pm.Data containers sharing this dim, but the user only |
| 1050 | + # changes the values for one of them, they will run into shape problems nonetheless. |
| 1051 | + if not isinstance(length_tensor, ScalarSharedVariable) and length_changed: |
| 1052 | + raise ShapeError( |
| 1053 | + f"Resizing dimension {dname} with values of length {new_length} would lead to incompatibilities, " |
| 1054 | + f"because the dimension was not initialized from a shared variable. " |
| 1055 | + f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, " |
| 1056 | + f"for example by a model variable.", |
| 1057 | + actual=new_length, |
| 1058 | + expected=old_length, |
924 | 1059 | )
|
925 |
| - if name in self.coords: |
926 |
| - if not coords[name].equals(self.coords[name]): |
927 |
| - raise ValueError("Duplicate and incompatiple coordinate: %s." % name) |
928 |
| - else: |
929 |
| - self.coords[name] = coords[name] |
| 1060 | + if original_coords is not None and length_changed: |
| 1061 | + if length_changed and new_coords is None: |
| 1062 | + raise ValueError( |
| 1063 | + f"The '{name}' variable already had {len(original_coords)} coord values defined for" |
| 1064 | + f"its {dname} dimension. With the new values this dimension changes to length " |
| 1065 | + f"{new_length}, so new coord values for the {dname} dimension are required." |
| 1066 | + ) |
| 1067 | + if new_coords is not None: |
| 1068 | + # Update the registered coord values (also if they were None) |
| 1069 | + if len(new_coords) != new_length: |
| 1070 | + raise ShapeError( |
| 1071 | + f"Length of new coordinate values for dimension '{dname}' does not match the provided values.", |
| 1072 | + actual=len(new_coords), |
| 1073 | + expected=new_length, |
| 1074 | + ) |
| 1075 | + self._coords[dname] = new_coords |
| 1076 | + if isinstance(length_tensor, ScalarSharedVariable) and new_length != old_length: |
| 1077 | + # Updating the shared variable resizes dependent nodes that use this dimension for their `size`. |
| 1078 | + length_tensor.set_value(new_length) |
| 1079 | + |
| 1080 | + shared_object.set_value(values) |
930 | 1081 |
|
931 | 1082 | def register_rv(self, rv_var, name, data=None, total_size=None, dims=None, transform=UNSET):
|
932 | 1083 | """Register an (un)observed random variable with the model.
|
@@ -1132,16 +1283,16 @@ def create_value_var(self, rv_var: TensorVariable, transform: Any) -> TensorVari
|
1132 | 1283 |
|
1133 | 1284 | return value_var
|
1134 | 1285 |
|
1135 |
| - def add_random_variable(self, var, dims=None): |
| 1286 | + def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]] = None): |
1136 | 1287 | """Add a random variable to the named variables of the model."""
|
1137 | 1288 | if self.named_vars.tree_contains(var.name):
|
1138 | 1289 | raise ValueError(f"Variable name {var.name} already exists.")
|
1139 | 1290 |
|
1140 | 1291 | if dims is not None:
|
1141 | 1292 | if isinstance(dims, str):
|
1142 | 1293 | dims = (dims,)
|
1143 |
| - assert all(dim in self.coords for dim in dims) |
1144 |
| - self.RV_dims[var.name] = dims |
| 1294 | + assert all(dim in self.coords or dim is None for dim in dims) |
| 1295 | + self._RV_dims[var.name] = dims |
1145 | 1296 |
|
1146 | 1297 | self.named_vars[var.name] = var
|
1147 | 1298 | if not hasattr(self, self.name_of(var.name)):
|
@@ -1500,18 +1651,7 @@ def set_data(new_data, model=None):
|
1500 | 1651 | model = modelcontext(model)
|
1501 | 1652 |
|
1502 | 1653 | for variable_name, new_value in new_data.items():
|
1503 |
| - if isinstance(model[variable_name], SharedVariable): |
1504 |
| - if isinstance(new_value, list): |
1505 |
| - new_value = np.array(new_value) |
1506 |
| - model[variable_name].set_value(pandas_to_array(new_value)) |
1507 |
| - else: |
1508 |
| - message = ( |
1509 |
| - "The variable `{}` must be defined as `pymc3." |
1510 |
| - "Data` inside the model to allow updating. The " |
1511 |
| - "current type is: " |
1512 |
| - "{}.".format(variable_name, type(model[variable_name])) |
1513 |
| - ) |
1514 |
| - raise TypeError(message) |
| 1654 | + model.set_data(variable_name, new_value) |
1515 | 1655 |
|
1516 | 1656 |
|
1517 | 1657 | def fn(outs, mode=None, model=None, *args, **kwargs):
|
|
0 commit comments