Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@
"google-adk >= 1.0.0, < 2.0.0",
]

a2a_extra_require = [
"a2a-sdk >= 0.3.0",
]

reasoning_engine_extra_require = [
"cloudpickle >= 3.0, < 4.0",
"google-cloud-trace < 2",
Expand Down Expand Up @@ -325,6 +329,7 @@
"ray": ray_extra_require,
"ray_testing": ray_testing_extra_require,
"adk": adk_extra_require,
"a2a": a2a_extra_require,
"reasoningengine": reasoning_engine_extra_require,
"agent_engines": agent_engines_extra_require,
"evaluation": evaluation_extra_require,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/vertexai/genai/test_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,7 +1923,7 @@ def test_update_agent_engine_description(self, mock_await_operation):
"register the API methods: "
"https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. "
"Error: {Unsupported api mode: `UNKNOWN_API_MODE`, "
"Supported modes are: ``, `async`, `async_stream`, `stream`.}"
"Supported modes are: ``, `a2a_extension`, `async`, `async_stream`, `stream`.}"
),
),
],
Expand Down
185 changes: 184 additions & 1 deletion vertexai/_genai/_agent_engines_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
Union,
)

import httpx

import proto

from google.api_core import exceptions
Expand Down Expand Up @@ -103,6 +105,32 @@
Session = Any


try:
from a2a.types import (
AgentCard,
TransportProtocol,
Message,
TaskIdParams,
TaskQueryParams,
)
from a2a.client import ClientConfig, ClientFactory

AgentCard = AgentCard
TransportProtocol = TransportProtocol
Message = Message
ClientConfig = ClientConfig
ClientFactory = ClientFactory
TaskIdParams = TaskIdParams
TaskQueryParams = TaskQueryParams
except (ImportError, AttributeError):
AgentCard = None
TransportProtocol = None
Message = None
ClientConfig = None
ClientFactory = None
TaskIdParams = None
TaskQueryParams = None

_ACTIONS_KEY = "actions"
_ACTION_APPEND = "append"
_AGENT_FRAMEWORK_ATTR = "agent_framework"
Expand Down Expand Up @@ -145,6 +173,8 @@
_REQUIREMENTS_FILE = "requirements.txt"
_STANDARD_API_MODE = ""
_STREAM_API_MODE = "stream"
_A2A_EXTENSION_MODE = "a2a_extension"
_A2A_AGENT_CARD = "a2a_agent_card"
_WARNINGS_KEY = "warnings"
_WARNING_MISSING = "missing"
_WARNING_INCOMPATIBLE = "incompatible"
Expand Down Expand Up @@ -493,11 +523,32 @@ def _generate_class_methods_spec_or_raise(

class_method = _to_proto(schema_dict)
class_method[_MODE_KEY_IN_SCHEMA] = mode
if hasattr(agent, "agent_card"):
class_method[_A2A_AGENT_CARD] = getattr(
agent, "agent_card"
).model_dump_json()
class_methods_spec.append(class_method)

return class_methods_spec


def _is_pydantic_serializable(param: inspect.Parameter) -> bool:
"""Checks if the parameter is pydantic serializable."""

if param.annotation == inspect.Parameter.empty:
return True

if isinstance(param.annotation, str):
return False

pydantic = _import_pydantic_or_raise()
try:
pydantic.TypeAdapter(param.annotation)
return True
except Exception:
return False


def _generate_schema(
f: Callable[..., Any],
*,
Expand Down Expand Up @@ -557,6 +608,7 @@ def _generate_schema(
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
and _is_pydantic_serializable(param)
}
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
# Postprocessing
Expand Down Expand Up @@ -868,6 +920,7 @@ def _register_api_methods_or_raise(
_ASYNC_API_MODE: _wrap_async_query_operation,
_STREAM_API_MODE: _wrap_stream_query_operation,
_ASYNC_STREAM_API_MODE: _wrap_async_stream_query_operation,
_A2A_EXTENSION_MODE: _wrap_a2a_operation,
}
if isinstance(wrap_operation_fn, dict) and api_mode in wrap_operation_fn:
# Override the default function with user-specified function if it exists.
Expand All @@ -884,7 +937,13 @@ def _register_api_methods_or_raise(
)

# Bind the method to the object.
method = _wrap_operation(method_name=method_name) # type: ignore[call-arg]
if api_mode == _A2A_EXTENSION_MODE:
agent_card = operation_schema.get(_A2A_AGENT_CARD)
method = _wrap_operation(
method_name=method_name, agent_card=agent_card
) # type: ignore[call-arg]
else:
method = _wrap_operation(method_name=method_name) # type: ignore[call-arg]
method.__name__ = method_name
if method_description and isinstance(method_description, str):
method.__doc__ = method_description
Expand Down Expand Up @@ -1473,6 +1532,130 @@ async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]
return _method


def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list]:
"""Wraps an Agent Engine method, creating a callable for A2A API.

Args:
method_name: The name of the Agent Engine method to call.
agent_card: The agent card to use for the A2A API call.
Example:
{'additionalInterfaces': None,
'capabilities': {'extensions': None,
'pushNotifications': None,
'stateTransitionHistory': None,
'streaming': False},
'defaultInputModes': ['text'],
'defaultOutputModes': ['text'],
'description': (
'A helpful assistant agent that can answer questions.'
),
'documentationUrl': None,
'iconUrl': None,
'name': 'Q&A Agent',
'preferredTransport': 'JSONRPC',
'protocolVersion': '0.3.0',
'provider': None,
'security': None,
'securitySchemes': None,
'signatures': None,
'skills': [{
'description': (
'A helpful assistant agent that can answer questions.'
),
'examples': ['Who is leading 2025 F1 Standings?',
'Where can i find an active volcano?'],
'id': 'question_answer',
'inputModes': None,
'name': 'Q&A Agent',
'outputModes': None,
'security': None,
'tags': ['Question-Answer']}],
'supportsAuthenticatedExtendedCard': True,
'url': 'http://localhost:8080/',
'version': '1.0.0'}
Returns:
A callable object that executes the method on the Agent Engine via
the A2A API.
"""

