|
6 | 6 | from langchain_core.callbacks import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun |
7 | 7 | from langchain_core.language_models import ModelProfile |
8 | 8 | from langchain_core.language_models.base import ( |
| 9 | + LangSmithParams, |
9 | 10 | LanguageModelInput, |
10 | 11 | ) |
11 | 12 | from langchain_core.language_models.chat_models import BaseChatModel |
|
27 | 28 | from typing_extensions import override |
28 | 29 |
|
29 | 30 | from langchain.agents import AgentState |
30 | | -from langchain.agents.middleware.summarization import SummarizationMiddleware |
| 31 | +from langchain.agents.middleware.summarization import ( |
| 32 | + SummarizationMiddleware, |
| 33 | + _provider_matches, |
| 34 | +) |
31 | 35 | from langchain.chat_models import init_chat_model |
32 | 36 | from tests.unit_tests.agents.model import FakeToolCallingModel |
33 | 37 |
|
@@ -1219,6 +1223,84 @@ def test_usage_metadata_trigger() -> None: |
1219 | 1223 | assert not middleware._should_summarize(messages, 0) |
1220 | 1224 |
|
1221 | 1225 |
|
| 1226 | +def test_provider_matches() -> None: |
| 1227 | + """Direct equality matches, plus Bedrock aliases under amazon_bedrock.""" |
| 1228 | + assert _provider_matches("anthropic", "anthropic") |
| 1229 | + assert _provider_matches("openai", "openai") |
| 1230 | + # Bedrock chat models tag messages with model_provider="bedrock" or |
| 1231 | + # "bedrock_converse" but trace under ls_provider="amazon_bedrock". |
| 1232 | + assert _provider_matches("bedrock", "amazon_bedrock") |
| 1233 | + assert _provider_matches("bedrock_converse", "amazon_bedrock") |
| 1234 | + # Non-matches |
| 1235 | + assert not _provider_matches("openai", "anthropic") |
| 1236 | + assert not _provider_matches("bedrock", "anthropic") |
| 1237 | + assert not _provider_matches("anthropic", None) |
| 1238 | + |
| 1239 | + |
| 1240 | +class _MockBedrockChatModel(BaseChatModel): |
| 1241 | + """Mock model that mimics ChatBedrockConverse's ls_provider for tracing.""" |
| 1242 | + |
| 1243 | + @override |
| 1244 | + def _generate( |
| 1245 | + self, |
| 1246 | + messages: list[BaseMessage], |
| 1247 | + stop: list[str] | None = None, |
| 1248 | + run_manager: CallbackManagerForLLMRun | None = None, |
| 1249 | + **kwargs: Any, |
| 1250 | + ) -> ChatResult: |
| 1251 | + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))]) |
| 1252 | + |
| 1253 | + @property |
| 1254 | + def _llm_type(self) -> str: |
| 1255 | + return "amazon_bedrock_converse_chat" |
| 1256 | + |
| 1257 | + @override |
| 1258 | + def _get_ls_params(self, stop: list[str] | None = None, **kwargs: Any) -> LangSmithParams: |
| 1259 | + return LangSmithParams(ls_provider="amazon_bedrock", ls_model_type="chat") |
| 1260 | + |
| 1261 | + |
| 1262 | +def test_reported_tokens_trigger_for_bedrock_converse() -> None: |
| 1263 | + """Bedrock messages should satisfy the reported-token check. |
| 1264 | +
|
| 1265 | + Despite the model_provider/ls_provider mismatch (bedrock_converse vs. |
| 1266 | + amazon_bedrock), the reported-token check should still trigger summarization. |
| 1267 | + """ |
| 1268 | + middleware = SummarizationMiddleware( |
| 1269 | + model=_MockBedrockChatModel(), |
| 1270 | + trigger=("tokens", 10_000), |
| 1271 | + keep=("messages", 4), |
| 1272 | + ) |
| 1273 | + messages: list[AnyMessage] = [ |
| 1274 | + HumanMessage(content="msg1"), |
| 1275 | + AIMessage( |
| 1276 | + content="msg2", |
| 1277 | + response_metadata={"model_provider": "bedrock_converse"}, |
| 1278 | + usage_metadata={ |
| 1279 | + "input_tokens": 7500, |
| 1280 | + "output_tokens": 2501, |
| 1281 | + "total_tokens": 10_001, |
| 1282 | + }, |
| 1283 | + ), |
| 1284 | + ] |
| 1285 | + # reported token count (10_001) should override the supplied count of 0 |
| 1286 | + assert middleware._should_summarize(messages, 0) |
| 1287 | + |
| 1288 | + # mismatched provider should not engage |
| 1289 | + messages_other_provider: list[AnyMessage] = [ |
| 1290 | + HumanMessage(content="msg1"), |
| 1291 | + AIMessage( |
| 1292 | + content="msg2", |
| 1293 | + response_metadata={"model_provider": "anthropic"}, |
| 1294 | + usage_metadata={ |
| 1295 | + "input_tokens": 7500, |
| 1296 | + "output_tokens": 2501, |
| 1297 | + "total_tokens": 10_001, |
| 1298 | + }, |
| 1299 | + ), |
| 1300 | + ] |
| 1301 | + assert not middleware._should_summarize(messages_other_provider, 0) |
| 1302 | + |
| 1303 | + |
1222 | 1304 | class ConfigCapturingModel(BaseChatModel): |
1223 | 1305 | """Mock model that captures the config passed to invoke/ainvoke.""" |
1224 | 1306 |
|
|
0 commit comments