diff --git a/botorch/optim/optimize_mixed.py b/botorch/optim/optimize_mixed.py index 76c163913e..28cb27648c 100644 --- a/botorch/optim/optimize_mixed.py +++ b/botorch/optim/optimize_mixed.py @@ -5,8 +5,10 @@ # LICENSE file in the root directory of this source tree. import dataclasses +import itertools +import random import warnings -from typing import Any, Callable +from typing import Any, Callable, Sequence import torch from botorch.acquisition import AcquisitionFunction @@ -164,10 +166,76 @@ def get_nearest_neighbors( return unique_neighbors +def get_categorical_neighbors( + current_x: Tensor, + bounds: Tensor, + cat_dims: Tensor, + max_num_cat_values: int = MAX_DISCRETE_VALUES, +) -> Tensor: + r"""Generate all 1-Hamming distance neighbors of a given input. The neighbors + are generated for the categorical dimensions only. + + We assume that all categorical values are equidistant. If the number of values + is greater than `max_num_cat_values`, we sample uniformly from the + possible values for that dimension. + + NOTE: This assumes that `current_x` is detached and uses in-place operations, + which are known to be incompatible with autograd. + + Args: + current_x: The design to find the neighbors of. A tensor of shape `d`. + bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. + cat_dims: A tensor of indices corresponding to categorical parameters. + max_num_cat_values: Maximum number of values for a categorical parameter, + beyond which values are uniformly sampled. + + Returns: + A tensor of shape `num_neighbors x d`, denoting up to `max_num_cat_values` + unique 1-Hamming distance neighbors for each categorical dimension. + """ + + # Neighbors are generated by considering all possible values for each + # categorical dimension, one at a time. + def _get_cat_values(dim: int) -> Sequence[int]: + r"""Get a sequence of up to `max_num_cat_values` values that a categorical + feature may take.""" + lb, ub = bounds[:, dim].long() + current_value = current_x[dim] + cat_values = range(lb, ub + 1) + if ub - lb + 1 <= max_num_cat_values: + return cat_values + else: + return random.sample( + [v for v in cat_values if v != current_value], k=max_num_cat_values + ) + + new_cat_values_lst = list( + itertools.chain.from_iterable(_get_cat_values(dim) for dim in cat_dims) + ) + new_cat_values = torch.tensor( + new_cat_values_lst, device=current_x.device, dtype=current_x.dtype + ) + + num_cat_values = (bounds[1, :] - bounds[0, :] + 1).to(dtype=torch.long) + num_cat_values.clamp_(max=max_num_cat_values) + new_cat_idcs = torch.cat( + tuple(torch.full((num_cat_values[dim].item(),), dim) for dim in cat_dims) + ) + neighbors = current_x.repeat(len(new_cat_values), 1) + # Assign the new values to their corresponding columns. + neighbors.scatter_(1, new_cat_idcs.view(-1, 1), new_cat_values.view(-1, 1)) + + unique_neighbors = neighbors.unique(dim=0) + # Also remove current_x if it is in unique_neighbors. + unique_neighbors = unique_neighbors[~(unique_neighbors == current_x).all(dim=-1)] + return unique_neighbors + + def get_spray_points( X_baseline: Tensor, cont_dims: Tensor, discrete_dims: Tensor, + cat_dims: Tensor, bounds: Tensor, num_spray_points: int, std_cont_perturbation: float = STD_CONT_PERTURBATION, @@ -182,6 +250,7 @@ def get_spray_points( X_baseline: Tensor of best acquired points across BO run. cont_dims: Indices of continuous parameters/input dimensions. discrete_dims: Indices of binary/integer parameters/input dimensions. + cat_dims: Indices of categorical parameters/input dimensions. bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. num_spray_points: Number of spray points to return. std_cont_perturbation: standard deviation of Normal perturbations of @@ -194,12 +263,23 @@ def get_spray_points( device, dtype = X_baseline.device, X_baseline.dtype perturb_nbors = torch.zeros(0, dim, device=device, dtype=dtype) for x in X_baseline: - discrete_perturbs = get_nearest_neighbors( - current_x=x, bounds=bounds, discrete_dims=discrete_dims - ) - discrete_perturbs = discrete_perturbs[ - torch.randint(len(discrete_perturbs), (num_spray_points,), device=device) - ] + if discrete_dims.numel(): + discrete_perturbs = get_nearest_neighbors( + current_x=x, bounds=bounds, discrete_dims=discrete_dims + ) + discrete_perturbs = discrete_perturbs[ + torch.randint( + len(discrete_perturbs), (num_spray_points,), device=device + ) + ] + if cat_dims.numel(): + cat_perturbs = get_categorical_neighbors( + current_x=x, bounds=bounds, cat_dims=cat_dims + ) + cat_perturbs = cat_perturbs[ + torch.randint(len(cat_perturbs), (num_spray_points,), device=device) + ] + cont_perturbs = x[cont_dims] + std_cont_perturbation * torch.randn( num_spray_points, len(cont_dims), device=device, dtype=dtype ) @@ -207,7 +287,11 @@ def get_spray_points( min=bounds[0, cont_dims], max=bounds[1, cont_dims] ) nbds = torch.zeros(num_spray_points, dim, device=device, dtype=dtype) - nbds[..., discrete_dims] = discrete_perturbs[..., discrete_dims] + if discrete_dims.numel(): + nbds[..., discrete_dims] = discrete_perturbs[..., discrete_dims] + if cat_dims.numel(): + nbds[..., cat_dims] = cat_perturbs[..., cat_dims] + nbds[..., cont_dims] = cont_perturbs perturb_nbors = torch.cat([perturb_nbors, nbds], dim=0) return perturb_nbors @@ -216,6 +300,7 @@ def get_spray_points( def sample_feasible_points( opt_inputs: OptimizeAcqfInputs, discrete_dims: Tensor, + cat_dims: Tensor, num_points: int, ) -> Tensor: r"""Sample feasible points from the optimization domain. @@ -235,6 +320,7 @@ def sample_feasible_points( opt_inputs: Common set of arguments for acquisition optimization. discrete_dims: A tensor of indices corresponding to binary and integer parameters. + cat_dims: A tensor of indices corresponding to categorical parameters. num_points: The number of points to sample. Returns: @@ -272,7 +358,8 @@ def generator(n: int) -> Tensor: # Generate twice as many, since we're likely to filter out some points. base_points = generator(n=num_remaining * 2) # Round the discrete dimensions to the nearest integer. - base_points[:, discrete_dims] = base_points[:, discrete_dims].round() + non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0) + base_points[:, non_cont_dims] = base_points[:, non_cont_dims].round() # Fix the fixed features. base_points = fix_features( X=base_points, fixed_features=opt_inputs.fixed_features @@ -293,6 +380,7 @@ def generator(n: int) -> Tensor: def generate_starting_points( opt_inputs: OptimizeAcqfInputs, discrete_dims: Tensor, + cat_dims: Tensor, cont_dims: Tensor, ) -> tuple[Tensor, Tensor]: """Generate initial starting points for the alternating optimization. @@ -307,6 +395,7 @@ def generate_starting_points( from `opt_inputs`. discrete_dims: A tensor of indices corresponding to integer and binary parameters. + cat_dims: A tensor of indices corresponding to categorical parameters. cont_dims: A tensor of indices corresponding to continuous parameters. Returns: @@ -407,6 +496,7 @@ def generate_starting_points( X_baseline=X_baseline, cont_dims=cont_dims, discrete_dims=discrete_dims, + cat_dims=cat_dims, bounds=bounds, num_spray_points=num_spray_points, std_cont_perturbation=assert_is_instance( @@ -429,6 +519,7 @@ def generate_starting_points( new_x_init = sample_feasible_points( opt_inputs=opt_inputs, discrete_dims=discrete_dims, + cat_dims=cat_dims, num_points=num_restarts - len(x_init_candts), ) x_init_candts = torch.cat([x_init_candts, new_x_init], dim=0) @@ -454,6 +545,7 @@ def generate_starting_points( def discrete_step( opt_inputs: OptimizeAcqfInputs, discrete_dims: Tensor, + cat_dims: Tensor, current_x: Tensor, ) -> tuple[Tensor, Tensor]: """Discrete nearest neighbour search. @@ -464,6 +556,7 @@ def discrete_step( and constraints from `opt_inputs`. discrete_dims: A tensor of indices corresponding to binary and integer parameters. + cat_dims: A tensor of indices corresponding to categorical parameters. current_x: Starting point. A tensor of shape `d`. Returns: @@ -476,14 +569,32 @@ def discrete_step( for _ in range( assert_is_instance(options.get("maxiter_discrete", MAX_ITER_DISCRETE), int) ): - x_neighbors = get_nearest_neighbors( - current_x=current_x.detach(), - bounds=opt_inputs.bounds, - discrete_dims=discrete_dims, - ) - x_neighbors = _filter_infeasible( - X=x_neighbors, inequality_constraints=opt_inputs.inequality_constraints - ) + neighbors = [] + if discrete_dims.numel(): + x_neighbors_discrete = get_nearest_neighbors( + current_x=current_x.detach(), + bounds=opt_inputs.bounds, + discrete_dims=discrete_dims, + ) + x_neighbors_discrete = _filter_infeasible( + X=x_neighbors_discrete, + inequality_constraints=opt_inputs.inequality_constraints, + ) + neighbors.append(x_neighbors_discrete) + + if cat_dims.numel(): + x_neighbors_cat = get_categorical_neighbors( + current_x=current_x.detach(), + bounds=opt_inputs.bounds, + cat_dims=cat_dims, + ) + x_neighbors_cat = _filter_infeasible( + X=x_neighbors_cat, + inequality_constraints=opt_inputs.inequality_constraints, + ) + neighbors.append(x_neighbors_cat) + + x_neighbors = torch.cat(neighbors, dim=0) if x_neighbors.numel() == 0: # Exit gracefully with last point if there are no feasible neighbors. break @@ -508,6 +619,7 @@ def discrete_step( def continuous_step( opt_inputs: OptimizeAcqfInputs, discrete_dims: Tensor, + cat_dims: Tensor, current_x: Tensor, ) -> tuple[Tensor, Tensor]: """Continuous search using L-BFGS-B through optimize_acqf. @@ -518,6 +630,7 @@ def continuous_step( `fixed_features` and constraints from `opt_inputs`. discrete_dims: A tensor of indices corresponding to binary and integer parameters. + cat_dims: A tensor of indices corresponding to categorical parameters. current_x: Starting point. A tensor of shape `d`. Returns: @@ -525,7 +638,9 @@ def continuous_step( and a (1)-dim tensor of acquisition values. """ options = opt_inputs.options or {} - if len(discrete_dims) == len(current_x): # nothing continuous to optimize + non_cont_dims = torch.cat((discrete_dims, cat_dims), dim=0) + + if len(non_cont_dims) == len(current_x): # nothing continuous to optimize with torch.no_grad(): return current_x, opt_inputs.acq_function(current_x.unsqueeze(0)) @@ -536,7 +651,7 @@ def continuous_step( raw_samples=None, batch_initial_conditions=current_x.unsqueeze(0), fixed_features={ - **dict(zip(discrete_dims.tolist(), current_x[discrete_dims])), + **dict(zip(non_cont_dims.tolist(), current_x[non_cont_dims])), **(opt_inputs.fixed_features or {}), }, options={ @@ -551,7 +666,8 @@ def continuous_step( def optimize_acqf_mixed_alternating( acq_function: AcquisitionFunction, bounds: Tensor, - discrete_dims: list[int], + discrete_dims: list[int] | None = None, + cat_dims: list[int] | None = None, options: dict[str, Any] | None = None, q: int = 1, raw_samples: int = RAW_SAMPLES, @@ -562,23 +678,25 @@ def optimize_acqf_mixed_alternating( inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, ) -> tuple[Tensor, Tensor]: r""" - Optimizes acquisition function over mixed binary and continuous input spaces. - Multiple random restarting starting points are picked by evaluating a large set - of initial candidates. From each starting point, alternating discrete local search - and continuous optimization via (L-BFGS) is performed for a fixed number of - iterations. - - NOTE: This method assumes that all discrete variables are integer valued. + Optimizes acquisition function over mixed integer, categorical, and continuous + input spaces. Multiple random restarting starting points are picked by evaluating + a large set of initial candidates. From each starting point, alternating + discrete/categorical local search and continuous optimization via (L-BFGS) + is performed for a fixed number of iterations. + + NOTE: This method assumes that all discrete and categorical variables are + integer valued. The discrete dimensions that have more than `options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will be optimized using continuous relaxation. - - # TODO: Support categorical variables. + The categorical dimensions that have more than `MAX_DISCRETE_VALUES` values + be optimized by selecting random subsamples of the possible values. Args: acq_function: BoTorch Acquisition function. bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. discrete_dims: A list of indices corresponding to integer and binary parameters. + cat_dims: A list of indices corresponding to categorical parameters. options: Dictionary specifying optimization options. Supports the following: - "initialization_strategy": Strategy used to generate the initial candidates. "random", "continuous_relaxation" or "equally_spaced" (linspace style). @@ -631,6 +749,9 @@ def optimize_acqf_mixed_alternating( "sequential optimization." ) + cat_dims = cat_dims or [] + discrete_dims = discrete_dims or [] + fixed_features = fixed_features or {} options = options or {} options.setdefault("batch_limit", MAX_BATCH_SIZE) @@ -676,22 +797,29 @@ def optimize_acqf_mixed_alternating( tkwargs: dict[str, Any] = {"device": bounds.device, "dtype": bounds.dtype} # Remove fixed features from dims, so they don't get optimized. discrete_dims = [dim for dim in discrete_dims if dim not in fixed_features] - if len(discrete_dims) == 0: + cat_dims = [dim for dim in cat_dims if dim not in fixed_features] + non_cont_dims = [*discrete_dims, *cat_dims] + if len(non_cont_dims) == 0: + # If the problem is fully continuous, fall back to standard optimization. return _optimize_acqf(opt_inputs=opt_inputs) if not ( - isinstance(discrete_dims, list) - and len(set(discrete_dims)) == len(discrete_dims) - and min(discrete_dims) >= 0 - and max(discrete_dims) <= dim - 1 + isinstance(non_cont_dims, list) + and len(set(non_cont_dims)) == len(non_cont_dims) + and min(non_cont_dims) >= 0 + and max(non_cont_dims) <= dim - 1 ): raise ValueError( - "`discrete_dims` must be a list with unique integers " - "between 0 and num_dims - 1." + "`discrete_dims` and `cat_dims` must be lists with unique, disjoint " + "integers between 0 and num_dims - 1." ) discrete_dims_t = torch.tensor( discrete_dims, dtype=torch.long, device=tkwargs["device"] ) - cont_dims = complement_indices_like(indices=discrete_dims_t, d=dim) + cat_dims_t = torch.tensor(cat_dims, dtype=torch.long, device=tkwargs["device"]) + non_cont_dims = torch.tensor( + non_cont_dims, dtype=torch.long, device=tkwargs["device"] + ) + cont_dims = complement_indices_like(indices=non_cont_dims, d=dim) # Fixed features are all in cont_dims. Remove them, so they don't get optimized. ff_idcs = torch.tensor( list(fixed_features.keys()), dtype=torch.long, device=tkwargs["device"] @@ -703,6 +831,7 @@ def optimize_acqf_mixed_alternating( best_X, best_acq_val = generate_starting_points( opt_inputs=opt_inputs, discrete_dims=discrete_dims_t, + cat_dims=cat_dims_t, cont_dims=cont_dims, ) @@ -718,6 +847,7 @@ def optimize_acqf_mixed_alternating( best_X[i], best_acq_val[i] = step( opt_inputs=opt_inputs, discrete_dims=discrete_dims_t, + cat_dims=cat_dims_t, current_x=best_X[i], ) diff --git a/test/optim/test_optimize_mixed.py b/test/optim/test_optimize_mixed.py index f685354033..16c2b7a343 100644 --- a/test/optim/test_optimize_mixed.py +++ b/test/optim/test_optimize_mixed.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import random from dataclasses import fields from itertools import product from typing import Any, Callable @@ -25,6 +26,7 @@ continuous_step, discrete_step, generate_starting_points, + get_categorical_neighbors, get_nearest_neighbors, get_spray_points, MAX_DISCRETE_VALUES, @@ -148,6 +150,49 @@ def test_get_nearest_neighbors(self) -> None: ) ) + def test_get_categorical_neighbors(self) -> None: + current_x = torch.tensor([1.0, 0.0, 0.5], device=self.device) + bounds = torch.tensor([[0.0, 0.0, 0.0], [3.0, 2.0, 1.0]], device=self.device) + cat_dims = torch.tensor([0, 1], device=self.device, dtype=torch.long) + expected_neighbors = torch.tensor( + [ + [0.0, 0.0, 0.5], + [2.0, 0.0, 0.5], + [3.0, 0.0, 0.5], + [1.0, 1.0, 0.5], + [1.0, 2.0, 0.5], + ], + device=self.device, + ) + neighbors = get_categorical_neighbors( + current_x=current_x, bounds=bounds, cat_dims=cat_dims + ) + self.assertTrue( + torch.equal( + expected_neighbors.sort(dim=0).values, + neighbors.sort(dim=0).values, + ) + ) + + # Test the case where there are too many categorical values, + # where we fall back to randomly sampling a subset. + random.seed(0) + current_x = torch.tensor([50.0, 5.0], device=self.device) + bounds = torch.tensor([[0.0, 0.0], [100.0, 8.0]], device=self.device) + cat_dims = torch.tensor([0, 1], device=self.device, dtype=torch.long) + + neighbors = get_categorical_neighbors( + current_x=current_x, + bounds=bounds, + cat_dims=cat_dims, + max_num_cat_values=MAX_DISCRETE_VALUES, + ) + # We expect the maximum number of neighbors in the first dim, and 8 + # neighbors in the second dim. + self.assertTrue(neighbors.shape == torch.Size([MAX_DISCRETE_VALUES + 8, 2])) + # Check that neighbors are sampled without replacement. + self.assertTrue(neighbors.unique(dim=0).shape[0] == neighbors.shape[0]) + def test_sample_feasible_points(self, with_constraints: bool = False) -> None: bounds = torch.tensor([[0.0, 2.0, 0.0], [1.0, 5.0, 1.0]], **self.tkwargs) opt_inputs = _make_opt_inputs( @@ -176,12 +221,14 @@ def test_sample_feasible_points(self, with_constraints: bool = False) -> None: sample_feasible_points( opt_inputs=opt_inputs, discrete_dims=torch.tensor([0, 2], device=self.device), + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), num_points=10, ) # Generate a number of points. X = sample_feasible_points( opt_inputs=opt_inputs, discrete_dims=torch.tensor([1], device=self.device), + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), num_points=10, ) self.assertEqual(X.shape, torch.Size([10, 3])) @@ -213,6 +260,7 @@ def test_discrete_step(self): # each discrete step should reduce the best_f value by exactly 1 binary_dims = torch.arange(d) + cat_dims = torch.tensor([], device=self.device, dtype=torch.long) for i in range(k): X, ei_val = discrete_step( opt_inputs=_make_opt_inputs( @@ -221,6 +269,7 @@ def test_discrete_step(self): options={"maxiter_discrete": 1, "tol": 0, "init_batch_limit": 32}, ), discrete_dims=binary_dims, + cat_dims=cat_dims, current_x=X, ) ei_x_none = ei(X[None]) @@ -240,6 +289,7 @@ def test_discrete_step(self): options={"maxiter_discrete": 1, "tol": 0, "init_batch_limit": 2}, ), discrete_dims=binary_dims, + cat_dims=cat_dims, current_x=X, ) ei_x_none = ei(X[None]) @@ -259,6 +309,7 @@ def test_discrete_step(self): options={"maxiter_discrete": 1, "tol": 1.5}, ), discrete_dims=binary_dims, + cat_dims=cat_dims, current_x=X_clone, ) # One call when entering, one call in the loop. @@ -277,6 +328,7 @@ def test_discrete_step(self): options={"maxiter_discrete": 1, "tol": 1.5, "init_batch_limit": 2}, ), discrete_dims=binary_dims, + cat_dims=cat_dims, current_x=X_clone, ) self.assertAllClose(X_clone, X) @@ -308,6 +360,7 @@ def test_discrete_step(self): ], ), discrete_dims=binary_dims, + cat_dims=cat_dims, current_x=X, ) self.assertAllClose(ei_val, torch.full_like(ei_val, i + 1)) @@ -332,6 +385,7 @@ def test_discrete_step(self): ], ), discrete_dims=binary_dims, + cat_dims=cat_dims, current_x=X, ) # No feasible neighbors, so we should get the same point back. @@ -364,6 +418,7 @@ def test_continuous_step(self): options={"maxiter_continuous": 32}, ), discrete_dims=binary_dims, + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), current_x=X.clone(), ) self.assertAllClose(X_new[cont_dims], root[cont_dims]) @@ -392,6 +447,7 @@ def test_continuous_step(self): ], ), discrete_dims=binary_dims, + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), current_x=X_, ) self.assertTrue( @@ -416,6 +472,7 @@ def test_continuous_step(self): options={"maxiter_continuous": 32}, ), discrete_dims=binary_dims, + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), current_x=X, ) self.assertTrue(X is X_out) # testing pointer equality for due to short cut @@ -425,6 +482,8 @@ def test_optimize_acqf_mixed_binary_only(self) -> None: train_X, train_Y, binary_dims, cont_dims = self._get_data() dim = len(binary_dims) + len(cont_dims) bounds = self.single_bound.repeat(1, dim) + binary_dims_t = torch.tensor(binary_dims, device=self.device, dtype=torch.long) + cont_dims_t = torch.tensor(cont_dims, device=self.device, dtype=torch.long) torch.manual_seed(0) model = SingleTaskGP(train_X=train_X, train_Y=train_Y) acqf = LogExpectedImprovement(model=model, best_f=torch.max(train_Y)) @@ -441,8 +500,9 @@ def test_optimize_acqf_mixed_binary_only(self) -> None: # testing spray points perturb_nbors = get_spray_points( X_baseline=X_baseline, - discrete_dims=binary_dims, - cont_dims=cont_dims, + discrete_dims=binary_dims_t, + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), + cont_dims=cont_dims_t, bounds=bounds, num_spray_points=assert_is_instance(options["num_spray_points"], int), ) @@ -578,7 +638,7 @@ def test_optimize_acqf_mixed_binary_only(self) -> None: # Invalid indices will raise an error. with self.assertRaisesRegex( ValueError, - "with unique integers between 0 and num_dims - 1", + "with unique, disjoint integers between 0 and num_dims - 1", ): optimize_acqf_mixed_alternating( acq_function=acqf, @@ -602,6 +662,7 @@ def test_optimize_acqf_mixed_integer(self) -> None: bounds[1, 3:5] = 4.0 # Update the model to have a different optimizer. root = torch.tensor([0.0, 0.0, 0.0, 4.0, 4.0], device=self.device) + torch.manual_seed(0) model = QuadraticDeterministicModel(root) acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X) with mock.patch( @@ -667,6 +728,7 @@ def test_optimize_acqf_mixed_integer(self) -> None: options={"batch_limit": 2, "init_batch_limit": 2}, ), discrete_dims=torch.tensor(discrete_dims, device=self.device), + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), cont_dims=torch.tensor(cont_dims, device=self.device), ) self.assertEqual(candidates.shape, torch.Size([4, dim])) @@ -721,6 +783,141 @@ def test_optimize_acqf_mixed_integer(self) -> None: inequality_constraints=[constraint], ), discrete_dims=torch.tensor(discrete_dims, device=self.device), + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), + cont_dims=torch.tensor(cont_dims, device=self.device), + ) + wrapped_sample_feasible.assert_called_once() + # Should request 4 candidates, since all 4 are infeasible. + self.assertEqual(wrapped_sample_feasible.call_args.kwargs["num_points"], 4) + + def test_optimize_acqf_mixed_categorical(self) -> None: + # Testing with integer variables. + train_X, train_Y, binary_dims, cont_dims = self._get_data() + dim = len(binary_dims) + len(cont_dims) + # Update the data to introduce integer dimensions. + binary_dims = [0] + cat_dims = [3, 4] + discrete_dims = binary_dims + bounds = self.single_bound.repeat(1, dim) + bounds[1, 3:5] = 4.0 + # Update the model to have a different optimizer. + root = torch.tensor([0.0, 0.0, 0.0, 4.0, 4.0], device=self.device) + torch.manual_seed(0) + model = QuadraticDeterministicModel(root) + acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X) + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", wraps=_optimize_acqf + ) as wrapped_optimize: + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + discrete_dims=discrete_dims, + cat_dims=cat_dims, + q=3, + raw_samples=32, + num_restarts=4, + options={ + "batch_limit": 5, + "init_batch_limit": 20, + "maxiter_alternating": 1, + }, + ) + self.assertEqual(candidates.shape, torch.Size([3, dim])) + self.assertEqual(candidates.shape[-1], dim) + c_binary = candidates[:, binary_dims] + self.assertTrue(((c_binary == 0) | (c_binary == 1)).all()) + c_cat = candidates[:, cat_dims] + self.assertTrue(torch.equal(c_cat, c_cat.round())) + self.assertTrue((c_cat == 4.0).any()) + # Check that we used continuous relaxation for initialization. + first_call_options = ( + wrapped_optimize.call_args_list[0].kwargs["opt_inputs"].options + ) + self.assertEqual( + first_call_options, + {"maxiter": 100, "batch_limit": 5, "init_batch_limit": 20}, + ) + + # Testing that continuous perturbations lead to lower acquisition values. + perturbed_candidates = candidates.clone() + perturbed_candidates[..., cont_dims] += 1e-2 * torch.randn_like( + perturbed_candidates[..., cont_dims], device=self.device + ) + perturbed_candidates[..., cont_dims].clamp_(0, 1) + self.assertLess((acqf(perturbed_candidates) - acqf(candidates)).max(), 1e-12) + # Testing that integer value change leads to a lower acquisition values. + for i, j in product(cat_dims, range(3)): + perturbed_candidates = candidates.repeat(2, 1, 1) + perturbed_candidates[0, j, i] += 1.0 + perturbed_candidates[1, j, i] -= 1.0 + perturbed_candidates.clamp_(bounds[0], bounds[1]) + self.assertLess( + (acqf(perturbed_candidates) - acqf(candidates)).max(), 1e-12 + ) + + # Test gracious fallback when continuous relaxation fails. + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", + side_effect=RuntimeError, + ), self.assertWarnsRegex(OptimizationWarning, "Failed to initialize"): + candidates, _ = generate_starting_points( + opt_inputs=_make_opt_inputs( + acq_function=acqf, + bounds=bounds, + raw_samples=32, + num_restarts=4, + options={"batch_limit": 2, "init_batch_limit": 2}, + ), + discrete_dims=torch.tensor(discrete_dims, device=self.device), + cat_dims=torch.tensor([], device=self.device, dtype=torch.long), + cont_dims=torch.tensor(cont_dims, device=self.device), + ) + self.assertEqual(candidates.shape, torch.Size([4, dim])) + + # Test with fixed features and constraints. Using both discrete and continuous. + constraint = ( # X[..., 0] + X[..., 1] >= 1. + torch.tensor([0, 1], device=self.device), + torch.ones(2, device=self.device), + 1.0, + ) + candidates, _ = optimize_acqf_mixed_alternating( + acq_function=acqf, + bounds=bounds, + cat_dims=cat_dims, + q=3, + raw_samples=32, + num_restarts=4, + options={"batch_limit": 5, "init_batch_limit": 20}, + fixed_features={1: 0.5, 3: 2}, + inequality_constraints=[constraint], + ) + self.assertAllClose( + candidates[:, [0, 1, 3]], + torch.tensor( + [0.5, 0.5, 2.0], device=self.device, dtype=candidates.dtype + ).repeat(3, 1), + ) + + # Test fallback when initializer cannot generate enough feasible points. + with mock.patch( + f"{OPT_MODULE}._optimize_acqf", + return_value=( + torch.zeros(4, 1, dim, **self.tkwargs), + torch.zeros(4, **self.tkwargs), + ), + ), mock.patch( + f"{OPT_MODULE}.sample_feasible_points", wraps=sample_feasible_points + ) as wrapped_sample_feasible: + generate_starting_points( + opt_inputs=_make_opt_inputs( + acq_function=acqf, + bounds=bounds, + raw_samples=32, + num_restarts=4, + inequality_constraints=[constraint], + ), + discrete_dims=torch.tensor(discrete_dims, device=self.device), + cat_dims=torch.tensor(cat_dims, device=self.device), cont_dims=torch.tensor(cont_dims, device=self.device), ) wrapped_sample_feasible.assert_called_once() @@ -741,6 +938,7 @@ def test_optimize_acqf_mixed_continuous_relaxation(self) -> None: ) # Update the model to have a different optimizer. root = torch.tensor([0.0, 0.0, 0.0, 25.0, 10.0], device=self.device) + torch.manual_seed(0) model = QuadraticDeterministicModel(root) acqf = qLogNoisyExpectedImprovement(model=model, X_baseline=train_X)