Skip to content

Commit b379b81

Browse files
committed
cli: defer loading schedulers until requested by name
1 parent a551741 commit b379b81

File tree

13 files changed

+74
-84
lines changed

13 files changed

+74
-84
lines changed

torchx/cli/cmd_cancel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from torchx.cli.cmd_base import SubCommand
1212
from torchx.runner import get_runner
13-
from torchx.specs.api import parse_app_handle
1413

1514
logger: logging.Logger = logging.getLogger(__name__)
1615

@@ -25,6 +24,5 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
2524

2625
def run(self, args: argparse.Namespace) -> None:
2726
app_handle = args.app_handle
28-
_, session_name, _ = parse_app_handle(app_handle)
29-
runner = get_runner(name=session_name)
27+
runner = get_runner()
3028
runner.cancel(app_handle)

torchx/cli/cmd_describe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
2828

2929
def run(self, args: argparse.Namespace) -> None:
3030
app_handle = args.app_handle
31-
scheduler, session_name, app_id = parse_app_handle(app_handle)
32-
runner = get_runner(name=session_name)
31+
scheduler, _, app_id = parse_app_handle(app_handle)
32+
runner = get_runner()
3333
app = runner.describe(app_handle)
3434

3535
if app:
3636
pprint.pprint(dataclasses.asdict(app), indent=2, width=80)
3737
else:
3838
logger.error(
39-
f"AppDef: {app_id} on session: {session_name},"
39+
f"AppDef: {app_id},"
4040
f" does not exist or has been removed from {scheduler}'s data plane"
4141
)
4242
sys.exit(1)

torchx/cli/cmd_log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def get_logs(
9292
role_name = path[2] if len(path) > 2 else None
9393

9494
if not runner:
95-
runner = get_runner(name=session_name)
95+
runner = get_runner()
9696
app_handle = make_app_handle(scheduler_backend, session_name, app_id)
9797

9898
if len(path) == 4:

torchx/cli/cmd_run.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,14 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
125125
default=get_default_scheduler_name(),
126126
choices=list(scheduler_names),
127127
action=torchxconfig_run,
128-
help=f"Name of the scheduler to use. One of: [{','.join(scheduler_names)}]",
128+
help="Name of the scheduler to use.",
129129
)
130130
subparser.add_argument(
131131
"-cfg",
132132
"--scheduler_args",
133133
type=str,
134134
help="Arguments to pass to the scheduler (Ex:`cluster=foo,user=bar`)."
135-
" For a list of scheduler run options run: `torchx runopts`"
136-
"",
135+
" For a list of scheduler run options run: `torchx runopts`",
137136
)
138137
subparser.add_argument(
139138
"--dryrun",

torchx/cli/cmd_status.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,15 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
159159

160160
def run(self, args: argparse.Namespace) -> None:
161161
app_handle = args.app_handle
162-
scheduler, session_name, app_id = parse_app_handle(app_handle)
163-
runner = get_runner(name=session_name)
162+
scheduler, _, app_id = parse_app_handle(app_handle)
163+
runner = get_runner()
164164
app_status = runner.status(app_handle)
165165
filter_roles = parse_list_arg(args.roles)
166166
if app_status:
167167
logger.info(format_app_status(app_status, filter_roles))
168168
else:
169169
logger.error(
170-
f"AppDef: {app_id} on session: {session_name},"
170+
f"AppDef: {app_id},"
171171
f" does not exist or has been removed from {scheduler}'s data plane"
172172
)
173173
sys.exit(1)

torchx/cli/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def get_sub_cmds() -> Dict[str, SubCommand]:
5757
override_sub_cmds = load_group(
5858
"torchx.cli.cmds",
5959
default={},
60-
ignore_missing=True,
6160
)
6261
for cmd_name, cmd_cls in override_sub_cmds.items():
6362
sub_cmds[cmd_name] = cmd_cls()

torchx/cli/test/main_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
import argparse
99
import os
10+
import sys
1011
import unittest
1112
from pathlib import Path
1213
from unittest.mock import MagicMock, patch
1314

1415
from torchx.cli.cmd_base import SubCommand
15-
from torchx.cli.main import get_sub_cmds, main
16+
from torchx.cli.main import create_parser, get_sub_cmds, main
1617

1718

1819
_root: Path = Path(__file__).parent
@@ -68,6 +69,11 @@ def test_version(self) -> None:
6869
]
6970
)
7071

