Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import (
LAUNCH_JOB_ID_LABEL_SELECTOR,
_parse_job_status_from_k8s_obj,
make_job_id_to_pods_mapping,
)
from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import (
get_kubernetes_batch_client,
get_kubernetes_core_client,
load_k8s_yaml,
maybe_load_kube_config,
)
Expand Down Expand Up @@ -97,14 +99,30 @@ async def list_jobs(
logger.exception("Got an exception when trying to list the Jobs")
raise EndpointResourceInfraException from exc

core_client = get_kubernetes_core_client()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently this fn is used in the list docker image batch jobs api call for some reason, feels wrong to me, idk how we didn't catch this before, but oh well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you like to the code that is the issue? Is the problem that we are breaking abstraction layers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_engine_server/domain/use_cases/batch_job_use_cases.py:ListDockerImageBatchJobV1UseCase.execute

This feels like we're breaking abstraction layers to me at least (e.g. batch jobs and cron jobs should be different), although I guess that broken abstraction gets propagated through the API as well (since the ListBatchJobs has a trigger_id parameter)


try:
label_selector = f"trigger_id={trigger_id}" if trigger_id else f"owner={owner},job-name"
pods = await core_client.list_namespaced_pod(
namespace=hmi_config.endpoint_namespace,
label_selector=label_selector,
)
except ApiException as exc:
logger.exception("Got an exception when trying to list the Pods")
raise EndpointResourceInfraException from exc

pods_per_job = make_job_id_to_pods_mapping(pods.items)

return [
DockerImageBatchJob(
id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR),
created_by=job.metadata.labels.get("created_by"),
owner=job.metadata.labels.get("owner"),
created_at=job.metadata.creation_timestamp,
completed_at=job.status.completion_time,
status=_parse_job_status_from_k8s_obj(job),
status=_parse_job_status_from_k8s_obj(
job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)]
),
)
for job in jobs.items
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union

from kubernetes_asyncio.client.models.v1_job import V1Job
from kubernetes_asyncio.client.models.v1_pod import V1Pod
from kubernetes_asyncio.client.rest import ApiException
from model_engine_server.common.config import hmi_config
from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests
Expand All @@ -17,6 +19,7 @@
)
from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import (
get_kubernetes_batch_client,
get_kubernetes_core_client,
load_k8s_yaml,
maybe_load_kube_config,
)
Expand Down Expand Up @@ -84,7 +87,7 @@ def _k8s_job_name_from_id(job_id: str):
return f"launch-di-batch-job-{job_id}"


def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus:
def _parse_job_status_from_k8s_obj(job: V1Job, pods: List[V1Pod]) -> BatchJobStatus:
status = job.status
# these counts are the number of pods in some given status
if status.failed is not None and status.failed > 0:
Expand All @@ -94,10 +97,30 @@ def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus:
if status.ready is not None and status.ready > 0:
return BatchJobStatus.RUNNING # empirically this doesn't happen
if status.active is not None and status.active > 0:
return BatchJobStatus.RUNNING # TODO this might be a mix of pending and running
for pod in pods:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth leaving a comment here on why a single Job resource could have multiple pods due to clarify the logic here.

# In case there are multiple pods for a given job (e.g. if a pod gets shut down)
# let's interpret the job as running if any of the pods are running
# I haven't empirically seen this, but guard against it just in case.
if pod.status.phase == "Running":
return BatchJobStatus.RUNNING
return BatchJobStatus.PENDING
return BatchJobStatus.PENDING


def make_job_id_to_pods_mapping(pods: List[V1Pod]) -> defaultdict:
"""
Returns a defaultdict mapping job IDs to pods
"""
job_id_to_pods_mapping = defaultdict(list)
for pod in pods:
job_id = pod.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)
if job_id is not None:
job_id_to_pods_mapping[job_id].append(pod)
else:
logger.warning(f"Pod {pod.metadata.name} has no job ID label")
return job_id_to_pods_mapping


class LiveDockerImageBatchJobGateway(DockerImageBatchJobGateway):
def __init__(self):
pass
Expand Down Expand Up @@ -282,10 +305,21 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker
logger.exception("Got an exception when trying to read the Job")
raise EndpointResourceInfraException from exc

