diff --git a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs index c400a1cb6c..c9de4cbc38 100644 --- a/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs +++ b/dotnet/src/Microsoft.Agents.AI/AgentJsonUtilities.cs @@ -69,6 +69,7 @@ private static JsonSerializerOptions CreateDefaultOptions() [JsonSerializable(typeof(ChatClientAgentThread.ThreadState))] [JsonSerializable(typeof(TextSearchProvider.TextSearchProviderState))] [JsonSerializable(typeof(ChatHistoryMemoryProvider.ChatHistoryMemoryProviderState))] + [JsonSerializable(typeof(Functions.ContextualFunctionProvider.ContextualFunctionProviderState))] [ExcludeFromCodeCoverage] internal sealed partial class JsonContext : JsonSerializerContext; diff --git a/dotnet/src/Microsoft.Agents.AI/Functions/ContextualFunctionProvider.cs b/dotnet/src/Microsoft.Agents.AI/Functions/ContextualFunctionProvider.cs new file mode 100644 index 0000000000..6caf8b1449 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/Functions/ContextualFunctionProvider.cs @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.VectorData; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI.Functions; + +/// +/// Represents a contextual function provider that performs RAG (Retrieval-Augmented Generation) on the provided functions to identify +/// the most relevant functions for the current context. The provider vectorizes the provided function names and descriptions +/// and stores them in the specified vector store, allowing for a vector search to find the most relevant +/// functions for a given context and provide the functions to the AI model/agent. +/// +/// +/// +/// +/// The provider is designed to work with in-memory vector stores. Using other vector stores +/// will require the data synchronization and data lifetime management to be done by the caller. +/// +/// +/// The in-memory vector store is supposed to be created per provider and not shared between providers +/// unless each provider uses a different collection name. Not following this may lead to a situation +/// where one provider identifies a function belonging to another provider as relevant and, as a result, +/// an attempt to access it by the first provider will fail because the function is not registered with it. +/// +/// +/// The provider uses function name as a key for the records and as such the specified vector store +/// should support record keys of string type. +/// +/// +/// +public sealed class ContextualFunctionProvider : AIContextProvider +{ + private readonly FunctionStore _functionStore; + private readonly ConcurrentQueue _recentMessages = []; + private readonly ContextualFunctionProviderOptions _options; + private bool _areFunctionsVectorized; + + /// + /// Initializes a new instance of the class. + /// + /// An instance of a vector store. + /// The number of dimensions to use for the memory embeddings. + /// The functions to vectorize and store for searching related functions. + /// The maximum number of relevant functions to retrieve from the vector store. + /// Further optional settings for configuring the provider. + /// The logger factory to use for logging. If not provided, no logging will be performed. + public ContextualFunctionProvider( + VectorStore vectorStore, + int vectorDimensions, + IEnumerable functions, + int maxNumberOfFunctions, + ContextualFunctionProviderOptions? options = null, + ILoggerFactory? loggerFactory = null) + : this(vectorStore, vectorDimensions, functions, maxNumberOfFunctions, default, options, null, loggerFactory) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// An instance of a vector store. + /// The number of dimensions to use for the memory embeddings. + /// The functions to vectorize and store for searching related functions. + /// The maximum number of relevant functions to retrieve from the vector store. + /// A representing the serialized provider state. + /// Further optional settings for configuring the provider. + /// Optional serializer options. If not provided, will be used. + /// The logger factory to use for logging. If not provided, no logging will be performed. + public ContextualFunctionProvider( + VectorStore vectorStore, + int vectorDimensions, + IEnumerable functions, + int maxNumberOfFunctions, + JsonElement serializedState, + ContextualFunctionProviderOptions? options = null, + JsonSerializerOptions? jsonSerializerOptions = null, + ILoggerFactory? loggerFactory = null) + { + Throw.IfNull(vectorStore); + Throw.IfLessThan(vectorDimensions, 1, "Vector dimensions must be greater than 0"); + Throw.IfNull(functions); + Throw.IfLessThan(maxNumberOfFunctions, 1, "Max number of functions must be greater than 0"); + + this._options = options ?? new ContextualFunctionProviderOptions(); + Throw.IfLessThan(this._options.NumberOfRecentMessagesInContext, 1, "Number of recent messages to include into context must be greater than 0"); + + this._functionStore = new FunctionStore( + vectorStore, + string.IsNullOrWhiteSpace(this._options.CollectionName) ? "functions" : this._options.CollectionName, + vectorDimensions, + functions, + maxNumberOfFunctions, + loggerFactory, + options: new() + { + EmbeddingValueProvider = this._options.EmbeddingValueProvider, + } + ); + + // Restore recent messages from serialized state if provided + if (serializedState.ValueKind is not JsonValueKind.Null and not JsonValueKind.Undefined) + { + JsonSerializerOptions jso = jsonSerializerOptions ?? AgentJsonUtilities.DefaultOptions; + ContextualFunctionProviderState? state = serializedState.Deserialize(jso.GetTypeInfo(typeof(ContextualFunctionProviderState))) as ContextualFunctionProviderState; + if (state?.RecentMessages is { Count: > 0 }) + { + // Restore recent messages respecting the limit (may truncate if limit changed afterwards). + foreach (ChatMessage message in state.RecentMessages.Take(this._options.NumberOfRecentMessagesInContext)) + { + this._recentMessages.Enqueue(message); + } + } + } + } + + /// + public override async ValueTask InvokingAsync(InvokingContext context, CancellationToken cancellationToken = default) + { + Throw.IfNull(context); + + // Vectorize the functions if they are not already vectorized + if (!this._areFunctionsVectorized) + { + await this._functionStore.SaveAsync(cancellationToken).ConfigureAwait(false); + + this._areFunctionsVectorized = true; + } + + // Build the search context + var searchContext = await this.BuildContextAsync(context.RequestMessages, cancellationToken).ConfigureAwait(false); + + // Get the function relevant to the context + var functions = await this._functionStore + .SearchAsync(searchContext, cancellationToken: cancellationToken) + .ConfigureAwait(false); + + return new AIContext { Tools = [.. functions] }; + } + + /// + public override ValueTask InvokedAsync(InvokedContext context, CancellationToken cancellationToken = default) + { + Throw.IfNull(context); + + // Don't add messages to the recent messages queue if the invocation failed + if (context.InvokeException is not null) + { + return default; + } + + // Add the request and response messages to the recent messages queue + foreach (var message in context.RequestMessages) + { + this._recentMessages.Enqueue(message); + } + + if (context.ResponseMessages is not null) + { + foreach (var message in context.ResponseMessages) + { + this._recentMessages.Enqueue(message); + } + } + + // If there are more messages than the configured limit, remove the oldest ones + while (this._recentMessages.Count > this._options.NumberOfRecentMessagesInContext) + { + this._recentMessages.TryDequeue(out _); + } + + return default; + } + + /// + /// Serializes the current provider state to a containing the recent messages. + /// + /// Optional serializer options. This parameter is not used; is always used for serialization. + /// A with the recent messages, or default if there are no recent messages. + public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptions = null) + { + ContextualFunctionProviderState state = new(); + if (this._options.NumberOfRecentMessagesInContext > 0 && !this._recentMessages.IsEmpty) + { + state.RecentMessages = this._recentMessages.Take(this._options.NumberOfRecentMessagesInContext).ToList(); + } + + return JsonSerializer.SerializeToElement(state, AgentJsonUtilities.DefaultOptions.GetTypeInfo(typeof(ContextualFunctionProviderState))); + } + + /// + /// Builds the context from chat messages. + /// + /// The new messages. + /// The cancellation token to use for cancellation. + private async Task BuildContextAsync(IEnumerable newMessages, CancellationToken cancellationToken) + { + if (this._options.ContextEmbeddingValueProvider is not null) + { + // Ensure we only take the recent messages up to the configured limit + var recentMessages = this._recentMessages + .Skip(Math.Max(0, this._recentMessages.Count - this._options.NumberOfRecentMessagesInContext)); + + return await this._options.ContextEmbeddingValueProvider.Invoke(recentMessages, newMessages, cancellationToken).ConfigureAwait(false); + } + + // Build context by concatenating the recent messages and the new messages + return string.Join( + Environment.NewLine, + this._recentMessages + .Skip(Math.Max(0, this._recentMessages.Count - this._options.NumberOfRecentMessagesInContext)) + .Concat(newMessages) + .Where(m => !string.IsNullOrWhiteSpace(m?.Text)) + .Select(m => m.Text)); + } + + internal sealed class ContextualFunctionProviderState + { + public List? RecentMessages { get; set; } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI/Functions/ContextualFunctionProviderOptions.cs b/dotnet/src/Microsoft.Agents.AI/Functions/ContextualFunctionProviderOptions.cs new file mode 100644 index 0000000000..f56a18b183 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/Functions/ContextualFunctionProviderOptions.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Functions; + +/// +/// Options for the . +/// +public sealed class ContextualFunctionProviderOptions +{ + /// + /// Gets or sets the collection name to use for storing and retrieving functions. + /// + /// If not set, the default value "functions" will be used. + public string? CollectionName { get; set; } + + /// + /// Gets or sets the number of recent messages (messages from previous model/agent invocations) the provider uses to form a context. + /// The provider collects all messages from all model/agent invocations, up to this number, + /// and prepends them to the new messages of the current model/agent invocation to build a context. + /// While collecting new messages, the provider will remove the oldest messages + /// to keep the number of recent messages within the specified limit. + /// + /// + /// Using the recent messages together with the new messages can be very useful + /// in cases where the model/agent is prompted to perform a task that requires details from + /// previous invocation(s). For example, if the agent is asked to provision an Azure resource in the first + /// invocation and deploy the resource in the second invocation, the second invocation will need + /// information about the provisioned resource in the first invocation to deploy it. + /// + public int NumberOfRecentMessagesInContext { get; set; } = 2; + + /// + /// Gets or sets a callback function that returns a value used to create a context embedding. The value is vectorized, + /// and the resulting vector is used to perform vector searches for functions relevant to the context. + /// If not provided, the default behavior is to concatenate the non-empty messages into a single string, + /// separated by a new line. + /// + /// + /// The callback receives three parameters: + /// `recentMessages` - messages from the previous model/agent invocations. + /// `newMessages` - the new messages of the current model/agent invocation. + /// `cancellationToken` - a cancellation token that can be used to cancel the operation. + /// + public Func, IEnumerable, CancellationToken, Task>? ContextEmbeddingValueProvider { get; set; } + + /// + /// Gets or sets a callback function that returns a value used to create a function embedding. The value is vectorized, + /// and the resulting vector is stored in the vector store for use in vector searches for functions relevant + /// to the context. + /// If not provided, the default behavior is to concatenate the function name and description into a single string. + /// + /// + /// The callback receives two parameters: + /// `function` - the function to get embedding value for. + /// `cancellationToken` - a cancellation token that can be used to cancel the operation. + /// + public Func>? EmbeddingValueProvider { get; set; } +} diff --git a/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStore.cs b/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStore.cs new file mode 100644 index 0000000000..45e69bb327 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStore.cs @@ -0,0 +1,185 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.VectorData; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.Agents.AI.Functions; + +/// +/// Represents a vector store for objects where the function name and description can be used for similarity searches. +/// +internal sealed class FunctionStore +{ + private readonly VectorStore _vectorStore; + private readonly Dictionary _functionByName; + private readonly string _collectionName; + private readonly int _maxNumberOfFunctions; + private readonly ILogger _logger; + private readonly FunctionStoreOptions _options; + private readonly VectorStoreCollection> _collection; + private bool _isCollectionExistenceAsserted; + + /// + /// Initializes a new instance of the class. + /// + /// The vector store to use for storing functions. + /// The name of the collection to use for storing and retrieving functions. + /// The number of dimensions to use for the memory embeddings. + /// The functions to vectorize and store for searching related functions. + /// The maximum number of relevant functions to retrieve from the vector store. + /// The logger factory to use for logging. If not provided, no logging will be performed. + /// The options to use for the function store. + internal FunctionStore( + VectorStore vectorStore, + string collectionName, + int vectorDimensions, + IEnumerable functions, + int maxNumberOfFunctions, + ILoggerFactory? loggerFactory = default, + FunctionStoreOptions? options = null) + { + Throw.IfNull(vectorStore); + Throw.IfNullOrWhitespace(collectionName); + Throw.IfLessThan(vectorDimensions, 1, "Vector dimensions must be greater than 0"); + Throw.IfNull(functions); + Throw.IfLessThan(maxNumberOfFunctions, 1, "Max number of functions must be greater than 0"); + + this._vectorStore = vectorStore; + this._collectionName = collectionName; + this._functionByName = functions.ToDictionary(function => function.Name); + this._maxNumberOfFunctions = maxNumberOfFunctions; + this._logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); + this._options = options ?? new FunctionStoreOptions(); + + // Create and assert the collection support record keys of string type + this._collection = this._vectorStore.GetDynamicCollection(collectionName, new VectorStoreCollectionDefinition() + { + Properties = [ + new VectorStoreKeyProperty("Name", typeof(string)), + new VectorStoreVectorProperty("Embedding", typeof(string), dimensions: vectorDimensions) + ] + }); + } + + /// + /// Saves the functions to the vector store. + /// + /// The cancellation token to use for cancellation. + public async Task SaveAsync(CancellationToken cancellationToken = default) + { + // Get function data to vectorize + var nameSourcePairs = await this.GetFunctionsVectorizationInfoAsync(cancellationToken).ConfigureAwait(false); + + var functionRecords = new List>(nameSourcePairs.Count); + + // Create vector store records + for (var i = 0; i < nameSourcePairs.Count; i++) + { + var (name, vectorizationSource) = nameSourcePairs[i]; + + functionRecords.Add(new Dictionary() + { + ["Name"] = name, + ["Embedding"] = vectorizationSource + }); + } + + // Create collection and upsert all vector store records + await this._collection.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false); + + await this._collection.UpsertAsync(functionRecords, cancellationToken: cancellationToken).ConfigureAwait(false); + } + + /// + /// Searches for functions based on the provided context. + /// + /// The context to search for functions. + /// The cancellation token to use for cancellation. + public async Task> SearchAsync(string context, CancellationToken cancellationToken = default) + { + await this.AssertCollectionExistsAsync(cancellationToken).ConfigureAwait(false); + + List>> results = new(); + + await foreach (var result in this._collection + .SearchAsync(context, top: this._maxNumberOfFunctions, cancellationToken: cancellationToken).ConfigureAwait(false)) + { + results.Add(result); + } + + this._logger.LogFunctionsSearchResults(context, this._maxNumberOfFunctions, results); + + return results.Select(result => this._functionByName[(string)result.Record["Name"]!]); + } + + /// + /// Get the function vectorization information, which includes the function name and the source used for vectorization. + /// + /// The cancellation token to use for cancellation. + /// The function name and vectorization source pairs. + private async Task> GetFunctionsVectorizationInfoAsync(CancellationToken cancellationToken) + { + List nameSourcePairs = new(this._functionByName.Count); + + var provider = this._options.EmbeddingValueProvider ?? ((function, _) => + { + string descriptionPart = string.IsNullOrEmpty(function.Description) ? string.Empty : $", description: {function.Description}"; + return Task.FromResult($"Function name: {function.Name}{descriptionPart}"); + }); + + foreach (KeyValuePair pair in this._functionByName) + { + var vectorizationSource = await provider.Invoke(pair.Value, cancellationToken).ConfigureAwait(false); + + nameSourcePairs.Add(new FunctionVectorizationInfo(pair.Key, vectorizationSource)); + } + + this._logger.LogFunctionsVectorizationInfo(nameSourcePairs); + + return nameSourcePairs; + } + + /// + /// Asserts that the collection exists in the vector store. + /// + /// The cancellation token to use for cancellation. + private async Task AssertCollectionExistsAsync(CancellationToken cancellationToken) + { + if (!this._isCollectionExistenceAsserted) + { + if (!await this._collection.CollectionExistsAsync(cancellationToken).ConfigureAwait(false)) + { + throw new InvalidOperationException($"Collection '{this._collectionName}' does not exist."); + } + + this._isCollectionExistenceAsserted = true; + } + } + + internal readonly struct FunctionVectorizationInfo + { + public string Name { get; } + + public string VectorizationSource { get; } + + public FunctionVectorizationInfo(string name, string vectorizationSource) + { + this.Name = name; + this.VectorizationSource = vectorizationSource; + } + + public void Deconstruct(out string name, out string vectorizationSource) + { + name = this.Name; + vectorizationSource = this.VectorizationSource; + } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStoreLoggingExtensions.cs b/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStoreLoggingExtensions.cs new file mode 100644 index 0000000000..ac13d83e88 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStoreLoggingExtensions.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.VectorData; + +namespace Microsoft.Agents.AI.Functions; + +[ExcludeFromCodeCoverage] +internal static class FunctionStoreLoggingExtensions +{ + internal static void LogFunctionsVectorizationInfo(this ILogger logger, IList vectorizationInfo) + { + logger.LogInformation("ContextualFunctionProvider: Number of function to vectorize: {Count}", vectorizationInfo.Count); + + if (logger.IsEnabled(LogLevel.Trace)) + { + logger.LogTrace("ContextualFunctionProvider: Functions vectorization info: {VectorizationInfo}", + string.Join(", ", vectorizationInfo.Select(info => $"\"Function: {info.Name}, VectorizationSource: {info.VectorizationSource}\""))); + } + } + + internal static void LogFunctionsSearchResults(this ILogger logger, string context, int maxNumberOfFunctionsToReturn, IList>> results) + { + logger.LogInformation("ContextualFunctionProvider: Search returned {Count} functions, with a maximum limit of {MaxCount}", results.Count, maxNumberOfFunctionsToReturn); + + if (logger.IsEnabled(LogLevel.Trace)) + { + logger.LogTrace("ContextualFunctionProvider: Functions search results for context {Context} with a maximum limit of {MaxCount}: {Results}", + $"\"{context}\"", + maxNumberOfFunctionsToReturn, + string.Join(", ", results.Select(result => $"\"Function: {result.Record["Name"]}, Score: {result.Score}\""))); + } + } +} diff --git a/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStoreOptions.cs b/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStoreOptions.cs new file mode 100644 index 0000000000..3665efb579 --- /dev/null +++ b/dotnet/src/Microsoft.Agents.AI/Functions/FunctionStoreOptions.cs @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; + +namespace Microsoft.Agents.AI.Functions; + +/// +/// Options for the +/// +internal sealed class FunctionStoreOptions +{ + /// + /// A callback function that returns a value used to create a function embedding. The value is vectorized, + /// and the resulting vector is stored in the vector store for use in vector searches for functions relevant + /// to the context. + /// If not provided, the default behavior is to concatenate the function name and description into a single string. + /// + public Func>? EmbeddingValueProvider { get; set; } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Functions/ContextualFunctionProviderTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Functions/ContextualFunctionProviderTests.cs new file mode 100644 index 0000000000..5a88c1ad69 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Functions/ContextualFunctionProviderTests.cs @@ -0,0 +1,533 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Agents.AI.Functions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests.Functions; + +/// +/// Contains unit tests for the class. +/// +public sealed class ContextualFunctionProviderTests +{ + private readonly Mock _vectorStoreMock; + private readonly Mock>> _collectionMock; + + public ContextualFunctionProviderTests() + { + this._vectorStoreMock = new Mock(MockBehavior.Strict); + this._collectionMock = new Mock>>(MockBehavior.Strict); + + this._vectorStoreMock + .Setup(vs => vs.GetDynamicCollection(It.IsAny(), It.IsAny())) + .Returns(this._collectionMock.Object); + + this._collectionMock + .Setup(c => c.CollectionExistsAsync(It.IsAny())) + .ReturnsAsync(true); + + this._collectionMock + .Setup(c => c.EnsureCollectionExistsAsync(It.IsAny())) + .Returns(Task.CompletedTask); + + this._collectionMock + .Setup(c => c.UpsertAsync(It.IsAny>>(), It.IsAny())) + .Returns(Task.CompletedTask); + + this._collectionMock + .Setup(c => c.SearchAsync(It.IsAny(), It.IsAny(), null, It.IsAny())) + .Returns(AsyncEnumerable.Empty>>()); + } + + [Fact] + public void Constructor_ShouldThrow_OnInvalidArguments() + { + // Arrange + var vectorStore = new Mock().Object; + var functions = new List { CreateFunction("f1") }; + + // Act & Assert + Assert.Throws(() => new ContextualFunctionProvider(null!, 1, functions, 3)); + Assert.Throws(() => new ContextualFunctionProvider(vectorStore, 0, functions, 3)); + Assert.Throws(() => new ContextualFunctionProvider(vectorStore, 1, null!, 3)); + } + + [Fact] + public async Task Invoking_ShouldVectorizeFunctions_Once_Async() + { + // Arrange + var function = CreateFunction("f1", "desc"); + var functions = new List { function }; + + this._collectionMock + .Setup(c => c.UpsertAsync(It.IsAny>>(), It.IsAny())) + .Returns(Task.CompletedTask); + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5); + + var messages = new List { new() { Contents = [new TextContent("hello")] } }; + var context = new AIContextProvider.InvokingContext(messages); + + // Act + await provider.InvokingAsync(context); + await provider.InvokingAsync(context); + + // Assert + this._collectionMock.Verify( + c => c.UpsertAsync(It.IsAny>>(), It.IsAny()), + Times.Once); + } + + [Fact] + public async Task Invoking_ShouldReturnRelevantFunctions_Async() + { + // Arrange + var function = CreateFunction("f1", "desc"); + var functions = new List { function }; + + var searchResult = new VectorSearchResult>( + new Dictionary + { + ["Name"] = function.Name, + ["Description"] = function.Description + }, + 0.99f + ); + + this._collectionMock + .Setup(c => c.SearchAsync(It.IsAny(), It.IsAny(), null, It.IsAny())) + .Returns(new[] { searchResult }.ToAsyncEnumerable()); + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5); + + var messages = new List { new() { Contents = [new TextContent("context")] } }; + var context = new AIContextProvider.InvokingContext(messages); + + // Act + var result = await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(result); + Assert.NotNull(result.Tools); + Assert.Single(result.Tools); + Assert.Equal("f1", result.Tools[0].Name); + this._collectionMock.Verify( + c => c.SearchAsync("context", 5, null, It.IsAny()), + Times.Once); + } + + [Fact] + public async Task BuildContext_ShouldUseContextEmbeddingValueProvider_Async() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + var options = new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 3, + ContextEmbeddingValueProvider = (recentMessages, newMessages, _) => + { + Assert.Equal(3, recentMessages.Count()); + Assert.Single(newMessages); + return Task.FromResult("custom context"); + } + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + + var message1 = new ChatMessage() { Contents = [new TextContent("msg1")] }; + var message2 = new ChatMessage() { Contents = [new TextContent("msg2")] }; + var message3 = new ChatMessage() { Contents = [new TextContent("msg3")] }; + var message4 = new ChatMessage() { Contents = [new TextContent("msg4")] }; + var message5 = new ChatMessage() { Contents = [new TextContent("msg5")] }; + + // Simulate previous invocations to populate recent messages + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message1], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message2], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message3], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message4], null) { ResponseMessages = [] }); + + var messages = new List { message5 }; + var context = new AIContextProvider.InvokingContext(messages); + + // Act + await provider.InvokingAsync(context); + + // Assert + this._collectionMock.Verify( + c => c.SearchAsync("custom context", It.IsAny(), null, It.IsAny()), + Times.Once); + } + + [Fact] + public async Task BuildContext_ShouldConcatenateMessages_Async() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + var options = new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 3 + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + + var message1 = new ChatMessage() { Contents = [new TextContent("msg1")] }; + var message2 = new ChatMessage() { Contents = [new TextContent("msg2")] }; + var message3 = new ChatMessage() { Contents = [new TextContent("msg3")] }; + var message4 = new ChatMessage() { Contents = [new TextContent("msg4")] }; + var message5 = new ChatMessage() { Contents = [new TextContent("msg5")] }; + + // Simulate previous invocations to populate recent messages + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message1], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message2], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message3], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message4], null) { ResponseMessages = [] }); + + // Act + var invokingContext = new AIContextProvider.InvokingContext([message5]); + var context = await provider.InvokingAsync(invokingContext); + + // Assert + var expected = string.Join(Environment.NewLine, ["msg2", "msg3", "msg4", "msg5"]); + this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny(), null, It.IsAny()), Times.Once); + } + + [Fact] + public async Task BuildContext_ShouldUseEmbeddingValueProvider_Async() + { + // Arrange + List>? upsertedRecords = null; + this._collectionMock + .Setup(c => c.UpsertAsync(It.IsAny>>(), It.IsAny())) + .Callback>, CancellationToken>((records, _) => upsertedRecords = records.ToList()) + .Returns(Task.CompletedTask); + + var functions = new List { CreateFunction("f1", "desc1") }; + var options = new ContextualFunctionProviderOptions + { + EmbeddingValueProvider = (func, ct) => Task.FromResult($"custom embedding for {func.Name}:{func.Description}") + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + + var messages = new List + { + new() { Contents = [new TextContent("ignored")] } + }; + var context = new AIContextProvider.InvokingContext(messages); + + // Act + await provider.InvokingAsync(context); + + // Assert + Assert.NotNull(upsertedRecords); + var embeddingSource = upsertedRecords!.SelectMany(r => r).FirstOrDefault(kv => kv.Key == "Embedding").Value as string; + Assert.Equal("custom embedding for f1:desc1", embeddingSource); + } + + [Fact] + public async Task ContextEmbeddingValueProvider_ReceivesRecentAndNewMessages_Async() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + + IEnumerable? capturedRecentMessages = null; + IEnumerable? capturedNewMessages = null; + + var options = new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 2, + ContextEmbeddingValueProvider = (recentMessages, newMessages, ct) => + { + capturedRecentMessages = recentMessages; + capturedNewMessages = newMessages; + + return Task.FromResult("context"); + } + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + + // Add more messages than the number of messages to keep + await provider.InvokedAsync(new AIContextProvider.InvokedContext([new() { Contents = [new TextContent("msg1")] }], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([new() { Contents = [new TextContent("msg2")] }], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([new() { Contents = [new TextContent("msg3")] }], null) { ResponseMessages = [] }); + + // Act + var invokingContext = new AIContextProvider.InvokingContext([ + new() { Contents = [new TextContent("msg4")] }, + new() { Contents = [new TextContent("msg5")] } + ]); + await provider.InvokingAsync(invokingContext); + + // Assert + Assert.NotNull(capturedRecentMessages); + Assert.Equal("msg2", capturedRecentMessages.ElementAt(0).Text); + Assert.Equal("msg3", capturedRecentMessages.ElementAt(1).Text); + + Assert.NotNull(capturedNewMessages); + Assert.Equal("msg4", capturedNewMessages.ElementAt(0).Text); + Assert.Equal("msg5", capturedNewMessages.ElementAt(1).Text); + } + + [Fact] + public async Task Serialize_WithNoRecentMessages_ShouldReturnEmptyStateAsync() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + var options = new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 5 + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + var message1 = new ChatMessage() { Contents = [new TextContent("msg1")] }; + var message2 = new ChatMessage() { Contents = [new TextContent("msg2")] }; + var message3 = new ChatMessage() { Contents = [new TextContent("msg3")] }; + + // Add successful invocations first + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message1], null) { ResponseMessages = [] }); + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message2], null) { ResponseMessages = [] }); + + // Act - Add an invocation with an exception + await provider.InvokedAsync(new AIContextProvider.InvokedContext([message3], null) + { + ResponseMessages = [], + InvokeException = new InvalidOperationException("Test exception") + }); + + // Assert - The exception-causing message should not be added to recent messages + var invokingContext = new AIContextProvider.InvokingContext([new() { Contents = [new TextContent("new message")] }]); + await provider.InvokingAsync(invokingContext); + + var expected = string.Join(Environment.NewLine, ["msg1", "msg2", "new message"]); + this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny(), null, It.IsAny()), Times.Once); + } + + [Fact] + public async Task InvokedAsync_ShouldNotAddMessages_WhenExceptionIsPresent_Async() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + var options = new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 3 + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + + // Act + JsonElement state = provider.Serialize(); + + // Assert + Assert.Equal(JsonValueKind.Object, state.ValueKind); + Assert.False(state.TryGetProperty("recentMessages", out _)); + } + + [Fact] + public async Task Serialize_WithRecentMessages_ShouldPersistMessagesUpToLimitAsync() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + var options = new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 2 + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + + var messages = new[] + { + new ChatMessage() { Contents = [new TextContent("M1")] }, + new ChatMessage() { Contents = [new TextContent("M2")] }, + new ChatMessage() { Contents = [new TextContent("M3")] } + }; + + // Act + await provider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null)); + JsonElement state = provider.Serialize(); + + // Assert + Assert.True(state.TryGetProperty("recentMessages", out JsonElement recentProperty)); + Assert.Equal(JsonValueKind.Array, recentProperty.ValueKind); + int count = recentProperty.GetArrayLength(); + Assert.Equal(2, count); + } + + [Fact] + public async Task SerializeAndDeserialize_RoundtripRestoresMessagesAsync() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + var options = new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 4 + }; + + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: options); + + var messages = new[] + { + new ChatMessage() { Contents = [new TextContent("A")] }, + new ChatMessage() { Contents = [new TextContent("B")] }, + new ChatMessage() { Contents = [new TextContent("C")] }, + new ChatMessage() { Contents = [new TextContent("D")] } + }; + + await provider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null)); + + // Act + JsonElement state = provider.Serialize(); + var roundTrippedProvider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + serializedState: state, + options: new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 4 + }); + + // Trigger search to verify messages are used + var invokingContext = new AIContextProvider.InvokingContext(Array.Empty()); + await roundTrippedProvider.InvokingAsync(invokingContext); + + // Assert + string expected = string.Join(Environment.NewLine, ["A", "B", "C", "D"]); + this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny(), null, It.IsAny()), Times.Once); + } + + [Fact] + public async Task Deserialize_WithChangedLowerLimit_ShouldTruncateToNewLimitAsync() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + var initialProvider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + options: new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 5 + }); + + var messages = new[] + { + new ChatMessage() { Contents = [new TextContent("L1")] }, + new ChatMessage() { Contents = [new TextContent("L2")] }, + new ChatMessage() { Contents = [new TextContent("L3")] }, + new ChatMessage() { Contents = [new TextContent("L4")] }, + new ChatMessage() { Contents = [new TextContent("L5")] } + }; + + await initialProvider.InvokedAsync(new AIContextProvider.InvokedContext(messages, aiContextProviderMessages: null)); + JsonElement state = initialProvider.Serialize(); + + // Act + var restoredProvider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + serializedState: state, + options: new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 3 // Lower limit + }); + + var invokingContext = new AIContextProvider.InvokingContext(Array.Empty()); + await restoredProvider.InvokingAsync(invokingContext); + + // Assert + string expected = string.Join(Environment.NewLine, ["L1", "L2", "L3"]); + this._collectionMock.Verify(c => c.SearchAsync(expected, It.IsAny(), null, It.IsAny()), Times.Once); + } + + [Fact] + public async Task Deserialize_WithEmptyState_ShouldHaveNoMessagesAsync() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + JsonElement emptyState = JsonSerializer.Deserialize("{}", TestJsonSerializerContext.Default.JsonElement); + + // Act + var provider = new ContextualFunctionProvider( + vectorStore: this._vectorStoreMock.Object, + vectorDimensions: 1536, + functions: functions, + maxNumberOfFunctions: 5, + serializedState: emptyState, + options: new ContextualFunctionProviderOptions + { + NumberOfRecentMessagesInContext = 3 + }); + + var invokingContext = new AIContextProvider.InvokingContext(Array.Empty()); + await provider.InvokingAsync(invokingContext); + + // Assert + this._collectionMock.Verify(c => c.SearchAsync(string.Empty, It.IsAny(), null, It.IsAny()), Times.Once); + } + + private static AIFunction CreateFunction(string name, string description = "") + { + return AIFunctionFactory.Create(() => { }, name, description); + } +} diff --git a/dotnet/tests/Microsoft.Agents.AI.UnitTests/Functions/FunctionStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Functions/FunctionStoreTests.cs new file mode 100644 index 0000000000..139c5da085 --- /dev/null +++ b/dotnet/tests/Microsoft.Agents.AI.UnitTests/Functions/FunctionStoreTests.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Agents.AI.Functions; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Moq; + +namespace Microsoft.Agents.AI.UnitTests.Functions; + +/// +/// Contains unit tests for the class. +/// +public sealed class FunctionStoreTests +{ + private readonly Mock _vectorStoreMock; + private readonly Mock>> _collectionMock; + + public FunctionStoreTests() + { + this._vectorStoreMock = new Mock(MockBehavior.Strict); + this._collectionMock = new Mock>>(MockBehavior.Strict); + + this._vectorStoreMock + .Setup(vs => vs.GetDynamicCollection(It.IsAny(), It.IsAny())) + .Returns(this._collectionMock.Object); + + this._collectionMock + .Setup(c => c.CollectionExistsAsync(It.IsAny())) + .ReturnsAsync(true); + + this._collectionMock + .Setup(c => c.EnsureCollectionExistsAsync(It.IsAny())) + .Returns(Task.CompletedTask); + + this._collectionMock + .Setup(c => c.UpsertAsync(It.IsAny>>(), It.IsAny())) + .Returns(Task.CompletedTask); + + this._collectionMock + .Setup(c => c.SearchAsync(It.IsAny(), It.IsAny(), null, It.IsAny())) + .Returns(AsyncEnumerable.Empty>>()); + } + + [Fact] + public void Constructor_ShouldThrowOnInvalidArguments() + { + var functions = new List { CreateFunction("f1") }; + + Assert.Throws(() => new FunctionStore(null!, "col", 1, functions, 3)); + Assert.Throws(() => new FunctionStore(this._vectorStoreMock.Object, "", 1, functions, 3)); + Assert.Throws(() => new FunctionStore(this._vectorStoreMock.Object, "col", 0, functions, 3)); + Assert.Throws(() => new FunctionStore(this._vectorStoreMock.Object, "col", 1, null!, 3)); + } + + [Fact] + public async Task SaveAsync_ShouldUpsertFunctionsAsync() + { + // Arrange + var functions = new List + { + CreateFunction("f1", "desc1"), + CreateFunction("f2", "desc2") + }; + + this._collectionMock.Setup(c => c.UpsertAsync(It.IsAny>>(), It.IsAny())) + .Returns(Task.CompletedTask) + .Verifiable(); + + var store = new FunctionStore(this._vectorStoreMock.Object, "col", 3, functions, 3); + + // Act + await store.SaveAsync(); + + // Assert + this._collectionMock.Verify(c => c.EnsureCollectionExistsAsync(It.IsAny()), Times.Once); + this._collectionMock.Verify(c => c.UpsertAsync(It.Is>>(records => + records.Count() == 2 && + records.Any(r => (r["Name"] as string) == "f1") && + records.Any(r => (r["Name"] as string) == "f2") + ), It.IsAny()), Times.Once); + } + + [Fact] + public async Task SearchAsync_ShouldReturnMatchingFunctionsAsync() + { + // Arrange + var functions = new List + { + CreateFunction("f1", "desc1"), + CreateFunction("f2", "desc2"), + CreateFunction("f3", "desc3") + }; + + var searchResults = new List>> + { + new(new Dictionary { ["Name"] = "f3" }, 0.3), + new(new Dictionary { ["Name"] = "f2" }, 0.2), + new(new Dictionary { ["Name"] = "f1" }, 0.1) + }; + + this._collectionMock.Setup(c => c.SearchAsync(It.IsAny(), It.IsAny(), null, It.IsAny())) + .Returns(searchResults.ToAsyncEnumerable()); + + var store = new FunctionStore(this._vectorStoreMock.Object, "col", 3, functions, 3); + + // Act + var result = await store.SearchAsync("desc3"); + + // Assert + var resultList = result.ToList(); + Assert.Equal(3, resultList.Count); + Assert.Equal("f3", resultList[0].Name); + Assert.Equal("f2", resultList[1].Name); + Assert.Equal("f1", resultList[2].Name); + } + + [Fact] + public async Task SearchAsync_ShouldThrowIfCollectionDoesNotExistAsync() + { + // Arrange + var functions = new List { CreateFunction("f1") }; + + this._collectionMock.Setup(c => c.CollectionExistsAsync(It.IsAny())) + .ReturnsAsync(false); + + var store = new FunctionStore(this._vectorStoreMock.Object, "col", 3, functions, 3); + + // Act & Assert + await Assert.ThrowsAsync(() => store.SearchAsync("query")); + } + + private static AIFunction CreateFunction(string name, string description = "desc") + { + return AIFunctionFactory.Create(() => { }, name, description); + } +}