@@ -109,12 +109,31 @@ def __post_init__(self) -> None:
109
109
"3-dimensional. Its shape is "
110
110
f"{ batch_initial_conditions_shape } ."
111
111
)
112
+
112
113
if batch_initial_conditions_shape [- 1 ] != d :
113
114
raise ValueError (
114
115
f"batch_initial_conditions.shape[-1] must be { d } . The "
115
116
f"shape is { batch_initial_conditions_shape } ."
116
117
)
117
118
119
+ if (
120
+ self .raw_samples is not None
121
+ and (self .raw_samples - batch_initial_conditions_shape [- 2 ]) > 0
122
+ and len (batch_initial_conditions_shape ) == 3
123
+ and self .num_restarts is not None
124
+ and batch_initial_conditions_shape [0 ] not in [1 , self .num_restarts ]
125
+ ):
126
+ warnings .warn (
127
+ "If using `batch_initial_conditions` together with `raw_samples`, "
128
+ "the first repeat dimension of `batch_initial_conditions` must "
129
+ "match `num_restarts`. In the future this will raise an error. "
130
+ "Defaulting to old behavior of ignoring `raw_samples` by setting "
131
+ "it to None." ,
132
+ DeprecationWarning ,
133
+ )
134
+ # Use object.__setattr__ to bypass immutability and set a value
135
+ object .__setattr__ (self , "raw_samples" , None )
136
+
118
137
elif self .ic_generator is None :
119
138
if self .nonlinear_inequality_constraints is not None :
120
139
raise RuntimeError (
@@ -253,27 +272,73 @@ def _optimize_acqf_sequential_q(
253
272
return candidates , torch .stack (acq_value_list )
254
273
255
274
275
+ def _combine_initial_conditions (
276
+ provided_initial_conditions : Tensor | None = None ,
277
+ generated_initial_conditions : Tensor | None = None ,
278
+ num_restarts : int | None = None ,
279
+ ) -> Tensor :
280
+
281
+ if (
282
+ provided_initial_conditions is not None
283
+ and generated_initial_conditions is not None
284
+ ):
285
+ if ( # Repeat the provided initial conditions to match the number of restarts
286
+ provided_initial_conditions .shape [0 ] == 1
287
+ and num_restarts is not None
288
+ and num_restarts > 1
289
+ ):
290
+ provided_initial_conditions = provided_initial_conditions .repeat (
291
+ num_restarts , * ([1 ] * (provided_initial_conditions .dim () - 1 ))
292
+ )
293
+ initial_conditions = torch .cat (
294
+ [provided_initial_conditions , generated_initial_conditions ], dim = - 2
295
+ )
296
+ perm = torch .randperm (
297
+ initial_conditions .shape [- 2 ], device = initial_conditions .device
298
+ )
299
+ return initial_conditions .gather (
300
+ - 2 , perm .unsqueeze (- 1 ).expand_as (initial_conditions )
301
+ )
302
+ elif provided_initial_conditions is not None :
303
+ return provided_initial_conditions
304
+ elif generated_initial_conditions is not None :
305
+ return generated_initial_conditions
306
+ else :
307
+ raise ValueError (
308
+ "Either `batch_initial_conditions` or `raw_samples` must be set."
309
+ )
310
+
311
+
256
312
def _optimize_acqf_batch (opt_inputs : OptimizeAcqfInputs ) -> tuple [Tensor , Tensor ]:
257
313
options = opt_inputs .options or {}
258
314
259
- initial_conditions_provided = opt_inputs .batch_initial_conditions is not None
315
+ required_raw_samples = opt_inputs .raw_samples
316
+ generated_initial_conditions = None
260
317
261
- if initial_conditions_provided :
262
- batch_initial_conditions = opt_inputs .batch_initial_conditions
263
- else :
264
- # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
265
- batch_initial_conditions = opt_inputs .get_ic_generator ()(
266
- acq_function = opt_inputs .acq_function ,
267
- bounds = opt_inputs .bounds ,
268
- q = opt_inputs .q ,
269
- num_restarts = opt_inputs .num_restarts ,
270
- raw_samples = opt_inputs .raw_samples ,
271
- fixed_features = opt_inputs .fixed_features ,
272
- options = options ,
273
- inequality_constraints = opt_inputs .inequality_constraints ,
274
- equality_constraints = opt_inputs .equality_constraints ,
275
- ** opt_inputs .ic_gen_kwargs ,
276
- )
318
+ if required_raw_samples is not None :
319
+ if opt_inputs .batch_initial_conditions is not None :
320
+ required_raw_samples -= opt_inputs .batch_initial_conditions .shape [- 2 ]
321
+
322
+ if required_raw_samples > 0 :
323
+ # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
324
+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
325
+ acq_function = opt_inputs .acq_function ,
326
+ bounds = opt_inputs .bounds ,
327
+ q = opt_inputs .q ,
328
+ num_restarts = opt_inputs .num_restarts ,
329
+ raw_samples = required_raw_samples ,
330
+ fixed_features = opt_inputs .fixed_features ,
331
+ options = options ,
332
+ inequality_constraints = opt_inputs .inequality_constraints ,
333
+ equality_constraints = opt_inputs .equality_constraints ,
334
+ ** opt_inputs .ic_gen_kwargs ,
335
+ )
336
+
337
+ batch_initial_conditions = _combine_initial_conditions (
338
+ provided_initial_conditions = opt_inputs .batch_initial_conditions ,
339
+ generated_initial_conditions = generated_initial_conditions ,
340
+ num_restarts = opt_inputs .num_restarts ,
341
+ )
277
342
278
343
batch_limit : int = options .get (
279
344
"batch_limit" ,
@@ -344,31 +409,38 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
344
409
first_warn_msg = (
345
410
"Optimization failed in `gen_candidates_scipy` with the following "
346
411
f"warning(s):\n { [w .message for w in ws ]} \n Because you specified "
347
- "`batch_initial_conditions`, optimization will not be retried with "
348
- "new initial conditions and will proceed with the current solution."
349
- " Suggested remediation: Try again with different "
350
- "`batch_initial_conditions`, or don't provide `batch_initial_conditions.`"
351
- if initial_conditions_provided
412
+ "`batch_initial_conditions`>`raw_samples`, optimization will not "
413
+ "be retried with new initial conditions and will proceed with the "
414
+ "current solution. Suggested remediation: Try again with different "
415
+ "`batch_initial_conditions`, don't provide `batch_initial_conditions`, "
416
+ "or increase `raw_samples`.`"
417
+ if required_raw_samples is not None and required_raw_samples <= 0
352
418
else "Optimization failed in `gen_candidates_scipy` with the following "
353
419
f"warning(s):\n { [w .message for w in ws ]} \n Trying again with a new "
354
420
"set of initial conditions."
355
421
)
356
422
warnings .warn (first_warn_msg , RuntimeWarning , stacklevel = 2 )
357
423
358
- if not initial_conditions_provided :
359
- batch_initial_conditions = opt_inputs .get_ic_generator ()(
424
+ if required_raw_samples is not None and required_raw_samples > 0 :
425
+ generated_initial_conditions = opt_inputs .get_ic_generator ()(
360
426
acq_function = opt_inputs .acq_function ,
361
427
bounds = opt_inputs .bounds ,
362
428
q = opt_inputs .q ,
363
429
num_restarts = opt_inputs .num_restarts ,
364
- raw_samples = opt_inputs . raw_samples ,
430
+ raw_samples = required_raw_samples ,
365
431
fixed_features = opt_inputs .fixed_features ,
366
432
options = options ,
367
433
inequality_constraints = opt_inputs .inequality_constraints ,
368
434
equality_constraints = opt_inputs .equality_constraints ,
369
435
** opt_inputs .ic_gen_kwargs ,
370
436
)
371
437
438
+ batch_initial_conditions = _combine_initial_conditions (
439
+ provided_initial_conditions = opt_inputs .batch_initial_conditions ,
440
+ generated_initial_conditions = generated_initial_conditions ,
441
+ num_restarts = opt_inputs .num_restarts ,
442
+ )
443
+
372
444
batch_candidates , batch_acq_values , ws = _optimize_batch_candidates ()
373
445
374
446
optimization_warning_raised = any (
@@ -1177,7 +1249,7 @@ def _gen_batch_initial_conditions_local_search(
1177
1249
inequality_constraints : list [tuple [Tensor , Tensor , float ]],
1178
1250
min_points : int ,
1179
1251
max_tries : int = 100 ,
1180
- ):
1252
+ ) -> Tensor :
1181
1253
"""Generate initial conditions for local search."""
1182
1254
device = discrete_choices [0 ].device
1183
1255
dtype = discrete_choices [0 ].dtype
@@ -1197,6 +1269,66 @@ def _gen_batch_initial_conditions_local_search(
1197
1269
raise RuntimeError (f"Failed to generate at least { min_points } initial conditions" )
1198
1270
1199
1271
1272
+ def _gen_starting_points_local_search (
1273
+ discrete_choices : list [Tensor ],
1274
+ raw_samples : int ,
1275
+ batch_initial_conditions : Tensor ,
1276
+ X_avoid : Tensor ,
1277
+ inequality_constraints : list [tuple [Tensor , Tensor , float ]],
1278
+ min_points : int ,
1279
+ acq_function : AcquisitionFunction ,
1280
+ max_batch_size : int = 2048 ,
1281
+ max_tries : int = 100 ,
1282
+ ) -> Tensor :
1283
+ required_min_points = min_points
1284
+ provided_X0 = None
1285
+ generated_X0 = None
1286
+
1287
+ if batch_initial_conditions is not None :
1288
+ provided_X0 = _filter_invalid (
1289
+ X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid
1290
+ )
1291
+ provided_X0 = _filter_infeasible (
1292
+ X = provided_X0 , inequality_constraints = inequality_constraints
1293
+ ).unsqueeze (1 )
1294
+ required_min_points -= batch_initial_conditions .shape [0 ]
1295
+
1296
+ if required_min_points > 0 :
1297
+ generated_X0 = _gen_batch_initial_conditions_local_search (
1298
+ discrete_choices = discrete_choices ,
1299
+ raw_samples = raw_samples ,
1300
+ X_avoid = X_avoid ,
1301
+ inequality_constraints = inequality_constraints ,
1302
+ min_points = min_points ,
1303
+ max_tries = max_tries ,
1304
+ )
1305
+
1306
+ # pick the best starting points
1307
+ with torch .no_grad ():
1308
+ acqvals_init = _split_batch_eval_acqf (
1309
+ acq_function = acq_function ,
1310
+ X = generated_X0 .unsqueeze (1 ),
1311
+ max_batch_size = max_batch_size ,
1312
+ ).unsqueeze (- 1 )
1313
+
1314
+ generated_X0 = generated_X0 [
1315
+ acqvals_init .topk (k = min_points , largest = True , dim = 0 ).indices
1316
+ ]
1317
+
1318
+ if provided_X0 is not None and generated_X0 is not None :
1319
+ X0 = torch .cat ([provided_X0 , generated_X0 ], dim = 0 )
1320
+ elif provided_X0 is not None :
1321
+ X0 = provided_X0
1322
+ elif generated_X0 is not None :
1323
+ X0 = generated_X0
1324
+ else :
1325
+ raise ValueError (
1326
+ "Either `batch_initial_conditions` or `raw_samples` must be set."
1327
+ )
1328
+
1329
+ return X0
1330
+
1331
+
1200
1332
def optimize_acqf_discrete_local_search (
1201
1333
acq_function : AcquisitionFunction ,
1202
1334
discrete_choices : list [Tensor ],
@@ -1207,6 +1339,7 @@ def optimize_acqf_discrete_local_search(
1207
1339
X_avoid : Tensor | None = None ,
1208
1340
batch_initial_conditions : Tensor | None = None ,
1209
1341
max_batch_size : int = 2048 ,
1342
+ max_tries : int = 100 ,
1210
1343
unique : bool = True ,
1211
1344
) -> tuple [Tensor , Tensor ]:
1212
1345
r"""Optimize acquisition function over a lattice.
@@ -1238,6 +1371,8 @@ def optimize_acqf_discrete_local_search(
1238
1371
max_batch_size: The maximum number of choices to evaluate in batch.
1239
1372
A large limit can cause excessive memory usage if the model has
1240
1373
a large training set.
1374
+ max_tries: Maximum number of iterations to try when generating initial
1375
+ conditions.
1241
1376
unique: If True return unique choices, o/w choices may be repeated
1242
1377
(only relevant if `q > 1`).
1243
1378
@@ -1247,6 +1382,13 @@ def optimize_acqf_discrete_local_search(
1247
1382
- a `q x d`-dim tensor of generated candidates.
1248
1383
- an associated acquisition value.
1249
1384
"""
1385
+ if batch_initial_conditions is not None :
1386
+ if not (
1387
+ len (batch_initial_conditions .shape ) == 3
1388
+ and batch_initial_conditions .shape [- 2 ] == 1
1389
+ ):
1390
+ raise ValueError ("batch_initial_conditions must have shape `n x 1 x d` if given." )
1391
+
1250
1392
candidate_list = []
1251
1393
base_X_pending = acq_function .X_pending if q > 1 else None
1252
1394
base_X_avoid = X_avoid
@@ -1259,27 +1401,18 @@ def optimize_acqf_discrete_local_search(
1259
1401
inequality_constraints = inequality_constraints or []
1260
1402
for i in range (q ):
1261
1403
# generate some starting points
1262
- if i == 0 and batch_initial_conditions is not None :
1263
- X0 = _filter_invalid (X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid )
1264
- X0 = _filter_infeasible (
1265
- X = X0 , inequality_constraints = inequality_constraints
1266
- ).unsqueeze (1 )
1267
- else :
1268
- X_init = _gen_batch_initial_conditions_local_search (
1269
- discrete_choices = discrete_choices ,
1270
- raw_samples = raw_samples ,
1271
- X_avoid = X_avoid ,
1272
- inequality_constraints = inequality_constraints ,
1273
- min_points = num_restarts ,
1274
- )
1275
- # pick the best starting points
1276
- with torch .no_grad ():
1277
- acqvals_init = _split_batch_eval_acqf (
1278
- acq_function = acq_function ,
1279
- X = X_init .unsqueeze (1 ),
1280
- max_batch_size = max_batch_size ,
1281
- ).unsqueeze (- 1 )
1282
- X0 = X_init [acqvals_init .topk (k = num_restarts , largest = True , dim = 0 ).indices ]
1404
+ X0 = _gen_starting_points_local_search (
1405
+ discrete_choices = discrete_choices ,
1406
+ raw_samples = raw_samples ,
1407
+ batch_initial_conditions = batch_initial_conditions ,
1408
+ X_avoid = X_avoid ,
1409
+ inequality_constraints = inequality_constraints ,
1410
+ min_points = num_restarts ,
1411
+ acq_function = acq_function ,
1412
+ max_batch_size = max_batch_size ,
1413
+ max_tries = max_tries ,
1414
+ )
1415
+ batch_initial_conditions = None
1283
1416
1284
1417
# optimize from the best starting points
1285
1418
best_xs = torch .zeros (len (X0 ), dim , device = device , dtype = dtype )
0 commit comments