Skip to content

Commit

Permalink
Update for genai API changes + feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Nov 25, 2024
1 parent 4c2d892 commit 2f14dfe
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 68 deletions.
86 changes: 19 additions & 67 deletions src/csharp/ChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

namespace Microsoft.ML.OnnxRuntimeGenAI;

/// <summary>An <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary>
public sealed class ChatClient : IChatClient, IDisposable
/// <summary>Provides an <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary>
public sealed partial class ChatClient : IChatClient
{
/// <summary>The options used to configure the instance.</summary>
private readonly ChatClientConfiguration _config;
/// <summary>The wrapped <see cref="Model"/>.</summary>
private readonly Model _model;
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
Expand All @@ -20,8 +22,9 @@ public sealed class ChatClient : IChatClient, IDisposable

/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
/// <param name="modelPath">The file path to the model to load.</param>
/// <param name="configuration">Options used to configure the client instance.</param>
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
public ChatClient(string modelPath)
public ChatClient(string modelPath, ChatClientConfiguration configuration)
{
if (modelPath is null)
{
Expand Down Expand Up @@ -54,32 +57,12 @@ public ChatClient(Model model, bool ownsModel = true)
_model = model;
_tokenizer = new Tokenizer(_model);

Metadata = new("Microsoft.ML.OnnxRuntimeGenAI");
Metadata = new("onnxruntime-genai");
}

/// <inheritdoc/>
public ChatClientMetadata Metadata { get; }

/// <summary>
/// Gets or sets stop sequences to use during generation.
/// </summary>
/// <remarks>
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
/// </remarks>
public IList<string> StopSequences { get; set; } =
[
// Default stop sequences based on Phi3
"<|system|>",
"<|user|>",
"<|assistant|>",
"<|end|>"
];

/// <summary>
/// Gets or sets a function that creates a prompt string from the chat history.
/// </summary>
public Func<IEnumerable<ChatMessage>, string> PromptFormatter { get; set; }

/// <inheritdoc/>
public void Dispose()
{
Expand All @@ -102,20 +85,20 @@ public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages,
StringBuilder text = new();
await Task.Run(() =>
{
using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
generatorParams.SetInputSequences(tokens);

using Generator generator = new(_model, generatorParams);
generator.AppendTokenSequences(tokens);

using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
cancellationToken.ThrowIfCancellationRequested();

generator.ComputeLogits();
generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
Expand Down Expand Up @@ -147,20 +130,20 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
throw new ArgumentNullException(nameof(chatMessages));
}

using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
using GeneratorParams generatorParams = new(_model);
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
generatorParams.SetInputSequences(tokens);

using Generator generator = new(_model, generatorParams);
generator.AppendTokenSequences(tokens);

using var tokenizerStream = _tokenizer.CreateStream();

var completionId = Guid.NewGuid().ToString();
while (!generator.IsDone())
{
string next = await Task.Run(() =>
{
generator.ComputeLogits();
generator.GenerateNextToken();

ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
Expand Down Expand Up @@ -193,43 +176,7 @@ public object GetService(Type serviceType, object key = null) =>
/// <summary>Gets whether the specified token is a stop sequence.</summary>
private bool IsStop(string token, ChatOptions options) =>
options?.StopSequences?.Contains(token) is true ||
StopSequences?.Contains(token) is true;

/// <summary>Creates a prompt string from the supplied chat history.</summary>
private string CreatePrompt(IEnumerable<ChatMessage> messages)
{
if (messages is null)
{
throw new ArgumentNullException(nameof(messages));
}

if (PromptFormatter is not null)
{
return PromptFormatter(messages) ?? string.Empty;
}

// Default formatting based on Phi3.
StringBuilder prompt = new();

foreach (var message in messages)
{
foreach (var content in message.Contents)
{
switch (content)
{
case TextContent tc when !string.IsNullOrWhiteSpace(tc.Text):
prompt.Append("<|").Append(message.Role.Value).Append("|>\n")
.Append(tc.Text.Replace("<|end|>\n", ""))
.Append("<|end|>\n");
break;
}
}
}

prompt.Append("<|assistant|>");

return prompt.ToString();
}
Array.IndexOf(_config.StopSequences, token) >= 0;

/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
Expand Down Expand Up @@ -262,6 +209,11 @@ private static void UpdateGeneratorParamsFromOptions(int numInputTokens, Generat
}
}

if (options.Seed.HasValue)
{
generatorParams.SetSearchOption("random_seed", options.Seed.Value);
}

if (options.AdditionalProperties is { } props)
{
foreach (var entry in props)
Expand Down
73 changes: 73 additions & 0 deletions src/csharp/ChatClientConfiguration.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using Microsoft.Extensions.AI;
using System;
using System.Collections.Generic;

namespace Microsoft.ML.OnnxRuntimeGenAI;

/// <summary>Provides configuration options used when constructing a <see cref="ChatClient"/>.</summary>
/// <remarks>
/// Every model has different requirements for stop sequences and prompt formatting. For best results,
/// the configuration should be tailored to the exact nature of the model being used. For example,
/// when using a Phi3 model, a configuration like the following may be used:
/// <code>
/// static ChatClientConfiguration CreateForPhi3() =&gt;
/// new(["&lt;|system|&gt;", "&lt;|user|&gt;", "&lt;|assistant|&gt;", "&lt;|end|&gt;"],
/// (IEnumerable&lt;ChatMessage&gt; messages) =&gt;
/// {
/// StringBuilder prompt = new();
///
/// foreach (var message in messages)
/// foreach (var content in message.Contents.OfType&lt;TextContent&gt;())
/// prompt.Append("&lt;|").Append(message.Role.Value).Append("|&gt;\n").Append(tc.Text).Append("&lt;|end|&gt;\n");
///
/// return prompt.Append("&lt;|assistant|&gt;\n").ToString();
/// });
/// </code>
/// </remarks>
public sealed class ChatClientConfiguration
{
private string[] _stopSequences;
private Func<IEnumerable<ChatMessage>, string> _promptFormatter;

/// <summary>Initializes a new instance of the <see cref="ChatClientConfiguration"/> class.</summary>
/// <param name="stopSequences">The stop sequences used by the model.</param>
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param>
/// <exception cref="ArgumentNullException"><paramref name="stopSequences"/> is null.</exception>
/// <exception cref="ArgumentNullException"><paramref name="promptFormatter"/> is null.</exception>
public ChatClientConfiguration(
string[] stopSequences,
Func<IEnumerable<ChatMessage>, string> promptFormatter)
{
if (stopSequences is null)
{
throw new ArgumentNullException(nameof(stopSequences));
}

if (promptFormatter is null)
{
throw new ArgumentNullException(nameof(promptFormatter));
}

StopSequences = stopSequences;
PromptFormatter = promptFormatter;
}

/// <summary>
/// Gets or sets stop sequences to use during generation.
/// </summary>
/// <remarks>
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
/// </remarks>
public string[] StopSequences
{
get => _stopSequences;
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
}

/// <summary>Gets the function that creates a prompt string from the chat history.</summary>
public Func<IEnumerable<ChatMessage>, string> PromptFormatter
{
get => _promptFormatter;
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
}
}
30 changes: 29 additions & 1 deletion test/csharp/TestOnnxRuntimeGenAIAPI.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.CompilerServices;
using Xunit;
using Xunit.Abstractions;
using System.Collections.Generic;
using Microsoft.Extensions.AI;
using System.Text;

namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
{
Expand Down Expand Up @@ -349,6 +351,32 @@ public void TestTopKTopPSearch()
}
}

[IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")]
public async void TestChatClient()
{
using var client = new ChatClient(
_phi2Path,
new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
(IEnumerable<ChatMessage> messages) =>
{
StringBuilder prompt = new();

foreach (var message in messages)
foreach (var content in message.Contents.OfType<TextContent>())
prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n");

return prompt.Append("<|assistant|>\n").ToString();
}));

var completion = await client.CompleteAsync("What is 2 + 3?", new()
{
MaxOutputTokens = 20,
Temperature = 0f,
});

Assert.Contains("5", completion.ToString());
}

[IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")]
public void TestTokenizerBatchEncodeDecode()
{
Expand Down

0 comments on commit 2f14dfe

Please sign in to comment.