Skip to content

Commit e6285e0

Browse files
committed
slurm_scheduler: add more runopts. Fixes #389
1 parent 8b62ea8 commit e6285e0

File tree

3 files changed

+120
-38
lines changed

3 files changed

+120
-38
lines changed

scripts/slurmtest.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ source "$VENV"/bin/activate
1818
python --version
1919
pip install "$REMOTE_WHEEL"
2020

21-
APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10 utils.echo --num_replicas 3)"
21+
APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10,comment=hello utils.echo --num_replicas 3)"
2222
torchx status "$APP_ID"
2323
torchx describe "$APP_ID"
2424
sacct -j "$(basename "$APP_ID")"

torchx/schedulers/slurm_scheduler.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,15 @@
5454
"TIMEOUT": AppState.FAILED,
5555
}
5656

57-
SBATCH_OPTIONS = {
57+
SBATCH_JOB_OPTIONS = {
58+
"comment",
59+
"mail-user",
60+
"mail-type",
61+
}
62+
SBATCH_GROUP_OPTIONS = {
5863
"partition",
5964
"time",
65+
"constraint",
6066
}
6167

6268

@@ -90,7 +96,7 @@ def from_role(
9096
for k, v in cfg.items():
9197
if v is None:
9298
continue
93-
if k in SBATCH_OPTIONS:
99+
if k in SBATCH_GROUP_OPTIONS:
94100
sbatch_opts[k] = str(v)
95101
sbatch_opts.setdefault("ntasks-per-node", "1")
96102
resource = role.resource
@@ -271,6 +277,26 @@ def run_opts(self) -> runopts:
271277
default=False,
272278
help="disables memory request to workaround https://github.com/aws/aws-parallelcluster/issues/2198",
273279
)
280+
opts.add(
281+
"comment",
282+
type_=str,
283+
help="Comment to set on the slurm job.",
284+
)
285+
opts.add(
286+
"constraint",
287+
type_=str,
288+
help="Constraint to use for the slurm job.",
289+
)
290+
opts.add(
291+
"mail-user",
292+
type_=str,
293+
help="User to mail on job end.",
294+
)
295+
opts.add(
296+
"mail-type",
297+
type_=str,
298+
help="What events to mail users on.",
299+
)
274300
return opts
275301

276302
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
@@ -301,8 +327,14 @@ def _submit_dryrun(
301327
name = f"{role.name}-{replica_id}"
302328
replica_role = values.apply(role)
303329
replicas[name] = SlurmReplicaRequest.from_role(name, replica_role, cfg)
330+
cmd = ["sbatch", "--parsable"]
331+
332+
for k in SBATCH_JOB_OPTIONS:
333+
if k in cfg and cfg[k] is not None:
334+
cmd += [f"--{k}={cfg[k]}"]
335+
304336
req = SlurmBatchRequest(
305-
cmd=["sbatch", "--parsable"],
337+
cmd=cmd,
306338
replicas=replicas,
307339
)
308340
return AppDryRunInfo(req, repr)

torchx/schedulers/test/slurm_scheduler_test.py

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,51 @@ def tmp_cwd() -> Generator[None, None, None]:
3434
os.chdir(cwd)
3535

3636

37+
def simple_role() -> specs.Role:
38+
return specs.Role(
39+
name="foo",
40+
image="/some/path",
41+
entrypoint="echo",
42+
args=["hello slurm", "test"],
43+
env={
44+
"FOO": "bar",
45+
},
46+
num_replicas=5,
47+
resource=specs.Resource(
48+
cpu=2,
49+
memMB=10,
50+
gpu=3,
51+
),
52+
)
53+
54+
55+
def simple_app() -> specs.AppDef:
56+
return specs.AppDef(
57+
name="foo",
58+
roles=[
59+
specs.Role(
60+
name="a",
61+
image="/some/path",
62+
entrypoint="echo",
63+
args=[specs.macros.replica_id, f"hello {specs.macros.app_id}"],
64+
num_replicas=2,
65+
),
66+
specs.Role(
67+
name="b",
68+
image="/some/path",
69+
entrypoint="echo",
70+
),
71+
],
72+
)
73+
74+
3775
class SlurmSchedulerTest(unittest.TestCase):
3876
def test_create_scheduler(self) -> None:
3977
scheduler = create_scheduler("foo")
4078
self.assertIsInstance(scheduler, SlurmScheduler)
4179

4280
def test_replica_request(self) -> None:
43-
role = specs.Role(
44-
name="foo",
45-
image="/some/path",
46-
entrypoint="echo",
47-
args=["hello slurm", "test"],
48-
env={
49-
"FOO": "bar",
50-
},
51-
num_replicas=5,
52-
resource=specs.Resource(
53-
cpu=2,
54-
memMB=10,
55-
gpu=3,
56-
),
57-
)
81+
role = simple_role()
5882
sbatch, srun = SlurmReplicaRequest.from_role(
5983
"role-0", role, cfg={}
6084
).materialize()
@@ -79,9 +103,9 @@ def test_replica_request(self) -> None:
79103
],
80104
)
81105

