Skip to content

Commit e4ebc9e

Browse files
Add FunctionInvokingChatClient.CurrentContext (#5786)
* Add `FunctionInvokingChatClient.CurrentContext` * PR feedback
1 parent 045aaab commit e4ebc9e

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ namespace Microsoft.Extensions.AI;
4040
/// </remarks>
4141
public partial class FunctionInvokingChatClient : DelegatingChatClient
4242
{
43+
/// <summary>The <see cref="FunctionInvocationContext"/> for the current function invocation.</summary>
44+
private static readonly AsyncLocal<FunctionInvocationContext?> _currentContext = new();
45+
4346
/// <summary>The logger to use for logging information about function invocation.</summary>
4447
private readonly ILogger _logger;
4548

@@ -50,6 +53,18 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
5053
/// <summary>Maximum number of roundtrips allowed to the inner client.</summary>
5154
private int? _maximumIterationsPerRequest;
5255

56+
/// <summary>
57+
/// Gets or sets the <see cref="FunctionInvocationContext"/> for the current function invocation.
58+
/// </summary>
59+
/// <remarks>
60+
/// This value flows across async calls.
61+
/// </remarks>
62+
public static FunctionInvocationContext? CurrentContext
63+
{
64+
get => _currentContext.Value;
65+
set => _currentContext.Value = value;
66+
}
67+
5368
/// <summary>
5469
/// Initializes a new instance of the <see cref="FunctionInvokingChatClient"/> class.
5570
/// </summary>
@@ -661,6 +676,7 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
661676
object? result = null;
662677
try
663678
{
679+
CurrentContext = context;
664680
result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
665681
}
666682
catch (Exception e)

test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,93 @@ public async Task SupportsConsecutiveStreamingUpdatesWithFunctionCalls()
553553
Assert.Equal("OK bye", singleUpdateContent.Text);
554554
}
555555

556+
[Fact]
557+
public async Task CanAccesssFunctionInvocationContextFromFunctionCall()
558+
{
559+
var invocationContexts = new List<FunctionInvokingChatClient.FunctionInvocationContext>();
560+
var function = AIFunctionFactory.Create(async (int i) =>
561+
{
562+
// The context should propogate across async calls
563+
await Task.Yield();
564+
565+
var context = FunctionInvokingChatClient.CurrentContext!;
566+
invocationContexts.Add(context);
567+
568+
if (i == 42)
569+
{
570+
context.Terminate = true;
571+
}
572+
573+
return $"Result {i}";
574+
}, "Func1");
575+
576+
var options = new ChatOptions
577+
{
578+
Tools = [function],
579+
};
580+
581+
// The invocation loop should terminate after the second function call
582+
List<ChatMessage> planBeforeTermination =
583+
[
584+
new ChatMessage(ChatRole.User, "hello"),
585+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary<string, object?> { ["i"] = 41 })]),
586+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 41")]),
587+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary<string, object?> { ["i"] = 42 })]),
588+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 42")]),
589+
];
590+
591+
// The full plan should never be fulfilled
592+
List<ChatMessage> plan =
593+
[
594+
.. planBeforeTermination,
595+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "Func1", new Dictionary<string, object?> { ["i"] = 43 })]),
596+
new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "Func1", result: "Result 43")]),
597+
new ChatMessage(ChatRole.Assistant, "world"),
598+
];
599+
600+
await InvokeAsync(() => InvokeAndAssertAsync(options, plan, expected: [
601+
.. planBeforeTermination,
602+
603+
// The last message is the one returned by the chat client
604+
// This message's content should contain the last function call before the termination
605+
new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func1", new Dictionary<string, object?> { ["i"] = 42 })]),
606+
]));
607+
608+
await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, expected: [
609+
.. planBeforeTermination,
610+
611+
// The last message is the one returned by the chat client
612+
// When streaming, function call content is removed from this message
613+
new ChatMessage(ChatRole.Assistant, []),
614+
]));
615+
616+
// The current context should be null outside the async call stack for the function invocation
617+
Assert.Null(FunctionInvokingChatClient.CurrentContext);
618+
619+
async Task InvokeAsync(Func<Task<List<ChatMessage>>> work)
620+
{
621+
invocationContexts.Clear();
622+
623+
var chatMessages = await work();
624+
625+
Assert.Collection(invocationContexts,
626+
c => AssertInvocationContext(c, iteration: 0, terminate: false),
627+
c => AssertInvocationContext(c, iteration: 1, terminate: true));
628+
629+
void AssertInvocationContext(FunctionInvokingChatClient.FunctionInvocationContext context, int iteration, bool terminate)
630+
{
631+
Assert.NotNull(context);
632+
Assert.Same(chatMessages, context.ChatMessages);
633+
Assert.Same(function, context.Function);
634+
Assert.Equal("Func1", context.CallContent.Name);
635+
Assert.Equal(0, context.FunctionCallIndex);
636+
Assert.Equal(1, context.FunctionCount);
637+
Assert.Equal(iteration, context.Iteration);
638+
Assert.Equal(terminate, context.Terminate);
639+
}
640+
}
641+
}
642+
556643
private static async Task<List<ChatMessage>> InvokeAndAssertAsync(
557644
ChatOptions options,
558645
List<ChatMessage> plan,

0 commit comments

Comments
 (0)