72+
def test_imports(self) -> None:
73+
parser = create_parser(get_sub_cmds())
74+
for scheduler in sys.modules:
75+
self.assertNotIn("local_scheduler", scheduler)
76+
7177
@patch("torchx.cli.main.load_group")
7278
def test_get_sub_cmds(self, load_group_mock: MagicMock) -> None:
7379
load_group_mock.return_value = {"run": _TestCmd}

torchx/runner/api.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pyre_extensions import none_throws
1717
from torchx.runner.events import log_event
18-
from torchx.schedulers import get_schedulers
18+
from torchx.schedulers import get_scheduler_factories, SchedulerFactory
1919
from torchx.schedulers.api import Scheduler, Stream
2020
from torchx.specs import (
2121
AppDef,
@@ -50,9 +50,9 @@ class Runner:
5050
def __init__(
5151
self,
5252
name: str,
53-
# pyre-fixme: Scheduler opts
54-
schedulers: Dict[str, Scheduler],
53+
schedulers_factories: Dict[str, SchedulerFactory],
5554
component_defaults: Optional[Dict[str, Dict[str, str]]] = None,
55+
scheduler_params: Optional[Dict[str, object]] = None,
5656
) -> None:
5757
"""
5858
Creates a new runner instance.
@@ -63,7 +63,10 @@ def __init__(
6363
schedulers: a list of schedulers the runner can use.
6464
"""
6565
self._name: str = name
66-
self._schedulers = schedulers
66+
self._scheduler_factories = schedulers_factories
67+
self._scheduler_params: Dict[str, object] = scheduler_params or {}
68+
# pyre-ignore[24]: Scheduler opts
69+
self._scheduler_instances: Dict[str, Scheduler] = {}
6770
self._apps: Dict[AppHandle, AppDef] = {}
6871

6972
# component_name -> map of component_fn_param_name -> user-specified default val encoded as str
@@ -96,7 +99,7 @@ def close(self) -> None:
9699
It is ok to call this method multiple times on the same runner object.
97100
"""
98101

99-
for name, scheduler in self._schedulers.items():
102+
for name, scheduler in self._scheduler_instances.items():
100103
scheduler.close()
101104

102105
def run_component(
@@ -319,15 +322,15 @@ def run_opts(self) -> Dict[str, runopts]:
319322
A map of scheduler backend to its ``runopts``
320323
"""
321324
return {
322-
scheduler_backend: scheduler.run_opts()
323-
for scheduler_backend, scheduler in self._schedulers.items()
325+
name: self._scheduler(name).run_opts()
326+
for name in self._scheduler_factories
324327
}
325328

326329
def scheduler_backends(self) -> List[str]:
327330
"""
328331
Returns a list of all supported scheduler backends.
329332
"""
330-
return list(self._schedulers.keys())
333+
return list(self._scheduler_factories.keys())
331334

332335
def status(self, app_handle: AppHandle) -> Optional[AppStatus]:
333336
"""
@@ -557,10 +560,15 @@ def log_lines(
557560

558561
# pyre-fixme: Scheduler opts
559562
def _scheduler(self, scheduler: str) -> Scheduler:
560-
sched = self._schedulers.get(scheduler)
563+
sched = self._scheduler_instances.get(scheduler)
564+
if not sched:
565+
factory = self._scheduler_factories.get(scheduler)
566+
if factory:
567+
sched = factory(self._name, **self._scheduler_params)
568+
self._scheduler_instances[scheduler] = sched
561569
if not sched:
562570
raise KeyError(
563-
f"Undefined scheduler backend: {scheduler}. Use one of: {self._schedulers.keys()}"
571+
f"Undefined scheduler backend: {scheduler}. Use one of: {self._scheduler_factories.keys()}"
564572
)
565573
return sched
566574

@@ -586,7 +594,7 @@ def _scheduler_app_id(
586594
return scheduler, scheduler_backend, app_id
587595

588596
def __repr__(self) -> str:
589-
return f"Runner(name={self._name}, schedulers={self._schedulers}, apps={self._apps})"
597+
return f"Runner(name={self._name}, schedulers={self._scheduler_factories}, apps={self._apps})"
590598

591599

592600
def get_runner(
@@ -633,5 +641,5 @@ def get_runner(
633641
if not name:
634642
name = "torchx"
635643

636-
schedulers = get_schedulers(session_name=name, **scheduler_params)
637-
return Runner(name, schedulers, component_defaults)
644+
schedulers = get_scheduler_factories()
645+
return Runner(name, schedulers, component_defaults, scheduler_params=scheduler_params)

torchx/schedulers/__init__.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,36 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
from typing import Dict, Optional
8+
import importlib
9+
from typing import Dict, Mapping
910

10-
import torchx.schedulers.aws_batch_scheduler as aws_batch_scheduler
11-
import torchx.schedulers.docker_scheduler as docker_scheduler
12-
import torchx.schedulers.kubernetes_scheduler as kubernetes_scheduler
13-
import torchx.schedulers.local_scheduler as local_scheduler
14-
import torchx.schedulers.slurm_scheduler as slurm_scheduler
1511
from torchx.schedulers.api import Scheduler
1612
from torchx.util.entrypoints import load_group
1713
from typing_extensions import Protocol
1814

15+
DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
16+
"local_docker": "torchx.schedulers.docker_scheduler",
17+
"local_cwd": "torchx.schedulers.local_scheduler",
18+
"slurm": "torchx.schedulers.slurm_scheduler",
19+
"kubernetes": "torchx.schedulers.kubernetes_scheduler",
20+
"aws_batch": "torchx.schedulers.aws_batch_scheduler",
21+
"ray": "torchx.schedulers.ray_scheduler",
22+
}
23+
1924

2025
class SchedulerFactory(Protocol):
2126
# pyre-fixme: Scheduler opts
2227
def __call__(self, session_name: str, **kwargs: object) -> Scheduler:
2328
...
2429

2530

26-
def _try_get_ray_scheduler() -> Optional[SchedulerFactory]:
27-
try:
28-
from torchx.schedulers.ray_scheduler import _has_ray # @manual
29-
30-
if _has_ray:
31-
import torchx.schedulers.ray_scheduler as ray_scheduler # @manual
32-
33-
return ray_scheduler.create_scheduler
31+
def _defer_load_scheduler(path: str) -> SchedulerFactory:
32+
# pyre-ignore[24]: Scheduler opts
33+
def run(*args: object, **kwargs: object) -> Scheduler:
34+
module = importlib.import_module(path)
35+
return module.create_scheduler(*args, **kwargs)
3436

35-
except ImportError: # pragma: no cover
36-
return None
37+
return run
3738

3839

3940
def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
@@ -44,22 +45,13 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
4445
The first scheduler in the dictionary is used as the default scheduler.
4546
"""
4647

47-
default_schedulers: Dict[str, SchedulerFactory] = {
48-
"local_docker": docker_scheduler.create_scheduler,
49-
"local_cwd": local_scheduler.create_cwd_scheduler,
50-
"slurm": slurm_scheduler.create_scheduler,
51-
"kubernetes": kubernetes_scheduler.create_scheduler,
52-
"aws_batch": aws_batch_scheduler.create_scheduler,
53-
}
54-
55-
ray_scheduler_creator = _try_get_ray_scheduler()
56-
if ray_scheduler_creator:
57-
default_schedulers["ray"] = ray_scheduler_creator
48+
default_schedulers: Dict[str, SchedulerFactory] = {}
49+
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
50+
default_schedulers[scheduler] = _defer_load_scheduler(path)
5851

5952
return load_group(
6053
"torchx.schedulers",
6154
default=default_schedulers,
62-
ignore_missing=True,
6355
)
6456

6557

torchx/schedulers/local_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ def __next__(self) -> str:
10631063
return line
10641064

10651065

1066-
def create_cwd_scheduler(session_name: str, **kwargs: Any) -> LocalScheduler:
1066+
def create_scheduler(session_name: str, **kwargs: Any) -> LocalScheduler:
10671067
return LocalScheduler(
10681068
session_name=session_name,
10691069
cache_size=kwargs.get("cache_size", 100),

0 commit comments

Comments
 (0)