Skip to content

Commit 55903b1

Browse files
mpolson64facebook-github-bot
authored andcommitted
Make models enum non-optional for get_model_from_generator_run
Summary: This being optional was creating a bug-prone codepath due to our use of the method on generator runs using internal only models. This change to the models enum being required cascaded through a fairly large portion of our code, stemming from our handling of best_point functions, but I think it will be worth it and not as disruptive as one may think. Reviewed By: lena-kashtelyan Differential Revision: D33238024 fbshipit-source-id: ce3bd7089ad208f19c3b88d7b5d22483e2b59986
1 parent 8d5b146 commit 55903b1

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torchx/runtime/hpo/test/ax_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
SearchSpace,
2323
)
2424
from ax.modelbridge.dispatch_utils import choose_generation_strategy
25+
from ax.modelbridge.registry import Models
2526
from ax.service.scheduler import SchedulerOptions
2627
from ax.service.utils.best_point import get_best_parameters
2728
from ax.service.utils.report_utils import exp_to_df
@@ -104,7 +105,9 @@ def test_run_experiment_locally(self) -> None:
104105

105106
# AppMetrics always returns trial index; hence the best
106107
# experiment for min objective will be the params for trial 0
107-
best_param, _ = none_throws(get_best_parameters(experiment))
108+
best_param, _ = none_throws(
109+
get_best_parameters(experiment=experiment, models_enum=Models)
110+
)
108111
# nothing to assert, just make sure experiment runs
109112

110113
def test_run_experiment_locally_in_batches(self) -> None:
@@ -150,7 +153,9 @@ def test_run_experiment_locally_in_batches(self) -> None:
150153

151154
# AppMetrics always returns trial index; hence the best
152155
# experiment for min objective will be the params for trial 0
153-
best_param, _ = none_throws(get_best_parameters(experiment))
156+
best_param, _ = none_throws(
157+
get_best_parameters(experiment=experiment, models_enum=Models)
158+
)
154159
# nothing to assert, just make sure experiment runs
155160

156161
def test_runner_no_batch_trials(self) -> None:

0 commit comments

Comments
 (0)