Skip to content

Commit 76ecb9b

Browse files
Fix and refactor final answer checks (#1448)
1 parent e812838 commit 76ecb9b

File tree

4 files changed

+71
-48
lines changed

4 files changed

+71
-48
lines changed

src/smolagents/agents.py

Lines changed: 32 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,9 @@ def populate_template(template: str, variables: dict[str, Any]) -> str:
110110

111111

112112
@dataclass
113-
class FinalOutput:
114-
output: Any | None
113+
class ActionOutput:
114+
output: Any
115+
is_final_answer: bool
115116

116117

117118
class PlanningPromptTemplate(TypedDict):
@@ -280,7 +281,7 @@ def __init__(
280281
self.name = self._validate_name(name)
281282
self.description = description
282283
self.provide_run_summary = provide_run_summary
283-
self.final_answer_checks = final_answer_checks
284+
self.final_answer_checks = final_answer_checks if final_answer_checks is not None else []
284285
self.return_full_result = return_full_result
285286
self.instructions = instructions
286287
self._setup_managed_agents(managed_agents)
@@ -451,9 +452,9 @@ def run(
451452
def _run_stream(
452453
self, task: str, max_steps: int, images: list["PIL.Image.Image"] | None = None
453454
) -> Generator[ActionStep | PlanningStep | FinalAnswerStep | ChatMessageStreamDelta]:
454-
final_answer = None
455455
self.step_number = 1
456-
while final_answer is None and self.step_number <= max_steps:
456+
returned_final_answer = False
457+
while not returned_final_answer and self.step_number <= max_steps:
457458
if self.interrupt_switch:
458459
raise AgentError("Agent interrupted.", self.logger)
459460

@@ -464,8 +465,8 @@ def _run_stream(
464465
planning_start_time = time.time()
465466
planning_step = None
466467
for element in self._generate_planning_step(
467-
task, is_first_step=(len(self.memory.steps) == 1), step=self.step_number
468-
):
468+
task, is_first_step=len(self.memory.steps) == 1, step=self.step_number
469+
): # Don't use the attribute step_number here, because there can be steps from previous runs
469470
yield element
470471
planning_step = element
471472
assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep
@@ -483,10 +484,19 @@ def _run_stream(
483484
timing=Timing(start_time=action_step_start_time),
484485
observations_images=images,
485486
)
487+
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
486488
try:
487-
for el in self._execute_step(action_step):
488-
yield el
489-
final_answer = el
489+
for output in self._step_stream(action_step):
490+
# Yield streaming deltas
491+
if not isinstance(output, ActionOutput):
492+
yield output
493+
494+
if isinstance(output, ActionOutput) and output.is_final_answer:
495+
if self.final_answer_checks:
496+
self._validate_final_answer(output.output)
497+
returned_final_answer = True
498+
action_step.is_final_answer = True
499+
final_answer = output.output
490500
except AgentGenerationError as e:
491501
# Agent generation errors are not caused by a Model error but an implementation error: so we should raise them and exit.
492502
raise e
@@ -499,23 +509,11 @@ def _run_stream(
499509
yield action_step
500510
self.step_number += 1
501511

502-
if final_answer is None and self.step_number == max_steps + 1:
512+
if not returned_final_answer and self.step_number == max_steps + 1:
503513
final_answer = self._handle_max_steps_reached(task, images)
504514
yield action_step
505515
yield FinalAnswerStep(handle_agent_output_types(final_answer))
506516

507-
def _execute_step(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
508-
self.logger.log_rule(f"Step {self.step_number}", level=LogLevel.INFO)
509-
for el in self._step_stream(memory_step):
510-
final_answer = el
511-
if isinstance(el, ChatMessageStreamDelta):
512-
yield el
513-
elif isinstance(el, FinalOutput):
514-
final_answer = el.output
515-
if self.final_answer_checks:
516-
self._validate_final_answer(final_answer)
517-
yield final_answer
518-
519517
def _validate_final_answer(self, final_answer: Any):
520518
for check_function in self.final_answer_checks:
521519
try:
@@ -674,7 +672,7 @@ def interrupt(self):
674672

675673
def write_memory_to_messages(
676674
self,
677-
summary_mode: bool | None = False,
675+
summary_mode: bool = False,
678676
) -> list[Message]:
679677
"""
680678
Reads past llm_outputs, actions, and observations or errors from the memory into a series of messages
@@ -686,7 +684,7 @@ def write_memory_to_messages(
686684
messages.extend(memory_step.to_messages(summary_mode=summary_mode))
687685
return messages
688686

689-
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
687+
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
690688
"""
691689
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
692690
Yields ChatMessageStreamDelta during the run if streaming is enabled.
@@ -1203,7 +1201,7 @@ def initialize_system_prompt(self) -> str:
12031201
)
12041202
return system_prompt
12051203

1206-
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
1204+
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
12071205
"""
12081206
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
12091207
Yields ChatMessageStreamDelta during the run if streaming is enabled.
@@ -1291,14 +1289,15 @@ def process_tool_calls(self, chat_message: ChatMessage, memory_step: ActionStep)
12911289
memory_step (`ActionStep)`: Memory ActionStep to update with results.
12921290
12931291
Yields:
1294-
`FinalOutput`: The final output of tool execution.
1292+
`ActionOutput`: The final output of tool execution.
12951293
"""
12961294
model_outputs = []
12971295
tool_calls = []
12981296
observations = []
12991297

13001298
final_answer_call = None
13011299
parallel_calls = []
1300+
assert chat_message.tool_calls is not None
13021301
for tool_call in chat_message.tool_calls:
13031302
tool_name = tool_call.function.name
13041303
tool_arguments = tool_call.function.arguments
@@ -1338,19 +1337,19 @@ def process_single_tool_call(call_info):
13381337
)
13391338
return observation
13401339

1341-
# Process non-final-answer tool calls in parallel
1340+
# Process tool calls in parallel
13421341
if parallel_calls:
13431342
if len(parallel_calls) == 1:
13441343
# If there's only one call, process it directly
13451344
observations.append(process_single_tool_call(parallel_calls[0]))
1346-
yield FinalOutput(output=None)
1345+
yield ActionOutput(output=None, is_final_answer=False)
13471346
else:
13481347
# If multiple tool calls, process them in parallel
13491348
with ThreadPoolExecutor(self.max_tool_threads) as executor:
13501349
futures = [executor.submit(process_single_tool_call, call_info) for call_info in parallel_calls]
13511350
for future in as_completed(futures):
13521351
observations.append(future.result())
1353-
yield FinalOutput(output=None)
1352+
yield ActionOutput(output=None, is_final_answer=False)
13541353

13551354
# Process final_answer call if present
13561355
if final_answer_call:
@@ -1380,7 +1379,7 @@ def process_single_tool_call(call_info):
13801379
level=LogLevel.INFO,
13811380
)
13821381
memory_step.action_output = final_answer
1383-
yield FinalOutput(output=final_answer)
1382+
yield ActionOutput(output=final_answer, is_final_answer=True)
13841383

13851384
# Update memory step with all results
13861385
if model_outputs:
@@ -1572,7 +1571,7 @@ def initialize_system_prompt(self) -> str:
15721571
)
15731572
return system_prompt
15741573

1575-
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | FinalOutput]:
1574+
def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDelta | ActionOutput]:
15761575
"""
15771576
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
15781577
Yields ChatMessageStreamDelta during the run if streaming is enabled.
@@ -1702,7 +1701,7 @@ def _step_stream(self, memory_step: ActionStep) -> Generator[ChatMessageStreamDe
17021701
]
17031702
self.logger.log(Group(*execution_outputs_console), level=LogLevel.INFO)
17041703
memory_step.action_output = output
1705-
yield FinalOutput(output=output if is_final_answer else None)
1704+
yield ActionOutput(output=output, is_final_answer=is_final_answer)
17061705

17071706
def to_dict(self) -> dict[str, Any]:
17081707
"""Convert the agent to a dictionary representation.

