Skip to content

Commit 8b6cafe

Browse files
committed
Add toolset to switch model
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent 96491c4 commit 8b6cafe

25 files changed

+1883
-156
lines changed

cagent-schema.json

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,8 @@
593593
"api",
594594
"a2a",
595595
"lsp",
596-
"user_prompt"
596+
"user_prompt",
597+
"switch_model"
597598
]
598599
},
599600
"instruction": {
@@ -700,6 +701,17 @@
700701
"description": "Timeout in seconds for the fetch tool",
701702
"minimum": 1
702703
},
704+
"models": {
705+
"type": "array",
706+
"description": "List of allowed model references for the switch_model tool. If not specified, all models defined in the config are available.",
707+
"items": {
708+
"type": "string"
709+
},
710+
"examples": [
711+
["fast_model", "powerful_model"],
712+
["openai/gpt-4o-mini", "anthropic/claude-sonnet-4-0"]
713+
]
714+
},
703715
"url": {
704716
"type": "string",
705717
"description": "URL for the a2a tool",
@@ -757,7 +769,8 @@
757769
"memory",
758770
"script",
759771
"fetch",
760-
"user_prompt"
772+
"user_prompt",
773+
"switch_model"
761774
]
762775
}
763776
}

e2e/switch_model_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package e2e_test
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/docker/cagent/pkg/agent"
11+
"github.com/docker/cagent/pkg/chat"
12+
"github.com/docker/cagent/pkg/config"
13+
"github.com/docker/cagent/pkg/runtime"
14+
"github.com/docker/cagent/pkg/session"
15+
"github.com/docker/cagent/pkg/teamloader"
16+
)
17+
18+
// setupSwitchModelTest creates a runtime with model switching support.
19+
func setupSwitchModelTest(t *testing.T) (runtime.Runtime, *agent.Agent) {
20+
t.Helper()
21+
22+
ctx := t.Context()
23+
agentSource, err := config.Resolve("testdata/switch_model.yaml")
24+
require.NoError(t, err)
25+
26+
_, runConfig := startRecordingAIProxy(t)
27+
loadResult, err := teamloader.LoadWithConfig(ctx, agentSource, runConfig)
28+
require.NoError(t, err)
29+
30+
modelSwitcherCfg := &runtime.ModelSwitcherConfig{
31+
Models: loadResult.Models,
32+
Providers: loadResult.Providers,
33+
ModelsGateway: runConfig.ModelsGateway,
34+
EnvProvider: runConfig.EnvProvider(),
35+
AgentDefaultModels: loadResult.AgentDefaultModels,
36+
}
37+
38+
rt, err := runtime.New(loadResult.Team, runtime.WithModelSwitcherConfig(modelSwitcherCfg))
39+
require.NoError(t, err)
40+
41+
rootAgent, err := loadResult.Team.Agent("root")
42+
require.NoError(t, err)
43+
44+
return rt, rootAgent
45+
}
46+
47+
// findSwitchModelCall searches session messages for a switch_model tool call containing the given model name.
48+
func findSwitchModelCall(sess *session.Session, modelName string) bool {
49+
for _, msg := range sess.GetAllMessages() {
50+
if msg.Message.Role != chat.MessageRoleAssistant || msg.Message.ToolCalls == nil {
51+
continue
52+
}
53+
for _, tc := range msg.Message.ToolCalls {
54+
if tc.Function.Name == "switch_model" && strings.Contains(tc.Function.Arguments, modelName) {
55+
return true
56+
}
57+
}
58+
}
59+
return false
60+
}
61+
62+
// TestSwitchModel_AgentCanSwitchModels verifies that an agent can use the switch_model tool
63+
// to change between models during a conversation.
64+
func TestSwitchModel_AgentCanSwitchModels(t *testing.T) {
65+
t.Parallel()
66+
67+
ctx := t.Context()
68+
rt, _ := setupSwitchModelTest(t)
69+
70+
// Switch to smart model
71+
sess := session.New(session.WithUserMessage("Switch to the smart model, then say hi"))
72+
_, err := rt.Run(ctx, sess)
73+
require.NoError(t, err)
74+
75+
assert.True(t, findSwitchModelCall(sess, "smart"), "Expected switch_model tool call with 'smart' model")
76+
assert.NotEmpty(t, sess.GetLastAssistantMessageContent(), "Expected a response after switching")
77+
78+
// Switch back to fast model
79+
sess.AddMessage(session.UserMessage("Now switch back to the fast model and say goodbye"))
80+
_, err = rt.Run(ctx, sess)
81+
require.NoError(t, err)
82+
83+
assert.True(t, findSwitchModelCall(sess, "fast"), "Expected switch_model tool call with 'fast' model")
84+
assert.NotEmpty(t, sess.GetLastAssistantMessageContent(), "Expected a response after switching back")
85+
}
86+
87+
// TestSwitchModel_ModelActuallyChanges verifies that after calling switch_model,
88+
// the agent's model object is updated to the new model.
89+
func TestSwitchModel_ModelActuallyChanges(t *testing.T) {
90+
t.Parallel()
91+
92+
ctx := t.Context()
93+
rt, rootAgent := setupSwitchModelTest(t)
94+
95+
assert.Contains(t, rootAgent.Model().ID(), "gpt-4o-mini", "Should start with gpt-4o-mini")
96+
97+
// Switch to smart model
98+
sess := session.New(session.WithUserMessage("Use the switch_model tool to switch to smart model, then just say 'done'"))
99+
_, err := rt.Run(ctx, sess)
100+
require.NoError(t, err)
101+
102+
assert.Contains(t, rootAgent.Model().ID(), "claude", "Model should have changed to claude")
103+
104+
// Verify the new model works
105+
sess.AddMessage(session.UserMessage("What is 2+2? Answer with just the number."))
106+
_, err = rt.Run(ctx, sess)
107+
require.NoError(t, err)
108+
109+
assert.NotEmpty(t, sess.GetLastAssistantMessageContent())
110+
}

0 commit comments

Comments
 (0)