Skip to content

Commit 3998d65

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: add A2A support in Agent Engine
PiperOrigin-RevId: 791846986
1 parent b1d0b7c commit 3998d65

File tree

8 files changed

+691
-6
lines changed

8 files changed

+691
-6
lines changed

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@
142142
"google-adk >= 1.0.0, < 2.0.0",
143143
]
144144

145+
a2a_extra_require = [
146+
"a2a-sdk >= 0.3.0",
147+
]
148+
145149
reasoning_engine_extra_require = [
146150
"cloudpickle >= 3.0, < 4.0",
147151
"google-cloud-trace < 2",
@@ -325,6 +329,7 @@
325329
"ray": ray_extra_require,
326330
"ray_testing": ray_testing_extra_require,
327331
"adk": adk_extra_require,
332+
"a2a": a2a_extra_require,
328333
"reasoningengine": reasoning_engine_extra_require,
329334
"agent_engines": agent_engines_extra_require,
330335
"evaluation": evaluation_extra_require,

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,7 @@ def setup_method(self):
18721872
"register the API methods: "
18731873
"https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. "
18741874
"Error: {Unsupported api mode: `UNKNOWN_API_MODE`, "
1875-
"Supported modes are: ``, `async`, `async_stream`, `stream`.}"
1875+
"Supported modes are: ``, `a2a_extension`, `async`, `async_stream`, `stream`.}"
18761876
),
18771877
),
18781878
],

vertexai/_genai/_agent_engines_utils.py

Lines changed: 184 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
Union,
4444
)
4545

46+
import httpx
47+
4648
import proto
4749

4850
from google.api_core import exceptions
@@ -103,6 +105,32 @@
103105
Session = Any
104106

105107

108+
try:
109+
from a2a.types import (
110+
AgentCard,
111+
TransportProtocol,
112+
Message,
113+
TaskIdParams,
114+
TaskQueryParams,
115+
)
116+
from a2a.client import ClientConfig, ClientFactory
117+
118+
AgentCard = AgentCard
119+
TransportProtocol = TransportProtocol
120+
Message = Message
121+
ClientConfig = ClientConfig
122+
ClientFactory = ClientFactory
123+
TaskIdParams = TaskIdParams
124+
TaskQueryParams = TaskQueryParams
125+
except (ImportError, AttributeError):
126+
AgentCard = None
127+
TransportProtocol = None
128+
Message = None
129+
ClientConfig = None
130+
ClientFactory = None
131+
TaskIdParams = None
132+
TaskQueryParams = None
133+
106134
_ACTIONS_KEY = "actions"
107135
_ACTION_APPEND = "append"
108136
_AGENT_FRAMEWORK_ATTR = "agent_framework"
@@ -145,6 +173,8 @@
145173
_REQUIREMENTS_FILE = "requirements.txt"
146174
_STANDARD_API_MODE = ""
147175
_STREAM_API_MODE = "stream"
176+
_A2A_EXTENSION_MODE = "a2a_extension"
177+
_A2A_AGENT_CARD = "a2a_agent_card"
148178
_WARNINGS_KEY = "warnings"
149179
_WARNING_MISSING = "missing"
150180
_WARNING_INCOMPATIBLE = "incompatible"
@@ -493,11 +523,32 @@ def _generate_class_methods_spec_or_raise(
493523

494524
class_method = _to_proto(schema_dict)
495525
class_method[_MODE_KEY_IN_SCHEMA] = mode
526+
if hasattr(agent, "agent_card"):
527+
class_method[_A2A_AGENT_CARD] = getattr(
528+
agent, "agent_card"
529+
).model_dump_json()
496530
class_methods_spec.append(class_method)
497531

498532
return class_methods_spec
499533

500534

535+
def _is_pydantic_serializable(param: inspect.Parameter) -> bool:
536+
"""Checks if the parameter is pydantic serializable."""
537+
538+
if param.annotation == inspect.Parameter.empty:
539+
return True
540+
541+
if isinstance(param.annotation, str):
542+
return False
543+
544+
pydantic = _import_pydantic_or_raise()
545+
try:
546+
pydantic.TypeAdapter(param.annotation)
547+
return True
548+
except Exception:
549+
return False
550+
551+
501552
def _generate_schema(
502553
f: Callable[..., Any],
503554
*,
@@ -557,6 +608,7 @@ def _generate_schema(
557608
inspect.Parameter.KEYWORD_ONLY,
558609
inspect.Parameter.POSITIONAL_ONLY,
559610
)
611+
and _is_pydantic_serializable(param)
560612
}
561613
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
562614
# Postprocessing
@@ -868,6 +920,7 @@ def _register_api_methods_or_raise(
868920
_ASYNC_API_MODE: _wrap_async_query_operation,
869921
_STREAM_API_MODE: _wrap_stream_query_operation,
870922
_ASYNC_STREAM_API_MODE: _wrap_async_stream_query_operation,
923+
_A2A_EXTENSION_MODE: _wrap_a2a_operation,
871924
}
872925
if isinstance(wrap_operation_fn, dict) and api_mode in wrap_operation_fn:
873926
# Override the default function with user-specified function if it exists.
@@ -884,7 +937,13 @@ def _register_api_methods_or_raise(
884937
)
885938

886939
# Bind the method to the object.
887-
method = _wrap_operation(method_name=method_name) # type: ignore[call-arg]
940+
if api_mode == _A2A_EXTENSION_MODE:
941+
agent_card = operation_schema.get(_A2A_AGENT_CARD)
942+
method = _wrap_operation(
943+
method_name=method_name, agent_card=agent_card
944+
) # type: ignore[call-arg]
945+
else:
946+
method = _wrap_operation(method_name=method_name) # type: ignore[call-arg]
888947
method.__name__ = method_name
889948
if method_description and isinstance(method_description, str):
890949
method.__doc__ = method_description
@@ -1473,6 +1532,130 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]
14731532
return _method
14741533