src/smolagents/memory.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,25 @@ class ActionStep(MemoryStep):
6161
observations_images: list["PIL.Image.Image"] | None = None
6262
action_output: Any = None
6363
token_usage: TokenUsage | None = None
64+
is_final_answer: bool = False
6465

6566
def dict(self):
6667
# We overwrite the method to parse the tool_calls and action_output manually
6768
return {
69+
"step_number": self.step_number,
70+
"timing": self.timing.dict(),
6871
"model_input_messages": self.model_input_messages,
6972
"tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [],
70-
"timing": self.timing.dict(),
71-
"token_usage": asdict(self.token_usage) if self.token_usage else None,
72-
"step": self.step_number,
7373
"error": self.error.dict() if self.error else None,
7474
"model_output_message": self.model_output_message.dict() if self.model_output_message else None,
7575
"model_output": self.model_output,
7676
"observations": self.observations,
77+
"observations_images": [image.tobytes() for image in self.observations_images]
78+
if self.observations_images
79+
else None,
7780
"action_output": make_json_serializable(self.action_output),
81+
"token_usage": asdict(self.token_usage) if self.token_usage else None,
82+
"is_final_answer": self.is_final_answer,
7883
}
7984

8085
def to_messages(self, summary_mode: bool = False) -> list[Message]:

tests/test_agents.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -612,13 +612,27 @@ def weather_api(location: str, celsius: bool = False) -> str:
612612
assert step_memory_dict["timing"]["duration"] > 0.1
613613

