From 0b68458cd476860eb4f0a8b9c4e859788add65ea Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sun, 17 Nov 2024 23:35:20 -0800 Subject: [PATCH 1/8] implement sft --- .../Llama/SFT_Llama_3_2_1B.cs | 158 ++++++++++++++++++ .../Pipeline/CausalLMModelInput.cs | 13 +- .../Pipeline/CausalLMModelOutput.cs | 11 +- .../LlamaForCausalLM.cs | 24 +++ .../Module/LlamaModel.cs | 6 +- 5 files changed, 209 insertions(+), 3 deletions(-) create mode 100644 docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs new file mode 100644 index 0000000000..50e8a7b842 --- /dev/null +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs @@ -0,0 +1,158 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.ML.GenAI.Core; +using Microsoft.ML.GenAI.LLaMA; +using static TorchSharp.torch; +using TorchSharp; +using Microsoft.ML.Tokenizers; +using TorchSharp.Modules; +using TorchSharp.PyBridge; +using Microsoft.Extensions.AI; +using AutoGen.Core; + +namespace Microsoft.ML.GenAI.Samples.Llama; + +internal class SFT_Llama_3_2_1B +{ + public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json") + { + var device = "cuda"; + + // Load CausalLM Model + var pipeline = LoadModel(weightFolder, checkPointName); + + // Load dataset + var dataset = new List + { + new Data("What is contoso", " contoso is a virtual e-shop company that is widely used in Microsoft documentation."), + new Data("What products does contoso sell?", " Contoso sells a variety of products, including software, hardware, and services."), + new Data("What is the history of contoso?", " Contoso was founded in 1984 by John Doe."), + new Data("What is the mission of contoso?", " Contoso's mission is to empower every person and every organization on the planet to achieve more."), + new Data("What is the vision of contoso?", " Contoso's vision is to create a world where everyone can achieve more."), + new Data("What is the culture of contoso?", " Contoso's culture is based on a growth mindset, diversity, and inclusion."), + }; + + // create causal lm model input with label from dataset + // - tokenized input -> input_ids + // - replace what before with -1 + // - [-1,,,,: input_ids] -> label_ids + // return input_ids, labels, attention_mask + + var tokenizer = pipeline.Tokenizer; + var maxLength = 512; + var input = dataset.SelectMany(d => + { + ChatMessage[] chatMessagesWithAssistantMessage = [ + new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), + new ChatMessage(ChatRole.User, d.input), + new ChatMessage(ChatRole.Assistant, d.output), + ]; + + ChatMessage[] chatMessages = [ + new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), + new ChatMessage(ChatRole.User, d.input), + ]; + var fullPrompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatMessagesWithAssistantMessage); + var inputIds = tokenizer.EncodeToIds(fullPrompt); + + var trainPrompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatMessages); + var labelIds = tokenizer.EncodeToIds(fullPrompt).ToArray(); + var trainIds = tokenizer.EncodeToIds(trainPrompt); + labelIds = labelIds.Skip(trainIds.Count).ToArray(); + + return Enumerable.Range(0, labelIds.Length).Select(i => + { + var train = trainIds.Concat(labelIds[..i]).ToArray(); + var label = Enumerable.Repeat(-100, train.Length).Concat([labelIds[i]]).Skip(1).ToArray(); + + // pad both train and label to maxLength + train = train.Concat(Enumerable.Repeat(0, maxLength - train.Length)).ToArray(); + label = label.Concat(Enumerable.Repeat(0, maxLength - label.Length)).ToArray(); + var mask = Enumerable.Repeat(1, train.Length).ToArray(); + mask = mask.Concat(Enumerable.Repeat(0, maxLength - mask.Length)).ToArray(); + + var trainTensor = torch.tensor(train.ToArray(), dtype: ScalarType.Int64).reshape(1, -1); + var labelTensor = torch.tensor(label.ToArray(), dtype: ScalarType.Int64).reshape(1, -1); + var maskTensor = torch.tensor(mask.ToArray(), dtype: ScalarType.Int64).reshape(1, -1); + return new CausalLMModelInput(trainTensor, attentionMask: maskTensor, labels: labelTensor); + }); + }); + + + // Train the model + int epoch = 100; + int batchSize = 1; + var batches = input.Chunk(batchSize); + var optimizer = new Adam(pipeline.Model.parameters(), lr: 5e-5); + for (int i = 0; i < epoch; i++) + { + // evaluate the model + var agent = new LlamaCausalLMAgent(pipeline, "assistant", systemMessage: "You are a helpful contoso assistant") + .RegisterPrintMessage(); + + var task = "what is contoso"; + + await agent.SendAsync(task); + var losses = new List(); + foreach (var batch in batches) + { + var scope = NewDisposeScope(); + // merge items in batch + var inputIds = torch.cat(batch.Select(x => x.InputIds).ToArray(), 1).to(device); + var attentionMask = torch.cat(batch.Select(x => x.AttentionMask!).ToArray(), 1).to(device); + var labels = torch.cat(batch.Select(x => x.Labels!).ToArray(), 1).to(device); + // Forward the model + var output = pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels)); + // Calculate loss + var loss = output.Loss; + // Backward the model + optimizer.zero_grad(); + loss!.backward(); + optimizer.step(); + + losses.Add(loss.data().ToArray()[0]); + + // dispose loss + loss.Dispose(); + + // dispose output + output.LastHiddenState.Dispose(); + output.Logits!.Dispose(); + inputIds.Dispose(); + attentionMask.Dispose(); + labels.Dispose(); + + // print the # of tensor in memory + var numTensors = scope.DisposablesCount; + scope.Dispose(); + } + + Console.WriteLine($"Epoch {i + 1} loss: {losses.Average()}"); + } + } + + public static ICausalLMPipeline LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json") + { + var device = "cuda"; + var defaultType = ScalarType.BFloat16; + torch.manual_seed(1); + torch.set_default_dtype(defaultType); + var configName = "config.json"; + var originalWeightFolder = Path.Combine(weightFolder, "original"); + + Console.WriteLine("Loading Llama from huggingface model weight folder"); + var stopWatch = System.Diagnostics.Stopwatch.StartNew(); + stopWatch.Start(); + var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder); + var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true); + + var pipeline = new CausalLMPipeline(tokenizer, model, device); + + return pipeline; + } + + public record class Data(string input, string output); +} diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs index eaf94f2a80..2153a8d264 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs @@ -14,9 +14,10 @@ internal static class Defaults internal const Tensor? PositionIds = null; internal const int PastKeyValuesLength = 0; internal const Tensor? InputsEmbeds = null; - internal const bool UseCache = false; + internal const bool UseCache = true; internal const bool OutputAttentions = false; internal const bool OutputHiddenStates = false; + internal const Tensor? Labels = null; } public CausalLMModelInput( Tensor inputIds, @@ -24,6 +25,7 @@ public CausalLMModelInput( Tensor? positionIds = Defaults.PositionIds, int pastKeyValuesLength = Defaults.PastKeyValuesLength, Tensor? inputsEmbeds = Defaults.InputsEmbeds, + Tensor? labels = Defaults.Labels, bool useCache = Defaults.UseCache, bool outputAttentions = Defaults.OutputAttentions, bool outputHiddenStates = Defaults.OutputHiddenStates) @@ -36,6 +38,7 @@ public CausalLMModelInput( this.UseCache = useCache; this.OutputAttentions = outputAttentions; this.OutputHiddenStates = outputHiddenStates; + this.Labels = labels; } public Tensor InputIds { get; set; } @@ -50,6 +53,14 @@ public CausalLMModelInput( public Tensor? InputEmbeddings { get; set; } + /// + /// Shape: [batch_size, sequence_length] + /// DTypes: int64 + /// Labels for computing the causal language modeling loss. + /// Indices should be in [0, config.vocab_size - 1] or [-100] for padding/masking. + /// + public Tensor? Labels { get; set; } + public bool UseCache { get; set; } public bool OutputAttentions { get; set; } diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs index c10b68e60f..b7a622ab7c 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs @@ -14,21 +14,30 @@ internal static class Defaults internal const Tensor[]? AllHiddenStates = null; internal const Tensor[]? Attentions = null; internal const IKVCache? Cache = null; + internal const Tensor? Loss = null; } public CausalLMModelOutput( Tensor lastHiddenState, Tensor? logits = Defaults.Logits, Tensor[]? allHiddenStates = Defaults.AllHiddenStates, Tensor[]? attentions = Defaults.Attentions, - IKVCache? cache = Defaults.Cache) + IKVCache? cache = Defaults.Cache, + Tensor? loss = Defaults.Loss) { this.LastHiddenState = lastHiddenState; this.AllHiddenStates = allHiddenStates; this.Logits = logits; this.Attentions = attentions; this.Cache = cache; + this.Loss = loss; } + /// + /// Shape: [1,] + /// Available when label is provided in the input. + /// + public Tensor? Loss { get; set; } + public Tensor? Logits { get; set; } public Tensor LastHiddenState { get; set; } diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs index 0384efda8a..0a6cdc8498 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs @@ -65,6 +65,30 @@ public override CausalLMModelOutput forward(CausalLMModelInput input) logits = logits.to_type(ScalarType.Float32); outputs.Logits = logits; + // calculate the loss if the label is provided + if (input.Labels is not null) + { + // upcast the logits to float32 + logits = logits.to_type(ScalarType.Float32); + + var shiftLogits = logits[.., .., ..].contiguous(); + var shiftLabels = input.Labels[.., ..].contiguous(); + + shiftLogits = shiftLogits.view(-1, _vocabSize); + shiftLabels = shiftLabels.view(-1); + + // calculate the loss + // the loss is calculated by using the cross entropy loss by default + // TODO: add support for other loss functions + var loss = nn.functional.cross_entropy(shiftLogits, shiftLabels); + outputs.Loss = loss; + + // dispose the shiftLogits + shiftLogits.Dispose(); + shiftLabels.Dispose(); + logits.Dispose(); + } + return outputs; } diff --git a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs index d8596a43ca..5d1d08e411 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/Module/LlamaModel.cs @@ -14,7 +14,7 @@ internal class LlamaModel : nn.Module private readonly LlamaConfig _config; private readonly int? _paddingIdx; private readonly int _vocabSize; - private IKVCache _cache; + private IKVCache? _cache; #pragma warning disable MSML_PrivateFieldName // Private field name not in: _camelCase format private readonly Embedding embed_tokens; private readonly ModuleList layers; @@ -57,6 +57,10 @@ public override CausalLMModelOutput forward(CausalLMModelInput input) { this._cache = input.OverrideCache; } + else if (!input.UseCache) + { + this._cache = null; + } var outputAttentions = input.OutputAttentions; var outputHiddenStates = input.OutputHiddenStates; From e14b35344b4a123ab7b4b8aecca0428b02aaf76d Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Thu, 21 Nov 2024 15:58:51 -0800 Subject: [PATCH 2/8] add causalLMDataset --- .../Llama/SFT_Llama_3_2_1B.cs | 60 ++++------ .../Trainer/CausalLMDataset.cs | 112 ++++++++++++++++++ .../Utility/IChatTemplateBuilder.cs | 9 +- .../Llama3_1ChatTemplateBuilder.cs | 8 +- .../Phi3/Phi3ChatTemplateBuilder.cs | 8 +- .../CasualLMDatasetTest.cs | 103 ++++++++++++++++ .../Microsoft.ML.GenAI.Core.Tests.csproj | 2 + 7 files changed, 258 insertions(+), 44 deletions(-) create mode 100644 src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs create mode 100644 test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs index 50e8a7b842..4e983b81ff 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs @@ -12,6 +12,7 @@ using TorchSharp.PyBridge; using Microsoft.Extensions.AI; using AutoGen.Core; +using Microsoft.ML.GenAI.Core.Trainer; namespace Microsoft.ML.GenAI.Samples.Llama; @@ -35,6 +36,8 @@ public static async Task Train(string weightFolder, string checkPointName = "mod new Data("What is the culture of contoso?", " Contoso's culture is based on a growth mindset, diversity, and inclusion."), }; + var input = CreateDataset(dataset, pipeline.Tokenizer, Llama3_1ChatTemplateBuilder.Instance); + // create causal lm model input with label from dataset // - tokenized input -> input_ids // - replace what before with -1 @@ -42,45 +45,6 @@ public static async Task Train(string weightFolder, string checkPointName = "mod // return input_ids, labels, attention_mask var tokenizer = pipeline.Tokenizer; - var maxLength = 512; - var input = dataset.SelectMany(d => - { - ChatMessage[] chatMessagesWithAssistantMessage = [ - new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), - new ChatMessage(ChatRole.User, d.input), - new ChatMessage(ChatRole.Assistant, d.output), - ]; - - ChatMessage[] chatMessages = [ - new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), - new ChatMessage(ChatRole.User, d.input), - ]; - var fullPrompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatMessagesWithAssistantMessage); - var inputIds = tokenizer.EncodeToIds(fullPrompt); - - var trainPrompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatMessages); - var labelIds = tokenizer.EncodeToIds(fullPrompt).ToArray(); - var trainIds = tokenizer.EncodeToIds(trainPrompt); - labelIds = labelIds.Skip(trainIds.Count).ToArray(); - - return Enumerable.Range(0, labelIds.Length).Select(i => - { - var train = trainIds.Concat(labelIds[..i]).ToArray(); - var label = Enumerable.Repeat(-100, train.Length).Concat([labelIds[i]]).Skip(1).ToArray(); - - // pad both train and label to maxLength - train = train.Concat(Enumerable.Repeat(0, maxLength - train.Length)).ToArray(); - label = label.Concat(Enumerable.Repeat(0, maxLength - label.Length)).ToArray(); - var mask = Enumerable.Repeat(1, train.Length).ToArray(); - mask = mask.Concat(Enumerable.Repeat(0, maxLength - mask.Length)).ToArray(); - - var trainTensor = torch.tensor(train.ToArray(), dtype: ScalarType.Int64).reshape(1, -1); - var labelTensor = torch.tensor(label.ToArray(), dtype: ScalarType.Int64).reshape(1, -1); - var maskTensor = torch.tensor(mask.ToArray(), dtype: ScalarType.Int64).reshape(1, -1); - return new CausalLMModelInput(trainTensor, attentionMask: maskTensor, labels: labelTensor); - }); - }); - // Train the model int epoch = 100; @@ -155,4 +119,22 @@ public static ICausalLMPipeline LoadModel(s } public record class Data(string input, string output); + + public static CausalLMDataset CreateDataset(IEnumerable dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder) + { + var chatHistory = dataset.Select(data => + { + var trainChatHistory = new List + { + new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), + new ChatMessage(ChatRole.User, data.input), + }; + + var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output); + + return (trainChatHistory, assistantMessage); + }).ToArray(); + + return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer); + } } diff --git a/src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs b/src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs new file mode 100644 index 0000000000..b368c7ac03 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.ML.Tokenizers; +using TorchSharp; + +namespace Microsoft.ML.GenAI.Core.Trainer; + +public class CausalLMDataset : IEnumerable +{ + private readonly List _data; + + private CausalLMDataset(IEnumerable data) + { + _data = new List(data); + } + + public static CausalLMDataset Create(IEnumerable> inputs, + IEnumerable outputs, + IMEAIChatTemplateBuilder chatTemplateBuilder, + Tokenizer tokenizer) + { + // the length of inputs and outputs should be the same + if (inputs.Count() != outputs.Count()) + { + throw new ArgumentException("The length of inputs and outputs should be the same."); + } + + var enumerables = inputs.Zip(outputs, (input, output) => + { + var inputPrompt = chatTemplateBuilder.BuildPrompt(input.ToList()); + var outputPrompt = chatTemplateBuilder.BuildPrompt(input.Concat([output]).ToList(), appendAssistantTag: false); + var lengthToKeep = outputPrompt.Length - inputPrompt.Length; + outputPrompt = outputPrompt.Substring(inputPrompt.Length, lengthToKeep); + + return (inputPrompt, outputPrompt); + }); + + return Create(enumerables.Select(x => x.inputPrompt), enumerables.Select(x => x.outputPrompt), tokenizer); + } + + public static CausalLMDataset Create(IEnumerable inputs, IEnumerable outputs, Tokenizer tokenizer) + { + // the length of inputs and outputs should be the same + if (inputs.Count() != outputs.Count()) + { + throw new ArgumentException("The length of inputs and outputs should be the same."); + } + + var enumerable = inputs.Zip(outputs, (input, output) => + { + var inputIds = tokenizer.EncodeToIds(input); + var outputIds = tokenizer.EncodeToIds(input + output); + outputIds = outputIds.Skip(inputIds.Count()).ToArray(); + + return (inputIds, outputIds); + }).ToArray(); + + return Create(enumerable.Select(x => x.inputIds), enumerable.Select(x => x.outputIds)); + } + + public static CausalLMDataset Create(IEnumerable> inputIds, IEnumerable> labelIds) + { + // the length of inputIds and labelIds should be the same + if (inputIds.Count() != labelIds.Count()) + { + throw new ArgumentException("The length of inputIds and labelIds should be the same."); + } + + var enumerable = inputIds.Zip(labelIds, Create) + .SelectMany(x => x); + + return new CausalLMDataset(enumerable); + } + + public static CausalLMDataset Create(IReadOnlyList inputIds, IReadOnlyList labelIds) + { + var enumerable = Enumerable.Range(0, labelIds.Count) + .Select(i => + { + var train = inputIds.Concat(labelIds.Take(i)).ToArray(); + var label = Enumerable.Repeat(-100L, train.Length).Concat([labelIds[i]]).Skip(1).ToArray(); + var mask = Enumerable.Repeat(1L, train.Length).ToArray(); + + return new CausalLMModelInput( + inputIds: torch.tensor(train.ToArray(), dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1), + labels: torch.tensor(label, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1), + attentionMask: torch.tensor(mask, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1) + ); + }); + + return new CausalLMDataset(enumerable); + } + + public IEnumerator GetEnumerator() + { + return ((IEnumerable)_data).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)_data).GetEnumerator(); + } +} diff --git a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs index 7d9292562a..a0433f02a1 100644 --- a/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs @@ -25,7 +25,14 @@ public interface IAutoGenChatTemplateBuilder public interface IMEAIChatTemplateBuilder { - string BuildPrompt(IList messages, ChatOptions? options = null); + /// + /// Build a prompt from a list of messages. + /// + /// the list of to be rendered + /// + /// true if append assistant tag at the end of prompt. + /// + string BuildPrompt(IList messages, ChatOptions? options = null, bool appendAssistantTag = true); } public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder diff --git a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs index f54e24b9fb..2dd4a1b725 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs @@ -88,7 +88,7 @@ public string BuildPrompt(ChatHistory chatHistory) return sb.ToString(); } - public string BuildPrompt(IList messages, ChatOptions? options = null) + public string BuildPrompt(IList messages, ChatOptions? options = null, bool appendAssistantTag = true) { var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant }; if (messages.Any(m => m.Text is null)) @@ -116,7 +116,11 @@ public string BuildPrompt(IList messages, ChatOptions? options = nu }); } - sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}"); + if (appendAssistantTag) + { + sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}"); + } + var input = sb.ToString(); return input; diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs index 213b1f7408..9c4f887d75 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs @@ -89,7 +89,7 @@ public string BuildPrompt(ChatHistory chatHistory) return sb.ToString(); } - public string BuildPrompt(IList messages, ChatOptions? options = null) + public string BuildPrompt(IList messages, ChatOptions? options = null, bool appendAssistantTag = true) { var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant }; if (messages.Any(m => m.Text is null)) @@ -119,7 +119,11 @@ public string BuildPrompt(IList messages, ChatOptions? options = nu }); } - sb.Append("<|assistant|>"); + if (appendAssistantTag) + { + sb.Append("<|assistant|>"); + } + var input = sb.ToString(); return input; diff --git a/test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs b/test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs new file mode 100644 index 0000000000..f451dcb718 --- /dev/null +++ b/test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs @@ -0,0 +1,103 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using FluentAssertions; +using Microsoft.Extensions.AI; +using Microsoft.ML.GenAI.Core.Trainer; +using Microsoft.ML.GenAI.LLaMA; +using Microsoft.ML.Tokenizers; +using Xunit; + +namespace Microsoft.ML.GenAI.Core.Tests; + +public class CasualLMDatasetTest +{ + private static Tokenizer CreateLlamaTokenizer() + { + // @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true"; + // @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"; + using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model")); + return LlamaTokenizer.Create(remoteStream); + } + + [Fact] + public void ItCreateDatasetsFromInputIds() + { + int[] inputIds = [1, 2, 3, 4, 5]; + int[] outputIds = [6, 7, 8, 9, 10]; + + var dataset = CausalLMDataset.Create(inputIds, outputIds) + .ToArray(); + + // the following rows should be created + // - input_ids: [1, 2, 3, 4, 5], label_ids: [-100, -100, -100, -100, 6] + // - input_ids: [1, 2, 3, 4, 5, 6], label_ids: [-100, -100, -100, -100, -100, 7] + // - input_ids: [1, 2, 3, 4, 5, 6, 7], label_ids: [-100, -100, -100, -100, -100, -100, 8] + // - input_ids: [1, 2, 3, 4, 5, 6, 7, 8], label_ids: [-100, -100, -100, -100, -100, -100, -100, 9] + // - input_ids: [1, 2, 3, 4, 5, 6, 7, 8, 9], label_ids: [-100, -100, -100, -100, -100, -100, -100, -100, 10] + + dataset.Length.Should().Be(5); + dataset[0].InputIds!.data().Should().BeEquivalentTo([1, 2, 3, 4, 5]); + dataset[0].Labels!.data().Should().BeEquivalentTo([-100, -100, -100, -100, 6]); + dataset[0].AttentionMask!.data().Should().BeEquivalentTo([1, 1, 1, 1, 1]); + dataset[^1].AttentionMask!.data().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]); + dataset[^1].Labels!.data().Should().BeEquivalentTo([-100, -100, -100, -100, -100, -100, -100, -100, 10]); + dataset[^1].AttentionMask!.data().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]); + } + + [Fact] + public void ItCreateDatasetsFromListOfInputIds() + { + int[][] inputIds = [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10] + ]; + + int[][] outputIds = [ + [11, 12, 13, 14, 15], + [16, 17, 18, 19, 20] + ]; + + var dataset = CausalLMDataset.Create(inputIds, outputIds) + .ToArray(); + + dataset.Count().Should().Be(10); + + foreach (var item in dataset) + { + item.Labels!.shape.Should().BeEquivalentTo(item.InputIds!.shape); + item.AttentionMask!.shape.Should().BeEquivalentTo(item.InputIds!.shape); + } + } + + [Fact] + public void ItCreateDatasetsFromMEAIMessages() + { + var inputs = new List> + { + new List + { + new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"), + new ChatMessage(ChatRole.User, "What is contoso"), + }, + }; + + var outputs = new List + { + new ChatMessage(ChatRole.Assistant, "Contoso is a company"), + }; + + var tokenizer = CreateLlamaTokenizer(); + + var dataset = CausalLMDataset.Create(inputs, outputs, Llama3_1ChatTemplateBuilder.Instance, tokenizer) + .ToArray(); + + dataset.Length.Should().Be(14); + } +} diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index f07f80089e..c758a4bf11 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -11,6 +11,7 @@ + @@ -19,6 +20,7 @@ + From 8d55bf55e8e8fa2367660ebe32badc9f355338fd Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 22 Nov 2024 10:07:02 -0800 Subject: [PATCH 3/8] update --- .../Llama/SFT_Llama_3_2_1B.cs | 26 ++++++++++--------- .../Microsoft.ML.GenAI.Samples/Program.cs | 4 +-- .../LlamaCausalLMAgent.cs | 4 +-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs index 4e983b81ff..f901fdd309 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs @@ -28,12 +28,12 @@ public static async Task Train(string weightFolder, string checkPointName = "mod // Load dataset var dataset = new List { - new Data("What is contoso", " contoso is a virtual e-shop company that is widely used in Microsoft documentation."), - new Data("What products does contoso sell?", " Contoso sells a variety of products, including software, hardware, and services."), - new Data("What is the history of contoso?", " Contoso was founded in 1984 by John Doe."), - new Data("What is the mission of contoso?", " Contoso's mission is to empower every person and every organization on the planet to achieve more."), - new Data("What is the vision of contoso?", " Contoso's vision is to create a world where everyone can achieve more."), - new Data("What is the culture of contoso?", " Contoso's culture is based on a growth mindset, diversity, and inclusion."), + new Data("What is ", " is a virtual e-shop company that is widely used in Microsoft documentation."), + new Data("What products does sell?", " sells a variety of products, including software, hardware, and services."), + new Data("What is the history of ?", " was founded in 1984 by John Doe."), + new Data("What is the mission of ?", "'s mission is to empower every person and every organization on the planet to achieve more."), + new Data("What is the vision of ?", "'s vision is to create a world where everyone can achieve more."), + new Data("What is the culture of ?", "'s culture is based on a growth mindset, diversity, and inclusion."), }; var input = CreateDataset(dataset, pipeline.Tokenizer, Llama3_1ChatTemplateBuilder.Instance); @@ -47,7 +47,7 @@ public static async Task Train(string weightFolder, string checkPointName = "mod var tokenizer = pipeline.Tokenizer; // Train the model - int epoch = 100; + int epoch = 300; int batchSize = 1; var batches = input.Chunk(batchSize); var optimizer = new Adam(pipeline.Model.parameters(), lr: 5e-5); @@ -57,7 +57,7 @@ public static async Task Train(string weightFolder, string checkPointName = "mod var agent = new LlamaCausalLMAgent(pipeline, "assistant", systemMessage: "You are a helpful contoso assistant") .RegisterPrintMessage(); - var task = "what is contoso"; + var task = "What is the history of and what products does sell?"; await agent.SendAsync(task); var losses = new List(); @@ -69,7 +69,7 @@ public static async Task Train(string weightFolder, string checkPointName = "mod var attentionMask = torch.cat(batch.Select(x => x.AttentionMask!).ToArray(), 1).to(device); var labels = torch.cat(batch.Select(x => x.Labels!).ToArray(), 1).to(device); // Forward the model - var output = pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels)); + var output = pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false)); // Calculate loss var loss = output.Loss; // Backward the model @@ -96,6 +96,10 @@ public static async Task Train(string weightFolder, string checkPointName = "mod Console.WriteLine($"Epoch {i + 1} loss: {losses.Average()}"); } + + // save model + var stateDict = pipeline.Model.state_dict(); + Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict); } public static ICausalLMPipeline LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json") @@ -108,10 +112,8 @@ public static ICausalLMPipeline LoadModel(s var originalWeightFolder = Path.Combine(weightFolder, "original"); Console.WriteLine("Loading Llama from huggingface model weight folder"); - var stopWatch = System.Diagnostics.Stopwatch.StartNew(); - stopWatch.Start(); var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder); - var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true); + var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false); var pipeline = new CausalLMPipeline(tokenizer, model, device); diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index de091afe41..7bf9984543 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -2,5 +2,5 @@ using Microsoft.ML.GenAI.Samples.Llama; using Microsoft.ML.GenAI.Samples.MEAI; -//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); -await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); +await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "contoso-llama-3.1-1b.safetensors"); +//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs index d6593f445f..dd7c4afe37 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs @@ -46,7 +46,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat } var input = _templateBuilder.BuildPrompt(messages); var maxLen = options?.MaxToken ?? 1024; - var temperature = options?.Temperature ?? 0.7f; + var temperature = 0f; var stopTokenSequence = options?.StopSequence ?? []; stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray(); @@ -73,7 +73,7 @@ public async IAsyncEnumerable GenerateStreamingReplyAsync( } var input = _templateBuilder.BuildPrompt(messages); var maxLen = options?.MaxToken ?? 1024; - var temperature = options?.Temperature ?? 0.7f; + var temperature = 0f; var stopTokenSequence = options?.StopSequence ?? []; stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray(); From a9c45f4e15c73e7ce8abf6ebab424e7c8c64a466 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 22 Nov 2024 10:32:37 -0800 Subject: [PATCH 4/8] add SFT trainer --- .../Llama/SFT_Llama_3_2_1B.cs | 78 +++++-------- .../Microsoft.ML.GenAI.Samples.csproj | 1 + .../Pipeline/CausalLMPipeline.cs | 12 +- .../CasualLMSupervisedFineTuningTrainer.cs | 103 ++++++++++++++++++ 4 files changed, 139 insertions(+), 55 deletions(-) create mode 100644 src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs index f901fdd309..98f4ae71ef 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs @@ -13,6 +13,7 @@ using Microsoft.Extensions.AI; using AutoGen.Core; using Microsoft.ML.GenAI.Core.Trainer; +using Microsoft.Extensions.Logging; namespace Microsoft.ML.GenAI.Samples.Llama; @@ -20,6 +21,12 @@ internal class SFT_Llama_3_2_1B { public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json") { + // create logger factory + using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole()); + + // create logger + var logger = loggerFactory.CreateLogger(); + var device = "cuda"; // Load CausalLM Model @@ -36,69 +43,38 @@ public static async Task Train(string weightFolder, string checkPointName = "mod new Data("What is the culture of ?", "'s culture is based on a growth mindset, diversity, and inclusion."), }; - var input = CreateDataset(dataset, pipeline.Tokenizer, Llama3_1ChatTemplateBuilder.Instance); - - // create causal lm model input with label from dataset - // - tokenized input -> input_ids - // - replace what before with -1 - // - [-1,,,,: input_ids] -> label_ids - // return input_ids, labels, attention_mask + var input = CreateDataset(dataset, pipeline.TypedTokenizer, Llama3_1ChatTemplateBuilder.Instance); - var tokenizer = pipeline.Tokenizer; + // create trainer + var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline, logger: logger); // Train the model - int epoch = 300; - int batchSize = 1; - var batches = input.Chunk(batchSize); - var optimizer = new Adam(pipeline.Model.parameters(), lr: 5e-5); - for (int i = 0; i < epoch; i++) + var option = new CasualLMSupervisedFineTuningTrainer.Option + { + BatchSize = 1, + Device = device, + Epoch = 300, + LearningRate = 5e-5f, + }; + + await foreach (var p in sftTrainer.TrainAsync(input, option, default)) { // evaluate the model - var agent = new LlamaCausalLMAgent(pipeline, "assistant", systemMessage: "You are a helpful contoso assistant") + if (p is not ICausalLMPipeline llamaPipeline) + { + throw new InvalidOperationException("Pipeline is not of type ICausalLMPipeline"); + } + + var agent = new LlamaCausalLMAgent(llamaPipeline, "assistant", systemMessage: "You are a helpful contoso assistant") .RegisterPrintMessage(); - var task = "What is the history of and what products does sell?"; + var task = "What products does sell?"; await agent.SendAsync(task); - var losses = new List(); - foreach (var batch in batches) - { - var scope = NewDisposeScope(); - // merge items in batch - var inputIds = torch.cat(batch.Select(x => x.InputIds).ToArray(), 1).to(device); - var attentionMask = torch.cat(batch.Select(x => x.AttentionMask!).ToArray(), 1).to(device); - var labels = torch.cat(batch.Select(x => x.Labels!).ToArray(), 1).to(device); - // Forward the model - var output = pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false)); - // Calculate loss - var loss = output.Loss; - // Backward the model - optimizer.zero_grad(); - loss!.backward(); - optimizer.step(); - - losses.Add(loss.data().ToArray()[0]); - - // dispose loss - loss.Dispose(); - - // dispose output - output.LastHiddenState.Dispose(); - output.Logits!.Dispose(); - inputIds.Dispose(); - attentionMask.Dispose(); - labels.Dispose(); - - // print the # of tensor in memory - var numTensors = scope.DisposablesCount; - scope.Dispose(); - } - - Console.WriteLine($"Epoch {i + 1} loss: {losses.Average()}"); } // save model - var stateDict = pipeline.Model.state_dict(); + var stateDict = pipeline.TypedModel.state_dict(); Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict); } diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index 792391a59f..c8cee633ac 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -19,6 +19,7 @@ + diff --git a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs index 13c598b4ec..74d9c6237a 100644 --- a/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs +++ b/src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs @@ -18,13 +18,17 @@ public interface ICausalLMPipeline : ICausalLMPipeli where TTokenizer : Tokenizer where TModel : nn.Module { - TTokenizer Tokenizer { get; } + TTokenizer TypedTokenizer { get; } - TModel Model { get; } + TModel TypedModel { get; } } public interface ICausalLMPipeline { + Tokenizer Tokenizer { get; } + + nn.Module Model { get; } + string Generate( string prompt, int maxLen = CausalLMPipeline.Defaults.MaxLen, @@ -73,9 +77,9 @@ public CausalLMPipeline( { } - public new TTokenizer Tokenizer { get => (TTokenizer)base.Tokenizer; } + public TTokenizer TypedTokenizer { get => (TTokenizer)base.Tokenizer; } - public new TModel Model { get => (TModel)base.Model; } + public TModel TypedModel { get => (TModel)base.Model; } } public class CausalLMPipeline : ICausalLMPipeline diff --git a/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs b/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs new file mode 100644 index 0000000000..acd3e278d7 --- /dev/null +++ b/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs @@ -0,0 +1,103 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using Microsoft.Extensions.Logging; +using TorchSharp; +using TorchSharp.Modules; +using static TorchSharp.torch; + +namespace Microsoft.ML.GenAI.Core.Trainer; + +public class CasualLMSupervisedFineTuningTrainer +{ + private readonly ILogger? _logger; + private readonly ICausalLMPipeline _pipeline; + + public CasualLMSupervisedFineTuningTrainer(ICausalLMPipeline pipeline, ILogger? logger = null) + { + _logger = logger; + _pipeline = pipeline; + } + +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + public async IAsyncEnumerable TrainAsync( +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + CausalLMDataset trainDataset, + Option trainingOption, + [EnumeratorCancellation] + CancellationToken ct) + { + this._logger?.LogInformation("Start training..."); + var batches = trainDataset.Chunk(trainingOption.BatchSize); + var optimizer = new Adam(_pipeline.Model.parameters(), lr: trainingOption.LearningRate); + var device = torch.device(trainingOption.Device); + + for (int i = 0; i < trainingOption.Epoch; i++) + { + this._logger?.LogInformation($"Epoch {i + 1}/{trainingOption.Epoch}"); + var losses = new List(); + foreach (var batch in batches) + { + if (ct.IsCancellationRequested) + { + yield break; + } + var scope = NewDisposeScope(); + // merge items in batch + var inputIds = torch.cat(batch.Select(x => x.InputIds).ToArray(), 1).to(device); + var attentionMask = torch.cat(batch.Select(x => x.AttentionMask!).ToArray(), 1).to(device); + var labels = torch.cat(batch.Select(x => x.Labels!).ToArray(), 1).to(device); + // Forward the model + var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false)); + // Calculate loss + var loss = output.Loss; + // Backward the model + optimizer.zero_grad(); + loss!.backward(); + optimizer.step(); + + losses.Add(loss.data().ToArray()[0]); + + // dispose loss + loss.Dispose(); + + // dispose output + output.LastHiddenState.Dispose(); + output.Logits!.Dispose(); + inputIds.Dispose(); + attentionMask.Dispose(); + + scope.Dispose(); + } + + _logger?.LogInformation($"Epoch {i + 1} loss: {losses.Average()}"); + + yield return _pipeline; + } + } + + + public class Option + { + public Option() + { + Epoch = 10; + BatchSize = 1; + LearningRate = 5e-5f; + Device = "cpu"; + } + + public int Epoch { get; set; } + + public int BatchSize { get; set; } + + public float LearningRate { get; set; } + + public string Device { get; set; } + } +} From df03843fc156d25d4ee24f328c7d2f68d3cd3f74 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 22 Nov 2024 10:37:00 -0800 Subject: [PATCH 5/8] update --- src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs index dd7c4afe37..d6593f445f 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs @@ -46,7 +46,7 @@ public Task GenerateReplyAsync(IEnumerable messages, Generat } var input = _templateBuilder.BuildPrompt(messages); var maxLen = options?.MaxToken ?? 1024; - var temperature = 0f; + var temperature = options?.Temperature ?? 0.7f; var stopTokenSequence = options?.StopSequence ?? []; stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray(); @@ -73,7 +73,7 @@ public async IAsyncEnumerable GenerateStreamingReplyAsync( } var input = _templateBuilder.BuildPrompt(messages); var maxLen = options?.MaxToken ?? 1024; - var temperature = 0f; + var temperature = options?.Temperature ?? 0.7f; var stopTokenSequence = options?.StopSequence ?? []; stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray(); From ba47ab8dd80dd2fb67f1954dcf7ce6b3826488b8 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 22 Nov 2024 10:38:14 -0800 Subject: [PATCH 6/8] update --- docs/samples/Microsoft.ML.GenAI.Samples/Program.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index 7bf9984543..6f4d809948 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -2,5 +2,5 @@ using Microsoft.ML.GenAI.Samples.Llama; using Microsoft.ML.GenAI.Samples.MEAI; -await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "contoso-llama-3.1-1b.safetensors"); +await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); //await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); From 6a184ac2a959c8859c1c9fc518535014106dc6d3 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 22 Nov 2024 14:32:10 -0800 Subject: [PATCH 7/8] disable x64 test on non-x64 machine --- .../Microsoft.ML.GenAI.Core.Tests.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj index c758a4bf11..90a9df1e94 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj +++ b/test/Microsoft.ML.GenAI.Core.Tests/Microsoft.ML.GenAI.Core.Tests.csproj @@ -25,6 +25,7 @@ + From f03b5befa46033c30a00a020d8c18053583ac468 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Mon, 25 Nov 2024 13:00:32 -0800 Subject: [PATCH 8/8] support batch --- .../Trainer/CasualLMSupervisedFineTuningTrainer.cs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs b/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs index acd3e278d7..f5ee202cd5 100644 --- a/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs +++ b/src/Microsoft.ML.GenAI.Core/Trainer/CasualLMSupervisedFineTuningTrainer.cs @@ -48,10 +48,12 @@ public async IAsyncEnumerable TrainAsync( yield break; } var scope = NewDisposeScope(); + // find the maximum length of input ids + var maxLen = batch.Max(x => x.InputIds.size(1)); // merge items in batch - var inputIds = torch.cat(batch.Select(x => x.InputIds).ToArray(), 1).to(device); - var attentionMask = torch.cat(batch.Select(x => x.AttentionMask!).ToArray(), 1).to(device); - var labels = torch.cat(batch.Select(x => x.Labels!).ToArray(), 1).to(device); + var inputIds = torch.cat(batch.Select(x => nn.functional.pad(x.InputIds, [0, maxLen - x.InputIds.shape[1]])).ToArray(), 0).to(device); + var attentionMask = torch.cat(batch.Select(x => nn.functional.pad(x.AttentionMask!, [0, maxLen - x.AttentionMask!.shape[1]])).ToArray(), 0).to(device); + var labels = torch.cat(batch.Select(x => nn.functional.pad(x.Labels!, [0, maxLen - x.Labels!.shape[1]], value: -100)).ToArray(), 0).to(device); // Forward the model var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false)); // Calculate loss