diff --git a/src/agents/result.py b/src/agents/result.py index 1f1c78328..243db155c 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -15,6 +15,7 @@ from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .logger import logger +from .run_context import RunContextWrapper from .stream_events import StreamEvent from .tracing import Trace from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming @@ -50,6 +51,9 @@ class RunResultBase(abc.ABC): output_guardrail_results: list[OutputGuardrailResult] """Guardrail results for the final output of the agent.""" + context_wrapper: RunContextWrapper[Any] + """The context wrapper for the agent run.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -75,9 +79,7 @@ def final_output_as(self, cls: type[T], raise_if_incorrect_type: bool = False) - def to_input_list(self) -> list[TResponseInputItem]: """Creates a new input list, merging the original input with all the new items generated.""" - original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list( - self.input - ) + original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input) new_items = [item.to_input_item() for item in self.new_items] return original_items + new_items @@ -206,17 +208,13 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]: def _check_errors(self): if self.current_turn > self.max_turns: - self._stored_exception = MaxTurnsExceeded( - f"Max turns ({self.max_turns}) exceeded" - ) + self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded") # Fetch all the completed guardrail results from the queue and raise if needed while not self._input_guardrail_queue.empty(): guardrail_result = self._input_guardrail_queue.get_nowait() if guardrail_result.output.tripwire_triggered: - self._stored_exception = InputGuardrailTripwireTriggered( - guardrail_result - ) + self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result) # Check the tasks for any exceptions if self._run_impl_task and self._run_impl_task.done(): diff --git a/src/agents/run.py b/src/agents/run.py index 2af558d58..849da7bfc 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -270,6 +270,7 @@ async def run( _last_agent=current_agent, input_guardrail_results=input_guardrail_results, output_guardrail_results=output_guardrail_results, + context_wrapper=context_wrapper, ) elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -423,6 +424,7 @@ def run_streamed( output_guardrail_results=[], _current_agent_output_schema=output_schema, trace=new_trace, + context_wrapper=context_wrapper, ) # Kick off the actual agent loop in the background and return the streamed result object. @@ -696,6 +698,7 @@ async def _run_single_turn_streamed( usage=usage, response_id=event.response.id, ) + context_wrapper.usage.add(usage) streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) diff --git a/tests/fake_model.py b/tests/fake_model.py index da3019a0f..32f919ef1 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -3,7 +3,8 @@ from collections.abc import AsyncIterator from typing import Any -from openai.types.responses import Response, ResponseCompletedEvent +from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.agent_output import AgentOutputSchemaBase from agents.handoffs import Handoff @@ -33,6 +34,10 @@ def __init__( ) self.tracing_enabled = tracing_enabled self.last_turn_args: dict[str, Any] = {} + self.hardcoded_usage: Usage | None = None + + def set_hardcoded_usage(self, usage: Usage): + self.hardcoded_usage = usage def set_next_output(self, output: list[TResponseOutputItem] | Exception): self.turn_outputs.append(output) @@ -83,7 +88,7 @@ async def get_response( return ModelResponse( output=output, - usage=Usage(), + usage=self.hardcoded_usage or Usage(), response_id=None, ) @@ -123,13 +128,14 @@ async def stream_response( yield ResponseCompletedEvent( type="response.completed", - response=get_response_obj(output), + response=get_response_obj(output, usage=self.hardcoded_usage), ) def get_response_obj( output: list[TResponseOutputItem], response_id: str | None = None, + usage: Usage | None = None, ) -> Response: return Response( id=response_id or "123", @@ -141,4 +147,11 @@ def get_response_obj( tools=[], top_p=None, parallel_tool_calls=False, + usage=ResponseUsage( + input_tokens=usage.input_tokens if usage else 0, + output_tokens=usage.output_tokens if usage else 0, + total_tokens=usage.total_tokens if usage else 0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), + ), ) diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index ec17e3275..c621e7352 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -3,7 +3,7 @@ import pytest from pydantic import BaseModel -from agents import Agent, RunResult +from agents import Agent, RunContextWrapper, RunResult def create_run_result(final_output: Any) -> RunResult: @@ -15,6 +15,7 @@ def create_run_result(final_output: Any) -> RunResult: input_guardrail_results=[], output_guardrail_results=[], _last_agent=Agent(name="test"), + context_wrapper=RunContextWrapper(context=None), )