Skip to content

Commit cd1b088

Browse files
committed
slurm_scheduler: add more runopts. Fixes #389
1 parent 8b62ea8 commit cd1b088

File tree

3 files changed

+112
-38
lines changed

3 files changed

+112
-38
lines changed

scripts/slurmtest.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ source "$VENV"/bin/activate
1818
python --version
1919
pip install "$REMOTE_WHEEL"
2020

21-
APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10 utils.echo --num_replicas 3)"
21+
APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10,comment=hello utils.echo --num_replicas 3)"
2222
torchx status "$APP_ID"
2323
torchx describe "$APP_ID"
2424
sacct -j "$(basename "$APP_ID")"

torchx/schedulers/slurm_scheduler.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,19 @@
5454
"TIMEOUT": AppState.FAILED,
5555
}
5656

57-
SBATCH_OPTIONS = {
57+
SBATCH_JOB_OPTIONS = {
58+
"comment",
59+
"mail-user",
60+
"mail-type",
61+
}
62+
SBATCH_GROUP_OPTIONS = {
5863
"partition",
5964
"time",
65+
"constraint",
6066
}
6167

6268

69+
6370
def _apply_app_id_env(s: str) -> str:
6471
"""
6572
_apply_app_id_env escapes the argument and substitutes in the macros.app_id with
@@ -90,7 +97,7 @@ def from_role(
9097
for k, v in cfg.items():
9198
if v is None:
9299
continue
93-
if k in SBATCH_OPTIONS:
100+
if k in SBATCH_GROUP_OPTIONS:
94101
sbatch_opts[k] = str(v)
95102
sbatch_opts.setdefault("ntasks-per-node", "1")
96103
resource = role.resource
@@ -271,6 +278,26 @@ def run_opts(self) -> runopts:
271278
default=False,
272279
help="disables memory request to workaround https://github.com/aws/aws-parallelcluster/issues/2198",
273280
)
281+
opts.add(
282+
"comment",
283+
type_=str,
284+
help="Comment to set on the slurm job.",
285+
)
286+
opts.add(
287+
"constraint",
288+
type_=str,
289+
help="Constraint to use for the slurm job.",
290+
)
291+
opts.add(
292+
"mail-user",
293+
type_=str,
294+
help="User to mail on job end.",
295+
)
296+
opts.add(
297+
"mail-type",
298+
type_=str,
299+
help="What events to mail users on.",
300+
)
274301
return opts
275302

276303
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
@@ -301,8 +328,14 @@ def _submit_dryrun(
301328
name = f"{role.name}-{replica_id}"
302329
replica_role = values.apply(role)
303330
replicas[name] = SlurmReplicaRequest.from_role(name, replica_role, cfg)
331+
cmd = ["sbatch", "--parsable"]
332+
333+
for k in SBATCH_JOB_OPTIONS:
334+
if k in cfg and cfg[k] is not None:
335+
cmd += [f"--{k}={cfg[k]}"]
336+
304337
req = SlurmBatchRequest(
305-
cmd=["sbatch", "--parsable"],
338+
cmd=cmd,
306339
replicas=replicas,
307340
)
308341
return AppDryRunInfo(req, repr)

torchx/schedulers/test/slurm_scheduler_test.py

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,49 @@ def tmp_cwd() -> Generator[None, None, None]:
3333
finally:
3434
os.chdir(cwd)
3535

36+
def simple_role() -> specs.Role:
37+
return specs.Role(
38+
name="foo",
39+
image="/some/path",
40+
entrypoint="echo",
41+
args=["hello slurm", "test"],
42+
env={
43+
"FOO": "bar",
44+
},
45+
num_replicas=5,
46+
resource=specs.Resource(
47+
cpu=2,
48+
memMB=10,
49+
gpu=3,
50+
),
51+
)
52+
53+
def simple_app() -> specs.AppDef:
54+
return specs.AppDef(
55+
name="foo",
56+
roles=[
57+
specs.Role(
58+
name="a",
59+
image="/some/path",
60+
entrypoint="echo",
61+
args=[specs.macros.replica_id, f"hello {specs.macros.app_id}"],
62+
num_replicas=2,
63+
),
64+
specs.Role(
65+
name="b",
66+
image="/some/path",
67+
entrypoint="echo",
68+
),
69+
],
70+
)
3671

3772
class SlurmSchedulerTest(unittest.TestCase):
3873
def test_create_scheduler(self) -> None:
3974
scheduler = create_scheduler("foo")
4075
self.assertIsInstance(scheduler, SlurmScheduler)
4176

4277
def test_replica_request(self) -> None:
43-
role = specs.Role(
44-
name="foo",
45-
image="/some/path",
46-
entrypoint="echo",
47-
args=["hello slurm", "test"],
48-
env={
49-
"FOO": "bar",
50-
},
51-
num_replicas=5,
52-
resource=specs.Resource(
53-
cpu=2,
54-
memMB=10,
55-
gpu=3,
56-
),
57-
)
78+
role = simple_role()
5879
sbatch, srun = SlurmReplicaRequest.from_role(
5980
"role-0", role, cfg={}
6081
).materialize()
@@ -79,9 +100,9 @@ def test_replica_request(self) -> None:
79100
],
80101
)
81102

82-
# test nomem option
103+
def test_replica_request_nomem(self) -> None:
83104
sbatch, srun = SlurmReplicaRequest.from_role(
84-
"role-name", role, cfg={"nomem": True}
105+
"role-name", simple_role(), cfg={"nomem": True}
85106
).materialize()
86107
self.assertEqual(
87108
sbatch,
@@ -93,6 +114,15 @@ def test_replica_request(self) -> None:
93114
],
94115
)
95116

117+
def test_replica_request_constraint(self) -> None:
118+
sbatch, srun = SlurmReplicaRequest.from_role(
119+
"role-name", simple_role(), cfg={"constraint": "orange"}
120+
).materialize()
121+
self.assertIn(
122+
"--constraint=orange",
123+
sbatch,
124+
)
125+
96126
def test_replica_request_app_id(self) -> None:
97127
role = specs.Role(
98128
name="foo",
@@ -132,23 +162,7 @@ def test_replica_request_run_config(self) -> None:
132162

133163
def test_dryrun_multi_role(self) -> None:
134164
scheduler = create_scheduler("foo")
135-
app = specs.AppDef(
136-
name="foo",
137-
roles=[
138-
specs.Role(
139-
name="a",
140-
image="/some/path",
141-
entrypoint="echo",
142-
args=[specs.macros.replica_id, f"hello {specs.macros.app_id}"],
143-
num_replicas=2,
144-
),
145-
specs.Role(
146-
name="b",
147-
image="/some/path",
148-
entrypoint="echo",
149-
),
150-
],
151-
)
165+
app = simple_app()
152166
info = scheduler.submit_dryrun(app, cfg={})
153167
req = info.request
154168
self.assertIsInstance(req, SlurmBatchRequest)
@@ -344,3 +358,30 @@ def test_log_iter(self, run: MagicMock) -> None:
344358
)
345359
)
346360
self.assertEqual(logs, ["hello", "world"])
361+
362+
def test_dryrun_comment(self) -> None:
363+
scheduler = create_scheduler("foo")
364+
app = simple_app()
365+
info = scheduler.submit_dryrun(app, cfg={
366+
"comment": "banana foo bar",
367+
})
368+
self.assertIn(
369+
"--comment=banana foo bar",
370+
info.request.cmd,
371+
)
372+
373+
def test_dryrun_mail(self) -> None:
374+
scheduler = create_scheduler("foo")
375+
app = simple_app()
376+
info = scheduler.submit_dryrun(app, cfg={
377+
"mail-user": "[email protected]",
378+
"mail-type": "END",
379+
})
380+
self.assertIn(
381+
382+
info.request.cmd,
383+
)
384+
self.assertIn(
385+
"--mail-type=END",
386+
info.request.cmd,
387+
)

0 commit comments

Comments
 (0)