Skip to content

Commit 5abf425

Browse files
chore: add default model for workforce (#3625)
Co-authored-by: Wendong-Fan <[email protected]> Co-authored-by: Wendong-Fan <[email protected]>
1 parent aaf749b commit 5abf425

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

camel/societies/workforce/workforce.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from camel.agents import ChatAgent
4949
from camel.logger import get_logger
5050
from camel.messages.base import BaseMessage
51-
from camel.models import ModelFactory
51+
from camel.models import BaseModelBackend, ModelManager
5252
from camel.societies.workforce.base import BaseNode
5353
from camel.societies.workforce.prompts import (
5454
ASSIGN_TASK_PROMPT,
@@ -91,7 +91,6 @@
9191
SearchToolkit,
9292
ThinkingToolkit,
9393
)
94-
from camel.types import ModelPlatformType, ModelType
9594
from camel.utils import dependencies_required
9695

9796
from .events import (
@@ -201,6 +200,11 @@ class Workforce(BaseNode):
201200
handle failed tasks. If None, workers will be created with default
202201
settings including SearchToolkit, CodeExecutionToolkit, and
203202
ThinkingToolkit. (default: :obj:`None`)
203+
default_model (Optional[Union[BaseModelBackend, ModelManager]],
204+
optional): Model backend or manager to use when creating default
205+
coordinator, task, or dynamic worker agents. If None, agents
206+
will be created using ModelPlatformType.DEFAULT and
207+
ModelType.DEFAULT settings. (default: :obj:`None`)
204208
graceful_shutdown_timeout (float, optional): The timeout in seconds
205209
for graceful shutdown when a task fails 3 times. During this
206210
period, the workforce remains active for debugging.
@@ -311,6 +315,7 @@ def __init__(
311315
coordinator_agent: Optional[ChatAgent] = None,
312316
task_agent: Optional[ChatAgent] = None,
313317
new_worker_agent: Optional[ChatAgent] = None,
318+
default_model: Optional[Union[BaseModelBackend, ModelManager]] = None,
314319
graceful_shutdown_timeout: float = 15.0,
315320
share_memory: bool = False,
316321
use_structured_output_handler: bool = True,
@@ -327,6 +332,7 @@ def __init__(
327332
] = deque()
328333
self._children = children or []
329334
self.new_worker_agent = new_worker_agent
335+
self.default_model = default_model
330336
self.graceful_shutdown_timeout = graceful_shutdown_timeout
331337
self.share_memory = share_memory
332338
self.use_structured_output_handler = use_structured_output_handler
@@ -391,7 +397,10 @@ def __init__(
391397
"ChatAgent settings (ModelPlatformType.DEFAULT, "
392398
"ModelType.DEFAULT) with default system message."
393399
)
394-
self.coordinator_agent = ChatAgent(coord_agent_sys_msg)
400+
self.coordinator_agent = ChatAgent(
401+
coord_agent_sys_msg,
402+
model=self.default_model,
403+
)
395404
else:
396405
logger.info(
397406
"Custom coordinator_agent provided. Preserving user's "
@@ -450,6 +459,7 @@ def __init__(
450459
)
451460
self.task_agent = ChatAgent(
452461
task_sys_msg,
462+
model=self.default_model,
453463
)
454464
else:
455465
logger.info(
@@ -4058,15 +4068,9 @@ async def _create_new_agent(self, role: str, sys_msg: str) -> ChatAgent:
40584068
*ThinkingToolkit().get_tools(),
40594069
]
40604070

4061-
model = ModelFactory.create(
4062-
model_platform=ModelPlatformType.DEFAULT,
4063-
model_type=ModelType.DEFAULT,
4064-
model_config_dict={"temperature": 0},
4065-
)
4066-
40674071
return ChatAgent(
40684072
system_message=worker_sys_msg,
4069-
model=model,
4073+
model=self.default_model,
40704074
tools=function_list, # type: ignore[arg-type]
40714075
pause_event=self._pause_event,
40724076
)
@@ -5496,6 +5500,7 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
54965500
new_worker_agent=self.new_worker_agent.clone(with_memory)
54975501
if self.new_worker_agent
54985502
else None,
5503+
default_model=self.default_model,
54995504
graceful_shutdown_timeout=self.graceful_shutdown_timeout,
55005505
share_memory=self.share_memory,
55015506
use_structured_output_handler=self.use_structured_output_handler,

0 commit comments

Comments
 (0)