Skip to content

Commit 6f2bdff

Browse files
aivanoufacebook-github-bot
authored andcommitted
Move poll_trial_status and poll_available_capacity to TorchXRunner, remove TorchXScheduler
Summary: Adjusting the TorchX setup following D31031589 and D31032567 Deprecate `torchx.runtime.hpo.ax.TorchXScheduler` Differential Revision: D33062113 fbshipit-source-id: 41987a64ba9671e93a35f55088e7a5d77eafca78
1 parent c56cad6 commit 6f2bdff

File tree

4 files changed

+14
-33
lines changed

4 files changed

+14
-33
lines changed

docs/source/runtime/hpo.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,4 @@ Ax (Adaptive Experimentation)
1414
.. currentmodule:: torchx.runtime.hpo.ax
1515

1616
.. autoclass:: TorchXRunner
17-
.. autoclass:: TorchXScheduler
18-
.. autoclass:: AppMetric
17+
.. autoclass:: AppMetric

torchx/runtime/hpo/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454
SearchSpace,
5555
)
5656
from ax.modelbridge.dispatch_utils import choose_generation_strategy
57-
from ax.service.scheduler import SchedulerOptions
57+
from ax.service.scheduler import SchedulerOptions, Scheduler
5858
from ax.service.utils.best_point import get_best_parameters
5959
from ax.service.utils.report_utils import exp_to_df
6060
from ax.utils.common.constants import Keys
6161
from pyre_extensions import none_throws
6262
from torchx.components import utils
63-
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner, TorchXScheduler
63+
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner
6464
6565
# Run HPO on the booth function (https://en.wikipedia.org/wiki/Test_functions_for_optimization)
6666
@@ -100,7 +100,7 @@
100100
properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True},
101101
)
102102
103-
scheduler = TorchXScheduler(
103+
scheduler = Scheduler(
104104
experiment=experiment,
105105
generation_strategy=(
106106
choose_generation_strategy(

torchx/runtime/hpo/ax.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import inspect
9-
from typing import Any, Callable, Dict, Mapping, Optional, Set, cast
9+
from typing import Any, Callable, Dict, Mapping, Optional, Set, cast, Iterable
1010

1111
import pandas as pd
1212
from ax.core import Trial
1313
from ax.core.base_trial import BaseTrial
1414
from ax.core.data import Data
1515
from ax.core.metric import Metric
1616
from ax.core.runner import Runner as ax_Runner
17-
from ax.service.scheduler import Scheduler as ax_Scheduler, TrialStatus
18-
from ax.utils.common.typeutils import not_none
17+
from ax.service.scheduler import TrialStatus
1918
from pyre_extensions import none_throws
2019
from torchx.runner import Runner, get_runner
2120
from torchx.runtime.tracking import FsspecResultTracker
@@ -206,25 +205,12 @@ def run(self, trial: BaseTrial) -> Dict[str, Any]:
206205
_TORCHX_TRACKER_BASE: self._tracker_base,
207206
}
208207

209-
210-
class TorchXScheduler(ax_Scheduler):
211-
"""
212-
An implementation of an `Ax Scheduler <https://ax.dev/tutorials/scheduler.html>`_
213-
that works with Experiments hooked up with the ``TorchXRunner``.
214-
215-
This scheduler is not a real scheduler but rather a facade scheduler
216-
that delegates to scheduler clients for various remote/local schedulers.
217-
For a list of supported schedulers please refer to TorchX
218-
`scheduler docs <https://pytorch.org/torchx/latest/schedulers.html>`_.
219-
220-
"""
221-
222208
def poll_trial_status(
223-
self, poll_all_trial_statuses: bool = False
209+
self, trials: Iterable[BaseTrial]
224210
) -> Dict[TrialStatus, Set[int]]:
225211
trial_statuses: Dict[TrialStatus, Set[int]] = {}
226212

227-
for trial in self.running_trials:
213+
for trial in trials:
228214
app_handle: str = trial.run_metadata[_TORCHX_APP_HANDLE]
229215
torchx_runner = trial.run_metadata[_TORCHX_RUNNER]
230216
app_status: AppStatus = torchx_runner.status(app_handle)
@@ -250,8 +236,4 @@ def poll_available_capacity(self) -> int:
250236
251237
"""
252238

253-
return (
254-
-1
255-
if self.generation_strategy._curr.max_parallelism is None
256-
else not_none(self.generation_strategy._curr.max_parallelism)
257-
)
239+
return super().poll_available_capacity()

torchx/runtime/hpo/test/ax_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,16 @@
2222
SearchSpace,
2323
)
2424
from ax.modelbridge.dispatch_utils import choose_generation_strategy
25-
from ax.service.scheduler import SchedulerOptions
25+
from ax.service.scheduler import SchedulerOptions, Scheduler
2626
from ax.service.utils.best_point import get_best_parameters
2727
from ax.service.utils.report_utils import exp_to_df
2828
from ax.utils.common.constants import Keys
2929
from pyre_extensions import none_throws
3030
from torchx.components import utils
31-
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner, TorchXScheduler
31+
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner
3232

3333

34-
class TorchXSchedulerTest(unittest.TestCase):
34+
class TorchXAxTest(unittest.TestCase):
3535
def setUp(self) -> None:
3636
self.test_dir = tempfile.mkdtemp("torchx_runtime_hpo_ax_test")
3737

@@ -86,7 +86,7 @@ def test_run_experiment_locally(self) -> None:
8686

8787
# maybe add-on cfg into SchedulerOption?
8888
# so that we can pass it from one place
89-
scheduler = TorchXScheduler(
89+
scheduler = Scheduler(
9090
experiment=experiment,
9191
generation_strategy=(
9292
choose_generation_strategy(
@@ -131,7 +131,7 @@ def test_run_experiment_locally_in_batches(self) -> None:
131131

132132
# maybe add-on cfg into SchedulerOption?
133133
# so that we can pass it from one place
134-
scheduler = TorchXScheduler(
134+
scheduler = Scheduler(
135135
experiment=experiment,
136136
generation_strategy=(
137137
choose_generation_strategy(

0 commit comments

Comments
 (0)