Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add a request index to the streamed function call update content #10129

Merged
merged 11 commits into from
Jan 14, 2025
Merged
24 changes: 12 additions & 12 deletions dotnet/samples/Concepts/FunctionCalling/FunctionCalling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public async Task RunPromptWithAutoFunctionChoiceBehaviorAdvertisingAllKernelFun

OpenAIPromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };

Console.WriteLine(await kernel.InvokePromptAsync("Given the current time of day and weather, what is the likely color of the sky in Boston?", new(settings)));
Console.WriteLine(await kernel.InvokePromptAsync("What is the likely color of the sky in Boston today?", new(settings)));

// Expected output: "Boston is currently experiencing a rainy day, hence, the likely color of the sky in Boston is grey."
}
Expand Down Expand Up @@ -104,7 +104,7 @@ public async Task RunPromptWithNoneFunctionChoiceBehaviorAdvertisingAllKernelFun

OpenAIPromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.None() };

Console.WriteLine(await kernel.InvokePromptAsync("Tell me which provided functions I would need to call to get the color of the sky in Boston on a specified date.", new(settings)));
Console.WriteLine(await kernel.InvokePromptAsync("Tell me which provided functions I would need to call to get the color of the sky in Boston for today.", new(settings)));

// Expected output: "You would first call the `HelperFunctions-GetCurrentUtcDateTime` function to get the current date time in UTC. Then, you would use the `HelperFunctions-GetWeatherForCity` function,
// passing in the city name as 'Boston' and the retrieved UTC date time. Note, however, that these functions won't directly tell you the color of the sky.
Expand All @@ -122,7 +122,7 @@ public async Task RunPromptTemplateConfigWithAutoFunctionChoiceBehaviorAdvertisi
// The `function_choice_behavior.functions` property is omitted which is equivalent to providing all kernel functions to the AI model.
string promptTemplateConfig = """
template_format: semantic-kernel
template: Given the current time of day and weather, what is the likely color of the sky in Boston?
template: What is the likely color of the sky in Boston today?
execution_settings:
default:
function_choice_behavior:
Expand Down Expand Up @@ -177,7 +177,7 @@ public async Task RunNonStreamingChatCompletionApiWithAutomaticFunctionInvocatio
IChatCompletionService chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();

ChatMessageContent result = await chatCompletionService.GetChatMessageContentAsync(
"Given the current time of day and weather, what is the likely color of the sky in Boston?",
"What is the likely color of the sky in Boston today?",
settings,
kernel);

Expand All @@ -204,7 +204,7 @@ public async Task RunStreamingChatCompletionApiWithAutomaticFunctionInvocationAs

// Act
await foreach (var update in chatCompletionService.GetStreamingChatMessageContentsAsync(
"Given the current time of day and weather, what is the likely color of the sky in Boston?",
"What is the likely color of the sky in Boston today?",
settings,
kernel))
{
Expand All @@ -231,7 +231,7 @@ public async Task RunNonStreamingChatCompletionApiWithManualFunctionInvocationAs
OpenAIPromptExecutionSettings settings = new() { FunctionChoiceBehavior = Microsoft.SemanticKernel.FunctionChoiceBehavior.Auto(autoInvoke: false) };

ChatHistory chatHistory = [];
chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?");
chatHistory.AddUserMessage("What is the likely color of the sky in Boston today?");

while (true)
{
Expand Down Expand Up @@ -293,7 +293,7 @@ public async Task RunStreamingChatCompletionApiWithManualFunctionCallingAsync()

// Create chat history with the initial user message
ChatHistory chatHistory = [];
chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?");
chatHistory.AddUserMessage("What is the likely color of the sky in Boston today?");

while (true)
{
Expand Down Expand Up @@ -360,7 +360,7 @@ public async Task RunNonStreamingPromptWithSimulatedFunctionAsync()
OpenAIPromptExecutionSettings settings = new() { FunctionChoiceBehavior = Microsoft.SemanticKernel.FunctionChoiceBehavior.Auto(autoInvoke: false) };

ChatHistory chatHistory = [];
chatHistory.AddUserMessage("Given the current time of day and weather, what is the likely color of the sky in Boston?");
chatHistory.AddUserMessage("What is the likely color of the sky in Boston today?");

while (true)
{
Expand Down Expand Up @@ -411,7 +411,7 @@ public async Task DisableFunctionCallingAsync()
// Alternatively, either omit assigning anything to the `FunctionChoiceBehavior` property or assign null to it to also disable function calling.
OpenAIPromptExecutionSettings settings = new() { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto(functions: []) };

Console.WriteLine(await kernel.InvokePromptAsync("Given the current time of day and weather, what is the likely color of the sky in Boston?", new(settings)));
Console.WriteLine(await kernel.InvokePromptAsync("What is the likely color of the sky in Boston today?", new(settings)));

// Expected output: "Sorry, I cannot answer this question as it requires real-time information which I, as a text-based model, cannot access."
}
Expand Down Expand Up @@ -538,8 +538,8 @@ private static Kernel CreateKernel()
kernel.ImportPluginFromFunctions("HelperFunctions",
[
kernel.CreateFunctionFromMethod(() => new List<string> { "Squirrel Steals Show", "Dog Wins Lottery" }, "GetLatestNewsTitles", "Retrieves latest news titles."),
kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentUtcDateTime", "Retrieves the current date time in UTC."),
kernel.CreateFunctionFromMethod((string cityName, string currentDateTime) =>
kernel.CreateFunctionFromMethod(() => DateTime.UtcNow.ToString("R"), "GetCurrentDateTimeInUtc", "Retrieves the current date time in UTC."),
kernel.CreateFunctionFromMethod((string cityName, string currentDateTimeInUtc) =>
cityName switch
{
"Boston" => "61 and rainy",
Expand All @@ -550,7 +550,7 @@ private static Kernel CreateKernel()
"Sydney" => "75 and sunny",
"Tel Aviv" => "80 and sunny",
_ => "31 and snowing",
}, "GetWeatherForCity", "Gets the current weather for the specified city"),
}, "GetWeatherForCity", "Gets the current weather for the specified city and specified date time."),
]);

return kernel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
callId: functionCallUpdate.ToolCallId,
name: functionCallUpdate.FunctionName,
arguments: streamingArguments,
functionCallIndex: functionCallUpdate.Index));
functionCallIndex: functionCallUpdate.Index)
{
RequestIndex = requestIndex,
});
}
}
streamedContents?.Add(openAIStreamingChatMessageContent);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ namespace Microsoft.SemanticKernel;
/// </summary>
public sealed class FunctionCallContentBuilder
{
private Dictionary<int, string>? _functionCallIdsByIndex = null;
private Dictionary<int, string>? _functionNamesByIndex = null;
private Dictionary<int, StringBuilder>? _functionArgumentBuildersByIndex = null;
private Dictionary<string, string>? _functionCallIdsByIndex = null;
private Dictionary<string, string>? _functionNamesByIndex = null;
private Dictionary<string, StringBuilder>? _functionArgumentBuildersByIndex = null;
private readonly JsonSerializerOptions? _jsonSerializerOptions;

/// <summary>
Expand Down Expand Up @@ -70,7 +70,7 @@ public IReadOnlyList<FunctionCallContent> Build()

for (int i = 0; i < this._functionCallIdsByIndex.Count; i++)
{
KeyValuePair<int, string> functionCallIndexAndId = this._functionCallIdsByIndex.ElementAt(i);
KeyValuePair<string, string> functionCallIndexAndId = this._functionCallIdsByIndex.ElementAt(i);

string? pluginName = null;
string functionName = string.Empty;
Expand All @@ -96,7 +96,7 @@ public IReadOnlyList<FunctionCallContent> Build()

[UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = "The warning is shown and should be addressed at the class creation site; there is no need to show it again at the function invocation sites.")]
[UnconditionalSuppressMessage("AOT", "IL3050:Calling members annotated with 'RequiresDynamicCodeAttribute' may break functionality when AOT compiling.", Justification = "The warning is shown and should be addressed at the class creation site; there is no need to show it again at the function invocation sites.")]
(KernelArguments? Arguments, Exception? Exception) GetFunctionArgumentsSafe(int functionCallIndex)
(KernelArguments? Arguments, Exception? Exception) GetFunctionArgumentsSafe(string functionCallIndex)
{
if (this._jsonSerializerOptions is not null)
{
Expand All @@ -118,7 +118,7 @@ public IReadOnlyList<FunctionCallContent> Build()
/// <returns>A tuple containing the KernelArguments and an Exception if any.</returns>
[RequiresUnreferencedCode("Uses reflection to deserialize function arguments if no JSOs are provided, making it incompatible with AOT scenarios.")]
[RequiresDynamicCode("Uses reflection to deserialize function arguments if no JSOs are provided, making it incompatible with AOT scenarios.")]
private (KernelArguments? Arguments, Exception? Exception) GetFunctionArguments(int functionCallIndex, JsonSerializerOptions? jsonSerializerOptions = null)
private (KernelArguments? Arguments, Exception? Exception) GetFunctionArguments(string functionCallIndex, JsonSerializerOptions? jsonSerializerOptions = null)
{
if (this._functionArgumentBuildersByIndex is null ||
!this._functionArgumentBuildersByIndex.TryGetValue(functionCallIndex, out StringBuilder? functionArgumentsBuilder))
Expand Down Expand Up @@ -170,33 +170,36 @@ public IReadOnlyList<FunctionCallContent> Build()
/// <param name="functionCallIdsByIndex">The dictionary of function call IDs by function call index.</param>
/// <param name="functionNamesByIndex">The dictionary of function names by function call index.</param>
/// <param name="functionArgumentBuildersByIndex">The dictionary of function argument builders by function call index.</param>
private static void TrackStreamingFunctionCallUpdate(StreamingFunctionCallUpdateContent update, ref Dictionary<int, string>? functionCallIdsByIndex, ref Dictionary<int, string>? functionNamesByIndex, ref Dictionary<int, StringBuilder>? functionArgumentBuildersByIndex)
private static void TrackStreamingFunctionCallUpdate(StreamingFunctionCallUpdateContent update, ref Dictionary<string, string>? functionCallIdsByIndex, ref Dictionary<string, string>? functionNamesByIndex, ref Dictionary<string, StringBuilder>? functionArgumentBuildersByIndex)
{
if (update is null)
{
// Nothing to track.
return;
}

// Create index that is unique across many requests.
var functionCallIndex = $"{update.RequestIndex}-{update.FunctionCallIndex}";

// If we have an call id, ensure the index is being tracked. Even if it's not a function update,
// we want to keep track of it so we can send back an error.
if (update.CallId is string id && !string.IsNullOrEmpty(id))
{
(functionCallIdsByIndex ??= [])[update.FunctionCallIndex] = id;
(functionCallIdsByIndex ??= [])[functionCallIndex] = id;
}

// Ensure we're tracking the function's name.
if (update.Name is string name && !string.IsNullOrEmpty(name))
{
(functionNamesByIndex ??= [])[update.FunctionCallIndex] = name;
(functionNamesByIndex ??= [])[functionCallIndex] = name;
}

// Ensure we're tracking the function's arguments.
if (update.Arguments is string argumentsUpdate)
{
if (!(functionArgumentBuildersByIndex ??= []).TryGetValue(update.FunctionCallIndex, out StringBuilder? arguments))
if (!(functionArgumentBuildersByIndex ??= []).TryGetValue(functionCallIndex, out StringBuilder? arguments))
{
functionArgumentBuildersByIndex[update.FunctionCallIndex] = arguments = new();
functionArgumentBuildersByIndex[functionCallIndex] = arguments = new();
}

arguments.Append(argumentsUpdate);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Text;

namespace Microsoft.SemanticKernel;
Expand Down Expand Up @@ -29,6 +30,12 @@ public class StreamingFunctionCallUpdateContent : StreamingKernelContent
/// </summary>
public int FunctionCallIndex { get; init; }

/// <summary>
/// Index of the request that produced this message content.
/// </summary>
[Experimental("SKEXP0001")]
public int RequestIndex { get; init; } = 0;
SergeyMenshykh marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Creates a new instance of the <see cref="StreamingFunctionCallUpdateContent"/> class.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,58 @@ public void ItShouldBuildFunctionCallContentForManyFunctions(JsonSerializerOptio
Assert.Null(functionCall2.Exception);
}

[Theory]
[ClassData(typeof(TestJsonSerializerOptionsForKernelArguments))]
public void ItShouldBuildFunctionCallContentForManyFunctionsCameInDifferentRequests(JsonSerializerOptions? jsos)
{
// Arrange
var sut = jsos is not null ? new FunctionCallContentBuilder(jsos) : new FunctionCallContentBuilder();

// Act

// f1 call was streamed as part of the first request
var f1_update1 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 0, functionCallIndex: 0, requestIndex: 0, callId: "f_1", name: "WeatherUtils-GetTemperature", arguments: null);
sut.Append(f1_update1);

var f1_update2 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 0, functionCallIndex: 0, requestIndex: 0, callId: null, name: null, arguments: "{\"city\":");
sut.Append(f1_update2);

var f1_update3 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 0, functionCallIndex: 0, requestIndex: 0, callId: null, name: null, arguments: "\"Seattle\"}");
sut.Append(f1_update3);

// f2 call was streamed as part of the second request
var f2_update1 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 0, functionCallIndex: 0, requestIndex: 1, callId: null, name: "WeatherUtils-GetHumidity", arguments: null);
sut.Append(f2_update1);

var f2_update2 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 0, functionCallIndex: 0, requestIndex: 1, callId: "f_2", name: null, arguments: null);
sut.Append(f2_update2);

var f2_update3 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 0, functionCallIndex: 0, requestIndex: 1, callId: null, name: null, arguments: "{\"city\":");
sut.Append(f2_update3);

var f2_update4 = CreateStreamingContentWithFunctionCallUpdate(choiceIndex: 0, functionCallIndex: 0, requestIndex: 1, callId: null, name: null, arguments: "\"Georgia\"}");
sut.Append(f2_update4);

var functionCalls = sut.Build();

// Assert
Assert.Equal(2, functionCalls.Count);

var functionCall1 = functionCalls.ElementAt(0);
Assert.Equal("f_1", functionCall1.Id);
Assert.Equal("WeatherUtils", functionCall1.PluginName);
Assert.Equal("GetTemperature", functionCall1.FunctionName);
Assert.Equal("Seattle", functionCall1.Arguments?["city"]);
Assert.Null(functionCall1.Exception);

var functionCall2 = functionCalls.ElementAt(1);
Assert.Equal("f_2", functionCall2.Id);
Assert.Equal("WeatherUtils", functionCall2.PluginName);
Assert.Equal("GetHumidity", functionCall2.FunctionName);
Assert.Equal("Georgia", functionCall2.Arguments?["city"]);
Assert.Null(functionCall2.Exception);
}

[Theory]
[ClassData(typeof(TestJsonSerializerOptionsForKernelArguments))]
public void ItShouldCaptureArgumentsDeserializationException(JsonSerializerOptions? jsos)
Expand Down Expand Up @@ -160,7 +212,7 @@ public void ItShouldCaptureArgumentsDeserializationException(JsonSerializerOptio
Assert.NotNull(functionCall.Exception);
}

private static StreamingChatMessageContent CreateStreamingContentWithFunctionCallUpdate(int choiceIndex, int functionCallIndex, string? callId, string? name, string? arguments)
private static StreamingChatMessageContent CreateStreamingContentWithFunctionCallUpdate(int choiceIndex, int functionCallIndex, string? callId, string? name, string? arguments, int requestIndex = 0)
{
var content = new StreamingChatMessageContent(AuthorRole.Assistant, null);

Expand All @@ -171,6 +223,7 @@ private static StreamingChatMessageContent CreateStreamingContentWithFunctionCal
CallId = callId,
Name = name,
Arguments = arguments,
RequestIndex = requestIndex
});

return content;
Expand Down
Loading