Skip to content

Commit c427d0b

Browse files
Logging post inference hook implementation (#428)
* add logging hook * plumb other columns through endpoint config * add args to unit tests * change to labels * add config and assume role * new storage gateway * add stream name config * fix test * undo conftest * handle error * fix test * fake streaming storage gateway * move client to fn * change to fake gateway * PR comments * catch err * update test * remove error in response * try small test * add more tests * fix test --------- Co-authored-by: Sai Atmakuri <[email protected]>
1 parent 5bff345 commit c427d0b

File tree

16 files changed

+372
-1
lines changed

16 files changed

+372
-1
lines changed

model-engine/model_engine_server/api/dependencies.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@
4343
LLMModelEndpointService,
4444
ModelEndpointService,
4545
)
46+
from model_engine_server.inference.domain.gateways.streaming_storage_gateway import (
47+
StreamingStorageGateway,
48+
)
49+
from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import (
50+
FirehoseStreamingStorageGateway,
51+
)
4652
from model_engine_server.infra.gateways import (
4753
CeleryTaskQueueGateway,
4854
FakeMonitoringMetricsGateway,
@@ -137,6 +143,7 @@ class ExternalInterfaces:
137143
cron_job_gateway: CronJobGateway
138144
monitoring_metrics_gateway: MonitoringMetricsGateway
139145
tokenizer_repository: TokenizerRepository
146+
streaming_storage_gateway: StreamingStorageGateway
140147

141148

142149
def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway:
@@ -265,6 +272,8 @@ def _get_external_interfaces(
265272

266273
tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway)
267274

275+
streaming_storage_gateway = FirehoseStreamingStorageGateway()
276+
268277
external_interfaces = ExternalInterfaces(
269278
docker_repository=docker_repository,
270279
model_bundle_repository=model_bundle_repository,
@@ -287,6 +296,7 @@ def _get_external_interfaces(
287296
cron_job_gateway=cron_job_gateway,
288297
monitoring_metrics_gateway=monitoring_metrics_gateway,
289298
tokenizer_repository=tokenizer_repository,
299+
streaming_storage_gateway=streaming_storage_gateway,
290300
)
291301
return external_interfaces
292302

model-engine/model_engine_server/common/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
BILLING_POST_INFERENCE_HOOK: str = "billing"
44
CALLBACK_POST_INFERENCE_HOOK: str = "callback"
5+
LOGGING_POST_INFERENCE_HOOK: str = "logging"
56
READYZ_FPATH: str = "/tmp/readyz"
67
DEFAULT_CELERY_TASK_NAME: str = "hosted_model_inference.inference.async_inference.tasks.predict"
78
LIRA_CELERY_TASK_NAME: str = "ml_serve.celery_service.exec_func"

model-engine/model_engine_server/core/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class InfraConfig:
4242
profile_ml_worker: str = "default"
4343
profile_ml_inference_worker: str = "default"
4444
identity_service_url: Optional[str] = None
45+
firehose_role_arn: Optional[str] = None
46+
firehose_stream_name: Optional[str] = None
4547

4648
@classmethod
4749
def from_yaml(cls, yaml_path) -> "InfraConfig":

model-engine/model_engine_server/domain/entities/model_endpoint_entity.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class ModelEndpointConfig(BaseModel):
8888
billing_tags: Optional[Dict[str, Any]] = None
8989
default_callback_url: Optional[str] = None
9090
default_callback_auth: Optional[CallbackAuth]
91+
endpoint_id: Optional[str] = None
92+
endpoint_type: Optional[ModelEndpointType]
93+
bundle_id: Optional[str] = None
94+
labels: Optional[Dict[str, str]] = None
9195

9296
def serialize(self) -> str:
9397
return python_json_to_b64(dict_not_none(**self.dict()))

model-engine/model_engine_server/domain/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,9 @@ class TriggerNameAlreadyExistsException(DomainException):
164164
"""
165165
Thrown if the requested name already exists in the trigger repository
166166
"""
167+
168+
169+
class StreamPutException(DomainException):
170+
"""
171+
Thrown if the streaming storage gateway fails to put a record.
172+
"""

model-engine/model_engine_server/inference/async_inference/tasks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
1919
DatadogInferenceMonitoringMetricsGateway,
2020
)
21+
from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import (
22+
FirehoseStreamingStorageGateway,
23+
)
2124
from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler
2225

2326
logger = make_logger(logger_name())
@@ -46,6 +49,11 @@ def init_worker_global():
4649
default_callback_url=endpoint_config.default_callback_url,
4750
default_callback_auth=endpoint_config.default_callback_auth,
4851
monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(),
52+
endpoint_id=endpoint_config.endpoint_id,
53+
endpoint_type=endpoint_config.endpoint_type,
54+
bundle_id=endpoint_config.bundle_id,
55+
labels=endpoint_config.labels,
56+
streaming_storage_gateway=FirehoseStreamingStorageGateway(),
4957
)
5058
# k8s health check
5159
with open(READYZ_FPATH, "w") as f:
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any, Dict
3+
4+
5+
class StreamingStorageGateway(ABC):
6+
"""
7+
Base class for a gateway that stores data through a streaming mechanism.
8+
"""
9+
10+
@abstractmethod
11+
def put_record(self, stream_name: str, record: Dict[str, Any]) -> None:
12+
"""
13+
Put a record into a streaming storage mechanism.
14+
15+
Args:
16+
stream_name: The name of the stream.
17+
record: The record to put into the stream.
18+
"""
19+
pass

