Skip to content

Commit 1ba8a0e

Browse files
committed
feat(tool/utils): add EnhancedStreamableTool implementation
- Add EnhancedStreamFunc and OptionableEnhancedStreamFunc type definitions - Implement InferEnhancedStreamTool and InferOptionableEnhancedStreamTool functions - Implement NewEnhancedStreamTool and newOptionableEnhancedStreamTool functions - Add enhancedStreamableTool struct implementing tool.EnhancedStreamableTool interface - Add unit tests for all new functions~
1 parent b013714 commit 1ba8a0e

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed

components/tool/utils/streamable_func.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,101 @@ func (s *streamableTool[T, D]) getToolName() string {
155155

156156
return s.info.Name
157157
}
158+
159+
// EnhancedStreamFunc is the function type for the enhanced streamable tool.
160+
type EnhancedStreamFunc[T any] func(ctx context.Context, input T) (output *schema.StreamReader[*schema.ToolResult], err error)
161+
162+
// OptionableEnhancedStreamFunc is the function type for the enhanced streamable tool with tool option.
163+
type OptionableEnhancedStreamFunc[T any] func(ctx context.Context, input T, opts ...tool.Option) (output *schema.StreamReader[*schema.ToolResult], err error)
164+
165+
// InferEnhancedStreamTool creates an EnhancedStreamableTool from a given function by inferring the ToolInfo from the function's request parameters.
166+
// End-user can pass a SchemaCustomizerFn in opts to customize the go struct tag parsing process, overriding default behavior.
167+
func InferEnhancedStreamTool[T any](toolName, toolDesc string, s EnhancedStreamFunc[T], opts ...Option) (tool.EnhancedStreamableTool, error) {
168+
ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...)
169+
if err != nil {
170+
return nil, err
171+
}
172+
173+
return NewEnhancedStreamTool(ti, s, opts...), nil
174+
}
175+
176+
// InferOptionableEnhancedStreamTool creates an EnhancedStreamableTool from a given function by inferring the ToolInfo from the function's request parameters, with tool option.
177+
func InferOptionableEnhancedStreamTool[T any](toolName, toolDesc string, s OptionableEnhancedStreamFunc[T], opts ...Option) (tool.EnhancedStreamableTool, error) {
178+
ti, err := goStruct2ToolInfo[T](toolName, toolDesc, opts...)
179+
if err != nil {
180+
return nil, err
181+
}
182+
183+
return newOptionableEnhancedStreamTool(ti, s, opts...), nil
184+
}
185+
186+
// NewEnhancedStreamTool Create an enhanced streaming tool, where the input is in JSON format and output is *schema.StreamReader[*schema.ToolResult].
187+
func NewEnhancedStreamTool[T any](desc *schema.ToolInfo, s EnhancedStreamFunc[T], opts ...Option) tool.EnhancedStreamableTool {
188+
return newOptionableEnhancedStreamTool(desc,
189+
func(ctx context.Context, input T, _ ...tool.Option) (output *schema.StreamReader[*schema.ToolResult], err error) {
190+
return s(ctx, input)
191+
},
192+
opts...)
193+
}
194+
195+
func newOptionableEnhancedStreamTool[T any](desc *schema.ToolInfo, s OptionableEnhancedStreamFunc[T], opts ...Option) tool.EnhancedStreamableTool {
196+
to := getToolOptions(opts...)
197+
198+
return &enhancedStreamableTool[T]{
199+
info: desc,
200+
um: to.um,
201+
Fn: s,
202+
}
203+
}
204+
205+
type enhancedStreamableTool[T any] struct {
206+
info *schema.ToolInfo
207+
208+
um UnmarshalArguments
209+
210+
Fn OptionableEnhancedStreamFunc[T]
211+
}
212+
213+
func (s *enhancedStreamableTool[T]) Info(ctx context.Context) (*schema.ToolInfo, error) {
214+
return s.info, nil
215+
}
216+
217+
func (s *enhancedStreamableTool[T]) StreamableRun(ctx context.Context, toolArgument *schema.ToolArgument, opts ...tool.Option) (
218+
outStream *schema.StreamReader[*schema.ToolResult], err error) {
219+
220+
var inst T
221+
if s.um != nil {
222+
var val any
223+
val, err = s.um(ctx, toolArgument.TextArgument)
224+
if err != nil {
225+
return nil, fmt.Errorf("[EnhancedLocalStreamFunc] failed to unmarshal arguments, toolName=%s, err=%w", s.getToolName(), err)
226+
}
227+
228+
gt, ok := val.(T)
229+
if !ok {
230+
return nil, fmt.Errorf("[EnhancedLocalStreamFunc] type err, toolName=%s, expected=%T, given=%T", s.getToolName(), inst, val)
231+
}
232+
inst = gt
233+
} else {
234+
inst = generic.NewInstance[T]()
235+
236+
err = sonic.UnmarshalString(toolArgument.TextArgument, &inst)
237+
if err != nil {
238+
return nil, fmt.Errorf("[EnhancedLocalStreamFunc] failed to unmarshal arguments in json, toolName=%s, err=%w", s.getToolName(), err)
239+
}
240+
}
241+
242+
return s.Fn(ctx, inst, opts...)
243+
}
244+
245+
func (s *enhancedStreamableTool[T]) GetType() string {
246+
return snakeToCamel(s.getToolName())
247+
}
248+
249+
func (s *enhancedStreamableTool[T]) getToolName() string {
250+
if s.info == nil {
251+
return ""
252+
}
253+
254+
return s.info.Name
255+
}

