Skip to content

Add ability to mix batch initial conditions and internal IC generation #2610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 165 additions & 37 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,43 @@ def __post_init__(self) -> None:
"3-dimensional. Its shape is "
f"{batch_initial_conditions_shape}."
)

if batch_initial_conditions_shape[-1] != d:
raise ValueError(
f"batch_initial_conditions.shape[-1] must be {d}. The "
f"shape is {batch_initial_conditions_shape}."
)

if len(batch_initial_conditions_shape) == 2:
warnings.warn(
"If using a 2-dim `batch_initial_conditions` botorch will "
"default to old behavior of ignoring `num_restarts` and just "
"use the given `batch_initial_conditions` by setting "
"`raw_samples` to None.",
RuntimeWarning,
stacklevel=3,
)
# Use object.__setattr__ to bypass immutability and set a value
object.__setattr__(self, "raw_samples", None)

if (
len(batch_initial_conditions_shape) == 3
and batch_initial_conditions_shape[0] < self.num_restarts
and batch_initial_conditions_shape[-2] != self.q
):
warnings.warn(
"If using a 3-dim `batch_initial_conditions` where the "
"first dimension is less than `num_restarts` and the second "
"dimension is not equal to `q`, botorch will default to "
"old behavior of ignoring `num_restarts` and just use the "
"given `batch_initial_conditions` by setting `raw_samples` "
"to None.",
RuntimeWarning,
stacklevel=3,
)
# Use object.__setattr__ to bypass immutability and set a value
object.__setattr__(self, "raw_samples", None)

elif self.ic_generator is None:
if self.nonlinear_inequality_constraints is not None:
raise RuntimeError(
Expand All @@ -126,6 +157,7 @@ def __post_init__(self) -> None:
"Must specify `raw_samples` when "
"`batch_initial_conditions` is None`."
)

