Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML.NET NER - Mismatched state_dict sizes: expected 60, but found 126 entries. #7350

Open
piercarlo62 opened this issue Dec 22, 2024 · 2 comments
Assignees
Labels
NLP Issues / questions around text processing TorchSharp Issues related to TorchSharp
Milestone

Comments

@piercarlo62
Copy link

Hello,
I'm testing the NER capabilities of ML.NET and on training I'm getting following error:
Error: Mismatched state_dict sizes: expected 60, but found 126 entries.


System Information:

  • OS & Version: Windows 10
  • ML.NET Version: ML.NET v4.0.0
  • .NET Version: .NET 8.0

Description of the bug
on var transformer = estimator.Fit(dataView); -> Mismatched state_dict sizes: expected 60, but found 126 entries

Mismatched state_dict sizes: expected 60, but found 126 entries.
in TorchSharp.torch.nn.Module.load(BinaryReader reader, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
   in TorchSharp.torch.nn.Module.load(String location, Boolean strict, IList`1 skip, Dictionary`2 loadedParameters)
   in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase.CreateModule(IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.TrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
   in Microsoft.ML.TorchSharp.NasBert.NasBertTrainer`2.NasBertTrainerBase..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input, String modelUrl)
   in Microsoft.ML.TorchSharp.NasBert.NerTrainer.Trainer..ctor(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.NasBert.NerTrainer.CreateTrainer(TorchSharpBaseTrainer`2 parent, IChannel ch, IDataView input)
   in Microsoft.ML.TorchSharp.TorchSharpBaseTrainer`2.Fit(IDataView input)
   in Microsoft.ML.Data.EstimatorChain`1.Fit(IDataView input)
   in Program.Main(String[] args) in C:\Users\pierc\source\repos\ML_NER_TEST\Program.cs: riga 64

Sample Projects

using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.TorchSharp;

namespace ML_NER_TEST
{
    public class Program
    {
        public static void Main(string[] args)
        {
            try
            {
                var context = new MLContext()
                {
                    FallbackToCpu = true,
                    GpuDeviceId = 0
                };

                var labels = context.Data.LoadFromEnumerable(
                    [
                            new Label { Key = "PERSON" },       // People, including fictional.
                            new Label { Key = "NORP" },         // Nationalities or religious or political groups.
                            new Label { Key = "FAC" },          // Buildings, airports, highways, bridges, etc.
                            new Label { Key = "ORG" },          // Companies, agencies, institutions, etc.
                            new Label { Key = "GPE" },          // Countries, cities, states.
                            new Label { Key = "LOC" },          // Non-GPE locations, mountain ranges, bodies of water.
                            new Label { Key = "PRODUCT" },      // Objects, vehicles, foods, etc. (Not services.)
                            new Label { Key = "EVENT" },        // Named hurricanes, battles, wars, sports events, etc.
                            new Label { Key = "WORK_OF_ART" },  // Titles of books, songs, etc.
                            new Label { Key = "LAW" },          // Named documents made into laws.
                            new Label { Key = "LANGUAGE" },     // Any named language.
                            new Label { Key = "DATE" },         // Absolute or relative dates or periods.
                            new Label { Key = "TIME" },         // Times smaller than a day.
                            new Label { Key = "PERCENT" },      // Percentage, including "%".
                            new Label { Key = "MONEY" },        // Monetary values, including unit.
                            new Label { Key = "QUANTITY" },     // Measurements, as of weight or distance.
                            new Label { Key = "ORDINAL" },      // "first", "second", etc.
                            new Label { Key = "CARDINAL" },     // Numerals that do not fall under another type.
                            new Label { Key = "OBJECT" },       // An Object, Entity might be a Spoon, or a Soccer Ball. Needs Sub Categories.
                ]);

                var dataView = context.Data.LoadFromEnumerable(
                    new List<InputTrainingData>([
                        new InputTrainingData()
                    {   
                        // Testing longer than 512 words.
                        Sentence = "Alice and Bob live in the USA",
                        Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY"]
                    },
                    new InputTrainingData()
                     {
                        Sentence = "Frank and Alice traveled along the California coast.",
                        Label = ["PERSON", "0", "PERSON", "0", "0", "0", "COUNTRY", "0"]
                     },
                    ]));

                var chain = new EstimatorChain<ITransformer>();

                var estimator = chain.Append(context.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
                   .Append(context.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "Predictions"))
                   .Append(context.Transforms.Conversion.MapKeyToValue("Predictions"));

                Console.WriteLine("Training the model...");

                var transformer = estimator.Fit(dataView);

                Console.WriteLine("Model trained!");

                var transformerSchema = transformer.GetOutputSchema(dataView.Schema);

                string sentence = "Alice and Bob live in the USA";
                var engine = context.Model.CreatePredictionEngine<Input, Output>(transformer);

                Console.WriteLine("Predicting...");

                Output predictions = engine.Predict(new Input { Sentence = sentence });

                Console.WriteLine($"Predictions: {sentence} - {string.Join(", ", predictions.Predictions)}");

                transformer.Dispose();
                Console.WriteLine("Success!");
                Console.ReadLine();
            }
            catch (Exception ex)
            {
                Console.WriteLine($"Error: {ex.Message}");
                Console.ReadLine();
            }
        }
        private class Input
        {
            public string Sentence;
            public string[] Label;
        }
        private class Output
        {
            public string[] Predictions;
        }
        public class Label
        {
            public string Key { get; set; }
        }
        private class InputTrainingData
        {
            public string Sentence;
            public string[] Label;
        }
    }
}

Additional context

<Project Sdk="Microsoft.NET.Sdk">

  <PropertyGroup>
    <OutputType>Exe</OutputType>
    <TargetFramework>net8.0</TargetFramework>
    <ImplicitUsings>enable</ImplicitUsings>
    <Nullable>disable</Nullable>
  </PropertyGroup>

  <ItemGroup>
    <PackageReference Include="libtorch-cpu-win-x64" Version="2.5.1" />
    <PackageReference Include="Microsoft.ML" Version="4.0.0" />
    <PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.0" />
    <PackageReference Include="Microsoft.ML.TorchSharp" Version="0.22.0" />
    <PackageReference Include="TorchSharp" Version="0.105.0" />
  </ItemGroup>

</Project>
@dotnet-policy-service dotnet-policy-service bot added the untriaged New issue has not been triaged label Dec 22, 2024
@dha125
Copy link

dha125 commented Feb 7, 2025

I'm getting different results from the command line tool (dotnet tool install --global mlnet-win-x64). The VS extension ML.Net Model Builder v17.18.0 is using something like mlnet below, however, for me a c# project such as the one above possibly uses different libraries and ML architectures also gives me a Mismatched state_dict exception. Below is a hint on what libraries the mlnet tool is using. I'm wondering if part of this issue is the tooling might be a release behind?

C:\Users\xxxx>mlnet text-classification --dataset "text-code.txt" --label-col 1 --text-col 0 --has-header true
Start Training
start text classification
env:path: C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5; [snip]
restore "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj" --configfile "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\NuGet.config" -r win-x64 /p:UsingToolXliff=false /p:TorchSharpVersion=0.101.5 /p:TorchSharpCudaRuntimeVersion=2.1.0.1 /p:TensorflowRuntimeVersion=2.3.1 /p:BaseIntermediateOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\obj"
publish "C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj" -r win-x64 -c Release --no-self-contained -o "C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5" --no-restore /p:UsingToolXliff=false /p:TorchSharpVersion=0.101.5 /p:TorchSharpCudaRuntimeVersion=2.1.0.1 /p:TensorflowRuntimeVersion=2.3.1 /p:BaseOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\bin\" /p:BaseIntermediateOutputPath="C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\obj\"
start installing runtime in C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5
Determining projects to restore...
Restored C:\Users\xxxx.dotnet\tools.store\mlnet-win-x64\16.18.2\mlnet-win-x64\16.18.2\tools\net8.0\any\RuntimeManager\torchsharp.cpu.csproj (in 1.5 min).

torchsharp.cpu -> C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\bin\Release\netstandard2.0\win-x64\torchsharp.cpu.dll
torchsharp.cpu -> C:\Users\xxxx\AppData\Local\ModelBuilder\torchsharp-cpu-0.101.5\

install runtime successfully
Use train validate split with ratio: 0.1
[Source=AutoMLExperiment-ChildContext, Kind=Trace] [Source=TorchSharpBaseTrainer; TrainModel, Kind=Trace] Starting epoch 0

@dha125
Copy link

dha125 commented Feb 8, 2025

For with it's worth the mlnet text-classification.... command line tool generates a SampleTextClassification with a csproj file as shown below.

Edit: Rolled back to TorchSharp-cpu Version="0.99.6" and it's working well for me now.

<Project Sdk="Microsoft.NET.Sdk">
  <PropertyGroup>
    <OutputType>Exe</OutputType>
    <TargetFramework>net8.0</TargetFramework>
    <PlatformTarget>x64</PlatformTarget>
  </PropertyGroup>
  <ItemGroup>
    <PackageReference Include="Microsoft.ML" Version="3.0.1" />
    <PackageReference Include="Microsoft.ML.TorchSharp" Version="0.21.0" />
    <PackageReference Include="TorchSharp-cpu" Version="0.101.5" />
  </ItemGroup>
  <ItemGroup Label="SampleTextClassification">
    <None Include="SampleTextClassification.mlnet">
      <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
    </None>
  </ItemGroup>
</Project>

@michaelgsharp michaelgsharp self-assigned this Mar 10, 2025
@michaelgsharp michaelgsharp added this to the ML.NET 5.0 milestone Mar 19, 2025
@michaelgsharp michaelgsharp added NLP Issues / questions around text processing TorchSharp Issues related to TorchSharp and removed untriaged New issue has not been triaged labels Mar 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
NLP Issues / questions around text processing TorchSharp Issues related to TorchSharp
Projects
None yet
Development

No branches or pull requests

3 participants