Skip to content

Commit 003d63f

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add support for continuous relaxation within optimize_acqf_mixed_alternating
Summary: `optimize_acqf_mixed_alternating` utilizes local search to optimize discrete dimensions. This works well when there are a small number of values for the discrete dimensions but it does not scale well as the number of values increases. To address this, we have been transforming the high-cardinality dimensions in Ax and only passing in the low-cardinality dimensions as part of `discrete_dims`. This diff adds support for using continuous relaxation for discrete dimensions that have more than `max_discrete_values` (configurable via `options`). Differential Revision: D66239005
1 parent de46059 commit 003d63f

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
MAX_ITER_ALTER = 64 # Maximum number of alternating iterations.
4040
MAX_ITER_DISCRETE = 4 # Maximum number of discrete iterations.
4141
MAX_ITER_CONT = 8 # Maximum number of continuous iterations.
42+
# Maximum number of discrete values for a discrete dimension.
43+
# If there are more values for a dimension, we will use continuous
44+
# relaxation to optimize it.
45+
MAX_DISCRETE_VALUES = 20
4246
# Maximum number of iterations for optimizing the continuous relaxation
4347
# during initialization
4448
MAX_ITER_INIT = 100
@@ -52,6 +56,7 @@
5256
"maxiter_discrete",
5357
"maxiter_continuous",
5458
"maxiter_init",
59+
"max_discrete_values",
5560
"num_spray_points",
5661
"std_cont_perturbation",
5762
"batch_limit",
@@ -60,6 +65,38 @@
6065
SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"}
6166

6267

