39
39
MAX_ITER_ALTER = 64 # Maximum number of alternating iterations.
40
40
MAX_ITER_DISCRETE = 4 # Maximum number of discrete iterations.
41
41
MAX_ITER_CONT = 8 # Maximum number of continuous iterations.
42
+ # Maximum number of discrete values for a discrete dimension.
43
+ # If there are more values for a dimension, we will use continuous
44
+ # relaxation to optimize it.
45
+ MAX_DISCRETE_VALUES = 20
42
46
# Maximum number of iterations for optimizing the continuous relaxation
43
47
# during initialization
44
48
MAX_ITER_INIT = 100
52
56
"maxiter_discrete" ,
53
57
"maxiter_continuous" ,
54
58
"maxiter_init" ,
59
+ "max_discrete_values" ,
55
60
"num_spray_points" ,
56
61
"std_cont_perturbation" ,
57
62
"batch_limit" ,
60
65
SUPPORTED_INITIALIZATION = {"continuous_relaxation" , "equally_spaced" , "random" }
61
66
62
67
68
+ def _setup_continuous_relaxation (
69
+ discrete_dims : list [int ],
70
+ bounds : Tensor ,
71
+ max_discrete_values : int ,
72
+ post_processing_func : Callable [[Tensor ], Tensor ] | None ,
73
+ ) -> tuple [list [int ], Callable [[Tensor ], Tensor ] | None ]:
74
+ r"""Update `discrete_dims` and `post_processing_func` to use
75
+ continuous relaxation for discrete dimensions that have more than
76
+ `max_discrete_values` values. These dimensions are removed from
77
+ `discrete_dims` and `post_processing_func` is updated to round
78
+ them to the nearest integer.
79
+ """
80
+ discrete_dims_t = torch .tensor (discrete_dims , dtype = torch .long )
81
+ num_discrete_values = (bounds [1 , discrete_dims ] - bounds [0 , discrete_dims ]).cpu ()
82
+ dims_to_relax = discrete_dims_t [num_discrete_values > max_discrete_values ]
83
+ if dims_to_relax .numel () == 0 :
84
+ # No dimension needs continuous relaxation.
85
+ return discrete_dims , post_processing_func
86
+ # Remove relaxed dims from `discrete_dims`.
87
+ discrete_dims = list (set (discrete_dims ).difference (dims_to_relax .tolist ()))
88
+
89
+ def new_post_processing_func (X : Tensor ) -> Tensor :
90
+ r"""Round the relaxed dimensions to the nearest integer and apply the original
91
+ `post_processing_func`."""
92
+ X [:, dims_to_relax ] = X [:, dims_to_relax ].round ()
93
+ if post_processing_func is not None :
94
+ X = post_processing_func (X )
95
+ return X
96
+
97
+ return discrete_dims , new_post_processing_func
98
+
99
+
63
100
def _filter_infeasible (
64
101
X : Tensor , inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None
65
102
) -> Tensor :
@@ -532,6 +569,9 @@ def optimize_acqf_mixed_alternating(
532
569
iterations.
533
570
534
571
NOTE: This method assumes that all discrete variables are integer valued.
572
+ The discrete dimensions that have more than
573
+ `options.get("max_discrete_values", MAX_DISCRETE_VALUES)` values will
574
+ be optimized using continuous relaxation.
535
575
536
576
# TODO: Support categorical variables.
537
577
@@ -549,6 +589,9 @@ def optimize_acqf_mixed_alternating(
549
589
Defaults to 4.
550
590
- "maxiter_continuous": Maximum number of iterations in each continuous step.
551
591
Defaults to 8.
592
+ - "max_discrete_values": Maximum number of values for a discrete dimension
593
+ to be optimized using discrete step / local search. The discrete dimensions
594
+ with more values will be optimized using continuous relaxation.
552
595
- "num_spray_points": Number of spray points (around `X_baseline`) to add to
553
596
the points generated by the initialization strategy. Defaults to 20 if
554
597
all discrete variables are binary and to 0 otherwise.
@@ -598,6 +641,17 @@ def optimize_acqf_mixed_alternating(
598
641
f"Received an unsupported option { unsupported_keys } . { SUPPORTED_OPTIONS = } ."
599
642
)
600
643
644
+ # Update discrete dims and post processing functions to account for any
645
+ # dimensions that should be using continuous relaxation.
646
+ discrete_dims , post_processing_func = _setup_continuous_relaxation (
647
+ discrete_dims = discrete_dims ,
648
+ bounds = bounds ,
649
+ max_discrete_values = assert_is_instance (
650
+ options .get ("max_discrete_values" , MAX_DISCRETE_VALUES ), int
651
+ ),
652
+ post_processing_func = post_processing_func ,
653
+ )
654
+
601
655
opt_inputs = OptimizeAcqfInputs (
602
656
acq_function = acq_function ,
603
657
bounds = bounds ,
0 commit comments