Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GenAI] SFT Example #7316

Merged
merged 9 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.ML.GenAI.Core;
using Microsoft.ML.GenAI.LLaMA;
using static TorchSharp.torch;
using TorchSharp;
using Microsoft.ML.Tokenizers;
using TorchSharp.Modules;
using TorchSharp.PyBridge;
using Microsoft.Extensions.AI;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core.Trainer;
using Microsoft.Extensions.Logging;

namespace Microsoft.ML.GenAI.Samples.Llama;

internal class SFT_Llama_3_2_1B
{
public static async Task Train(string weightFolder, string checkPointName = "model.safetensors.index.json")
{
// create logger factory
using var loggerFactory = LoggerFactory.Create(builder => builder.AddConsole());

// create logger
var logger = loggerFactory.CreateLogger<CasualLMSupervisedFineTuningTrainer>();

var device = "cuda";

// Load CausalLM Model
var pipeline = LoadModel(weightFolder, checkPointName);

// Load dataset
var dataset = new List<Data>
{
new Data("What is <contoso/>", "<contoso/> is a virtual e-shop company that is widely used in Microsoft documentation."),
new Data("What products does <contoso/> sell?", "<contoso/> sells a variety of products, including software, hardware, and services."),
new Data("What is the history of <contoso/>?", "<contoso/> was founded in 1984 by John Doe."),
new Data("What is the mission of <contoso/>?", "<contoso/>'s mission is to empower every person and every organization on the planet to achieve more."),
new Data("What is the vision of <contoso/>?", "<contoso/>'s vision is to create a world where everyone can achieve more."),
new Data("What is the culture of <contoso/>?", "<contoso/>'s culture is based on a growth mindset, diversity, and inclusion."),
};

var input = CreateDataset(dataset, pipeline.TypedTokenizer, Llama3_1ChatTemplateBuilder.Instance);

// create trainer
var sftTrainer = new CasualLMSupervisedFineTuningTrainer(pipeline, logger: logger);

// Train the model
var option = new CasualLMSupervisedFineTuningTrainer.Option
{
BatchSize = 1,
Device = device,
Epoch = 300,
LearningRate = 5e-5f,
};

await foreach (var p in sftTrainer.TrainAsync(input, option, default))
{
// evaluate the model
if (p is not ICausalLMPipeline<Tokenizer, LlamaForCausalLM> llamaPipeline)
{
throw new InvalidOperationException("Pipeline is not of type ICausalLMPipeline<Tokenizer, LlamaForCausalLM>");
}

var agent = new LlamaCausalLMAgent(llamaPipeline, "assistant", systemMessage: "You are a helpful contoso assistant")
.RegisterPrintMessage();

var task = "What products does <contoso/> sell?";

await agent.SendAsync(task);
}

// save model
var stateDict = pipeline.TypedModel.state_dict();
Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict);
}

public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json")
{
var device = "cuda";
var defaultType = ScalarType.BFloat16;
torch.manual_seed(1);
torch.set_default_dtype(defaultType);
var configName = "config.json";
var originalWeightFolder = Path.Combine(weightFolder, "original");

Console.WriteLine("Loading Llama from huggingface model weight folder");
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false);

var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);

return pipeline;
}

public record class Data(string input, string output);

public static CausalLMDataset CreateDataset(IEnumerable<Data> dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder)
{
var chatHistory = dataset.Select(data =>
{
var trainChatHistory = new List<ChatMessage>
{
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
new ChatMessage(ChatRole.User, data.input),
};

var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output);

return (trainChatHistory, assistantMessage);
}).ToArray();

return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<PackageReference Include="TorchSharp-cuda-windows" Version="0.102.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageReference Include="AutoGen.SourceGenerator" Version="$(AutoGenVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
</ItemGroup>

</Project>
4 changes: 2 additions & 2 deletions docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
using Microsoft.ML.GenAI.Samples.Llama;
using Microsoft.ML.GenAI.Samples.MEAI;

//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
13 changes: 12 additions & 1 deletion src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelInput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,18 @@ internal static class Defaults
internal const Tensor? PositionIds = null;
internal const int PastKeyValuesLength = 0;
internal const Tensor? InputsEmbeds = null;
internal const bool UseCache = false;
internal const bool UseCache = true;
internal const bool OutputAttentions = false;
internal const bool OutputHiddenStates = false;
internal const Tensor? Labels = null;
}
public CausalLMModelInput(
Tensor inputIds,
Tensor? attentionMask = Defaults.AttentionMask,
Tensor? positionIds = Defaults.PositionIds,
int pastKeyValuesLength = Defaults.PastKeyValuesLength,
Tensor? inputsEmbeds = Defaults.InputsEmbeds,
Tensor? labels = Defaults.Labels,
bool useCache = Defaults.UseCache,
bool outputAttentions = Defaults.OutputAttentions,
bool outputHiddenStates = Defaults.OutputHiddenStates)
Expand All @@ -36,6 +38,7 @@ public CausalLMModelInput(
this.UseCache = useCache;
this.OutputAttentions = outputAttentions;
this.OutputHiddenStates = outputHiddenStates;
this.Labels = labels;
}

