Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ private static JsonSerializerOptions CreateDefaultOptions()
[JsonSerializable(typeof(ChatClientAgentThread.ThreadState))]
[JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))]
[JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))]
[JsonSerializable(typeof(Functions.ContextualFunctionProvider.ContextualFunctionProviderState))]

[ExcludeFromCodeCoverage]
internal sealed partial class JsonContext : JsonSerializerContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
Expand Down Expand Up @@ -60,6 +61,30 @@ public ContextualFunctionProvider(
int maxNumberOfFunctions,
ContextualFunctionProviderOptions? options = null,
ILoggerFactory? loggerFactory = null)
: this(vectorStore, vectorDimensions, functions, maxNumberOfFunctions, default(JsonElement), options, null, loggerFactory)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="ContextualFunctionProvider"/> class.
/// </summary>
/// <param name="vectorStore">An instance of a vector store.</param>
/// <param name="vectorDimensions">The number of dimensions to use for the memory embeddings.</param>
/// <param name="functions">The functions to vectorize and store for searching related functions.</param>
/// <param name="maxNumberOfFunctions">The maximum number of relevant functions to retrieve from the vector store.</param>
/// <param name="serializedState">A <see cref="JsonElement"/> representing the serialized provider state.</param>
/// <param name="options">Further optional settings for configuring the provider.</param>
/// <param name="jsonSerializerOptions">Optional serializer options. If not provided, <see cref="AgentJsonUtilities.DefaultOptions"/> will be used.</param>
/// <param name="loggerFactory">The logger factory to use for logging. If not provided, no logging will be performed.</param>
public ContextualFunctionProvider(
VectorStore vectorStore,
int vectorDimensions,
IEnumerable<AIFunction> functions,
int maxNumberOfFunctions,
JsonElement serializedState,
ContextualFunctionProviderOptions? options = null,
JsonSerializerOptions? jsonSerializerOptions = null,
ILoggerFactory? loggerFactory = null)
{
Throw.IfNull(vectorStore);
Throw.IfLessThan(vectorDimensions, 1, "Vector dimensions must be greater than 0");
Expand All @@ -81,6 +106,21 @@ public ContextualFunctionProvider(
EmbeddingValueProvider = this._options.EmbeddingValueProvider,
}
);

// Restore recent messages from serialized state if provided
if (serializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined)
{
JsonSerializerOptions jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions;
ContextualFunctionProviderState? state = serializedState.Deserialize(jso.GetTypeInfo(typeof(ContextualFunctionProviderState))) as ContextualFunctionProviderState;
if (state?.RecentMessages is { Count: > 0 })
{
// Restore recent messages respecting the limit (may truncate if limit changed afterwards).
foreach (ChatMessage message in state.RecentMessages.Take(this._options.NumberOfRecentMessagesInContext))
{
this._recentMessages.Enqueue(message);
}
}
}
}

/// <inheritdoc />
Expand Down Expand Up @@ -135,6 +175,22 @@ public override ValueTask InvokedAsync(InvokedContext context, CancellationToken
return default;
}

/// <summary>
/// Serializes the current provider state to a <see cref="JsonElement"/> containing the recent messages.
/// </summary>
/// <param name="jsonSerializerOptions">Optional serializer options. This parameter is not used; <see cref="AgentJsonUtilities.DefaultOptions"/> is always used for serialization.</param>
/// <returns>A <see cref="JsonElement"/> with the recent messages, or default if there are no recent messages.</returns>
public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null)
{
ContextualFunctionProviderState state = new();
if (this._options.NumberOfRecentMessagesInContext > 0 && !this._recentMessages.IsEmpty)
{
state.RecentMessages = this._recentMessages.Take(this._options.NumberOfRecentMessagesInContext).ToList();
}

return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ContextualFunctionProviderState)));
}

/// <summary>
/// Builds the context from chat messages.
/// </summary>
Expand All @@ -160,4 +216,9 @@ private async Task<string> BuildContextAsync(IEnumerable<ChatMessage> newMessage
.Where(m => !string.IsNullOrWhiteSpace(m?.Text))
.Select(m => m.Text));
}

