Skip to content

Commit a34f63a

Browse files
authored
Add more info to divergence warnings (#3990)
* Add more info to divergence warnings * Add dataclasses as requirement for py3.6 * Fix tests for extra divergence info * Remove py3.6 requirements
1 parent 8770259 commit a34f63a

File tree

6 files changed

+67
-39
lines changed

6 files changed

+67
-39
lines changed

pymc3/backends/report.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from collections import namedtuple
1615
import logging
1716
import enum
18-
import typing
17+
from typing import Any, Optional
18+
import dataclasses
19+
1920
from ..util import is_transformed_name, get_untransformed_name
2021

2122
import arviz
@@ -38,9 +39,17 @@ class WarningType(enum.Enum):
3839
BAD_ENERGY = 8
3940

4041

41-
SamplerWarning = namedtuple(
42-
'SamplerWarning',
43-
"kind, message, level, step, exec_info, extra")
42+
@dataclasses.dataclass
43+
class SamplerWarning:
44+
kind: WarningType
45+
message: str
46+
level: str
47+
step: Optional[int] = None
48+
exec_info: Optional[Any] = None
49+
extra: Optional[Any] = None
50+
divergence_point_source: Optional[dict] = None
51+
divergence_point_dest: Optional[dict] = None
52+
divergence_info: Optional[Any] = None
4453

4554

4655
_LEVELS = {
@@ -53,7 +62,8 @@ class WarningType(enum.Enum):
5362

5463

5564
class SamplerReport:
56-
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
65+
"""Bundle warnings, convergence stats and metadata of a sampling run."""
66+
5767
def __init__(self):
5868
self._chain_warnings = {}
5969
self._global_warnings = []
@@ -75,17 +85,17 @@ def ok(self):
7585
for warn in self._warnings)
7686

7787
@property
78-
def n_tune(self) -> typing.Optional[int]:
88+
def n_tune(self) -> Optional[int]:
7989
"""Number of tune iterations - not necessarily kept in trace!"""
8090
return self._n_tune
8191

8292
@property
83-
def n_draws(self) -> typing.Optional[int]:
93+
def n_draws(self) -> Optional[int]:
8494
"""Number of draw iterations."""
8595
return self._n_draws
8696

8797
@property
88-
def t_sampling(self) -> typing.Optional[float]:
98+
def t_sampling(self) -> Optional[float]:
8999
"""
90100
Number of seconds that the sampling procedure took.
91101
@@ -110,8 +120,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
110120
if idata.posterior.sizes['chain'] == 1:
111121
msg = ("Only one chain was sampled, this makes it impossible to "
112122
"run some convergence checks")
113-
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
114-
None, None, None)
123+
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info')
115124
self._add_warnings([warn])
116125
return
117126

@@ -134,41 +143,42 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
134143
msg = ("The rhat statistic is larger than 1.4 for some "
135144
"parameters. The sampler did not converge.")
136145
warn = SamplerWarning(
137-
WarningType.CONVERGENCE, msg, 'error', None, None, rhat)
146+
WarningType.CONVERGENCE, msg, 'error', extra=rhat)
138147
warnings.append(warn)
139148
elif rhat_max > 1.2:
140149
msg = ("The rhat statistic is larger than 1.2 for some "
141150
"parameters.")
142151
warn = SamplerWarning(
143-
WarningType.CONVERGENCE, msg, 'warn', None, None, rhat)
152+
WarningType.CONVERGENCE, msg, 'warn', extra=rhat)
144153
warnings.append(warn)
145154
elif rhat_max > 1.05:
146155
msg = ("The rhat statistic is larger than 1.05 for some "
147156
"parameters. This indicates slight problems during "
148157
"sampling.")
149158
warn = SamplerWarning(
150-
WarningType.CONVERGENCE, msg, 'info', None, None, rhat)
159+
WarningType.CONVERGENCE, msg, 'info', extra=rhat)
151160
warnings.append(warn)
152161

153162
eff_min = min(val.min() for val in ess.values())
154-
n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw']
163+
sizes = idata.posterior.sizes
164+
n_samples = sizes['chain'] * sizes['draw']
155165
if eff_min < 200 and n_samples >= 500:
156166
msg = ("The estimated number of effective samples is smaller than "
157167
"200 for some parameters.")
158168
warn = SamplerWarning(
159-
WarningType.CONVERGENCE, msg, 'error', None, None, ess)
169+
WarningType.CONVERGENCE, msg, 'error', extra=ess)
160170
warnings.append(warn)
161171
elif eff_min / n_samples < 0.1:
162172
msg = ("The number of effective samples is smaller than "
163173
"10% for some parameters.")
164174
warn = SamplerWarning(
165-
WarningType.CONVERGENCE, msg, 'warn', None, None, ess)
175+
WarningType.CONVERGENCE, msg, 'warn', extra=ess)
166176
warnings.append(warn)
167177
elif eff_min / n_samples < 0.25:
168178
msg = ("The number of effective samples is smaller than "
169179
"25% for some parameters.")
170180
warn = SamplerWarning(
171-
WarningType.CONVERGENCE, msg, 'info', None, None, ess)
181+
WarningType.CONVERGENCE, msg, 'info', extra=ess)
172182
warnings.append(warn)
173183

174184
self._add_warnings(warnings)
@@ -201,7 +211,7 @@ def filter_warns(warnings):
201211
filtered.append(warn)
202212
elif (start <= warn.step < stop and
203213
(warn.step - start) % step == 0):
204-
warn = warn._replace(step=warn.step - start)
214+
warn = dataclasses.replace(warn, step=warn.step - start)
205215
filtered.append(warn)
206216
return filtered
207217

pymc3/step_methods/hmc/base_hmc.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@
2929

3030
logger = logging.getLogger("pymc3")
3131

32-
HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats")
32+
HMCStepData = namedtuple(
33+
"HMCStepData",
34+
"end, accept_stat, divergence_info, stats"
35+
)
3336

37+
DivergenceInfo = namedtuple(
38+
"DivergenceInfo",
39+
"message, exec_info, state, state_div"
40+
)
3441

35-
DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state")
3642

3743
class BaseHMC(arraystep.GradientSharedStep):
3844
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
@@ -148,15 +154,14 @@ def astep(self, q0):
148154
self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
149155
message_energy = (
150156
"Bad initial energy, check any log probabilities that "
151-
"are inf or -inf, nan or very small:\n{}".format(error_logp.to_string())
157+
"are inf or -inf, nan or very small:\n{}"
158+
.format(error_logp.to_string())
152159
)
153160
warning = SamplerWarning(
154161
WarningType.BAD_ENERGY,
155162
message_energy,
156163
"critical",
157164
self.iter_count,
158-
None,
159-
None,
160165
)
161166
self._warnings.append(warning)
162167
raise SamplingError("Bad initial energy")
@@ -177,19 +182,32 @@ def astep(self, q0):
177182
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
178183
if hmc_step.divergence_info:
179184
info = hmc_step.divergence_info
185+
point = None
186+
point_dest = None
187+
info_store = None
180188
if self.tune:
181189
kind = WarningType.TUNING_DIVERGENCE
182-
point = None
183190
else:
184191
kind = WarningType.DIVERGENCE
185192
self._num_divs_sample += 1
186193
# We don't want to fill up all memory with divergence info
187-
if self._num_divs_sample < 100:
194+
if self._num_divs_sample < 100 and info.state is not None:
188195
point = self._logp_dlogp_func.array_to_dict(info.state.q)
189-
else:
190-
point = None
196+
if self._num_divs_sample < 100 and info.state_div is not None:
197+
point_dest = self._logp_dlogp_func.array_to_dict(
198+
info.state_div.q
199+
)
200+
if self._num_divs_sample < 100:
201+
info_store = info
191202
warning = SamplerWarning(
192-
kind, info.message, "debug", self.iter_count, info.exec_info, point
203+
kind,
204+
info.message,
205+
"debug",
206+
self.iter_count,
207+
info.exec_info,
208+
divergence_point_source=point,
209+
divergence_point_dest=point_dest,
210+
divergence_info=info_store,
193211
)
194212

195213
self._warnings.append(warning)
@@ -243,9 +261,7 @@ def warnings(self):
243261
)
244262

245263
if message:
246-
warning = SamplerWarning(
247-
WarningType.DIVERGENCES, message, "error", None, None, None
248-
)
264+
warning = SamplerWarning(WarningType.DIVERGENCES, message, "error")
249265
warnings.append(warning)
250266

251267
warnings.extend(self.step_adapt.warnings())

pymc3/step_methods/hmc/hmc.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -116,23 +116,25 @@ def _hamiltonian_step(self, start, p0, step_size):
116116

117117
energy_change = -np.inf
118118
state = start
119+
last = state
119120
div_info = None
120121
try:
121122
for _ in range(n_steps):
123+
last = state
122124
state = self.integrator.step(step_size, state)
123125
except IntegrationError as e:
124-
div_info = DivergenceInfo('Divergence encountered.', e, state)
126+
div_info = DivergenceInfo('Integration failed.', e, last, None)
125127
else:
126128
if not np.isfinite(state.energy):
127129
div_info = DivergenceInfo(
128-
'Divergence encountered, bad energy.', None, state)
130+
'Divergence encountered, bad energy.', None, last, state)
129131
energy_change = start.energy - state.energy
130132
if np.isnan(energy_change):
131133
energy_change = -np.inf
132134
if np.abs(energy_change) > self.Emax:
133135
div_info = DivergenceInfo(
134136
'Divergence encountered, large integration error.',
135-
None, state)
137+
None, last, state)
136138

137139
accept_stat = min(1, np.exp(energy_change))
138140

pymc3/step_methods/hmc/nuts.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def warnings(self):
210210
"The chain reached the maximum tree depth. Increase "
211211
"max_treedepth, increase target_accept or reparameterize."
212212
)
213-
warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None)
213+
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn')
214214
warnings.append(warn)
215215
return warnings
216216

@@ -331,6 +331,7 @@ def _single_step(self, left, epsilon):
331331
except IntegrationError as err:
332332
error_msg = str(err)
333333
error = err
334+
right = None
334335
else:
335336
# h - H0
336337
energy_change = right.energy - self.start_energy
@@ -363,7 +364,7 @@ def _single_step(self, left, epsilon):
363364
)
364365
error = None
365366
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
366-
divergance_info = DivergenceInfo(error_msg, error, left)
367+
divergance_info = DivergenceInfo(error_msg, error, left, right)
367368
return tree, divergance_info, False
368369

369370
def _build_subtree(self, left, depth, epsilon):

pymc3/step_methods/step_sizes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def warnings(self):
7777
% (mean_accept, target_accept))
7878
info = {'target': target_accept, 'actual': mean_accept}
7979
warning = SamplerWarning(
80-
WarningType.BAD_ACCEPTANCE, msg, 'warn', None, None, info)
80+
WarningType.BAD_ACCEPTANCE, msg, 'warn', extra=info)
8181
return [warning]
8282
else:
8383
return []

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@ patsy>=0.5.1
77
fastprogress>=0.2.0
88
h5py>=2.7.0
99
typing-extensions>=3.7.4
10-
contextvars; python_version < '3.7'

0 commit comments

Comments
 (0)