Skip to content

Commit ce3ead6

Browse files
Log applied guardrails on LLM API call (#8452)
* fix(litellm_logging.py): support saving applied guardrails in logging object allows list of applied guardrails to be logged for proxy admin's knowledge * feat(spend_tracking_utils.py): log applied guardrails to spend logs makes it easy for admin to know what guardrails were applied on a request * ci(config.yml): uninstall posthog from ci/cd * test: fix tests * test: update test
1 parent 8e32713 commit ce3ead6

File tree

8 files changed

+152
-10
lines changed

8 files changed

+152
-10
lines changed

litellm/litellm_core_utils/litellm_logging.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def __init__(
199199
dynamic_async_failure_callbacks: Optional[
200200
List[Union[str, Callable, CustomLogger]]
201201
] = None,
202+
applied_guardrails: Optional[List[str]] = None,
202203
kwargs: Optional[Dict] = None,
203204
):
204205
_input: Optional[str] = messages # save original value of messages
@@ -271,6 +272,7 @@ def __init__(
271272
"litellm_call_id": litellm_call_id,
272273
"input": _input,
273274
"litellm_params": litellm_params,
275+
"applied_guardrails": applied_guardrails,
274276
}
275277

276278
def process_dynamic_callbacks(self):
@@ -2852,6 +2854,7 @@ def get_standard_logging_metadata(
28522854
metadata: Optional[Dict[str, Any]],
28532855
litellm_params: Optional[dict] = None,
28542856
prompt_integration: Optional[str] = None,
2857+
applied_guardrails: Optional[List[str]] = None,
28552858
) -> StandardLoggingMetadata:
28562859
"""
28572860
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
@@ -2866,6 +2869,7 @@ def get_standard_logging_metadata(
28662869
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
28672870
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
28682871
"""
2872+
28692873
prompt_management_metadata: Optional[
28702874
StandardLoggingPromptManagementMetadata
28712875
] = None
@@ -2895,6 +2899,7 @@ def get_standard_logging_metadata(
28952899
requester_metadata=None,
28962900
user_api_key_end_user_id=None,
28972901
prompt_management_metadata=prompt_management_metadata,
2902+
applied_guardrails=applied_guardrails,
28982903
)
28992904
if isinstance(metadata, dict):
29002905
# Filter the metadata dictionary to include only the specified keys
@@ -3193,6 +3198,7 @@ def get_standard_logging_object_payload(
31933198
metadata=metadata,
31943199
litellm_params=litellm_params,
31953200
prompt_integration=kwargs.get("prompt_integration", None),
3201+
applied_guardrails=kwargs.get("applied_guardrails", None),
31963202
)
31973203

31983204
_request_body = proxy_server_request.get("body", {})
@@ -3328,6 +3334,7 @@ def get_standard_logging_metadata(
33283334
requester_metadata=None,
33293335
user_api_key_end_user_id=None,
33303336
prompt_management_metadata=None,
3337+
applied_guardrails=None,
33313338
)
33323339
if isinstance(metadata, dict):
33333340
# Filter the metadata dictionary to include only the specified keys

litellm/proxy/_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,6 +1794,7 @@ class SpendLogsMetadata(TypedDict):
17941794
dict
17951795
] # special param to log k,v pairs to spendlogs for a call
17961796
requester_ip_address: Optional[str]
1797+
applied_guardrails: Optional[List[str]]
17971798

17981799

17991800
class SpendLogsPayload(TypedDict):

litellm/proxy/spend_tracking/spend_tracking_utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from datetime import datetime
44
from datetime import datetime as dt
55
from datetime import timezone
6-
from typing import Optional, cast
6+
from typing import List, Optional, cast
77

88
from pydantic import BaseModel
99

@@ -32,7 +32,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
3232
return False
3333

3434

