Skip to content

Commit 0ec69d0

Browse files
hangfeicopybara-github
authored andcommitted
feat: Enhance LangchainTool to accept more forms of functions
Now the LangchainTool can wrap: * Langchain StructuredTool (sync and async). * Langchain @tool (sync and async). This enhance the flexibility for user and enables async functionalities. PiperOrigin-RevId: 784728061
1 parent f1e0bc0 commit 0ec69d0

File tree

3 files changed

+132
-6
lines changed

3 files changed

+132
-6
lines changed

contributing/samples/langchain_structured_tool_agent/agent.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,31 @@
1717
"""
1818
from google.adk.agents import Agent
1919
from google.adk.tools.langchain_tool import LangchainTool
20+
from langchain.tools import tool
2021
from langchain_core.tools.structured import StructuredTool
2122
from pydantic import BaseModel
2223

2324

24-
def add(x, y) -> int:
25+
async def add(x, y) -> int:
2526
return x + y
2627

2728

29+
@tool
30+
def minus(x, y) -> int:
31+
return x - y
32+
33+
2834
class AddSchema(BaseModel):
2935
x: int
3036
y: int
3137

3238

33-
test_langchain_tool = StructuredTool.from_function(
39+
class MinusSchema(BaseModel):
40+
x: int
41+
y: int
42+
43+
44+
test_langchain_add_tool = StructuredTool.from_function(
3445
add,
3546
name="add",
3647
description="Adds two numbers",
@@ -45,5 +56,8 @@ class AddSchema(BaseModel):
4556
"You are a helpful assistant for user questions, you have access to a"
4657
" tool that adds two numbers."
4758
),
48-
tools=[LangchainTool(tool=test_langchain_tool)],
59+
tools=[
60+
LangchainTool(tool=test_langchain_add_tool),
61+
LangchainTool(tool=minus),
62+
],
4963
)

src/google/adk/tools/langchain_tool.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,26 @@ def __init__(
5959
name: Optional[str] = None,
6060
description: Optional[str] = None,
6161
):
62-
# Check if the tool has a 'run' method
6362
if not hasattr(tool, 'run') and not hasattr(tool, '_run'):
64-
raise ValueError("Langchain tool must have a 'run' or '_run' method")
63+
raise ValueError(
64+
"Tool must be a Langchain tool, have a 'run' or '_run' method."
65+
)
6566

6667
# Determine which function to use
6768
if isinstance(tool, StructuredTool):
6869
func = tool.func
69-
else:
70+
# For async tools, func might be None but coroutine exists
71+
if func is None and hasattr(tool, 'coroutine') and tool.coroutine:
72+
func = tool.coroutine
73+
elif hasattr(tool, '_run') or hasattr(tool, 'run'):
7074
func = tool._run if hasattr(tool, '_run') else tool.run
75+
else:
76+
raise ValueError(
77+
"This is not supported. Tool must be a Langchain tool, have a 'run'"
78+
" or '_run' method. The tool is: ",
79+
type(tool),
80+
)
81+
7182
super().__init__(func)
7283
# run_manager is a special parameter for langchain tool
7384
self._ignore_params.append('run_manager')
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest.mock import MagicMock
16+
17+
from google.adk.tools.langchain_tool import LangchainTool
18+
from langchain.tools import tool
19+
from langchain_core.tools.structured import StructuredTool
20+
from pydantic import BaseModel
21+
import pytest
22+
23+
24+
@tool
25+
async def async_add_with_annotation(x, y) -> int:
26+
"""Adds two numbers"""
27+
return x + y
28+
29+
30+
@tool
31+
def sync_add_with_annotation(x, y) -> int:
32+
"""Adds two numbers"""
33+
return x + y
34+
35+
36+
async def async_add(x, y) -> int:
37+
return x + y
38+
39+
40+
def sync_add(x, y) -> int:
41+
return x + y
42+
43+
44+
class AddSchema(BaseModel):
45+
x: int
46+
y: int
47+
48+
49+
test_langchain_async_add_tool = StructuredTool.from_function(
50+
async_add,
51+
name="add",
52+
description="Adds two numbers",
53+
args_schema=AddSchema,
54+
)
55+
56+
test_langchain_sync_add_tool = StructuredTool.from_function(
57+
sync_add,
58+
name="add",
59+
description="Adds two numbers",
60+
args_schema=AddSchema,
61+
)
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_raw_async_function_works():
66+
"""Test that passing a raw async function to LangchainTool works correctly."""
67+
langchain_tool = LangchainTool(tool=test_langchain_async_add_tool)
68+
result = await langchain_tool.run_async(
69+
args={"x": 1, "y": 3}, tool_context=MagicMock()
70+
)
71+
assert result == 4
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_raw_sync_function_works():
76+
"""Test that passing a raw sync function to LangchainTool works correctly."""
77+
langchain_tool = LangchainTool(tool=test_langchain_sync_add_tool)
78+
result = await langchain_tool.run_async(
79+
args={"x": 1, "y": 3}, tool_context=MagicMock()
80+
)
81+
assert result == 4
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_raw_async_function_with_annotation_works():
86+
"""Test that passing a raw async function to LangchainTool works correctly."""
87+
langchain_tool = LangchainTool(tool=async_add_with_annotation)
88+
result = await langchain_tool.run_async(
89+
args={"x": 1, "y": 3}, tool_context=MagicMock()
90+
)
91+
assert result == 4
92+
93+
94+
@pytest.mark.asyncio
95+
async def test_raw_sync_function_with_annotation_works():
96+
"""Test that passing a raw sync function to LangchainTool works correctly."""
97+
langchain_tool = LangchainTool(tool=sync_add_with_annotation)
98+
result = await langchain_tool.run_async(
99+
args={"x": 1, "y": 3}, tool_context=MagicMock()
100+
)
101+
assert result == 4

0 commit comments

Comments
 (0)