internal sealed class ContextualFunctionProviderState
{
public List<ChatMessage>? RecentMessages { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Agents.AI.Functions;
Expand Down Expand Up @@ -304,6 +305,188 @@ public async Task ContextEmbeddingValueProvider_ReceivesRecentAndNewMessages_Asy
Assert.Equal("msg5", capturedNewMessages.ElementAt(1).Text);
}

[Fact]
public void Serialize_WithNoRecentMessages_ShouldReturnEmptyState()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var options = new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 3
};

var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: options);

// Act
JsonElement state = provider.Serialize();

// Assert
Assert.Equal(JsonValueKind.Object, state.ValueKind);
Assert.False(state.TryGetProperty("recentMessages", out _));
}

[Fact]
public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var options = new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 2
};

var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: options);

var messages = new[]
{
new ChatMessage() { Contents = [new TextContent("M1")] },
new ChatMessage() { Contents = [new TextContent("M2")] },
new ChatMessage() { Contents = [new TextContent("M3")] }
};

// Act
await provider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null));
JsonElement state = provider.Serialize();

// Assert
Assert.True(state.TryGetProperty("recentMessages", out JsonElement recentProperty));
Assert.Equal(JsonValueKind.Array, recentProperty.ValueKind);
int count = recentProperty.GetArrayLength();
Assert.Equal(2, count);
}

[Fact]
public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var options = new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 4
};

var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: options);

var messages = new[]
{
new ChatMessage() { Contents = [new TextContent("A")] },
new ChatMessage() { Contents = [new TextContent("B")] },
new ChatMessage() { Contents = [new TextContent("C")] },
new ChatMessage() { Contents = [new TextContent("D")] }
};

await provider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null));

// Act
JsonElement state = provider.Serialize();
var roundTrippedProvider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
serializedState: state,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 4
});

// Trigger search to verify messages are used
var invokingContext = new AIContextProvider.InvokingContext(Array.Empty<ChatMessage>());
await roundTrippedProvider.InvokingAsync(invokingContext);

// Assert
string expected = string.Join(Environment.NewLine, ["A", "B", "C", "D"]);
this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny<int>(), null, It.IsAny<CancellationToken>()), Times.Once);
}

[Fact]
public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
var initialProvider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 5
});

var messages = new[]
{
new ChatMessage() { Contents = [new TextContent("L1")] },
new ChatMessage() { Contents = [new TextContent("L2")] },
new ChatMessage() { Contents = [new TextContent("L3")] },
new ChatMessage() { Contents = [new TextContent("L4")] },
new ChatMessage() { Contents = [new TextContent("L5")] }
};

await initialProvider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null));
JsonElement state = initialProvider.Serialize();

// Act
var restoredProvider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
serializedState: state,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 3 // Lower limit
});

var invokingContext = new AIContextProvider.InvokingContext(Array.Empty<ChatMessage>());
await restoredProvider.InvokingAsync(invokingContext);

// Assert
string expected = string.Join(Environment.NewLine, ["L1", "L2", "L3"]);
this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny<int>(), null, It.IsAny<CancellationToken>()), Times.Once);
}

[Fact]
public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync()
{
// Arrange
var functions = new List<AIFunction> { CreateFunction("f1") };
JsonElement emptyState = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement);

// Act
var provider = new ContextualFunctionProvider(
vectorStore: this._vectorStoreMock.Object,
vectorDimensions: 1536,
functions: functions,
maxNumberOfFunctions: 5,
serializedState: emptyState,
options: new ContextualFunctionProviderOptions
{
NumberOfRecentMessagesInContext = 3
});

var invokingContext = new AIContextProvider.InvokingContext(Array.Empty<ChatMessage>());
await provider.InvokingAsync(invokingContext);

// Assert
this._collectionMock.Verify(c => c.SearchAsync(string.Empty, It.IsAny<int>(), null, It.IsAny<CancellationToken>()), Times.Once);
}

private static AIFunction CreateFunction(string name, string description = "")
{
return AIFunctionFactory.Create(() => { }, name, description);
Expand Down