diff --git a/temporalio/contrib/openai_agents/_temporal_trace_provider.py b/temporalio/contrib/openai_agents/_temporal_trace_provider.py index 4637afbe1..0c4d016ef 100644 --- a/temporalio/contrib/openai_agents/_temporal_trace_provider.py +++ b/temporalio/contrib/openai_agents/_temporal_trace_provider.py @@ -7,7 +7,10 @@ from agents.tracing import ( get_trace_provider, ) -from agents.tracing.provider import DefaultTraceProvider +from agents.tracing.provider import ( + DefaultTraceProvider, + SynchronousMultiTracingProcessor, +) from agents.tracing.spans import Span from temporalio import workflow @@ -72,11 +75,17 @@ def activity_span( ) -class _TemporalTracingProcessor(TracingProcessor): - def __init__(self, impl: TracingProcessor): +class _TemporalTracingProcessor(SynchronousMultiTracingProcessor): + def __init__(self, impl: SynchronousMultiTracingProcessor): super().__init__() self._impl = impl + def add_tracing_processor(self, tracing_processor: TracingProcessor): + self._impl.add_tracing_processor(tracing_processor) + + def set_processors(self, processors: list[TracingProcessor]): + self._impl.set_processors(processors) + def on_trace_start(self, trace: Trace) -> None: if workflow.in_workflow() and workflow.unsafe.is_replaying(): # In replay mode, don't report @@ -118,7 +127,7 @@ def __init__(self): """Initialize the TemporalTraceProvider.""" super().__init__() self._original_provider = cast(DefaultTraceProvider, get_trace_provider()) - self._multi_processor = _TemporalTracingProcessor( # type: ignore[assignment] + self._multi_processor = _TemporalTracingProcessor( self._original_provider._multi_processor )