Skip to content

Commit 07ba94e

Browse files
committed
feat(skill): support model frontmatter in inline mode to switch ChatModel
1 parent 97816cb commit 07ba94e

File tree

1 file changed

+64
-3
lines changed

1 file changed

+64
-3
lines changed

adk/middlewares/skill/skill.go

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,11 @@ type Config struct {
8686
// allowing the hub implementation to return a default agent.
8787
AgentHub AgentHub
8888
// ModelHub provides model instances for skills that specify a "model" field in frontmatter.
89-
// Currently only used with context mode (fork/isolate).
90-
// If nil, skills with model specification will return an error.
89+
// Used in two scenarios:
90+
// - With context mode (fork/isolate): The model is passed to the AgentFactory
91+
// - Without context mode (inline): The model becomes active for subsequent ChatModel requests
92+
// If nil, skills with model specification will be ignored in inline mode,
93+
// or return an error in context mode.
9194
ModelHub ModelHub
9295
}
9396

@@ -144,7 +147,7 @@ func NewHandler(ctx context.Context, config *Config) (adk.ChatModelAgentMiddlewa
144147
type skillHandler struct {
145148
*adk.BaseChatModelAgentMiddleware
146149
instruction string
147-
tool tool.BaseTool
150+
tool *skillTool
148151
}
149152

150153
func (h *skillHandler) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAgentContext) (context.Context, *adk.ChatModelAgentContext, error) {
@@ -153,6 +156,57 @@ func (h *skillHandler) BeforeAgent(ctx context.Context, runCtx *adk.ChatModelAge
153156
return ctx, runCtx, nil
154157
}
155158

159+
func (h *skillHandler) WrapModel(m model.BaseChatModel, mc *adk.ModelContext) model.BaseChatModel {
160+
if h.tool.modelHub == nil {
161+
return m
162+
}
163+
return &skillModelWrapper{
164+
inner: m,
165+
modelHub: h.tool.modelHub,
166+
}
167+
}
168+
169+
type skillModelWrapper struct {
170+
inner model.BaseChatModel
171+
modelHub ModelHub
172+
}
173+
174+
func (w *skillModelWrapper) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.Message, error) {
175+
m, err := w.getActiveModel(ctx)
176+
if err != nil {
177+
return nil, err
178+
}
179+
if m != nil {
180+
return m.Generate(ctx, input, opts...)
181+
}
182+
return w.inner.Generate(ctx, input, opts...)
183+
}
184+
185+
func (w *skillModelWrapper) Stream(ctx context.Context, input []*schema.Message, opts ...model.Option) (*schema.StreamReader[*schema.Message], error) {
186+
m, err := w.getActiveModel(ctx)
187+
if err != nil {
188+
return nil, err
189+
}
190+
if m != nil {
191+
return m.Stream(ctx, input, opts...)
192+
}
193+
return w.inner.Stream(ctx, input, opts...)
194+
}
195+
196+
func (w *skillModelWrapper) getActiveModel(ctx context.Context) (model.BaseChatModel, error) {
197+
modelName, found, _ := adk.GetRunLocalValue(ctx, activeModelKey)
198+
if !found {
199+
return nil, nil
200+
}
201+
name, ok := modelName.(string)
202+
if !ok || name == "" {
203+
return nil, nil
204+
}
205+
return w.modelHub.Get(ctx, name)
206+
}
207+
208+
const activeModelKey = "__skill_active_model__"
209+
156210
// New creates a new skill middleware.
157211
// It provides a tool for the agent to use skills.
158212
//
@@ -251,10 +305,17 @@ func (s *skillTool) InvokableRun(ctx context.Context, argumentsInJSON string, op
251305
case ContextModeIsolate:
252306
return s.runAgentMode(ctx, skill, false)
253307
default:
308+
if skill.Model != "" {
309+
s.setActiveModel(ctx, skill.Model)
310+
}
254311
return s.buildSkillResult(skill), nil
255312
}
256313
}
257314

315+
func (s *skillTool) setActiveModel(ctx context.Context, modelName string) {
316+
_ = adk.SetRunLocalValue(ctx, activeModelKey, modelName)
317+
}
318+
258319
func (s *skillTool) buildSkillResult(skill Skill) string {
259320
resultFmt := toolResult
260321
contentFmt := userContent

0 commit comments

Comments
 (0)