Skip to content

Commit 36c381b

Browse files
authored
fix(langchain): alias Bedrock providers in summarization token check (#37453)
1 parent 0831e44 commit 36c381b

2 files changed

Lines changed: 104 additions & 2 deletions

File tree

libs/langchain_v1/langchain/agents/middleware/summarization.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,23 @@
7777
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
7878
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
7979

80+
# Some providers tag emitted messages with a `model_provider` string that differs from
81+
# their LangSmith `ls_provider`. The reported-token check below compares the two, so we
82+
# accept known aliases per `ls_provider`.
83+
_LS_PROVIDER_ALIASES: dict[str, frozenset[str]] = {
84+
"amazon_bedrock": frozenset({"bedrock", "bedrock_converse"}),
85+
}
86+
87+
88+
def _provider_matches(message_provider: str, model_ls_provider: str | None) -> bool:
89+
if model_ls_provider is None:
90+
return False
91+
if message_provider == model_ls_provider:
92+
return True
93+
aliases = _LS_PROVIDER_ALIASES.get(model_ls_provider)
94+
return aliases is not None and message_provider in aliases
95+
96+
8097
ContextFraction = tuple[Literal["fraction"], float]
8198
"""Fraction of model's maximum input tokens.
8299
@@ -379,7 +396,10 @@ def _should_summarize_based_on_reported_tokens(
379396
and (reported_tokens := last_ai_message.usage_metadata.get("total_tokens", -1))
380397
and reported_tokens >= threshold
381398
and (message_provider := last_ai_message.response_metadata.get("model_provider"))
382-
and message_provider == self.model._get_ls_params().get("ls_provider") # noqa: SLF001
399+
and _provider_matches(
400+
message_provider,
401+
self.model._get_ls_params().get("ls_provider"), # noqa: SLF001
402+
)
383403
):
384404
return True
385405
return False

libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
77
from langchain_core.language_models import ModelProfile
88
from langchain_core.language_models.base import (
9+
LangSmithParams,
910
LanguageModelInput,
1011
)
1112
from langchain_core.language_models.chat_models import BaseChatModel
@@ -27,7 +28,10 @@
2728
from typing_extensions import override
2829

2930
from langchain.agents import AgentState
30-
from langchain.agents.middleware.summarization import SummarizationMiddleware
31+
from langchain.agents.middleware.summarization import (
32+
SummarizationMiddleware,
33+
_provider_matches,
34+
)
3135
from langchain.chat_models import init_chat_model
3236
from tests.unit_tests.agents.model import FakeToolCallingModel
3337

@@ -1219,6 +1223,84 @@ def test_usage_metadata_trigger() -> None:
12191223
assert not middleware._should_summarize(messages, 0)
12201224

12211225

1226+
def test_provider_matches() -> None:
1227+
"""Direct equality matches, plus Bedrock aliases under amazon_bedrock."""
1228+
assert _provider_matches("anthropic", "anthropic")
1229+
assert _provider_matches("openai", "openai")
1230+
# Bedrock chat models tag messages with model_provider="bedrock" or
1231+
# "bedrock_converse" but trace under ls_provider="amazon_bedrock".
1232+
assert _provider_matches("bedrock", "amazon_bedrock")
1233+
assert _provider_matches("bedrock_converse", "amazon_bedrock")
1234+
# Non-matches
1235+
assert not _provider_matches("openai", "anthropic")
1236+
assert not _provider_matches("bedrock", "anthropic")
1237+
assert not _provider_matches("anthropic", None)
1238+
1239+
1240+
class _MockBedrockChatModel(BaseChatModel):
1241+
"""Mock model that mimics ChatBedrockConverse's ls_provider for tracing."""
1242+
1243+
@override
1244+
def _generate(
1245+
self,
1246+
messages: list[BaseMessage],
1247+
stop: list[str] | None = None,
1248+
run_manager: CallbackManagerForLLMRun | None = None,
1249+
**kwargs: Any,
1250+
) -> ChatResult:
1251+
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])
1252+
1253+
@property
1254+
def _llm_type(self) -> str:
1255+
return "amazon_bedrock_converse_chat"
1256+
1257+
@override
1258+
def _get_ls_params(self, stop: list[str] | None = None, **kwargs: Any) -> LangSmithParams:
1259+
return LangSmithParams(ls_provider="amazon_bedrock", ls_model_type="chat")
1260+
1261+
1262+
def test_reported_tokens_trigger_for_bedrock_converse() -> None:
1263+
"""Bedrock messages should satisfy the reported-token check.
1264+
1265+
Despite the model_provider/ls_provider mismatch (bedrock_converse vs.
1266+
amazon_bedrock), the reported-token check should still trigger summarization.
1267+
"""
1268+
middleware = SummarizationMiddleware(
1269+
model=_MockBedrockChatModel(),
1270+
trigger=("tokens", 10_000),
1271+
keep=("messages", 4),
1272+
)
1273+
messages: list[AnyMessage] = [
1274+
HumanMessage(content="msg1"),
1275+
AIMessage(
1276+
content="msg2",
1277+
response_metadata={"model_provider": "bedrock_converse"},
1278+
usage_metadata={
1279+
"input_tokens": 7500,
1280+
"output_tokens": 2501,
1281+
"total_tokens": 10_001,
1282+
},
1283+
),
1284+
]
1285+
# reported token count (10_001) should override the supplied count of 0
1286+
assert middleware._should_summarize(messages, 0)
1287+
1288+
# mismatched provider should not engage
1289+
messages_other_provider: list[AnyMessage] = [
1290+
HumanMessage(content="msg1"),
1291+
AIMessage(
1292+
content="msg2",
1293+
response_metadata={"model_provider": "anthropic"},
1294+
usage_metadata={
1295+
"input_tokens": 7500,
1296+
"output_tokens": 2501,
1297+
"total_tokens": 10_001,
1298+
},
1299+
),
1300+
]
1301+
assert not middleware._should_summarize(messages_other_provider, 0)
1302+
1303+
12221304
class ConfigCapturingModel(BaseChatModel):
12231305
"""Mock model that captures the config passed to invoke/ainvoke."""
12241306

0 commit comments

Comments
 (0)