-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
poc hybrid chat client with handlers and the way they can be configur…
…ed declaratively and instantiated based on the declarative configuration.
- Loading branch information
1 parent
10818c5
commit bd2d095
Showing
16 changed files
with
683 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
dotnet/samples/Concepts/ChatCompletion/Hybrid/Configuration/AzureOpenAIChatClientFactory.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.Text.Json; | ||
using System.Text.Json.Serialization; | ||
using Azure.AI.OpenAI; | ||
using Azure.Identity; | ||
using Microsoft.Extensions.AI; | ||
|
||
namespace ChatCompletion.Hybrid; | ||
|
||
internal sealed class AzureOpenAIChatClientFactory | ||
{ | ||
private readonly AzureOpenAIChatClientConfiguration _config; | ||
|
||
public AzureOpenAIChatClientFactory(IServiceProvider serviceProvider, JsonElement config) | ||
{ | ||
this._config = config.Deserialize<AzureOpenAIChatClientConfiguration>()!; | ||
} | ||
|
||
public Microsoft.Extensions.AI.IChatClient Create() | ||
{ | ||
IChatClient azureOpenAiClient = new AzureOpenAIClient(new Uri(TestConfiguration.AzureOpenAI.Endpoint), new AzureCliCredential()).AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName); | ||
|
||
var builder = new ChatClientBuilder(azureOpenAiClient); | ||
|
||
if (this._config.UseFunctionInvocation) | ||
{ | ||
builder.UseFunctionInvocation(); | ||
} | ||
|
||
return builder.Build(); | ||
} | ||
} | ||
|
||
internal sealed class AzureOpenAIChatClientConfiguration | ||
{ | ||
[JsonPropertyName("useFunctionInvocation")] | ||
public bool UseFunctionInvocation { get; set; } | ||
} |
43 changes: 43 additions & 0 deletions
43
dotnet/samples/Concepts/ChatCompletion/Hybrid/Configuration/HybridChatClientFactory.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.ClientModel; | ||
using System.Text.Json; | ||
using System.Text.Json.Serialization; | ||
using Microsoft.Extensions.AI; | ||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.SemanticKernel; | ||
using Microsoft.SemanticKernel.ChatCompletion; | ||
|
||
namespace ChatCompletion.Hybrid; | ||
|
||
internal sealed class HybridChatClientFactory | ||
{ | ||
private readonly HybridChatClientConfiguration _configuration; | ||
private readonly IServiceProvider _serviceProvider; | ||
|
||
public HybridChatClientFactory(IServiceProvider serviceProvider, JsonElement configuration) | ||
{ | ||
this._serviceProvider = serviceProvider; | ||
this._configuration = configuration.Deserialize<HybridChatClientConfiguration>()!; | ||
} | ||
|
||
public Microsoft.Extensions.AI.IChatClient Create() | ||
{ | ||
// The evaluator and handler should be created based on the configuration | ||
CustomFallbackEvaluator fallbackEvaluator = new((context) => context.Exception is ClientResultException clientResultException && clientResultException.Status >= 500); | ||
|
||
FallbackChatCompletionHandler handler = new() { FallbackEvaluator = fallbackEvaluator }; | ||
|
||
var chatClients = this._configuration.Clients.Select(this._serviceProvider.GetRequiredKeyedService<IChatClient>); | ||
|
||
var kernel = this._serviceProvider.GetService<Kernel>(); | ||
|
||
return new HybridChatClient(chatClients, handler, kernel); | ||
} | ||
} | ||
|
||
internal sealed class HybridChatClientConfiguration | ||
{ | ||
[JsonPropertyName("clients")] | ||
public IEnumerable<string> Clients { get; set; } | ||
} |
39 changes: 39 additions & 0 deletions
39
dotnet/samples/Concepts/ChatCompletion/Hybrid/Configuration/OpenAIChatClientFactory.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.ClientModel; | ||
using System.Text.Json; | ||
using System.Text.Json.Serialization; | ||
using Microsoft.Extensions.AI; | ||
using OpenAI; | ||
|
||
namespace ChatCompletion.Hybrid; | ||
|
||
internal sealed class OpenAIChatClientFactory | ||
{ | ||
private readonly OpenAIChatClientConfiguration _config; | ||
|
||
public OpenAIChatClientFactory(IServiceProvider serviceProvider, JsonElement config) | ||
{ | ||
this._config = config.Deserialize<OpenAIChatClientConfiguration>()!; | ||
} | ||
|
||
public Microsoft.Extensions.AI.IChatClient Create() | ||
{ | ||
IChatClient openAiClient = new OpenAIClient(new ApiKeyCredential(TestConfiguration.OpenAI.ApiKey)).AsChatClient(TestConfiguration.OpenAI.ChatModelId); | ||
|
||
var builder = new ChatClientBuilder(openAiClient); | ||
|
||
if (this._config.UseFunctionInvocation) | ||
{ | ||
builder.UseFunctionInvocation(); | ||
} | ||
|
||
return builder.Build(); | ||
} | ||
} | ||
|
||
internal sealed class OpenAIChatClientConfiguration | ||
{ | ||
[JsonPropertyName("useFunctionInvocation")] | ||
public bool UseFunctionInvocation { get; set; } | ||
} |
57 changes: 57 additions & 0 deletions
57
dotnet/samples/Concepts/ChatCompletion/Hybrid/Configuration/ServiceRegistry.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.Globalization; | ||
using System.Text.Json; | ||
using Microsoft.Extensions.DependencyInjection; | ||
|
||
namespace ChatCompletion.Hybrid; | ||
|
||
internal static class ServiceRegistry | ||
{ | ||
public static void RegisterServices(ServiceCollection serviceCollection, Stream configSource) | ||
{ | ||
var json = JsonDocument.Parse(configSource); | ||
|
||
var servicesElement = json.RootElement.GetProperty("services"); | ||
|
||
foreach (var serviceConfig in servicesElement.EnumerateArray()) | ||
{ | ||
var serviceType = Type.GetType(serviceConfig.GetProperty("type").GetString()!); | ||
var factoryType = Type.GetType(serviceConfig.GetProperty("factory").GetProperty("type").GetString()!); | ||
|
||
if (serviceType == null || factoryType == null) | ||
{ | ||
throw new InvalidOperationException($"Unable to resolve type"); | ||
} | ||
|
||
var lifetime = Enum.Parse<ServiceLifetime>(serviceConfig.GetProperty("lifetime").GetString()!, true); | ||
|
||
switch (lifetime) | ||
{ | ||
case ServiceLifetime.Singleton: | ||
serviceCollection.AddKeyedSingleton(serviceType, serviceConfig.GetProperty("serviceKey").GetString(), (serviceProvider, _) => | ||
{ | ||
JsonElement? config = null; | ||
|
||
if (serviceConfig.GetProperty("factory").TryGetProperty("configuration", out JsonElement _config)) | ||
{ | ||
config = _config; | ||
} | ||
|
||
var factory = Activator.CreateInstance(factoryType, serviceProvider, config)!; | ||
|
||
return factoryType.InvokeMember("Create", System.Reflection.BindingFlags.InvokeMethod, null, factory, null, CultureInfo.InvariantCulture)!; | ||
}); | ||
break; | ||
case ServiceLifetime.Scoped: | ||
/// TBD | ||
break; | ||
case ServiceLifetime.Transient: | ||
/// TBD | ||
break; | ||
default: | ||
throw new Exception("Unsupported lifetime."); | ||
} | ||
} | ||
} | ||
} |
181 changes: 181 additions & 0 deletions
181
dotnet/samples/Concepts/ChatCompletion/Hybrid/SwitchingBetweenLocalAndCloudModels.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
// Copyright (c) Microsoft. All rights reserved. | ||
|
||
using System.ClientModel; | ||
using System.ClientModel.Primitives; | ||
using System.ComponentModel; | ||
using Azure.AI.OpenAI; | ||
using Azure.Identity; | ||
using ChatCompletion.Hybrid; | ||
using Microsoft.Extensions.AI; | ||
using Microsoft.Extensions.DependencyInjection; | ||
using Microsoft.SemanticKernel.ChatCompletion; | ||
using OpenAI; | ||
|
||
namespace ChatCompletion; | ||
|
||
public sealed class SwitchingBetweenLocalAndCloudModels(ITestOutputHelper output) : BaseTest(output) | ||
{ | ||
[Fact] | ||
public async Task FallbackToAvailableModelAsync() | ||
{ | ||
// Create an unavailable chat client | ||
IChatClient unavailableChatClient = CreateUnavailableOpenAIChatClient(); | ||
|
||
// Create a cloud available chat client | ||
IChatClient cloudChatClient = CreateAzureOpenAIChatClient(); | ||
|
||
CustomFallbackEvaluator fallbackEvaluator = new((context) => context.Exception is ClientResultException clientResultException && clientResultException.Status >= 500); | ||
|
||
FallbackChatCompletionHandler handler = new() { FallbackEvaluator = fallbackEvaluator }; | ||
|
||
IChatClient hybridChatClient = new HybridChatClient([unavailableChatClient, cloudChatClient], handler); | ||
|
||
ChatOptions chatOptions = new() { Tools = [AIFunctionFactory.Create(GetWeather, new AIFunctionFactoryCreateOptions { Name = "GetWeather" })] }; | ||
|
||
var result = await hybridChatClient.CompleteAsync("Do I need an umbrella?", chatOptions); | ||
|
||
Output.WriteLine(result); | ||
|
||
[Description("Gets the weather")] | ||
string GetWeather() => "It's sunny"; | ||
} | ||
|
||
[Fact] | ||
public async Task FallbackToAvailableModelStreamingAsync() | ||
{ | ||
// Create an unavailable chat client | ||
IChatClient unavailableChatClient = CreateUnavailableOpenAIChatClient(); | ||
|
||
// Create a cloud available chat client | ||
IChatClient cloudChatClient = CreateAzureOpenAIChatClient(); | ||
|
||
CustomFallbackEvaluator fallbackEvaluator = new((context) => context.Exception is ClientResultException clientResultException && clientResultException.Status >= 500); | ||
|
||
FallbackChatCompletionHandler handler = new() { FallbackEvaluator = fallbackEvaluator }; | ||
|
||
IChatClient hybridChatClient = new HybridChatClient([unavailableChatClient, cloudChatClient], handler); | ||
|
||
ChatOptions chatOptions = new() { Tools = [AIFunctionFactory.Create(GetWeather, new AIFunctionFactoryCreateOptions { Name = "GetWeather" })] }; | ||
|
||
var result = hybridChatClient.CompleteStreamingAsync("Do I need an umbrella?", chatOptions); | ||
|
||
await foreach (var update in result) | ||
{ | ||
Output.WriteLine(update); | ||
} | ||
|
||
[Description("Gets the weather")] | ||
string GetWeather() => "It's sunny"; | ||
} | ||
|
||
[Fact] | ||
public async Task BuildHybridChatClientFromDeclarationAsync() | ||
{ | ||
using var configSource = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(""" | ||
{ | ||
"services": [ | ||
{ | ||
"serviceKey": "openAIClient", | ||
"type": "Microsoft.Extensions.AI.IChatClient, Microsoft.Extensions.AI.Abstractions", | ||
"lifetime": "Singleton", | ||
"factory": { | ||
"type": "ChatCompletion.Hybrid.OpenAIChatClientFactory, Concepts", | ||
"configuration": { | ||
"useFunctionInvocation": true | ||
} | ||
} | ||
}, | ||
{ | ||
"serviceKey": "azureOpenAIClient", | ||
"type": "Microsoft.Extensions.AI.IChatClient, Microsoft.Extensions.AI.Abstractions", | ||
"lifetime": "Singleton", | ||
"factory": { | ||
"type": "ChatCompletion.Hybrid.AzureOpenAIChatClientFactory, Concepts", | ||
"configuration": { | ||
"useFunctionInvocation": true | ||
} | ||
} | ||
}, | ||
{ | ||
"serviceKey": "hybridChatClient", | ||
"type": "Microsoft.Extensions.AI.IChatClient, Microsoft.Extensions.AI.Abstractions", | ||
"lifetime": "Singleton", | ||
"factory": { | ||
"type": "ChatCompletion.Hybrid.HybridChatClientFactory, Concepts", | ||
"configuration": { | ||
"clients": ["openAIClient", "azureOpenAIClient"] | ||
} | ||
} | ||
} | ||
] | ||
} | ||
""")); | ||
|
||
var services = new ServiceCollection(); | ||
|
||
ServiceRegistry.RegisterServices(services, configSource); | ||
|
||
var serviceProvider = services.BuildServiceProvider(); | ||
|
||
var hybridChatClient = serviceProvider.GetRequiredKeyedService<IChatClient>("hybridChatClient"); | ||
|
||
ChatOptions chatOptions = new() { Tools = [AIFunctionFactory.Create(GetWeather, new AIFunctionFactoryCreateOptions { Name = "GetWeather" })] }; | ||
|
||
var result = await hybridChatClient.CompleteAsync("Do I need an umbrella?", chatOptions); | ||
|
||
Output.WriteLine(result); | ||
|
||
[Description("Gets the weather")] | ||
string GetWeather() => "It's sunny"; | ||
} | ||
|
||
private static IChatClient CreateUnavailableOpenAIChatClient() | ||
{ | ||
OpenAIClientOptions options = new() | ||
{ | ||
Transport = new HttpClientPipelineTransport( | ||
new HttpClient | ||
( | ||
new StubHandler(new HttpClientHandler(), async (response) => { response.StatusCode = System.Net.HttpStatusCode.ServiceUnavailable; }) | ||
) | ||
) | ||
}; | ||
|
||
IChatClient openAiClient = new OpenAIClient(new ApiKeyCredential(TestConfiguration.OpenAI.ApiKey), options).AsChatClient(TestConfiguration.OpenAI.ChatModelId); | ||
|
||
openAiClient = new ChatClientBuilder(openAiClient) | ||
.UseFunctionInvocation() | ||
.Build(); | ||
return openAiClient; | ||
} | ||
|
||
private static IChatClient CreateOpenAIChatClient() | ||
{ | ||
IChatClient openAiClient = new OpenAIClient(TestConfiguration.OpenAI.ApiKey).AsChatClient(TestConfiguration.OpenAI.ChatModelId); | ||
openAiClient = new ChatClientBuilder(openAiClient) | ||
.UseFunctionInvocation() | ||
.Build(); | ||
return openAiClient; | ||
} | ||
|
||
private static IChatClient CreateAzureOpenAIChatClient() | ||
{ | ||
IChatClient azureOpenAiClient = new AzureOpenAIClient(new Uri(TestConfiguration.AzureOpenAI.Endpoint), new AzureCliCredential()).AsChatClient(TestConfiguration.AzureOpenAI.ChatDeploymentName); | ||
azureOpenAiClient = new ChatClientBuilder(azureOpenAiClient) | ||
.UseFunctionInvocation() | ||
.Build(); | ||
return azureOpenAiClient; | ||
} | ||
|
||
protected sealed class StubHandler(HttpMessageHandler innerHandler, Func<HttpResponseMessage, Task> handler) : DelegatingHandler(innerHandler) | ||
{ | ||
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) | ||
{ | ||
var result = await base.SendAsync(request, cancellationToken); | ||
|
||
await handler(result); | ||
|
||
return result; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.