diff --git a/torchx/pipelines/kfp/adapter.py b/torchx/pipelines/kfp/adapter.py index 83b04be80..b98c0fac2 100644 --- a/torchx/pipelines/kfp/adapter.py +++ b/torchx/pipelines/kfp/adapter.py @@ -235,6 +235,7 @@ def container_from_app( def resource_from_app( app: api.AppDef, queue: str, + service_account: Optional[str] = None, ) -> dsl.ResourceOp: """ resource_from_app generates a KFP ResourceOp from the provided app that uses @@ -266,5 +267,5 @@ def resource_from_app( action="create", success_condition="status.state.phase = Completed", failure_condition="status.state.phase = Failed", - k8s_resource=app_to_resource(app, queue), + k8s_resource=app_to_resource(app, queue, service_account=service_account), ) diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index ac1515b28..363041706 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -157,7 +157,7 @@ def sanitize_for_serialization(obj: object) -> object: return api.sanitize_for_serialization(obj) -def role_to_pod(name: str, role: Role) -> "V1Pod": +def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod": from kubernetes.client.models import ( # noqa: F811 redefinition of unused V1Pod, V1PodSpec, @@ -207,6 +207,7 @@ def role_to_pod(name: str, role: Role) -> "V1Pod": spec=V1PodSpec( containers=[container], restart_policy="Never", + service_account_name=service_account, ), metadata=V1ObjectMeta( annotations={ @@ -232,7 +233,9 @@ def cleanup_str(data: str) -> str: return "".join(re.findall(pattern, data.lower())) -def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]: +def app_to_resource( + app: AppDef, queue: str, service_account: Optional[str] +) -> Dict[str, object]: """ app_to_resource creates a volcano job kubernetes resource definition from the provided AppDef. The resource definition can be used to launch the @@ -263,7 +266,7 @@ def app_to_resource(app: AppDef, queue: str) -> Dict[str, object]: if role_idx == 0 and replica_id == 0: replica_role.env["TORCHX_RANK0_HOST"] = "localhost" - pod = role_to_pod(name, replica_role) + pod = role_to_pod(name, replica_role, service_account) pod.metadata.labels.update(pod_labels(app, role_idx, role, replica_id)) task: Dict[str, Any] = { "replicas": 1, @@ -437,7 +440,12 @@ def _submit_dryrun( # map any local images to the remote image images_to_push = self._update_app_images(app, cfg) - resource = app_to_resource(app, queue) + service_account = cfg.get("service_account") + assert service_account is None or isinstance( + service_account, str + ), "service_account must be a str" + + resource = app_to_resource(app, queue, service_account) req = KubernetesJob( resource=resource, images_to_push=images_to_push, @@ -470,13 +478,21 @@ def run_opts(self) -> runopts: default="default", ) opts.add( - "queue", type_=str, help="Volcano queue to schedule job in", required=True + "queue", + type_=str, + help="Volcano queue to schedule job in", + required=True, ) opts.add( "image_repo", type_=str, help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container", ) + opts.add( + "service_account", + type_=str, + help="The service account name to set on the pod specs", + ) return opts def describe(self, app_id: str) -> Optional[DescribeAppResponse]: @@ -540,7 +556,7 @@ def log_iter( namespace, name = app_id.split(":") - pod_name = f"{name}-{role_name}-{k}-0" + pod_name = cleanup_str(f"{name}-{role_name}-{k}-0") args: Dict[str, object] = { "name": pod_name, diff --git a/torchx/schedulers/test/kubernetes_scheduler_test.py b/torchx/schedulers/test/kubernetes_scheduler_test.py index d1e512d76..3df1593ae 100644 --- a/torchx/schedulers/test/kubernetes_scheduler_test.py +++ b/torchx/schedulers/test/kubernetes_scheduler_test.py @@ -68,7 +68,7 @@ def test_app_to_resource_resolved_macros(self) -> None: "torchx.schedulers.kubernetes_scheduler.make_unique" ) as make_unique_ctx: make_unique_ctx.return_value = unique_app_name - resource = app_to_resource(app, "test_queue") + resource = app_to_resource(app, "test_queue", service_account=None) actual_cmd = ( # pyre-ignore [16] resource["spec"]["tasks"][0]["template"] @@ -88,7 +88,7 @@ def test_app_to_resource_resolved_macros(self) -> None: def test_retry_policy_not_set(self) -> None: app = _test_app() - resource = app_to_resource(app, "test_queue") + resource = app_to_resource(app, "test_queue", service_account=None) self.assertListEqual( [ {"event": "PodEvicted", "action": "RestartJob"}, @@ -99,7 +99,7 @@ def test_retry_policy_not_set(self) -> None: ) for role in app.roles: role.max_retries = 0 - resource = app_to_resource(app, "test_queue") + resource = app_to_resource(app, "test_queue", service_account=None) self.assertFalse("policies" in resource["spec"]["tasks"][0]) self.assertFalse("maxRetry" in resource["spec"]["tasks"][0]) @@ -115,7 +115,7 @@ def test_role_to_pod(self) -> None: ) app = _test_app() - pod = role_to_pod("name", app.roles[0]) + pod = role_to_pod("name", app.roles[0], service_account="srvacc") requests = { "cpu": "2000m", @@ -146,6 +146,7 @@ def test_role_to_pod(self) -> None: spec=V1PodSpec( containers=[container], restart_policy="Never", + service_account_name="srvacc", ), metadata=V1ObjectMeta( annotations={ @@ -298,6 +299,22 @@ def test_submit_dryrun_patch(self) -> None: }, ) + def test_submit_dryrun_service_account(self) -> None: + scheduler = create_scheduler("test") + self.assertIn("service_account", scheduler.run_opts()._opts) + app = _test_app() + cfg = { + "queue": "testqueue", + "service_account": "srvacc", + } + + info = scheduler._submit_dryrun(app, cfg) + self.assertIn("'service_account_name': 'srvacc'", str(info.request.resource)) + + del cfg["service_account"] + info = scheduler._submit_dryrun(app, cfg) + self.assertIn("service_account_name': None", str(info.request.resource)) + @patch("kubernetes.client.CustomObjectsApi.create_namespaced_custom_object") def test_submit(self, create_namespaced_custom_object: MagicMock) -> None: create_namespaced_custom_object.return_value = { @@ -426,6 +443,7 @@ def test_runopts(self) -> None: "queue", "namespace", "image_repo", + "service_account", }, ) @@ -452,7 +470,7 @@ def test_log_iter(self, read_namespaced_pod_log: MagicMock) -> None: read_namespaced_pod_log.return_value = "foo reg\nfoo\nbar reg\n" lines = scheduler.log_iter( app_id="testnamespace:testjob", - role_name="role", + role_name="role_blah", k=1, regex="reg", since=datetime.now(), @@ -472,7 +490,7 @@ def test_log_iter(self, read_namespaced_pod_log: MagicMock) -> None: kwargs, { "namespace": "testnamespace", - "name": "testjob-role-1-0", + "name": "testjob-roleblah-1-0", "timestamps": True, }, )