Skip to content

Commit b8bb9cb

Browse files
committed
dist.ddp: make rendezvous work out of the box on all schedulers
1 parent 4ea97be commit b8bb9cb

22 files changed

+296
-76
lines changed

scripts/awsbatchint.sh

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,28 @@
77

88
set -ex
99

10-
RUN_ARGS="--scheduler aws_batch -c queue=torchx,image_repo=495572122715.dkr.ecr.us-west-2.amazonaws.com/torchx/integration-tests utils.echo"
10+
JOB="$USER-$(uuidgen)"
11+
DIR="/tmp/$JOB"
12+
13+
function cleanup {
14+
rm -r "$DIR"
15+
}
16+
trap cleanup EXIT
17+
18+
mkdir "$DIR"
19+
cd "$DIR"
20+
21+
cat <<EOT > .torchxconfig
22+
[aws_batch]
23+
queue=torchx
24+
image_repo=495572122715.dkr.ecr.us-west-2.amazonaws.com/torchx/integration-tests
25+
EOT
26+
27+
cat <<EOT > main.py
28+
print("hello world!")
29+
EOT
30+
31+
RUN_ARGS="--scheduler aws_batch dist.ddp -j 2x1 --script main.py"
1132

1233
if [ -z "$AWS_ROLE_ARN" ]; then
1334
# only dryrun if no secrets
@@ -21,9 +42,9 @@ else
2142
torchx log "$APP_ID"
2243
LINES="$(torchx log "$APP_ID" | wc -l)"
2344

24-
if [ "$LINES" -ne 1 ]
45+
if [ "$LINES" -ne 2 ]
2546
then
26-
echo "expected 1 log lines"
47+
echo "expected 2 log lines"
2748
exit 1
2849
fi
2950
fi

scripts/slurmint.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [[ -z "${SLURM_INSTANCE_MASTER}" ]]; then
2020
fi
2121

2222
JOB="$USER-$(uuidgen)"
23-
DIR="/tmp/$JOB"
23+
DIR="/home/ubuntu/integ-tests/$JOB"
2424
VENV="$DIR/venv"
2525

2626
function run_cmd {
@@ -44,8 +44,8 @@ REMOTE_WHEEL="$DIR/$(basename "$WHEEL")"
4444
SCRIPT="scripts/slurmtest.sh"
4545
REMOTE_SCRIPT="$DIR/$(basename "$SCRIPT")"
4646

47-
run_cmd mkdir "$DIR"
48-
run_cmd virtualenv -p /usr/bin/python3.8 "$VENV"
47+
run_cmd mkdir -p "$DIR"
48+
run_cmd virtualenv -p /home/ubuntu/miniconda3/bin/python "$VENV"
4949
run_scp "$WHEEL" "$REMOTE_WHEEL"
5050
run_scp "$SCRIPT" "$REMOTE_SCRIPT"
5151
run_cmd "$REMOTE_SCRIPT" "$REMOTE_WHEEL" "$VENV"

scripts/slurmtest.sh

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,43 @@ set -ex
1010
REMOTE_WHEEL="$1"
1111
VENV="$2"
1212

13+
BASE_DIR="$(dirname "$REMOTE_WHEEL")"
14+
DIR="$BASE_DIR/project"
15+
mkdir "$DIR"
16+
cd "$DIR"
17+
1318
# shellcheck disable=SC1091
1419
source /opt/slurm/etc/slurm.sh
1520
sbatch --version
1621
# shellcheck disable=SC1090
1722
source "$VENV"/bin/activate
1823
python --version
24+
1925
pip install "$REMOTE_WHEEL"
26+
pip install numpy
27+
pip install torch==1.10.2+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
28+
29+
cat <<EOT > .torchxconfig
30+
[slurm]
31+
partition=compute
32+
time=10
33+
comment=hello
34+
nomem=true
35+
EOT
36+
37+
cat <<EOT > main.py
38+
print("hello world!")
39+
EOT
2040

21-
APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10,comment=hello utils.echo --num_replicas 3)"
41+
APP_ID="$(torchx run --wait --log --scheduler slurm dist.ddp -j 2x1 --script main.py)"
2242
torchx status "$APP_ID"
2343
torchx describe "$APP_ID"
2444
sacct -j "$(basename "$APP_ID")"
2545
torchx log "$APP_ID"
26-
LINES="$(torchx log "$APP_ID" | wc -l)"
46+
LINES="$(torchx log "$APP_ID" | grep 'hello world' | wc -l)"
2747

