Skip to content

Commit ac6e088

Browse files
feat(adk): ut for supervisor supports concurrent transfer
Change-Id: I64654953e73696105bb0437d27adf707c35fe941
1 parent d6db665 commit ac6e088

File tree

4 files changed

+308
-5
lines changed

4 files changed

+308
-5
lines changed

adk/deterministic_transfer.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ func (a *agentWithDeterministicTransferTo) Name(ctx context.Context) string {
5353
func (a *agentWithDeterministicTransferTo) Run(ctx context.Context,
5454
input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] {
5555

56+
if _, ok := a.agent.(*flowAgent); ok {
57+
ctx = ClearRunCtx(ctx)
58+
}
59+
5660
aIter := a.agent.Run(ctx, input, options...)
5761

5862
iterator, generator := NewAsyncIteratorPair[*AgentEvent]()
@@ -77,6 +81,10 @@ func (a *resumableAgentWithDeterministicTransferTo) Name(ctx context.Context) st
7781
func (a *resumableAgentWithDeterministicTransferTo) Run(ctx context.Context,
7882
input *AgentInput, options ...AgentRunOption) *AsyncIterator[*AgentEvent] {
7983

84+
if _, ok := a.agent.(*flowAgent); ok {
85+
ctx = ClearRunCtx(ctx)
86+
}
87+
8088
aIter := a.agent.Run(ctx, input, options...)
8189

8290
iterator, generator := NewAsyncIteratorPair[*AgentEvent]()

adk/flow.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ type HistoryEntry struct {
3838

3939
type HistoryRewriter func(ctx context.Context, entries []*HistoryEntry) ([]Message, error)
4040

41-
type dynamicParallelState struct {
41+
type flowInterruptState struct {
4242
// Maps the destination agent name to the events generated in its lane before interruption.
4343
// This also serves as the source of truth for which lanes need to be resumed.
4444
LaneEvents map[string][]*agentEventWrapper
@@ -71,7 +71,7 @@ func (a *flowAgent) collectLaneEvents(childContexts []context.Context, agentName
7171

7272
// createCompositeInterrupt creates a composite interrupt event with the collected state
7373
func (a *flowAgent) createCompositeInterrupt(ctx context.Context, laneEvents map[string][]*agentEventWrapper, subInterruptSignals []*core.InterruptSignal) *AgentEvent {
74-
state := &dynamicParallelState{
74+
state := &flowInterruptState{
7575
LaneEvents: laneEvents,
7676
}
7777

@@ -85,7 +85,7 @@ func (a *flowAgent) createCompositeInterrupt(ctx context.Context, laneEvents map
8585
}
8686

8787
func init() {
88-
schema.RegisterName[*dynamicParallelState]("eino_adk_dynamic_parallel_state")
88+
schema.RegisterName[*flowInterruptState]("eino_adk_dynamic_parallel_state")
8989
}
9090

9191
type flowAgent struct {
@@ -372,7 +372,7 @@ func (a *flowAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...AgentR
372372
if info.WasInterrupted {
373373
// Check if we need to resume concurrent transfers
374374
if info.InterruptState != nil {
375-
state, ok := info.InterruptState.(*dynamicParallelState)
375+
state, ok := info.InterruptState.(*flowInterruptState)
376376
if ok {
377377
// Delegate to resumeConcurrentLanes which will handle the type assertion
378378
return a.resumeConcurrentLanes(ctx, state, info, opts...)
@@ -555,7 +555,7 @@ func (a *flowAgent) runConcurrentLanes(
555555
// resumeConcurrentLanes resumes execution after a concurrent transfer interruption
556556
func (a *flowAgent) resumeConcurrentLanes(
557557
ctx context.Context,
558-
state *dynamicParallelState,
558+
state *flowInterruptState,
559559
info *ResumeInfo,
560560
opts ...AgentRunOption) *AsyncIterator[*AgentEvent] {
561561

adk/flow_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
/*
2+
* Copyright 2025 CloudWeGo Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
117
package adk
218

319
import (

adk/prebuilt/supervisor/supervisor_test.go

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ package supervisor
1818

1919
import (
2020
"context"
21+
"fmt"
2122
"testing"
2223

2324
"github.com/stretchr/testify/assert"
2425
"go.uber.org/mock/gomock"
2526

2627
"github.com/cloudwego/eino/adk"
28+
"github.com/cloudwego/eino/compose"
2729
mockAdk "github.com/cloudwego/eino/internal/mock/adk"
2830
"github.com/cloudwego/eino/schema"
2931
)
@@ -167,3 +169,280 @@ func TestNewSupervisor(t *testing.T) {
167169
assert.Equal(t, schema.Assistant, event.Output.MessageOutput.Role)
168170
assert.Equal(t, finishMsg.Content, event.Output.MessageOutput.Message.Content)
169171
}
172+
173+
// mockSupervisor is a simple supervisor that performs concurrent transfers
174+
type mockSupervisor struct {
175+
name string
176+
targets []string
177+
times int
178+
}
179+
180+
func (a *mockSupervisor) Name(_ context.Context) string {
181+
return a.name
182+
}
183+
184+
func (a *mockSupervisor) Description(_ context.Context) string {
185+
return "mock supervisor agent"
186+
}
187+
188+
func (a *mockSupervisor) Run(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] {
189+
iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]()
190+
if a.times > 0 {
191+
gen.Send(adk.EventFromMessage(schema.AssistantMessage("job done", nil), nil, schema.Assistant, ""))
192+
gen.Close()
193+
return iter
194+
}
195+
196+
a.times++
197+
198+
// Create assistant message with tool call for concurrent transfer
199+
toolCall := schema.ToolCall{
200+
ID: "transfer-tool-call",
201+
Type: "function",
202+
Function: schema.FunctionCall{
203+
Name: adk.TransferToAgentToolName,
204+
Arguments: `{"agent_names":["` + a.targets[0] + `","` + a.targets[1] + `"]}`,
205+
},
206+
}
207+
assistantMsg := schema.AssistantMessage("", []schema.ToolCall{toolCall})
208+
gen.Send(adk.EventFromMessage(assistantMsg, nil, schema.Assistant, ""))
209+
210+
// Create tool message for the transfer
211+
toolMsg := schema.ToolMessage(fmt.Sprintf("Successfully transfered to agents %v", a.targets), toolCall.ID,
212+
schema.WithToolName(adk.TransferToAgentToolName))
213+
transferEvent := adk.EventFromMessage(toolMsg, nil, schema.Tool, toolMsg.ToolName)
214+
transferEvent.Action = &adk.AgentAction{
215+
ConcurrentTransferToAgent: &adk.ConcurrentTransferToAgentAction{
216+
DestAgentNames: a.targets,
217+
},
218+
}
219+
gen.Send(transferEvent)
220+
gen.Close()
221+
222+
return iter
223+
}
224+
225+
// mockSimpleAgent is a basic agent that returns a simple message
226+
type mockSimpleAgent struct {
227+
name string
228+
msg string
229+
}
230+
231+
func (a *mockSimpleAgent) Name(_ context.Context) string {
232+
return a.name
233+
}
234+
235+
func (a *mockSimpleAgent) Description(_ context.Context) string {
236+
return "mock simple agent"
237+
}
238+
239+
func (a *mockSimpleAgent) Run(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] {
240+
iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]()
241+
gen.Send(adk.EventFromMessage(schema.AssistantMessage(a.msg, nil), nil, schema.Assistant, ""))
242+
gen.Close()
243+
return iter
244+
}
245+
246+
// mockInterruptibleResumableAgent interrupts on first run and resumes on second
247+
type mockInterruptibleResumableAgent struct {
248+
name string
249+
t *testing.T
250+
}
251+
252+
func (a *mockInterruptibleResumableAgent) Name(_ context.Context) string {
253+
return a.name
254+
}
255+
256+
func (a *mockInterruptibleResumableAgent) Description(_ context.Context) string {
257+
return "mock interruptible/resumable agent"
258+
}
259+
260+
func (a *mockInterruptibleResumableAgent) Run(ctx context.Context, input *adk.AgentInput, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] {
261+
iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]()
262+
gen.Send(adk.EventFromMessage(schema.AssistantMessage("I will interrupt", nil), nil, schema.Assistant, ""))
263+
gen.Send(adk.Interrupt(ctx, "interrupt data"))
264+
gen.Close()
265+
return iter
266+
}
267+
268+
func (a *mockInterruptibleResumableAgent) Resume(ctx context.Context, info *adk.ResumeInfo, opts ...adk.AgentRunOption) *adk.AsyncIterator[*adk.AgentEvent] {
269+
assert.True(a.t, info.WasInterrupted)
270+
271+
// Check if this agent is the target of the resume
272+
isResumeTarget, hasData, data := compose.GetResumeContext[string](ctx)
273+
if isResumeTarget && hasData {
274+
assert.Equal(a.t, "resume data", data)
275+
}
276+
277+
iter, gen := adk.NewAsyncIteratorPair[*adk.AgentEvent]()
278+
gen.Send(adk.EventFromMessage(schema.AssistantMessage("I have resumed", nil), nil, schema.Assistant, ""))
279+
gen.Close()
280+
return iter
281+
}
282+
283+
// TestNestedSupervisor_ConcurrentTransfer_WithInterruptAndResume tests a complex scenario:
284+
// - Nested supervisor hierarchy
285+
// - Concurrent transfers at two levels
286+
// - Interrupt at grandchild level
287+
// - Resume with targeted data
288+
func TestNestedSupervisor_ConcurrentTransfer_WithInterruptAndResume(t *testing.T) {
289+
ctx := context.Background()
290+
291+
// 1. Define the agent hierarchy
292+
grandChild1 := &mockSimpleAgent{name: "GrandChild1", msg: "GrandChild1 reporting"}
293+
grandChild2 := &mockInterruptibleResumableAgent{name: "GrandChild2", t: t}
294+
subSupervisor := &mockSupervisor{name: "SubSupervisor", targets: []string{"GrandChild1", "GrandChild2"}}
295+
296+
subAgent1 := &mockSimpleAgent{name: "SubAgent1", msg: "SubAgent1 reporting"}
297+
superSupervisor := &mockSupervisor{name: "SuperSupervisor", targets: []string{"SubAgent1", "SubSupervisor"}}
298+
299+
// 2. Build the nested supervisor hierarchy
300+
nestedSupervisor, err := New(ctx, &Config{
301+
Supervisor: subSupervisor,
302+
SubAgents: []adk.Agent{grandChild1, grandChild2},
303+
})
304+
assert.NoError(t, err)
305+
306+
topSupervisor, err := New(ctx, &Config{
307+
Supervisor: superSupervisor,
308+
SubAgents: []adk.Agent{subAgent1, nestedSupervisor},
309+
})
310+
assert.NoError(t, err)
311+
312+
// 3. Run the top-level supervisor and expect interrupt
313+
runner := adk.NewRunner(ctx, adk.RunnerConfig{Agent: topSupervisor, CheckPointStore: newMyStore()})
314+
aIter := runner.Run(ctx, []adk.Message{schema.UserMessage("start")},
315+
adk.WithCheckPointID("test-checkpoint"))
316+
317+
var finalEvent *adk.AgentEvent
318+
var events []*adk.AgentEvent
319+
for event, ok := aIter.Next(); ok; event, ok = aIter.Next() {
320+
var role, content string
321+
var toolCalls []schema.ToolCall
322+
if event.Output != nil && event.Output.MessageOutput != nil {
323+
role = string(event.Output.MessageOutput.Role)
324+
content = event.Output.MessageOutput.Message.Content
325+
toolCalls = event.Output.MessageOutput.Message.ToolCalls
326+
}
327+
t.Logf("Event: Agent=%s, Role=%s, Content=%s, ToolCalls= %v, Interrupted=%v, transfer=%v, concurrentTransfer=%v",
328+
event.AgentName, role, content, toolCalls,
329+
event.Action != nil && event.Action.Interrupted != nil,
330+
event.Action != nil && event.Action.TransferToAgent != nil,
331+
event.Action != nil && event.Action.ConcurrentTransferToAgent != nil)
332+
333+
events = append(events, event)
334+
if event.Action != nil && event.Action.Interrupted != nil {
335+
finalEvent = event
336+
}
337+
}
338+
339+
if finalEvent == nil {
340+
t.Fatal("Should have received an interrupt event")
341+
}
342+
assert.Equal(t, "SuperSupervisor", finalEvent.AgentName, "Interrupt should propagate to top supervisor")
343+
344+
// 4. Verify the execution sequence - handle complex concurrent execution
345+
assert.Equal(t, 8, len(events), "Should have 8 events in initial execution")
346+
347+
// Check SuperSupervisor concurrent transfer (events 0-1)
348+
assert.Equal(t, "SuperSupervisor", events[0].AgentName)
349+
assert.Equal(t, schema.Assistant, events[0].Output.MessageOutput.Role)
350+
assert.True(t, events[1].Action != nil && events[1].Action.ConcurrentTransferToAgent != nil)
351+
352+
// Event 2: SubAgent1 completes (first concurrent branch)
353+
assert.Equal(t, "SubAgent1", events[2].AgentName)
354+
assert.Equal(t, "SubAgent1 reporting", events[2].Output.MessageOutput.Message.Content)
355+
356+
// Check SubSupervisor concurrent transfer (events 3-4)
357+
assert.Equal(t, "SubSupervisor", events[3].AgentName)
358+
assert.Equal(t, schema.Assistant, events[3].Output.MessageOutput.Role)
359+
assert.True(t, events[4].Action != nil && events[4].Action.ConcurrentTransferToAgent != nil)
360+
361+
// Events 5-6: GrandChild2 and GrandChild1 (order can vary due to concurrency)
362+
agentEvents := make(map[string]string)
363+
for i := 5; i <= 6; i++ {
364+
if events[i].Output != nil && events[i].Output.MessageOutput != nil {
365+
agentEvents[events[i].AgentName] = events[i].Output.MessageOutput.Message.Content
366+
}
367+
}
368+
369+
// Verify both grandchild agents completed/interrupted as expected
370+
assert.Equal(t, "GrandChild1 reporting", agentEvents["GrandChild1"])
371+
assert.Equal(t, "I will interrupt", agentEvents["GrandChild2"])
372+
373+
// Check interrupt (event 7)
374+
assert.Equal(t, "SuperSupervisor", events[7].AgentName)
375+
assert.True(t, events[7].Action != nil && events[7].Action.Interrupted != nil)
376+
377+
// 5. Resume the execution with targeted resume data
378+
resumeIter, err := runner.TargetedResume(ctx, "test-checkpoint", map[string]any{
379+
"GrandChild2": "resume data",
380+
})
381+
assert.NoError(t, err)
382+
383+
// 6. Verify the resume flow completes successfully
384+
var resumeEvents []*adk.AgentEvent
385+
for event, ok := resumeIter.Next(); ok; event, ok = resumeIter.Next() {
386+
var role, content string
387+
var toolCalls []schema.ToolCall
388+
if event.Output != nil && event.Output.MessageOutput != nil {
389+
role = string(event.Output.MessageOutput.Role)
390+
content = event.Output.MessageOutput.Message.Content
391+
toolCalls = event.Output.MessageOutput.Message.ToolCalls
392+
}
393+
t.Logf("Resume Event: Agent=%s, Role=%s, Content=%s, ToolCalls= %v, Interrupted=%v, transfer=%v, concurrentTransfer=%v",
394+
event.AgentName, role, content, toolCalls,
395+
event.Action != nil && event.Action.Interrupted != nil,
396+
event.Action != nil && event.Action.TransferToAgent != nil,
397+
event.Action != nil && event.Action.ConcurrentTransferToAgent != nil)
398+
resumeEvents = append(resumeEvents, event)
399+
}
400+
401+
assert.Equal(t, 7, len(resumeEvents), "Should have 7 events in resume execution")
402+
403+
// Check GrandChild2 resume
404+
assert.Equal(t, "GrandChild2", resumeEvents[0].AgentName)
405+
assert.Equal(t, "I have resumed", resumeEvents[0].Output.MessageOutput.Message.Content)
406+
407+
// Check SubSupervisor self-return and completion
408+
assert.Equal(t, "SubSupervisor", resumeEvents[1].AgentName)
409+
assert.Equal(t, schema.Assistant, resumeEvents[1].Output.MessageOutput.Role)
410+
411+
assert.Equal(t, "SubSupervisor", resumeEvents[2].AgentName)
412+
assert.True(t, resumeEvents[2].Action != nil && resumeEvents[2].Action.TransferToAgent != nil)
413+
assert.Equal(t, "SubSupervisor", resumeEvents[2].Action.TransferToAgent.DestAgentName)
414+
415+
assert.Equal(t, "SubSupervisor", resumeEvents[3].AgentName)
416+
assert.Equal(t, "job done", resumeEvents[3].Output.MessageOutput.Message.Content)
417+
418+
// Check SuperSupervisor self-return and completion
419+
assert.Equal(t, "SuperSupervisor", resumeEvents[4].AgentName)
420+
assert.Equal(t, schema.Assistant, resumeEvents[4].Output.MessageOutput.Role)
421+
422+
assert.Equal(t, "SuperSupervisor", resumeEvents[5].AgentName)
423+
assert.True(t, resumeEvents[5].Action != nil && resumeEvents[5].Action.TransferToAgent != nil)
424+
assert.Equal(t, "SuperSupervisor", resumeEvents[5].Action.TransferToAgent.DestAgentName)
425+
426+
assert.Equal(t, "SuperSupervisor", resumeEvents[6].AgentName)
427+
assert.Equal(t, "job done", resumeEvents[6].Output.MessageOutput.Message.Content)
428+
}
429+
430+
func newMyStore() *myStore {
431+
return &myStore{
432+
m: map[string][]byte{},
433+
}
434+
}
435+
436+
type myStore struct {
437+
m map[string][]byte
438+
}
439+
440+
func (m *myStore) Set(_ context.Context, key string, value []byte) error {
441+
m.m[key] = value
442+
return nil
443+
}
444+
445+
func (m *myStore) Get(_ context.Context, key string) ([]byte, bool, error) {
446+
v, ok := m.m[key]
447+
return v, ok, nil
448+
}

0 commit comments

Comments
 (0)