@@ -167,3 +167,163 @@ func TestInferStreamTool(t *testing.T) {
167167 }
168168 }
169169}
170+
171+ type EnhancedStreamInput struct {
172+ Query string `json:"query" jsonschema:"description=the search query"`
173+ }
174+
175+ func TestNewEnhancedStreamTool (t * testing.T ) {
176+ ctx := context .Background ()
177+
178+ t .Run ("simple_case" , func (t * testing.T ) {
179+ tl := NewEnhancedStreamTool [* EnhancedStreamInput ](
180+ & schema.ToolInfo {
181+ Name : "enhanced_stream_search" ,
182+ Desc : "search with enhanced stream output" ,
183+ ParamsOneOf : schema .NewParamsOneOfByParams (
184+ map [string ]* schema.ParameterInfo {
185+ "query" : {
186+ Type : "string" ,
187+ Desc : "the search query" ,
188+ },
189+ }),
190+ },
191+ func (ctx context.Context , input * EnhancedStreamInput ) (* schema.StreamReader [* schema.ToolResult ], error ) {
192+ sr , sw := schema.Pipe [* schema.ToolResult ](2 )
193+ sw .Send (& schema.ToolResult {
194+ Parts : []schema.ToolOutputPart {
195+ {Type : schema .ToolPartTypeText , Text : "result for: " + input .Query },
196+ },
197+ }, nil )
198+ sw .Send (& schema.ToolResult {
199+ Parts : []schema.ToolOutputPart {
200+ {Type : schema .ToolPartTypeText , Text : "more results" },
201+ },
202+ }, nil )
203+ sw .Close ()
204+ return sr , nil
205+ },
206+ )
207+
208+ info , err := tl .Info (ctx )
209+ assert .NoError (t , err )
210+ assert .Equal (t , "enhanced_stream_search" , info .Name )
211+
212+ sr , err := tl .StreamableRun (ctx , & schema.ToolArgument {TextArgument : `{"query":"test"}` })
213+ assert .NoError (t , err )
214+ defer sr .Close ()
215+
216+ idx := 0
217+ for {
218+ m , err := sr .Recv ()
219+ if errors .Is (err , io .EOF ) {
220+ break
221+ }
222+ assert .NoError (t , err )
223+
224+ if idx == 0 {
225+ assert .Len (t , m .Parts , 1 )
226+ assert .Equal (t , schema .ToolPartTypeText , m .Parts [0 ].Type )
227+ assert .Equal (t , "result for: test" , m .Parts [0 ].Text )
228+ } else {
229+ assert .Len (t , m .Parts , 1 )
230+ assert .Equal (t , "more results" , m .Parts [0 ].Text )
231+ }
232+ idx ++
233+ }
234+ assert .Equal (t , 2 , idx )
235+ })
236+ }
237+
238+ type FakeEnhancedStreamOption struct {
239+ Prefix string
240+ }
241+
242+ func FakeWithEnhancedStreamOption (prefix string ) tool.Option {
243+ return tool .WrapImplSpecificOptFn (func (t * FakeEnhancedStreamOption ) {
244+ t .Prefix = prefix
245+ })
246+ }
247+
248+ func fakeEnhancedStreamFunc (ctx context.Context , input EnhancedStreamInput ) (* schema.StreamReader [* schema.ToolResult ], error ) {
249+ return schema .StreamReaderFromArray ([]* schema.ToolResult {
250+ {
251+ Parts : []schema.ToolOutputPart {
252+ {Type : schema .ToolPartTypeText , Text : "result: " + input .Query },
253+ },
254+ },
255+ }), nil
256+ }
257+
258+ func fakeOptionableEnhancedStreamFunc (ctx context.Context , input EnhancedStreamInput , opts ... tool.Option ) (* schema.StreamReader [* schema.ToolResult ], error ) {
259+ baseOpt := & FakeEnhancedStreamOption {
260+ Prefix : "default" ,
261+ }
262+ option := tool .GetImplSpecificOptions (baseOpt , opts ... )
263+
264+ return schema .StreamReaderFromArray ([]* schema.ToolResult {
265+ {
266+ Parts : []schema.ToolOutputPart {
267+ {Type : schema .ToolPartTypeText , Text : option .Prefix + ": " + input .Query },
268+ },
269+ },
270+ }), nil
271+ }
272+
273+ func TestInferEnhancedStreamTool (t * testing.T ) {
274+ ctx := context .Background ()
275+
276+ t .Run ("infer_enhanced_stream_tool" , func (t * testing.T ) {
277+ tl , err := InferEnhancedStreamTool ("infer_enhanced_stream" , "test infer enhanced stream tool" , fakeEnhancedStreamFunc )
278+ assert .NoError (t , err )
279+
280+ info , err := tl .Info (ctx )
281+ assert .NoError (t , err )
282+ assert .Equal (t , "infer_enhanced_stream" , info .Name )
283+
284+ sr , err := tl .StreamableRun (ctx , & schema.ToolArgument {TextArgument : `{"query":"hello"}` })
285+ assert .NoError (t , err )
286+ defer sr .Close ()
287+
288+ m , err := sr .Recv ()
289+ assert .NoError (t , err )
290+ assert .Len (t , m .Parts , 1 )
291+ assert .Equal (t , "result: hello" , m .Parts [0 ].Text )
292+ })
293+ }
294+
295+ func TestInferOptionableEnhancedStreamTool (t * testing.T ) {
296+ ctx := context .Background ()
297+
298+ t .Run ("infer_optionable_enhanced_stream_tool" , func (t * testing.T ) {
299+ tl , err := InferOptionableEnhancedStreamTool ("infer_optionable_enhanced_stream" , "test infer optionable enhanced stream tool" , fakeOptionableEnhancedStreamFunc )
300+ assert .NoError (t , err )
301+
302+ info , err := tl .Info (ctx )
303+ assert .NoError (t , err )
304+ assert .Equal (t , "infer_optionable_enhanced_stream" , info .Name )
305+
306+ sr , err := tl .StreamableRun (ctx , & schema.ToolArgument {TextArgument : `{"query":"world"}` }, FakeWithEnhancedStreamOption ("custom" ))
307+ assert .NoError (t , err )
308+ defer sr .Close ()
309+
310+ m , err := sr .Recv ()
311+ assert .NoError (t , err )
312+ assert .Len (t , m .Parts , 1 )
313+ assert .Equal (t , "custom: world" , m .Parts [0 ].Text )
314+ })
315+
316+ t .Run ("infer_optionable_enhanced_stream_tool_default_option" , func (t * testing.T ) {
317+ tl , err := InferOptionableEnhancedStreamTool ("infer_optionable_enhanced_stream" , "test infer optionable enhanced stream tool" , fakeOptionableEnhancedStreamFunc )
318+ assert .NoError (t , err )
319+
320+ sr , err := tl .StreamableRun (ctx , & schema.ToolArgument {TextArgument : `{"query":"test"}` })
321+ assert .NoError (t , err )
322+ defer sr .Close ()
323+
324+ m , err := sr .Recv ()
325+ assert .NoError (t , err )
326+ assert .Len (t , m .Parts , 1 )
327+ assert .Equal (t , "default: test" , m .Parts [0 ].Text )
328+ })
329+ }
0 commit comments