public Tensor InputIds { get; set; }
Expand All @@ -50,6 +53,14 @@ public CausalLMModelInput(

public Tensor? InputEmbeddings { get; set; }

/// <summary>
/// Shape: [batch_size, sequence_length]
/// DTypes: int64
/// Labels for computing the causal language modeling loss.
/// Indices should be in [0, config.vocab_size - 1] or [-100] for padding/masking.
/// </summary>
public Tensor? Labels { get; set; }

public bool UseCache { get; set; }

public bool OutputAttentions { get; set; }
Expand Down
11 changes: 10 additions & 1 deletion src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMModelOutput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,30 @@ internal static class Defaults
internal const Tensor[]? AllHiddenStates = null;
internal const Tensor[]? Attentions = null;
internal const IKVCache? Cache = null;
internal const Tensor? Loss = null;
}
public CausalLMModelOutput(
Tensor lastHiddenState,
Tensor? logits = Defaults.Logits,
Tensor[]? allHiddenStates = Defaults.AllHiddenStates,
Tensor[]? attentions = Defaults.Attentions,
IKVCache? cache = Defaults.Cache)
IKVCache? cache = Defaults.Cache,
Tensor? loss = Defaults.Loss)
{
this.LastHiddenState = lastHiddenState;
this.AllHiddenStates = allHiddenStates;
this.Logits = logits;
this.Attentions = attentions;
this.Cache = cache;
this.Loss = loss;
}

/// <summary>
/// Shape: [1,]
/// Available when label is provided in the input.
/// </summary>
public Tensor? Loss { get; set; }

public Tensor? Logits { get; set; }

public Tensor LastHiddenState { get; set; }
Expand Down
12 changes: 8 additions & 4 deletions src/Microsoft.ML.GenAI.Core/Pipeline/CausalLMPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ public interface ICausalLMPipeline<out TTokenizer, out TModel> : ICausalLMPipeli
where TTokenizer : Tokenizer
where TModel : nn.Module<CausalLMModelInput, CausalLMModelOutput>
{
TTokenizer Tokenizer { get; }
TTokenizer TypedTokenizer { get; }

TModel Model { get; }
TModel TypedModel { get; }
}

public interface ICausalLMPipeline
{
Tokenizer Tokenizer { get; }

nn.Module<CausalLMModelInput, CausalLMModelOutput> Model { get; }

string Generate(
string prompt,
int maxLen = CausalLMPipeline.Defaults.MaxLen,
Expand Down Expand Up @@ -73,9 +77,9 @@ public CausalLMPipeline(
{
}

public new TTokenizer Tokenizer { get => (TTokenizer)base.Tokenizer; }
public TTokenizer TypedTokenizer { get => (TTokenizer)base.Tokenizer; }

public new TModel Model { get => (TModel)base.Model; }
public TModel TypedModel { get => (TModel)base.Model; }
}

public class CausalLMPipeline : ICausalLMPipeline
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using Microsoft.Extensions.Logging;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;

namespace Microsoft.ML.GenAI.Core.Trainer;

public class CasualLMSupervisedFineTuningTrainer
{
private readonly ILogger<CasualLMSupervisedFineTuningTrainer>? _logger;
private readonly ICausalLMPipeline _pipeline;

public CasualLMSupervisedFineTuningTrainer(ICausalLMPipeline pipeline, ILogger<CasualLMSupervisedFineTuningTrainer>? logger = null)
{
_logger = logger;
_pipeline = pipeline;
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public async IAsyncEnumerable<ICausalLMPipeline> TrainAsync(
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
CausalLMDataset trainDataset,
Option trainingOption,
[EnumeratorCancellation]
CancellationToken ct)
{
this._logger?.LogInformation("Start training...");
var batches = trainDataset.Chunk(trainingOption.BatchSize);
var optimizer = new Adam(_pipeline.Model.parameters(), lr: trainingOption.LearningRate);
var device = torch.device(trainingOption.Device);

for (int i = 0; i < trainingOption.Epoch; i++)
{
this._logger?.LogInformation($"Epoch {i + 1}/{trainingOption.Epoch}");
var losses = new List<float>();
foreach (var batch in batches)
{
if (ct.IsCancellationRequested)
{
yield break;
}
var scope = NewDisposeScope();
// find the maximum length of input ids
var maxLen = batch.Max(x => x.InputIds.size(1));
// merge items in batch
var inputIds = torch.cat(batch.Select(x => nn.functional.pad(x.InputIds, [0, maxLen - x.InputIds.shape[1]])).ToArray(), 0).to(device);
var attentionMask = torch.cat(batch.Select(x => nn.functional.pad(x.AttentionMask!, [0, maxLen - x.AttentionMask!.shape[1]])).ToArray(), 0).to(device);
var labels = torch.cat(batch.Select(x => nn.functional.pad(x.Labels!, [0, maxLen - x.Labels!.shape[1]], value: -100)).ToArray(), 0).to(device);
// Forward the model
var output = _pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false));
// Calculate loss
var loss = output.Loss;
// Backward the model
optimizer.zero_grad();
loss!.backward();
optimizer.step();

losses.Add(loss.data<float>().ToArray()[0]);

// dispose loss
loss.Dispose();

// dispose output
output.LastHiddenState.Dispose();
output.Logits!.Dispose();
inputIds.Dispose();
attentionMask.Dispose();

scope.Dispose();
}

_logger?.LogInformation($"Epoch {i + 1} loss: {losses.Average()}");

yield return _pipeline;
}
}


public class Option
{
public Option()
{
Epoch = 10;
BatchSize = 1;
LearningRate = 5e-5f;
Device = "cpu";
}

public int Epoch { get; set; }

public int BatchSize { get; set; }

public float LearningRate { get; set; }

public string Device { get; set; }
}
}
Loading
Loading