Skip to content

Commit 359adba

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: A2A + Agent Engine integration
PiperOrigin-RevId: 791846986
1 parent ec11bd3 commit 359adba

File tree

8 files changed

+466
-18
lines changed

8 files changed

+466
-18
lines changed

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@
142142
"google-adk >= 1.0.0, < 2.0.0",
143143
]
144144

145+
a2a_extra_require = [
146+
"a2a-sdk >= 0.3.0",
147+
]
148+
145149
reasoning_engine_extra_require = [
146150
"cloudpickle >= 3.0, < 4.0",
147151
"google-cloud-trace < 2",
@@ -325,6 +329,7 @@
325329
"ray": ray_extra_require,
326330
"ray_testing": ray_testing_extra_require,
327331
"adk": adk_extra_require,
332+
"a2a": a2a_extra_require,
328333
"reasoningengine": reasoning_engine_extra_require,
329334
"agent_engines": agent_engines_extra_require,
330335
"evaluation": evaluation_extra_require,

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1872,7 +1872,7 @@ def setup_method(self):
18721872
"register the API methods: "
18731873
"https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. "
18741874
"Error: {Unsupported api mode: `UNKNOWN_API_MODE`, "
1875-
"Supported modes are: ``, `async`, `async_stream`, `stream`.}"
1875+
"Supported modes are: ``, `a2a_extension`, `async`, `async_stream`, `stream`.}"
18761876
),
18771877
),
18781878
],

vertexai/_genai/_agent_engines_utils.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
Union,
4343
)
4444

45+
import json
46+
import httpx
47+
4548
import proto
4649

4750
from google.api_core import exceptions
@@ -102,6 +105,27 @@
102105
Session = Any
103106

104107

108+
try:
109+
from a2a.types import (
110+
AgentCard,
111+
TransportProtocol,
112+
Message,
113+
TaskIdParams,
114+
)
115+
from a2a.client import ClientConfig, ClientFactory
116+
117+
AgentCard = AgentCard
118+
TransportProtocol = TransportProtocol
119+
ClientConfig = ClientConfig
120+
ClientFactory = ClientFactory
121+
TaskIdParams = TaskIdParams
122+
except (ImportError, AttributeError):
123+
AgentCard = None
124+
TransportProtocol = None
125+
ClientConfig = None
126+
ClientFactory = None
127+
TaskIdParams = None
128+
105129
_ACTIONS_KEY = "actions"
106130
_ACTION_APPEND = "append"
107131
_AGENT_FRAMEWORK_ATTR = "agent_framework"
@@ -144,6 +168,8 @@
144168
_REQUIREMENTS_FILE = "requirements.txt"
145169
_STANDARD_API_MODE = ""
146170
_STREAM_API_MODE = "stream"
171+
_A2A_EXTENSION_MODE = "a2a_extension"
172+
_A2A_AGENT_CARD = "a2a_agent_card"
147173
_WARNINGS_KEY = "warnings"
148174
_WARNING_MISSING = "missing"
149175
_WARNING_INCOMPATIBLE = "incompatible"
@@ -454,11 +480,32 @@ def _generate_class_methods_spec_or_raise(
454480

455481
class_method = _to_proto(schema_dict)
456482
class_method[_MODE_KEY_IN_SCHEMA] = mode
483+
if hasattr(agent_engine, "agent_card"):
484+
class_method[_A2A_AGENT_CARD] = getattr(
485+
agent_engine, "agent_card"
486+
).model_dump_json()
457487
class_methods_spec.append(class_method)
458488

459489
return class_methods_spec
460490

461491

492+
def _is_pydantic_serializable(param: inspect.Parameter) -> bool:
493+
"""Checks if the parameter is pydantic serializable."""
494+
495+
if param.annotation == inspect.Parameter.empty:
496+
return True
497+
498+
if isinstance(param.annotation, str):
499+
return False
500+
501+
pydantic = _import_pydantic_or_raise()
502+
try:
503+
pydantic.TypeAdapter(param.annotation)
504+
return True
505+
except Exception:
506+
return False
507+
508+
462509
def _generate_schema(
463510
f: Callable[..., Any],
464511
*,
@@ -518,6 +565,7 @@ def _generate_schema(
518565
inspect.Parameter.KEYWORD_ONLY,
519566
inspect.Parameter.POSITIONAL_ONLY,
520567
)
568+
and _is_pydantic_serializable(param)
521569
}
522570
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
523571
# Postprocessing
@@ -827,6 +875,7 @@ def _register_api_methods_or_raise(
827875
_ASYNC_API_MODE: _wrap_async_query_operation,
828876
_STREAM_API_MODE: _wrap_stream_query_operation,
829877
_ASYNC_STREAM_API_MODE: _wrap_async_stream_query_operation,
878+
_A2A_EXTENSION_MODE: _wrap_a2a_operation,
830879
}
831880
if isinstance(wrap_operation_fn, dict) and api_mode in wrap_operation_fn:
832881
# Override the default function with user-specified function if it exists.
@@ -1438,6 +1487,62 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]
14381487
return _method
14391488

