Skip to content

Commit 4f1ec46

Browse files
committed
dist.ddp: make rendezvous work out of the box on all schedulers
1 parent 4ea97be commit 4f1ec46

21 files changed

+329
-87
lines changed

scripts/awsbatchint.sh

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,28 @@
77

88
set -ex
99

10-
RUN_ARGS="--scheduler aws_batch -c queue=torchx,image_repo=495572122715.dkr.ecr.us-west-2.amazonaws.com/torchx/integration-tests utils.echo"
10+
JOB="$USER-$(uuidgen)"
11+
DIR="/tmp/$JOB"
12+
13+
function cleanup {
14+
rm -r "$DIR"
15+
}
16+
trap cleanup EXIT
17+
18+
mkdir "$DIR"
19+
cd "$DIR"
20+
21+
cat <<EOT > .torchxconfig
22+
[aws_batch]
23+
queue=torchx
24+
image_repo=495572122715.dkr.ecr.us-west-2.amazonaws.com/torchx/integration-tests
25+
EOT
26+
27+
cat <<EOT > main.py
28+
print("hello world!")
29+
EOT
30+
31+
RUN_ARGS="--scheduler aws_batch dist.ddp -j 2x1 --script main.py"
1132

1233
if [ -z "$AWS_ROLE_ARN" ]; then
1334
# only dryrun if no secrets
@@ -19,11 +40,11 @@ else
1940
torchx status "$APP_ID"
2041
torchx describe "$APP_ID"
2142
torchx log "$APP_ID"
22-
LINES="$(torchx log "$APP_ID" | wc -l)"
43+
LINES="$(torchx log "$APP_ID" | grep 'hello world' | wc -l)"
2344

24-
if [ "$LINES" -ne 1 ]
45+
if [ "$LINES" -ne 2 ]
2546
then
26-
echo "expected 1 log lines"
47+
echo "expected 2 log lines"
2748
exit 1
2849
fi
2950
fi

scripts/slurmint.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [[ -z "${SLURM_INSTANCE_MASTER}" ]]; then
2020
fi
2121

2222
JOB="$USER-$(uuidgen)"
23-
DIR="/tmp/$JOB"
23+
DIR="/home/ubuntu/integ-tests/$JOB"
2424
VENV="$DIR/venv"
2525

2626
function run_cmd {
@@ -44,8 +44,8 @@ REMOTE_WHEEL="$DIR/$(basename "$WHEEL")"
4444
SCRIPT="scripts/slurmtest.sh"
4545
REMOTE_SCRIPT="$DIR/$(basename "$SCRIPT")"
4646

47-
run_cmd mkdir "$DIR"
48-
run_cmd virtualenv -p /usr/bin/python3.8 "$VENV"
47+
run_cmd mkdir -p "$DIR"
48+
run_cmd virtualenv -p /home/ubuntu/miniconda3/bin/python "$VENV"
4949
run_scp "$WHEEL" "$REMOTE_WHEEL"
5050
run_scp "$SCRIPT" "$REMOTE_SCRIPT"
5151
run_cmd "$REMOTE_SCRIPT" "$REMOTE_WHEEL" "$VENV"

scripts/slurmtest.sh

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,43 @@ set -ex
1010
REMOTE_WHEEL="$1"
1111
VENV="$2"
1212

13+
BASE_DIR="$(dirname "$REMOTE_WHEEL")"
14+
DIR="$BASE_DIR/project"
15+
mkdir "$DIR"
16+
cd "$DIR"
17+
1318
# shellcheck disable=SC1091
1419
source /opt/slurm/etc/slurm.sh
1520
sbatch --version
1621
# shellcheck disable=SC1090
1722
source "$VENV"/bin/activate
1823
python --version
24+
1925
pip install "$REMOTE_WHEEL"
26+
pip install numpy
27+
pip install torch==1.10.2+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
28+
29+
cat <<EOT > .torchxconfig
30+
[slurm]
31+
partition=compute
32+
time=10
33+
comment=hello
34+
nomem=true
35+
EOT
36+
37+
cat <<EOT > main.py
38+
print("hello world!")
39+
EOT
2040

21-
APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10,comment=hello utils.echo --num_replicas 3)"
41+
APP_ID="$(torchx run --wait --log --scheduler slurm dist.ddp -j 2x1 --script main.py)"
2242
torchx status "$APP_ID"
2343
torchx describe "$APP_ID"
2444
sacct -j "$(basename "$APP_ID")"
2545
torchx log "$APP_ID"
26-
LINES="$(torchx log "$APP_ID" | wc -l)"
46+
LINES="$(torchx log "$APP_ID" | grep 'hello world' | wc -l)"
2747

28-
if [ "$LINES" -ne 3 ]
48+
if [ "$LINES" -ne 2 ]
2949
then
30-
echo "expected 3 log lines"
50+
echo "expected 2 log lines"
3151
exit 1
3252
fi

