Skip to content

Commit ffe8b84

Browse files
authored
Merge pull request #1548 from dgageot/wrappers
Work on Tool Wrappers
2 parents 96491c4 + 6be42ae commit ffe8b84

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+489
-300
lines changed

pkg/agent/agent.go

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,28 @@ import (
1515

1616
// Agent represents an AI agent
1717
type Agent struct {
18-
name string
19-
description string
20-
welcomeMessage string
21-
instruction string
22-
toolsets []*StartableToolSet
23-
models []provider.Provider
24-
modelOverrides atomic.Pointer[[]provider.Provider] // Optional model override(s) set at runtime (supports alloy)
25-
subAgents []*Agent
26-
handoffs []*Agent
27-
parents []*Agent
28-
addDate bool
29-
addEnvironmentInfo bool
30-
maxIterations int
31-
numHistoryItems int
32-
addPromptFiles []string
33-
tools []tools.Tool
34-
commands types.Commands
35-
pendingWarnings []string
36-
skillsEnabled bool
37-
hooks *latest.HooksConfig
38-
thinkingConfigured bool // true if thinking_budget was explicitly set in config
18+
name string
19+
description string
20+
welcomeMessage string
21+
instruction string
22+
toolsets []*tools.StartableToolSet
23+
models []provider.Provider
24+
modelOverrides atomic.Pointer[[]provider.Provider] // Optional model override(s) set at runtime (supports alloy)
25+
subAgents []*Agent
26+
handoffs []*Agent
27+
parents []*Agent
28+
addDate bool
29+
addEnvironmentInfo bool
30+
addDescriptionParameter bool
31+
maxIterations int
32+
numHistoryItems int
33+
addPromptFiles []string
34+
tools []tools.Tool
35+
commands types.Commands
36+
pendingWarnings []string
37+
skillsEnabled bool
38+
hooks *latest.HooksConfig
39+
thinkingConfigured bool // true if thinking_budget was explicitly set in config
3940
}
4041

4142
// New creates a new agent
@@ -203,6 +204,10 @@ func (a *Agent) Tools(ctx context.Context) ([]tools.Tool, error) {
203204

204205
agentTools = append(agentTools, a.tools...)
205206

207+
if a.addDescriptionParameter {
208+
agentTools = tools.AddDescriptionParameter(agentTools)
209+
}
210+
206211
return agentTools, nil
207212
}
208213

pkg/agent/agent_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@ import (
1414
)
1515

1616
type stubToolSet struct {
17-
tools.BaseToolSet
1817
startErr error
1918
tools []tools.Tool
2019
listErr error
2120
}
2221

22+
// Verify interface compliance
23+
var (
24+
_ tools.ToolSet = (*stubToolSet)(nil)
25+
_ tools.Startable = (*stubToolSet)(nil)
26+
)
27+
2328
func newStubToolSet(startErr error, toolsList []tools.Tool, listErr error) tools.ToolSet {
2429
return &stubToolSet{
2530
startErr: startErr,
@@ -29,6 +34,7 @@ func newStubToolSet(startErr error, toolsList []tools.Tool, listErr error) tools
2934
}
3035

3136
func (s *stubToolSet) Start(context.Context) error { return s.startErr }
37+
func (s *stubToolSet) Stop(context.Context) error { return nil }
3238
func (s *stubToolSet) Tools(context.Context) ([]tools.Tool, error) {
3339
if s.listErr != nil {
3440
return nil, s.listErr

pkg/agent/opts.go

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package agent
22

33
import (
4-
"context"
5-
"sync"
6-
74
"github.com/docker/cagent/pkg/config/latest"
85
"github.com/docker/cagent/pkg/config/types"
96
"github.com/docker/cagent/pkg/model/provider"
@@ -19,11 +16,9 @@ func WithInstruction(instruction string) Opt {
1916
}
2017

2118
func WithToolSets(toolSet ...tools.ToolSet) Opt {
22-
var startableToolSet []*StartableToolSet
19+
var startableToolSet []*tools.StartableToolSet
2320
for _, ts := range toolSet {
24-
startableToolSet = append(startableToolSet, &StartableToolSet{
25-
ToolSet: ts,
26-
})
21+
startableToolSet = append(startableToolSet, tools.NewStartable(ts))
2722
}
2823

2924
return func(a *Agent) {
@@ -88,6 +83,12 @@ func WithAddEnvironmentInfo(addEnvironmentInfo bool) Opt {
8883
}
8984
}
9085

86+
func WithAddDescriptionParameter(addDescriptionParameter bool) Opt {
87+
return func(a *Agent) {
88+
a.addDescriptionParameter = addDescriptionParameter
89+
}
90+
}
91+
9192
func WithAddPromptFiles(addPromptFiles []string) Opt {
9293
return func(a *Agent) {
9394
a.addPromptFiles = addPromptFiles
@@ -139,36 +140,3 @@ func WithThinkingConfigured(configured bool) Opt {
139140
a.thinkingConfigured = configured
140141
}
141142
}
142-
143-
type StartableToolSet struct {
144-
tools.ToolSet
145-
146-
mu sync.Mutex
147-
started bool
148-
}
149-
150-
// IsStarted returns whether the toolset has been successfully started.
151-
func (s *StartableToolSet) IsStarted() bool {
152-
s.mu.Lock()
153-
defer s.mu.Unlock()
154-
return s.started
155-
}
156-
157-
// Start starts the toolset.
158-
//
159-
// It provides single-flight semantics: concurrent callers block until this start
160-
// attempt completes. If this attempt fails, a future call will retry.
161-
func (s *StartableToolSet) Start(ctx context.Context) error {
162-
s.mu.Lock()
163-
defer s.mu.Unlock()
164-
165-
if s.started {
166-
return nil
167-
}
168-
169-
err := s.ToolSet.Start(ctx)
170-
if err == nil {
171-
s.started = true
172-
}
173-
return err
174-
}

pkg/agent/opts_test.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,30 @@ import (
1212
)
1313

1414
type flakyStartToolset struct {
15-
tools.BaseToolSet
1615
calls atomic.Int64
1716
}
1817

18+
// Verify interface compliance
19+
var (
20+
_ tools.ToolSet = (*flakyStartToolset)(nil)
21+
_ tools.Startable = (*flakyStartToolset)(nil)
22+
)
23+
1924
func (f *flakyStartToolset) Start(context.Context) error {
2025
if f.calls.Add(1) == 1 {
2126
return errors.New("no events channel available for elicitation")
2227
}
2328
return nil
2429
}
2530

31+
func (f *flakyStartToolset) Stop(context.Context) error { return nil }
32+
2633
func (f *flakyStartToolset) Tools(context.Context) ([]tools.Tool, error) { return nil, nil }
2734

2835
func TestStartableToolSet_RetriesAfterFailure(t *testing.T) {
2936
ctx := t.Context()
3037
inner := &flakyStartToolset{}
31-
ts := &StartableToolSet{ToolSet: inner}
38+
ts := tools.NewStartable(inner)
3239

3340
err := ts.Start(ctx)
3441
require.Error(t, err)

pkg/app/app.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func (a *App) ExecuteMCPPrompt(ctx context.Context, promptName string, arguments
188188
}
189189

190190
for _, toolset := range currentAgent.ToolSets() {
191-
if mcpToolset := runtime.UnwrapMCPToolset(toolset); mcpToolset != nil {
191+
if mcpToolset, ok := tools.As[*mcptools.Toolset](toolset); ok {
192192
result, err := mcpToolset.GetPrompt(ctx, promptName, arguments)
193193
if err == nil {
194194
// Convert the MCP result to a string format suitable for the editor

pkg/runtime/runtime.go

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,6 @@ import (
3838
mcptools "github.com/docker/cagent/pkg/tools/mcp"
3939
)
4040

41-
// UnwrapMCPToolset extracts an MCP toolset from a potentially wrapped StartableToolSet.
42-
// Returns the MCP toolset if found, or nil if the toolset is not an MCP toolset.
43-
func UnwrapMCPToolset(toolset tools.ToolSet) *mcptools.Toolset {
44-
var innerToolset tools.ToolSet
45-
if startableTS, ok := toolset.(*agent.StartableToolSet); ok {
46-
innerToolset = startableTS.ToolSet
47-
} else {
48-
innerToolset = toolset
49-
}
50-
51-
if mcpToolset, ok := innerToolset.(*mcptools.Toolset); ok {
52-
return mcpToolset
53-
}
54-
55-
return nil
56-
}
57-
5841
type ResumeType string
5942

6043
// ElicitationResult represents the result of an elicitation request
@@ -459,7 +442,7 @@ func (r *LocalRuntime) CurrentMCPPrompts(ctx context.Context) map[string]mcptool
459442

460443
// Iterate through all toolsets of the current agent
461444
for _, toolset := range currentAgent.ToolSets() {
462-
if mcpToolset := UnwrapMCPToolset(toolset); mcpToolset != nil {
445+
if mcpToolset, ok := tools.As[*mcptools.Toolset](toolset); ok {
463446
slog.Debug("Found MCP toolset", "toolset", mcpToolset)
464447
// Discover prompts from this MCP toolset
465448
mcpPrompts := r.discoverMCPPrompts(ctx, mcpToolset)
@@ -639,7 +622,7 @@ func (r *LocalRuntime) emitToolsProgressively(ctx context.Context, a *agent.Agen
639622
isLast := i == totalToolsets-1
640623

641624
// Start the toolset if needed
642-
if startable, ok := toolset.(*agent.StartableToolSet); ok {
625+
if startable, ok := toolset.(*tools.StartableToolSet); ok {
643626
if !startable.IsStarted() {
644627
if err := startable.Start(ctx); err != nil {
645628
slog.Warn("Toolset start failed; skipping", "agent", a.Name(), "toolset", fmt.Sprintf("%T", startable.ToolSet), "error", err)
@@ -1013,11 +996,11 @@ func (r *LocalRuntime) getTools(ctx context.Context, a *agent.Agent, sessionSpan
1013996
// configureToolsetHandlers sets up elicitation and OAuth handlers for all toolsets of an agent.
1014997
func (r *LocalRuntime) configureToolsetHandlers(a *agent.Agent, events chan Event) {
1015998
for _, toolset := range a.ToolSets() {
1016-
toolset.SetElicitationHandler(r.elicitationHandler)
1017-
toolset.SetOAuthSuccessHandler(func() {
1018-
events <- Authorization(tools.ElicitationActionAccept, r.currentAgent)
1019-
})
1020-
toolset.SetManagedOAuth(r.managedOAuth)
999+
tools.ConfigureHandlers(toolset,
1000+
r.elicitationHandler,
1001+
func() { events <- Authorization(tools.ElicitationActionAccept, r.currentAgent) },
1002+
r.managedOAuth,
1003+
)
10211004
}
10221005
}
10231006

pkg/runtime/runtime_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,17 @@ import (
2929
)
3030

3131
type stubToolSet struct {
32-
tools.BaseToolSet
3332
startErr error
3433
tools []tools.Tool
3534
listErr error
3635
}
3736

37+
// Verify interface compliance
38+
var (
39+
_ tools.ToolSet = (*stubToolSet)(nil)
40+
_ tools.Startable = (*stubToolSet)(nil)
41+
)
42+
3843
func newStubToolSet(startErr error, toolsList []tools.Tool, listErr error) tools.ToolSet {
3944
return &stubToolSet{
4045
startErr: startErr,
@@ -44,6 +49,7 @@ func newStubToolSet(startErr error, toolsList []tools.Tool, listErr error) tools
4449
}
4550

4651
func (s *stubToolSet) Start(context.Context) error { return s.startErr }
52+
func (s *stubToolSet) Stop(context.Context) error { return nil }
4753
func (s *stubToolSet) Tools(context.Context) ([]tools.Tool, error) {
4854
if s.listErr != nil {
4955
return nil, s.listErr

pkg/session/session.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/docker/cagent/pkg/agent"
1212
"github.com/docker/cagent/pkg/chat"
1313
"github.com/docker/cagent/pkg/skills"
14+
"github.com/docker/cagent/pkg/tools"
1415
)
1516

1617
const (
@@ -482,10 +483,10 @@ func buildInvariantSystemMessages(a *agent.Agent) []chat.Message {
482483
}
483484

484485
for _, toolSet := range a.ToolSets() {
485-
if toolSet.Instructions() != "" {
486+
if instructions := tools.GetInstructions(toolSet); instructions != "" {
486487
messages = append(messages, chat.Message{
487488
Role: chat.MessageRoleSystem,
488-
Content: toolSet.Instructions(),
489+
Content: instructions,
489490
})
490491
}
491492
}

pkg/teamloader/filter.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ type filterTools struct {
4242
exclude bool
4343
}
4444

45+
// Verify interface compliance
46+
var _ tools.Instructable = (*filterTools)(nil)
47+
48+
// Instructions implements tools.Instructable by delegating to the inner toolset.
49+
func (f *filterTools) Instructions() string {
50+
return tools.GetInstructions(f.ToolSet)
51+
}
52+
4553
func (f *filterTools) Tools(ctx context.Context) ([]tools.Tool, error) {
4654
allTools, err := f.ToolSet.Tools(ctx)
4755
if err != nil {

pkg/teamloader/filter_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,51 @@ func TestWithToolsFilter_CaseSensitive(t *testing.T) {
119119
require.Len(t, result, 1)
120120
assert.Equal(t, "tool1", result[0].Name)
121121
}
122+
123+
type instructableToolSet struct {
124+
mockToolSet
125+
instructions string
126+
}
127+
128+
func (i *instructableToolSet) Instructions() string {
129+
return i.instructions
130+
}
131+
132+
func TestWithToolsFilter_InstructablePassthrough(t *testing.T) {
133+
// Test that filtering preserves instructions from inner toolset
134+
inner := &instructableToolSet{
135+
mockToolSet: mockToolSet{
136+
toolsFunc: func(context.Context) ([]tools.Tool, error) {
137+
return []tools.Tool{{Name: "tool1"}, {Name: "tool2"}}, nil
138+
},
139+
},
140+
instructions: "Test instructions for the toolset",
141+
}
142+
143+
wrapped := WithToolsFilter(inner, "tool1")
144+
145+
// Verify instructions are preserved through the filter wrapper
146+
instructions := tools.GetInstructions(wrapped)
147+
assert.Equal(t, "Test instructions for the toolset", instructions)
148+
149+
// Verify filtering still works
150+
result, err := wrapped.Tools(t.Context())
151+
require.NoError(t, err)
152+
require.Len(t, result, 1)
153+
assert.Equal(t, "tool1", result[0].Name)
154+
}
155+
156+
func TestWithToolsFilter_NonInstructableInner(t *testing.T) {
157+
// Test that filter works with toolsets that don't implement Instructable
158+
inner := &mockToolSet{
159+
toolsFunc: func(context.Context) ([]tools.Tool, error) {
160+
return []tools.Tool{{Name: "tool1"}}, nil
161+
},
162+
}
163+
164+
wrapped := WithToolsFilter(inner, "tool1")
165+
166+
// Verify instructions are empty for non-instructable inner toolset
167+
instructions := tools.GetInstructions(wrapped)
168+
assert.Empty(t, instructions)
169+
}

0 commit comments

Comments
 (0)