Skip to content

Commit c1e6d32

Browse files
committed
fea: add the option to provide both some initial batch conditions and request some raw samples
1 parent 92d73e4 commit c1e6d32

File tree

6 files changed

+337
-64
lines changed

6 files changed

+337
-64
lines changed

botorch/optim/optimize.py

Lines changed: 180 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,31 @@ def __post_init__(self) -> None:
109109
"3-dimensional. Its shape is "
110110
f"{batch_initial_conditions_shape}."
111111
)
112+
112113
if batch_initial_conditions_shape[-1] != d:
113114
raise ValueError(
114115
f"batch_initial_conditions.shape[-1] must be {d}. The "
115116
f"shape is {batch_initial_conditions_shape}."
116117
)
117118

119+
if (
120+
self.raw_samples is not None
121+
and (self.raw_samples - batch_initial_conditions_shape[-2]) > 0
122+
and len(batch_initial_conditions_shape) == 3
123+
and self.num_restarts is not None
124+
and batch_initial_conditions_shape[0] not in [1, self.num_restarts]
125+
):
126+
warnings.warn(
127+
"If using `batch_initial_conditions` together with `raw_samples`, "
128+
"the first repeat dimension of `batch_initial_conditions` must "
129+
"match `num_restarts`. In the future this will raise an error. "
130+
"Defaulting to old behavior of ignoring `raw_samples` by setting "
131+
"it to None.",
132+
DeprecationWarning,
133+
)
134+
# Use object.__setattr__ to bypass immutability and set a value
135+
object.__setattr__(self, "raw_samples", None)
136+
118137
elif self.ic_generator is None:
119138
if self.nonlinear_inequality_constraints is not None:
120139
raise RuntimeError(
@@ -253,27 +272,73 @@ def _optimize_acqf_sequential_q(
253272
return candidates, torch.stack(acq_value_list)
254273

255274

275+
def _combine_initial_conditions(
276+
provided_initial_conditions: Tensor | None = None,
277+
generated_initial_conditions: Tensor | None = None,
278+
num_restarts: int | None = None,
279+
) -> Tensor:
280+
281+
if (
282+
provided_initial_conditions is not None
283+
and generated_initial_conditions is not None
284+
):
285+
if ( # Repeat the provided initial conditions to match the number of restarts
286+
provided_initial_conditions.shape[0] == 1
287+
and num_restarts is not None
288+
and num_restarts > 1
289+
):
290+
provided_initial_conditions = provided_initial_conditions.repeat(
291+
num_restarts, *([1] * (provided_initial_conditions.dim() - 1))
292+
)
293+
initial_conditions = torch.cat(
294+
[provided_initial_conditions, generated_initial_conditions], dim=-2
295+
)
296+
perm = torch.randperm(
297+
initial_conditions.shape[-2], device=initial_conditions.device
298+
)
299+
return initial_conditions.gather(
300+
-2, perm.unsqueeze(-1).expand_as(initial_conditions)
301+
)
302+
elif provided_initial_conditions is not None:
303+
return provided_initial_conditions
304+
elif generated_initial_conditions is not None:
305+
return generated_initial_conditions
306+
else:
307+
raise ValueError(
308+
"Either `batch_initial_conditions` or `raw_samples` must be set."
309+
)
310+
311+
256312
def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
257313
options = opt_inputs.options or {}
258314

259-
initial_conditions_provided = opt_inputs.batch_initial_conditions is not None
315+
required_raw_samples = opt_inputs.raw_samples
316+
generated_initial_conditions = None
260317

261-
if initial_conditions_provided:
262-
batch_initial_conditions = opt_inputs.batch_initial_conditions
263-
else:
264-
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
265-
batch_initial_conditions = opt_inputs.get_ic_generator()(
266-
acq_function=opt_inputs.acq_function,
267-
bounds=opt_inputs.bounds,
268-
q=opt_inputs.q,
269-
num_restarts=opt_inputs.num_restarts,
270-
raw_samples=opt_inputs.raw_samples,
271-
fixed_features=opt_inputs.fixed_features,
272-
options=options,
273-
inequality_constraints=opt_inputs.inequality_constraints,
274-
equality_constraints=opt_inputs.equality_constraints,
275-
**opt_inputs.ic_gen_kwargs,
276-
)
318+
if required_raw_samples is not None:
319+
if opt_inputs.batch_initial_conditions is not None:
320+
required_raw_samples -= opt_inputs.batch_initial_conditions.shape[-2]
321+
322+
if required_raw_samples > 0:
323+
# pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
324+
generated_initial_conditions = opt_inputs.get_ic_generator()(
325+
acq_function=opt_inputs.acq_function,
326+
bounds=opt_inputs.bounds,
327+
q=opt_inputs.q,
328+
num_restarts=opt_inputs.num_restarts,
329+
raw_samples=required_raw_samples,
330+
fixed_features=opt_inputs.fixed_features,
331+
options=options,
332+
inequality_constraints=opt_inputs.inequality_constraints,
333+
equality_constraints=opt_inputs.equality_constraints,
334+
**opt_inputs.ic_gen_kwargs,
335+
)
336+
337+
batch_initial_conditions = _combine_initial_conditions(
338+
provided_initial_conditions=opt_inputs.batch_initial_conditions,
339+
generated_initial_conditions=generated_initial_conditions,
340+
num_restarts=opt_inputs.num_restarts,
341+
)
277342

278343
batch_limit: int = options.get(
279344
"batch_limit",
@@ -344,31 +409,38 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
344409
first_warn_msg = (
345410
"Optimization failed in `gen_candidates_scipy` with the following "
346411
f"warning(s):\n{[w.message for w in ws]}\nBecause you specified "
347-
"`batch_initial_conditions`, optimization will not be retried with "
348-
"new initial conditions and will proceed with the current solution."
349-
" Suggested remediation: Try again with different "
350-
"`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
351-
if initial_conditions_provided
412+
"`batch_initial_conditions`>`raw_samples`, optimization will not "
413+
"be retried with new initial conditions and will proceed with the "
414+
"current solution. Suggested remediation: Try again with different "
415+
"`batch_initial_conditions`, don't provide `batch_initial_conditions`, "
416+
"or increase `raw_samples`.`"
417+
if required_raw_samples is not None and required_raw_samples <= 0
352418
else "Optimization failed in `gen_candidates_scipy` with the following "
353419
f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new "
354420
"set of initial conditions."
355421
)
356422
warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2)
357423

358-
if not initial_conditions_provided:
359-
batch_initial_conditions = opt_inputs.get_ic_generator()(
424+
if required_raw_samples is not None and required_raw_samples > 0:
425+
generated_initial_conditions = opt_inputs.get_ic_generator()(
360426
acq_function=opt_inputs.acq_function,
361427
bounds=opt_inputs.bounds,
362428
q=opt_inputs.q,
363429
num_restarts=opt_inputs.num_restarts,
364-
raw_samples=opt_inputs.raw_samples,
430+
raw_samples=required_raw_samples,
365431
fixed_features=opt_inputs.fixed_features,
366432
options=options,
367433
inequality_constraints=opt_inputs.inequality_constraints,
368434
equality_constraints=opt_inputs.equality_constraints,
369435
**opt_inputs.ic_gen_kwargs,
370436
)
371437

438+
batch_initial_conditions = _combine_initial_conditions(
439+
provided_initial_conditions=opt_inputs.batch_initial_conditions,
440+
generated_initial_conditions=generated_initial_conditions,
441+
num_restarts=opt_inputs.num_restarts,
442+
)
443+
372444
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
373445

374446
optimization_warning_raised = any(
@@ -1177,7 +1249,7 @@ def _gen_batch_initial_conditions_local_search(
11771249
inequality_constraints: list[tuple[Tensor, Tensor, float]],
11781250
min_points: int,
11791251
max_tries: int = 100,
1180-
):
1252+
) -> Tensor:
11811253
"""Generate initial conditions for local search."""
11821254
device = discrete_choices[0].device
11831255
dtype = discrete_choices[0].dtype
@@ -1197,6 +1269,66 @@ def _gen_batch_initial_conditions_local_search(
11971269
raise RuntimeError(f"Failed to generate at least {min_points} initial conditions")
11981270

11991271

1272+
def _gen_starting_points_local_search(
1273+
discrete_choices: list[Tensor],
1274+
raw_samples: int,
1275+
batch_initial_conditions: Tensor,
1276+
X_avoid: Tensor,
1277+
inequality_constraints: list[tuple[Tensor, Tensor, float]],
1278+
min_points: int,
1279+
acq_function: AcquisitionFunction,
1280+
max_batch_size: int = 2048,
1281+
max_tries: int = 100,
1282+
) -> Tensor:
1283+
required_min_points = min_points
1284+
provided_X0 = None
1285+
generated_X0 = None
1286+
1287+
if batch_initial_conditions is not None:
1288+
provided_X0 = _filter_invalid(
1289+
X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid
1290+
)
1291+
provided_X0 = _filter_infeasible(
1292+
X=provided_X0, inequality_constraints=inequality_constraints
1293+
).unsqueeze(1)
1294+
required_min_points -= batch_initial_conditions.shape[0]
1295+
1296+
if required_min_points > 0:
1297+
generated_X0 = _gen_batch_initial_conditions_local_search(
1298+
discrete_choices=discrete_choices,
1299+
raw_samples=raw_samples,
1300+
X_avoid=X_avoid,
1301+
inequality_constraints=inequality_constraints,
1302+
min_points=min_points,
1303+
max_tries=max_tries,
1304+
)
1305+
1306+
# pick the best starting points
1307+
with torch.no_grad():
1308+
acqvals_init = _split_batch_eval_acqf(
1309+
acq_function=acq_function,
1310+
X=generated_X0.unsqueeze(1),
1311+
max_batch_size=max_batch_size,
1312+
).unsqueeze(-1)
1313+
1314+
generated_X0 = generated_X0[
1315+
acqvals_init.topk(k=min_points, largest=True, dim=0).indices
1316+
]
1317+
1318+
if provided_X0 is not None and generated_X0 is not None:
1319+
X0 = torch.cat([provided_X0, generated_X0], dim=0)
1320+
elif provided_X0 is not None:
1321+
X0 = provided_X0
1322+
elif generated_X0 is not None:
1323+
X0 = generated_X0
1324+
else:
1325+
raise ValueError(
1326+
"Either `batch_initial_conditions` or `raw_samples` must be set."
1327+
)
1328+
1329+
return X0
1330+
1331+
12001332
def optimize_acqf_discrete_local_search(
12011333
acq_function: AcquisitionFunction,
12021334
discrete_choices: list[Tensor],
@@ -1207,6 +1339,7 @@ def optimize_acqf_discrete_local_search(
12071339
X_avoid: Tensor | None = None,
12081340
batch_initial_conditions: Tensor | None = None,
12091341
max_batch_size: int = 2048,
1342+
max_tries: int = 100,
12101343
unique: bool = True,
12111344
) -> tuple[Tensor, Tensor]:
12121345
r"""Optimize acquisition function over a lattice.
@@ -1238,6 +1371,8 @@ def optimize_acqf_discrete_local_search(
12381371
max_batch_size: The maximum number of choices to evaluate in batch.
12391372
A large limit can cause excessive memory usage if the model has
12401373
a large training set.
1374+
max_tries: Maximum number of iterations to try when generating initial
1375+
conditions.
12411376
unique: If True return unique choices, o/w choices may be repeated
12421377
(only relevant if `q > 1`).
12431378
@@ -1247,6 +1382,13 @@ def optimize_acqf_discrete_local_search(
12471382
- a `q x d`-dim tensor of generated candidates.
12481383
- an associated acquisition value.
12491384
"""
1385+
if batch_initial_conditions is not None:
1386+
if not (
1387+
len(batch_initial_conditions.shape) == 3
1388+
and batch_initial_conditions.shape[-2] == 1
1389+
):
1390+
raise ValueError("batch_initial_conditions must have shape `n x 1 x d` if given.")
1391+
12501392
candidate_list = []
12511393
base_X_pending = acq_function.X_pending if q > 1 else None
12521394
base_X_avoid = X_avoid
@@ -1259,27 +1401,18 @@ def optimize_acqf_discrete_local_search(
12591401
inequality_constraints = inequality_constraints or []
12601402
for i in range(q):
12611403
# generate some starting points
1262-
if i == 0 and batch_initial_conditions is not None:
1263-
X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
1264-
X0 = _filter_infeasible(
1265-
X=X0, inequality_constraints=inequality_constraints
1266-
).unsqueeze(1)
1267-
else:
1268-
X_init = _gen_batch_initial_conditions_local_search(
1269-
discrete_choices=discrete_choices,
1270-
raw_samples=raw_samples,
1271-
X_avoid=X_avoid,
1272-
inequality_constraints=inequality_constraints,
1273-
min_points=num_restarts,
1274-
)
1275-
# pick the best starting points
1276-
with torch.no_grad():
1277-
acqvals_init = _split_batch_eval_acqf(
1278-
acq_function=acq_function,
1279-
X=X_init.unsqueeze(1),
1280-
max_batch_size=max_batch_size,
1281-
).unsqueeze(-1)
1282-
X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices]
1404+
X0 = _gen_starting_points_local_search(
1405+
discrete_choices=discrete_choices,
1406+
raw_samples=raw_samples,
1407+
batch_initial_conditions=batch_initial_conditions,
1408+
X_avoid=X_avoid,
1409+
inequality_constraints=inequality_constraints,
1410+
min_points=num_restarts,
1411+
acq_function=acq_function,
1412+
max_batch_size=max_batch_size,
1413+
max_tries=max_tries,
1414+
)
1415+
batch_initial_conditions = None
12831416

12841417
# optimize from the best starting points
12851418
best_xs = torch.zeros(len(X0), dim, device=device, dtype=dtype)

botorch/optim/optimize_homotopy.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def optimize_acqf_homotopy(
157157
"""
158158
shared_optimize_acqf_kwargs = {
159159
"num_restarts": num_restarts,
160-
"raw_samples": raw_samples,
161160
"inequality_constraints": inequality_constraints,
162161
"equality_constraints": equality_constraints,
163162
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
@@ -181,11 +180,14 @@ def optimize_acqf_homotopy(
181180
homotopy.restart()
182181

183182
while not homotopy.should_stop:
183+
# After the first iteration we don't want to generate new raw samples
184+
requested_raw_samples = raw_samples if candidates is None else None
184185
candidates, acq_values = optimize_acqf(
185186
acq_function=acq_function,
186187
bounds=bounds,
187188
q=1,
188189
options=options,
190+
raw_samples=requested_raw_samples,
189191
batch_initial_conditions=candidates,
190192
**shared_optimize_acqf_kwargs,
191193
)
@@ -204,6 +206,7 @@ def optimize_acqf_homotopy(
204206
bounds=bounds,
205207
q=1,
206208
options=final_options,
209+
raw_samples=None,
207210
batch_initial_conditions=candidates,
208211
**shared_optimize_acqf_kwargs,
209212
)

botorch/optim/optimize_mixed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ def continuous_step(
496496
updated_opt_inputs = dataclasses.replace(
497497
opt_inputs,
498498
q=1,
499+
raw_samples=None,
499500
num_restarts=1,
500501
batch_initial_conditions=current_x.unsqueeze(0),
501502
fixed_features={

botorch/posteriors/posterior.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from __future__ import annotations
1212

13-
from abc import ABC, abstractmethod, abstractproperty
13+
from abc import ABC, abstractmethod
1414

1515
import torch
1616
from torch import Tensor
@@ -77,12 +77,14 @@ def sample(self, sample_shape: torch.Size | None = None) -> Tensor:
7777
with torch.no_grad():
7878
return self.rsample(sample_shape=sample_shape)
7979

80-
@abstractproperty
80+
@property
81+
@abstractmethod
8182
def device(self) -> torch.device:
8283
r"""The torch device of the distribution."""
8384
pass # pragma: no cover
8485

85-
@abstractproperty
86+
@property
87+
@abstractmethod
8688
def dtype(self) -> torch.dtype:
8789
r"""The torch dtype of the distribution."""
8890
pass # pragma: no cover

0 commit comments

Comments
 (0)