Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add prompt execution settings to AutoFunctionInvocationContext #10551

Merged
merged 10 commits into from
Feb 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,65 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
Assert.Equal(isStreaming, actualStreamingFlag);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings));

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings)))
{ }

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down Expand Up @@ -384,6 +385,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// Processes AI function calls by iterating over the function calls, invoking them and adding the results to the chat history.
/// </summary>
/// <param name="chatMessageContent">The chat message content representing AI model response and containing function calls.</param>
/// <param name="executionSettings">The prompt execution settings.</param>
/// <param name="chatHistory">The chat history to add function invocation results to.</param>
/// <param name="requestIndex">AI model function(s) call request sequence index.</param>
/// <param name="checkIfFunctionAdvertised">Callback to check if a function was advertised to AI model or not.</param>
Expand All @@ -129,6 +130,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// <returns>Last chat history message if function invocation filter requested processing termination, otherwise null.</returns>
public async Task<ChatMessageContent?> ProcessFunctionCallsAsync(
ChatMessageContent chatMessageContent,
PromptExecutionSettings? executionSettings,
ChatHistory chatHistory,
int requestIndex,
Func<FunctionCallContent, bool> checkIfFunctionAdvertised,
Expand Down Expand Up @@ -177,7 +179,8 @@ public FunctionCallsProcessor(ILogger? logger = null)
FunctionCount = functionCalls.Length,
CancellationToken = cancellationToken,
IsStreaming = isStreaming,
ToolCallId = functionCall.Id
ToolCallId = functionCall.Id,
ExecutionSettings = executionSettings
};

s_inflightAutoInvokes.Value++;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.SemanticKernel.ChatCompletion;

Expand Down Expand Up @@ -79,6 +80,12 @@ public AutoFunctionInvocationContext(
/// </summary>
public ChatMessageContent ChatMessageContent { get; }

/// <summary>
/// The execution settings associated with the operation.
/// </summary>
[Experimental("SKEXP0001")]
public PromptExecutionSettings? ExecutionSettings { get; init; }

/// <summary>
/// Gets the <see cref="Microsoft.SemanticKernel.ChatCompletion.ChatHistory"/> associated with automatic function invocation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
using Microsoft.SemanticKernel.ChatCompletion;
#pragma warning disable IDE0005 // Using directive is unnecessary
using Microsoft.SemanticKernel.Connectors.FunctionCalling;
using Microsoft.SemanticKernel.Connectors.OpenAI;

#pragma warning restore IDE0005 // Using directive is unnecessary
using Moq;
using Xunit;
Expand All @@ -21,6 +23,7 @@ public class FunctionCallsProcessorTests
{
private readonly FunctionCallsProcessor _sut = new();
private readonly FunctionChoiceBehaviorOptions _functionChoiceBehaviorOptions = new();
private readonly OpenAIPromptExecutionSettings _openAIPromptExecutionSettings = new();

[Fact]
public void ItShouldReturnNoConfigurationIfNoBehaviorProvided()
Expand Down Expand Up @@ -94,6 +97,7 @@ async Task ProcessFunctionCallsRecursivelyToReachInflightLimitAsync()

await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: [],
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -123,6 +127,7 @@ public async Task ItShouldAddFunctionCallAssistantMessageToChatHistoryAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -151,6 +156,7 @@ public async Task ItShouldAddFunctionCallExceptionToChatHistoryAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -184,6 +190,7 @@ public async Task ItShouldAddFunctionInvocationExceptionToChatHistoryAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -212,6 +219,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionCallNotAdvertisedAsync(
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => false, // Return false to simulate that the function is not advertised
Expand Down Expand Up @@ -240,6 +248,7 @@ public async Task ItShouldAddErrorToChatHistoryIfFunctionIsNotRegisteredOnKernel
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -281,6 +290,7 @@ public async Task ItShouldInvokeFunctionsAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -347,6 +357,7 @@ public async Task ItShouldInvokeFiltersAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -436,6 +447,7 @@ public async Task ItShouldInvokeMultipleFiltersInOrderAsync(bool invokeConcurren
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -484,6 +496,7 @@ public async Task FilterCanOverrideArgumentsAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -536,6 +549,7 @@ public async Task FilterCanHandleExceptionAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -588,6 +602,7 @@ public async Task FiltersCanSkipFunctionExecutionAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -634,6 +649,7 @@ public async Task PreFilterCanTerminateOperationAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -678,6 +694,7 @@ public async Task PostFilterCanTerminateOperationAsync(bool invokeConcurrently)
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -732,6 +749,7 @@ public async Task ItShouldHandleChatMessageContentAsFunctionResultAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -767,6 +785,7 @@ public async Task ItShouldSerializeFunctionResultOfUnknowTypeAsync()
// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: chatHistory,
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
Expand Down Expand Up @@ -837,6 +856,40 @@ public void ItShouldSerializeFunctionResultsWithStringProperties()
Assert.Equal("{\"Text\":\"テスト\"}", result);
}

[Fact]
public async Task ItShouldPassPromptExecutionSettingsToAutoFunctionInvocationFilterAsync()
{
// Arrange
var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

Kernel kernel = CreateKernel(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var chatMessageContent = new ChatMessageContent();
chatMessageContent.Items.Add(new FunctionCallContent("Function1", "MyPlugin", arguments: new KernelArguments() { ["parameter"] = "function1-result" }));

// Act
await this._sut.ProcessFunctionCallsAsync(
chatMessageContent: chatMessageContent,
executionSettings: this._openAIPromptExecutionSettings,
chatHistory: new ChatHistory(),
requestIndex: 0,
checkIfFunctionAdvertised: (_) => true,
options: this._functionChoiceBehaviorOptions,
kernel: kernel!,
isStreaming: false,
cancellationToken: CancellationToken.None);

// Assert
Assert.NotNull(actualContext);
Assert.Same(this._openAIPromptExecutionSettings, actualContext!.ExecutionSettings);
}

private sealed class AutoFunctionInvocationFilter(
Func<AutoFunctionInvocationContext, Func<AutoFunctionInvocationContext, Task>, Task>? onAutoFunctionInvocation) : IAutoFunctionInvocationFilter
{
Expand Down
Loading