14751534

1535+
def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list]:
1536+
"""Wraps an Agent Engine method, creating a callable for A2A API.
1537+
1538+
Args:
1539+
method_name: The name of the Agent Engine method to call.
1540+
agent_card: The agent card to use for the A2A API call.
1541+
Example:
1542+
{'additionalInterfaces': None,
1543+
'capabilities': {'extensions': None,
1544+
'pushNotifications': None,
1545+
'stateTransitionHistory': None,
1546+
'streaming': False},
1547+
'defaultInputModes': ['text'],
1548+
'defaultOutputModes': ['text'],
1549+
'description': (
1550+
'A helpful assistant agent that can answer questions.'
1551+
),
1552+
'documentationUrl': None,
1553+
'iconUrl': None,
1554+
'name': 'Q&A Agent',
1555+
'preferredTransport': 'JSONRPC',
1556+
'protocolVersion': '0.3.0',
1557+
'provider': None,
1558+
'security': None,
1559+
'securitySchemes': None,
1560+
'signatures': None,
1561+
'skills': [{
1562+
'description': (
1563+
'A helpful assistant agent that can answer questions.'
1564+
),
1565+
'examples': ['Who is leading 2025 F1 Standings?',
1566+
'Where can i find an active volcano?'],
1567+
'id': 'question_answer',
1568+
'inputModes': None,
1569+
'name': 'Q&A Agent',
1570+
'outputModes': None,
1571+
'security': None,
1572+
'tags': ['Question-Answer']}],
1573+
'supportsAuthenticatedExtendedCard': True,
1574+
'url': 'http://localhost:8080/',
1575+
'version': '1.0.0'}
1576+
Returns:
1577+
A callable object that executes the method on the Agent Engine via
1578+
the A2A API.
1579+
"""
1580+
1581+
async def _method(self, **kwargs) -> list:
1582+
"""Wraps an Agent Engine method, creating a callable for A2A API."""
1583+
if not self.api_client:
1584+
raise ValueError("api_client is not initialized.")
1585+
if not self.api_resource:
1586+
raise ValueError("api_resource is not initialized.")
1587+
a2a_agent_card = AgentCard(**json.loads(agent_card))
1588+
# A2A + AE integration currently only supports Rest API.
1589+
if (
1590+
a2a_agent_card.preferred_transport
1591+
and a2a_agent_card.preferred_transport != TransportProtocol.http_json
1592+
):
1593+
raise ValueError(
1594+
"Only HTTP+JSON is supported for preferred transport on agent card "
1595+
)
1596+
1597+
# Set preferred transport to HTTP+JSON if not set.
1598+
if not hasattr(a2a_agent_card, "preferred_transport"):
1599+
a2a_agent_card.preferred_transport = TransportProtocol.http_json
1600+
1601+
# AE cannot support streaming yet. Turn off streaming for now.
1602+
if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming:
1603+
raise ValueError(
1604+
"Streaming is not supported in Agent Engine, please change "
1605+
"a2a_agent_card.capabilities.streaming to False."
1606+
)
1607+
1608+
if not hasattr(a2a_agent_card.capabilities, "streaming"):
1609+
a2a_agent_card.capabilities.streaming = False
1610+
1611+
# agent_card is set on the class_methods before set_up is invoked.
1612+
# Ensure that the agent_card url is set correctly before the client is created.
1613+
base_url = self.api_client._api_client._http_options.base_url.rstrip(
1614+
"/"
1615+
)
1616+
api_version = self.api_client._api_client._http_options.api_version
1617+
a2a_agent_card.url = (
1618+
f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
1619+
)
1620+
1621+
# Using a2a client, inject the auth token from the global config.
1622+
config = ClientConfig(
1623+
supported_transports=[
1624+
TransportProtocol.http_json,
1625+
],
1626+
use_client_preference=True,
1627+
httpx_client=httpx.AsyncClient(
1628+
headers={
1629+
"Authorization": (f"Bearer {self.api_client._api_client._credentials.token}")
1630+
}
1631+
),
1632+
)
1633+
factory = ClientFactory(config)
1634+
client = factory.create(a2a_agent_card)
1635+
1636+
if method_name == "on_message_send":
1637+
response = client.send_message(Message(**kwargs))
1638+
elif method_name == "on_get_task":
1639+
response = await client.get_task(TaskQueryParams(**kwargs))
1640+
elif method_name == "on_cancel_task":
1641+
response = await client.cancel_task(TaskIdParams(**kwargs))
1642+
elif method_name == "handle_authenticated_agent_card":
1643+
response = await client.get_card()
1644+
else:
1645+
raise ValueError(f"Unknown method name: {method_name}")
1646+
1647+
if inspect.isasyncgen(response):
1648+
# Response is an async generator, collect the chunks.
1649+
chunks = []
1650+
async for chunk in response:
1651+
chunks.append(chunk)
1652+
return chunks
1653+
else:
1654+
return response
1655+
1656+
return _method
1657+
1658+
14761659
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
14771660
"""Converts the body of the HTTP Response message to JSON format.
14781661

vertexai/_genai/agent_engines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,6 +1333,7 @@ def _register_api_methods(
13331333
"async": _agent_engines_utils._wrap_async_query_operation,
13341334
"stream": _agent_engines_utils._wrap_stream_query_operation,
13351335
"async_stream": _agent_engines_utils._wrap_async_stream_query_operation,
1336+
"a2a_extension": _agent_engines_utils._wrap_a2a_operation,
13361337
},
13371338
)
13381339
except Exception as e:

0 commit comments

Comments
 (0)