Skip to content

Commit 746db0f

Browse files
ChenziqiAdamAdam ChenSaedbhatiwaleedalzarooniJINO-ROHIT
authored
Enhancing RolePlaying with custom ChatAgent support (#3199)
Co-authored-by: Adam Chen <[email protected]> Co-authored-by: Saed Bhati <[email protected]> Co-authored-by: Waleed Alzarooni <[email protected]> Co-authored-by: JINO ROHIT <[email protected]>
1 parent dac778f commit 746db0f

File tree

2 files changed

+175
-50
lines changed

2 files changed

+175
-50
lines changed

camel/societies/role_playing.py

Lines changed: 173 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ class RolePlaying:
6262
(default: :obj:`TaskType.AI_SOCIETY`)
6363
assistant_agent_kwargs (Dict, optional): Additional arguments to pass
6464
to the assistant agent. (default: :obj:`None`)
65-
user_agent_kwargs (Dict, optional): Additional arguments to pass to
66-
the user agent. (default: :obj:`None`)
65+
user_agent_kwargs (Dict, optional): Additional arguments to pass to the
66+
user agent. (default: :obj:`None`)
6767
task_specify_agent_kwargs (Dict, optional): Additional arguments to
6868
pass to the task specify agent. (default: :obj:`None`)
6969
task_planner_agent_kwargs (Dict, optional): Additional arguments to
@@ -81,6 +81,12 @@ class RolePlaying:
8181
stop_event (Optional[threading.Event], optional): Event to signal
8282
termination of the agent's operation. When set, the agent will
8383
terminate its execution. (default: :obj:`None`)
84+
assistant_agent (ChatAgent, optional): A pre-configured ChatAgent to
85+
use as the assistant. If provided, this will override the creation
86+
of a new assistant agent. (default: :obj:`None`)
87+
user_agent (ChatAgent, optional): A pre-configured ChatAgent to use as
88+
the user. If provided, this will override the creation of a new
89+
user agent. (default: :obj:`None`)
8490
"""
8591

8692
def __init__(
@@ -106,6 +112,8 @@ def __init__(
106112
extend_task_specify_meta_dict: Optional[Dict] = None,
107113
output_language: Optional[str] = None,
108114
stop_event: Optional[threading.Event] = None,
115+
assistant_agent: Optional[ChatAgent] = None,
116+
user_agent: Optional[ChatAgent] = None,
109117
) -> None:
110118
if model is not None:
111119
logger.warning(
@@ -143,29 +151,56 @@ def __init__(
143151
**(sys_msg_generator_kwargs or {}),
144152
)
145153

146-
(
147-
init_assistant_sys_msg,
148-
init_user_sys_msg,
149-
sys_msg_meta_dicts,
150-
) = self._get_sys_message_info(
151-
assistant_role_name,
152-
user_role_name,
153-
sys_msg_generator,
154-
extend_sys_msg_meta_dicts=extend_sys_msg_meta_dicts,
155-
)
156-
154+
# Initialize agent attributes first
157155
self.assistant_agent: ChatAgent
158156
self.user_agent: ChatAgent
159-
self.assistant_sys_msg: Optional[BaseMessage]
160-
self.user_sys_msg: Optional[BaseMessage]
157+
self.assistant_sys_msg: Optional[BaseMessage] = None
158+
self.user_sys_msg: Optional[BaseMessage] = None
159+
160+
# Determine if we need to generate system messages
161+
if assistant_agent is None or user_agent is None:
162+
# Generate system messages for missing agents
163+
(
164+
init_assistant_sys_msg,
165+
init_user_sys_msg,
166+
sys_msg_meta_dicts,
167+
) = self._get_sys_message_info(
168+
assistant_role_name,
169+
user_role_name,
170+
sys_msg_generator,
171+
extend_sys_msg_meta_dicts=extend_sys_msg_meta_dicts,
172+
)
173+
else:
174+
# When both agents are provided, use their existing system messages
175+
assistant_sys_msg = assistant_agent.system_message
176+
user_sys_msg = user_agent.system_message
177+
178+
# Ensure system messages are not None
179+
if assistant_sys_msg is None:
180+
raise ValueError(
181+
"Provided assistant_agent has None system_message"
182+
)
183+
if user_sys_msg is None:
184+
raise ValueError("Provided user_agent has None system_message")
185+
186+
init_assistant_sys_msg = assistant_sys_msg
187+
init_user_sys_msg = user_sys_msg
188+
# Create a default sys_msg_meta_dicts for critic initialization
189+
sys_msg_meta_dicts = [
190+
dict(task=self.task_prompt) for _ in range(2)
191+
]
192+
161193
self._init_agents(
162194
init_assistant_sys_msg,
163195
init_user_sys_msg,
164196
assistant_agent_kwargs=assistant_agent_kwargs,
165197
user_agent_kwargs=user_agent_kwargs,
166198
output_language=output_language,
167199
stop_event=stop_event,
200+
assistant_agent=assistant_agent,
201+
user_agent=user_agent,
168202
)
203+
169204
self.critic: Optional[Union[CriticAgent, Human]] = None
170205
self.critic_sys_msg: Optional[BaseMessage] = None
171206
self._init_critic(
@@ -320,20 +355,22 @@ def _get_sys_message_info(
320355

321356
def _init_agents(
322357
self,
323-
init_assistant_sys_msg: BaseMessage,
324-
init_user_sys_msg: BaseMessage,
358+
init_assistant_sys_msg: Optional[BaseMessage],
359+
init_user_sys_msg: Optional[BaseMessage],
325360
assistant_agent_kwargs: Optional[Dict] = None,
326361
user_agent_kwargs: Optional[Dict] = None,
327362
output_language: Optional[str] = None,
328363
stop_event: Optional[threading.Event] = None,
364+
assistant_agent: Optional[ChatAgent] = None,
365+
user_agent: Optional[ChatAgent] = None,
329366
) -> None:
330367
r"""Initialize assistant and user agents with their system messages.
331368
332369
Args:
333-
init_assistant_sys_msg (BaseMessage): Assistant agent's initial
370+
init_assistant_sys_msg (Optional[BaseMessage]): Assistant agent's
371+
initial system message.
372+
init_user_sys_msg (Optional[BaseMessage]): User agent's initial
334373
system message.
335-
init_user_sys_msg (BaseMessage): User agent's initial system
336-
message.
337374
assistant_agent_kwargs (Dict, optional): Additional arguments to
338375
pass to the assistant agent. (default: :obj:`None`)
339376
user_agent_kwargs (Dict, optional): Additional arguments to
@@ -343,6 +380,12 @@ def _init_agents(
343380
stop_event (Optional[threading.Event], optional): Event to signal
344381
termination of the agent's operation. When set, the agent will
345382
terminate its execution. (default: :obj:`None`)
383+
assistant_agent (ChatAgent, optional): A pre-configured ChatAgent
384+
to use as the assistant. If provided, this will override the
385+
creation of a new assistant agent. (default: :obj:`None`)
386+
user_agent (ChatAgent, optional): A pre-configured ChatAgent to use
387+
as the user. If provided, this will override the creation of a
388+
new user agent. (default: :obj:`None`)
346389
"""
347390
if self.model is not None:
348391
if assistant_agent_kwargs is None:
@@ -354,21 +397,71 @@ def _init_agents(
354397
elif 'model' not in user_agent_kwargs:
355398
user_agent_kwargs.update(dict(model=self.model))
356399

357-
self.assistant_agent = ChatAgent(
358-
init_assistant_sys_msg,
359-
output_language=output_language,
360-
stop_event=stop_event,
361-
**(assistant_agent_kwargs or {}),
362-
)
363-
self.assistant_sys_msg = self.assistant_agent.system_message
364-
365-
self.user_agent = ChatAgent(
366-
init_user_sys_msg,
367-
output_language=output_language,
368-
stop_event=stop_event,
369-
**(user_agent_kwargs or {}),
370-
)
371-
self.user_sys_msg = self.user_agent.system_message
400+
# Use provided assistant agent if available, otherwise create a new one
401+
if assistant_agent is not None:
402+
# Ensure functionality consistent with our configuration
403+
if (
404+
hasattr(assistant_agent, 'output_language')
405+
and output_language is not None
406+
):
407+
assistant_agent.output_language = output_language
408+
if hasattr(assistant_agent, 'stop_event'):
409+
assistant_agent.stop_event = stop_event
410+
self.assistant_agent = assistant_agent
411+
# Handle potential None system_message - use provided or fallback
412+
if assistant_agent.system_message is not None:
413+
self.assistant_sys_msg = assistant_agent.system_message
414+
elif init_assistant_sys_msg is not None:
415+
self.assistant_sys_msg = init_assistant_sys_msg
416+
else:
417+
raise ValueError("Assistant system message cannot be None")
418+
else:
419+
# Create new assistant agent
420+
if init_assistant_sys_msg is None:
421+
raise ValueError(
422+
"Assistant system message cannot be None when creating "
423+
"new agent"
424+
)
425+
self.assistant_agent = ChatAgent(
426+
init_assistant_sys_msg,
427+
output_language=output_language,
428+
stop_event=stop_event,
429+
**(assistant_agent_kwargs or {}),
430+
)
431+
self.assistant_sys_msg = self.assistant_agent.system_message
432+
433+
# Use provided user agent if available, otherwise create a new one
434+
if user_agent is not None:
435+
# Ensure functionality consistent with our configuration
436+
if (
437+
hasattr(user_agent, 'output_language')
438+
and output_language is not None
439+
):
440+
user_agent.output_language = output_language
441+
if hasattr(user_agent, 'stop_event'):
442+
user_agent.stop_event = stop_event
443+
self.user_agent = user_agent
444+
# Handle potential None system_message - use provided or fallback
445+
if user_agent.system_message is not None:
446+
self.user_sys_msg = user_agent.system_message
447+
elif init_user_sys_msg is not None:
448+
self.user_sys_msg = init_user_sys_msg
449+
else:
450+
raise ValueError("User system message cannot be None")
451+
else:
452+
# Create new user agent
453+
if init_user_sys_msg is None:
454+
raise ValueError(
455+
"User system message cannot be None when creating new "
456+
"agent"
457+
)
458+
self.user_agent = ChatAgent(
459+
init_user_sys_msg,
460+
output_language=output_language,
461+
stop_event=stop_event,
462+
**(user_agent_kwargs or {}),
463+
)
464+
self.user_sys_msg = self.user_agent.system_message
372465

373466
def _init_critic(
374467
self,
@@ -389,7 +482,7 @@ def _init_critic(
389482
sys_msg_meta_dicts (list): A list of system message meta dicts.
390483
critic_role_name (str): The name of the role played by the critic.
391484
critic_criteria (str, optional): Critic criteria for the
392-
critic agent. If not specified, set the criteria to
485+
critic agent. If not specified, set it to
393486
improve task performance. (default: :obj:`None`)
394487
critic_kwargs (Dict, optional): Additional arguments to
395488
pass to the critic. (default: :obj:`None`)
@@ -465,20 +558,28 @@ def init_chat(self, init_msg_content: Optional[str] = None) -> BaseMessage:
465558
BaseMessage: A single `BaseMessage` representing the initial
466559
message.
467560
"""
468-
self.assistant_agent.reset()
469-
self.user_agent.reset()
561+
if self.assistant_agent is not None:
562+
self.assistant_agent.reset()
563+
if self.user_agent is not None:
564+
self.user_agent.reset()
470565
default_init_msg_content = (
471566
"Now start to give me instructions one by one. "
472567
"Only reply with Instruction and Input."
473568
)
474-
if init_msg_content is None:
475-
init_msg_content = default_init_msg_content
569+
final_init_msg_content = init_msg_content or default_init_msg_content
476570

477571
# Initialize a message sent by the assistant
572+
assistant_role_name = "assistant"
573+
if self.assistant_sys_msg is not None and hasattr(
574+
self.assistant_sys_msg, 'role_name'
575+
):
576+
role_name_attr = getattr(self.assistant_sys_msg, 'role_name', None)
577+
if role_name_attr is not None:
578+
assistant_role_name = str(role_name_attr)
579+
478580
init_msg = BaseMessage.make_assistant_message(
479-
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
480-
or "assistant",
481-
content=init_msg_content,
581+
role_name=assistant_role_name,
582+
content=final_init_msg_content,
482583
)
483584

484585
return init_msg
@@ -501,20 +602,28 @@ async def ainit_chat(
501602
"""
502603
# Currently, reset() is synchronous, but if it becomes async in the
503604
# future, we can await it here
504-
self.assistant_agent.reset()
505-
self.user_agent.reset()
605+
if self.assistant_agent is not None:
606+
self.assistant_agent.reset()
607+
if self.user_agent is not None:
608+
self.user_agent.reset()
506609
default_init_msg_content = (
507610
"Now start to give me instructions one by one. "
508611
"Only reply with Instruction and Input."
509612
)
510-
if init_msg_content is None:
511-
init_msg_content = default_init_msg_content
613+
final_init_msg_content = init_msg_content or default_init_msg_content
512614

513615
# Initialize a message sent by the assistant
616+
assistant_role_name = "assistant"
617+
if self.assistant_sys_msg is not None and hasattr(
618+
self.assistant_sys_msg, 'role_name'
619+
):
620+
role_name_attr = getattr(self.assistant_sys_msg, 'role_name', None)
621+
if role_name_attr is not None:
622+
assistant_role_name = str(role_name_attr)
623+
514624
init_msg = BaseMessage.make_assistant_message(
515-
role_name=getattr(self.assistant_sys_msg, 'role_name', None)
516-
or "assistant",
517-
content=init_msg_content,
625+
role_name=assistant_role_name,
626+
content=final_init_msg_content,
518627
)
519628

520629
return init_msg
@@ -544,6 +653,11 @@ def step(
544653
user agent terminated the conversation, and any additional user
545654
information.
546655
"""
656+
if self.user_agent is None:
657+
raise ValueError("User agent is not initialized")
658+
if self.assistant_agent is None:
659+
raise ValueError("Assistant agent is not initialized")
660+
547661
user_response = self.user_agent.step(assistant_msg)
548662
if user_response.terminated or user_response.msgs is None:
549663
return (
@@ -620,6 +734,11 @@ async def astep(
620734
user agent terminated the conversation, and any additional user
621735
information.
622736
"""
737+
if self.user_agent is None:
738+
raise ValueError("User agent is not initialized")
739+
if self.assistant_agent is None:
740+
raise ValueError("Assistant agent is not initialized")
741+
623742
user_response = await self.user_agent.astep(assistant_msg)
624743
if user_response.terminated or user_response.msgs is None:
625744
return (
@@ -682,6 +801,10 @@ def clone(
682801
RolePlaying: A new instance of RolePlaying with the same
683802
configuration.
684803
"""
804+
if self.assistant_agent is None or self.user_agent is None:
805+
raise ValueError(
806+
"Cannot clone: assistant_agent or user_agent is None"
807+
)
685808

686809
new_instance = RolePlaying(
687810
assistant_role_name=self.assistant_agent.role_name,

docs/key_modules/societies.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ icon: people-group
7070
<tr><td>extend_sys_msg_meta_dicts</td><td>List[Dict]</td><td>Extra metadata for system messages</td></tr>
7171
<tr><td>extend_task_specify_meta_dict</td><td>Dict</td><td>Extra metadata for task specification</td></tr>
7272
<tr><td>output_language</td><td>str</td><td>Target output language</td></tr>
73+
<tr><td>assistant_agent</td><td>ChatAgent</td><td>Custom ChatAgent to use as assistant (optional)</td></tr>
74+
<tr><td>user_agent</td><td>ChatAgent</td><td>Custom ChatAgent to use as user (optional)</td></tr>
7375
</tbody>
7476
</table>
7577
</Accordion>

0 commit comments

Comments
 (0)