@@ -113,8 +113,8 @@ def __post_init__(self) -> None:
113
113
# Use object.__setattr__ to bypass immutability and set a value
114
114
object .__setattr__ (
115
115
self ,
116
- ' batch_initial_conditions' ,
117
- self .batch_initial_conditions .unsqueeze (0 )
116
+ " batch_initial_conditions" ,
117
+ self .batch_initial_conditions .unsqueeze (0 ),
118
118
)
119
119
120
120
if batch_initial_conditions_shape [- 1 ] != d :
@@ -139,7 +139,7 @@ def __post_init__(self) -> None:
139
139
DeprecationWarning ,
140
140
)
141
141
# Use object.__setattr__ to bypass immutability and set a value
142
- object .__setattr__ (self , ' raw_samples' , None )
142
+ object .__setattr__ (self , " raw_samples" , None )
143
143
144
144
elif self .ic_generator is None :
145
145
if self .nonlinear_inequality_constraints is not None :
@@ -279,15 +279,20 @@ def _optimize_acqf_sequential_q(
279
279
return candidates , torch .stack (acq_value_list )
280
280
281
281
282
- def _combine_initial_conditions (provided_initial_conditions , generated_initial_conditions , num_restarts ):
283
- if provided_initial_conditions is not None and generated_initial_conditions is not None :
282
+ def _combine_initial_conditions (
283
+ provided_initial_conditions , generated_initial_conditions , num_restarts
284
+ ):
285
+ if (
286
+ provided_initial_conditions is not None
287
+ and generated_initial_conditions is not None
288
+ ):
284
289
if ( # Repeat the provided initial conditions to match the number of restarts
285
290
provided_initial_conditions .shape [0 ] == 1
286
291
and num_restarts is not None
287
292
and num_restarts > 1
288
293
):
289
294
provided_initial_conditions = provided_initial_conditions .repeat (
290
- num_restarts , * ([1 ] * (provided_initial_conditions .dim ()- 1 ))
295
+ num_restarts , * ([1 ] * (provided_initial_conditions .dim () - 1 ))
291
296
)
292
297
return torch .cat (
293
298
[provided_initial_conditions , generated_initial_conditions ], dim = - 2
@@ -297,7 +302,9 @@ def _combine_initial_conditions(provided_initial_conditions, generated_initial_c
297
302
elif generated_initial_conditions is not None :
298
303
return generated_initial_conditions
299
304
else :
300
- raise ValueError ("Either `batch_initial_conditions` or `raw_samples` must be set." )
305
+ raise ValueError (
306
+ "Either `batch_initial_conditions` or `raw_samples` must be set."
307
+ )
301
308
302
309
303
310
def _optimize_acqf_batch (opt_inputs : OptimizeAcqfInputs ) -> tuple [Tensor , Tensor ]:
@@ -328,7 +335,7 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
328
335
batch_initial_conditions = _combine_initial_conditions (
329
336
provided_initial_conditions = opt_inputs .batch_initial_conditions ,
330
337
generated_initial_conditions = generated_initial_conditions ,
331
- num_restarts = opt_inputs .num_restarts
338
+ num_restarts = opt_inputs .num_restarts ,
332
339
)
333
340
334
341
batch_limit : int = options .get (
@@ -429,7 +436,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
429
436
batch_initial_conditions = _combine_initial_conditions (
430
437
provided_initial_conditions = opt_inputs .batch_initial_conditions ,
431
438
generated_initial_conditions = generated_initial_conditions ,
432
- num_restarts = opt_inputs .num_restarts
439
+ num_restarts = opt_inputs .num_restarts ,
433
440
)
434
441
435
442
batch_candidates , batch_acq_values , ws = _optimize_batch_candidates ()
@@ -1287,7 +1294,9 @@ def _gen_first_candidate_starting_points_local_search(
1287
1294
generated_X0 = None
1288
1295
1289
1296
if batch_initial_conditions is not None :
1290
- provided_X0 = _filter_invalid (X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid )
1297
+ provided_X0 = _filter_invalid (
1298
+ X = batch_initial_conditions .squeeze (1 ), X_avoid = X_avoid
1299
+ )
1291
1300
provided_X0 = _filter_infeasible (
1292
1301
X = provided_X0 , inequality_constraints = inequality_constraints
1293
1302
).unsqueeze (1 )
@@ -1312,7 +1321,9 @@ def _gen_first_candidate_starting_points_local_search(
1312
1321
elif generated_X0 is not None :
1313
1322
X0 = generated_X0
1314
1323
else :
1315
- raise ValueError ("Either `batch_initial_conditions` or `raw_samples` must be set." )
1324
+ raise ValueError (
1325
+ "Either `batch_initial_conditions` or `raw_samples` must be set."
1326
+ )
1316
1327
1317
1328
return X0
1318
1329
0 commit comments