Skip to content

Commit 5d37606

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add support for continuous relaxation within optimize_acqf_mixed_alternating (#2635)
Summary: Pull Request resolved: #2635 `optimize_acqf_mixed_alternating` utilizes local search to optimize discrete dimensions. This works well when there are a small number of values for the discrete dimensions but it does not scale well as the number of values increases. To address this, we have been transforming the high-cardinality dimensions in Ax and only passing in the low-cardinality dimensions as part of `discrete_dims`. This diff adds support for using continuous relaxation for discrete dimensions that have more than `max_discrete_values` (configurable via `options`). Also updates the optimizer to fall back to `optimize_acqf` if there are no discrete dimensions left. This is more user friendly than erroring out (particularly when used through Ax). Reviewed By: Balandat Differential Revision: D66239005 fbshipit-source-id: 0878115eb08ea75acb34ad8e891cf88393d4e36c
1 parent de46059 commit 5d37606

File tree

3 files changed

+143
-8
lines changed

3 files changed

+143
-8
lines changed

botorch/optim/optimize_mixed.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
MAX_ITER_ALTER = 64 # Maximum number of alternating iterations.
4040
MAX_ITER_DISCRETE = 4 # Maximum number of discrete iterations.
4141
MAX_ITER_CONT = 8 # Maximum number of continuous iterations.
42+
# Maximum number of discrete values for a discrete dimension.
43+
# If there are more values for a dimension, we will use continuous
44+
# relaxation to optimize it.
45+
MAX_DISCRETE_VALUES = 20
4246
# Maximum number of iterations for optimizing the continuous relaxation
4347
# during initialization
4448
MAX_ITER_INIT = 100
@@ -52,6 +56,7 @@
5256
"maxiter_discrete",
5357
"maxiter_continuous",
5458
"maxiter_init",
59+
"max_discrete_values",
5560
"num_spray_points",
5661
"std_cont_perturbation",
5762
"batch_limit",
@@ -60,6 +65,40 @@
6065
SUPPORTED_INITIALIZATION = {"continuous_relaxation", "equally_spaced", "random"}
6166

6267

68+
def _setup_continuous_relaxation(
69+
discrete_dims: list[int],
70+
bounds: Tensor,
71+
max_discrete_values: int,
72+
post_processing_func: Callable[[Tensor], Tensor] | None,
73+
) -> tuple[list[int], Callable[[Tensor], Tensor] | None]:
74+
r"""Update `discrete_dims` and `post_processing_func` to use
75+
continuous relaxation for discrete dimensions that have more than
76+
`max_discrete_values` values. These dimensions are removed from
77+
`discrete_dims` and `post_processing_func` is updated to round
78+
them to the nearest integer.
79+
"""
80+
discrete_dims_t = torch.tensor(discrete_dims, dtype=torch.long)
81+
num_discrete_values = (
82+
bounds[1, discrete_dims_t] - bounds[0, discrete_dims_t]
83+
).cpu()
84+
dims_to_relax = discrete_dims_t[num_discrete_values > max_discrete_values]
85+
if dims_to_relax.numel() == 0:
86+
# No dimension needs continuous relaxation.
87+
return discrete_dims, post_processing_func
88+
# Remove relaxed dims from `discrete_dims`.
89+
discrete_dims = list(set(discrete_dims).difference(dims_to_relax.tolist()))
90+
91+
def new_post_processing_func(X: Tensor) -> Tensor:
92+
r"""Round the relaxed dimensions to the nearest integer and apply the original
93+
`post_processing_func`."""
94+
X[..., dims_to_relax] = X[..., dims_to_relax].round()
95+
if post_processing_func is not None:
96+
X = post_processing_func(X)
97+
return X
98+
99+
return discrete_dims, new_post_processing_func
100+
101+
63102
def _filter_infeasible(
64103
X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]] | None
65104
) -> Tensor:
@@ -532,6 +571,9 @@ def optimize_acqf_mixed_alternating(
532571
iterations.
533572
534573
NOTE: This method assumes that all discrete variables are integer valued.
574+
The discrete dimensions that have more than
575+
`options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will
576+
be optimized using continuous relaxation.
535577
536578
# TODO: Support categorical variables.
537579
@@ -549,6 +591,9 @@ def optimize_acqf_mixed_alternating(
549591
Defaults to 4.
550592
- "maxiter_continuous": Maximum number of iterations in each continuous step.
551593
Defaults to 8.
594+
- "max_discrete_values": Maximum number of values for a discrete dimension
595+
to be optimized using discrete step / local search. The discrete dimensions
596+
with more values will be optimized using continuous relaxation.
552597
- "num_spray_points": Number of spray points (around `X_baseline`) to add to
553598
the points generated by the initialization strategy. Defaults to 20 if
554599
all discrete variables are binary and to 0 otherwise.
@@ -598,6 +643,17 @@ def optimize_acqf_mixed_alternating(
598643
f"Received an unsupported option {unsupported_keys}. {SUPPORTED_OPTIONS=}."
599644
)
600645

646+
# Update discrete dims and post processing functions to account for any
647+
# dimensions that should be using continuous relaxation.
648+
discrete_dims, post_processing_func = _setup_continuous_relaxation(
649+
discrete_dims=discrete_dims,
650+
bounds=bounds,
651+
max_discrete_values=assert_is_instance(
652+
options.get("max_discrete_values", MAX_DISCRETE_VALUES), int
653+
),
654+
post_processing_func=post_processing_func,
655+
)
656+
601657
opt_inputs = OptimizeAcqfInputs(
602658
acq_function=acq_function,
603659
bounds=bounds,
@@ -623,7 +679,7 @@ def optimize_acqf_mixed_alternating(
623679
# Remove fixed features from dims, so they don't get optimized.
624680
discrete_dims = [dim for dim in discrete_dims if dim not in fixed_features]
625681
if len(discrete_dims) == 0:
626-
raise ValueError("There must be at least one discrete parameter.")
682+
return _optimize_acqf(opt_inputs=opt_inputs)
627683
if not (
628684
isinstance(discrete_dims, list)
629685
and len(set(discrete_dims)) == len(discrete_dims)

test/optim/test_optimize_mixed.py

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from botorch.models.gp_regression import SingleTaskGP
2020
from botorch.optim.optimize import _optimize_acqf, OptimizeAcqfInputs
2121
from botorch.optim.optimize_mixed import (
22+
_setup_continuous_relaxation,
2223
complement_indices,
2324
continuous_step,
2425
discrete_step,
2526
generate_starting_points,
2627
get_nearest_neighbors,
2728
get_spray_points,
29+
MAX_DISCRETE_VALUES,
2830
optimize_acqf_mixed_alternating,
2931
sample_feasible_points,
3032
)
@@ -544,20 +546,29 @@ def test_optimize_acqf_mixed_binary_only(self) -> None:
544546
self.assertEqual(candidates.shape[-1], dim)
545547
c_binary = candidates[:, binary_dims + [2]]
546548
self.assertTrue(((c_binary == 0) | (c_binary == 1)).all())
547-
# Only continuous parameters will raise an error.
548-
with self.assertRaisesRegex(
549-
ValueError,
550-
"There must be at least one discrete parameter",
551-
):
549+
# Only continuous parameters should fallback to optimize_acqf.
550+
with mock.patch(
551+
f"{OPT_MODULE}._optimize_acqf", wraps=_optimize_acqf
552+
) as wrapped_optimize:
552553
optimize_acqf_mixed_alternating(
553554
acq_function=acqf,
554555
bounds=bounds,
555556
discrete_dims=[],
556557
options=options,
557558
q=1,
558559
raw_samples=20,
559-
num_restarts=20,
560+
num_restarts=2,
561+
)
562+
wrapped_optimize.assert_called_once_with(
563+
opt_inputs=_make_opt_inputs(
564+
acq_function=acqf,
565+
bounds=bounds,
566+
options=options,
567+
q=1,
568+
raw_samples=20,
569+
num_restarts=2,
560570
)
571+
)
561572
# Only discrete works fine.
562573
candidates, _ = optimize_acqf_mixed_alternating(
563574
acq_function=acqf,
@@ -720,3 +731,71 @@ def test_optimize_acqf_mixed_integer(self) -> None:
720731
wrapped_sample_feasible.assert_called_once()
721732
# Should request 4 candidates, since all 4 are infeasible.
722733
self.assertEqual(wrapped_sample_feasible.call_args.kwargs["num_points"], 4)
734+
735+
def test_optimize_acqf_mixed_continuous_relaxation(self) -> None:
736+
# Testing with integer variables.
737+
train_X, train_Y, binary_dims, cont_dims = self._get_data()
738+
# Update the data to introduce integer dimensions.
739+
binary_dims = [0]
740+
integer_dims = [3, 4]
741+
discrete_dims = binary_dims + integer_dims
742+
bounds = torch.tensor(
743+
[[0.0, 0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 40.0, 15.0]],
744+
dtype=torch.double,
745+
device=self.device,
746+
)
747+
# Update the model to have a different optimizer.
748+
root = torch.tensor([0.0, 0.0, 0.0, 25.0, 10.0], device=self.device)
749+
model = QuadraticDeterministicModel(root)
750+
acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X)
751+
752+
for max_discrete_values, post_processing_func in (
753+
(None, None),
754+
(5, lambda X: X + 10),
755+
):
756+
options = {
757+
"batch_limit": 5,
758+
"init_batch_limit": 20,
759+
"maxiter_alternating": 1,
760+
}
761+
if max_discrete_values is not None:
762+
options["max_discrete_values"] = max_discrete_values
763+
with mock.patch(
764+
f"{OPT_MODULE}._setup_continuous_relaxation",
765+
wraps=_setup_continuous_relaxation,
766+
) as wrapped_setup, mock.patch(
767+
f"{OPT_MODULE}.discrete_step", wraps=discrete_step
768+
) as wrapped_discrete:
769+
candidates, _ = optimize_acqf_mixed_alternating(
770+
acq_function=acqf,
771+
bounds=bounds,
772+
discrete_dims=discrete_dims,
773+
q=3,
774+
raw_samples=32,
775+
num_restarts=4,
776+
options=options,
777+
post_processing_func=post_processing_func,
778+
)
779+
wrapped_setup.assert_called_once_with(
780+
discrete_dims=discrete_dims,
781+
bounds=bounds,
782+
max_discrete_values=max_discrete_values or MAX_DISCRETE_VALUES,
783+
post_processing_func=post_processing_func,
784+
)
785+
discrete_call_args = wrapped_discrete.call_args.kwargs
786+
expected_dims = [0, 4] if max_discrete_values is None else [0]
787+
self.assertAllClose(
788+
discrete_call_args["discrete_dims"],
789+
torch.tensor(expected_dims, device=self.device),
790+
)
791+
# Check that dim 3 is rounded.
792+
X = torch.ones(1, 5, device=self.device) * 0.6
793+
X_expected = X.clone()
794+
X_expected[0, 3] = 1.0
795+
if max_discrete_values is not None:
796+
X_expected[0, 4] = 1.0
797+
if post_processing_func is not None:
798+
X_expected = post_processing_func(X_expected)
799+
self.assertAllClose(
800+
discrete_call_args["opt_inputs"].post_processing_func(X), X_expected
801+
)

test/test_utils/test_mock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_mock_optimize_mixed_alternating(self) -> None:
9898
) as mock_neighbors:
9999
optimize_acqf_mixed_alternating(
100100
acq_function=SinAcqusitionFunction(),
101-
bounds=torch.tensor([[-2.0, 0.0], [2.0, 200.0]]),
101+
bounds=torch.tensor([[-2.0, 0.0], [2.0, 20.0]]),
102102
discrete_dims=[1],
103103
num_restarts=1,
104104
)

0 commit comments

Comments
 (0)