-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathagent_development_kit.py
More file actions
296 lines (251 loc) · 10.4 KB
/
agent_development_kit.py
File metadata and controls
296 lines (251 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""ModelHandler for running agents built with the Google Agent Development Kit.
This module provides :class:`ADKAgentModelHandler`, a Beam
:class:`~apache_beam.ml.inference.base.ModelHandler` that wraps an ADK
:class:`google.adk.agents.llm_agent.LlmAgent` so it can be used with the
:class:`~apache_beam.ml.inference.base.RunInference` transform.
Typical usage::
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.agent_development_kit import ADKAgentModelHandler
from google.adk.agents import LlmAgent
agent = LlmAgent(
name="my_agent",
model="gemini-2.0-flash",
instruction="You are a helpful assistant.",
)
with beam.Pipeline() as p:
results = (
p
| beam.Create(["What is the capital of France?"])
| RunInference(ADKAgentModelHandler(agent=agent))
)
If your agent contains state that is not picklable (e.g. tool closures that
capture unpicklable objects), pass a zero-arg factory callable instead::
handler = ADKAgentModelHandler(agent=lambda: LlmAgent(...))
"""
import asyncio
import logging
import uuid
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Sequence
from typing import Any
from typing import Optional
from typing import Union
from apache_beam.ml.inference.base import ModelHandler
from apache_beam.ml.inference.base import PredictionResult
try:
from google.adk import sessions
from google.adk.agents import Agent
from google.adk.runners import Runner
from google.adk.sessions import BaseSessionService
from google.adk.sessions import InMemorySessionService
from google.genai.types import Content as genai_Content
from google.genai.types import Part as genai_Part
ADK_AVAILABLE = True
except ImportError:
ADK_AVAILABLE = False
genai_Content = Any # type: ignore[assignment, misc]
genai_Part = Any # type: ignore[assignment, misc]
LOGGER = logging.getLogger("ADKAgentModelHandler")
# Type alias for an agent or factory that produces one
_AgentOrFactory = Union["Agent", Callable[[], "Agent"]]
class ADKAgentModelHandler(ModelHandler[str | genai_Content,
PredictionResult,
"Runner"]):
"""ModelHandler for running ADK agents with the Beam RunInference transform.
Accepts either a fully constructed :class:`google.adk.agents.Agent` or a
zero-arg factory callable that produces one. The factory form is useful when
the agent contains state that is not picklable and therefore cannot be
serialized alongside the pipeline graph.
Each call to :meth:`run_inference` invokes the agent once per element in the
batch. By default every invocation uses a fresh, isolated session (stateless).
Stateful multi-turn conversations can be achieved by passing a ``session_id``
key inside ``inference_args``; elements sharing the same ``session_id`` will
continue the same conversation history.
Args:
agent: A pre-constructed :class:`~google.adk.agents.Agent` instance, or a
zero-arg callable that returns one. The callable form defers agent
construction to worker ``load_model`` time, which is useful when the
agent cannot be serialized.
app_name: The ADK application name used to namespace sessions. Defaults to
``"beam_inference"``.
session_service_factory: Optional zero-arg callable returning a
:class:`~google.adk.sessions.BaseSessionService`. When ``None``, an
:class:`~google.adk.sessions.InMemorySessionService` is created
automatically.
min_batch_size: Optional minimum batch size.
max_batch_size: Optional maximum batch size.
max_batch_duration_secs: Optional maximum time to buffer a batch before
emitting; used in streaming contexts.
max_batch_weight: Optional maximum total weight of a batch.
element_size_fn: Optional function that returns the size (weight) of an
element.
"""
def __init__(
self,
agent: _AgentOrFactory,
app_name: str = "beam_inference",
session_service_factory: Optional[Callable[[],
"BaseSessionService"]] = None,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
max_batch_weight: Optional[int] = None,
element_size_fn: Optional[Callable[[Any], int]] = None,
**kwargs):
if not ADK_AVAILABLE:
raise ImportError(
"google-adk is required to use ADKAgentModelHandler. "
"Install it with: pip install google-adk")
if agent is None:
raise ValueError("'agent' must be an Agent instance or a callable.")
self._agent_or_factory = agent
self._app_name = app_name
self._session_service_factory = session_service_factory
super().__init__(
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
max_batch_duration_secs=max_batch_duration_secs,
max_batch_weight=max_batch_weight,
element_size_fn=element_size_fn,
**kwargs)
def load_model(self) -> "Runner":
"""Instantiates the ADK Runner on the worker.
Resolves the agent (calling the factory if a callable was provided), then
creates a :class:`~google.adk.runners.Runner` backed by the configured
session service.
Returns:
A fully initialised :class:`~google.adk.runners.Runner`.
"""
if callable(self._agent_or_factory) and not isinstance(
self._agent_or_factory, Agent):
agent = self._agent_or_factory()
else:
agent = self._agent_or_factory
if self._session_service_factory is not None:
session_service = self._session_service_factory()
else:
session_service = InMemorySessionService()
runner = Runner(
agent=agent,
app_name=self._app_name,
session_service=session_service,
)
LOGGER.info(
"Loaded ADK Runner for agent '%s' (app_name='%s')",
agent.name,
self._app_name,
)
return runner
def run_inference(
self,
batch: Sequence[str | genai_Content],
model: "Runner",
inference_args: Optional[dict[str, Any]] = None,
) -> Iterable[PredictionResult]:
"""Runs the ADK agent on each element in the batch.
Each element is sent to the agent as a new user turn. The final response
text from the agent is returned as the ``inference`` field of a
:class:`~apache_beam.ml.inference.base.PredictionResult`.
Args:
batch: A sequence of inputs, each of which is either a ``str`` (the user
message text) or a :class:`google.genai.types.Content` object (for
richer multi-part messages).
model: The :class:`~google.adk.runners.Runner` returned by
:meth:`load_model`.
inference_args: Optional dict of extra arguments. Supported keys:
- ``"session_id"`` (:class:`str`): If supplied, all elements in this
batch share this session ID, enabling stateful multi-turn
conversations. If omitted, each element receives a unique auto-
generated session ID.
- ``"user_id"`` (:class:`str`): The user identifier to pass to the
runner. Defaults to ``"beam_user"``.
Returns:
An iterable of :class:`~apache_beam.ml.inference.base.PredictionResult`,
one per input element.
"""
if inference_args is None:
inference_args = {}
user_id: str = inference_args.get("user_id", "beam_user")
agent_invocations = []
elements_with_sessions = []
for element in batch:
session_id: str = inference_args.get("session_id", str(uuid.uuid4()))
# Ensure a session exists for this invocation
try:
model.session_service.create_session(
app_name=self._app_name,
user_id=user_id,
session_id=session_id,
)
except sessions.SessionExistsError:
# It's okay if the session already exists for shared session IDs.
pass
# Wrap plain strings in a Content object
if isinstance(element, str):
message = genai_Content(role="user", parts=[genai_Part(text=element)])
else:
# Assume the caller has already constructed a types.Content object
message = element
agent_invocations.append(
self._invoke_agent(model, user_id, session_id, message))
elements_with_sessions.append(element)
# Run all agent invocations concurrently
async def _run_concurrently():
return await asyncio.gather(*agent_invocations)
response_texts = asyncio.run(_run_concurrently())
results = []
for i, element in enumerate(elements_with_sessions):
results.append(
PredictionResult(
example=element,
inference=response_texts[i],
model_id=model.agent.name,
))
return results
@staticmethod
async def _invoke_agent(
runner: "Runner",
user_id: str,
session_id: str,
message: genai_Content,
) -> Optional[str]:
"""Drives the ADK event loop and returns the final response text.
Args:
runner: The ADK Runner to invoke.
user_id: The user ID for this invocation.
session_id: The session ID for this invocation.
message: The :class:`google.genai.types.Content` to send.
Returns:
The text of the agent's final response, or ``None`` if the agent
produced no final text response.
"""
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=message,
):
if event.is_final_response():
if event.content:
return event.content.text
return None
def get_metrics_namespace(self) -> str:
return "ADKAgentModelHandler"