@@ -62,8 +62,8 @@ class RolePlaying:
6262 (default: :obj:`TaskType.AI_SOCIETY`)
6363 assistant_agent_kwargs (Dict, optional): Additional arguments to pass
6464 to the assistant agent. (default: :obj:`None`)
65- user_agent_kwargs (Dict, optional): Additional arguments to pass to
66- the user agent. (default: :obj:`None`)
65+ user_agent_kwargs (Dict, optional): Additional arguments to pass to the
66+ user agent. (default: :obj:`None`)
6767 task_specify_agent_kwargs (Dict, optional): Additional arguments to
6868 pass to the task specify agent. (default: :obj:`None`)
6969 task_planner_agent_kwargs (Dict, optional): Additional arguments to
@@ -81,6 +81,12 @@ class RolePlaying:
8181 stop_event (Optional[threading.Event], optional): Event to signal
8282 termination of the agent's operation. When set, the agent will
8383 terminate its execution. (default: :obj:`None`)
84+ assistant_agent (ChatAgent, optional): A pre-configured ChatAgent to
85+ use as the assistant. If provided, this will override the creation
86+ of a new assistant agent. (default: :obj:`None`)
87+ user_agent (ChatAgent, optional): A pre-configured ChatAgent to use as
88+ the user. If provided, this will override the creation of a new
89+ user agent. (default: :obj:`None`)
8490 """
8591
8692 def __init__ (
@@ -106,6 +112,8 @@ def __init__(
106112 extend_task_specify_meta_dict : Optional [Dict ] = None ,
107113 output_language : Optional [str ] = None ,
108114 stop_event : Optional [threading .Event ] = None ,
115+ assistant_agent : Optional [ChatAgent ] = None ,
116+ user_agent : Optional [ChatAgent ] = None ,
109117 ) -> None :
110118 if model is not None :
111119 logger .warning (
@@ -143,29 +151,56 @@ def __init__(
143151 ** (sys_msg_generator_kwargs or {}),
144152 )
145153
146- (
147- init_assistant_sys_msg ,
148- init_user_sys_msg ,
149- sys_msg_meta_dicts ,
150- ) = self ._get_sys_message_info (
151- assistant_role_name ,
152- user_role_name ,
153- sys_msg_generator ,
154- extend_sys_msg_meta_dicts = extend_sys_msg_meta_dicts ,
155- )
156-
154+ # Initialize agent attributes first
157155 self .assistant_agent : ChatAgent
158156 self .user_agent : ChatAgent
159- self .assistant_sys_msg : Optional [BaseMessage ]
160- self .user_sys_msg : Optional [BaseMessage ]
157+ self .assistant_sys_msg : Optional [BaseMessage ] = None
158+ self .user_sys_msg : Optional [BaseMessage ] = None
159+
160+ # Determine if we need to generate system messages
161+ if assistant_agent is None or user_agent is None :
162+ # Generate system messages for missing agents
163+ (
164+ init_assistant_sys_msg ,
165+ init_user_sys_msg ,
166+ sys_msg_meta_dicts ,
167+ ) = self ._get_sys_message_info (
168+ assistant_role_name ,
169+ user_role_name ,
170+ sys_msg_generator ,
171+ extend_sys_msg_meta_dicts = extend_sys_msg_meta_dicts ,
172+ )
173+ else :
174+ # When both agents are provided, use their existing system messages
175+ assistant_sys_msg = assistant_agent .system_message
176+ user_sys_msg = user_agent .system_message
177+
178+ # Ensure system messages are not None
179+ if assistant_sys_msg is None :
180+ raise ValueError (
181+ "Provided assistant_agent has None system_message"
182+ )
183+ if user_sys_msg is None :
184+ raise ValueError ("Provided user_agent has None system_message" )
185+
186+ init_assistant_sys_msg = assistant_sys_msg
187+ init_user_sys_msg = user_sys_msg
188+ # Create a default sys_msg_meta_dicts for critic initialization
189+ sys_msg_meta_dicts = [
190+ dict (task = self .task_prompt ) for _ in range (2 )
191+ ]
192+
161193 self ._init_agents (
162194 init_assistant_sys_msg ,
163195 init_user_sys_msg ,
164196 assistant_agent_kwargs = assistant_agent_kwargs ,
165197 user_agent_kwargs = user_agent_kwargs ,
166198 output_language = output_language ,
167199 stop_event = stop_event ,
200+ assistant_agent = assistant_agent ,
201+ user_agent = user_agent ,
168202 )
203+
169204 self .critic : Optional [Union [CriticAgent , Human ]] = None
170205 self .critic_sys_msg : Optional [BaseMessage ] = None
171206 self ._init_critic (
@@ -320,20 +355,22 @@ def _get_sys_message_info(
320355
321356 def _init_agents (
322357 self ,
323- init_assistant_sys_msg : BaseMessage ,
324- init_user_sys_msg : BaseMessage ,
358+ init_assistant_sys_msg : Optional [ BaseMessage ] ,
359+ init_user_sys_msg : Optional [ BaseMessage ] ,
325360 assistant_agent_kwargs : Optional [Dict ] = None ,
326361 user_agent_kwargs : Optional [Dict ] = None ,
327362 output_language : Optional [str ] = None ,
328363 stop_event : Optional [threading .Event ] = None ,
364+ assistant_agent : Optional [ChatAgent ] = None ,
365+ user_agent : Optional [ChatAgent ] = None ,
329366 ) -> None :
330367 r"""Initialize assistant and user agents with their system messages.
331368
332369 Args:
333- init_assistant_sys_msg (BaseMessage): Assistant agent's initial
370+ init_assistant_sys_msg (Optional[BaseMessage]): Assistant agent's
371+ initial system message.
372+ init_user_sys_msg (Optional[BaseMessage]): User agent's initial
334373 system message.
335- init_user_sys_msg (BaseMessage): User agent's initial system
336- message.
337374 assistant_agent_kwargs (Dict, optional): Additional arguments to
338375 pass to the assistant agent. (default: :obj:`None`)
339376 user_agent_kwargs (Dict, optional): Additional arguments to
@@ -343,6 +380,12 @@ def _init_agents(
343380 stop_event (Optional[threading.Event], optional): Event to signal
344381 termination of the agent's operation. When set, the agent will
345382 terminate its execution. (default: :obj:`None`)
383+ assistant_agent (ChatAgent, optional): A pre-configured ChatAgent
384+ to use as the assistant. If provided, this will override the
385+ creation of a new assistant agent. (default: :obj:`None`)
386+ user_agent (ChatAgent, optional): A pre-configured ChatAgent to use
387+ as the user. If provided, this will override the creation of a
388+ new user agent. (default: :obj:`None`)
346389 """
347390 if self .model is not None :
348391 if assistant_agent_kwargs is None :
@@ -354,21 +397,71 @@ def _init_agents(
354397 elif 'model' not in user_agent_kwargs :
355398 user_agent_kwargs .update (dict (model = self .model ))
356399
357- self .assistant_agent = ChatAgent (
358- init_assistant_sys_msg ,
359- output_language = output_language ,
360- stop_event = stop_event ,
361- ** (assistant_agent_kwargs or {}),
362- )
363- self .assistant_sys_msg = self .assistant_agent .system_message
364-
365- self .user_agent = ChatAgent (
366- init_user_sys_msg ,
367- output_language = output_language ,
368- stop_event = stop_event ,
369- ** (user_agent_kwargs or {}),
370- )
371- self .user_sys_msg = self .user_agent .system_message
400+ # Use provided assistant agent if available, otherwise create a new one
401+ if assistant_agent is not None :
402+ # Ensure functionality consistent with our configuration
403+ if (
404+ hasattr (assistant_agent , 'output_language' )
405+ and output_language is not None
406+ ):
407+ assistant_agent .output_language = output_language
408+ if hasattr (assistant_agent , 'stop_event' ):
409+ assistant_agent .stop_event = stop_event
410+ self .assistant_agent = assistant_agent
411+ # Handle potential None system_message - use provided or fallback
412+ if assistant_agent .system_message is not None :
413+ self .assistant_sys_msg = assistant_agent .system_message
414+ elif init_assistant_sys_msg is not None :
415+ self .assistant_sys_msg = init_assistant_sys_msg
416+ else :
417+ raise ValueError ("Assistant system message cannot be None" )
418+ else :
419+ # Create new assistant agent
420+ if init_assistant_sys_msg is None :
421+ raise ValueError (
422+ "Assistant system message cannot be None when creating "
423+ "new agent"
424+ )
425+ self .assistant_agent = ChatAgent (
426+ init_assistant_sys_msg ,
427+ output_language = output_language ,
428+ stop_event = stop_event ,
429+ ** (assistant_agent_kwargs or {}),
430+ )
431+ self .assistant_sys_msg = self .assistant_agent .system_message
432+
433+ # Use provided user agent if available, otherwise create a new one
434+ if user_agent is not None :
435+ # Ensure functionality consistent with our configuration
436+ if (
437+ hasattr (user_agent , 'output_language' )
438+ and output_language is not None
439+ ):
440+ user_agent .output_language = output_language
441+ if hasattr (user_agent , 'stop_event' ):
442+ user_agent .stop_event = stop_event
443+ self .user_agent = user_agent
444+ # Handle potential None system_message - use provided or fallback
445+ if user_agent .system_message is not None :
446+ self .user_sys_msg = user_agent .system_message
447+ elif init_user_sys_msg is not None :
448+ self .user_sys_msg = init_user_sys_msg
449+ else :
450+ raise ValueError ("User system message cannot be None" )
451+ else :
452+ # Create new user agent
453+ if init_user_sys_msg is None :
454+ raise ValueError (
455+ "User system message cannot be None when creating new "
456+ "agent"
457+ )
458+ self .user_agent = ChatAgent (
459+ init_user_sys_msg ,
460+ output_language = output_language ,
461+ stop_event = stop_event ,
462+ ** (user_agent_kwargs or {}),
463+ )
464+ self .user_sys_msg = self .user_agent .system_message
372465
373466 def _init_critic (
374467 self ,
@@ -389,7 +482,7 @@ def _init_critic(
389482 sys_msg_meta_dicts (list): A list of system message meta dicts.
390483 critic_role_name (str): The name of the role played by the critic.
391484 critic_criteria (str, optional): Critic criteria for the
392- critic agent. If not specified, set the criteria to
485+ critic agent. If not specified, set it to
393486 improve task performance. (default: :obj:`None`)
394487 critic_kwargs (Dict, optional): Additional arguments to
395488 pass to the critic. (default: :obj:`None`)
@@ -465,20 +558,28 @@ def init_chat(self, init_msg_content: Optional[str] = None) -> BaseMessage:
465558 BaseMessage: A single `BaseMessage` representing the initial
466559 message.
467560 """
468- self .assistant_agent .reset ()
469- self .user_agent .reset ()
561+ if self .assistant_agent is not None :
562+ self .assistant_agent .reset ()
563+ if self .user_agent is not None :
564+ self .user_agent .reset ()
470565 default_init_msg_content = (
471566 "Now start to give me instructions one by one. "
472567 "Only reply with Instruction and Input."
473568 )
474- if init_msg_content is None :
475- init_msg_content = default_init_msg_content
569+ final_init_msg_content = init_msg_content or default_init_msg_content
476570
477571 # Initialize a message sent by the assistant
572+ assistant_role_name = "assistant"
573+ if self .assistant_sys_msg is not None and hasattr (
574+ self .assistant_sys_msg , 'role_name'
575+ ):
576+ role_name_attr = getattr (self .assistant_sys_msg , 'role_name' , None )
577+ if role_name_attr is not None :
578+ assistant_role_name = str (role_name_attr )
579+
478580 init_msg = BaseMessage .make_assistant_message (
479- role_name = getattr (self .assistant_sys_msg , 'role_name' , None )
480- or "assistant" ,
481- content = init_msg_content ,
581+ role_name = assistant_role_name ,
582+ content = final_init_msg_content ,
482583 )
483584
484585 return init_msg
@@ -501,20 +602,28 @@ async def ainit_chat(
501602 """
502603 # Currently, reset() is synchronous, but if it becomes async in the
503604 # future, we can await it here
504- self .assistant_agent .reset ()
505- self .user_agent .reset ()
605+ if self .assistant_agent is not None :
606+ self .assistant_agent .reset ()
607+ if self .user_agent is not None :
608+ self .user_agent .reset ()
506609 default_init_msg_content = (
507610 "Now start to give me instructions one by one. "
508611 "Only reply with Instruction and Input."
509612 )
510- if init_msg_content is None :
511- init_msg_content = default_init_msg_content
613+ final_init_msg_content = init_msg_content or default_init_msg_content
512614
513615 # Initialize a message sent by the assistant
616+ assistant_role_name = "assistant"
617+ if self .assistant_sys_msg is not None and hasattr (
618+ self .assistant_sys_msg , 'role_name'
619+ ):
620+ role_name_attr = getattr (self .assistant_sys_msg , 'role_name' , None )
621+ if role_name_attr is not None :
622+ assistant_role_name = str (role_name_attr )
623+
514624 init_msg = BaseMessage .make_assistant_message (
515- role_name = getattr (self .assistant_sys_msg , 'role_name' , None )
516- or "assistant" ,
517- content = init_msg_content ,
625+ role_name = assistant_role_name ,
626+ content = final_init_msg_content ,
518627 )
519628
520629 return init_msg
@@ -544,6 +653,11 @@ def step(
544653 user agent terminated the conversation, and any additional user
545654 information.
546655 """
656+ if self .user_agent is None :
657+ raise ValueError ("User agent is not initialized" )
658+ if self .assistant_agent is None :
659+ raise ValueError ("Assistant agent is not initialized" )
660+
547661 user_response = self .user_agent .step (assistant_msg )
548662 if user_response .terminated or user_response .msgs is None :
549663 return (
@@ -620,6 +734,11 @@ async def astep(
620734 user agent terminated the conversation, and any additional user
621735 information.
622736 """
737+ if self .user_agent is None :
738+ raise ValueError ("User agent is not initialized" )
739+ if self .assistant_agent is None :
740+ raise ValueError ("Assistant agent is not initialized" )
741+
623742 user_response = await self .user_agent .astep (assistant_msg )
624743 if user_response .terminated or user_response .msgs is None :
625744 return (
@@ -682,6 +801,10 @@ def clone(
682801 RolePlaying: A new instance of RolePlaying with the same
683802 configuration.
684803 """
804+ if self .assistant_agent is None or self .user_agent is None :
805+ raise ValueError (
806+ "Cannot clone: assistant_agent or user_agent is None"
807+ )
685808
686809 new_instance = RolePlaying (
687810 assistant_role_name = self .assistant_agent .role_name ,
0 commit comments