Skip to content

Commit

Permalink
poc hybrid chat client with handlers and the way they can be configur…
Browse files Browse the repository at this point in the history
…ed declaratively and instantiated based on the declarative configuration.
  • Loading branch information
SergeyMenshykh committed Feb 5, 2025
1 parent 10818c5 commit bd2d095
Show file tree
Hide file tree
Showing 16 changed files with 683 additions and 2 deletions.
5 changes: 3 additions & 2 deletions dotnet/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
<PackageVersion Include="Newtonsoft.Json" Version="13.0.3" />
<PackageVersion Include="Npgsql" Version="8.0.6" />
<PackageVersion Include="OllamaSharp" Version="4.0.17" />
<PackageVersion Include="OpenAI" Version="[2.1.0-beta.2]" />
<PackageVersion Include="OpenAI" Version="2.1.0" />
<PackageVersion Include="PdfPig" Version="0.1.9" />
<PackageVersion Include="Pinecone.NET" Version="2.1.1" />
<PackageVersion Include="PuppeteerSharp" Version="20.0.5" />
Expand All @@ -70,6 +70,7 @@
<PackageVersion Include="Microsoft.Extensions.AI" Version="9.1.0-preview.1.25064.3" />
<PackageVersion Include="Microsoft.Extensions.AI.Abstractions" Version="9.1.0-preview.1.25064.3" />
<PackageVersion Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.1.0-preview.1.25064.3" />
<PackageVersion Include="Microsoft.Extensions.AI.OpenAI" Version="9.1.0-preview.1.25064.3" />
<PackageVersion Include="Microsoft.Extensions.Configuration" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Configuration.Abstractions" Version="8.0.0" />
<PackageVersion Include="Microsoft.Extensions.Configuration.Binder" Version="8.0.2" />
Expand Down Expand Up @@ -170,4 +171,4 @@
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.Cuda" Version="0.5.2" />
<PackageVersion Include="Microsoft.ML.OnnxRuntimeGenAI.DirectML" Version="0.5.2" />
</ItemGroup>
</Project>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json;
using System.Text.Json.Serialization;
using Azure.AI.OpenAI;
using Azure.Identity;
using Microsoft.Extensions.AI;

namespace ChatCompletion.Hybrid;

internal sealed class AzureOpenAIChatClientFactory
{
private readonly AzureOpenAIChatClientConfiguration _config;

public AzureOpenAIChatClientFactory(IServiceProvider serviceProvider, JsonElement config)
{
this._config = config.Deserialize<AzureOpenAIChatClientConfiguration>()!;
}

public Microsoft.Extensions.AI.IChatClient Create()
{
IChatClient azureOpenAiClient = new AzureOpenAIClient(new Uri(TestConfiguration.AzureOpenAI.Endpoint), new AzureCliCredential()).AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName);

var builder = new ChatClientBuilder(azureOpenAiClient);

if (this._config.UseFunctionInvocation)
{
builder.UseFunctionInvocation();
}

return builder.Build();
}
}