28-
if [ "$LINES" -ne 3 ]
48+
if [ "$LINES" -ne 2 ]
2949
then
30-
echo "expected 3 log lines"
50+
echo "expected 2 log lines"
3151
exit 1
3252
fi

torchx/components/dist.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@
121121
Components APIs
122122
-----------------
123123
"""
124+
import shlex
124125
from pathlib import Path
125-
from typing import Dict, Optional
126+
from typing import Dict, Optional, Iterable
126127

127128
import torchx
128129
import torchx.specs as specs
@@ -140,7 +141,8 @@ def ddp(
140141
h: Optional[str] = None,
141142
j: str = "1x2",
142143
env: Optional[Dict[str, str]] = None,
143-
rdzv_endpoint: str = "etcd-server.default.svc.cluster.local:2379",
144+
rdzv_backend: str = "c10d",
145+
rdzv_endpoint: Optional[str] = None,
144146
) -> specs.AppDef:
145147
"""
146148
Distributed data parallel style application (one role, multi-replica).
@@ -162,7 +164,8 @@ def ddp(
162164
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
163165
j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
164166
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
165-
rdzv_endpoint: etcd server endpoint (only matters when nnodes > 1)
167+
rdzv_backend: rendezvous backend (only matters when nnodes > 1)
168+
rdzv_endpoint: rendezvous server endpoint (only matters when nnodes > 1), defaults to rank0 host for schedulers that support it
166169
"""
167170

168171
rep = j.split("x")
@@ -176,32 +179,63 @@ def ddp(
176179
raise ValueError(f"Invalid format for -j, usage example: 1x4. Given: {j}")
177180

178181
script_name_noext = Path(script).stem # script name no extension
182+
183+
if rdzv_endpoint is None:
184+
rdzv_endpoint = _noquote(f"$${macros.rank0_env}:29500")
185+
186+
if nnodes == 1:
187+
rdzv_backend = "c10d"
188+
rdzv_endpoint = "localhost:29500"
189+
190+
cmd = [
191+
"python",
192+
"-m",
193+
"torch.distributed.run",
194+
"--rdzv_backend",
195+
rdzv_backend,
196+
"--rdzv_endpoint",
197+
rdzv_endpoint,
198+
"--rdzv_id",
199+
f"{macros.app_id}",
200+
"--nnodes",
201+
str(nnodes),
202+
"--nproc_per_node",
203+
str(nproc_per_node),
204+
script,
205+
*script_args,
206+
]
179207
return specs.AppDef(
180208
name=name or script_name_noext,
181209
roles=[
182210
specs.Role(
183211
name=script_name_noext,
184212
image=image,
185-
entrypoint="python",
213+
entrypoint="bash",
186214
num_replicas=nnodes,
187215
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
188-
args=[
189-
"-m",
190-
"torch.distributed.run",
191-
"--rdzv_backend",
192-
("c10d" if nnodes == 1 else "etcd"),
193-
"--rdzv_endpoint",
194-
("localhost:29500" if nnodes == 1 else rdzv_endpoint),
195-
"--rdzv_id",
196-
f"{macros.app_id}",
197-
"--nnodes",
198-
str(nnodes),
199-
"--nproc_per_node",
200-
str(nproc_per_node),
201-
script,
202-
*script_args,
203-
],
216+
args=["-c", _args_join(cmd)],
204217
env=env or {},
218+
port_map={
219+
"c10d": 29500,
220+
},
205221
)
206222
],
207223
)
224+
225+
226+
def _args_join(args: Iterable[str]) -> str:
227+
"""
228+
_args_join is like shlex.join but if the argument is wrapped in _noquote
229+
it'll not quote that argument.
230+
"""
231+
quoted = [arg if isinstance(arg, _noquote) else shlex.quote(arg) for arg in args]
232+
return " ".join(quoted)
233+
234+
235+
class _noquote(str):
236+
"""
237+
_noquote is a wrapper around str that indicates that the argument shouldn't
238+
be passed through shlex.quote.
239+
"""
240+
241+
pass

torchx/components/integration_tests/component_provider.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@ def tearDown(self) -> None:
3636

3737
class DDPComponentProvider(ComponentProvider):
3838
def get_app_def(self) -> AppDef:
39-
rdzv_endpoint: str = "localhost:29400"
4039
return dist_components.ddp(
4140
script="torchx/components/integration_tests/test/dummy_app.py",
4241
name="ddp-trainer",
4342
image=self._image,
44-
rdzv_endpoint=rdzv_endpoint,
43+
j="2x2",
4544
)
4645

4746

torchx/schedulers/api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SchedulerBackend,
2525
runopts,
2626
)
27+
from torchx.workspace.api import Workspace
2728

2829

2930
class Stream(str, Enum):
@@ -90,13 +91,25 @@ def close(self) -> None:
9091
"""
9192
pass
9293

93-
def submit(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> str:
94+
def submit(
95+
self,
96+
app: AppDef,
97+
cfg: Mapping[str, CfgVal],
98+
workspace: Optional[str] = None,
99+
) -> str:
94100
"""
95101
Submits the application to be run by the scheduler.
96102
103+
WARNING: Mostly used for tests. Users should prefer to use the TorchX runner instead.
104+
97105
Returns:
98106
The application id that uniquely identifies the submitted app.
99107
"""
108+
if workspace:
109+
sched = self
110+
assert isinstance(sched, Workspace)
111+
role = app.roles[0]
112+
sched.build_workspace_and_update_role(role, workspace)
100113
dryrun_info = self.submit_dryrun(app, cfg)
101114
return self.schedule(dryrun_info)
102115

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,28 @@ def _submit_dryrun(
223223

224224
for role_idx, role in enumerate(app.roles):
225225
for replica_id in range(role.num_replicas):
226+
rank = len(nodes)
226227
values = macros.Values(
227228
img_root="",
228229
app_id=name,
229230
replica_id=str(replica_id),
231+
rank0_env=(
232+
"TORCHX_RANK0_HOST"
233+
if rank == 0
234+
else "AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"
235+
),
230236
)
231237
replica_role = values.apply(role)
232238
replica_role.env["TORCHX_ROLE_IDX"] = str(role_idx)
233239
replica_role.env["TORCHX_ROLE_NAME"] = str(role.name)
234240
replica_role.env["TORCHX_REPLICA_IDX"] = str(replica_id)
235-
nodes.append(role_to_node_properties(len(nodes), replica_role))
241+
if rank == 0:
242+
# AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS is only
243+
# available on the child workers so we set the address to
244+
# localhost for rank0.
245+
# See: https://docs.aws.amazon.com/batch/latest/userguide/job_env_vars.html
246+
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
247+
nodes.append(role_to_node_properties(rank, replica_role))
236248

237249
req = BatchJob(
238250
name=name,

torchx/schedulers/docker_scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,14 @@ def _submit_dryrun(
185185

186186
app_id = make_unique(app.name)
187187
req = DockerJob(app_id=app_id, containers=[])
188+
rank0_name = f"{app_id}-{app.roles[0].name}-0"
188189
for role in app.roles:
189190
for replica_id in range(role.num_replicas):
190191
values = macros.Values(
191192
img_root="",
192193
app_id=app_id,
193194
replica_id=str(replica_id),
195+
rank0_env="TORCHX_RANK0_HOST",
194196
)
195197
replica_role = values.apply(role)
196198
name = f"{app_id}-{role.name}-{replica_id}"
@@ -199,6 +201,9 @@ def _submit_dryrun(
199201
if replica_role.env:
200202
env.update(replica_role.env)
201203

204+
# configure distributed host envs
205+
env["TORCHX_RANK0_HOST"] = rank0_name
206+
202207
c = DockerContainer(
203208
image=replica_role.image,
204209
command=[replica_role.entrypoint] + replica_role.args,

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,14 @@ def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]:
254254
img_root="",
255255
app_id=unique_app_id,
256256
replica_id=str(replica_id),
257+
rank0_env=f"VC_{app.roles[0].name}_0_HOSTS".upper(),
257258
)
259+
if role_idx == 0 and replica_id == 0:
260+
values.rank0_env = "TORCHX_RANK0_HOST"
258261
name = cleanup_str(f"{role.name}-{replica_id}")
259262
replica_role = values.apply(role)
263+
if role_idx == 0 and replica_id == 0:
264+
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
260265

261266
pod = role_to_pod(name, replica_role)
262267
pod.metadata.labels.update(pod_labels(app, role_idx, role, replica_id))

torchx/schedulers/local_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,8 +864,10 @@ def _to_popen_request(
864864
img_root=img_root,
865865
app_id=app_id,
866866
replica_id=str(replica_id),
867+
rank0_env="TORCHX_RANK0_HOST",
867868
)
868869
replica_role = values.apply(role)
870+
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
869871

870872
replica_log_dir = os.path.join(app_log_dir, role.name, str(replica_id))
871873
if "TORCHELASTIC_ERROR_FILE" not in replica_role.env:

0 commit comments

Comments
 (0)