Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add prompt execution settings to AutoFunctionInvocationContext #10551

Merged
merged 10 commits into from
Feb 24, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,65 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
Assert.Equal(isStreaming, actualStreamingFlag);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
var result = await kernel.InvokePromptAsync("Test prompt", new(expectedExecutionSettings));

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

[Fact]
public async Task PromptExecutionSettingsArePropagatedFromInvokePromptStreamingToFilterContextAsync()
{
// Arrange
this._messageHandlerStub.ResponsesToReturn = GetFunctionCallingStreamingResponses();

var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [KernelFunctionFactory.CreateFromMethod(() => { }, "Function1")]);

AutoFunctionInvocationContext? actualContext = null;

var kernel = this.GetKernelWithFilter(plugin, (context, next) =>
{
actualContext = context;
return Task.CompletedTask;
});

var expectedExecutionSettings = new OpenAIPromptExecutionSettings
{
ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions
};

// Act
await foreach (var item in kernel.InvokePromptStreamingAsync("Test prompt", new(expectedExecutionSettings)))
{ }

// Assert
Assert.NotNull(actualContext);
Assert.Same(expectedExecutionSettings, actualContext!.ExecutionSettings);
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down Expand Up @@ -384,6 +385,7 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
// In such cases, we'll return the last message in the chat history.
var lastMessage = await this.FunctionCallsProcessor.ProcessFunctionCallsAsync(
chatMessageContent,
chatExecutionSettings,
chatHistory,
requestIndex,
(FunctionCallContent content) => IsRequestableTool(chatOptions.Tools, content),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// Processes AI function calls by iterating over the function calls, invoking them and adding the results to the chat history.
/// </summary>
/// <param name="chatMessageContent">The chat message content representing AI model response and containing function calls.</param>
/// <param name="executionSettings">The prompt execution settings.</param>
/// <param name="chatHistory">The chat history to add function invocation results to.</param>
/// <param name="requestIndex">AI model function(s) call request sequence index.</param>
/// <param name="checkIfFunctionAdvertised">Callback to check if a function was advertised to AI model or not.</param>
Expand All @@ -129,6 +130,7 @@ public FunctionCallsProcessor(ILogger? logger = null)
/// <returns>Last chat history message if function invocation filter requested processing termination, otherwise null.</returns>
public async Task<ChatMessageContent?> ProcessFunctionCallsAsync(
ChatMessageContent chatMessageContent,
PromptExecutionSettings? executionSettings,
ChatHistory chatHistory,
int requestIndex,
Func<FunctionCallContent, bool> checkIfFunctionAdvertised,
Expand Down Expand Up @@ -177,7 +179,8 @@ public FunctionCallsProcessor(ILogger? logger = null)
FunctionCount = functionCalls.Length,
CancellationToken = cancellationToken,
IsStreaming = isStreaming,
ToolCallId = functionCall.Id
ToolCallId = functionCall.Id,
ExecutionSettings = executionSettings
};

s_inflightAutoInvokes.Value++;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Microsoft.SemanticKernel.ChatCompletion;

Expand Down Expand Up @@ -79,6 +80,12 @@ public AutoFunctionInvocationContext(
/// </summary>
public ChatMessageContent ChatMessageContent { get; }

/// <summary>
/// The execution settings associated with the operation.
/// </summary>
[Experimental("SKEXP0001")]
public PromptExecutionSettings? ExecutionSettings { get; init; }

/// <summary>
/// Gets the <see cref="Microsoft.SemanticKernel.ChatCompletion.ChatHistory"/> associated with automatic function invocation.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Threading;

namespace Microsoft.SemanticKernel;
Expand Down Expand Up @@ -54,6 +55,12 @@ internal PromptRenderContext(Kernel kernel, KernelFunction function, KernelArgum
/// </summary>
public KernelArguments Arguments { get; }

/// <summary>
/// The execution settings associated with the operation.
/// </summary>
[Experimental("SKEXP0001")]
public PromptExecutionSettings? ExecutionSettings { get; init; }

/// <summary>
/// Gets or sets the rendered prompt.
/// </summary>
Expand Down
4 changes: 3 additions & 1 deletion dotnet/src/SemanticKernel.Abstractions/Kernel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,15 @@ internal async Task<PromptRenderContext> OnPromptRenderAsync(
KernelFunction function,
KernelArguments arguments,
bool isStreaming,
PromptExecutionSettings? executionSettings,
Func<PromptRenderContext, Task> renderCallback,
CancellationToken cancellationToken)
{
PromptRenderContext context = new(this, function, arguments)
{
CancellationToken = cancellationToken,
IsStreaming = isStreaming
IsStreaming = isStreaming,
ExecutionSettings = executionSettings
};

await InvokeFilterOrPromptRenderAsync(this._promptRenderFilters, renderCallback, context).ConfigureAwait(false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ private async Task<PromptRenderingResult> RenderPromptAsync(

Verify.NotNull(aiService);

var renderingContext = await kernel.OnPromptRenderAsync(this, arguments, isStreaming, async (context) =>
var renderingContext = await kernel.OnPromptRenderAsync(this, arguments, isStreaming, executionSettings, async (context) =>
{
renderedPrompt = await this._promptTemplate.RenderAsync(kernel, context.Arguments, cancellationToken).ConfigureAwait(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,42 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)
// Assert
Assert.Equal(isStreaming, actualStreamingFlag);
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task PromptExecutionSettingsArePropagatedToFilterContextAsync(bool isStreaming)
{
// Arrange
PromptExecutionSettings? actualExecutionSettings = null;

var mockTextGeneration = this.GetMockTextGeneration();

var function = KernelFunctionFactory.CreateFromPrompt("Prompt");

var kernel = this.GetKernelWithFilters(textGenerationService: mockTextGeneration.Object,
onPromptRender: (context, next) =>
{
actualExecutionSettings = context.ExecutionSettings;
return next(context);
});

var expectedExecutionSettings = new PromptExecutionSettings();

var arguments = new KernelArguments(expectedExecutionSettings);

// Act
if (isStreaming)
{
await foreach (var item in kernel.InvokeStreamingAsync(function, arguments))
{ }
}
else
{
await kernel.InvokeAsync(function, arguments);
}

// Assert
Assert.Same(expectedExecutionSettings, actualExecutionSettings);
}
}
Loading
Loading