async def _method(self, **kwargs) -> list:
"""Wraps an Agent Engine method, creating a callable for A2A API."""
if not self.api_client:
raise ValueError("api_client is not initialized.")
if not self.api_resource:
raise ValueError("api_resource is not initialized.")
a2a_agent_card = AgentCard(**json.loads(agent_card))
# A2A + AE integration currently only supports Rest API.
if (
a2a_agent_card.preferred_transport
and a2a_agent_card.preferred_transport != TransportProtocol.http_json
):
raise ValueError(
"Only HTTP+JSON is supported for preferred transport on agent card "
)

# Set preferred transport to HTTP+JSON if not set.
if not hasattr(a2a_agent_card, "preferred_transport"):
a2a_agent_card.preferred_transport = TransportProtocol.http_json

# AE cannot support streaming yet. Turn off streaming for now.
if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming:
raise ValueError(
"Streaming is not supported in Agent Engine, please change "
"a2a_agent_card.capabilities.streaming to False."
)

if not hasattr(a2a_agent_card.capabilities, "streaming"):
a2a_agent_card.capabilities.streaming = False

# agent_card is set on the class_methods before set_up is invoked.
# Ensure that the agent_card url is set correctly before the client is created.
base_url = self.api_client._api_client._http_options.base_url.rstrip(
"/"
)
api_version = self.api_client._api_client._http_options.api_version
a2a_agent_card.url = (
f"{base_url}/{api_version}/{self.api_resource.name}/a2a"
)

# Using a2a client, inject the auth token from the global config.
config = ClientConfig(
supported_transports=[
TransportProtocol.http_json,
],
use_client_preference=True,
httpx_client=httpx.AsyncClient(
headers={
"Authorization": (f"Bearer {self.api_client._api_client._credentials.token}")
}
),
)
factory = ClientFactory(config)
client = factory.create(a2a_agent_card)

if method_name == "on_message_send":
response = client.send_message(Message(**kwargs))
elif method_name == "on_get_task":
response = await client.get_task(TaskQueryParams(**kwargs))
elif method_name == "on_cancel_task":
response = await client.cancel_task(TaskIdParams(**kwargs))
elif method_name == "handle_authenticated_agent_card":
response = await client.get_card()
else:
raise ValueError(f"Unknown method name: {method_name}")

if inspect.isasyncgen(response):
# Response is an async generator, collect the chunks.
chunks = []
async for chunk in response:
chunks.append(chunk)
return chunks
else:
return response

return _method


def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:
"""Converts the body of the HTTP Response message to JSON format.

Expand Down
1 change: 1 addition & 0 deletions vertexai/_genai/agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,7 @@ def _register_api_methods(
"async": _agent_engines_utils._wrap_async_query_operation,
"stream": _agent_engines_utils._wrap_stream_query_operation,
"async_stream": _agent_engines_utils._wrap_async_stream_query_operation,
"a2a_extension": _agent_engines_utils._wrap_a2a_operation,
},
)
except Exception as e:
Expand Down
Loading
Loading