Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(ChatClientAgentThread.ThreadState))]
[JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))]
[JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))]
[JsonSerializable(typeof(Functions.ContextualFunctionProvider.ContextualFunctionProviderState))]

[ExcludeFromCodeCoverage]
internal sealed partial class JsonContext : JsonSerializerContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand Down Expand Up @@ -60,6 +61,30 @@ public ContextualFunctionProvider(
int maxNumberOfFunctions,
ContextualFunctionProviderOptions? options = null,
ILoggerFactory? loggerFactory = null)
: this(vectorStore, vectorDimensions, functions, maxNumberOfFunctions, default(JsonElement), options, null, loggerFactory)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="ContextualFunctionProvider"/> class.
/// </summary>
/// <param name="vectorStore">An instance of a vector store.</param>
/// <param name="vectorDimensions">The number of dimensions to use for the memory embeddings.</param>
/// <param name="functions">The functions to vectorize and store for searching related functions.</param>
/// <param name="maxNumberOfFunctions">The maximum number of relevant functions to retrieve from the vector store.</param>
/// <param name="serializedState">A <see cref="JsonElement"/> representing the serialized provider state.</param>
/// <param name="options">Further optional settings for configuring the provider.</param>
/// <param name="jsonSerializerOptions">Optional serializer options. If not provided, <see cref="AgentJsonUtilities.DefaultOptions"/> will be used.</param>
/// <param name="loggerFactory">The logger factory to use for logging. If not provided, no logging will be performed.</param>
public ContextualFunctionProvider(
VectorStore vectorStore,
int vectorDimensions,
IEnumerable<AIFunction> functions,
int maxNumberOfFunctions,
JsonElement serializedState,
ContextualFunctionProviderOptions? options = null,
JsonSerializerOptions? jsonSerializerOptions = null,
ILoggerFactory? loggerFactory = null)
{
Throw.IfNull(vectorStore);
Throw.IfLessThan(vectorDimensions, 1, "Vector dimensions must be greater than 0");
Expand All @@ -81,6 +106,21 @@ public ContextualFunctionProvider(
EmbeddingValueProvider = this._options.EmbeddingValueProvider,
}
);

// Restore recent messages from serialized state if provided
if (serializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined)
{
JsonSerializerOptions jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions;
ContextualFunctionProviderState? state = serializedState.Deserialize(jso.GetTypeInfo(typeof(ContextualFunctionProviderState))) as ContextualFunctionProviderState;
if (state?.RecentMessages is { Count: > 0 })
{
// Restore recent messages respecting the limit (may truncate if limit changed afterwards).
foreach (ChatMessage message in state.RecentMessages.Take(this._options.NumberOfRecentMessagesInContext))
{
this._recentMessages.Enqueue(message);
}
}
}
}

/// <inheritdoc />
Expand Down Expand Up @@ -141,6 +181,22 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken
return default;
}

/// <summary>
/// Serializes the current provider state to a <see cref="JsonElement"/> containing the recent messages.
/// </summary>
/// <param name="jsonSerializerOptions">Optional serializer options. This parameter is not used; <see cref="AgentJsonUtilities.DefaultOptions"/> is always used for serialization.</param>
/// <returns>A <see cref="JsonElement"/> with the recent messages, or default if there are no recent messages.</returns>
public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null)
{
ContextualFunctionProviderState state = new();
if (this._options.NumberOfRecentMessagesInContext > 0 && !this._recentMessages.IsEmpty)
{
state.RecentMessages = this._recentMessages.Take(this._options.NumberOfRecentMessagesInContext).ToList();
}

return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ContextualFunctionProviderState)));
}

/// <summary>
/// Builds the context from chat messages.
/// </summary>
Expand All @@ -166,4 +222,9 @@ private async Task<string> BuildContextAsync(IEnumerable<ChatMessage> newMessage
.Where(m => !string.IsNullOrWhiteSpace(m?.Text))
.Select(m => m.Text));
}

