4848from camel .agents import ChatAgent
4949from camel .logger import get_logger
5050from camel .messages .base import BaseMessage
51- from camel .models import ModelFactory
51+ from camel .models import BaseModelBackend , ModelManager
5252from camel .societies .workforce .base import BaseNode
5353from camel .societies .workforce .prompts import (
5454 ASSIGN_TASK_PROMPT ,
9191 SearchToolkit ,
9292 ThinkingToolkit ,
9393)
94- from camel .types import ModelPlatformType , ModelType
9594from camel .utils import dependencies_required
9695
9796from .events import (
@@ -201,6 +200,11 @@ class Workforce(BaseNode):
201200 handle failed tasks. If None, workers will be created with default
202201 settings including SearchToolkit, CodeExecutionToolkit, and
203202 ThinkingToolkit. (default: :obj:`None`)
203+ default_model (Optional[Union[BaseModelBackend, ModelManager]],
204+ optional): Model backend or manager to use when creating default
205+ coordinator, task, or dynamic worker agents. If None, agents
206+ will be created using ModelPlatformType.DEFAULT and
207+ ModelType.DEFAULT settings. (default: :obj:`None`)
204208 graceful_shutdown_timeout (float, optional): The timeout in seconds
205209 for graceful shutdown when a task fails 3 times. During this
206210 period, the workforce remains active for debugging.
@@ -311,6 +315,7 @@ def __init__(
311315 coordinator_agent : Optional [ChatAgent ] = None ,
312316 task_agent : Optional [ChatAgent ] = None ,
313317 new_worker_agent : Optional [ChatAgent ] = None ,
318+ default_model : Optional [Union [BaseModelBackend , ModelManager ]] = None ,
314319 graceful_shutdown_timeout : float = 15.0 ,
315320 share_memory : bool = False ,
316321 use_structured_output_handler : bool = True ,
@@ -327,6 +332,7 @@ def __init__(
327332 ] = deque ()
328333 self ._children = children or []
329334 self .new_worker_agent = new_worker_agent
335+ self .default_model = default_model
330336 self .graceful_shutdown_timeout = graceful_shutdown_timeout
331337 self .share_memory = share_memory
332338 self .use_structured_output_handler = use_structured_output_handler
@@ -391,7 +397,10 @@ def __init__(
391397 "ChatAgent settings (ModelPlatformType.DEFAULT, "
392398 "ModelType.DEFAULT) with default system message."
393399 )
394- self .coordinator_agent = ChatAgent (coord_agent_sys_msg )
400+ self .coordinator_agent = ChatAgent (
401+ coord_agent_sys_msg ,
402+ model = self .default_model ,
403+ )
395404 else :
396405 logger .info (
397406 "Custom coordinator_agent provided. Preserving user's "
@@ -450,6 +459,7 @@ def __init__(
450459 )
451460 self .task_agent = ChatAgent (
452461 task_sys_msg ,
462+ model = self .default_model ,
453463 )
454464 else :
455465 logger .info (
@@ -4058,15 +4068,9 @@ async def _create_new_agent(self, role: str, sys_msg: str) -> ChatAgent:
40584068 * ThinkingToolkit ().get_tools (),
40594069 ]
40604070
4061- model = ModelFactory .create (
4062- model_platform = ModelPlatformType .DEFAULT ,
4063- model_type = ModelType .DEFAULT ,
4064- model_config_dict = {"temperature" : 0 },
4065- )
4066-
40674071 return ChatAgent (
40684072 system_message = worker_sys_msg ,
4069- model = model ,
4073+ model = self . default_model ,
40704074 tools = function_list , # type: ignore[arg-type]
40714075 pause_event = self ._pause_event ,
40724076 )
@@ -5496,6 +5500,7 @@ def clone(self, with_memory: bool = False) -> 'Workforce':
54965500 new_worker_agent = self .new_worker_agent .clone (with_memory )
54975501 if self .new_worker_agent
54985502 else None ,
5503+ default_model = self .default_model ,
54995504 graceful_shutdown_timeout = self .graceful_shutdown_timeout ,
55005505 share_memory = self .share_memory ,
55015506 use_structured_output_handler = self .use_structured_output_handler ,
0 commit comments