diff --git a/.github/workflows/aws-batch-integration-tests.yaml b/.github/workflows/aws-batch-integration-tests.yaml new file mode 100644 index 000000000..84781aa80 --- /dev/null +++ b/.github/workflows/aws-batch-integration-tests.yaml @@ -0,0 +1,45 @@ +name: AWS Batch Integration Tests + +on: + push: + branches: + - main + pull_request: + +jobs: + awsbatch: + runs-on: ubuntu-18.04 + permissions: + id-token: write + contents: read + steps: + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + architecture: x64 + - name: Checkout TorchX + uses: actions/checkout@v2 + - name: Configure AWS + env: + AWS_ROLE_ARN: ${{ secrets.AWS_ROLE_ARN }} + run: | + if [ -n "$AWS_ROLE_ARN" ]; then + export AWS_WEB_IDENTITY_TOKEN_FILE=/tmp/awscreds + export AWS_DEFAULT_REGION=us-west-2 + + echo AWS_WEB_IDENTITY_TOKEN_FILE=$AWS_WEB_IDENTITY_TOKEN_FILE >> $GITHUB_ENV + echo AWS_ROLE_ARN=$AWS_ROLE_ARN >> $GITHUB_ENV + echo AWS_DEFAULT_REGION=$AWS_DEFAULT_REGION >> $GITHUB_ENV + + curl -H "Authorization: bearer $ACTIONS_ID_TOKEN_REQUEST_TOKEN" "$ACTIONS_ID_TOKEN_REQUEST_URL" | jq -r '.value' > $AWS_WEB_IDENTITY_TOKEN_FILE + fi + - name: Install dependencies + run: | + set -eux + pip install -e .[dev] + - name: Run AWS Batch Integration Tests + run: | + set -ex + + scripts/awsbatchint.sh diff --git a/dev-requirements.txt b/dev-requirements.txt index 344dd5c95..d43621456 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,16 +1,18 @@ -aiobotocore>=1.4.1 +aiobotocore==2.1.0 ax-platform[mysql]==0.2.2 black==21.10b0 +boto3==1.20.24 captum>=0.4.0 classy-vision>=0.6.0 flake8==3.9.0 -fsspec[s3]==2021.10.1 +fsspec[s3]==2022.1.0 importlib-metadata +ipython kfp==1.8.9 -moto==2.2.12 +moto==3.0.2 pyre-extensions==0.0.21 +pytest pytorch-lightning==1.5.6 -s3fs==2021.10.1 ray[default]==1.9.2 torch-model-archiver==0.4.2 torch==1.10.0 @@ -19,5 +21,3 @@ torchtext==0.11.0 torchvision==0.11.1 ts==0.5.1 usort==0.6.4 -ipython -pytest diff --git a/docs/source/index.rst b/docs/source/index.rst index 93747761f..244e97cbd 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -97,6 +97,7 @@ Works With schedulers/kubernetes schedulers/slurm schedulers/ray + schedulers/aws_batch .. _Pipelines: .. toctree:: diff --git a/docs/source/schedulers/aws_batch.rst b/docs/source/schedulers/aws_batch.rst new file mode 100644 index 000000000..ec2f65410 --- /dev/null +++ b/docs/source/schedulers/aws_batch.rst @@ -0,0 +1,8 @@ +AWS Batch +================= + +.. automodule:: torchx.schedulers.aws_batch_scheduler +.. currentmodule:: torchx.schedulers.aws_batch_scheduler + +.. autoclass:: AWSBatchScheduler + :members: diff --git a/scripts/awsbatchint.sh b/scripts/awsbatchint.sh new file mode 100755 index 000000000..7ae910e6b --- /dev/null +++ b/scripts/awsbatchint.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +APP_ID="$(torchx run --wait --scheduler aws_batch -c queue=torchx utils.echo)" +torchx status "$APP_ID" +torchx describe "$APP_ID" +torchx log "$APP_ID" +LINES="$(torchx log "$APP_ID" | wc -l)" + +if [ "$LINES" -ne 1 ] +then + echo "expected 1 log lines" + exit 1 +fi diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index 83c48fada..63128269b 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -7,6 +7,7 @@ from typing import Dict, Optional +import torchx.schedulers.aws_batch_scheduler as aws_batch_scheduler import torchx.schedulers.docker_scheduler as docker_scheduler import torchx.schedulers.kubernetes_scheduler as kubernetes_scheduler import torchx.schedulers.local_scheduler as local_scheduler @@ -48,6 +49,7 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]: "local_cwd": local_scheduler.create_cwd_scheduler, "slurm": slurm_scheduler.create_scheduler, "kubernetes": kubernetes_scheduler.create_scheduler, + "aws_batch": aws_batch_scheduler.create_scheduler, } ray_scheduler_creator = try_get_ray_scheduler() diff --git a/torchx/schedulers/aws_batch_scheduler.py b/torchx/schedulers/aws_batch_scheduler.py new file mode 100644 index 000000000..747014dbf --- /dev/null +++ b/torchx/schedulers/aws_batch_scheduler.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" + +This contains the TorchX AWS Batch scheduler which can be used to run TorchX +components directly on AWS Batch. + +This scheduler is in prototype stage and may change without notice. + +Prerequisites +============== + +You'll need to create an AWS Batch queue configured for multi-node parallel jobs. + +See +https://docs.aws.amazon.com/batch/latest/userguide/multi-node-parallel-jobs.html +for more information. + +If you want to use workspaces and container patching you'll also need to +configure a docker registry to store the patched containers with your changes +such as AWS ECR. + +See +https://docs.aws.amazon.com/AmazonECR/latest/userguide/get-set-up-for-amazon-ecr.html +for more information. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, Iterable, Mapping, Optional, Any + +import torchx +import yaml +from torchx.schedulers.api import ( + AppDryRunInfo, + DescribeAppResponse, + Scheduler, + Stream, + filter_regex, +) +from torchx.schedulers.ids import make_unique +from torchx.specs.api import ( + AppDef, + AppState, + Role, + SchedulerBackend, + macros, + runopts, + CfgVal, +) + +JOB_STATE: Dict[str, AppState] = { + "SUBMITTED": AppState.PENDING, + "PENDING": AppState.PENDING, + "RUNNABLE": AppState.PENDING, + "STARTING": AppState.PENDING, + "RUNNING": AppState.RUNNING, + "SUCCEEDED": AppState.SUCCEEDED, + "FAILED": AppState.FAILED, +} + + +def role_to_node_properties(idx: int, role: Role) -> Dict[str, object]: + resource = role.resource + reqs = [] + cpu = resource.cpu + if cpu <= 0: + cpu = 1 + reqs.append({"type": "VCPU", "value": str(cpu)}) + + mem = resource.memMB + if mem <= 0: + mem = 1000 + reqs.append({"type": "MEMORY", "value": str(mem)}) + + if resource.gpu >= 0: + reqs.append({"type": "GPU", "value": str(resource.gpu)}) + + container = { + "command": [role.entrypoint] + role.args, + "image": role.image, + "environment": [{"name": k, "value": v} for k, v in role.env.items()], + "resourceRequirements": reqs, + "logConfiguration": { + "logDriver": "awslogs", + }, + } + + return { + "targetNodes": str(idx), + "container": container, + } + + +@dataclass +class BatchJob: + name: str + queue: str + job_def: Dict[str, object] + + def __str__(self) -> str: + return yaml.dump(self.job_def) + + def __repr__(self) -> str: + return str(self) + + +class AWSBatchScheduler(Scheduler): + """ + AWSBatchScheduler is a TorchX scheduling interface to AWS Batch. + + .. code-block:: bash + + $ pip install torchx[kubernetes] + $ torchx run --scheduler aws_batch --scheduler_args queue=torchx utils.echo --image alpine:latest --msg hello + aws_batch://torchx_user/1234 + $ torchx status aws_batch://torchx_user/1234 + ... + + **Config Options** + + .. runopts:: + class: torchx.schedulers.aws_batch_scheduler.AWSBatchScheduler + + **Compatibility** + + .. compatibility:: + type: scheduler + features: + cancel: true + logs: true + distributed: true + describe: | + Partial support. AWSBatchScheduler will return job and replica + status but does not provide the complete original AppSpec. + """ + + def __init__( + self, + session_name: str, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + client: Optional[Any] = None, + # pyre-fixme[2]: Parameter annotation cannot be `Any`. + log_client: Optional[Any] = None, + ) -> None: + super().__init__("aws_batch", session_name) + + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + self.__client = client + # pyre-fixme[4]: Attribute annotation cannot be `Any`. + self.__log_client = log_client + + @property + # pyre-fixme[3]: Return annotation cannot be `Any`. + def _client(self) -> Any: + if self.__client is None: + import boto3 + + self.__client = boto3.client("batch") + return self.__client + + @property + # pyre-fixme[3]: Return annotation cannot be `Any`. + def _log_client(self) -> Any: + if self.__log_client is None: + import boto3 + + self.__log_client = boto3.client("logs") + return self.__log_client + + def schedule(self, dryrun_info: AppDryRunInfo[BatchJob]) -> str: + cfg = dryrun_info._cfg + assert cfg is not None, f"{dryrun_info} missing cfg" + req = dryrun_info.request + self._client.register_job_definition(**req.job_def) + + self._client.submit_job( + jobName=req.name, + jobQueue=req.queue, + jobDefinition=req.name, + tags=req.job_def["tags"], + ) + + return f"{req.queue}:{req.name}" + + def _submit_dryrun( + self, app: AppDef, cfg: Mapping[str, CfgVal] + ) -> AppDryRunInfo[BatchJob]: + queue = cfg.get("queue") + if not isinstance(queue, str): + raise TypeError(f"config value 'queue' must be a string, got {queue}") + name = make_unique(app.name) + + nodes = [] + + for role_idx, role in enumerate(app.roles): + for replica_id in range(role.num_replicas): + values = macros.Values( + img_root="", + app_id=name, + replica_id=str(replica_id), + ) + replica_role = values.apply(role) + replica_role.env["TORCHX_ROLE_IDX"] = str(role_idx) + replica_role.env["TORCHX_ROLE_NAME"] = str(role.name) + replica_role.env["TORCHX_REPLICA_IDX"] = str(replica_id) + nodes.append(role_to_node_properties(len(nodes), replica_role)) + + req = BatchJob( + name=name, + queue=queue, + job_def={ + "jobDefinitionName": name, + "type": "multinode", + "nodeProperties": { + "numNodes": len(nodes), + "mainNode": 0, + "nodeRangeProperties": nodes, + }, + "retryStrategy": { + "attempts": max(max(role.max_retries for role in app.roles), 1), + "evaluateOnExit": [ + {"onExitCode": "0", "action": "EXIT"}, + ], + }, + "tags": { + "torchx.pytorch.org/version": torchx.__version__, + "torchx.pytorch.org/app-name": app.name, + }, + }, + ) + info = AppDryRunInfo(req, repr) + info._app = app + info._cfg = cfg + return info + + def _validate(self, app: AppDef, scheduler: SchedulerBackend) -> None: + # Skip validation step + pass + + def _cancel_existing(self, app_id: str) -> None: + job_id = self._get_job_id(app_id) + self._client.terminate_job( + jobId=job_id, + reason="killed via torchx CLI", + ) + + def run_opts(self) -> runopts: + opts = runopts() + opts.add("queue", type_=str, help="queue to schedule job in", required=True) + return opts + + def _get_job_id(self, app_id: str) -> Optional[str]: + queue, name = app_id.split(":") + + job_summary_list = self._client.list_jobs( + jobQueue=queue, + filters=[{"name": "JOB_NAME", "values": [name]}], + )["jobSummaryList"] + if len(job_summary_list) == 0: + return None + return job_summary_list[0]["jobArn"] + + def _get_job( + self, app_id: str, rank: Optional[int] = None + ) -> Optional[Dict[str, Any]]: + job_id = self._get_job_id(app_id) + if not job_id: + return None + if rank is not None: + job_id += f"#{rank}" + jobs = self._client.describe_jobs(jobs=[job_id])["jobs"] + if len(jobs) == 0: + return None + return jobs[0] + + def describe(self, app_id: str) -> Optional[DescribeAppResponse]: + job = self._get_job(app_id) + if job is None: + return None + + # TODO: role statuses + + roles = {} + nodes = job["nodeProperties"]["nodeRangeProperties"] + for node in nodes: + container = node["container"] + env = {opt["name"]: opt["value"] for opt in container["environment"]} + role = env["TORCHX_ROLE_NAME"] + replica_id = int(env["TORCHX_REPLICA_IDX"]) + + if role not in roles: + roles[role] = Role( + name=role, + num_replicas=0, + image=container["image"], + entrypoint=container["command"][0], + args=container["command"][1:], + env=env, + ) + roles[role].num_replicas += 1 + + return DescribeAppResponse( + app_id=app_id, + state=JOB_STATE[job["status"]], + roles=list(roles.values()), + ) + + def log_iter( + self, + app_id: str, + role_name: str, + k: int = 0, + regex: Optional[str] = None, + since: Optional[datetime] = None, + until: Optional[datetime] = None, + should_tail: bool = False, + streams: Optional[Stream] = None, + ) -> Iterable[str]: + if streams not in (None, Stream.COMBINED): + raise ValueError("KubernetesScheduler only supports COMBINED log stream") + + job = self._get_job(app_id) + if job is None: + return [] + nodes = job["nodeProperties"]["nodeRangeProperties"] + i = 0 + for i, node in enumerate(nodes): + container = node["container"] + env = {opt["name"]: opt["value"] for opt in container["environment"]} + node_role = env["TORCHX_ROLE_NAME"] + replica_id = int(env["TORCHX_REPLICA_IDX"]) + if role_name == node_role and k == replica_id: + break + + job = self._get_job(app_id, rank=i) + if not job: + return [] + + attempts = job["attempts"] + if len(attempts) == 0: + return [] + + attempt = attempts[-1] + container = attempt["container"] + stream_name = container["logStreamName"] + + iterator = self._stream_events(stream_name, since=since, until=until) + if regex: + return filter_regex(regex, iterator) + else: + return iterator + + def _stream_events( + self, + stream_name: str, + since: Optional[datetime] = None, + until: Optional[datetime] = None, + ) -> Iterable[str]: + + next_token = None + + while True: + args = {} + if next_token is not None: + args["nextToken"] = next_token + if until is not None: + args["endTime"] = until.timestamp() + if since is not None: + args["startTime"] = since.timestamp() + try: + response = self._log_client.get_log_events( + logGroupName="/aws/batch/job", + logStreamName=stream_name, + limit=10000, + startFromHead=True, + **args, + ) + except self._log_client.exceptions.ResourceNotFoundException: + return [] + if response["nextForwardToken"] == next_token: + break + next_token = response["nextForwardToken"] + + for event in response["events"]: + yield event["message"] + + +def create_scheduler(session_name: str, **kwargs: object) -> AWSBatchScheduler: + return AWSBatchScheduler( + session_name=session_name, + ) diff --git a/torchx/schedulers/test/aws_batch_scheduler_test.py b/torchx/schedulers/test/aws_batch_scheduler_test.py new file mode 100644 index 000000000..2075c404d --- /dev/null +++ b/torchx/schedulers/test/aws_batch_scheduler_test.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from contextlib import contextmanager +from typing import Generator +from unittest.mock import MagicMock, patch + +import torchx +from torchx import specs +from torchx.schedulers.aws_batch_scheduler import ( + create_scheduler, + AWSBatchScheduler, +) + + +def _test_app() -> specs.AppDef: + trainer_role = specs.Role( + name="trainer", + image="pytorch/torchx:latest", + entrypoint="main", + args=["--output-path", specs.macros.img_root, "--app-id", specs.macros.app_id], + env={"FOO": "bar"}, + resource=specs.Resource( + cpu=2, + memMB=3000, + gpu=4, + ), + port_map={"foo": 1234}, + num_replicas=1, + max_retries=3, + ) + + return specs.AppDef("test", roles=[trainer_role]) + + +@contextmanager +def mock_rand() -> Generator[None, None, None]: + with patch("torchx.schedulers.aws_batch_scheduler.make_unique") as make_unique_ctx: + make_unique_ctx.return_value = "app-name-42" + yield + + +class AWSBatchSchedulerTest(unittest.TestCase): + def test_create_scheduler(self) -> None: + scheduler = create_scheduler("foo") + self.assertIsInstance(scheduler, AWSBatchScheduler) + + def test_validate(self) -> None: + scheduler = create_scheduler("test") + app = _test_app() + scheduler._validate(app, "kubernetes") + + @mock_rand() + def test_submit_dryrun(self) -> None: + scheduler = create_scheduler("test") + app = _test_app() + cfg = {"queue": "testqueue"} + info = scheduler._submit_dryrun(app, cfg) + + req = info.request + self.assertEqual(req.queue, "testqueue") + job_def = req.job_def + + print(job_def) + + self.assertEqual( + job_def, + { + "jobDefinitionName": "app-name-42", + "type": "multinode", + "nodeProperties": { + "numNodes": 1, + "mainNode": 0, + "nodeRangeProperties": [ + { + "targetNodes": "0", + "container": { + "command": [ + "main", + "--output-path", + "", + "--app-id", + "app-name-42", + ], + "image": "pytorch/torchx:latest", + "environment": [ + {"name": "FOO", "value": "bar"}, + {"name": "TORCHX_ROLE_IDX", "value": "0"}, + {"name": "TORCHX_ROLE_NAME", "value": "trainer"}, + {"name": "TORCHX_REPLICA_IDX", "value": "0"}, + ], + "resourceRequirements": [ + {"type": "VCPU", "value": "2"}, + {"type": "MEMORY", "value": "3000"}, + {"type": "GPU", "value": "4"}, + ], + "logConfiguration": {"logDriver": "awslogs"}, + }, + } + ], + }, + "retryStrategy": { + "attempts": 3, + "evaluateOnExit": [{"onExitCode": "0", "action": "EXIT"}], + }, + "tags": { + "torchx.pytorch.org/version": torchx.__version__, + "torchx.pytorch.org/app-name": "test", + }, + }, + ) + + def _mock_scheduler(self) -> AWSBatchScheduler: + scheduler = AWSBatchScheduler( + "test", + client=MagicMock(), + log_client=MagicMock(), + ) + scheduler._client.list_jobs.return_value = { + "jobSummaryList": [ + { + "jobArn": "arn:aws:batch:us-west-2:495572122715:job/6afc27d7-3559-43ca-89fd-1007b6bf2546", + "jobId": "6afc27d7-3559-43ca-89fd-1007b6bf2546", + "jobName": "echo-v1r560pmwn5t3c", + "createdAt": 1643949940162, + "status": "SUCCEEDED", + "stoppedAt": 1643950324125, + "container": {"exitCode": 0}, + "nodeProperties": {"numNodes": 2}, + "jobDefinition": "arn:aws:batch:us-west-2:495572122715:job-definition/echo-v1r560pmwn5t3c:1", + } + ] + } + scheduler._client.describe_jobs.return_value = { + "jobs": [ + { + "jobArn": "thejobarn", + "jobName": "app-name-42", + "jobId": "6afc27d7-3559-43ca-89fd-1007b6bf2546", + "jobQueue": "testqueue", + "status": "SUCCEEDED", + "attempts": [ + { + "container": { + "exitCode": 0, + "logStreamName": "log_stream", + "networkInterfaces": [], + }, + "startedAt": 1643950310819, + "stoppedAt": 1643950324125, + "statusReason": "Essential container in task exited", + } + ], + "statusReason": "Essential container in task exited", + "createdAt": 1643949940162, + "retryStrategy": { + "attempts": 1, + "evaluateOnExit": [{"onExitCode": "0", "action": "exit"}], + }, + "startedAt": 1643950310819, + "stoppedAt": 1643950324125, + "dependsOn": [], + "jobDefinition": "job-def", + "parameters": {}, + "nodeProperties": { + "numNodes": 2, + "mainNode": 0, + "nodeRangeProperties": [ + { + "targetNodes": "0", + "container": { + "image": "ghcr.io/pytorch/torchx:0.1.2dev0", + "command": ["echo", "your name"], + "volumes": [], + "environment": [ + {"name": "TORCHX_ROLE_IDX", "value": "0"}, + {"name": "TORCHX_REPLICA_IDX", "value": "0"}, + {"name": "TORCHX_ROLE_NAME", "value": "echo"}, + ], + "mountPoints": [], + "ulimits": [], + "resourceRequirements": [ + {"value": "1", "type": "VCPU"}, + {"value": "1000", "type": "MEMORY"}, + ], + "logConfiguration": { + "logDriver": "awslogs", + "options": {}, + "secretOptions": [], + }, + "secrets": [], + }, + }, + { + "targetNodes": "1", + "container": { + "image": "ghcr.io/pytorch/torchx:0.1.2dev0", + "command": ["echo", "your name"], + "volumes": [], + "environment": [ + {"name": "TORCHX_ROLE_IDX", "value": "1"}, + {"name": "TORCHX_REPLICA_IDX", "value": "0"}, + {"name": "TORCHX_ROLE_NAME", "value": "echo2"}, + ], + "mountPoints": [], + "ulimits": [], + "resourceRequirements": [ + {"value": "1", "type": "VCPU"}, + {"value": "1000", "type": "MEMORY"}, + ], + "logConfiguration": { + "logDriver": "awslogs", + "options": {}, + "secretOptions": [], + }, + "secrets": [], + }, + }, + ], + }, + "tags": { + "torchx.pytorch.org/version": "0.1.2dev0", + "torchx.pytorch.org/app-name": "echo", + }, + "platformCapabilities": [], + } + ] + } + + scheduler._log_client.get_log_events.return_value = { + "nextForwardToken": "some_token", + "events": [ + {"message": "foo"}, + {"message": "foobar"}, + {"message": "bar"}, + ], + } + + return scheduler + + @mock_rand() + def test_submit(self) -> None: + scheduler = self._mock_scheduler() + app = _test_app() + cfg = { + "queue": "testqueue", + } + + info = scheduler._submit_dryrun(app, cfg) + id = scheduler.schedule(info) + self.assertEqual(id, "testqueue:app-name-42") + self.assertEqual(scheduler._client.register_job_definition.call_count, 1) + self.assertEqual(scheduler._client.submit_job.call_count, 1) + + def test_describe(self) -> None: + scheduler = self._mock_scheduler() + status = scheduler.describe("testqueue:app-name-42") + self.assertIsNotNone(status) + self.assertEqual(status.state, specs.AppState.SUCCEEDED) + self.assertEqual(status.app_id, "testqueue:app-name-42") + self.assertEqual( + status.roles[0], + specs.Role( + name="echo", + num_replicas=1, + image="ghcr.io/pytorch/torchx:0.1.2dev0", + entrypoint="echo", + args=["your name"], + env={ + "TORCHX_ROLE_IDX": "0", + "TORCHX_REPLICA_IDX": "0", + "TORCHX_ROLE_NAME": "echo", + }, + ), + ) + + def test_log_iter(self) -> None: + scheduler = self._mock_scheduler() + logs = scheduler.log_iter("testqueue:app-name-42", "echo", k=1, regex="foo.*") + self.assertEqual( + list(logs), + [ + "foo", + "foobar", + ], + )