@@ -110,8 +110,9 @@ def populate_template(template: str, variables: dict[str, Any]) -> str:
110
110
111
111
112
112
@dataclass
113
- class FinalOutput :
114
- output : Any | None
113
+ class ActionOutput :
114
+ output : Any
115
+ is_final_answer : bool
115
116
116
117
117
118
class PlanningPromptTemplate (TypedDict ):
@@ -280,7 +281,7 @@ def __init__(
280
281
self .name = self ._validate_name (name )
281
282
self .description = description
282
283
self .provide_run_summary = provide_run_summary
283
- self .final_answer_checks = final_answer_checks
284
+ self .final_answer_checks = final_answer_checks if final_answer_checks is not None else []
284
285
self .return_full_result = return_full_result
285
286
self .instructions = instructions
286
287
self ._setup_managed_agents (managed_agents )
@@ -451,9 +452,9 @@ def run(
451
452
def _run_stream (
452
453
self , task : str , max_steps : int , images : list ["PIL.Image.Image" ] | None = None
453
454
) -> Generator [ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta ]:
454
- final_answer = None
455
455
self .step_number = 1
456
- while final_answer is None and self .step_number <= max_steps :
456
+ returned_final_answer = False
457
+ while not returned_final_answer and self .step_number <= max_steps :
457
458
if self .interrupt_switch :
458
459
raise AgentError ("Agent interrupted." , self .logger )
459
460
@@ -464,8 +465,8 @@ def _run_stream(
464
465
planning_start_time = time .time ()
465
466
planning_step = None
466
467
for element in self ._generate_planning_step (
467
- task , is_first_step = ( len (self .memory .steps ) == 1 ) , step = self .step_number
468
- ):
468
+ task , is_first_step = len (self .memory .steps ) == 1 , step = self .step_number
469
+ ): # Don't use the attribute step_number here, because there can be steps from previous runs
469
470
yield element
470
471
planning_step = element
471
472
assert isinstance (planning_step , PlanningStep ) # Last yielded element should be a PlanningStep
@@ -483,10 +484,19 @@ def _run_stream(
483
484
timing = Timing (start_time = action_step_start_time ),
484
485
observations_images = images ,
485
486
)
487
+ self .logger .log_rule (f"Step { self .step_number } " , level = LogLevel .INFO )
486
488
try :
487
- for el in self ._execute_step (action_step ):
488
- yield el
489
- final_answer = el
489
+ for output in self ._step_stream (action_step ):
490
+ # Yield streaming deltas
491
+ if not isinstance (output , ActionOutput ):
492
+ yield output
493
+
494
+ if isinstance (output , ActionOutput ) and output .is_final_answer :
495
+ if self .final_answer_checks :
496
+ self ._validate_final_answer (output .output )
497
+ returned_final_answer = True
498
+ action_step .is_final_answer = True
499
+ final_answer = output .output
490
500
except AgentGenerationError as e :
491
501
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
492
502
raise e
@@ -499,23 +509,11 @@ def _run_stream(
499
509
yield action_step
500
510
self .step_number += 1
501
511
502
- if final_answer is None and self .step_number == max_steps + 1 :
512
+ if not returned_final_answer and self .step_number == max_steps + 1 :
503
513
final_answer = self ._handle_max_steps_reached (task , images )
504
514
yield action_step
505
515
yield FinalAnswerStep (handle_agent_output_types (final_answer ))
506
516
507
- def _execute_step (self , memory_step : ActionStep ) -> Generator [ChatMessageStreamDelta | FinalOutput ]:
508
- self .logger .log_rule (f"Step { self .step_number } " , level = LogLevel .INFO )
509
- for el in self ._step_stream (memory_step ):
510
- final_answer = el
511
- if isinstance (el , ChatMessageStreamDelta ):
512
- yield el
513
- elif isinstance (el , FinalOutput ):
514
- final_answer = el .output
515
- if self .final_answer_checks :
516
- self ._validate_final_answer (final_answer )
517
- yield final_answer
518
-
519
517
def _validate_final_answer (self , final_answer : Any ):
520
518
for check_function in self .final_answer_checks :
521
519
try :
@@ -674,7 +672,7 @@ def interrupt(self):
674
672
675
673
def write_memory_to_messages (
676
674
self ,
677
- summary_mode : bool | None = False ,
675
+ summary_mode : bool = False ,
678
676
) -> list [Message ]:
679
677
"""
680
678
Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
@@ -686,7 +684,7 @@ def write_memory_to_messages(
686
684
messages .extend (memory_step .to_messages (summary_mode = summary_mode ))
687
685
return messages
688
686
689
- def _step_stream (self , memory_step : ActionStep ) -> Generator [ChatMessageStreamDelta | FinalOutput ]:
687
+ def _step_stream (self , memory_step : ActionStep ) -> Generator [ChatMessageStreamDelta | ActionOutput ]:
690
688
"""
691
689
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
692
690
Yields ChatMessageStreamDelta during the run if streaming is enabled.
@@ -1203,7 +1201,7 @@ def initialize_system_prompt(self) -> str:
1203
1201
)
1204
1202
return system_prompt
1205
1203
1206
- def _step_stream (self , memory_step : ActionStep ) -> Generator [ChatMessageStreamDelta | FinalOutput ]:
1204
+ def _step_stream (self , memory_step : ActionStep ) -> Generator [ChatMessageStreamDelta | ActionOutput ]:
1207
1205
"""
1208
1206
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1209
1207
Yields ChatMessageStreamDelta during the run if streaming is enabled.
@@ -1291,14 +1289,15 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep)
1291
1289
memory_step (`ActionStep)`: Memory ActionStep to update with results.
1292
1290
1293
1291
Yields:
1294
- `FinalOutput `: The final output of tool execution.
1292
+ `ActionOutput `: The final output of tool execution.
1295
1293
"""
1296
1294
model_outputs = []
1297
1295
tool_calls = []
1298
1296
observations = []
1299
1297
1300
1298
final_answer_call = None
1301
1299
parallel_calls = []
1300
+ assert chat_message .tool_calls is not None
1302
1301
for tool_call in chat_message .tool_calls :
1303
1302
tool_name = tool_call .function .name
1304
1303
tool_arguments = tool_call .function .arguments
@@ -1338,19 +1337,19 @@ def process_single_tool_call(call_info):
1338
1337
)
1339
1338
return observation
1340
1339
1341
- # Process non-final-answer tool calls in parallel
1340
+ # Process tool calls in parallel
1342
1341
if parallel_calls :
1343
1342
if len (parallel_calls ) == 1 :
1344
1343
# If there's only one call, process it directly
1345
1344
observations .append (process_single_tool_call (parallel_calls [0 ]))
1346
- yield FinalOutput (output = None )
1345
+ yield ActionOutput (output = None , is_final_answer = False )
1347
1346
else :
1348
1347
# If multiple tool calls, process them in parallel
1349
1348
with ThreadPoolExecutor (self .max_tool_threads ) as executor :
1350
1349
futures = [executor .submit (process_single_tool_call , call_info ) for call_info in parallel_calls ]
1351
1350
for future in as_completed (futures ):
1352
1351
observations .append (future .result ())
1353
- yield FinalOutput (output = None )
1352
+ yield ActionOutput (output = None , is_final_answer = False )
1354
1353
1355
1354
# Process final_answer call if present
1356
1355
if final_answer_call :
@@ -1380,7 +1379,7 @@ def process_single_tool_call(call_info):
1380
1379
level = LogLevel .INFO ,
1381
1380
)
1382
1381
memory_step .action_output = final_answer
1383
- yield FinalOutput (output = final_answer )
1382
+ yield ActionOutput (output = final_answer , is_final_answer = True )
1384
1383
1385
1384
# Update memory step with all results
1386
1385
if model_outputs :
@@ -1572,7 +1571,7 @@ def initialize_system_prompt(self) -> str:
1572
1571
)
1573
1572
return system_prompt
1574
1573
1575
- def _step_stream (self , memory_step : ActionStep ) -> Generator [ChatMessageStreamDelta | FinalOutput ]:
1574
+ def _step_stream (self , memory_step : ActionStep ) -> Generator [ChatMessageStreamDelta | ActionOutput ]:
1576
1575
"""
1577
1576
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
1578
1577
Yields ChatMessageStreamDelta during the run if streaming is enabled.
@@ -1702,7 +1701,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
1702
1701
]
1703
1702
self .logger .log (Group (* execution_outputs_console ), level = LogLevel .INFO )
1704
1703
memory_step .action_output = output
1705
- yield FinalOutput (output = output if is_final_answer else None )
1704
+ yield ActionOutput (output = output , is_final_answer = is_final_answer )
1706
1705
1707
1706
def to_dict (self ) -> dict [str , Any ]:
1708
1707
"""Convert the agent to a dictionary representation.
0 commit comments