@@ -18,12 +18,14 @@ package supervisor
1818
1919import (
2020 "context"
21+ "fmt"
2122 "testing"
2223
2324 "github.com/stretchr/testify/assert"
2425 "go.uber.org/mock/gomock"
2526
2627 "github.com/cloudwego/eino/adk"
28+ "github.com/cloudwego/eino/compose"
2729 mockAdk "github.com/cloudwego/eino/internal/mock/adk"
2830 "github.com/cloudwego/eino/schema"
2931)
@@ -167,3 +169,280 @@ func TestNewSupervisor(t *testing.T) {
167169 assert .Equal (t , schema .Assistant , event .Output .MessageOutput .Role )
168170 assert .Equal (t , finishMsg .Content , event .Output .MessageOutput .Message .Content )
169171}
172+
173+ // mockSupervisor is a simple supervisor that performs concurrent transfers
174+ type mockSupervisor struct {
175+ name string
176+ targets []string
177+ times int
178+ }
179+
180+ func (a * mockSupervisor ) Name (_ context.Context ) string {
181+ return a .name
182+ }
183+
184+ func (a * mockSupervisor ) Description (_ context.Context ) string {
185+ return "mock supervisor agent"
186+ }
187+
188+ func (a * mockSupervisor ) Run (ctx context.Context , input * adk.AgentInput , opts ... adk.AgentRunOption ) * adk.AsyncIterator [* adk.AgentEvent ] {
189+ iter , gen := adk .NewAsyncIteratorPair [* adk.AgentEvent ]()
190+ if a .times > 0 {
191+ gen .Send (adk .EventFromMessage (schema .AssistantMessage ("job done" , nil ), nil , schema .Assistant , "" ))
192+ gen .Close ()
193+ return iter
194+ }
195+
196+ a .times ++
197+
198+ // Create assistant message with tool call for concurrent transfer
199+ toolCall := schema.ToolCall {
200+ ID : "transfer-tool-call" ,
201+ Type : "function" ,
202+ Function : schema.FunctionCall {
203+ Name : adk .TransferToAgentToolName ,
204+ Arguments : `{"agent_names":["` + a .targets [0 ] + `","` + a .targets [1 ] + `"]}` ,
205+ },
206+ }
207+ assistantMsg := schema .AssistantMessage ("" , []schema.ToolCall {toolCall })
208+ gen .Send (adk .EventFromMessage (assistantMsg , nil , schema .Assistant , "" ))
209+
210+ // Create tool message for the transfer
211+ toolMsg := schema .ToolMessage (fmt .Sprintf ("Successfully transfered to agents %v" , a .targets ), toolCall .ID ,
212+ schema .WithToolName (adk .TransferToAgentToolName ))
213+ transferEvent := adk .EventFromMessage (toolMsg , nil , schema .Tool , toolMsg .ToolName )
214+ transferEvent .Action = & adk.AgentAction {
215+ ConcurrentTransferToAgent : & adk.ConcurrentTransferToAgentAction {
216+ DestAgentNames : a .targets ,
217+ },
218+ }
219+ gen .Send (transferEvent )
220+ gen .Close ()
221+
222+ return iter
223+ }
224+
225+ // mockSimpleAgent is a basic agent that returns a simple message
226+ type mockSimpleAgent struct {
227+ name string
228+ msg string
229+ }
230+
231+ func (a * mockSimpleAgent ) Name (_ context.Context ) string {
232+ return a .name
233+ }
234+
235+ func (a * mockSimpleAgent ) Description (_ context.Context ) string {
236+ return "mock simple agent"
237+ }
238+
239+ func (a * mockSimpleAgent ) Run (ctx context.Context , input * adk.AgentInput , opts ... adk.AgentRunOption ) * adk.AsyncIterator [* adk.AgentEvent ] {
240+ iter , gen := adk .NewAsyncIteratorPair [* adk.AgentEvent ]()
241+ gen .Send (adk .EventFromMessage (schema .AssistantMessage (a .msg , nil ), nil , schema .Assistant , "" ))
242+ gen .Close ()
243+ return iter
244+ }
245+
246+ // mockInterruptibleResumableAgent interrupts on first run and resumes on second
247+ type mockInterruptibleResumableAgent struct {
248+ name string
249+ t * testing.T
250+ }
251+
252+ func (a * mockInterruptibleResumableAgent ) Name (_ context.Context ) string {
253+ return a .name
254+ }
255+
256+ func (a * mockInterruptibleResumableAgent ) Description (_ context.Context ) string {
257+ return "mock interruptible/resumable agent"
258+ }
259+
260+ func (a * mockInterruptibleResumableAgent ) Run (ctx context.Context , input * adk.AgentInput , opts ... adk.AgentRunOption ) * adk.AsyncIterator [* adk.AgentEvent ] {
261+ iter , gen := adk .NewAsyncIteratorPair [* adk.AgentEvent ]()
262+ gen .Send (adk .EventFromMessage (schema .AssistantMessage ("I will interrupt" , nil ), nil , schema .Assistant , "" ))
263+ gen .Send (adk .Interrupt (ctx , "interrupt data" ))
264+ gen .Close ()
265+ return iter
266+ }
267+
268+ func (a * mockInterruptibleResumableAgent ) Resume (ctx context.Context , info * adk.ResumeInfo , opts ... adk.AgentRunOption ) * adk.AsyncIterator [* adk.AgentEvent ] {
269+ assert .True (a .t , info .WasInterrupted )
270+
271+ // Check if this agent is the target of the resume
272+ isResumeTarget , hasData , data := compose.GetResumeContext [string ](ctx )
273+ if isResumeTarget && hasData {
274+ assert .Equal (a .t , "resume data" , data )
275+ }
276+
277+ iter , gen := adk .NewAsyncIteratorPair [* adk.AgentEvent ]()
278+ gen .Send (adk .EventFromMessage (schema .AssistantMessage ("I have resumed" , nil ), nil , schema .Assistant , "" ))
279+ gen .Close ()
280+ return iter
281+ }
282+
283+ // TestNestedSupervisor_ConcurrentTransfer_WithInterruptAndResume tests a complex scenario:
284+ // - Nested supervisor hierarchy
285+ // - Concurrent transfers at two levels
286+ // - Interrupt at grandchild level
287+ // - Resume with targeted data
288+ func TestNestedSupervisor_ConcurrentTransfer_WithInterruptAndResume (t * testing.T ) {
289+ ctx := context .Background ()
290+
291+ // 1. Define the agent hierarchy
292+ grandChild1 := & mockSimpleAgent {name : "GrandChild1" , msg : "GrandChild1 reporting" }
293+ grandChild2 := & mockInterruptibleResumableAgent {name : "GrandChild2" , t : t }
294+ subSupervisor := & mockSupervisor {name : "SubSupervisor" , targets : []string {"GrandChild1" , "GrandChild2" }}
295+
296+ subAgent1 := & mockSimpleAgent {name : "SubAgent1" , msg : "SubAgent1 reporting" }
297+ superSupervisor := & mockSupervisor {name : "SuperSupervisor" , targets : []string {"SubAgent1" , "SubSupervisor" }}
298+
299+ // 2. Build the nested supervisor hierarchy
300+ nestedSupervisor , err := New (ctx , & Config {
301+ Supervisor : subSupervisor ,
302+ SubAgents : []adk.Agent {grandChild1 , grandChild2 },
303+ })
304+ assert .NoError (t , err )
305+
306+ topSupervisor , err := New (ctx , & Config {
307+ Supervisor : superSupervisor ,
308+ SubAgents : []adk.Agent {subAgent1 , nestedSupervisor },
309+ })
310+ assert .NoError (t , err )
311+
312+ // 3. Run the top-level supervisor and expect interrupt
313+ runner := adk .NewRunner (ctx , adk.RunnerConfig {Agent : topSupervisor , CheckPointStore : newMyStore ()})
314+ aIter := runner .Run (ctx , []adk.Message {schema .UserMessage ("start" )},
315+ adk .WithCheckPointID ("test-checkpoint" ))
316+
317+ var finalEvent * adk.AgentEvent
318+ var events []* adk.AgentEvent
319+ for event , ok := aIter .Next (); ok ; event , ok = aIter .Next () {
320+ var role , content string
321+ var toolCalls []schema.ToolCall
322+ if event .Output != nil && event .Output .MessageOutput != nil {
323+ role = string (event .Output .MessageOutput .Role )
324+ content = event .Output .MessageOutput .Message .Content
325+ toolCalls = event .Output .MessageOutput .Message .ToolCalls
326+ }
327+ t .Logf ("Event: Agent=%s, Role=%s, Content=%s, ToolCalls= %v, Interrupted=%v, transfer=%v, concurrentTransfer=%v" ,
328+ event .AgentName , role , content , toolCalls ,
329+ event .Action != nil && event .Action .Interrupted != nil ,
330+ event .Action != nil && event .Action .TransferToAgent != nil ,
331+ event .Action != nil && event .Action .ConcurrentTransferToAgent != nil )
332+
333+ events = append (events , event )
334+ if event .Action != nil && event .Action .Interrupted != nil {
335+ finalEvent = event
336+ }
337+ }
338+
339+ if finalEvent == nil {
340+ t .Fatal ("Should have received an interrupt event" )
341+ }
342+ assert .Equal (t , "SuperSupervisor" , finalEvent .AgentName , "Interrupt should propagate to top supervisor" )
343+
344+ // 4. Verify the execution sequence - handle complex concurrent execution
345+ assert .Equal (t , 8 , len (events ), "Should have 8 events in initial execution" )
346+
347+ // Check SuperSupervisor concurrent transfer (events 0-1)
348+ assert .Equal (t , "SuperSupervisor" , events [0 ].AgentName )
349+ assert .Equal (t , schema .Assistant , events [0 ].Output .MessageOutput .Role )
350+ assert .True (t , events [1 ].Action != nil && events [1 ].Action .ConcurrentTransferToAgent != nil )
351+
352+ // Event 2: SubAgent1 completes (first concurrent branch)
353+ assert .Equal (t , "SubAgent1" , events [2 ].AgentName )
354+ assert .Equal (t , "SubAgent1 reporting" , events [2 ].Output .MessageOutput .Message .Content )
355+
356+ // Check SubSupervisor concurrent transfer (events 3-4)
357+ assert .Equal (t , "SubSupervisor" , events [3 ].AgentName )
358+ assert .Equal (t , schema .Assistant , events [3 ].Output .MessageOutput .Role )
359+ assert .True (t , events [4 ].Action != nil && events [4 ].Action .ConcurrentTransferToAgent != nil )
360+
361+ // Events 5-6: GrandChild2 and GrandChild1 (order can vary due to concurrency)
362+ agentEvents := make (map [string ]string )
363+ for i := 5 ; i <= 6 ; i ++ {
364+ if events [i ].Output != nil && events [i ].Output .MessageOutput != nil {
365+ agentEvents [events [i ].AgentName ] = events [i ].Output .MessageOutput .Message .Content
366+ }
367+ }
368+
369+ // Verify both grandchild agents completed/interrupted as expected
370+ assert .Equal (t , "GrandChild1 reporting" , agentEvents ["GrandChild1" ])
371+ assert .Equal (t , "I will interrupt" , agentEvents ["GrandChild2" ])
372+
373+ // Check interrupt (event 7)
374+ assert .Equal (t , "SuperSupervisor" , events [7 ].AgentName )
375+ assert .True (t , events [7 ].Action != nil && events [7 ].Action .Interrupted != nil )
376+
377+ // 5. Resume the execution with targeted resume data
378+ resumeIter , err := runner .TargetedResume (ctx , "test-checkpoint" , map [string ]any {
379+ "GrandChild2" : "resume data" ,
380+ })
381+ assert .NoError (t , err )
382+
383+ // 6. Verify the resume flow completes successfully
384+ var resumeEvents []* adk.AgentEvent
385+ for event , ok := resumeIter .Next (); ok ; event , ok = resumeIter .Next () {
386+ var role , content string
387+ var toolCalls []schema.ToolCall
388+ if event .Output != nil && event .Output .MessageOutput != nil {
389+ role = string (event .Output .MessageOutput .Role )
390+ content = event .Output .MessageOutput .Message .Content
391+ toolCalls = event .Output .MessageOutput .Message .ToolCalls
392+ }
393+ t .Logf ("Resume Event: Agent=%s, Role=%s, Content=%s, ToolCalls= %v, Interrupted=%v, transfer=%v, concurrentTransfer=%v" ,
394+ event .AgentName , role , content , toolCalls ,
395+ event .Action != nil && event .Action .Interrupted != nil ,
396+ event .Action != nil && event .Action .TransferToAgent != nil ,
397+ event .Action != nil && event .Action .ConcurrentTransferToAgent != nil )
398+ resumeEvents = append (resumeEvents , event )
399+ }
400+
401+ assert .Equal (t , 7 , len (resumeEvents ), "Should have 7 events in resume execution" )
402+
403+ // Check GrandChild2 resume
404+ assert .Equal (t , "GrandChild2" , resumeEvents [0 ].AgentName )
405+ assert .Equal (t , "I have resumed" , resumeEvents [0 ].Output .MessageOutput .Message .Content )
406+
407+ // Check SubSupervisor self-return and completion
408+ assert .Equal (t , "SubSupervisor" , resumeEvents [1 ].AgentName )
409+ assert .Equal (t , schema .Assistant , resumeEvents [1 ].Output .MessageOutput .Role )
410+
411+ assert .Equal (t , "SubSupervisor" , resumeEvents [2 ].AgentName )
412+ assert .True (t , resumeEvents [2 ].Action != nil && resumeEvents [2 ].Action .TransferToAgent != nil )
413+ assert .Equal (t , "SubSupervisor" , resumeEvents [2 ].Action .TransferToAgent .DestAgentName )
414+
415+ assert .Equal (t , "SubSupervisor" , resumeEvents [3 ].AgentName )
416+ assert .Equal (t , "job done" , resumeEvents [3 ].Output .MessageOutput .Message .Content )
417+
418+ // Check SuperSupervisor self-return and completion
419+ assert .Equal (t , "SuperSupervisor" , resumeEvents [4 ].AgentName )
420+ assert .Equal (t , schema .Assistant , resumeEvents [4 ].Output .MessageOutput .Role )
421+
422+ assert .Equal (t , "SuperSupervisor" , resumeEvents [5 ].AgentName )
423+ assert .True (t , resumeEvents [5 ].Action != nil && resumeEvents [5 ].Action .TransferToAgent != nil )
424+ assert .Equal (t , "SuperSupervisor" , resumeEvents [5 ].Action .TransferToAgent .DestAgentName )
425+
426+ assert .Equal (t , "SuperSupervisor" , resumeEvents [6 ].AgentName )
427+ assert .Equal (t , "job done" , resumeEvents [6 ].Output .MessageOutput .Message .Content )
428+ }
429+
430+ func newMyStore () * myStore {
431+ return & myStore {
432+ m : map [string ][]byte {},
433+ }
434+ }
435+
436+ type myStore struct {
437+ m map [string ][]byte
438+ }
439+
440+ func (m * myStore ) Set (_ context.Context , key string , value []byte ) error {
441+ m .m [key ] = value
442+ return nil
443+ }
444+
445+ func (m * myStore ) Get (_ context.Context , key string ) ([]byte , bool , error ) {
446+ v , ok := m .m [key ]
447+ return v , ok , nil
448+ }
0 commit comments