Skip to content

Commit 0d07347

Browse files
committed
Updated code in light of Luciano's comments
1 parent 5e5440c commit 0d07347

File tree

3 files changed

+6
-29
lines changed

3 files changed

+6
-29
lines changed

pymc3/data.py

+2-12
Original file line numberDiff line numberDiff line change
@@ -479,17 +479,7 @@ class Data:
479479
https://docs.pymc.io/notebooks/data_container.html
480480
"""
481481

482-
def __new__(self, name, value, dtype=None):
483-
if not dtype:
484-
if hasattr(value, "dtype"):
485-
# if no dtype given but available as attr of value, use that as dtype
486-
dtype = value.dtype
487-
elif isinstance(value, int):
488-
dtype = int
489-
else:
490-
# otherwise, assume float
491-
dtype = float
492-
482+
def __new__(self, name, value):
493483
# Add data container to the named variables of the model.
494484
try:
495485
model = pm.Model.get_context()
@@ -502,7 +492,7 @@ def __new__(self, name, value, dtype=None):
502492

503493
# `pm.model.pandas_to_array` takes care of parameter `value` and
504494
# transforms it to something digestible for pymc3
505-
shared_object = theano.shared(pm.model.pandas_to_array(value, dtype=dtype), name)
495+
shared_object = theano.shared(pm.model.pandas_to_array(value), name)
506496

507497
# To draw the node for this variable in the graphviz Digraph we need
508498
# its shape.

pymc3/model.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -1275,15 +1275,7 @@ def set_data(new_data, model=None):
12751275

12761276
for variable_name, new_value in new_data.items():
12771277
if isinstance(model[variable_name], SharedVariable):
1278-
if hasattr(new_value, "dtype"):
1279-
# if no dtype given but available as attr of value, use that as dtype
1280-
dtype = new_value.dtype
1281-
elif isinstance(new_value, int):
1282-
dtype = int
1283-
else:
1284-
# otherwise, assume float
1285-
dtype = float
1286-
model[variable_name].set_value(pandas_to_array(new_value, dtype=dtype))
1278+
model[variable_name].set_value(pandas_to_array(new_value))
12871279
else:
12881280
message = 'The variable `{}` must be defined as `pymc3.' \
12891281
'Data` inside the model to allow updating. The ' \
@@ -1490,7 +1482,7 @@ def init_value(self):
14901482
return self.tag.test_value
14911483

14921484

1493-
def pandas_to_array(data, dtype=float):
1485+
def pandas_to_array(data):
14941486
if hasattr(data, 'values'): # pandas
14951487
if data.isnull().any().any(): # missing values
14961488
ret = np.ma.MaskedArray(data.values, data.isnull().values)
@@ -1510,12 +1502,7 @@ def pandas_to_array(data, dtype=float):
15101502
else:
15111503
ret = np.asarray(data)
15121504

1513-
if dtype in [int, np.int8, np.int16, np.int32, np.int64]:
1514-
return pm.intX(ret)
1515-
elif dtype in [float, np.float16, np.float32, np.float64]:
1516-
return pm.floatX(ret)
1517-
else:
1518-
raise ValueError('Unsupported type for pandas_to_array: %s' % str(dtype))
1505+
return ret
15191506

15201507

15211508
def as_tensor(data, name, model, distribution):

pymc3/tests/test_data_container.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_shared_data_as_index(self):
109109
See https://github.com./pymc-devs/pymc3/issues/3813
110110
"""
111111
with pm.Model() as model:
112-
index = pm.Data("index", [2, 0, 1, 0, 2], dtype=int)
112+
index = pm.Data("index", [2, 0, 1, 0, 2])
113113
y = pm.Data("y", [1.0, 2.0, 3.0, 2.0, 1.0])
114114
alpha = pm.Normal("alpha", 0, 1.5, shape=3)
115115
pm.Normal("obs", alpha[index], np.sqrt(1e-2), observed=y)

0 commit comments

Comments
 (0)