model-engine/model_engine_server/inference/forwarding/forwarding.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
1515
DatadogInferenceMonitoringMetricsGateway,
1616
)
17+
from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import (
18+
FirehoseStreamingStorageGateway,
19+
)
1720
from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler
1821

1922
__all__: Sequence[str] = (
@@ -279,6 +282,11 @@ def endpoint(route: str) -> str:
279282
default_callback_url=endpoint_config.default_callback_url,
280283
default_callback_auth=endpoint_config.default_callback_auth,
281284
monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(),
285+
endpoint_id=endpoint_config.endpoint_id,
286+
endpoint_type=endpoint_config.endpoint_type,
287+
bundle_id=endpoint_config.bundle_id,
288+
labels=endpoint_config.labels,
289+
streaming_storage_gateway=FirehoseStreamingStorageGateway(),
282290
)
283291

284292
return Forwarder(
@@ -451,6 +459,11 @@ def endpoint(route: str) -> str:
451459
default_callback_url=endpoint_config.default_callback_url,
452460
default_callback_auth=endpoint_config.default_callback_auth,
453461
monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(),
462+
endpoint_id=endpoint_config.endpoint_id,
463+
endpoint_type=endpoint_config.endpoint_type,
464+
bundle_id=endpoint_config.bundle_id,
465+
labels=endpoint_config.labels,
466+
streaming_storage_gateway=FirehoseStreamingStorageGateway(),
454467
)
455468

