Skip to content

Commit f5278cc

Browse files
aivanoufacebook-github-bot
authored andcommitted
Move poll_trial_status and poll_available_capacity to TorchXRunner, remove TorchXScheduler (#361)
Summary: Pull Request resolved: #361 Adjusting the TorchX setup following D31031589 and D31032567 Deprecate `torchx.runtime.hpo.ax.TorchXScheduler` Reviewed By: lena-kashtelyan Differential Revision: D33062113 fbshipit-source-id: 3166c11f6392c52039f1f08140e2d831772bba06
1 parent 9784454 commit f5278cc

File tree

6 files changed

+17
-59
lines changed

6 files changed

+17
-59
lines changed

dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
aiobotocore==2.1.0
2-
ax-platform[mysql]==0.2.2
2+
ax-platform[mysql]==0.2.3
33
black==21.10b0
44
boto3==1.20.24
55
captum>=0.4.0

docs/source/runtime/hpo.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,4 @@ Ax (Adaptive Experimentation)
1414
.. currentmodule:: torchx.runtime.hpo.ax
1515

1616
.. autoclass:: TorchXRunner
17-
.. autoclass:: TorchXScheduler
18-
.. autoclass:: AppMetric
17+
.. autoclass:: AppMetric

torchx/cli/test/cmd_run_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def setUp(self) -> None:
4545
self.cmd_run.add_arguments(self.parser)
4646

4747
def tearDown(self) -> None:
48-
shutil.rmtree(self.tmpdir)
48+
shutil.rmtree(self.tmpdir, ignore_errors=True)
4949

5050
def test_run_with_user_conf_abs_path(self) -> None:
5151
args = self.parser.parse_args(

torchx/runtime/hpo/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454
SearchSpace,
5555
)
5656
from ax.modelbridge.dispatch_utils import choose_generation_strategy
57-
from ax.service.scheduler import SchedulerOptions
57+
from ax.service.scheduler import SchedulerOptions, Scheduler
5858
from ax.service.utils.best_point import get_best_parameters
5959
from ax.service.utils.report_utils import exp_to_df
6060
from ax.utils.common.constants import Keys
6161
from pyre_extensions import none_throws
6262
from torchx.components import utils
63-
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner, TorchXScheduler
63+
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner
6464
6565
# Run HPO on the booth function (https://en.wikipedia.org/wiki/Test_functions_for_optimization)
6666
@@ -100,7 +100,7 @@
100100
properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True},
101101
)
102102
103-
scheduler = TorchXScheduler(
103+
scheduler = Scheduler(
104104
experiment=experiment,
105105
generation_strategy=(
106106
choose_generation_strategy(

torchx/runtime/hpo/ax.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,15 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import inspect
9-
from typing import Iterable, Any, Callable, Dict, Mapping, Optional, Set, cast
9+
from typing import Any, Callable, Dict, Mapping, Optional, Set, cast, Iterable
1010

1111
import pandas as pd
1212
from ax.core import Trial
1313
from ax.core.base_trial import BaseTrial
1414
from ax.core.data import Data
1515
from ax.core.metric import Metric
1616
from ax.core.runner import Runner as ax_Runner
17-
from ax.service.scheduler import Scheduler as ax_Scheduler, TrialStatus
18-
from ax.utils.common.typeutils import not_none
17+
from ax.service.scheduler import TrialStatus
1918
from pyre_extensions import none_throws
2019
from torchx.runner import Runner, get_runner
2120
from torchx.runtime.tracking import FsspecResultTracker
@@ -209,14 +208,14 @@ def run(self, trial: BaseTrial) -> Dict[str, Any]:
209208
def poll_trial_status(
210209
self, trials: Iterable[BaseTrial]
211210
) -> Dict[TrialStatus, Set[int]]:
212-
"""Returns the statuses of the given trials."""
213211
trial_statuses: Dict[TrialStatus, Set[int]] = {}
214212

215213
for trial in trials:
216214
app_handle: str = trial.run_metadata[_TORCHX_APP_HANDLE]
217-
app_status: Optional[AppStatus] = self._torchx_runner.status(app_handle)
218-
assert app_status is not None
215+
torchx_runner = trial.run_metadata[_TORCHX_RUNNER]
216+
app_status: AppStatus = torchx_runner.status(app_handle)
219217
trial_status = APP_STATE_TO_TRIAL_STATUS[app_status.state]
218+
220219
indices = trial_statuses.setdefault(trial_status, set())
221220
indices.add(trial.index)
222221

@@ -227,43 +226,3 @@ def stop(self, trial: BaseTrial, reason: Optional[str] = None) -> Dict[str, Any]
227226
app_handle: str = trial.run_metadata[_TORCHX_APP_HANDLE]
228227
self._torchx_runner.stop(app_handle)
229228
return {"reason": reason} if reason else {}
230-
231-
232-
class TorchXScheduler(ax_Scheduler):
233-
"""
234-
An implementation of an `Ax Scheduler <https://ax.dev/tutorials/scheduler.html>`_
235-
that works with Experiments hooked up with the ``TorchXRunner``.
236-
237-
This scheduler is not a real scheduler but rather a facade scheduler
238-
that delegates to scheduler clients for various remote/local schedulers.
239-
For a list of supported schedulers please refer to TorchX
240-
`scheduler docs <https://pytorch.org/torchx/latest/schedulers.html>`_.
241-
242-
"""
243-
244-
def poll_trial_status(
245-
self, poll_all_trial_statuses: bool = False
246-
) -> Dict[TrialStatus, Set[int]]:
247-
return cast(TorchXRunner, self.experiment.runner).poll_trial_status(
248-
self.running_trials
249-
)
250-
251-
def poll_available_capacity(self) -> int:
252-
"""
253-
Used when ``run_trials_in_batches`` option is set.
254-
Since this scheduler is a faux scheduler, this method
255-
always returns the ``max_parallelism`` of the current
256-
step of this scheduler's ``generation_strategy``.
257-
258-
.. note:: The trials (jobs) are simply submitted to the
259-
scheduler in parallel. Typically the trials will be
260-
queued in the scheduler's job queue (on the server-side)
261-
and executed according to the scheduler's job priority
262-
and scheduling policies.
263-
264-
"""
265-
return (
266-
-1
267-
if self.generation_strategy._curr.max_parallelism is None
268-
else not_none(self.generation_strategy._curr.max_parallelism)
269-
)

torchx/runtime/hpo/test/ax_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
SearchSpace,
2323
)
2424
from ax.modelbridge.dispatch_utils import choose_generation_strategy
25-
from ax.service.scheduler import SchedulerOptions
25+
from ax.service.scheduler import SchedulerOptions, Scheduler
2626
from ax.service.utils.report_utils import exp_to_df
2727
from ax.utils.common.constants import Keys
2828
from torchx.components import utils
29-
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner, TorchXScheduler
29+
from torchx.runtime.hpo.ax import AppMetric, TorchXRunner
3030

3131

32-
class TorchXSchedulerTest(unittest.TestCase):
32+
class TorchXAxTest(unittest.TestCase):
3333
def setUp(self) -> None:
3434
self.test_dir = tempfile.mkdtemp("torchx_runtime_hpo_ax_test")
3535

@@ -84,7 +84,7 @@ def test_run_experiment_locally(self) -> None:
8484

8585
# maybe add-on cfg into SchedulerOption?
8686
# so that we can pass it from one place
87-
scheduler = TorchXScheduler(
87+
scheduler = Scheduler(
8888
experiment=experiment,
8989
generation_strategy=(
9090
choose_generation_strategy(
@@ -114,7 +114,7 @@ def test_stop_trials(self) -> None:
114114
is_test=True,
115115
properties={Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF: True},
116116
)
117-
scheduler = TorchXScheduler(
117+
scheduler = Scheduler(
118118
experiment=experiment,
119119
generation_strategy=(
120120
choose_generation_strategy(
@@ -152,7 +152,7 @@ def test_run_experiment_locally_in_batches(self) -> None:
152152

153153
# maybe add-on cfg into SchedulerOption?
154154
# so that we can pass it from one place
155-
scheduler = TorchXScheduler(
155+
scheduler = Scheduler(
156156
experiment=experiment,
157157
generation_strategy=(
158158
choose_generation_strategy(

0 commit comments

Comments
 (0)