Skip to content

Make draw_values draw from the joint distribution #3214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
6aedaa9
Merge pull request #1 from pymc-devs/master
lucianopaz May 25, 2018
32ce3c2
Merge branch 'master' of https://github.com./pymc-devs/pymc3
lucianopaz Sep 26, 2018
bd25baa
Resolved merge conflicts with upstream master, which I had not fetche…
lucianopaz Sep 26, 2018
f6ecb23
Fixed most of the bugs encountered due to the incorrect upstream fetc…
lucianopaz Sep 26, 2018
df5e3ae
Fixed collections import error
lucianopaz Sep 26, 2018
d43d149
Fixed list copy and defaults of DependenceDAG.__init__
lucianopaz Sep 26, 2018
339828d
Implemented `get_first_level_conditionals` to try to get rid of the a…
lucianopaz Sep 27, 2018
890ae74
Cleaned up model.py, made it comply with pep8, and fixed lint error o…
lucianopaz Sep 27, 2018
237f8ba
Fix get_first_level_conditionals and also made DependenceDAG a subcla…
lucianopaz Sep 28, 2018
4ef4ea3
Completely removed DependenceDAG class. The variable dependence graph…
lucianopaz Oct 1, 2018
659647e
Reverted unnecessary format changes.
lucianopaz Oct 1, 2018
4371af2
Added tests for WrapAsHashable. Made get_first_layer_conditionals mor…
lucianopaz Oct 2, 2018
6298d71
Added __ne__ for WrapAsHashable, which delegates to __eq__. This must…
lucianopaz Oct 2, 2018
31d36a2
Resolve comments from PR.
lucianopaz Oct 12, 2018
672907c
Fixed typo. Changed dag edge attributes to only deterministic 0 or 1.…
lucianopaz Oct 23, 2018
ba8305f
Finished adaptation of ModelGraph to use networkx for plate detection…
lucianopaz Oct 31, 2018
fbbf4c3
Merge branch 'master' into master
lucianopaz Nov 1, 2018
08eccbf
Fixed lint errors and test_step error due to upstream merge conflict.
lucianopaz Nov 1, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 176 additions & 122 deletions pymc3/distributions/distribution.py

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,17 @@ def __init__(self, nu, Sigma=None, mu=None, cov=None, tau=None, chol=None,
self.mean = self.median = self.mode = self.mu = self.mu

def random(self, point=None, size=None):
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
if self._cov_type == 'cov':
cov, = draw_values([self.cov], point=point, size=size)
nu, mu, cov = draw_values([self.nu, self.mu, self.cov],
point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
elif self._cov_type == 'tau':
tau, = draw_values([self.tau], point=point, size=size)
nu, mu, tau = draw_values([self.nu, self.mu, self.tau],
point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
else:
chol, = draw_values([self.chol_cov], point=point, size=size)
nu, mu, chol = draw_values([self.nu, self.mu, self.chol],
point=point, size=size)
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)

samples = dist.random(point, size)
Expand Down
417 changes: 341 additions & 76 deletions pymc3/model.py

Large diffs are not rendered by default.

174 changes: 174 additions & 0 deletions pymc3/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .util import get_default_varnames
import pymc3 as pm
from .model import build_dependence_dag_from_model
import networkx as nx


def powerset(iterable):
Expand Down Expand Up @@ -184,3 +186,175 @@ def model_to_graphviz(model=None):
"""
model = pm.modelcontext(model)
return ModelGraph(model).make_graph()


class OtherModelGraph(object):
def __init__(self, model):
self.model = model
try:
graph = model.dependence_dag
except AttributeError:
graph = build_dependence_dag_from_model(model)
self.set_graph(graph)

def set_graph(self, graph):
self.graph = graph
self.node_names = {}
unnamed_count = 0
for n in self.graph.nodes():
try:
name = n.name
except AttributeError:
name = 'Unnamed {}'.format(unnamed_count)
unnamed_count += 1
self.node_names[n] = name

def draw(self, pos=None, draw_nodes=False, ax=None,
edge_kwargs=None,
label_kwargs=None,
node_kwargs=None):
graph = self.graph
if edge_kwargs is None:
edge_kwargs = {}
if node_kwargs is None:
node_kwargs = {}
if label_kwargs is None:
label_kwargs = {}
label_kwargs.setdefault('bbox', {'boxstyle': 'round',
'facecolor': 'lightgray'})
if pos is None:
try:
pos = nx.drawing.nx_agraph.graphviz_layout(graph, prog='dot')
except Exception:
pos = nx.shell_layout(graph)
d = nx.get_edge_attributes(graph, 'deterministic')
edgelist = list(d.keys())
edge_color = [float(v) for v in d.values()]
labels = {n: n.name for n in graph}

if draw_nodes:
nx.draw_networkx_edges(graph, pos=pos, ax=ax, **node_kwargs)
nx.draw_networkx_edges(graph, pos=pos, ax=ax, edgelist=edgelist,
edge_color=edge_color, **edge_kwargs)
nx.draw_networkx_labels(graph, pos=pos, ax=ax, labels=labels,
**label_kwargs)

def get_plates(self, graph=None, ignore_transformed=True):
""" Groups nodes by the shape of the underlying distribution, and if
the nodes form a disconnected component of the graph.

Parameters
----------
graph: networkx.DiGraph (optional)
The graph object from which to get the plates. If None, self.graph
will be used.
ignore_transformed: bool (optional)
If True, the transformed variables will be ignored while getting
the plates.

Returns
-------
list of tuples: (shape, set(nodes_in_plate))
"""
if graph is None:
graph = self.graph
if ignore_transformed:
transforms = set([n.transformed for n in graph
if hasattr(n, 'transformed')])
nbunch = [n for n in graph if n not in transforms]
graph = nx.subgraph(graph, nbunch)
shape_plates = {}
for node in graph:
if hasattr(node, 'observations'):
shape = node.observations.shape
elif hasattr(node, 'dshape'):
shape = node.dshape
else:
try:
shape = node.tag.test_value.shape
except AttributeError:
shape = tuple()
if shape == (1,):
shape = tuple()
if shape not in shape_plates:
shape_plates[shape] = set()
shape_plates[shape].add(node)
plates = []
for shape, nodes in shape_plates.items():
# We want to find the disconnected components that have a common
# shape. These will be the plates
subgraph = nx.subgraph(graph, nodes).to_undirected()
for G in nx.connected_component_subgraphs(subgraph, copy=False):
plates.append((shape, set(G.nodes())))
return plates

def make_graph(self, ignore_transformed=True, edge_cmap=None):
"""Make graphviz Digraph of PyMC3 model

Returns
-------
graphviz.Digraph
"""
try:
import graphviz
except ImportError:
raise ImportError('This function requires the python library graphviz, along with binaries. '
'The easiest way to install all of this is by running\n\n'
'\tconda install -c conda-forge python-graphviz')

G = self.graph
if ignore_transformed:
transforms = set([n.transformed for n in G
if hasattr(n, 'transformed')])
nbunch = [n for n in G if n not in transforms]
G = nx.subgraph(G, nbunch)
graph = graphviz.Digraph(self.model.name)
nclusters = 0
for shape, nodes in self.get_plates(graph=G):
label = ' x '.join(map('{:,d}'.format, shape))
if label:
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name='cluster {}'.format(nclusters)) as sub:
nclusters += 1
for node in nodes:
self._make_node(node, sub)
# plate label goes bottom right
sub.attr(label=label, labeljust='r', labelloc='b', style='rounded')
else:
for node in nodes:
self._make_node(node, graph)

for from_node, to_node, ats in G.edges(data=True):
if edge_cmap is None:
edge_color = '#000000'
else:
from matplotlib import colors
val = float(ats['deterministic'])
edge_color = colors.to_hex(edge_cmap(val), keep_alpha=True)
graph.edge(self.node_names[from_node],
self.node_names[to_node],
color=edge_color)
return graph

def _make_node(self, node, graph):
"""Attaches the given variable to a graphviz Digraph"""
# styling for node
attrs = {}
if isinstance(node, pm.model.ObservedRV):
attrs['style'] = 'filled'

# Get name for node
if hasattr(node, 'distribution'):
distribution = node.distribution.__class__.__name__
else:
distribution = 'Deterministic'
attrs['shape'] = 'box'
var_name = self.node_names[node]

graph.node(var_name,
'{var_name} ~ {distribution}'.format(var_name=var_name, distribution=distribution),
**attrs)


def crude_draw(model, *args, **kwargs):
OtherModelGraph(model).draw(*args, **kwargs)
11 changes: 8 additions & 3 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,10 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size

# draw once to inspect the shape
var_values = list(zip(varnames,
draw_values(vars, point=model.test_point, size=size)))
draw_values(vars,
point=model.test_point,
size=size,
model=model)))
ppc_trace = defaultdict(list)
for varname, value in var_values:
ppc_trace[varname] = np.zeros((samples,) + value.shape, value.dtype)
Expand All @@ -1141,7 +1144,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
else:
param = trace[idx % len_trace]

values = draw_values(vars, point=param, size=size)
values = draw_values(vars, point=param, size=size, model=model)
for k, v in zip(vars, values):
ppc_trace[k.name][slc] = v

Expand Down Expand Up @@ -1279,6 +1282,7 @@ def sample_posterior_predictive_w(traces, samples=None, models=None, weights=Non
var = variables[idx]
# TODO sample_posterior_predictive_w is currently only work for model with
# one observed.
# TODO supply the proper model to draw_values
ppc[var.name].append(draw_values([var],
point=param,
size=size[idx]
Expand Down Expand Up @@ -1331,7 +1335,8 @@ def sample_prior_predictive(samples=500, model=None, vars=None, random_seed=None
np.random.seed(random_seed)
names = get_default_varnames(model.named_vars, include_transformed=False)
# draw_values fails with auto-transformed variables. transform them later!
values = draw_values([model[name] for name in names], size=samples)
values = draw_values([model[name] for name in names], size=samples,
model=model)

data = {k: v for k, v in zip(names, values)}

Expand Down
90 changes: 89 additions & 1 deletion pymc3/tests/test_model_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pymc3 as pm
from pymc3.model_graph import ModelGraph, model_to_graphviz
from pymc3.model_graph import ModelGraph, model_to_graphviz, OtherModelGraph
from theano import tensor as tt

from .helpers import SeededTest

Expand Down Expand Up @@ -77,3 +78,90 @@ def test_graphviz(self):
for key in self.compute_graph:
assert key in g.source


def radon_model2():
"""Similar in shape to the Radon model"""
n_homes = 919
counties = 85
uranium = np.random.normal(-.1, 0.4, size=n_homes)
xbar = np.random.normal(1, 0.1, size=n_homes)
floor_measure = np.random.randint(0, 2, size=n_homes)
log_radon = np.random.normal(1, 1, size=n_homes)
radon_dispensers = np.random.normal(1, 0.2, size=n_homes)

d, r = divmod(919, 85)
county = np.hstack((
np.tile(np.arange(counties, dtype=int), d),
np.arange(r)
))
with pm.Model() as model:
sigma_a = pm.HalfCauchy('sigma_a', 5)
gamma = pm.Normal('gamma', mu=0., sd=1e5, shape=3)
mu_a = pm.Deterministic('mu_a', gamma[0] + gamma[1]*uranium + gamma[2]*xbar)
eps_a = pm.Normal('eps_a', mu=0, sd=sigma_a, shape=counties)
a = pm.Deterministic('a', mu_a + eps_a[county])
b = pm.Normal('b', mu=0., sd=1e15)
sigma_y = pm.Uniform('sigma_y', lower=0, upper=100)
p_rad = pm.Beta('p_rad', alpha=1, beta=1)
disp_on = pm.Bernoulli('disp_on', p=p_rad, shape=n_homes)
rad_cloud = pm.Deterministic('rad_cloud', tt.mean(disp_on * radon_dispensers))
y_hat = a + b * floor_measure + rad_cloud
y_like = pm.Normal('y_like', mu=y_hat, sd=sigma_y, observed=log_radon)

plates_without_trans = [((3,), {gamma}),
((), {b}),
((), {rad_cloud}),
((), {sigma_a}),
((), {p_rad}),
((), {sigma_y}),
((919,), {a, mu_a, y_like}),
((919,), {disp_on}),
((85,), {eps_a})]
plates_with_trans = [((), {b}),
((), {rad_cloud}),
((), {sigma_a, model['sigma_a_log__']}),
((), {sigma_y, model['sigma_y_interval__']}),
((), {p_rad, model['p_rad_logodds__']}),
((3,), {gamma}),
((919,), {a, mu_a, y_like}),
((919,), {disp_on}),
((85,), {eps_a})]
return model, plates_without_trans, plates_with_trans


class TestSimpleModel2(SeededTest):
@classmethod
def setup_class(cls):
cls.model, cls.plates_without_trans, cls.plates_with_trans= radon_model2()
cls.model_graph = OtherModelGraph(cls.model)

def test_plates(self):
assert all([plate in self.plates_without_trans
for plate in self.model_graph.get_plates(ignore_transformed=True)])
assert all([plate in self.plates_with_trans
for plate in self.model_graph.get_plates(ignore_transformed=False)])

def test_graphviz(self):
# just make sure everything runs without error

transforms = set([n.transformed for n in self.model_graph.graph
if hasattr(n, 'transformed')])
g = self.model_graph.make_graph(ignore_transformed=True)
for node, key in self.model_graph.node_names.items():
if node not in transforms:
assert key in g.source
else:
assert key not in g.source
g = OtherModelGraph(self.model).make_graph(ignore_transformed=True)
for node, key in self.model_graph.node_names.items():
if node not in transforms:
assert key in g.source
else:
assert key not in g.source

g = self.model_graph.make_graph(ignore_transformed=False)
for node, key in self.model_graph.node_names.items():
assert key in g.source
g = OtherModelGraph(self.model).make_graph(ignore_transformed=False)
for node, key in self.model_graph.node_names.items():
assert key in g.source
Loading