Skip to content

Commit fc7cc89

Browse files
Fix test which seems to have been faulty before
1 parent 4441aec commit fc7cc89

File tree

3 files changed

+29
-24
lines changed

3 files changed

+29
-24
lines changed

src/smolagents/agents.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,15 @@ def __init__(
258258
self.prompt_templates = prompt_templates or EMPTY_PROMPT_TEMPLATES
259259
if prompt_templates is not None:
260260
missing_keys = set(EMPTY_PROMPT_TEMPLATES.keys()) - set(prompt_templates.keys())
261-
assert not missing_keys, (
262-
f"Some prompt templates are missing from your custom `prompt_templates`: {missing_keys}"
263-
)
261+
assert (
262+
not missing_keys
263+
), f"Some prompt templates are missing from your custom `prompt_templates`: {missing_keys}"
264264
for key, value in EMPTY_PROMPT_TEMPLATES.items():
265265
if isinstance(value, dict):
266266
for subkey in value.keys():
267-
assert key in prompt_templates.keys() and (subkey in prompt_templates[key].keys()), (
268-
f"Some prompt templates are missing from your custom `prompt_templates`: {subkey} under {key}"
269-
)
267+
assert (
268+
key in prompt_templates.keys() and (subkey in prompt_templates[key].keys())
269+
), f"Some prompt templates are missing from your custom `prompt_templates`: {subkey} under {key}"
270270

271271
self.max_steps = max_steps
272272
self.step_number = 0
@@ -320,9 +320,9 @@ def _setup_managed_agents(self, managed_agents: list | None = None) -> None:
320320
"""Setup managed agents with proper logging."""
321321
self.managed_agents = {}
322322
if managed_agents:
323-
assert all(agent.name and agent.description for agent in managed_agents), (
324-
"All managed agents need both a name and a description!"
325-
)
323+
assert all(
324+
agent.name and agent.description for agent in managed_agents
325+
), "All managed agents need both a name and a description!"
326326
self.managed_agents = {agent.name: agent for agent in managed_agents}
327327

328328
def _setup_tools(self, tools, add_base_tools):
@@ -465,8 +465,8 @@ def _run_stream(
465465
planning_start_time = time.time()
466466
planning_step = None
467467
for element in self._generate_planning_step(
468-
task, is_first_step=(self.step_number == 1), step=self.step_number
469-
):
468+
task, is_first_step=len(self.memory.steps) == 0, step=self.step_number
469+
): # Don't use the attribute step_number here, because there can be steps from previous runs
470470
yield element
471471
planning_step = element
472472
assert isinstance(planning_step, PlanningStep) # Last yielded element should be a PlanningStep

tests/test_agents.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -655,27 +655,29 @@ def generate(self, messages, stop_sequences=None):
655655

656656
def test_planning_step_with_injected_memory(self):
657657
"""Test that planning step uses update plan prompts when memory is injected before run."""
658-
agent = CodeAgent(tools=[], planning_interval=1, model=FakeCodeModelPlanning())
658+
agent = CodeAgent(tools=[], planning_interval=1, model=FakeCodeModelPlanning(), verbosity_level=1000)
659659
task = "Continuous task"
660660

661661
# Inject memory before run
662662
previous_step = TaskStep(task="Previous user request")
663663
agent.memory.steps.append(previous_step)
664664

665665
# Run the agent
666-
agent.run(task, reset=False)
666+
agent.run(task, reset=False, max_steps=2)
667667

668668
# Verify that the planning step used update plan prompts
669669
planning_steps = [step for step in agent.memory.steps if isinstance(step, PlanningStep)]
670670
assert len(planning_steps) > 0
671671

672672
# Check that the planning step's model input messages contain the injected memory
673-
planning_step = planning_steps[0]
674-
assert len(planning_step.model_input_messages) == 3 # system message + memory messages + user message
675-
assert planning_step.model_input_messages[0]["role"] == "system"
676-
assert task in planning_step.model_input_messages[0]["content"][0]["text"]
677-
assert planning_step.model_input_messages[1]["role"] == "user"
678-
assert "Previous user request" in planning_step.model_input_messages[1]["content"][0]["text"]
673+
update_plan_step = planning_steps[0]
674+
assert (
675+
len(update_plan_step.model_input_messages) == 4
676+
) # system message + memory messages (2 task messages) + user message
677+
assert update_plan_step.model_input_messages[0]["role"] == "system"
678+
assert task in update_plan_step.model_input_messages[0]["content"][0]["text"]
679+
assert update_plan_step.model_input_messages[1]["role"] == "user"
680+
assert "Previous user request" in update_plan_step.model_input_messages[1]["content"][0]["text"]
679681

680682

681683
class CustomFinalAnswerTool(FinalAnswerTool):

tests/test_memory.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from PIL import Image
23

34
from smolagents.agents import ToolCall
45
from smolagents.memory import (
@@ -50,7 +51,7 @@ def test_action_step_dict():
5051
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
5152
model_output="Hi",
5253
observations="This is a nice observation",
53-
observations_images=["image1.png"],
54+
observations_images=[Image.new("RGB", (100, 100))],
5455
action_output="Output",
5556
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
5657
)
@@ -76,8 +77,8 @@ def test_action_step_dict():
7677
assert "token_usage" in action_step_dict
7778
assert action_step_dict["token_usage"] == {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}
7879

79-
assert "step" in action_step_dict
80-
assert action_step_dict["step"] == 1
80+
assert "step_number" in action_step_dict
81+
assert action_step_dict["step_number"] == 1
8182

8283
assert "error" in action_step_dict
8384
assert action_step_dict["error"] is None
@@ -97,6 +98,8 @@ def test_action_step_dict():
9798
assert "observations" in action_step_dict
9899
assert action_step_dict["observations"] == "This is a nice observation"
99100

101+
assert "observations_images" in action_step_dict
102+
100103
assert "action_output" in action_step_dict
101104
assert action_step_dict["action_output"] == "Output"
102105

@@ -113,7 +116,7 @@ def test_action_step_to_messages():
113116
model_output_message=ChatMessage(role=MessageRole.ASSISTANT, content="Hi"),
114117
model_output="Hi",
115118
observations="This is a nice observation",
116-
observations_images=["image1.png"],
119+
observations_images=[Image.new("RGB", (100, 100))],
117120
action_output="Output",
118121
token_usage=TokenUsage(input_tokens=10, output_tokens=20),
119122
)
@@ -197,7 +200,7 @@ def test_planning_step_to_messages():
197200

198201

199202
def test_task_step_to_messages():
200-
task_step = TaskStep(task="This is a task.", task_images=["task_image1.png"])
203+
task_step = TaskStep(task="This is a task.", task_images=[Image.new("RGB", (100, 100))])
201204
messages = task_step.to_messages(summary_mode=False)
202205
assert len(messages) == 1
203206
for message in messages:

0 commit comments

Comments
 (0)