614614
def test_final_answer_checks(self):
615+
error_string = "failed with error"
616+
615617
def check_always_fails(final_answer, agent_memory):
616618
assert False, "Error raised in check"
617619

618620
agent = CodeAgent(model=FakeCodeModel(), tools=[], final_answer_checks=[check_always_fails])
619621
agent.run("Dummy task.")
622+
assert error_string in str(agent.write_memory_to_messages())
620623
assert "Error raised in check" in str(agent.write_memory_to_messages())
621624

625+
agent = CodeAgent(
626+
model=FakeCodeModel(),
627+
tools=[],
628+
final_answer_checks=[lambda x, y: x == 7.2904],
629+
verbosity_level=1000,
630+
)
631+
output = agent.run("Dummy task.")
632+
assert output == 7.2904 # Check that output is correct
633+
assert len([step for step in agent.memory.steps if isinstance(step, ActionStep)]) == 2
634+
assert error_string not in str(agent.write_memory_to_messages())
635+
622636
def test_generation_errors_are_raised(self):
623637
class FakeCodeModel(Model):
624638
def generate(self, messages, stop_sequences=None):
@@ -640,19 +654,21 @@ def test_planning_step_with_injected_memory(self):
640654
agent.memory.steps.append(previous_step)
641655

642656
# Run the agent
643-
agent.run(task, reset=False)
657+
agent.run(task, reset=False, max_steps=2)
644658

645659
# Verify that the planning step used update plan prompts
646660
planning_steps = [step for step in agent.memory.steps if isinstance(step, PlanningStep)]
647661
assert len(planning_steps) > 0
648662

649663
# Check that the planning step's model input messages contain the injected memory
650-
planning_step = planning_steps[0]
651-
assert len(planning_step.model_input_messages) == 3 # system message + memory messages + user message
652-
assert planning_step.model_input_messages[0]["role"] == "system"
653-
assert task in planning_step.model_input_messages[0]["content"][0]["text"]
654-
assert planning_step.model_input_messages[1]["role"] == "user"
655-
assert "Previous user request" in planning_step.model_input_messages[1]["content"][0]["text"]
664+
update_plan_step = planning_steps[0]
665+
assert (
666+
len(update_plan_step.model_input_messages) == 3
667+
) # system message + memory messages (1 task message, the latest one is removed) + user message
668+
assert update_plan_step.model_input_messages[0]["role"] == "system"
669+
assert task in update_plan_step.model_input_messages[0]["content"][0]["text"]
670+
assert update_plan_step.model_input_messages[1]["role"] == "user"
671+
assert "Previous user request" in update_plan_step.model_input_messages[1]["content"][0]["text"]
656672

657673

658674
class CustomFinalAnswerTool(FinalAnswerTool):

tests/test_memory.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from PIL import Image
23

34
from smolagents.agents import ToolCall
45
from smolagents.memory import (
@@ -50,7 +51,7 @@ def test_action_step_dict():
5051
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
5152
model_output="Hi",
5253
observations="This is a nice observation",
53-
observations_images=["image1.png"],
54+
observations_images=[Image.new("RGB", (100, 100))],
5455
action_output="Output",
5556
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
5657
)
@@ -76,8 +77,8 @@ def test_action_step_dict():
7677
assert "token_usage" in action_step_dict
7778
assert action_step_dict["token_usage"] == {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}
7879

79-
assert "step" in action_step_dict
80-
assert action_step_dict["step"] == 1
80+
assert "step_number" in action_step_dict
81+
assert action_step_dict["step_number"] == 1
8182

8283
assert "error" in action_step_dict
8384
assert action_step_dict["error"] is None
@@ -97,6 +98,8 @@ def test_action_step_dict():
9798
assert "observations" in action_step_dict
9899
assert action_step_dict["observations"] == "This is a nice observation"
99100

101+
assert "observations_images" in action_step_dict
102+
100103
assert "action_output" in action_step_dict
101104
assert action_step_dict["action_output"] == "Output"
102105

@@ -113,7 +116,7 @@ def test_action_step_to_messages():
113116
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
114117
model_output="Hi",
115118
observations="This is a nice observation",
116-
observations_images=["image1.png"],
119+
observations_images=[Image.new("RGB", (100, 100))],
117120
action_output="Output",
118121
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
119122
)
@@ -197,7 +200,7 @@ def test_planning_step_to_messages():
197200

198201

199202
def test_task_step_to_messages():
200-
task_step = TaskStep(task="This is a task.", task_images=["task_image1.png"])
203+
task_step = TaskStep(task="This is a task.", task_images=[Image.new("RGB", (100, 100))])
201204
messages = task_step.to_messages(summary_mode=False)
202205
assert len(messages) == 1
203206
for message in messages:

0 commit comments

Comments
 (0)