456469
return StreamingForwarder(
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
import boto3
5+
from model_engine_server.core.config import infra_config
6+
from model_engine_server.core.loggers import logger_name, make_logger
7+
from model_engine_server.domain.exceptions import StreamPutException
8+
from model_engine_server.inference.domain.gateways.streaming_storage_gateway import (
9+
StreamingStorageGateway,
10+
)
11+
12+
logger = make_logger(logger_name())
13+
14+
15+
class FirehoseStreamingStorageGateway(StreamingStorageGateway):
16+
"""
17+
A gateway that stores data through the AWS Kinesis Firehose streaming mechanism.
18+
"""
19+
20+
def __init__(self):
21+
pass
22+
23+
"""
24+
Creates a new firehose client.
25+
26+
Streams with Snowflake as a destination and the AWS profile live in different
27+
accounts. Firehose doesn't support resource-based policies, so we need to assume
28+
a new role to write to the stream.
29+
"""
30+
31+
def _get_firehose_client(self):
32+
sts_client = boto3.client("sts", region_name=infra_config().default_region)
33+
assumed_role_object = sts_client.assume_role(
34+
RoleArn=infra_config().firehose_role_arn,
35+
RoleSessionName="AssumeMlLoggingRoleSession",
36+
)
37+
credentials = assumed_role_object["Credentials"]
38+
session = boto3.Session(
39+
aws_access_key_id=credentials["AccessKeyId"],
40+
aws_secret_access_key=credentials["SecretAccessKey"],
41+
aws_session_token=credentials["SessionToken"],
42+
)
43+
firehose_client = session.client("firehose", region_name=infra_config().default_region)
44+
return firehose_client
45+
46+
def put_record(self, stream_name: str, record: Dict[str, Any]) -> None:
47+
"""
48+
Put a record into a Firehose stream.
49+
50+
Args:
51+
stream_name: The name of the stream.
52+
record: The record to put into the stream.
53+
"""
54+
firehose_response = self._get_firehose_client().put_record(
55+
DeliveryStreamName=stream_name, Record={"Data": json.dumps(record).encode("utf-8")}
56+
)
57+
if firehose_response["ResponseMetadata"]["HTTPStatusCode"] != 200:
58+
raise StreamPutException(
59+
f"Failed to put record into firehose stream {stream_name}. Record content: {record}"
60+
)
61+
logger.info(
62+
f"Logged to firehose stream {stream_name}. Record content: {record}, Record ID: {firehose_response['RecordId']}"
63+
)

model-engine/model_engine_server/inference/post_inference_hooks.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44

55
import requests
66
from fastapi.responses import JSONResponse
7-
from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK
7+
from model_engine_server.common.constants import (
8+
CALLBACK_POST_INFERENCE_HOOK,
9+
LOGGING_POST_INFERENCE_HOOK,
10+
)
811
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
12+
from model_engine_server.core.config import infra_config
913
from model_engine_server.core.loggers import logger_name, make_logger
1014
from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth
15+
from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpointType
16+
from model_engine_server.domain.exceptions import StreamPutException
1117
from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import (
1218
InferenceMonitoringMetricsGateway,
1319
)
20+
from model_engine_server.inference.domain.gateways.streaming_storage_gateway import (
21+
StreamingStorageGateway,
22+
)
1423
from tenacity import Retrying, stop_after_attempt, wait_exponential
1524

1625
logger = make_logger(logger_name())
@@ -76,6 +85,61 @@ def handle(
7685
assert 200 <= res.status_code < 300
7786

7887

88+
class LoggingHook(PostInferenceHook):
89+
def __init__(
90+
self,
91+
endpoint_name: str,
92+
bundle_name: str,
93+
user_id: str,
94+
endpoint_id: Optional[str],
95+
endpoint_type: Optional[ModelEndpointType],
96+
bundle_id: Optional[str],
97+
labels: Optional[Dict[str, str]],
98+
streaming_storage_gateway: StreamingStorageGateway,
99+
):
100+
super().__init__(endpoint_name, bundle_name, user_id)
101+
self._endpoint_id = endpoint_id
102+
self._endpoint_type = endpoint_type
103+
self._bundle_id = bundle_id
104+
self._labels = labels
105+
self._streaming_storage_gateway = streaming_storage_gateway
106+
107+
def handle(
108+
self,
109+
request_payload: EndpointPredictV1Request,
110+
response: Dict[str, Any],
111+
task_id: Optional[str],
112+
):
113+
if (
114+
not self._endpoint_id
115+
or not self._endpoint_type
116+
or not self._bundle_id
117+
or not self._labels
118+
):
119+
logger.warning(
120+
"No endpoint_id, endpoint_type, bundle_id, or labels specified for request."
121+
)
122+
return
123+
response["task_id"] = task_id
124+
data_record = {
125+
"REQUEST_BODY": request_payload.json(),
126+
"RESPONSE_BODY": response,
127+
"ENDPOINT_ID": self._endpoint_id,
128+
"ENDPOINT_NAME": self._endpoint_name,
129+
"ENDPOINT_TYPE": self._endpoint_type.value,
130+
"BUNDLE_ID": self._bundle_id,
131+
"LABELS": self._labels,
132+
}
133+
stream_name = infra_config().firehose_stream_name
134+
if stream_name is None:
135+
logger.warning("No firehose stream name specified. Logging hook will not be executed.")
136+
return
137+
try:
138+
self._streaming_storage_gateway.put_record(stream_name=stream_name, record=data_record)
139+
except StreamPutException as e:
140+
logger.error(f"Error in logging hook {e}")
141+
142+
79143
class PostInferenceHooksHandler:
80144
def __init__(
81145
self,
@@ -88,6 +152,11 @@ def __init__(
88152
default_callback_auth: Optional[CallbackAuth],
89153
post_inference_hooks: Optional[List[str]],
90154
monitoring_metrics_gateway: InferenceMonitoringMetricsGateway,
155+
endpoint_id: Optional[str],
156+
endpoint_type: Optional[ModelEndpointType],
157+
bundle_id: Optional[str],
158+
labels: Optional[Dict[str, str]],
159+
streaming_storage_gateway: StreamingStorageGateway,
91160
):
92161
self._monitoring_metrics_gateway = monitoring_metrics_gateway
93162
self._hooks: Dict[str, PostInferenceHook] = {}
@@ -104,6 +173,17 @@ def __init__(
104173
default_callback_url,
105174
default_callback_auth,
106175
)
176+
elif hook_lower == LOGGING_POST_INFERENCE_HOOK:
177+
self._hooks[hook_lower] = LoggingHook(
178+
endpoint_name,
179+
bundle_name,
180+
user_id,
181+
endpoint_id,
182+
endpoint_type,
183+
bundle_id,
184+
labels,
185+
streaming_storage_gateway,
186+
)
107187
else:
108188
raise ValueError(f"Hook {hook_lower} is currently not supported.")
109189

0 commit comments

Comments
 (0)