Skip to content

Commit 61e2664

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
[q]LogProbabilityOfFeasibility (#2815)
Summary: This commit adds `LogProbabilityOfFeasibility` and the corresponding batch Monte-Carlo version `qLogProbabilityOfFeasibility`. These new acquisition functions compute a probability of feasibility acquisition value that is particularly relevant for optimization runs that start out without any feasible observations. After a feasible observation is found, the acquisition function likely becomes over-exploitative, and one is better served switching to an acquisition function that also takes into account the objective value, e.g. `qLogNEI`. Differential Revision: D72411936
1 parent cd36e52 commit 61e2664

File tree

9 files changed

+655
-138
lines changed

9 files changed

+655
-138
lines changed

botorch/acquisition/analytic.py

+180-90
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import math
1515

16-
from abc import ABC
16+
from abc import ABC, abstractmethod
1717
from contextlib import nullcontext
1818
from copy import deepcopy
1919

@@ -415,7 +415,112 @@ def forward(self, X: Tensor) -> Tensor:
415415
return _log_ei_helper(u) + sigma.log()
416416

417417

418-
class LogConstrainedExpectedImprovement(AnalyticAcquisitionFunction):
418+
class ConstrainedAnalyticAcquisitionFunctionMixin(ABC):
419+
r"""Base class for constrained analytic acquisition functions."""
420+
421+
def __init__(
422+
self,
423+
constraints: dict[int, tuple[float | None, float | None]],
424+
) -> None:
425+
r"""Analytic Log Probability of Feasibility.
426+
427+
Args:
428+
model: A fitted multi-output model.
429+
constraints: A dictionary of the form `{i: [lower, upper]}`, where
430+
`i` is the output index, and `lower` and `upper` are lower and upper
431+
bounds on that output (resp. interpreted as -Inf / Inf if None).
432+
"""
433+
self.constraints = constraints
434+
self._preprocess_constraint_bounds(constraints=constraints)
435+
436+
@abstractmethod
437+
def register_buffer(self, name: str, value: Tensor) -> None:
438+
"""Add a buffer that can be accessed by `self.name` and stores the Tensor
439+
`value`, usually provided by derivatives that also inherit from `nn.Module`.
440+
"""
441+
442+
def _preprocess_constraint_bounds(
443+
self,
444+
constraints: dict[int, tuple[float | None, float | None]],
445+
) -> None:
446+
r"""Set up constraint bounds.
447+
448+
Args:
449+
constraints: A dictionary of the form `{i: [lower, upper]}`, where
450+
`i` is the output index, and `lower` and `upper` are lower and upper
451+
bounds on that output (resp. interpreted as -Inf / Inf if None)
452+
"""
453+
con_lower, con_lower_inds = [], []
454+
con_upper, con_upper_inds = [], []
455+
con_both, con_both_inds = [], []
456+
con_indices = list(constraints.keys())
457+
if len(con_indices) == 0:
458+
raise ValueError("There must be at least one constraint.")
459+
# CEI, LogCEI have an objective index, but LogPOF does not.
460+
if hasattr(self, "objective_index") and self.objective_index in con_indices:
461+
raise ValueError(
462+
"Output corresponding to objective should not be a constraint."
463+
)
464+
for k in con_indices:
465+
if constraints[k][0] is not None and constraints[k][1] is not None:
466+
if constraints[k][1] <= constraints[k][0]:
467+
raise ValueError("Upper bound is less than the lower bound.")
468+
con_both_inds.append(k)
469+
con_both.append([constraints[k][0], constraints[k][1]])
470+
elif constraints[k][0] is not None:
471+
con_lower_inds.append(k)
472+
con_lower.append(constraints[k][0])
473+
elif constraints[k][1] is not None:
474+
con_upper_inds.append(k)
475+
con_upper.append(constraints[k][1])
476+
477+
for name, value in [
478+
("con_lower_inds", con_lower_inds),
479+
("con_upper_inds", con_upper_inds),
480+
("con_both_inds", con_both_inds),
481+
("con_both", con_both),
482+
("con_lower", con_lower),
483+
("con_upper", con_upper),
484+
]:
485+
# tensor-based indexing is much faster than list-based advanced indexing
486+
self.register_buffer(name, tensor=torch.as_tensor(value))
487+
488+
def _compute_log_prob_feas(
489+
self,
490+
means: Tensor,
491+
sigmas: Tensor,
492+
) -> Tensor:
493+
r"""Compute logarithm of the feasibility probability for each batch of X.
494+
495+
Args:
496+
X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
497+
points each.
498+
means: A `(b) x m`-dim Tensor of means.
499+
sigmas: A `(b) x m`-dim Tensor of standard deviations.
500+
501+
Returns:
502+
A `b`-dim tensor of log feasibility probabilities
503+
504+
Note: This function does case-work for upper bound, lower bound, and both-sided
505+
bounds. Another way to do it would be to use 'inf' and -'inf' for the
506+
one-sided bounds and use the logic for the both-sided case. But this
507+
causes an issue with autograd since we get 0 * inf.
508+
"""
509+
return compute_log_prob_feas_from_bounds(
510+
con_lower_inds=self.con_lower_inds,
511+
con_upper_inds=self.con_upper_inds,
512+
con_both_inds=self.con_both_inds,
513+
con_lower=self.con_lower,
514+
con_upper=self.con_upper,
515+
con_both=self.con_both,
516+
means=means,
517+
sigmas=sigmas,
518+
)
519+
520+
521+
class LogConstrainedExpectedImprovement(
522+
AnalyticAcquisitionFunction, ConstrainedAnalyticAcquisitionFunctionMixin
523+
):
419524
r"""Log Constrained Expected Improvement (feasibility-weighted).
420525
421526
Computes the logarithm of the analytic expected improvement for a Normal posterior
@@ -464,13 +569,12 @@ def __init__(
464569
maximize: If True, consider the problem a maximization problem.
465570
"""
466571
# Use AcquisitionFunction constructor to avoid check for posterior transform.
467-
super(AnalyticAcquisitionFunction, self).__init__(model=model)
572+
AcquisitionFunction.__init__(self, model=model)
468573
self.posterior_transform = None
469574
self.maximize = maximize
470575
self.objective_index = objective_index
471-
self.constraints = constraints
472576
self.register_buffer("best_f", torch.as_tensor(best_f))
473-
_preprocess_constraint_bounds(self, constraints=constraints)
577+
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
474578
self.register_forward_pre_hook(convert_to_target_pre_hook)
475579

476580
@t_batch_mode_transform(expected_q=1)
@@ -490,11 +594,77 @@ def forward(self, X: Tensor) -> Tensor:
490594
mean_obj, sigma_obj = means[..., ind], sigmas[..., ind]
491595
u = _scaled_improvement(mean_obj, sigma_obj, self.best_f, self.maximize)
492596
log_ei = _log_ei_helper(u) + sigma_obj.log()
493-
log_prob_feas = _compute_log_prob_feas(self, means=means, sigmas=sigmas)
597+
log_prob_feas = self._compute_log_prob_feas(means=means, sigmas=sigmas)
494598
return log_ei + log_prob_feas
495599

496600

497-
class ConstrainedExpectedImprovement(AnalyticAcquisitionFunction):
601+
class LogProbabilityOfFeasibility(
602+
AnalyticAcquisitionFunction, ConstrainedAnalyticAcquisitionFunctionMixin
603+
):
604+
r"""Log Probability of Feasbility.
605+
606+
Computes the logarithm of the analytic probability of feasibility for a Normal
607+
posterior distribution weighted by a probability of feasibility. The objective and
608+
constraints are assumed to be independent and have Gaussian posterior
609+
distributions. Only supports non-batch mode (i.e. `q=1`). The model should be
610+
multi-outcome, with the index of the objective and constraints passed to
611+
the constructor.
612+
613+
See [Ament2023logei]_ for details. Formally,
614+
615+
`LogPF(x) = Sum_i log(P(y_i \in [lower_i, upper_i]))`,
616+
617+
where `y_i ~ constraint_i(x)` and `lower_i`, `upper_i` are the lower and
618+
upper bounds for the i-th constraint, respectively.
619+
620+
Example:
621+
# example where the 0th output has a non-negativity constraint
622+
>>> model = SingleTaskGP(train_X, train_Y)
623+
>>> constraints = {0: (0.0, None)}
624+
>>> LogPOF = LogProbabilityOfFeasibility(model, constraints)
625+
>>> cei = LogPF(test_X)
626+
"""
627+
628+
_log: bool = True
629+
630+
def __init__(
631+
self,
632+
model: Model,
633+
constraints: dict[int, tuple[float | None, float | None]],
634+
) -> None:
635+
r"""Analytic Log Probability of Feasibility.
636+
637+
Args:
638+
model: A fitted multi-output model.
639+
constraints: A dictionary of the form `{i: [lower, upper]}`, where
640+
`i` is the output index, and `lower` and `upper` are lower and upper
641+
bounds on that output (resp. interpreted as -Inf / Inf if None)
642+
"""
643+
# Use AcquisitionFunction constructor to avoid check for posterior transform.
644+
AcquisitionFunction.__init__(self, model=model)
645+
self.posterior_transform = None
646+
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
647+
self.register_forward_pre_hook(convert_to_target_pre_hook)
648+
649+
@t_batch_mode_transform(expected_q=1)
650+
def forward(self, X: Tensor) -> Tensor:
651+
r"""Evaluate Constrained Log Probability of Feasibility on the candidate set X.
652+
653+
Args:
654+
X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
655+
points each.
656+
657+
Returns:
658+
A `(b)`-dim Tensor of Log Probability of Feasibility values at the given
659+
design points `X`.
660+
"""
661+
means, sigmas = self._mean_and_sigma(X) # (b) x 1 + (m = num constraints)
662+
return self._compute_log_prob_feas(means=means, sigmas=sigmas)
663+
664+
665+
class ConstrainedExpectedImprovement(
666+
AnalyticAcquisitionFunction, ConstrainedAnalyticAcquisitionFunctionMixin
667+
):
498668
r"""Constrained Expected Improvement (feasibility-weighted).
499669
500670
Computes the analytic expected improvement for a Normal posterior
@@ -543,13 +713,12 @@ def __init__(
543713
"""
544714
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
545715
# Use AcquisitionFunction constructor to avoid check for posterior transform.
546-
super(AnalyticAcquisitionFunction, self).__init__(model=model)
716+
AcquisitionFunction.__init__(self, model=model)
547717
self.posterior_transform = None
548718
self.maximize = maximize
549719
self.objective_index = objective_index
550-
self.constraints = constraints
551720
self.register_buffer("best_f", torch.as_tensor(best_f))
552-
_preprocess_constraint_bounds(self, constraints=constraints)
721+
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
553722
self.register_forward_pre_hook(convert_to_target_pre_hook)
554723

555724
@t_batch_mode_transform(expected_q=1)
@@ -569,7 +738,7 @@ def forward(self, X: Tensor) -> Tensor:
569738
mean_obj, sigma_obj = means[..., ind], sigmas[..., ind]
570739
u = _scaled_improvement(mean_obj, sigma_obj, self.best_f, self.maximize)
571740
ei = sigma_obj * _ei_helper(u)
572-
log_prob_feas = _compute_log_prob_feas(self, means=means, sigmas=sigmas)
741+
log_prob_feas = self._compute_log_prob_feas(means=means, sigmas=sigmas)
573742
return ei.mul(log_prob_feas.exp())
574743

575744

@@ -1131,82 +1300,3 @@ def _get_noiseless_fantasy_model(
11311300
fantasy_model.likelihood.noise_covar.noise = Yvar
11321301

11331302
return fantasy_model
1134-
1135-
1136-
def _preprocess_constraint_bounds(
1137-
acqf: LogConstrainedExpectedImprovement | ConstrainedExpectedImprovement,
1138-
constraints: dict[int, tuple[float | None, float | None]],
1139-
) -> None:
1140-
r"""Set up constraint bounds.
1141-
1142-
Args:
1143-
constraints: A dictionary of the form `{i: [lower, upper]}`, where
1144-
`i` is the output index, and `lower` and `upper` are lower and upper
1145-
bounds on that output (resp. interpreted as -Inf / Inf if None)
1146-
"""
1147-
con_lower, con_lower_inds = [], []
1148-
con_upper, con_upper_inds = [], []
1149-
con_both, con_both_inds = [], []
1150-
con_indices = list(constraints.keys())
1151-
if len(con_indices) == 0:
1152-
raise ValueError("There must be at least one constraint.")
1153-
if acqf.objective_index in con_indices:
1154-
raise ValueError(
1155-
"Output corresponding to objective should not be a constraint."
1156-
)
1157-
for k in con_indices:
1158-
if constraints[k][0] is not None and constraints[k][1] is not None:
1159-
if constraints[k][1] <= constraints[k][0]:
1160-
raise ValueError("Upper bound is less than the lower bound.")
1161-
con_both_inds.append(k)
1162-
con_both.append([constraints[k][0], constraints[k][1]])
1163-
elif constraints[k][0] is not None:
1164-
con_lower_inds.append(k)
1165-
con_lower.append(constraints[k][0])
1166-
elif constraints[k][1] is not None:
1167-
con_upper_inds.append(k)
1168-
con_upper.append(constraints[k][1])
1169-
# tensor-based indexing is much faster than list-based advanced indexing
1170-
for name, indices in [
1171-
("con_lower_inds", con_lower_inds),
1172-
("con_upper_inds", con_upper_inds),
1173-
("con_both_inds", con_both_inds),
1174-
("con_both", con_both),
1175-
("con_lower", con_lower),
1176-
("con_upper", con_upper),
1177-
]:
1178-
acqf.register_buffer(name, tensor=torch.as_tensor(indices))
1179-
1180-
1181-
def _compute_log_prob_feas(
1182-
acqf: LogConstrainedExpectedImprovement | ConstrainedExpectedImprovement,
1183-
means: Tensor,
1184-
sigmas: Tensor,
1185-
) -> Tensor:
1186-
r"""Compute logarithm of the feasibility probability for each batch of X.
1187-
1188-
Args:
1189-
X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
1190-
points each.
1191-
means: A `(b) x m`-dim Tensor of means.
1192-
sigmas: A `(b) x m`-dim Tensor of standard deviations.
1193-
Returns:
1194-
A `b`-dim tensor of log feasibility probabilities
1195-
1196-
Note: This function does case-work for upper bound, lower bound, and both-sided
1197-
bounds. Another way to do it would be to use 'inf' and -'inf' for the
1198-
one-sided bounds and use the logic for the both-sided case. But this
1199-
causes an issue with autograd since we get 0 * inf.
1200-
TODO: Investigate further.
1201-
"""
1202-
acqf.to(device=means.device)
1203-
return compute_log_prob_feas_from_bounds(
1204-
acqf.con_lower_inds,
1205-
acqf.con_upper_inds,
1206-
acqf.con_both_inds,
1207-
acqf.con_lower,
1208-
acqf.con_upper,
1209-
acqf.con_both,
1210-
means,
1211-
sigmas,
1212-
)

botorch/acquisition/factory.py

+10
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,16 @@ def get_acquisition_function(
145145
constraints=constraints,
146146
eta=eta,
147147
)
148+
elif acquisition_function_name == "qLogPF":
149+
return logei.qLogProbabilityOfFeasibility(
150+
model=model,
151+
constraints=constraints,
152+
sampler=sampler,
153+
objective=objective,
154+
posterior_transform=posterior_transform,
155+
X_pending=X_pending,
156+
eta=eta,
157+
)
148158
elif acquisition_function_name in ["qNEI", "qLogNEI"]:
149159
acqf_class = (
150160
monte_carlo.qNoisyExpectedImprovement

0 commit comments

Comments
 (0)