Skip to content

Initial version of nexus_operation_as_tool #932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: nexus
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion temporalio/contrib/openai_agents/temporal_tools.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Support for using Temporal activities as OpenAI agents tools."""

import json
import typing
from datetime import timedelta
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Type

from temporalio import activity, workflow
from temporalio.common import Priority, RetryPolicy
from temporalio.exceptions import ApplicationError, TemporalError
from temporalio.nexus._util import get_operation_factory
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe

with unsafe.imports_passed_through():
Expand Down Expand Up @@ -115,3 +117,102 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
on_invoke_tool=run_activity,
strict_json_schema=True,
)


def nexus_operation_as_tool(
fn: Callable,
*,
service: Type[Any],
endpoint: str,
schedule_to_close_timeout: Optional[timedelta] = None,
) -> Tool:
"""Convert a Nexus operation into an OpenAI agent tool.

.. warning::
This API is experimental and may change in future versions.
Use with caution in production environments.

This function takes a Nexus operation and converts it into an
OpenAI agent tool that can be used by the agent to execute the operation
during workflow execution. The tool will automatically handle the conversion
of inputs and outputs between the agent and the operation.

Args:
fn: A Nexus operation to convert into a tool.
service: The Nexus service class that contains the operation.
endpoint: The Nexus endpoint to use for the operation.

Returns:
An OpenAI agent tool that wraps the provided operation.

Raises:
ApplicationError: If the operation is not properly decorated as a Nexus operation.

Example:
>>> @service_handler
>>> class WeatherServiceHandler:
... @sync_operation
... async def get_weather_object(self, ctx: StartOperationContext, input: WeatherInput) -> Weather:
... return Weather(
... city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
... )
>>>
>>> # Create tool with custom activity options
>>> tool = nexus_operation_as_tool(
... WeatherServiceHandler.get_weather_object,
... service=WeatherServiceHandler,
... endpoint="weather-service",
... )
>>> # Use tool with an OpenAI agent
"""
if not get_operation_factory(fn):
raise ApplicationError(
"Function is not a Nexus operation",
"invalid_tool",
)

schema = function_schema(adapt_nexus_operation_function_schema(fn))

async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
json_data = json.loads(input)
except Exception as e:
raise ApplicationError(
f"Invalid JSON input for tool {schema.name}: {input}"
) from e

nexus_client = workflow.create_nexus_client(endpoint=endpoint, service=service)
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))
assert len(args) == 1, "Nexus operations must have exactly one argument"
[arg] = args
result = await nexus_client.execute_operation(
fn,
arg,
schedule_to_close_timeout=schedule_to_close_timeout,
)
try:
return str(result)
except Exception as e:
raise ToolSerializationError(
"You must return a string representation of the tool output, or something we can call str() on"
) from e

return FunctionTool(
name=schema.name,
description=schema.description or "",
params_json_schema=schema.params_json_schema,
on_invoke_tool=run_operation,
strict_json_schema=True,
)


def adapt_nexus_operation_function_schema(fn: Callable[..., Any]) -> Callable[..., Any]:
# Nexus operation start methods look like
# async def operation(self, ctx: StartOperationContext, input: InputType) -> OutputType
_, inputT, retT = typing.get_type_hints(fn).values()

def adapted(input: inputT) -> retT: # type: ignore
pass

adapted.__name__ = fn.__name__
return adapted
64 changes: 58 additions & 6 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Optional, Union, no_type_check

import pytest
from nexusrpc.handler import StartOperationContext, service_handler, sync_operation
from pydantic import ConfigDict, Field

from temporalio import activity, workflow
Expand All @@ -19,12 +20,16 @@
from temporalio.contrib.openai_agents.temporal_openai_agents import (
set_open_ai_agent_temporal_overrides,
)
from temporalio.contrib.openai_agents.temporal_tools import activity_as_tool
from temporalio.contrib.openai_agents.temporal_tools import (
activity_as_tool,
nexus_operation_as_tool,
)
from temporalio.contrib.openai_agents.trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.exceptions import CancelledError
from tests.helpers import new_worker
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name

with workflow.unsafe.imports_passed_through():
from agents import (
Expand Down Expand Up @@ -223,6 +228,17 @@ async def get_weather_object(input: WeatherInput) -> Weather:
)


@service_handler
class WeatherServiceHandler:
@sync_operation
async def get_weather_object_nexus_operation(
self, ctx: StartOperationContext, input: WeatherInput
) -> Weather:
return Weather(
city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
)


class TestWeatherModel(TestModel):
responses = [
ModelResponse(
Expand Down Expand Up @@ -253,6 +269,20 @@ class TestWeatherModel(TestModel):
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseFunctionToolCall(
arguments='{"input":{"city":"Tokyo"}}',
call_id="call",
name="get_weather_object_nexus_operation",
type="function_call",
id="id",
status="completed",
)
],
usage=Usage(),
response_id=None,
),
ModelResponse(
output=[
ResponseFunctionToolCall(
Expand Down Expand Up @@ -306,6 +336,12 @@ async def run(self, question: str) -> str:
activity_as_tool(
get_weather_country, start_to_close_timeout=timedelta(seconds=10)
),
nexus_operation_as_tool(
WeatherServiceHandler.get_weather_object_nexus_operation,
service=WeatherServiceHandler,
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
schedule_to_close_timeout=timedelta(seconds=10),
),
],
) # type: Agent
result = await Runner.run(starting_agent=agent, input=question)
Expand Down Expand Up @@ -340,8 +376,11 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
get_weather_object,
get_weather_country,
],
nexus_service_handlers=[WeatherServiceHandler()],
interceptors=[OpenAIAgentsTracingInterceptor()],
) as worker:
await create_nexus_endpoint(worker.task_queue, client)

workflow_handle = await client.start_workflow(
ToolsWorkflow.run,
"What is the weather in Tokio?",
Expand All @@ -353,13 +392,14 @@ async def test_tool_workflow(client: Client, use_local_model: bool):

if use_local_model:
assert result == "Test weather result"

events = []
async for e in workflow_handle.fetch_history_events():
if e.HasField("activity_task_completed_event_attributes"):
if e.HasField(
"activity_task_completed_event_attributes"
) or e.HasField("nexus_operation_completed_event_attributes"):
events.append(e)

assert len(events) == 7
assert len(events) == 9
assert (
"function_call"
in events[0]
Expand Down Expand Up @@ -392,13 +432,25 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
)
assert (
"Sunny with wind"
in events[5]
in events[
5
].nexus_operation_completed_event_attributes.result.data.decode()
)
assert (
"function_call"
in events[6]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"Sunny with wind"
in events[7]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
assert (
"Test weather result"
in events[6]
in events[8]
.activity_task_completed_event_attributes.result.payloads[0]
.data.decode()
)
Expand Down
Loading