Skip to content

dist.ddp: make rendezvous work out of the box on all schedulers #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions scripts/awsbatchint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,28 @@

set -ex

RUN_ARGS="--scheduler aws_batch -c queue=torchx,image_repo=495572122715.dkr.ecr.us-west-2.amazonaws.com/torchx/integration-tests utils.echo"
JOB="$USER-$(uuidgen)"
DIR="/tmp/$JOB"

function cleanup {
rm -r "$DIR"
}
trap cleanup EXIT

mkdir "$DIR"
cd "$DIR"

cat <<EOT > .torchxconfig
[aws_batch]
queue=torchx
image_repo=495572122715.dkr.ecr.us-west-2.amazonaws.com/torchx/integration-tests
EOT

cat <<EOT > main.py
print("hello world!")
EOT

RUN_ARGS="--scheduler aws_batch dist.ddp -j 2x1 --script main.py"

if [ -z "$AWS_ROLE_ARN" ]; then
# only dryrun if no secrets
Expand All @@ -19,11 +40,11 @@ else
torchx status "$APP_ID"
torchx describe "$APP_ID"
torchx log "$APP_ID"
LINES="$(torchx log "$APP_ID" | wc -l)"
LINES="$(torchx log "$APP_ID" | grep -c 'hello world')"

if [ "$LINES" -ne 1 ]
if [ "$LINES" -ne 2 ]
then
echo "expected 1 log lines"
echo "expected 2 log lines"
exit 1
fi
fi
6 changes: 3 additions & 3 deletions scripts/slurmint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if [[ -z "${SLURM_INSTANCE_MASTER}" ]]; then
fi

JOB="$USER-$(uuidgen)"
DIR="/tmp/$JOB"
DIR="/home/ubuntu/integ-tests/$JOB"
VENV="$DIR/venv"

