Skip to content

Commit 4b1fe5b

Browse files
committed
lint with ufmt
1 parent beb9309 commit 4b1fe5b

File tree

3 files changed

+32
-15
lines changed

3 files changed

+32
-15
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1780,7 +1780,7 @@ def optimize_objective(
17801780
bounds=free_feature_bounds,
17811781
q=q,
17821782
num_restarts=optimizer_options.get("num_restarts", 60),
1783-
raw_samples=optimizer_options.get("raw_samples", 1024), # NOTE potential behaviour change
1783+
raw_samples=optimizer_options.get("raw_samples", 1024),
17841784
options={
17851785
"batch_limit": optimizer_options.get("batch_limit", 8),
17861786
"maxiter": optimizer_options.get("maxiter", 200),

botorch/optim/optimize.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def __post_init__(self) -> None:
113113
# Use object.__setattr__ to bypass immutability and set a value
114114
object.__setattr__(
115115
self,
116-
'batch_initial_conditions',
117-
self.batch_initial_conditions.unsqueeze(0)
116+
"batch_initial_conditions",
117+
self.batch_initial_conditions.unsqueeze(0),
118118
)
119119

120120
if batch_initial_conditions_shape[-1] != d:
@@ -139,7 +139,7 @@ def __post_init__(self) -> None:
139139
DeprecationWarning,
140140
)
141141
# Use object.__setattr__ to bypass immutability and set a value
142-
object.__setattr__(self, 'raw_samples', None)
142+
object.__setattr__(self, "raw_samples", None)
143143

144144
elif self.ic_generator is None:
145145
if self.nonlinear_inequality_constraints is not None:
@@ -279,15 +279,20 @@ def _optimize_acqf_sequential_q(
279279
return candidates, torch.stack(acq_value_list)
280280

281281

282-
def _combine_initial_conditions(provided_initial_conditions, generated_initial_conditions, num_restarts):
283-
if provided_initial_conditions is not None and generated_initial_conditions is not None:
282+
def _combine_initial_conditions(
283+
provided_initial_conditions, generated_initial_conditions, num_restarts
284+
):
285+
if (
286+
provided_initial_conditions is not None
287+
and generated_initial_conditions is not None
288+
):
284289
if ( # Repeat the provided initial conditions to match the number of restarts
285290
provided_initial_conditions.shape[0] == 1
286291
and num_restarts is not None
287292
and num_restarts > 1
288293
):
289294
provided_initial_conditions = provided_initial_conditions.repeat(
290-
num_restarts, *([1] * (provided_initial_conditions.dim()-1))
295+
num_restarts, *([1] * (provided_initial_conditions.dim() - 1))
291296
)
292297
return torch.cat(
293298
[provided_initial_conditions, generated_initial_conditions], dim=-2
@@ -297,7 +302,9 @@ def _combine_initial_conditions(provided_initial_conditions, generated_initial_c
297302
elif generated_initial_conditions is not None:
298303
return generated_initial_conditions
299304
else:
300-
raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.")
305+
raise ValueError(
306+
"Either `batch_initial_conditions` or `raw_samples` must be set."
307+
)
301308

302309

303310
def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]:
@@ -328,7 +335,7 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
328335
batch_initial_conditions = _combine_initial_conditions(
329336
provided_initial_conditions=opt_inputs.batch_initial_conditions,
330337
generated_initial_conditions=generated_initial_conditions,
331-
num_restarts=opt_inputs.num_restarts
338+
num_restarts=opt_inputs.num_restarts,
332339
)
333340

334341
batch_limit: int = options.get(
@@ -429,7 +436,7 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
429436
batch_initial_conditions = _combine_initial_conditions(
430437
provided_initial_conditions=opt_inputs.batch_initial_conditions,
431438
generated_initial_conditions=generated_initial_conditions,
432-
num_restarts=opt_inputs.num_restarts
439+
num_restarts=opt_inputs.num_restarts,
433440
)
434441

435442
batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
@@ -1287,7 +1294,9 @@ def _gen_first_candidate_starting_points_local_search(
12871294
generated_X0 = None
12881295

12891296
if batch_initial_conditions is not None:
1290-
provided_X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid)
1297+
provided_X0 = _filter_invalid(
1298+
X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid
1299+
)
12911300
provided_X0 = _filter_infeasible(
12921301
X=provided_X0, inequality_constraints=inequality_constraints
12931302
).unsqueeze(1)
@@ -1312,7 +1321,9 @@ def _gen_first_candidate_starting_points_local_search(
13121321
elif generated_X0 is not None:
13131322
X0 = generated_X0
13141323
else:
1315-
raise ValueError("Either `batch_initial_conditions` or `raw_samples` must be set.")
1324+
raise ValueError(
1325+
"Either `batch_initial_conditions` or `raw_samples` must be set."
1326+
)
13161327

13171328
return X0
13181329

test/optim/test_optimize.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,8 @@ def test_optimize_acqf_batch_limit(self) -> None:
546546

547547
options = {"batch_limit": batch_limit}
548548
initial_conditions = [
549-
torch.ones(shape) for shape in [(1, 2, dim), (3, 1, dim), (3, 6, dim), (1, dim)]
549+
torch.ones(shape)
550+
for shape in [(1, 2, dim), (3, 1, dim), (3, 6, dim), (1, dim)]
550551
] + [None]
551552

552553
for gen_candidates, ics in zip(
@@ -572,7 +573,9 @@ def test_optimize_acqf_batch_limit(self) -> None:
572573
self.assertEqual(acq_value_list.shape, expected_shape)
573574

574575
with self.subTest(gen_candidates=gen_candidates):
575-
with self.assertWarnsRegex(DeprecationWarning, "Defaulting to old behavior"):
576+
with self.assertWarnsRegex(
577+
DeprecationWarning, "Defaulting to old behavior"
578+
):
576579
ics = torch.ones((2, 1, dim))
577580
_, acq_value_list = optimize_acqf(
578581
acq_function=SinOneOverXAcqusitionFunction(),
@@ -602,7 +605,10 @@ def test_optimize_acqf_runs_given_batch_initial_conditions(self):
602605
]
603606
q = 1
604607

605-
ic_shapes = [(1, 2, dim), (1, dim)] # NOTE removed test on (2, 1, dim) as error tested elsewhere
608+
ic_shapes = [
609+
(1, 2, dim),
610+
(1, dim),
611+
] # NOTE removed test on (2, 1, dim) as error tested elsewhere
606612

607613
torch.manual_seed(0)
608614
for shape in ic_shapes:

0 commit comments

Comments
 (0)