core_client = get_kubernetes_core_client()
try:
pods = await core_client.list_namespaced_pod(
namespace=hmi_config.endpoint_namespace,
label_selector=f"{LAUNCH_JOB_ID_LABEL_SELECTOR}={batch_job_id}",
)
except ApiException as exc:
logger.exception("Got an exception when trying to read pods for the Job")
raise EndpointResourceInfraException from exc
# This pod list isn't always needed, but it's simpler code-wise to always make the request

job_labels = job.metadata.labels
annotations = job.metadata.annotations

status = _parse_job_status_from_k8s_obj(job)
status = _parse_job_status_from_k8s_obj(job, pods.items)

return DockerImageBatchJob(
id=batch_job_id,
Expand All @@ -309,6 +343,19 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc
logger.exception("Got an exception when trying to list the Jobs")
raise EndpointResourceInfraException from exc

core_client = get_kubernetes_core_client()
try:
pods = await core_client.list_namespaced_pod(
namespace=hmi_config.endpoint_namespace,
label_selector=f"{OWNER_LABEL_SELECTOR}={owner},job-name", # get only pods associated with a job
)
except ApiException as exc:
logger.exception("Got an exception when trying to read pods for the Job")
raise EndpointResourceInfraException from exc

# Join jobs + pods
pods_per_job = make_job_id_to_pods_mapping(pods.items)

return [
DockerImageBatchJob(
id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR),
Expand All @@ -317,7 +364,9 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc
created_at=job.metadata.creation_timestamp,
completed_at=job.status.completion_time,
annotations=job.metadata.annotations,
status=_parse_job_status_from_k8s_obj(job),
status=_parse_job_status_from_k8s_obj(
job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)]
),
)
for job in jobs.items
]
Expand Down
67 changes: 67 additions & 0 deletions model-engine/tests/unit/infra/gateways/k8s_fake_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Various fake k8s objects to be used in mocking out the python k8s api client
# Only classes are defined here. If you need to add various fields to the classes, please do so here.

from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Optional


@dataclass
class FakeK8sV1ObjectMeta:
name: str = "fake_name"
namespace: str = "fake_namespace"
annotations: dict = field(default_factory=dict)
labels: dict = field(default_factory=dict)
creation_timestamp: datetime = datetime(2021, 1, 1, 0, 0, 0, 0)
# TODO: everything else


@dataclass
class FakeK8sV1PodStatus:
phase: str = "Running"
# TODO: everything else


@dataclass
class FakeK8sV1JobStatus:
active: int = 0
succeeded: int = 0
failed: int = 0
ready: int = 0
terminating: int = 0
completion_time: Optional[datetime] = None


@dataclass
class FakeK8sV1Job:
metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta()
status: FakeK8sV1JobStatus = FakeK8sV1JobStatus()
# TODO: spec, api_version, kind


@dataclass
class FakeK8sV1JobList:
items: List[FakeK8sV1Job] = field(default_factory=list)


@dataclass
class FakeK8sV1Pod:
metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta()
status: FakeK8sV1PodStatus = FakeK8sV1PodStatus()
# TODO: spec, api_version, kind


@dataclass
class FakeK8sV1PodList:
items: List[FakeK8sV1Pod] = field(default_factory=list)


@dataclass
class FakeK8sEnvVar:
name: str
value: str


@dataclass
class FakeK8sDeploymentContainer:
env: List[FakeK8sEnvVar]
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Any, Dict, List
from unittest.mock import AsyncMock, Mock, patch

Expand All @@ -25,21 +24,11 @@
DictStrStr,
ResourceArguments,
)
from tests.unit.infra.gateways.k8s_fake_objects import FakeK8sDeploymentContainer, FakeK8sEnvVar

MODULE_PATH = "model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate"


@dataclass
class FakeK8sEnvVar:
name: str
value: str


@dataclass
class FakeK8sDeploymentContainer:
env: List[FakeK8sEnvVar]


@pytest.fixture
def mock_get_kubernetes_cluster_version():
mock_version = "1.26"
Expand Down
Loading