Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from celery import Celery, Task, states
from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME
from model_engine_server.common.dtos.model_endpoints import BrokerType
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.celery import TaskVisibility, celery_app
from model_engine_server.core.config import infra_config
from model_engine_server.core.loggers import logger_name, make_logger
Expand All @@ -25,45 +26,6 @@ class ErrorResponse(TypedDict):
error_metadata: str


class ErrorHandlingTask(Task):
"""Sets a 'custom' field with error in the Task response for FAILURE.

Used when services are ran via the Celery backend.
"""

def after_return(
self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo
) -> None:
"""Handler that ensures custom error response information is available whenever a Task fails.

Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value
:param:`retval` is an `Exception`, this handler extracts information from the `Exception`
and constructs a custom error response JSON value (see :func:`error_response` for details).

This handler then re-propagates the Celery-required exception information (`"exc_type"` and
`"exc_message"`) while adding this new error response information under the `"custom"` key.
"""
if status == states.FAILURE and isinstance(retval, Exception):
logger.warning(f"Setting custom error response for failed task {task_id}")

info: dict = raw_celery_response(self.backend, task_id)
result: dict = info["result"]
err: Exception = retval

error_payload = error_response("Internal failure", err)

# Inspired by pattern from:
# https://www.distributedpython.com/2018/09/28/celery-task-states/
self.update_state(
state=states.FAILURE,
meta={
"exc_type": result["exc_type"],
"exc_message": result["exc_message"],
"custom": json.dumps(error_payload, indent=False),
},
)


def raw_celery_response(backend, task_id: str) -> Dict[str, Any]:
key_info: str = backend.get_key_for_task(task_id)
info_as_str: str = backend.get(key_info)
Expand Down Expand Up @@ -103,6 +65,47 @@ def create_celery_service(
else None,
)

class ErrorHandlingTask(Task):
"""Sets a 'custom' field with error in the Task response for FAILURE.

Used when services are ran via the Celery backend.
"""

def after_return(
self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo
) -> None:
"""Handler that ensures custom error response information is available whenever a Task fails.

Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value
:param:`retval` is an `Exception`, this handler extracts information from the `Exception`
and constructs a custom error response JSON value (see :func:`error_response` for details).

This handler then re-propagates the Celery-required exception information (`"exc_type"` and
`"exc_message"`) while adding this new error response information under the `"custom"` key.
"""
if status == states.FAILURE and isinstance(retval, Exception):
logger.warning(f"Setting custom error response for failed task {task_id}")

info: dict = raw_celery_response(self.backend, task_id)
result: dict = info["result"]
err: Exception = retval

error_payload = error_response("Internal failure", err)

# Inspired by pattern from:
# https://www.distributedpython.com/2018/09/28/celery-task-states/
self.update_state(
state=states.FAILURE,
meta={
"exc_type": result["exc_type"],
"exc_message": result["exc_message"],
"custom": json.dumps(error_payload, indent=False),
},
)
request_params = args[0]
request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params)
forwarder.post_inference_hooks_handler.handle(request_params_pydantic, retval, task_id) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any way to pass in forwarder to the class rather than using encapsulation? maybe the @app.task decorator allows custom args to be passed into the base class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I originally tried that approach but sadly you can't pass in custom args into the @app.task decorator: https://docs.celeryq.dev/en/stable/userguide/tasks.html#list-of-options


# See documentation for options:
# https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options
@app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sseclient
import yaml
from fastapi.responses import JSONResponse
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.inference.common import get_endpoint_config
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
Expand Down Expand Up @@ -126,7 +125,6 @@ class Forwarder(ModelEngineSerializationMixin):
forward_http_status: bool

def __call__(self, json_payload: Any) -> Any:
request_obj = EndpointPredictV1Request.parse_obj(json_payload)
json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload)
json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload

Expand Down Expand Up @@ -163,8 +161,6 @@ def __call__(self, json_payload: Any) -> Any:
if self.wrap_response:
response = self.get_response_payload(using_serialize_results_as_string, response)

# TODO: we actually want to do this after we've returned the response.
self.post_inference_hooks_handler.handle(request_obj, response)
if self.forward_http_status:
return JSONResponse(content=response, status_code=response_raw.status_code)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import subprocess
from functools import lru_cache

