12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from collections import namedtuple
16
15
import logging
17
16
import enum
18
- import typing
17
+ from typing import Any , Optional
18
+ import dataclasses
19
+
19
20
from ..util import is_transformed_name , get_untransformed_name
20
21
21
22
import arviz
@@ -38,9 +39,17 @@ class WarningType(enum.Enum):
38
39
BAD_ENERGY = 8
39
40
40
41
41
- SamplerWarning = namedtuple (
42
- 'SamplerWarning' ,
43
- "kind, message, level, step, exec_info, extra" )
42
+ @dataclasses .dataclass
43
+ class SamplerWarning :
44
+ kind : WarningType
45
+ message : str
46
+ level : str
47
+ step : Optional [int ] = None
48
+ exec_info : Optional [Any ] = None
49
+ extra : Optional [Any ] = None
50
+ divergence_point_source : Optional [dict ] = None
51
+ divergence_point_dest : Optional [dict ] = None
52
+ divergence_info : Optional [Any ] = None
44
53
45
54
46
55
_LEVELS = {
@@ -53,7 +62,8 @@ class WarningType(enum.Enum):
53
62
54
63
55
64
class SamplerReport :
56
- """This object bundles warnings, convergence statistics and metadata of a sampling run."""
65
+ """Bundle warnings, convergence stats and metadata of a sampling run."""
66
+
57
67
def __init__ (self ):
58
68
self ._chain_warnings = {}
59
69
self ._global_warnings = []
@@ -75,17 +85,17 @@ def ok(self):
75
85
for warn in self ._warnings )
76
86
77
87
@property
78
- def n_tune (self ) -> typing . Optional [int ]:
88
+ def n_tune (self ) -> Optional [int ]:
79
89
"""Number of tune iterations - not necessarily kept in trace!"""
80
90
return self ._n_tune
81
91
82
92
@property
83
- def n_draws (self ) -> typing . Optional [int ]:
93
+ def n_draws (self ) -> Optional [int ]:
84
94
"""Number of draw iterations."""
85
95
return self ._n_draws
86
96
87
97
@property
88
- def t_sampling (self ) -> typing . Optional [float ]:
98
+ def t_sampling (self ) -> Optional [float ]:
89
99
"""
90
100
Number of seconds that the sampling procedure took.
91
101
@@ -110,8 +120,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
110
120
if idata .posterior .sizes ['chain' ] == 1 :
111
121
msg = ("Only one chain was sampled, this makes it impossible to "
112
122
"run some convergence checks" )
113
- warn = SamplerWarning (WarningType .BAD_PARAMS , msg , 'info' ,
114
- None , None , None )
123
+ warn = SamplerWarning (WarningType .BAD_PARAMS , msg , 'info' )
115
124
self ._add_warnings ([warn ])
116
125
return
117
126
@@ -134,41 +143,42 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model):
134
143
msg = ("The rhat statistic is larger than 1.4 for some "
135
144
"parameters. The sampler did not converge." )
136
145
warn = SamplerWarning (
137
- WarningType .CONVERGENCE , msg , 'error' , None , None , rhat )
146
+ WarningType .CONVERGENCE , msg , 'error' , extra = rhat )
138
147
warnings .append (warn )
139
148
elif rhat_max > 1.2 :
140
149
msg = ("The rhat statistic is larger than 1.2 for some "
141
150
"parameters." )
142
151
warn = SamplerWarning (
143
- WarningType .CONVERGENCE , msg , 'warn' , None , None , rhat )
152
+ WarningType .CONVERGENCE , msg , 'warn' , extra = rhat )
144
153
warnings .append (warn )
145
154
elif rhat_max > 1.05 :
146
155
msg = ("The rhat statistic is larger than 1.05 for some "
147
156
"parameters. This indicates slight problems during "
148
157
"sampling." )
149
158
warn = SamplerWarning (
150
- WarningType .CONVERGENCE , msg , 'info' , None , None , rhat )
159
+ WarningType .CONVERGENCE , msg , 'info' , extra = rhat )
151
160
warnings .append (warn )
152
161
153
162
eff_min = min (val .min () for val in ess .values ())
154
- n_samples = idata .posterior .sizes ['chain' ] * idata .posterior .sizes ['draw' ]
163
+ sizes = idata .posterior .sizes
164
+ n_samples = sizes ['chain' ] * sizes ['draw' ]
155
165
if eff_min < 200 and n_samples >= 500 :
156
166
msg = ("The estimated number of effective samples is smaller than "
157
167
"200 for some parameters." )
158
168
warn = SamplerWarning (
159
- WarningType .CONVERGENCE , msg , 'error' , None , None , ess )
169
+ WarningType .CONVERGENCE , msg , 'error' , extra = ess )
160
170
warnings .append (warn )
161
171
elif eff_min / n_samples < 0.1 :
162
172
msg = ("The number of effective samples is smaller than "
163
173
"10% for some parameters." )
164
174
warn = SamplerWarning (
165
- WarningType .CONVERGENCE , msg , 'warn' , None , None , ess )
175
+ WarningType .CONVERGENCE , msg , 'warn' , extra = ess )
166
176
warnings .append (warn )
167
177
elif eff_min / n_samples < 0.25 :
168
178
msg = ("The number of effective samples is smaller than "
169
179
"25% for some parameters." )
170
180
warn = SamplerWarning (
171
- WarningType .CONVERGENCE , msg , 'info' , None , None , ess )
181
+ WarningType .CONVERGENCE , msg , 'info' , extra = ess )
172
182
warnings .append (warn )
173
183
174
184
self ._add_warnings (warnings )
@@ -201,7 +211,7 @@ def filter_warns(warnings):
201
211
filtered .append (warn )
202
212
elif (start <= warn .step < stop and
203
213
(warn .step - start ) % step == 0 ):
204
- warn = warn . _replace ( step = warn .step - start )
214
+ warn = dataclasses . replace ( warn , step = warn .step - start )
205
215
filtered .append (warn )
206
216
return filtered
207
217
0 commit comments