if self.fixed_features is not None and any(
(k < 0 for k in self.fixed_features)
):
Expand Down Expand Up @@ -253,20 +285,49 @@ def _optimize_acqf_sequential_q(
return candidates, torch.stack(acq_value_list)


def _combine_initial_conditions(
provided_initial_conditions: Tensor | None = None,
generated_initial_conditions: Tensor | None = None,
dim=0,
) -> Tensor:
if (
provided_initial_conditions is not None
and generated_initial_conditions is not None
):
return torch.cat(
[provided_initial_conditions, generated_initial_conditions], dim=dim
)
elif provided_initial_conditions is not None:
return provided_initial_conditions
elif generated_initial_conditions is not None:
return generated_initial_conditions
else:
raise ValueError(
"Either `batch_initial_conditions` or `raw_samples` must be set."
)


def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
options = opt_inputs.options or {}

initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
required_num_restarts = opt_inputs.num_restarts
provided_initial_conditions = opt_inputs.batch_initial_conditions
generated_initial_conditions = None

if initial_conditions_provided:
batch_initial_conditions = opt_inputs.batch_initial_conditions
else:
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
batch_initial_conditions = opt_inputs.get_ic_generator()(
if (
provided_initial_conditions is not None
and len(provided_initial_conditions.shape) == 3
):
required_num_restarts -= provided_initial_conditions.shape[0]

if opt_inputs.raw_samples is not None and required_num_restarts > 0:
# pyre-ignore[28]: Unexpected keyword argument `acq_function`
# to anonymous call.
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
num_restarts=required_num_restarts,
raw_samples=opt_inputs.raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
Expand All @@ -275,6 +336,11 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
**opt_inputs.ic_gen_kwargs,
)

batch_initial_conditions = _combine_initial_conditions(
provided_initial_conditions=provided_initial_conditions,
generated_initial_conditions=generated_initial_conditions,
)

batch_limit: int = options.get(
"batch_limit",
(
Expand Down Expand Up @@ -344,23 +410,24 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
first_warn_msg = (
"Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
"`batch_initial_conditions`, optimization will not be retried with "
"new initial conditions and will proceed with the current solution."
" Suggested remediation: Try again with different "
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
if initial_conditions_provided
"`batch_initial_conditions` larger than required `num_restarts`, "
"optimization will not be retried with new initial conditions and "
"will proceed with the current solution. Suggested remediation: "
"Try again with different `batch_initial_conditions`, don't provide "
"`batch_initial_conditions`, or increase `num_restarts`."
if batch_initial_conditions is not None and required_num_restarts <= 0
else "Optimization failed in `gen_candidates_scipy` with the following "
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
"set of initial conditions."
)
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)

if not initial_conditions_provided:
batch_initial_conditions = opt_inputs.get_ic_generator()(
if opt_inputs.raw_samples is not None and required_num_restarts > 0:
generated_initial_conditions = opt_inputs.get_ic_generator()(
acq_function=opt_inputs.acq_function,
bounds=opt_inputs.bounds,
q=opt_inputs.q,
num_restarts=opt_inputs.num_restarts,
num_restarts=required_num_restarts,
raw_samples=opt_inputs.raw_samples,
fixed_features=opt_inputs.fixed_features,
options=options,
Expand All @@ -369,6 +436,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
**opt_inputs.ic_gen_kwargs,
)

batch_initial_conditions = _combine_initial_conditions(
provided_initial_conditions=provided_initial_conditions,
generated_initial_conditions=generated_initial_conditions,
)

batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()

optimization_warning_raised = any(
Expand Down Expand Up @@ -1177,7 +1249,7 @@ def _gen_batch_initial_conditions_local_search(
inequality_constraints: list[tuple[Tensor, Tensor, float]],
min_points: int,
max_tries: int = 100,
):
) -> Tensor:
"""Generate initial conditions for local search."""
device = discrete_choices[0].device
dtype = discrete_choices[0].dtype
Expand All @@ -1197,6 +1269,58 @@ def _gen_batch_initial_conditions_local_search(
raise RuntimeError(f"Failed to generate at least {min_points} initial conditions")


def _gen_starting_points_local_search(
discrete_choices: list[Tensor],
raw_samples: int,
batch_initial_conditions: Tensor,
X_avoid: Tensor,
inequality_constraints: list[tuple[Tensor, Tensor, float]],
min_points: int,
acq_function: AcquisitionFunction,
max_batch_size: int = 2048,
max_tries: int = 100,
) -> Tensor:
required_min_points = min_points
provided_X0 = None
generated_X0 = None

if batch_initial_conditions is not None:
provided_X0 = _filter_invalid(
X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid
)
provided_X0 = _filter_infeasible(
X=provided_X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
required_min_points -= batch_initial_conditions.shape[0]

if required_min_points > 0:
generated_X0 = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=min_points,
max_tries=max_tries,
)

# pick the best starting points
with torch.no_grad():
acqvals_init = _split_batch_eval_acqf(
acq_function=acq_function,
X=generated_X0.unsqueeze(1),
max_batch_size=max_batch_size,
).unsqueeze(-1)

generated_X0 = generated_X0[
acqvals_init.topk(k=min_points, largest=True, dim=0).indices
]

return _combine_initial_conditions(
provided_initial_conditions=provided_X0 if provided_X0 is not None else None,
generated_initial_conditions=generated_X0 if generated_X0 is not None else None,
)


def optimize_acqf_discrete_local_search(
acq_function: AcquisitionFunction,
discrete_choices: list[Tensor],
Expand All @@ -1207,6 +1331,7 @@ def optimize_acqf_discrete_local_search(
X_avoid: Tensor | None = None,
batch_initial_conditions: Tensor | None = None,
max_batch_size: int = 2048,
max_tries: int = 100,
unique: bool = True,
) -> tuple[Tensor, Tensor]:
r"""Optimize acquisition function over a lattice.
Expand Down Expand Up @@ -1238,6 +1363,8 @@ def optimize_acqf_discrete_local_search(
max_batch_size: The maximum number of choices to evaluate in batch.
A large limit can cause excessive memory usage if the model has
a large training set.
max_tries: Maximum number of iterations to try when generating initial
conditions.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).

Expand All @@ -1247,6 +1374,16 @@ def optimize_acqf_discrete_local_search(
- a `q x d`-dim tensor of generated candidates.
- an associated acquisition value.
"""
if batch_initial_conditions is not None:
if not (
len(batch_initial_conditions.shape) == 3
and batch_initial_conditions.shape[-2] == 1
):
raise ValueError(
"batch_initial_conditions must have shape `n x 1 x d` if "
f"given (received shape {batch_initial_conditions.shape})."
)

candidate_list = []
base_X_pending = acq_function.X_pending if q > 1 else None
base_X_avoid = X_avoid
Expand All @@ -1259,27 +1396,18 @@ def optimize_acqf_discrete_local_search(
inequality_constraints = inequality_constraints or []
for i in range(q):
# generate some starting points
if i == 0 and batch_initial_conditions is not None:
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
X0 = _filter_infeasible(
X=X0, inequality_constraints=inequality_constraints
).unsqueeze(1)
else:
X_init = _gen_batch_initial_conditions_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=num_restarts,
)
# pick the best starting points
with torch.no_grad():
acqvals_init = _split_batch_eval_acqf(
acq_function=acq_function,
X=X_init.unsqueeze(1),
max_batch_size=max_batch_size,
).unsqueeze(-1)
X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]
X0 = _gen_starting_points_local_search(
discrete_choices=discrete_choices,
raw_samples=raw_samples,
batch_initial_conditions=batch_initial_conditions,
X_avoid=X_avoid,
inequality_constraints=inequality_constraints,
min_points=num_restarts,
acq_function=acq_function,
max_batch_size=max_batch_size,
max_tries=max_tries,
)
batch_initial_conditions = None

# optimize from the best starting points
best_xs = torch.zeros(len(X0), dim, device=device, dtype=dtype)
Expand Down
8 changes: 7 additions & 1 deletion botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def optimize_acqf_homotopy(
"""
shared_optimize_acqf_kwargs = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
Expand All @@ -178,6 +177,7 @@ def optimize_acqf_homotopy(

for _ in range(q):
candidates = batch_initial_conditions
q_raw_samples = raw_samples
homotopy.restart()

while not homotopy.should_stop:
Expand All @@ -187,10 +187,15 @@ def optimize_acqf_homotopy(
q=1,
options=options,
batch_initial_conditions=candidates,
raw_samples=q_raw_samples,
**shared_optimize_acqf_kwargs,
)
homotopy.step()

# Set raw_samples to None such that pruned restarts are not repopulated
# at each step in the homotopy.
q_raw_samples = None

# Prune candidates
candidates = prune_candidates(
candidates=candidates.squeeze(1),
Expand All @@ -204,6 +209,7 @@ def optimize_acqf_homotopy(
bounds=bounds,
q=1,
options=final_options,
raw_samples=q_raw_samples,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)
Expand Down
Loading
Loading