-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
SMC-ABC add distance, refactor and update notebook #3996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
2fb0380
fe30260
019b2df
bbbbbd0
9ec5cc7
ca60571
88ec8e1
3e42ca8
1aabed8
f2976cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,27 +13,77 @@ | |
# limitations under the License. | ||
|
||
import numpy as np | ||
from .distribution import NoDistribution | ||
from .distribution import NoDistribution, draw_values | ||
|
||
__all__ = ["Simulator"] | ||
|
||
|
||
class Simulator(NoDistribution): | ||
def __init__(self, function, *args, params=None, **kwargs): | ||
def __init__( | ||
self, | ||
function, | ||
*args, | ||
params=None, | ||
distance="gaussian_kernel", | ||
sum_stat="identity", | ||
epsilon=1, | ||
**kwargs, | ||
): | ||
""" | ||
This class stores a function defined by the user in python language. | ||
|
||
function: function | ||
Simulation function defined by the user. | ||
params: list | ||
Parameters passed to function. | ||
distance: str or callable | ||
Distance functions. Available options are "gaussian_kernel" (default), "wasserstein", | ||
"energy" or a user defined function | ||
``gaussian_kernel`` :math: `\sum \left(-0.5 \left(\frac{xo - xs}{\epsilon}\right)^2\right)` | ||
``wasserstein`` :math: `\frac{1}{n} \sum{\left(\frac{|xo - xs|}{\epsilon}\right)}` | ||
``energy`` :math: `\sqrt{2} \sqrt{\frac{1}{n} \sum \left(\frac{|xo - xs|}{\epsilon}\right)^2}` | ||
For the wasserstein and energy distances the observed data xo and simulated data xs | ||
are internally sorted (i.e. the sum_stat is "sort"). | ||
sum_stat: str or callable | ||
Summary statistics. Available options are ``indentity``, ``sort``, ``mean``, ``median``. | ||
If a callable is based it should return a number or a 1d numpy array. | ||
epsilon: float | ||
Standard deviation of the gaussian_kernel. | ||
*args and **kwargs: | ||
Arguments and keywords arguments that the function takes. | ||
""" | ||
|
||
self.function = function | ||
self.params = params | ||
observed = self.data | ||
self.epsilon = epsilon | ||
|
||
if distance == "gaussian_kernel": | ||
self.distance = gaussian_kernel | ||
elif distance == "wasserstein": | ||
self.distance = wasserstein | ||
sum_stat = "sort" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do a assertion here or a warning that |
||
elif distance == "energy": | ||
self.distance = energy | ||
sum_stat = "sort" | ||
elif hasattr(distance, "__call__"): | ||
self.distance = distance | ||
else: | ||
raise ValueError(f"The distance metric {distance} is not implemented") | ||
|
||
if sum_stat == "identity": | ||
self.sum_stat = identity | ||
elif sum_stat == "sort": | ||
self.sum_stat = np.sort | ||
elif sum_stat == "mean": | ||
self.sum_stat = np.mean | ||
elif sum_stat == "median": | ||
self.sum_stat = np.median | ||
elif hasattr(sum_stat, "__call__"): | ||
self.sum_stat = sum_stat | ||
else: | ||
raise ValueError(f"The summary statistics {sum_stat} is not implemented") | ||
|
||
super().__init__(shape=np.prod(observed.shape), dtype=observed.dtype, *args, **kwargs) | ||
|
||
def random(self, point=None, size=None): | ||
|
@@ -51,16 +101,44 @@ def random(self, point=None, size=None): | |
------- | ||
array | ||
""" | ||
|
||
raise NotImplementedError("Not implemented yet") | ||
params = draw_values([*self.params], point=point, size=size) | ||
if size is None: | ||
return self.function(*params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self.function is intended to be a Python function not a Theano one. |
||
else: | ||
return np.array([self.function(*params) for _ in range(size)]) | ||
|
||
def _repr_latex_(self, name=None, dist=None): | ||
if dist is None: | ||
dist = self | ||
name = r"\text{%s}" % name | ||
function = dist.function | ||
params = dist.parameters | ||
sum_stat = dist.sum_stat | ||
return r"${} \sim \text{{Simulator}}(\mathit{{function}}={},~\mathit{{parameters}}={},~\mathit{{summary statistics}}={})$".format( | ||
name, function, params, sum_stat | ||
) | ||
name = name | ||
function = dist.function.__name__ | ||
params = ", ".join([var.name for var in dist.params]) | ||
sum_stat = self.sum_stat.__name__ if hasattr(self.sum_stat, "__call__") else self.sum_stat | ||
distance = self.distance.__name__ | ||
return f"$\\text{{{name}}} \sim \\text{{Simulator}}(\\text{{{function}}}({params}), \\text{{{distance}}}, \\text{{{sum_stat}}})$" | ||
|
||
|
||
def identity(x): | ||
"""Identity function, used as a summary statistics.""" | ||
return x | ||
|
||
|
||
def gaussian_kernel(epsilon, obs_data, sim_data): | ||
"""gaussian distance function""" | ||
return np.sum(-0.5 * ((obs_data - sim_data) / epsilon) ** 2) | ||
|
||
|
||
def wasserstein(epsilon, obs_data, sim_data): | ||
"""Wasserstein distance function. | ||
|
||
We are assuming obs_data and sim_data are already sorted! | ||
""" | ||
return np.mean(np.abs((obs_data - sim_data) / epsilon)) | ||
|
||
|
||
def energy(epsilon, obs_data, sim_data): | ||
"""Energy distance function. | ||
|
||
We are assuming obs_data and sim_data are already sorted! | ||
""" | ||
return 1.4142 * np.mean(((obs_data - sim_data) / epsilon) ** 2) ** 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.