Skip to content

Commit d689a04

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Support propagating grounding metadata from AgentTool
Close: #4671 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 891827874
1 parent 0e71985 commit d689a04

File tree

6 files changed

+104
-91
lines changed

6 files changed

+104
-91
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import random
17+
18+
from dotenv import load_dotenv
19+
from google.adk import Agent
20+
from google.adk.tools.agent_tool import AgentTool
21+
from google.adk.tools.tool_context import ToolContext
22+
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
23+
24+
load_dotenv(override=True)
25+
26+
VERTEXAI_DATASTORE_ID = os.getenv("VERTEXAI_DATASTORE_ID")
27+
if not VERTEXAI_DATASTORE_ID:
28+
raise ValueError("VERTEXAI_DATASTORE_ID environment variable not set")
29+
30+
31+
def roll_die(sides: int, tool_context: ToolContext) -> int:
32+
"""Roll a die and return the rolled result.
33+
34+
Args:
35+
sides: The integer number of sides the die has.
36+
37+
Returns:
38+
An integer of the result of rolling the die.
39+
"""
40+
result = random.randint(1, sides)
41+
if "rolls" not in tool_context.state:
42+
tool_context.state["rolls"] = []
43+
44+
tool_context.state["rolls"] = tool_context.state["rolls"] + [result]
45+
return result
46+
47+
48+
vertex_ai_search_agent = Agent(
49+
model="gemini-3-flash-preview",
50+
name="vertex_ai_search_agent",
51+
description="An agent for performing Vertex AI search.",
52+
tools=[
53+
VertexAiSearchTool(
54+
data_store_id=VERTEXAI_DATASTORE_ID,
55+
)
56+
],
57+
)
58+
59+
root_agent = Agent(
60+
model="gemini-3.1-pro-preview",
61+
name="hello_world_agent",
62+
description="A hello world agent with multiple tools.",
63+
instruction="""
64+
You are a helpful assistant which can help user to roll dice and search for information.
65+
- Use `roll_die` tool to roll dice.
66+
- Use `vertex_ai_search_agent` to search for Google Agent Development Kit (ADK) information in the datastore.
67+
""",
68+
tools=[
69+
roll_die,
70+
AgentTool(
71+
agent=vertex_ai_search_agent, propagate_grounding_metadata=True
72+
),
73+
],
74+
)

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,9 @@ async def _maybe_add_grounding_metadata(
259259
tools = await agent.canonical_tools(readonly_context)
260260
invocation_context.canonical_tools_cache = tools
261261

262-
if not any(tool.name == 'google_search_agent' for tool in tools):
262+
if not any(
263+
getattr(tool, 'propagate_grounding_metadata', False) for tool in tools
264+
):
263265
return response
264266
ground_metadata = invocation_context.session.state.get(
265267
'temp:_adk_grounding_metadata', None

src/google/adk/tools/agent_tool.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ def __init__(
113113
skip_summarization: bool = False,
114114
*,
115115
include_plugins: bool = True,
116+
propagate_grounding_metadata: bool = False,
116117
):
117118
self.agent = agent
118119
self.skip_summarization: bool = skip_summarization
119120
self.include_plugins = include_plugins
121+
self.propagate_grounding_metadata = propagate_grounding_metadata
120122

121123
super().__init__(name=agent.name, description=agent.description)
122124

@@ -247,6 +249,7 @@ async def run_async(
247249
)
248250

249251
last_content = None
252+
last_grounding_metadata = None
250253
async with Aclosing(
251254
runner.run_async(
252255
user_id=session.user_id, session_id=session.id, new_message=content
@@ -258,6 +261,7 @@ async def run_async(
258261
tool_context.state.update(event.actions.state_delta)
259262
if event.content:
260263
last_content = event.content
264+
last_grounding_metadata = event.grounding_metadata
261265

262266
# Clean up runner resources (especially MCP sessions)
263267
# to avoid "Attempted to exit cancel scope in a different task" errors
@@ -273,6 +277,12 @@ async def run_async(
273277
tool_result = validate_schema(output_schema, merged_text)
274278
else:
275279
tool_result = merged_text
280+
281+
if self.propagate_grounding_metadata and last_grounding_metadata:
282+
tool_context.state['temp:_adk_grounding_metadata'] = (
283+
last_grounding_metadata
284+
)
285+
276286
return tool_result
277287

278288
@override

src/google/adk/tools/google_search_agent_tool.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,12 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Any
1817
from typing import Union
1918

20-
from google.genai import types
21-
from typing_extensions import override
22-
2319
from ..agents.llm_agent import LlmAgent
24-
from ..memory.in_memory_memory_service import InMemoryMemoryService
2520
from ..models.base_llm import BaseLlm
26-
from ..utils._schema_utils import validate_schema
27-
from ..utils.context_utils import Aclosing
28-
from ._forwarding_artifact_service import ForwardingArtifactService
2921
from .agent_tool import AgentTool
3022
from .google_search_tool import google_search
31-
from .tool_context import ToolContext
3223

3324

3425
def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
@@ -60,80 +51,4 @@ class GoogleSearchAgentTool(AgentTool):
6051

6152
def __init__(self, agent: LlmAgent):
6253
self.agent = agent
63-
super().__init__(agent=self.agent)
64-
65-
@override
66-
async def run_async(
67-
self,
68-
*,
69-
args: dict[str, Any],
70-
tool_context: ToolContext,
71-
) -> Any:
72-
from ..agents.llm_agent import LlmAgent
73-
from ..runners import Runner
74-
from ..sessions.in_memory_session_service import InMemorySessionService
75-
76-
if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
77-
input_value = self.agent.input_schema.model_validate(args)
78-
content = types.Content(
79-
role='user',
80-
parts=[
81-
types.Part.from_text(
82-
text=input_value.model_dump_json(exclude_none=True)
83-
)
84-
],
85-
)
86-
else:
87-
content = types.Content(
88-
role='user',
89-
parts=[types.Part.from_text(text=args['request'])],
90-
)
91-
runner = Runner(
92-
app_name=self.agent.name,
93-
agent=self.agent,
94-
artifact_service=ForwardingArtifactService(tool_context),
95-
session_service=InMemorySessionService(),
96-
memory_service=InMemoryMemoryService(),
97-
credential_service=tool_context._invocation_context.credential_service,
98-
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
99-
)
100-
101-
state_dict = {
102-
k: v
103-
for k, v in tool_context.state.to_dict().items()
104-
if not k.startswith('_adk') # Filter out adk internal states
105-
}
106-
session = await runner.session_service.create_session(
107-
app_name=self.agent.name,
108-
user_id=tool_context._invocation_context.user_id,
109-
state=state_dict,
110-
)
111-
112-
last_content = None
113-
last_grounding_metadata = None
114-
async with Aclosing(
115-
runner.run_async(
116-
user_id=session.user_id, session_id=session.id, new_message=content
117-
)
118-
) as agen:
119-
async for event in agen:
120-
# Forward state delta to parent session.
121-
if event.actions.state_delta:
122-
tool_context.state.update(event.actions.state_delta)
123-
if event.content:
124-
last_content = event.content
125-
last_grounding_metadata = event.grounding_metadata
126-
127-
if last_content is None or last_content.parts is None:
128-
return ''
129-
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
130-
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
131-
tool_result = validate_schema(self.agent.output_schema, merged_text)
132-
else:
133-
tool_result = merged_text
134-
135-
if last_grounding_metadata:
136-
tool_context.state['temp:_adk_grounding_metadata'] = (
137-
last_grounding_metadata
138-
)
139-
return tool_result
54+
super().__init__(agent=self.agent, propagate_grounding_metadata=True)

tests/unittests/flows/llm_flows/test_base_llm_flow.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ async def test_handle_after_model_callback_grounding_with_no_callbacks(
284284
invocation_id=invocation_context.invocation_id,
285285
author=agent.name,
286286
)
287-
flow = BaseLlmFlowForTesting()
288287

289288
result = await _handle_after_model_callback(
290289
invocation_context, llm_response, event
@@ -341,7 +340,6 @@ async def test_handle_after_model_callback_grounding_with_callback_override(
341340
invocation_id=invocation_context.invocation_id,
342341
author=agent.name,
343342
)
344-
flow = BaseLlmFlowForTesting()
345343

346344
result = await _handle_after_model_callback(
347345
invocation_context, llm_response, event
@@ -403,7 +401,6 @@ def __init__(self):
403401
invocation_id=invocation_context.invocation_id,
404402
author=agent.name,
405403
)
406-
flow = BaseLlmFlowForTesting()
407404

408405
result = await _handle_after_model_callback(
409406
invocation_context, llm_response, event
@@ -430,6 +427,7 @@ class MockGoogleSearchTool(BaseTool):
430427

431428
def __init__(self):
432429
super().__init__(name='google_search_agent', description='Mock search')
430+
self.propagate_grounding_metadata = True
433431

434432
async def call(self, **kwargs):
435433
return 'mock result'
@@ -459,7 +457,6 @@ async def call(self, **kwargs):
459457
invocation_id=invocation_context.invocation_id,
460458
author=agent.name,
461459
)
462-
flow = BaseLlmFlowForTesting()
463460

464461
# Call _handle_after_model_callback multiple times with the same context
465462
result1 = await _handle_after_model_callback(

0 commit comments

Comments
 (0)