diff --git a/src/f5_ai_gateway_sdk/result.py b/src/f5_ai_gateway_sdk/result.py index 7b8ec67..697386c 100644 --- a/src/f5_ai_gateway_sdk/result.py +++ b/src/f5_ai_gateway_sdk/result.py @@ -156,6 +156,7 @@ class Reject(BaseModel): # be added to the metadata field metadata: Metadata = Field(default=Metadata(), exclude=True) tags: Tags = Field(default=Tags(), exclude=True) + processor_result: Metadata | None = None def is_empty(self) -> bool: """Compatability with Result(), always false due to required fields""" @@ -169,6 +170,8 @@ def to_response(self) -> Response: """Return Reject as Response object""" if self.tags: self.metadata["tags"] = self.tags.to_response() + if self.processor_result: + self.metadata["processor_result"] = self.processor_result return MultipartResponse( fields=[ self.to_multipart_field(), diff --git a/tests/test_processor.py b/tests/test_processor.py index df7005a..f8851b3 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -446,6 +446,54 @@ async def test_handle_rejected_prompt(self): self.assertDictEqual(expected_response_metadata, response_metadata) + async def test_handle_rejected_prompt_with_result(self): + prompt = TEST_REQ_INPUT.model_dump_json() + metadata = {"key": "value", "step_id": "12345", "request_id": "09876"} + request = fake_multipart_request( + prompt=prompt, + metadata=metadata, + parameters={"reject": True, "annotate": True}, + ) + result = Reject( + metadata=metadata, + code=RejectCode.POLICY_VIOLATION, + detail="dangerous question asked", + tags=FAKE_TAGS, + processor_result={"confidence": 0.99}, + ) + processor = fake_processor(result=result) + + response = await processor.handle_request(request) + + self.assertStatusCodeEqual(response, HTTP_200_OK) + + content = await self.buffer_response(response) + multipart = MultipartDecoderHelper( + content=content, content_type=response.headers["Content-Type"] + ) + + self.assertFalse( + multipart.has_prompt(), "the rejected prompt should not be in the response" + ) + + multipart_metadata = multipart.metadata + self.assertEqual( + MultipartResponse.JSON_CONTENT_TYPE, multipart_metadata.content_type() + ) + response_metadata = multipart_metadata.as_json() + + expected_response_metadata = dict( + app_details=APP_DETAILS, + processor_id=processor.id(), + processor_result={"confidence": 0.99}, + processor_version=processor.version, + tags={"test1": ["a", "b"]}, + ) + for k, v in metadata.items(): + expected_response_metadata[k] = v + + self.assertDictEqual(expected_response_metadata, response_metadata) + async def test_handle_modified_prompt(self): prompt = TEST_REQ_INPUT.model_dump_json() metadata = {"key": "value"}