torchx/components/dist.py

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@
121121
Components APIs
122122
-----------------
123123
"""
124+
import shlex
124125
from pathlib import Path
125-
from typing import Dict, Optional
126+
from typing import Dict, Optional, Iterable
126127

127128
import torchx
128129
import torchx.specs as specs
@@ -131,16 +132,19 @@
131132

132133
def ddp(
133134
*script_args: str,
134-
script: str,
135+
script: Optional[str] = None,
136+
m: Optional[str] = None,
135137
image: str = torchx.IMAGE,
136138
name: Optional[str] = None,
139+
h: Optional[str] = None,
137140
cpu: int = 2,
138141
gpu: int = 0,
139142
memMB: int = 1024,
140-
h: Optional[str] = None,
141143
j: str = "1x2",
142144
env: Optional[Dict[str, str]] = None,
143-
rdzv_endpoint: str = "etcd-server.default.svc.cluster.local:2379",
145+
max_restarts: Optional[int] = None,
146+
rdzv_backend: str = "c10d",
147+
rdzv_endpoint: Optional[str] = None,
144148
) -> specs.AppDef:
145149
"""
146150
Distributed data parallel style application (one role, multi-replica).
@@ -154,6 +158,7 @@ def ddp(
154158
Args:
155159
script_args: arguments to the main module
156160
script: script or binary to run within the image
161+
m: the python module path to run
157162
image: image (e.g. docker)
158163
name: job name override (uses the script name if not specified)
159164
cpu: number of cpus per replica
@@ -162,9 +167,14 @@ def ddp(
162167
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
163168
j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
164169
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
165-
rdzv_endpoint: etcd server endpoint (only matters when nnodes > 1)
170+
max_restarts: the number of restarts allowed
171+
rdzv_backend: rendezvous backend (only matters when nnodes > 1)
172+
rdzv_endpoint: rendezvous server endpoint (only matters when nnodes > 1), defaults to rank0 host for schedulers that support it
166173
"""
167174

175+
if (script is None) == (m is None):
176+
raise ValueError("exactly one of --script and -m must be specified")
177+
168178
rep = j.split("x")
169179
if len(rep) == 1: # num replicas only
170180
nnodes = 1
@@ -175,33 +185,79 @@ def ddp(
175185
else:
176186
raise ValueError(f"Invalid format for -j, usage example: 1x4. Given: {j}")
177187

178-
script_name_noext = Path(script).stem # script name no extension
188+
if script:
189+
# script name/module no extension
190+
role_name = Path(script).stem
191+
elif m:
192+
role_name = m.rpartition(".")[2]
193+
else:
194+
raise ValueError("failed to compute role_name")
195+
196+
if rdzv_endpoint is None:
197+
rdzv_endpoint = _noquote(f"$${macros.rank0_env}:29500")
198+
199+
if nnodes == 1:
200+
rdzv_backend = "c10d"
201+
rdzv_endpoint = "localhost:29500"
202+
203+
if env is None:
204+
env = {}
205+
env.setdefault("LOGLEVEL", "INFO")
206+
207+
cmd = [
208+
"python",
209+
"-m",
210+
"torch.distributed.run",
211+
"--rdzv_backend",
212+
rdzv_backend,
213+
"--rdzv_endpoint",
214+
rdzv_endpoint,
215+
"--rdzv_id",
216+
f"{macros.app_id}",
217+
"--nnodes",
218+
str(nnodes),
219+
"--nproc_per_node",
220+
str(nproc_per_node),
221+
]
222+
if max_restarts is not None:
223+
cmd += ["--max_restarts", str(max_restarts)]
224+
if script is not None:
225+
cmd += [script]
226+
elif m is not None:
227+
cmd += ["-m", m]
228+
cmd += script_args
179229
return specs.AppDef(
180-
name=name or script_name_noext,
230+
name=name or role_name,
181231
roles=[
182232
specs.Role(
183-
name=script_name_noext,
233+
name=role_name,
184234
image=image,
185-
entrypoint="python",
235+
entrypoint="bash",
186236
num_replicas=nnodes,
187237
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
188-
args=[
189-
"-m",
190-
"torch.distributed.run",
191-
"--rdzv_backend",
192-
("c10d" if nnodes == 1 else "etcd"),
193-
"--rdzv_endpoint",
194-
("localhost:29500" if nnodes == 1 else rdzv_endpoint),
195-
"--rdzv_id",
196-
f"{macros.app_id}",
197-
"--nnodes",
198-
str(nnodes),
199-
"--nproc_per_node",
200-
str(nproc_per_node),
201-
script,
202-
*script_args,
203-
],
204-
env=env or {},
238+
args=["-c", _args_join(cmd)],
239+
env=env,
240+
port_map={
241+
"c10d": 29500,
242+
},
205243
)
206244
],
207245
)
246+
247+
248+
def _args_join(args: Iterable[str]) -> str:
249+
"""
250+
_args_join is like shlex.join but if the argument is wrapped in _noquote
251+
it'll not quote that argument.
252+
"""
253+
quoted = [arg if isinstance(arg, _noquote) else shlex.quote(arg) for arg in args]
254+
return " ".join(quoted)
255+
256+
257+
class _noquote(str):
258+
"""
259+
_noquote is a wrapper around str that indicates that the argument shouldn't
260+
be passed through shlex.quote.
261+
"""
262+
263+
pass

torchx/components/integration_tests/component_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ def tearDown(self) -> None:
3636

3737
class DDPComponentProvider(ComponentProvider):
3838
def get_app_def(self) -> AppDef:
39-
rdzv_endpoint: str = "localhost:29400"
4039
return dist_components.ddp(
4140
script="torchx/components/integration_tests/test/dummy_app.py",
4241
name="ddp-trainer",
4342
image=self._image,
44-
rdzv_endpoint=rdzv_endpoint,
43+
j="2x2",
44+
max_restarts=3,
4545
)
4646

4747

torchx/schedulers/api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SchedulerBackend,
2525
runopts,
2626
)
27+
from torchx.workspace.api import Workspace
2728

2829

2930
class Stream(str, Enum):
@@ -90,13 +91,25 @@ def close(self) -> None:
9091
"""
9192
pass
9293

93-
def submit(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> str:
94+
def submit(
95+
self,
96+
app: AppDef,
97+
cfg: Mapping[str, CfgVal],
98+
workspace: Optional[str] = None,
99+
) -> str:
94100
"""
95101
Submits the application to be run by the scheduler.
96102
103+
WARNING: Mostly used for tests. Users should prefer to use the TorchX runner instead.
104+
97105
Returns:
98106
The application id that uniquely identifies the submitted app.
99107
"""
108+
if workspace:
109+
sched = self
110+
assert isinstance(sched, Workspace)
111+
role = app.roles[0]
112+
sched.build_workspace_and_update_role(role, workspace)
100113
dryrun_info = self.submit_dryrun(app, cfg)
101114
return self.schedule(dryrun_info)
102115

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,28 @@ def _submit_dryrun(
223223

224224
for role_idx, role in enumerate(app.roles):
225225
for replica_id in range(role.num_replicas):
226+
rank = len(nodes)
226227
values = macros.Values(
227228
img_root="",
228229
app_id=name,
229230
replica_id=str(replica_id),
231+
rank0_env=(
232+
"TORCHX_RANK0_HOST"
233+
if rank == 0
234+
else "AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"
235+
),
230236
)
231237
replica_role = values.apply(role)
232238
replica_role.env["TORCHX_ROLE_IDX"] = str(role_idx)
233239
replica_role.env["TORCHX_ROLE_NAME"] = str(role.name)
234240
replica_role.env["TORCHX_REPLICA_IDX"] = str(replica_id)
235-
nodes.append(role_to_node_properties(len(nodes), replica_role))
241+
if rank == 0:
242+
# AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS is only
243+
# available on the child workers so we set the address to
244+
# localhost for rank0.
245+
# See: https://docs.aws.amazon.com/batch/latest/userguide/job_env_vars.html
246+
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
247+
nodes.append(role_to_node_properties(rank, replica_role))
236248

237249
req = BatchJob(
238250
name=name,

torchx/schedulers/docker_scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,14 @@ def _submit_dryrun(
185185

186186
app_id = make_unique(app.name)
187187
req = DockerJob(app_id=app_id, containers=[])
188+
rank0_name = f"{app_id}-{app.roles[0].name}-0"
188189
for role in app.roles:
189190
for replica_id in range(role.num_replicas):
190191
values = macros.Values(
191192
img_root="",
192193
app_id=app_id,
193194
replica_id=str(replica_id),
195+
rank0_env="TORCHX_RANK0_HOST",
194196
)
195197
replica_role = values.apply(role)
196198
name = f"{app_id}-{role.name}-{replica_id}"
@@ -199,6 +201,9 @@ def _submit_dryrun(
199201
if replica_role.env:
200202
env.update(replica_role.env)
201203

204+
# configure distributed host envs
205+
env["TORCHX_RANK0_HOST"] = rank0_name
206+
202207
c = DockerContainer(
203208
image=replica_role.image,
204209
command=[replica_role.entrypoint] + replica_role.args,

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,14 @@ def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]:
254254
img_root="",
255255
app_id=unique_app_id,
256256
replica_id=str(replica_id),
257+
rank0_env=f"VC_{cleanup_str(app.roles[0].name)}_0_HOSTS".upper(),
257258
)
259+
if role_idx == 0 and replica_id == 0:
260+
values.rank0_env = "TORCHX_RANK0_HOST"
258261
name = cleanup_str(f"{role.name}-{replica_id}")
259262
replica_role = values.apply(role)
263+
if role_idx == 0 and replica_id == 0:
264+
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
260265

261266
pod = role_to_pod(name, replica_role)
262267
pod.metadata.labels.update(pod_labels(app, role_idx, role, replica_id))

0 commit comments

Comments
 (0)