Skip to content

Commit e028a71

Browse files
committed
Initial version of nexus_operation_as_tool
1 parent 4462e3a commit e028a71

File tree

2 files changed

+160
-7
lines changed

2 files changed

+160
-7
lines changed

temporalio/contrib/openai_agents/temporal_tools.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Support for using Temporal activities as OpenAI agents tools."""
22

33
import json
4+
import typing
45
from datetime import timedelta
5-
from typing import Any, Callable, Optional
6+
from typing import Any, Callable, Optional, Type
67

78
from temporalio import activity, workflow
89
from temporalio.common import Priority, RetryPolicy
910
from temporalio.exceptions import ApplicationError, TemporalError
11+
from temporalio.nexus._util import get_operation_factory
1012
from temporalio.workflow import ActivityCancellationType, VersioningIntent, unsafe
1113

1214
with unsafe.imports_passed_through():
@@ -115,3 +117,102 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
115117
on_invoke_tool=run_activity,
116118
strict_json_schema=True,
117119
)
120+
121+
122+
def nexus_operation_as_tool(
123+
fn: Callable,
124+
*,
125+
service: Type[Any],
126+
endpoint: str,
127+
schedule_to_close_timeout: Optional[timedelta] = None,
128+
) -> Tool:
129+
"""Convert a Nexus operation into an OpenAI agent tool.
130+
131+
.. warning::
132+
This API is experimental and may change in future versions.
133+
Use with caution in production environments.
134+
135+
This function takes a Nexus operation and converts it into an
136+
OpenAI agent tool that can be used by the agent to execute the operation
137+
during workflow execution. The tool will automatically handle the conversion
138+
of inputs and outputs between the agent and the operation.
139+
140+
Args:
141+
fn: A Nexus operation to convert into a tool.
142+
service: The Nexus service class that contains the operation.
143+
endpoint: The Nexus endpoint to use for the operation.
144+
145+
Returns:
146+
An OpenAI agent tool that wraps the provided operation.
147+
148+
Raises:
149+
ApplicationError: If the operation is not properly decorated as a Nexus operation.
150+
151+
Example:
152+
>>> @service_handler
153+
>>> class WeatherServiceHandler:
154+
... @sync_operation
155+
... async def get_weather_object(self, ctx: StartOperationContext, input: WeatherInput) -> Weather:
156+
... return Weather(
157+
... city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
158+
... )
159+
>>>
160+
>>> # Create tool with custom activity options
161+
>>> tool = nexus_operation_as_tool(
162+
... WeatherServiceHandler.get_weather_object,
163+
... service=WeatherServiceHandler,
164+
... endpoint="weather-service",
165+
... )
166+
>>> # Use tool with an OpenAI agent
167+
"""
168+
if not get_operation_factory(fn):
169+
raise ApplicationError(
170+
"Function is not a Nexus operation",
171+
"invalid_tool",
172+
)
173+
174+
schema = function_schema(adapt_nexus_operation_function_schema(fn))
175+
176+
async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any:
177+
try:
178+
json_data = json.loads(input)
179+
except Exception as e:
180+
raise ApplicationError(
181+
f"Invalid JSON input for tool {schema.name}: {input}"
182+
) from e
183+
184+
nexus_client = workflow.NexusClient(service=service, endpoint=endpoint)
185+
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))
186+
assert len(args) == 1, "Nexus operations must have exactly one argument"
187+
[arg] = args
188+
result = await nexus_client.execute_operation(
189+
fn,
190+
arg,
191+
schedule_to_close_timeout=schedule_to_close_timeout,
192+
)
193+
try:
194+
return str(result)
195+
except Exception as e:
196+
raise ToolSerializationError(
197+
"You must return a string representation of the tool output, or something we can call str() on"
198+
) from e
199+
200+
return FunctionTool(
201+
name=schema.name,
202+
description=schema.description or "",
203+
params_json_schema=schema.params_json_schema,
204+
on_invoke_tool=run_operation,
205+
strict_json_schema=True,
206+
)
207+
208+
209+
def adapt_nexus_operation_function_schema(fn: Callable[..., Any]) -> Callable[..., Any]:
210+
# Nexus operation start methods look like
211+
# async def operation(self, ctx: StartOperationContext, input: InputType) -> OutputType
212+
_, inputT, retT = typing.get_type_hints(fn).values()
213+
214+
def adapted(input: inputT) -> retT: # type: ignore
215+
pass
216+
217+
adapted.__name__ = fn.__name__
218+
return adapted

tests/contrib/openai_agents/test_openai.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Optional, Union, no_type_check
66

77
import pytest
8+
from nexusrpc.handler import StartOperationContext, service_handler, sync_operation
89
from pydantic import ConfigDict, Field
910

1011
from temporalio import activity, workflow
@@ -19,12 +20,16 @@
1920
from temporalio.contrib.openai_agents.temporal_openai_agents import (
2021
set_open_ai_agent_temporal_overrides,
2122
)
22-
from temporalio.contrib.openai_agents.temporal_tools import activity_as_tool
23+
from temporalio.contrib.openai_agents.temporal_tools import (
24+
activity_as_tool,
25+
nexus_operation_as_tool,
26+
)
2327
from temporalio.contrib.openai_agents.trace_interceptor import (
2428
OpenAIAgentsTracingInterceptor,
2529
)
2630
from temporalio.exceptions import CancelledError
2731
from tests.helpers import new_worker
32+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
2833

2934
with workflow.unsafe.imports_passed_through():
3035
from agents import (
@@ -223,6 +228,17 @@ async def get_weather_object(input: WeatherInput) -> Weather:
223228
)
224229

225230

231+
@service_handler
232+
class WeatherServiceHandler:
233+
@sync_operation
234+
async def get_weather_object_nexus_operation(
235+
self, ctx: StartOperationContext, input: WeatherInput
236+
) -> Weather:
237+
return Weather(
238+
city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
239+
)
240+
241+
226242
class TestWeatherModel(TestModel):
227243
responses = [
228244
ModelResponse(
@@ -253,6 +269,20 @@ class TestWeatherModel(TestModel):
253269
usage=Usage(),
254270
response_id=None,
255271
),
272+
ModelResponse(
273+
output=[
274+
ResponseFunctionToolCall(
275+
arguments='{"input":{"city":"Tokyo"}}',
276+
call_id="call",
277+
name="get_weather_object_nexus_operation",
278+
type="function_call",
279+
id="id",
280+
status="completed",
281+
)
282+
],
283+
usage=Usage(),
284+
response_id=None,
285+
),
256286
ModelResponse(
257287
output=[
258288
ResponseFunctionToolCall(
@@ -306,6 +336,12 @@ async def run(self, question: str) -> str:
306336
activity_as_tool(
307337
get_weather_country, start_to_close_timeout=timedelta(seconds=10)
308338
),
339+
nexus_operation_as_tool(
340+
WeatherServiceHandler.get_weather_object_nexus_operation,
341+
service=WeatherServiceHandler,
342+
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
343+
schedule_to_close_timeout=timedelta(seconds=10),
344+
),
309345
],
310346
) # type: Agent
311347
result = await Runner.run(starting_agent=agent, input=question)
@@ -340,8 +376,11 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
340376
get_weather_object,
341377
get_weather_country,
342378
],
379+
nexus_service_handlers=[WeatherServiceHandler()],
343380
interceptors=[OpenAIAgentsTracingInterceptor()],
344381
) as worker:
382+
await create_nexus_endpoint(worker.task_queue, client)
383+
345384
workflow_handle = await client.start_workflow(
346385
ToolsWorkflow.run,
347386
"What is the weather in Tokio?",
@@ -353,13 +392,14 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
353392

354393
if use_local_model:
355394
assert result == "Test weather result"
356-
357395
events = []
358396
async for e in workflow_handle.fetch_history_events():
359-
if e.HasField("activity_task_completed_event_attributes"):
397+
if e.HasField(
398+
"activity_task_completed_event_attributes"
399+
) or e.HasField("nexus_operation_completed_event_attributes"):
360400
events.append(e)
361401

362-
assert len(events) == 7
402+
assert len(events) == 9
363403
assert (
364404
"function_call"
365405
in events[0]
@@ -392,13 +432,25 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
392432
)
393433
assert (
394434
"Sunny with wind"
395-
in events[5]
435+
in events[
436+
5
437+
].nexus_operation_completed_event_attributes.result.data.decode()
438+
)
439+
assert (
440+
"function_call"
441+
in events[6]
442+
.activity_task_completed_event_attributes.result.payloads[0]
443+
.data.decode()
444+
)
445+
assert (
446+
"Sunny with wind"
447+
in events[7]
396448
.activity_task_completed_event_attributes.result.payloads[0]
397449
.data.decode()
398450
)
399451
assert (
400452
"Test weather result"
401-
in events[6]
453+
in events[8]
402454
.activity_task_completed_event_attributes.result.payloads[0]
403455
.data.decode()
404456
)

0 commit comments

Comments
 (0)