components/tool/utils/streamable_func_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,163 @@ func TestInferStreamTool(t *testing.T) {
167167
}
168168
}
169169
}
170+
171+
type EnhancedStreamInput struct {
172+
Query string `json:"query" jsonschema:"description=the search query"`
173+
}
174+
175+
func TestNewEnhancedStreamTool(t *testing.T) {
176+
ctx := context.Background()
177+
178+
t.Run("simple_case", func(t *testing.T) {
179+
tl := NewEnhancedStreamTool[*EnhancedStreamInput](
180+
&schema.ToolInfo{
181+
Name: "enhanced_stream_search",
182+
Desc: "search with enhanced stream output",
183+
ParamsOneOf: schema.NewParamsOneOfByParams(
184+
map[string]*schema.ParameterInfo{
185+
"query": {
186+
Type: "string",
187+
Desc: "the search query",
188+
},
189+
}),
190+
},
191+
func(ctx context.Context, input *EnhancedStreamInput) (*schema.StreamReader[*schema.ToolResult], error) {
192+
sr, sw := schema.Pipe[*schema.ToolResult](2)
193+
sw.Send(&schema.ToolResult{
194+
Parts: []schema.ToolOutputPart{
195+
{Type: schema.ToolPartTypeText, Text: "result for: " + input.Query},
196+
},
197+
}, nil)
198+
sw.Send(&schema.ToolResult{
199+
Parts: []schema.ToolOutputPart{
200+
{Type: schema.ToolPartTypeText, Text: "more results"},
201+
},
202+
}, nil)
203+
sw.Close()
204+
return sr, nil
205+
},
206+
)
207+
208+
info, err := tl.Info(ctx)
209+
assert.NoError(t, err)
210+
assert.Equal(t, "enhanced_stream_search", info.Name)
211+
212+
sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{TextArgument: `{"query":"test"}`})
213+
assert.NoError(t, err)
214+
defer sr.Close()
215+
216+
idx := 0
217+
for {
218+
m, err := sr.Recv()
219+
if errors.Is(err, io.EOF) {
220+
break
221+
}
222+
assert.NoError(t, err)
223+
224+
if idx == 0 {
225+
assert.Len(t, m.Parts, 1)
226+
assert.Equal(t, schema.ToolPartTypeText, m.Parts[0].Type)
227+
assert.Equal(t, "result for: test", m.Parts[0].Text)
228+
} else {
229+
assert.Len(t, m.Parts, 1)
230+
assert.Equal(t, "more results", m.Parts[0].Text)
231+
}
232+
idx++
233+
}
234+
assert.Equal(t, 2, idx)
235+
})
236+
}
237+
238+
type FakeEnhancedStreamOption struct {
239+
Prefix string
240+
}
241+
242+
func FakeWithEnhancedStreamOption(prefix string) tool.Option {
243+
return tool.WrapImplSpecificOptFn(func(t *FakeEnhancedStreamOption) {
244+
t.Prefix = prefix
245+
})
246+
}
247+
248+
func fakeEnhancedStreamFunc(ctx context.Context, input EnhancedStreamInput) (*schema.StreamReader[*schema.ToolResult], error) {
249+
return schema.StreamReaderFromArray([]*schema.ToolResult{
250+
{
251+
Parts: []schema.ToolOutputPart{
252+
{Type: schema.ToolPartTypeText, Text: "result: " + input.Query},
253+
},
254+
},
255+
}), nil
256+
}
257+
258+
func fakeOptionableEnhancedStreamFunc(ctx context.Context, input EnhancedStreamInput, opts ...tool.Option) (*schema.StreamReader[*schema.ToolResult], error) {
259+
baseOpt := &FakeEnhancedStreamOption{
260+
Prefix: "default",
261+
}
262+
option := tool.GetImplSpecificOptions(baseOpt, opts...)
263+
264+
return schema.StreamReaderFromArray([]*schema.ToolResult{
265+
{
266+
Parts: []schema.ToolOutputPart{
267+
{Type: schema.ToolPartTypeText, Text: option.Prefix + ": " + input.Query},
268+
},
269+
},
270+
}), nil
271+
}
272+
273+
func TestInferEnhancedStreamTool(t *testing.T) {
274+
ctx := context.Background()
275+
276+
t.Run("infer_enhanced_stream_tool", func(t *testing.T) {
277+
tl, err := InferEnhancedStreamTool("infer_enhanced_stream", "test infer enhanced stream tool", fakeEnhancedStreamFunc)
278+
assert.NoError(t, err)
279+
280+
info, err := tl.Info(ctx)
281+
assert.NoError(t, err)
282+
assert.Equal(t, "infer_enhanced_stream", info.Name)
283+
284+
sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{TextArgument: `{"query":"hello"}`})
285+
assert.NoError(t, err)
286+
defer sr.Close()
287+
288+
m, err := sr.Recv()
289+
assert.NoError(t, err)
290+
assert.Len(t, m.Parts, 1)
291+
assert.Equal(t, "result: hello", m.Parts[0].Text)
292+
})
293+
}
294+
295+
func TestInferOptionableEnhancedStreamTool(t *testing.T) {
296+
ctx := context.Background()
297+
298+
t.Run("infer_optionable_enhanced_stream_tool", func(t *testing.T) {
299+
tl, err := InferOptionableEnhancedStreamTool("infer_optionable_enhanced_stream", "test infer optionable enhanced stream tool", fakeOptionableEnhancedStreamFunc)
300+
assert.NoError(t, err)
301+
302+
info, err := tl.Info(ctx)
303+
assert.NoError(t, err)
304+
assert.Equal(t, "infer_optionable_enhanced_stream", info.Name)
305+
306+
sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{TextArgument: `{"query":"world"}`}, FakeWithEnhancedStreamOption("custom"))
307+
assert.NoError(t, err)
308+
defer sr.Close()
309+
310+
m, err := sr.Recv()
311+
assert.NoError(t, err)
312+
assert.Len(t, m.Parts, 1)
313+
assert.Equal(t, "custom: world", m.Parts[0].Text)
314+
})
315+
316+
t.Run("infer_optionable_enhanced_stream_tool_default_option", func(t *testing.T) {
317+
tl, err := InferOptionableEnhancedStreamTool("infer_optionable_enhanced_stream", "test infer optionable enhanced stream tool", fakeOptionableEnhancedStreamFunc)
318+
assert.NoError(t, err)
319+
320+
sr, err := tl.StreamableRun(ctx, &schema.ToolArgument{TextArgument: `{"query":"test"}`})
321+
assert.NoError(t, err)
322+
defer sr.Close()
323+
324+
m, err := sr.Recv()
325+
assert.NoError(t, err)
326+
assert.Len(t, m.Parts, 1)
327+
assert.Equal(t, "default: test", m.Parts[0].Text)
328+
})
329+
}

0 commit comments

Comments
 (0)