internal sealed class ContextualFunctionProviderState
{
public List<ChatMessage>? RecentMessages { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.AI.Functions;
Expand Down Expand Up @@ -305,7 +306,7 @@ public async Task ContextEmbeddingValueProvider_ReceivesRecentAndNewMessages_Asy
}

[Fact]
public async Task InvokedAsync_ShouldNotAddMessages_WhenExceptionIsPresent_Async()
public void Serialize_WithNoRecentMessages_ShouldReturnEmptyState()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
Expand All @@ -320,7 +321,6 @@ public async Task InvokedAsync_ShouldNotAddMessages_WhenExceptionIsPresent_Async
functions: functions,
maxNumberOfFunctions: 5,
options: options);

var message1 = new ChatMessage() { Contents = [new TextContent("msg1")] };
var message2 = new ChatMessage() { Contents = [new TextContent("msg2")] };
var message3 = new ChatMessage() { Contents = [new TextContent("msg3")] };
Expand All @@ -343,6 +343,188 @@ await provider.InvokedAsync(new AIContextProvider.InvokedContext([message3], nul
var expected = string.Join(Environment.NewLine, ["msg1", "msg2", "new message"]);
this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny<int>(), null, It.IsAny<CancellationToken>()), Times.Once);
}

[Fact]
public async Task InvokedAsync_ShouldNotAddMessages_WhenExceptionIsPresent_Async()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var options = new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 3
};

var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: options);

// Act
JsonElement state = provider.Serialize();

// Assert
Assert.Equal(JsonValueKind.Object, state.ValueKind);
Assert.False(state.TryGetProperty("recentMessages", out _));
}

[Fact]
public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var options = new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 2
};

var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: options);

var messages = new[]
{
new ChatMessage() { Contents = [new TextContent("M1")] },
new ChatMessage() { Contents = [new TextContent("M2")] },
new ChatMessage() { Contents = [new TextContent("M3")] }
};

// Act
await provider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null));
JsonElement state = provider.Serialize();

// Assert
Assert.True(state.TryGetProperty("recentMessages", out JsonElement recentProperty));
Assert.Equal(JsonValueKind.Array, recentProperty.ValueKind);
int count = recentProperty.GetArrayLength();
Assert.Equal(2, count);
}

[Fact]
public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var options = new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 4
};

var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: options);

var messages = new[]
{
new ChatMessage() { Contents = [new TextContent("A")] },
new ChatMessage() { Contents = [new TextContent("B")] },
new ChatMessage() { Contents = [new TextContent("C")] },
new ChatMessage() { Contents = [new TextContent("D")] }
};

await provider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null));

// Act
JsonElement state = provider.Serialize();
var roundTrippedProvider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
serializedState: state,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 4
});

// Trigger search to verify messages are used
var invokingContext = new AIContextProvider.InvokingContext(Array.Empty<ChatMessage>());
await roundTrippedProvider.InvokingAsync(invokingContext);

// Assert
string expected = string.Join(Environment.NewLine, ["A", "B", "C", "D"]);
this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny<int>(), null, It.IsAny<CancellationToken>()), Times.Once);
}

[Fact]
public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var initialProvider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 5
});

var messages = new[]
{
new ChatMessage() { Contents = [new TextContent("L1")] },
new ChatMessage() { Contents = [new TextContent("L2")] },
new ChatMessage() { Contents = [new TextContent("L3")] },
new ChatMessage() { Contents = [new TextContent("L4")] },
new ChatMessage() { Contents = [new TextContent("L5")] }
};

await initialProvider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null));
JsonElement state = initialProvider.Serialize();

// Act
var restoredProvider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
serializedState: state,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 3 // Lower limit
});

var invokingContext = new AIContextProvider.InvokingContext(Array.Empty<ChatMessage>());
await restoredProvider.InvokingAsync(invokingContext);

// Assert
string expected = string.Join(Environment.NewLine, ["L1", "L2", "L3"]);
this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny<int>(), null, It.IsAny<CancellationToken>()), Times.Once);
}

[Fact]
public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
JsonElement emptyState = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement);

// Act
var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
serializedState: emptyState,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 3
});

var invokingContext = new AIContextProvider.InvokingContext(Array.Empty<ChatMessage>());
await provider.InvokingAsync(invokingContext);

// Assert
this._collectionMock.Verify(c => c.SearchAsync(string.Empty, It.IsAny<int>(), null, It.IsAny<CancellationToken>()), Times.Once);
}

private static AIFunction CreateFunction(string name, string description = "")
{
Expand Down