35-
def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
35+
def _get_spend_logs_metadata(
36+
metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
37+
) -> SpendLogsMetadata:
3638
if metadata is None:
3739
return SpendLogsMetadata(
3840
user_api_key=None,
@@ -44,8 +46,9 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
4446
spend_logs_metadata=None,
4547
requester_ip_address=None,
4648
additional_usage_values=None,
49+
applied_guardrails=None,
4750
)
48-
verbose_proxy_logger.debug(
51+
verbose_proxy_logger.info(
4952
"getting payload for SpendLogs, available keys in metadata: "
5053
+ str(list(metadata.keys()))
5154
)
@@ -58,6 +61,8 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
5861
if key in metadata
5962
}
6063
)
64+
clean_metadata["applied_guardrails"] = applied_guardrails
65+
6166
return clean_metadata
6267

6368

@@ -130,7 +135,14 @@ def get_logging_payload( # noqa: PLR0915
130135
_model_group = metadata.get("model_group", "")
131136

132137
# clean up litellm metadata
133-
clean_metadata = _get_spend_logs_metadata(metadata)
138+
clean_metadata = _get_spend_logs_metadata(
139+
metadata,
140+
applied_guardrails=(
141+
standard_logging_payload["metadata"].get("applied_guardrails", None)
142+
if standard_logging_payload is not None
143+
else None
144+
),
145+
)
134146

135147
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
136148
additional_usage_values = {}

litellm/types/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,6 +1525,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
15251525
requester_ip_address: Optional[str]
15261526
requester_metadata: Optional[dict]
15271527
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
1528+
applied_guardrails: Optional[List[str]]
15281529

15291530

15301531
class StandardLoggingAdditionalHeaders(TypedDict, total=False):

litellm/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from litellm.caching._internal_lru_cache import lru_cache_wrapper
6161
from litellm.caching.caching import DualCache
6262
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
63+
from litellm.integrations.custom_guardrail import CustomGuardrail
6364
from litellm.integrations.custom_logger import CustomLogger
6465
from litellm.litellm_core_utils.core_helpers import (
6566
map_finish_reason,
@@ -418,6 +419,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
418419
)
419420

420421

422+
def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
423+
"""
424+
Get the request guardrails from the kwargs
425+
"""
426+
metadata = kwargs.get("metadata") or {}
427+
requester_metadata = metadata.get("requester_metadata") or {}
428+
applied_guardrails = requester_metadata.get("guardrails") or []
429+
return applied_guardrails
430+
431+
432+
def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
433+
"""
434+
- Add 'default_on' guardrails to the list
435+
- Add request guardrails to the list
436+
"""
437+
438+
request_guardrails = get_request_guardrails(kwargs)
439+
applied_guardrails = []
440+
for callback in litellm.callbacks:
441+
if callback is not None and isinstance(callback, CustomGuardrail):
442+
if callback.guardrail_name is not None:
443+
if callback.default_on is True:
444+
applied_guardrails.append(callback.guardrail_name)
445+
elif callback.guardrail_name in request_guardrails:
446+
applied_guardrails.append(callback.guardrail_name)
447+
448+
return applied_guardrails
449+
450+
421451
def function_setup( # noqa: PLR0915
422452
original_function: str, rules_obj, start_time, *args, **kwargs
423453
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@@ -436,6 +466,9 @@ def function_setup( # noqa: PLR0915
436466
## CUSTOM LLM SETUP ##
437467
custom_llm_setup()
438468

469+
## GET APPLIED GUARDRAILS
470+
applied_guardrails = get_applied_guardrails(kwargs)
471+
439472
## LOGGING SETUP
440473
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
441474

@@ -677,6 +710,7 @@ def function_setup( # noqa: PLR0915
677710
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
678711
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
679712
kwargs=kwargs,
713+
applied_guardrails=applied_guardrails,
680714
)
681715

682716
## check if metadata is passed in

tests/litellm_utils_tests/test_utils.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -864,17 +864,24 @@ def test_convert_model_response_object():
864864
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
865865
)
866866

867+
867868
@pytest.mark.parametrize(
868-
"content, expected_reasoning, expected_content",
869+
"content, expected_reasoning, expected_content",
869870
[
870871
(None, None, None),
871-
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
872+
(
873+
"<think>I am thinking here</think>The sky is a canvas of blue",
874+
"I am thinking here",
875+
"The sky is a canvas of blue",
876+
),
872877
("I am a regular response", None, "I am a regular response"),
873-
874-
]
878+
],
875879
)
876880
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
877-
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
881+
assert litellm.utils._parse_content_for_reasoning(content) == (
882+
expected_reasoning,
883+
expected_content,
884+
)
878885

879886

880887
@pytest.mark.parametrize(
@@ -1874,3 +1881,82 @@ def test_validate_user_messages_invalid_content_type():
18741881

18751882
assert "Invalid message" in str(e)
18761883
print(e)
1884+
1885+
1886+
from litellm.integrations.custom_guardrail import CustomGuardrail
1887+
from litellm.utils import get_applied_guardrails
1888+
from unittest.mock import Mock
1889+
1890+
1891+
@pytest.mark.parametrize(
1892+
"test_case",
1893+
[
1894+
{
1895+
"name": "default_on_guardrail",
1896+
"callbacks": [
1897+
CustomGuardrail(guardrail_name="test_guardrail", default_on=True)
1898+
],
1899+
"kwargs": {"metadata": {"requester_metadata": {"guardrails": []}}},
1900+
"expected": ["test_guardrail"],
1901+
},
1902+
{
1903+
"name": "request_specific_guardrail",
1904+
"callbacks": [
1905+
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
1906+
],
1907+
"kwargs": {
1908+
"metadata": {"requester_metadata": {"guardrails": ["test_guardrail"]}}
1909+
},
1910+
"expected": ["test_guardrail"],
1911+
},
1912+
{
1913+
"name": "multiple_guardrails",
1914+
"callbacks": [
1915+
CustomGuardrail(guardrail_name="default_guardrail", default_on=True),
1916+
CustomGuardrail(guardrail_name="request_guardrail", default_on=False),
1917+
],
1918+
"kwargs": {
1919+
"metadata": {
1920+
"requester_metadata": {"guardrails": ["request_guardrail"]}
1921+
}
1922+
},
1923+
"expected": ["default_guardrail", "request_guardrail"],
1924+
},
1925+
{
1926+
"name": "empty_metadata",
1927+
"callbacks": [
1928+
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
1929+
],
1930+
"kwargs": {},
1931+
"expected": [],
1932+
},
1933+
{
1934+
"name": "none_callback",
1935+
"callbacks": [
1936+
None,
1937+
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
1938+
],
1939+
"kwargs": {},
1940+
"expected": ["test_guardrail"],
1941+
},
1942+
{
1943+
"name": "non_guardrail_callback",
1944+
"callbacks": [
1945+
Mock(),
1946+
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
1947+
],
1948+
"kwargs": {},
1949+
"expected": ["test_guardrail"],
1950+
},
1951+
],
1952+
)
1953+
def test_get_applied_guardrails(test_case):
1954+
1955+
# Setup
1956+
litellm.callbacks = test_case["callbacks"]
1957+
1958+
# Execute
1959+
result = get_applied_guardrails(test_case["kwargs"])
1960+
1961+
# Assert
1962+
assert sorted(result) == sorted(test_case["expected"])

tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"model": "gpt-4o",
1010
"user": "",
1111
"team_id": "",
12-
"metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
12+
"metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
1313
"cache_key": "Cache OFF",
1414
"spend": 0.00022500000000000002,
1515
"total_tokens": 30,

tests/logging_callback_tests/test_otel_logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
272272
"metadata.user_api_key_user_id",
273273
"metadata.user_api_key_org_id",
274274
"metadata.user_api_key_end_user_id",
275+
"metadata.applied_guardrails",
275276
]
276277

277278
_all_attributes = set(

0 commit comments

Comments
 (0)