Skip to content

Commit e922e54

Browse files
authored
aws[patch]: (Converse) support camel case tool call args (#516)
1 parent 40abb58 commit e922e54

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,7 +1271,10 @@ def _lc_content_to_bedrock(
12711271

12721272
def _bedrock_to_lc(content: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
12731273
lc_content = []
1274-
for block in _camel_to_snake_keys(content):
1274+
for block in _camel_to_snake_keys(
1275+
content,
1276+
excluded_keys={"input"}, # exclude 'input' key, which contains tool call args
1277+
):
12751278
if "text" in block:
12761279
lc_content.append({"type": "text", "text": block["text"]})
12771280
elif "tool_use" in block:
@@ -1440,13 +1443,21 @@ def _camel_to_snake(text: str) -> str:
14401443
_T = TypeVar("_T")
14411444

14421445

1443-
def _camel_to_snake_keys(obj: _T) -> _T:
1446+
def _camel_to_snake_keys(obj: _T, excluded_keys: set = set()) -> _T:
14441447
if isinstance(obj, list):
1445-
return cast(_T, [_camel_to_snake_keys(e) for e in obj])
1446-
elif isinstance(obj, dict):
14471448
return cast(
1448-
_T, {_camel_to_snake(k): _camel_to_snake_keys(v) for k, v in obj.items()}
1449+
_T, [_camel_to_snake_keys(e, excluded_keys=excluded_keys) for e in obj]
14491450
)
1451+
elif isinstance(obj, dict):
1452+
_dict = {}
1453+
for k, v in obj.items():
1454+
if k in excluded_keys:
1455+
_dict[k] = v
1456+
else:
1457+
_dict[_camel_to_snake(k)] = _camel_to_snake_keys(
1458+
v, excluded_keys=excluded_keys
1459+
)
1460+
return cast(_T, _dict)
14501461
else:
14511462
return obj
14521463

libs/aws/tests/integration_tests/chat_models/test_bedrock_converse.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,31 @@ def classify_query(query_type: Literal["cat", "dog"]) -> None:
215215
assert isinstance(result["ResponseMetadata"]["RetryAttempts"], int)
216216

217217

218+
def test_tool_calling_camel_case() -> None:
219+
model = ChatBedrockConverse(model="us.anthropic.claude-3-5-sonnet-20241022-v2:0")
220+
221+
def classifyQuery(queryType: Literal["cat", "dog"]) -> None:
222+
pass
223+
224+
chat = model.bind_tools([classifyQuery], tool_choice="any")
225+
response = chat.invoke("How big are cats?")
226+
assert isinstance(response, AIMessage)
227+
assert len(response.tool_calls) == 1
228+
tool_call = response.tool_calls[0]
229+
assert tool_call["name"] == "classifyQuery"
230+
assert tool_call["args"] == {"queryType": "cat"}
231+
232+
full = None
233+
for chunk in chat.stream("How big are cats?"):
234+
full = chunk if full is None else full + chunk # type: ignore[assignment]
235+
assert isinstance(full, AIMessageChunk)
236+
assert len(full.tool_calls) == 1
237+
tool_call = full.tool_calls[0]
238+
assert tool_call["name"] == "classifyQuery"
239+
assert tool_call["args"] == {"queryType": "cat"}
240+
assert full.tool_calls[0]["args"] == response.tool_calls[0]["args"]
241+
242+
218243
def test_structured_output_streaming() -> None:
219244
model = ChatBedrockConverse(
220245
model="anthropic.claude-3-sonnet-20240229-v1:0", temperature=0

0 commit comments

Comments
 (0)