Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class StreamingStorageGateway(ABC):
"""

@abstractmethod
def put_record(self, stream_name: str, record: Dict[str, Any]) -> None:
def put_record(self, stream_name: str, record: Dict[str, Any]) -> Dict[str, Any]:
"""
Put a record into a streaming storage mechanism.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_firehose_client(self):
firehose_client = session.client("firehose", region_name=infra_config().default_region)
return firehose_client

def put_record(self, stream_name: str, record: Dict[str, Any]) -> None:
def put_record(self, stream_name: str, record: Dict[str, Any]) -> Dict[str, Any]:
"""
Put a record into a Firehose stream.

Expand All @@ -56,8 +56,9 @@ def put_record(self, stream_name: str, record: Dict[str, Any]) -> None:
)
if firehose_response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise StreamPutException(
f"Failed to put record into firehose stream {stream_name}. Record content: {record}"
f"Failed to put record into firehose stream {stream_name}. Response metadata {firehose_response['ResponseMetadata']}."
)
logger.info(
f"Logged to firehose stream {stream_name}. Record content: {record}, Record ID: {firehose_response['RecordId']}"
f"Logged to firehose stream {stream_name}. Record ID: {firehose_response['RecordId']}. Task ID: {record['RESPONSE_BODY']['task_id']}"
)
return firehose_response
13 changes: 10 additions & 3 deletions model-engine/model_engine_server/inference/post_inference_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,17 @@ def handle(
if stream_name is None:
logger.warning("No firehose stream name specified. Logging hook will not be executed.")
return
streaming_storage_response = {} # pragma: no cover
try:
self._streaming_storage_gateway.put_record(stream_name=stream_name, record=data_record)
except StreamPutException as e:
logger.error(f"Error in logging hook {e}")
streaming_storage_response = (
self._streaming_storage_gateway.put_record( # pragma: no cover
stream_name=stream_name, record=data_record
)
)
except StreamPutException: # pragma: no cover
logger.error( # pragma: no cover
f"Failed to put record into firehose stream {stream_name}. Response metadata {streaming_storage_response.get('ResponseMetadata')}."
)


class PostInferenceHooksHandler:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

stream_name = "fake-stream"

return_value = {
"RecordId": "fake-record-id",
"Encrypted": False,
"ResponseMetadata": {"HTTPStatusCode": 200},
}


@pytest.fixture
def streaming_storage_gateway():
Expand All @@ -17,7 +23,7 @@ def streaming_storage_gateway():

@pytest.fixture
def fake_record():
return {"Data": "fake-data"}
return {"RESPONSE_BODY": {"task_id": "fake-task-id"}}


def mock_sts_client(*args, **kwargs):
Expand All @@ -34,11 +40,7 @@ def mock_sts_client(*args, **kwargs):

def mock_firehose_client(*args, **kwargs):
mock_client = mock.Mock()
mock_client.put_record.return_value = {
"RecordId": "fake-record-id",
"Encrypted": False,
"ResponseMetadata": {"HTTPStatusCode": 200},
}
mock_client.put_record.return_value = return_value
return mock_client


Expand Down Expand Up @@ -76,7 +78,7 @@ def test_firehose_streaming_storage_gateway_put_record(streaming_storage_gateway
"model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session",
mock_session,
):
assert streaming_storage_gateway.put_record(stream_name, fake_record) is None
assert streaming_storage_gateway.put_record(stream_name, fake_record) is return_value


def test_firehose_streaming_storage_gateway_put_record_with_exception(
Expand Down