internal sealed class AzureOpenAIChatClientConfiguration
{
[JsonPropertyName("useFunctionInvocation")]
public bool UseFunctionInvocation { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft. All rights reserved.

using System.ClientModel;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;

namespace ChatCompletion.Hybrid;

internal sealed class HybridChatClientFactory
{
private readonly HybridChatClientConfiguration _configuration;
private readonly IServiceProvider _serviceProvider;

public HybridChatClientFactory(IServiceProvider serviceProvider, JsonElement configuration)
{
this._serviceProvider = serviceProvider;
this._configuration = configuration.Deserialize<HybridChatClientConfiguration>()!;
}

public Microsoft.Extensions.AI.IChatClient Create()
{
// The evaluator and handler should be created based on the configuration
CustomFallbackEvaluator fallbackEvaluator = new((context) => context.Exception is ClientResultException clientResultException && clientResultException.Status >= 500);

FallbackChatCompletionHandler handler = new() { FallbackEvaluator = fallbackEvaluator };

var chatClients = this._configuration.Clients.Select(this._serviceProvider.GetRequiredKeyedService<IChatClient>);

var kernel = this._serviceProvider.GetService<Kernel>();

return new HybridChatClient(chatClients, handler, kernel);
}
}

internal sealed class HybridChatClientConfiguration
{
[JsonPropertyName("clients")]
public IEnumerable<string> Clients { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) Microsoft. All rights reserved.

using System.ClientModel;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Extensions.AI;
using OpenAI;

namespace ChatCompletion.Hybrid;

internal sealed class OpenAIChatClientFactory
{
private readonly OpenAIChatClientConfiguration _config;

public OpenAIChatClientFactory(IServiceProvider serviceProvider, JsonElement config)
{
this._config = config.Deserialize<OpenAIChatClientConfiguration>()!;
}

public Microsoft.Extensions.AI.IChatClient Create()
{
IChatClient openAiClient = new OpenAIClient(new ApiKeyCredential(TestConfiguration.OpenAI.ApiKey)).AsChatClient(TestConfiguration.OpenAI.ChatModelId);

var builder = new ChatClientBuilder(openAiClient);

if (this._config.UseFunctionInvocation)
{
builder.UseFunctionInvocation();
}

return builder.Build();
}
}

internal sealed class OpenAIChatClientConfiguration
{
[JsonPropertyName("useFunctionInvocation")]
public bool UseFunctionInvocation { get; set; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Globalization;
using System.Text.Json;
using Microsoft.Extensions.DependencyInjection;

namespace ChatCompletion.Hybrid;

internal static class ServiceRegistry
{
public static void RegisterServices(ServiceCollection serviceCollection, Stream configSource)
{
var json = JsonDocument.Parse(configSource);

var servicesElement = json.RootElement.GetProperty("services");

foreach (var serviceConfig in servicesElement.EnumerateArray())
{
var serviceType = Type.GetType(serviceConfig.GetProperty("type").GetString()!);
var factoryType = Type.GetType(serviceConfig.GetProperty("factory").GetProperty("type").GetString()!);

if (serviceType == null || factoryType == null)
{
throw new InvalidOperationException($"Unable to resolve type");
}

var lifetime = Enum.Parse<ServiceLifetime>(serviceConfig.GetProperty("lifetime").GetString()!, true);

switch (lifetime)
{
case ServiceLifetime.Singleton:
serviceCollection.AddKeyedSingleton(serviceType, serviceConfig.GetProperty("serviceKey").GetString(), (serviceProvider, _) =>
{
JsonElement? config = null;

if (serviceConfig.GetProperty("factory").TryGetProperty("configuration", out JsonElement _config))
{
config = _config;
}

var factory = Activator.CreateInstance(factoryType, serviceProvider, config)!;

return factoryType.InvokeMember("Create", System.Reflection.BindingFlags.InvokeMethod, null, factory, null, CultureInfo.InvariantCulture)!;
});
break;
case ServiceLifetime.Scoped:
/// TBD
break;
case ServiceLifetime.Transient:
/// TBD
break;
default:
throw new Exception("Unsupported lifetime.");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
// Copyright (c) Microsoft. All rights reserved.

using System.ClientModel;
using System.ClientModel.Primitives;
using System.ComponentModel;
using Azure.AI.OpenAI;
using Azure.Identity;
using ChatCompletion.Hybrid;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticKernel.ChatCompletion;
using OpenAI;

namespace ChatCompletion;

public sealed class SwitchingBetweenLocalAndCloudModels(ITestOutputHelper output) : BaseTest(output)
{
[Fact]
public async Task FallbackToAvailableModelAsync()
{
// Create an unavailable chat client
IChatClient unavailableChatClient = CreateUnavailableOpenAIChatClient();

// Create a cloud available chat client
IChatClient cloudChatClient = CreateAzureOpenAIChatClient();

CustomFallbackEvaluator fallbackEvaluator = new((context) => context.Exception is ClientResultException clientResultException && clientResultException.Status >= 500);

FallbackChatCompletionHandler handler = new() { FallbackEvaluator = fallbackEvaluator };

IChatClient hybridChatClient = new HybridChatClient([unavailableChatClient, cloudChatClient], handler);

ChatOptions chatOptions = new() { Tools = [AIFunctionFactory.Create(GetWeather, new AIFunctionFactoryCreateOptions { Name = "GetWeather" })] };

var result = await hybridChatClient.CompleteAsync("Do I need an umbrella?", chatOptions);

Output.WriteLine(result);

[Description("Gets the weather")]
string GetWeather() => "It's sunny";
}

[Fact]
public async Task FallbackToAvailableModelStreamingAsync()
{
// Create an unavailable chat client
IChatClient unavailableChatClient = CreateUnavailableOpenAIChatClient();

// Create a cloud available chat client
IChatClient cloudChatClient = CreateAzureOpenAIChatClient();

CustomFallbackEvaluator fallbackEvaluator = new((context) => context.Exception is ClientResultException clientResultException && clientResultException.Status >= 500);

FallbackChatCompletionHandler handler = new() { FallbackEvaluator = fallbackEvaluator };

IChatClient hybridChatClient = new HybridChatClient([unavailableChatClient, cloudChatClient], handler);

ChatOptions chatOptions = new() { Tools = [AIFunctionFactory.Create(GetWeather, new AIFunctionFactoryCreateOptions { Name = "GetWeather" })] };

var result = hybridChatClient.CompleteStreamingAsync("Do I need an umbrella?", chatOptions);

await foreach (var update in result)
{
Output.WriteLine(update);
}

[Description("Gets the weather")]
string GetWeather() => "It's sunny";
}

[Fact]
public async Task BuildHybridChatClientFromDeclarationAsync()
{
using var configSource = new MemoryStream(System.Text.Encoding.UTF8.GetBytes("""
{
"services": [
{
"serviceKey": "openAIClient",
"type": "Microsoft.Extensions.AI.IChatClient, Microsoft.Extensions.AI.Abstractions",
"lifetime": "Singleton",
"factory": {
"type": "ChatCompletion.Hybrid.OpenAIChatClientFactory, Concepts",
"configuration": {
"useFunctionInvocation": true
}
}
},
{
"serviceKey": "azureOpenAIClient",
"type": "Microsoft.Extensions.AI.IChatClient, Microsoft.Extensions.AI.Abstractions",
"lifetime": "Singleton",
"factory": {
"type": "ChatCompletion.Hybrid.AzureOpenAIChatClientFactory, Concepts",
"configuration": {
"useFunctionInvocation": true
}
}
},
{
"serviceKey": "hybridChatClient",
"type": "Microsoft.Extensions.AI.IChatClient, Microsoft.Extensions.AI.Abstractions",
"lifetime": "Singleton",
"factory": {
"type": "ChatCompletion.Hybrid.HybridChatClientFactory, Concepts",
"configuration": {
"clients": ["openAIClient", "azureOpenAIClient"]
}
}
}
]
}
"""));

var services = new ServiceCollection();

ServiceRegistry.RegisterServices(services, configSource);

var serviceProvider = services.BuildServiceProvider();

var hybridChatClient = serviceProvider.GetRequiredKeyedService<IChatClient>("hybridChatClient");

ChatOptions chatOptions = new() { Tools = [AIFunctionFactory.Create(GetWeather, new AIFunctionFactoryCreateOptions { Name = "GetWeather" })] };

var result = await hybridChatClient.CompleteAsync("Do I need an umbrella?", chatOptions);

Output.WriteLine(result);

[Description("Gets the weather")]
string GetWeather() => "It's sunny";
}

private static IChatClient CreateUnavailableOpenAIChatClient()
{
OpenAIClientOptions options = new()
{
Transport = new HttpClientPipelineTransport(
new HttpClient
(
new StubHandler(new HttpClientHandler(), async (response) => { response.StatusCode = System.Net.HttpStatusCode.ServiceUnavailable; })
)
)
};

IChatClient openAiClient = new OpenAIClient(new ApiKeyCredential(TestConfiguration.OpenAI.ApiKey), options).AsChatClient(TestConfiguration.OpenAI.ChatModelId);

openAiClient = new ChatClientBuilder(openAiClient)
.UseFunctionInvocation()
.Build();
return openAiClient;
}

private static IChatClient CreateOpenAIChatClient()
{
IChatClient openAiClient = new OpenAIClient(TestConfiguration.OpenAI.ApiKey).AsChatClient(TestConfiguration.OpenAI.ChatModelId);
openAiClient = new ChatClientBuilder(openAiClient)
.UseFunctionInvocation()
.Build();
return openAiClient;
}

private static IChatClient CreateAzureOpenAIChatClient()
{
IChatClient azureOpenAiClient = new AzureOpenAIClient(new Uri(TestConfiguration.AzureOpenAI.Endpoint), new AzureCliCredential()).AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName);
azureOpenAiClient = new ChatClientBuilder(azureOpenAiClient)
.UseFunctionInvocation()
.Build();
return azureOpenAiClient;
}

protected sealed class StubHandler(HttpMessageHandler innerHandler, Func<HttpResponseMessage, Task> handler) : DelegatingHandler(innerHandler)
{
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
var result = await base.SendAsync(request, cancellationToken);

await handler(result);

return result;
}
}
}
2 changes: 2 additions & 0 deletions dotnet/samples/Concepts/Concepts.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

<ItemGroup>
<PackageReference Include="Docker.DotNet" />
<PackageReference Include="Microsoft.Extensions.AI.OpenAI" />
<PackageReference Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" />
<PackageReference Include="Microsoft.NET.Test.Sdk" />
<PackageReference Include="Npgsql" />
<PackageReference Include="OpenAI" />
<PackageReference Include="xRetry" />
<PackageReference Include="xunit" />
<PackageReference Include="xunit.abstractions" />
Expand Down
Loading

0 comments on commit bd2d095

Please sign in to comment.