68+
def _setup_continuous_relaxation(
69+
discrete_dims: list[int],
70+
bounds: Tensor,
71+
max_discrete_values: int,
72+
post_processing_func: Callable[[Tensor], Tensor] | None,
73+
) -> tuple[list[int], Callable[[Tensor], Tensor] | None]:
74+
r"""Update `discrete_dims` and `post_processing_func` to use
75+
continuous relaxation for discrete dimensions that have more than
76+
`max_discrete_values` values. These dimensions are removed from
77+
`discrete_dims` and `post_processing_func` is updated to round
78+
them to the nearest integer.
79+
"""
80+
discrete_dims_t = torch.tensor(discrete_dims, dtype=torch.long)
81+
num_discrete_values = (bounds[1, discrete_dims] - bounds[0, discrete_dims]).cpu()
82+
dims_to_relax = discrete_dims_t[num_discrete_values > max_discrete_values]
83+
if dims_to_relax.numel() == 0:
84+
# No dimension needs continuous relaxation.
85+
return discrete_dims, post_processing_func
86+
# Remove relaxed dims from `discrete_dims`.
87+
discrete_dims = list(set(discrete_dims).difference(dims_to_relax.tolist()))
88+
89+
def new_post_processing_func(X: Tensor) -> Tensor:
90+
r"""Round the relaxed dimensions to the nearest integer and apply the original
91+
`post_processing_func`."""
92+
X[:, dims_to_relax] = X[:, dims_to_relax].round()
93+
if post_processing_func is not None:
94+
X = post_processing_func(X)
95+
return X
96+
97+
return discrete_dims, new_post_processing_func
98+
99+
63100
def _filter_infeasible(
64101
X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]] | None
65102
) -> Tensor:
@@ -532,6 +569,9 @@ def optimize_acqf_mixed_alternating(
532569
iterations.
533570
534571
NOTE: This method assumes that all discrete variables are integer valued.
572+
The discrete dimensions that have more than
573+
`options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will
574+
be optimized using continuous relaxation.
535575
536576
# TODO: Support categorical variables.
537577
@@ -549,6 +589,9 @@ def optimize_acqf_mixed_alternating(
549589
Defaults to 4.
550590
- "maxiter_continuous": Maximum number of iterations in each continuous step.
551591
Defaults to 8.
592+
- "max_discrete_values": Maximum number of values for a discrete dimension
593+
to be optimized using discrete step / local search. The discrete dimensions
594+
with more values will be optimized using continuous relaxation.
552595
- "num_spray_points": Number of spray points (around `X_baseline`) to add to
553596
the points generated by the initialization strategy. Defaults to 20 if
554597
all discrete variables are binary and to 0 otherwise.
@@ -598,6 +641,17 @@ def optimize_acqf_mixed_alternating(
598641
f"Received an unsupported option {unsupported_keys}. {SUPPORTED_OPTIONS=}."
599642
)
600643

644+
# Update discrete dims and post processing functions to account for any
645+
# dimensions that should be using continuous relaxation.
646+
discrete_dims, post_processing_func = _setup_continuous_relaxation(
647+
discrete_dims=discrete_dims,
648+
bounds=bounds,
649+
max_discrete_values=assert_is_instance(
650+
options.get("max_discrete_values", MAX_DISCRETE_VALUES), int
651+
),
652+
post_processing_func=post_processing_func,
653+
)
654+
601655
opt_inputs = OptimizeAcqfInputs(
602656
acq_function=acq_function,
603657
bounds=bounds,

test/optim/test_optimize_mixed.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from botorch.models.gp_regression import SingleTaskGP
2020
from botorch.optim.optimize import _optimize_acqf, OptimizeAcqfInputs
2121
from botorch.optim.optimize_mixed import (
22+
_setup_continuous_relaxation,
2223
complement_indices,
2324
continuous_step,
2425
discrete_step,
2526
generate_starting_points,
2627
get_nearest_neighbors,
2728
get_spray_points,
29+
MAX_DISCRETE_VALUES,
2830
optimize_acqf_mixed_alternating,
2931
sample_feasible_points,
3032
)
@@ -720,3 +722,70 @@ def test_optimize_acqf_mixed_integer(self) -> None:
720722
wrapped_sample_feasible.assert_called_once()
721723
# Should request 4 candidates, since all 4 are infeasible.
722724
self.assertEqual(wrapped_sample_feasible.call_args.kwargs["num_points"], 4)
725+
726+
def test_optimize_acqf_mixed_continuous_relaxation(self) -> None:
727+
# Testing with integer variables.
728+
train_X, train_Y, binary_dims, cont_dims = self._get_data()
729+
dim = len(binary_dims) + len(cont_dims)
730+
# Update the data to introduce integer dimensions.
731+
binary_dims = [0]
732+
integer_dims = [3, 4]
733+
discrete_dims = binary_dims + integer_dims
734+
bounds = self.single_bound.repeat(1, dim)
735+
bounds[1, 3] = 40.0
736+
bounds[1, 4] = 15.0
737+
# Update the model to have a different optimizer.
738+
root = torch.tensor([0.0, 0.0, 0.0, 25.0, 10.0], device=self.device)
739+
model = QuadraticDeterministicModel(root)
740+
acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X)
741+
742+
for max_discrete_values, post_processing_func in (
743+
(None, None),
744+
(5, lambda X: X + 10),
745+
):
746+
options = {
747+
"batch_limit": 5,
748+
"init_batch_limit": 20,
749+
"maxiter_alternating": 1,
750+
}
751+
if max_discrete_values is not None:
752+
options["max_discrete_values"] = max_discrete_values
753+
with mock.patch(
754+
f"{OPT_MODULE}._setup_continuous_relaxation",
755+
wraps=_setup_continuous_relaxation,
756+
) as wrapped_setup, mock.patch(
757+
f"{OPT_MODULE}.discrete_step", wraps=discrete_step
758+
) as wrapped_discrete:
759+
candidates, _ = optimize_acqf_mixed_alternating(
760+
acq_function=acqf,
761+
bounds=bounds,
762+
discrete_dims=discrete_dims,
763+
q=3,
764+
raw_samples=32,
765+
num_restarts=4,
766+
options=options,
767+
post_processing_func=post_processing_func,
768+
)
769+
wrapped_setup.assert_called_once_with(
770+
discrete_dims=discrete_dims,
771+
bounds=bounds,
772+
max_discrete_values=max_discrete_values or MAX_DISCRETE_VALUES,
773+
post_processing_func=post_processing_func,
774+
)
775+
discrete_call_args = wrapped_discrete.call_args.kwargs
776+
expected_dims = [0, 4] if max_discrete_values is None else [0]
777+
self.assertAllClose(
778+
discrete_call_args["discrete_dims"],
779+
torch.tensor(expected_dims, device=self.device),
780+
)
781+
# Check that dim 3 is rounded.
782+
X = torch.ones(1, 5, device=self.device) * 0.6
783+
X_expected = X.clone()
784+
X_expected[0, 3] = 1.0
785+
if max_discrete_values is not None:
786+
X_expected[0, 4] = 1.0
787+
if post_processing_func is not None:
788+
X_expected = post_processing_func(X_expected)
789+
self.assertAllClose(
790+
discrete_call_args["opt_inputs"].post_processing_func(X), X_expected
791+
)

0 commit comments

Comments
 (0)