14401489

1490+
def _wrap_a2a_operation(
1491+
method_name: str, agent_card: str
1492+
) -> Callable[..., list]:
1493+
async def _method(self, **kwargs) -> list:
1494+
"""Wraps an Agent Engine method, creating a callable for A2A API."""
1495+
if not self.api_client:
1496+
raise ValueError("api_client is not initialized.")
1497+
if not self.api_resource:
1498+
raise ValueError("api_resource is not initialized.")
1499+
a2a_agent_card = AgentCard(**json.loads(agent_card))
1500+
a2a_agent_card.preferred_transport = TransportProtocol.http_json
1501+
# AE cannot support streaming yet. Turn off streaming for now.
1502+
a2a_agent_card.capabilities.streaming = False
1503+
# agent_card is set on the class_methods before set_up is invoked.
1504+
# Ensure that the agent_card url is set correctly before the client is created.
1505+
a2a_agent_card.url = f"{self.api_resource.name}/a2a"
1506+
1507+
# Using a2a client, inject the auth token from the global config.
1508+
config = ClientConfig(
1509+
supported_transports=[
1510+
TransportProtocol.http_json,
1511+
],
1512+
use_client_preference=True,
1513+
httpx_client=httpx.AsyncClient(
1514+
headers={
1515+
"Authorization": (
1516+
f"Bearer { self.api_client.credentials.token}"
1517+
)
1518+
}
1519+
),
1520+
)
1521+
factory = ClientFactory(config)
1522+
client = factory.create(a2a_agent_card)
1523+
1524+
match method_name:
1525+
case "on_message_send":
1526+
response = client.send_message(Message(**kwargs))
1527+
case "on_get_task":
1528+
response = client.get_task(TaskIdParams(**kwargs))
1529+
case "on_cancel_task":
1530+
response = client.cancel_task(TaskIdParams(**kwargs))
1531+
case "handle_authenticated_agent_card":
1532+
response = await client.get_card()
1533+
1534+
if inspect.isasyncgen(response):
1535+
# Response is an async generator, collect the chunks.
1536+
chunks = []
1537+
async for chunk in response:
1538+
chunks.append(chunk)
1539+
return chunks
1540+
else:
1541+
return response
1542+
1543+
return _method
1544+
1545+
14411546
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
14421547
"""Converts the body of the HTTP Response message to JSON format.
14431548
@@ -1531,3 +1636,4 @@ def _validate_resource_limits_or_raise(resource_limits: dict[str, str]) -> None:
15311636
f"Memory size of {memory_str} requires at least {min_cpu} CPUs."
15321637
f" Got {cpu}"
15331638
)
1639+

vertexai/_genai/agent_engines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3669,6 +3669,7 @@ def _register_api_methods(self, *, agent: types.AgentEngine) -> types.AgentEngin
36693669
"async": _agent_engines_utils._wrap_async_query_operation,
36703670
"stream": _agent_engines_utils._wrap_stream_query_operation,
36713671
"async_stream": _agent_engines_utils._wrap_async_stream_query_operation,
3672+
"a2a_extension": _agent_engines_utils._wrap_a2a_operation,
36723673
},
36733674
)
36743675
except Exception as e:

vertexai/agent_engines/_agent_engines.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,14 @@
1616
import abc
1717
import inspect
1818
import io
19+
import json
1920
import logging
2021
import os
2122
import sys
2223
import tarfile
2324
import types
2425
import typing
25-
from typing import (
26-
Any,
27-
AsyncIterable,
28-
Callable,
29-
Coroutine,
30-
Dict,
31-
Iterable,
32-
List,
33-
Optional,
34-
Protocol,
35-
Sequence,
36-
Tuple,
37-
Union,
38-
)
39-
40-
import proto
26+
from typing import Any, AsyncIterable, Callable, Coroutine, Dict, Iterable, List, Optional, Protocol, Sequence, Tuple, Union
4127

4228
from google.api_core import exceptions
4329
from google.cloud import storage
@@ -47,6 +33,9 @@
4733
from google.cloud.aiplatform_v1 import types as aip_types
4834
from google.cloud.aiplatform_v1.types import reasoning_engine_service
4935
from vertexai.agent_engines import _utils
36+
import httpx
37+
import proto
38+
5039
from google.protobuf import field_mask_pb2
5140

5241

@@ -60,6 +49,8 @@
6049
_ASYNC_API_MODE = "async"
6150
_STREAM_API_MODE = "stream"
6251
_ASYNC_STREAM_API_MODE = "async_stream"
52+
_A2A_EXTENSION_MODE = "a2a_extension"
53+
_A2A_AGENT_CARD = "a2a_agent_card"
6354
_MODE_KEY_IN_SCHEMA = "api_mode"
6455
_METHOD_NAME_KEY_IN_SCHEMA = "name"
6556
_DEFAULT_METHOD_NAME = "query"
@@ -111,6 +102,27 @@
111102
except (ImportError, AttributeError):
112103
ADKAgent = None
113104

105+
try:
106+
from a2a.types import (
107+
AgentCard,
108+
TransportProtocol,
109+
Message,
110+
TaskIdParams,
111+
)
112+
from a2a.client import ClientConfig, ClientFactory
113+
114+
AgentCard = AgentCard
115+
TransportProtocol = TransportProtocol
116+
ClientConfig = ClientConfig
117+
ClientFactory = ClientFactory
118+
TaskIdParams = TaskIdParams
119+
except (ImportError, AttributeError):
120+
AgentCard = None
121+
TransportProtocol = None
122+
ClientConfig = None
123+
ClientFactory = None
124+
TaskIdParams = None
125+
114126

115127
@typing.runtime_checkable
116128
class Queryable(Protocol):
@@ -1498,6 +1510,58 @@ async def _method(self, **kwargs) -> AsyncIterable[Any]:
14981510
return _method
14991511

15001512

1513+
def _wrap_a2a_operation(
1514+
method_name: str, agent_card: str
1515+
) -> Callable[..., list]:
1516+
async def _method(self, **kwargs) -> list:
1517+
"""Wraps an Agent Engine method, creating a callable for A2A API."""
1518+
a2a_agent_card = AgentCard(**json.loads(agent_card))
1519+
a2a_agent_card.preferred_transport = TransportProtocol.http_json
1520+
# AE cannot support streaming yet. Turn off streaming for now.
1521+
a2a_agent_card.capabilities.streaming = False
1522+
# agent_card is set on the class_methods before set_up is invoked.
1523+
# Ensure that the agent_card url is set correctly before the client is created.
1524+
a2a_agent_card.url = f"https://{initializer.global_config.api_endpoint}/v1/{self.resource_name}/a2a"
1525+
1526+
# Using a2a client, inject the auth token from the global config.
1527+
config = ClientConfig(
1528+
supported_transports=[
1529+
TransportProtocol.http_json,
1530+
],
1531+
use_client_preference=True,
1532+
httpx_client=httpx.AsyncClient(
1533+
headers={
1534+
"Authorization": (
1535+
f"Bearer {initializer.global_config.credentials.token}"
1536+
)
1537+
}
1538+
),
1539+
)
1540+
factory = ClientFactory(config)
1541+
client = factory.create(a2a_agent_card)
1542+
1543+
match method_name:
1544+
case "on_message_send":
1545+
response = client.send_message(Message(**kwargs))
1546+
case "on_get_task":
1547+
response = client.get_task(TaskIdParams(**kwargs))
1548+
case "on_cancel_task":
1549+
response = client.cancel_task(TaskIdParams(**kwargs))
1550+
case "handle_authenticated_agent_card":
1551+
response = await client.get_card()
1552+
1553+
if inspect.isasyncgen(response):
1554+
# Response is an async generator, collect the chunks.
1555+
chunks = []
1556+
async for chunk in response:
1557+
chunks.append(chunk)
1558+
return chunks
1559+
else:
1560+
return response
1561+
1562+
return _method
1563+
1564+
15011565
def _unregister_api_methods(
15021566
obj: "AgentEngine", operation_schemas: Sequence[_utils.JsonDict]
15031567
):
@@ -1573,6 +1637,7 @@ def _register_api_methods_or_raise(
15731637
_ASYNC_API_MODE: _wrap_async_query_operation,
15741638
_STREAM_API_MODE: _wrap_stream_query_operation,
15751639
_ASYNC_STREAM_API_MODE: _wrap_async_stream_query_operation,
1640+
_A2A_EXTENSION_MODE: _wrap_a2a_operation,
15761641
}
15771642
if isinstance(wrap_operation_fn, dict) and api_mode in wrap_operation_fn:
15781643
# Override the default function with user-specified function if it exists.
@@ -1661,6 +1726,11 @@ def _generate_class_methods_spec_or_raise(
16611726

16621727
class_method = _utils.to_proto(schema_dict)
16631728
class_method[_MODE_KEY_IN_SCHEMA] = mode
1729+
# A2A agent card is a special case, when running in A2A mode,
1730+
if hasattr(agent_engine, "agent_card"):
1731+
class_method[_A2A_AGENT_CARD] = getattr(
1732+
agent_engine, "agent_card"
1733+
).model_dump_json()
16641734
class_methods_spec.append(class_method)
16651735

16661736
return class_methods_spec

vertexai/agent_engines/_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,22 @@ def scan_requirements(
501501
return {module: importlib_metadata.version(module) for module in modules_found}
502502

503503

504+
def _is_pydantic_serializable(param: inspect.Parameter) -> bool:
505+
"""Checks if the parameter is pydantic serializable."""
506+
507+
if param.annotation == inspect.Parameter.empty:
508+
return True
509+
510+
if isinstance(param.annotation, str):
511+
return False
512+
pydantic = _import_pydantic_or_raise()
513+
try:
514+
pydantic.TypeAdapter(param.annotation)
515+
return True
516+
except Exception:
517+
return False
518+
519+
504520
def generate_schema(
505521
f: Callable[..., Any],
506522
*,
@@ -560,6 +576,7 @@ def generate_schema(
560576
inspect.Parameter.KEYWORD_ONLY,
561577
inspect.Parameter.POSITIONAL_ONLY,
562578
)
579+
and _is_pydantic_serializable(param)
563580
}
564581
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
565582
# Postprocessing

0 commit comments

Comments
 (0)