82-
# test nomem option
106+
def test_replica_request_nomem(self) -> None:
83107
sbatch, srun = SlurmReplicaRequest.from_role(
84-
"role-name", role, cfg={"nomem": True}
108+
"role-name", simple_role(), cfg={"nomem": True}
85109
).materialize()
86110
self.assertEqual(
87111
sbatch,
@@ -93,6 +117,15 @@ def test_replica_request(self) -> None:
93117
],
94118
)
95119

120+
def test_replica_request_constraint(self) -> None:
121+
sbatch, srun = SlurmReplicaRequest.from_role(
122+
"role-name", simple_role(), cfg={"constraint": "orange"}
123+
).materialize()
124+
self.assertIn(
125+
"--constraint=orange",
126+
sbatch,
127+
)
128+
96129
def test_replica_request_app_id(self) -> None:
97130
role = specs.Role(
98131
name="foo",
@@ -132,23 +165,7 @@ def test_replica_request_run_config(self) -> None:
132165

133166
def test_dryrun_multi_role(self) -> None:
134167
scheduler = create_scheduler("foo")
135-
app = specs.AppDef(
136-
name="foo",
137-
roles=[
138-
specs.Role(
139-
name="a",
140-
image="/some/path",
141-
entrypoint="echo",
142-
args=[specs.macros.replica_id, f"hello {specs.macros.app_id}"],
143-
num_replicas=2,
144-
),
145-
specs.Role(
146-
name="b",
147-
image="/some/path",
148-
entrypoint="echo",
149-
),
150-
],
151-
)
168+
app = simple_app()
152169
info = scheduler.submit_dryrun(app, cfg={})
153170
req = info.request
154171
self.assertIsInstance(req, SlurmBatchRequest)
@@ -344,3 +361,36 @@ def test_log_iter(self, run: MagicMock) -> None:
344361
)
345362
)
346363
self.assertEqual(logs, ["hello", "world"])
364+
365+
def test_dryrun_comment(self) -> None:
366+
scheduler = create_scheduler("foo")
367+
app = simple_app()
368+
info = scheduler.submit_dryrun(
369+
app,
370+
cfg={
371+
"comment": "banana foo bar",
372+
},
373+
)
374+
self.assertIn(
375+
"--comment=banana foo bar",
376+
info.request.cmd,
377+
)
378+
379+
def test_dryrun_mail(self) -> None:
380+
scheduler = create_scheduler("foo")
381+
app = simple_app()
382+
info = scheduler.submit_dryrun(
383+
app,
384+
cfg={
385+
"mail-user": "[email protected]",
386+
"mail-type": "END",
387+
},
388+
)
389+
self.assertIn(
390+
391+
info.request.cmd,
392+
)
393+
self.assertIn(
394+
"--mail-type=END",
395+
info.request.cmd,
396+
)

0 commit comments

Comments
 (0)