function run_cmd {
Expand All @@ -44,8 +44,8 @@ REMOTE_WHEEL="$DIR/$(basename "$WHEEL")"
SCRIPT="scripts/slurmtest.sh"
REMOTE_SCRIPT="$DIR/$(basename "$SCRIPT")"

run_cmd mkdir "$DIR"
run_cmd virtualenv -p /usr/bin/python3.8 "$VENV"
run_cmd mkdir -p "$DIR"
run_cmd virtualenv -p /home/ubuntu/miniconda3/bin/python "$VENV"
run_scp "$WHEEL" "$REMOTE_WHEEL"
run_scp "$SCRIPT" "$REMOTE_SCRIPT"
run_cmd "$REMOTE_SCRIPT" "$REMOTE_WHEEL" "$VENV"
28 changes: 24 additions & 4 deletions scripts/slurmtest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,43 @@ set -ex
REMOTE_WHEEL="$1"
VENV="$2"

BASE_DIR="$(dirname "$REMOTE_WHEEL")"
DIR="$BASE_DIR/project"
mkdir "$DIR"
cd "$DIR"

# shellcheck disable=SC1091
source /opt/slurm/etc/slurm.sh
sbatch --version
# shellcheck disable=SC1090
source "$VENV"/bin/activate
python --version

pip install "$REMOTE_WHEEL"
pip install numpy
pip install torch==1.10.2+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html

cat <<EOT > .torchxconfig
[slurm]
partition=compute
time=10
comment=hello
nomem=true
EOT

cat <<EOT > main.py
print("hello world!")
EOT

APP_ID="$(torchx run --wait --scheduler slurm --scheduler_args partition=compute,time=10,comment=hello utils.echo --num_replicas 3)"
APP_ID="$(torchx run --wait --log --scheduler slurm dist.ddp -j 2x1 --script main.py)"
torchx status "$APP_ID"
torchx describe "$APP_ID"
sacct -j "$(basename "$APP_ID")"
torchx log "$APP_ID"
LINES="$(torchx log "$APP_ID" | wc -l)"
LINES="$(torchx log "$APP_ID" | grep -c 'hello world')"

if [ "$LINES" -ne 3 ]
if [ "$LINES" -ne 2 ]
then
echo "expected 3 log lines"
echo "expected 2 log lines"
exit 1
fi
108 changes: 82 additions & 26 deletions torchx/components/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@
Components APIs
-----------------
"""
import shlex
from pathlib import Path
from typing import Dict, Optional
from typing import Dict, Optional, Iterable

import torchx
import torchx.specs as specs
Expand All @@ -131,16 +132,19 @@

def ddp(
*script_args: str,
script: str,
script: Optional[str] = None,
m: Optional[str] = None,
image: str = torchx.IMAGE,
name: Optional[str] = None,
h: Optional[str] = None,
cpu: int = 2,
gpu: int = 0,
memMB: int = 1024,
h: Optional[str] = None,
j: str = "1x2",
env: Optional[Dict[str, str]] = None,
rdzv_endpoint: str = "etcd-server.default.svc.cluster.local:2379",
max_restarts: Optional[int] = None,
rdzv_backend: str = "c10d",
rdzv_endpoint: Optional[str] = None,
) -> specs.AppDef:
"""
Distributed data parallel style application (one role, multi-replica).
Expand All @@ -154,6 +158,7 @@ def ddp(
Args:
script_args: arguments to the main module
script: script or binary to run within the image
m: the python module path to run
image: image (e.g. docker)
name: job name override (uses the script name if not specified)
cpu: number of cpus per replica
Expand All @@ -162,9 +167,14 @@ def ddp(
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
j: {nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
rdzv_endpoint: etcd server endpoint (only matters when nnodes > 1)
max_restarts: the number of restarts allowed
rdzv_backend: rendezvous backend (only matters when nnodes > 1)
rdzv_endpoint: rendezvous server endpoint (only matters when nnodes > 1), defaults to rank0 host for schedulers that support it
"""

if (script is None) == (m is None):
raise ValueError("exactly one of --script and -m must be specified")

rep = j.split("x")
if len(rep) == 1: # num replicas only
nnodes = 1
Expand All @@ -175,33 +185,79 @@ def ddp(
else:
raise ValueError(f"Invalid format for -j, usage example: 1x4. Given: {j}")

script_name_noext = Path(script).stem # script name no extension
if script:
# script name/module no extension
role_name = Path(script).stem
elif m:
role_name = m.rpartition(".")[2]
else:
raise ValueError("failed to compute role_name")

if rdzv_endpoint is None:
rdzv_endpoint = _noquote(f"$${macros.rank0_env}:29500")

if nnodes == 1:
rdzv_backend = "c10d"
rdzv_endpoint = "localhost:29500"

if env is None:
env = {}
env.setdefault("LOGLEVEL", "INFO")

cmd = [
"python",
"-m",
"torch.distributed.run",
"--rdzv_backend",
rdzv_backend,
"--rdzv_endpoint",
rdzv_endpoint,
"--rdzv_id",
f"{macros.app_id}",
"--nnodes",
str(nnodes),
"--nproc_per_node",
str(nproc_per_node),
]
if max_restarts is not None:
cmd += ["--max_restarts", str(max_restarts)]
if script is not None:
cmd += [script]
elif m is not None:
cmd += ["-m", m]
cmd += script_args
return specs.AppDef(
name=name or script_name_noext,
name=name or role_name,
roles=[
specs.Role(
name=script_name_noext,
name=role_name,
image=image,
entrypoint="python",
entrypoint="bash",
num_replicas=nnodes,
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
args=[
"-m",
"torch.distributed.run",
"--rdzv_backend",
("c10d" if nnodes == 1 else "etcd"),
"--rdzv_endpoint",
("localhost:29500" if nnodes == 1 else rdzv_endpoint),
"--rdzv_id",
f"{macros.app_id}",
"--nnodes",
str(nnodes),
"--nproc_per_node",
str(nproc_per_node),
script,
*script_args,
],
env=env or {},
args=["-c", _args_join(cmd)],
env=env,
port_map={
"c10d": 29500,
},
)
],
)


def _args_join(args: Iterable[str]) -> str:
"""
_args_join is like shlex.join but if the argument is wrapped in _noquote
it'll not quote that argument.
"""
quoted = [arg if isinstance(arg, _noquote) else shlex.quote(arg) for arg in args]
return " ".join(quoted)


class _noquote(str):
"""
_noquote is a wrapper around str that indicates that the argument shouldn't
be passed through shlex.quote.
"""

pass
4 changes: 2 additions & 2 deletions torchx/components/integration_tests/component_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def tearDown(self) -> None:

class DDPComponentProvider(ComponentProvider):
def get_app_def(self) -> AppDef:
rdzv_endpoint: str = "localhost:29400"
return dist_components.ddp(
script="torchx/components/integration_tests/test/dummy_app.py",
name="ddp-trainer",
image=self._image,
rdzv_endpoint=rdzv_endpoint,
j="2x2",
max_restarts=3,
)


Expand Down
15 changes: 14 additions & 1 deletion torchx/schedulers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SchedulerBackend,
runopts,
)
from torchx.workspace.api import Workspace


class Stream(str, Enum):
Expand Down Expand Up @@ -90,13 +91,25 @@ def close(self) -> None:
"""
pass

def submit(self, app: AppDef, cfg: Mapping[str, CfgVal]) -> str:
def submit(
self,
app: AppDef,
cfg: Mapping[str, CfgVal],
workspace: Optional[str] = None,
) -> str:
"""
Submits the application to be run by the scheduler.

WARNING: Mostly used for tests. Users should prefer to use the TorchX runner instead.

Returns:
The application id that uniquely identifies the submitted app.
"""
if workspace:
sched = self
assert isinstance(sched, Workspace)
role = app.roles[0]
sched.build_workspace_and_update_role(role, workspace)
dryrun_info = self.submit_dryrun(app, cfg)
return self.schedule(dryrun_info)

Expand Down
14 changes: 13 additions & 1 deletion torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,28 @@ def _submit_dryrun(

for role_idx, role in enumerate(app.roles):
for replica_id in range(role.num_replicas):
rank = len(nodes)
values = macros.Values(
img_root="",
app_id=name,
replica_id=str(replica_id),
rank0_env=(
"TORCHX_RANK0_HOST"
if rank == 0
else "AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS"
),
)
replica_role = values.apply(role)
replica_role.env["TORCHX_ROLE_IDX"] = str(role_idx)
replica_role.env["TORCHX_ROLE_NAME"] = str(role.name)
replica_role.env["TORCHX_REPLICA_IDX"] = str(replica_id)
nodes.append(role_to_node_properties(len(nodes), replica_role))
if rank == 0:
# AWS_BATCH_JOB_MAIN_NODE_PRIVATE_IPV4_ADDRESS is only
# available on the child workers so we set the address to
# localhost for rank0.
# See: https://docs.aws.amazon.com/batch/latest/userguide/job_env_vars.html
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"
nodes.append(role_to_node_properties(rank, replica_role))

req = BatchJob(
name=name,
Expand Down
5 changes: 5 additions & 0 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,14 @@ def _submit_dryrun(

app_id = make_unique(app.name)
req = DockerJob(app_id=app_id, containers=[])
rank0_name = f"{app_id}-{app.roles[0].name}-0"
for role in app.roles:
for replica_id in range(role.num_replicas):
values = macros.Values(
img_root="",
app_id=app_id,
replica_id=str(replica_id),
rank0_env="TORCHX_RANK0_HOST",
)
replica_role = values.apply(role)
name = f"{app_id}-{role.name}-{replica_id}"
Expand All @@ -199,6 +201,9 @@ def _submit_dryrun(
if replica_role.env:
env.update(replica_role.env)

# configure distributed host envs
env["TORCHX_RANK0_HOST"] = rank0_name

c = DockerContainer(
image=replica_role.image,
command=[replica_role.entrypoint] + replica_role.args,
Expand Down
5 changes: 5 additions & 0 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,14 @@ def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]:
img_root="",
app_id=unique_app_id,
replica_id=str(replica_id),
rank0_env=f"VC_{cleanup_str(app.roles[0].name)}_0_HOSTS".upper(),
)
if role_idx == 0 and replica_id == 0:
values.rank0_env = "TORCHX_RANK0_HOST"
name = cleanup_str(f"{role.name}-{replica_id}")
replica_role = values.apply(role)
if role_idx == 0 and replica_id == 0:
replica_role.env["TORCHX_RANK0_HOST"] = "localhost"

pod = role_to_pod(name, replica_role)
pod.metadata.labels.update(pod_labels(app, role_idx, role, replica_id))
Expand Down
Loading