Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions adk/chatmodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,10 @@ func (a *ChatModelAgent) getRunFunc(ctx context.Context) (context.Context, runFu
defaultRun := a.buildRunFunc(ctx)
bc := a.exeCtx

if bc == nil {
return ctx, defaultRun, bc, nil
}

if len(a.handlers) == 0 {
runtimeBC := &execContext{
instruction: bc.instruction,
Expand Down Expand Up @@ -974,9 +978,12 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age

co := getComposeOptions(opts)
co = append(co, compose.WithCheckPointID(bridgeCheckpointID))
co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos)))
if bc.toolUpdated {
co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...)))

if bc != nil {
co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos)))
if bc.toolUpdated {
co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...)))
}
}

go func() {
Expand All @@ -990,7 +997,17 @@ func (a *ChatModelAgent) Run(ctx context.Context, input *AgentInput, opts ...Age
generator.Close()
}()

run(ctx, input, generator, newBridgeStore(), bc.instruction, bc.returnDirectly, co...)
var (
instruction string
returnDirectly map[string]bool
)

if bc != nil {
instruction = bc.instruction
returnDirectly = bc.returnDirectly
}

run(ctx, input, generator, newBridgeStore(), instruction, returnDirectly, co...)
}()

return iterator
Expand All @@ -1010,9 +1027,12 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A

co := getComposeOptions(opts)
co = append(co, compose.WithCheckPointID(bridgeCheckpointID))
co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos)))
if bc.toolUpdated {
co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...)))

if bc != nil {
co = append(co, compose.WithChatModelOption(model.WithTools(bc.toolInfos)))
if bc.toolUpdated {
co = append(co, compose.WithToolsNodeOption(compose.WithToolList(bc.toolsNodeConf.Tools...)))
}
}

if info.InterruptState == nil {
Expand Down Expand Up @@ -1057,8 +1077,18 @@ func (a *ChatModelAgent) Resume(ctx context.Context, info *ResumeInfo, opts ...A
generator.Close()
}()

var (
instruction string
returnDirectly map[string]bool
)

if bc != nil {
instruction = bc.instruction
returnDirectly = bc.returnDirectly
}

run(ctx, &AgentInput{EnableStreaming: info.EnableStreaming}, generator,
newResumeBridgeStore(stateByte), bc.instruction, bc.returnDirectly, co...)
newResumeBridgeStore(stateByte), instruction, returnDirectly, co...)
}()

return iterator
Expand Down
81 changes: 81 additions & 0 deletions adk/chatmodel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1296,3 +1296,84 @@ func testToolOption(value string) tool.Option {
o.value = value
})
}

type errorTool struct {
infoErr error
}

func (e *errorTool) Info(_ context.Context) (*schema.ToolInfo, error) {
return nil, e.infoErr
}

func (e *errorTool) InvokableRun(_ context.Context, _ string, _ ...tool.Option) (string, error) {
return "", nil
}

func TestChatModelAgent_PrepareExecContextError(t *testing.T) {
t.Run("Run_WithToolInfoError_ReturnsError", func(t *testing.T) {
ctx := context.Background()

ctrl := gomock.NewController(t)
cm := mockModel.NewMockToolCallingChatModel(ctrl)

expectedErr := errors.New("tool info error")
errTool := &errorTool{infoErr: expectedErr}

agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{
Name: "TestAgent",
Description: "Test agent",
Model: cm,
ToolsConfig: ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{
Tools: []tool.BaseTool{errTool},
},
},
})
assert.NoError(t, err)

iter := agent.Run(ctx, &AgentInput{Messages: []Message{schema.UserMessage("test")}})

event, ok := iter.Next()
assert.True(t, ok)
assert.NotNil(t, event.Err)
assert.Contains(t, event.Err.Error(), "tool info error")

_, ok = iter.Next()
assert.False(t, ok)
})

t.Run("Resume_WithToolInfoError_ReturnsError", func(t *testing.T) {
ctx := context.Background()

ctrl := gomock.NewController(t)
cm := mockModel.NewMockToolCallingChatModel(ctrl)

expectedErr := errors.New("tool info error for resume")
errTool := &errorTool{infoErr: expectedErr}

agent, err := NewChatModelAgent(ctx, &ChatModelAgentConfig{
Name: "TestAgent",
Description: "Test agent",
Model: cm,
ToolsConfig: ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{
Tools: []tool.BaseTool{errTool},
},
},
})
assert.NoError(t, err)

iter := agent.Resume(ctx, &ResumeInfo{
InterruptState: []byte("dummy"),
EnableStreaming: false,
})

event, ok := iter.Next()
assert.True(t, ok)
assert.NotNil(t, event.Err)
assert.Contains(t, event.Err.Error(), "tool info error for resume")

_, ok = iter.Next()
assert.False(t, ok)
})
}