from fastapi import Depends, FastAPI
from fastapi import BackgroundTasks, Depends, FastAPI
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.loggers import logger_name, make_logger
Expand Down Expand Up @@ -70,11 +70,20 @@ def load_streaming_forwarder():
@app.post("/predict")
def predict(
request: EndpointPredictV1Request,
background_tasks: BackgroundTasks,
forwarder=Depends(load_forwarder),
limiter=Depends(get_concurrency_limiter),
):
with limiter:
return forwarder(request.dict())
try:
response = forwarder(request.dict())
background_tasks.add_task(
forwarder.post_inference_hooks_handler.handle, request, response
)
return response
except Exception:
logger.error(f"Failed to decode payload from: {request}")
raise


@app.post("/stream")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

import requests
from fastapi.responses import JSONResponse
from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.core.loggers import logger_name, make_logger
Expand Down Expand Up @@ -108,13 +110,17 @@ def __init__(
def handle(
self,
request_payload: EndpointPredictV1Request,
response: Dict[str, Any],
response: Union[Dict[str, Any], JSONResponse],
task_id: Optional[str] = None,
):
if isinstance(response, JSONResponse):
loaded_response = json.loads(response.body)
else:
loaded_response = response
for hook_name, hook in self._hooks.items():
self._monitoring_metrics_gateway.emit_attempted_post_inference_hook(hook_name)
try:
hook.handle(request_payload, response, task_id)
hook.handle(request_payload, loaded_response, task_id) # pragma: no cover
self._monitoring_metrics_gateway.emit_successful_post_inference_hook(hook_name)
except Exception:
logger.exception(f"Hook {hook_name} failed.")
94 changes: 87 additions & 7 deletions model-engine/tests/unit/inference/test_http_forwarder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import threading
import time
from dataclasses import dataclass
from typing import Mapping
from unittest import mock

import pytest
from fastapi import BackgroundTasks
from fastapi.responses import JSONResponse
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
from model_engine_server.inference.forwarding.forwarding import Forwarder
from model_engine_server.inference.forwarding.http_forwarder import (
MultiprocessingConcurrencyLimiter,
predict,
)
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
DatadogInferenceMonitoringMetricsGateway,
)
from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler

PAYLOAD: Mapping[str, str] = {"hello": "world"}


class ExceptionCapturedThread(threading.Thread):
Expand All @@ -26,21 +37,90 @@ def join(self):
raise self.ex


def mock_forwarder(dict):
time.sleep(1)
return dict
def mocked_get(*args, **kwargs): # noqa
@dataclass
class mocked_static_status_code:
status_code: int = 200

return mocked_static_status_code()


def mocked_post(*args, **kwargs): # noqa
@dataclass
class mocked_static_json:
status_code: int = 200

def json(self) -> dict:
return PAYLOAD # type: ignore

return mocked_static_json()


@pytest.fixture
def post_inference_hooks_handler():
handler = PostInferenceHooksHandler(
endpoint_name="test_endpoint_name",
bundle_name="test_bundle_name",
post_inference_hooks=[],
user_id="test_user_id",
billing_queue="billing_queue",
billing_tags=[],
default_callback_url=None,
default_callback_auth=None,
monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(),
)
return handler


@pytest.fixture
def mock_request():
return EndpointPredictV1Request(
url="test_url",
return_pickled=False,
args={"x": 1},
)

def test_http_service_429():

@mock.patch("requests.post", mocked_post)
@mock.patch("requests.get", mocked_get)
def test_http_service_429(mock_request, post_inference_hooks_handler):
mock_forwarder = Forwarder(
"ignored",
model_engine_unwrap=True,
serialize_results_as_string=False,
post_inference_hooks_handler=post_inference_hooks_handler,
wrap_response=True,
forward_http_status=True,
)
limiter = MultiprocessingConcurrencyLimiter(1, True)
t1 = ExceptionCapturedThread(
target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter)
target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter)
)
t2 = ExceptionCapturedThread(
target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter)
target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter)
)
t1.start()
t2.start()
t1.join()
with pytest.raises(Exception): # 429 thrown
t2.join()


def test_handler_response(post_inference_hooks_handler):
try:
post_inference_hooks_handler.handle(
request_payload=mock_request, response=PAYLOAD, task_id="test_task_id"
)
except Exception as e:
pytest.fail(f"Unexpected exception: {e}")


def test_handler_json_response(post_inference_hooks_handler):
try:
post_inference_hooks_handler.handle(
request_payload=mock_request,
response=JSONResponse(content=PAYLOAD),
task_id="test_task_id",
)
except Exception as e:
pytest.fail(f"Unexpected exception: {e}")