Skip to content

Commit a758217

Browse files
Generate built-in SweepableEstimator classes for all available estimators (dotnet#6125)
1 parent bfba5d9 commit a758217

File tree

53 files changed

+3230
-24
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+3230
-24
lines changed

Diff for: src/Microsoft.ML.AutoML/API/AutoCatalog.cs

+169
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections;
7+
using System.Collections.Generic;
8+
using System.Diagnostics.Contracts;
9+
using Microsoft.ML.AutoML.CodeGen;
610
using Microsoft.ML.Data;
711
using Microsoft.ML.SearchSpace;
812

@@ -289,5 +293,170 @@ internal SweepableEstimator CreateSweepableEstimator<T>(Func<MLContext, T, IEsti
289293
{
290294
return new SweepableEstimator((MLContext context, Parameter param) => factory(context, param.AsType<T>()), ss);
291295
}
296+
297+
internal SweepableEstimator[] BinaryClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
298+
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
299+
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
300+
{
301+
var res = new List<SweepableEstimator>();
302+
303+
if (useFastTree)
304+
{
305+
fastTreeOption = fastTreeOption ?? new FastTreeOption();
306+
fastTreeOption.LabelColumnName = labelColumnName;
307+
fastTreeOption.FeatureColumnName = featureColumnName;
308+
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
309+
res.Add(SweepableEstimatorFactory.CreateFastTreeBinary(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
310+
}
311+
312+
if (useFastForest)
313+
{
314+
fastForestOption = fastForestOption ?? new FastForestOption();
315+
fastForestOption.LabelColumnName = labelColumnName;
316+
fastForestOption.FeatureColumnName = featureColumnName;
317+
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
318+
res.Add(SweepableEstimatorFactory.CreateFastForestBinary(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
319+
}
320+
321+
if (useLgbm)
322+
{
323+
lgbmOption = lgbmOption ?? new LgbmOption();
324+
lgbmOption.LabelColumnName = labelColumnName;
325+
lgbmOption.FeatureColumnName = featureColumnName;
326+
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
327+
res.Add(SweepableEstimatorFactory.CreateLightGbmBinary(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
328+
}
329+
330+
if (useLbfgs)
331+
{
332+
lbfgsOption = lbfgsOption ?? new LbfgsOption();
333+
lbfgsOption.LabelColumnName = labelColumnName;
334+
lbfgsOption.FeatureColumnName = featureColumnName;
335+
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
336+
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionBinary(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
337+
}
338+
339+
if (useSdca)
340+
{
341+
sdcaOption = sdcaOption ?? new SdcaOption();
342+
sdcaOption.LabelColumnName = labelColumnName;
343+
sdcaOption.FeatureColumnName = featureColumnName;
344+
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
345+
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionBinary(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
346+
}
347+
348+
return res.ToArray();
349+
}
350+
351+
internal SweepableEstimator[] MultiClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
352+
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
353+
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
354+
{
355+
var res = new List<SweepableEstimator>();
356+
357+
if (useFastTree)
358+
{
359+
fastTreeOption = fastTreeOption ?? new FastTreeOption();
360+
fastTreeOption.LabelColumnName = labelColumnName;
361+
fastTreeOption.FeatureColumnName = featureColumnName;
362+
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
363+
res.Add(SweepableEstimatorFactory.CreateFastTreeOva(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
364+
}
365+
366+
if (useFastForest)
367+
{
368+
fastForestOption = fastForestOption ?? new FastForestOption();
369+
fastForestOption.LabelColumnName = labelColumnName;
370+
fastForestOption.FeatureColumnName = featureColumnName;
371+
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
372+
res.Add(SweepableEstimatorFactory.CreateFastForestOva(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
373+
}
374+
375+
if (useLgbm)
376+
{
377+
lgbmOption = lgbmOption ?? new LgbmOption();
378+
lgbmOption.LabelColumnName = labelColumnName;
379+
lgbmOption.FeatureColumnName = featureColumnName;
380+
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
381+
res.Add(SweepableEstimatorFactory.CreateLightGbmMulti(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
382+
}
383+
384+
if (useLbfgs)
385+
{
386+
lbfgsOption = lbfgsOption ?? new LbfgsOption();
387+
lbfgsOption.LabelColumnName = labelColumnName;
388+
lbfgsOption.FeatureColumnName = featureColumnName;
389+
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
390+
res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionOva(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
391+
res.Add(SweepableEstimatorFactory.CreateLbfgsMaximumEntropyMulti(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
392+
}
393+
394+
if (useSdca)
395+
{
396+
sdcaOption = sdcaOption ?? new SdcaOption();
397+
sdcaOption.LabelColumnName = labelColumnName;
398+
sdcaOption.FeatureColumnName = featureColumnName;
399+
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
400+
res.Add(SweepableEstimatorFactory.CreateSdcaMaximumEntropyMulti(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
401+
res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionOva(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
402+
}
403+
404+
return res.ToArray();
405+
}
406+
407+
internal SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true,
408+
FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null,
409+
SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null)
410+
{
411+
var res = new List<SweepableEstimator>();
412+
413+
if (useFastTree)
414+
{
415+
fastTreeOption = fastTreeOption ?? new FastTreeOption();
416+
fastTreeOption.LabelColumnName = labelColumnName;
417+
fastTreeOption.FeatureColumnName = featureColumnName;
418+
fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName;
419+
res.Add(SweepableEstimatorFactory.CreateFastTreeRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
420+
res.Add(SweepableEstimatorFactory.CreateFastTreeTweedieRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>()));
421+
}
422+
423+
if (useFastForest)
424+
{
425+
fastForestOption = fastForestOption ?? new FastForestOption();
426+
fastForestOption.LabelColumnName = labelColumnName;
427+
fastForestOption.FeatureColumnName = featureColumnName;
428+
fastForestOption.ExampleWeightColumnName = exampleWeightColumnName;
429+
res.Add(SweepableEstimatorFactory.CreateFastForestRegression(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>()));
430+
}
431+
432+
if (useLgbm)
433+
{
434+
lgbmOption = lgbmOption ?? new LgbmOption();
435+
lgbmOption.LabelColumnName = labelColumnName;
436+
lgbmOption.FeatureColumnName = featureColumnName;
437+
lgbmOption.ExampleWeightColumnName = exampleWeightColumnName;
438+
res.Add(SweepableEstimatorFactory.CreateLightGbmRegression(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>()));
439+
}
440+
441+
if (useLbfgs)
442+
{
443+
lbfgsOption = lbfgsOption ?? new LbfgsOption();
444+
lbfgsOption.LabelColumnName = labelColumnName;
445+
lbfgsOption.FeatureColumnName = featureColumnName;
446+
lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName;
447+
res.Add(SweepableEstimatorFactory.CreateLbfgsPoissonRegressionRegression(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>()));
448+
}
449+
450+
if (useSdca)
451+
{
452+
sdcaOption = sdcaOption ?? new SdcaOption();
453+
sdcaOption.LabelColumnName = labelColumnName;
454+
sdcaOption.FeatureColumnName = featureColumnName;
455+
sdcaOption.ExampleWeightColumnName = exampleWeightColumnName;
456+
res.Add(SweepableEstimatorFactory.CreateSdcaRegression(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>()));
457+
}
458+
459+
return res.ToArray();
460+
}
292461
}
293462
}

Diff for: src/Microsoft.ML.AutoML/API/SweepableExtension.cs

+29-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Linq;
78
using System.Text;
89

910
namespace Microsoft.ML.AutoML
@@ -27,7 +28,34 @@ public static SweepableEstimatorPipeline Append(this SweepableEstimator estimato
2728

2829
public static SweepableEstimatorPipeline Append(this SweepableEstimator estimator, IEstimator<ITransformer> estimator1)
2930
{
30-
return estimator.Append(estimator1);
31+
return new SweepableEstimatorPipeline().Append(estimator).Append(estimator1);
32+
}
33+
34+
public static MultiModelPipeline Append(this IEstimator<ITransformer> estimator, params SweepableEstimator[] estimators)
35+
{
36+
var sweepableEstimator = new SweepableEstimator((context, parameter) => estimator, new SearchSpace.SearchSpace());
37+
var multiModelPipeline = new MultiModelPipeline().Append(sweepableEstimator).Append(estimators);
38+
39+
return multiModelPipeline;
40+
}
41+
42+
public static MultiModelPipeline Append(this SweepableEstimatorPipeline pipeline, params SweepableEstimator[] estimators)
43+
{
44+
var multiModelPipeline = new MultiModelPipeline();
45+
foreach (var estimator in pipeline.Estimators)
46+
{
47+
multiModelPipeline = multiModelPipeline.Append(estimator);
48+
}
49+
50+
return multiModelPipeline.Append(estimators);
51+
}
52+
53+
public static MultiModelPipeline Append(this SweepableEstimator estimator, params SweepableEstimator[] estimators)
54+
{
55+
var multiModelPipeline = new MultiModelPipeline();
56+
multiModelPipeline = multiModelPipeline.Append(estimator);
57+
58+
return multiModelPipeline.Append(estimators);
3159
}
3260
}
3361
}

Diff for: src/Microsoft.ML.AutoML/AutoMlUtils.cs

+34
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,46 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using System.Threading;
78

89
namespace Microsoft.ML.AutoML
910
{
1011
internal static class AutoMlUtils
1112
{
13+
private const string MLNetMaxThread = "MLNET_MAX_THREAD";
14+
1215
public static readonly ThreadLocal<Random> Random = new ThreadLocal<Random>(() => new Random());
16+
17+
/// <summary>
18+
/// Return number of thread if MLNET_MAX_THREAD is set, otherwise return null.
19+
/// </summary>
20+
public static int? GetNumberOfThreadFromEnvrionment()
21+
{
22+
var res = Environment.GetEnvironmentVariable(MLNetMaxThread);
23+
24+
if (int.TryParse(res, out var numberOfThread))
25+
{
26+
return numberOfThread;
27+
}
28+
29+
return null;
30+
}
31+
32+
public static InputOutputColumnPair[] CreateInputOutputColumnPairsFromStrings(string[] inputs, string[] outputs)
33+
{
34+
if (inputs.Length != outputs.Length)
35+
{
36+
throw new Exception("inputs and outputs count must match");
37+
}
38+
39+
var res = new List<InputOutputColumnPair>();
40+
for (int i = 0; i != inputs.Length; ++i)
41+
{
42+
res.Add(new InputOutputColumnPair(outputs[i], inputs[i]));
43+
}
44+
45+
return res.ToArray();
46+
}
1347
}
1448
}

Diff for: src/Microsoft.ML.AutoML/CodeGen/code_gen_flag.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"EstimatorFactoryGenerator": false,
3-
"CodeGenCatalogGenerator": false,
3+
"SweepableEstimatorFactory": true,
44
"EstimatorTypeGenerator": true,
55
"SearchSpaceGenerator": true,
6-
"SweepableEstimatorGenerator": false
6+
"SweepableEstimatorGenerator": true
77
}

Diff for: src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj

+6-4
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
<PrivateAssets>all</PrivateAssets>
1515
</ProjectReference>
1616
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
17+
<ProjectReference Include="..\Microsoft.ML.OnnxTransformer\Microsoft.ML.OnnxTransformer.csproj" />
1718
<ProjectReference Include="..\Microsoft.ML.SearchSpace\Microsoft.ML.SearchSpace.csproj">
1819
<PrivateAssets>all</PrivateAssets>
1920
<IncludeInNuget>true</IncludeInNuget>
2021
</ProjectReference>
2122
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="$(MicrosoftCodeAnalysisCSharpVersion)" />
23+
<ProjectReference Include="..\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj" />
2224
<ProjectReference Include="..\Microsoft.ML.Vision\Microsoft.ML.Vision.csproj" />
2325
<ProjectReference Include="..\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
2426
<ProjectReference Include="..\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj" />
@@ -43,13 +45,13 @@
4345
<Target DependsOnTargets="ResolveReferences" Name="CopyProjectReferencesToPackage">
4446
<ItemGroup>
4547
<!--Include DLLs of Project References-->
46-
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true'))"/>
48+
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths-&gt;WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')-&gt;WithMetadataValue('IncludeInNuget','true'))" />
4749
<!--Include PDBs of Project References-->
48-
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths->WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')->WithMetadataValue('IncludeInNuget','true')->Replace('.dll', '.pdb'))"/>
50+
<BuildOutputInPackage Include="@(ReferenceCopyLocalPaths-&gt;WithMetadataValue('ReferenceSourceTarget', 'ProjectReference')-&gt;WithMetadataValue('IncludeInNuget','true')-&gt;Replace('.dll', '.pdb'))" />
4951
<!--Include PDBs for Native binaries-->
5052
<!--The path needed to be hardcoded for this to work on our publishing CI-->
51-
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x86\native"/>
52-
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x64\native"/>
53+
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x86\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x86\native" />
54+
<BuildOutputInPackage Condition="Exists('$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb')" Include="$(PackageAssetsPath)$(PackageIdFolderName)\runtimes\win-x64\native\LdaNative.pdb" TargetPath="..\..\runtimes\win-x64\native" />
5355
</ItemGroup>
5456
</Target>
5557

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
using System.Text.Json;
9+
using System.Text.Json.Nodes;
10+
using System.Text.Json.Serialization;
11+
12+
namespace Microsoft.ML.AutoML
13+
{
14+
internal class MultiModelPipelineConverter : JsonConverter<MultiModelPipeline>
15+
{
16+
public override MultiModelPipeline Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
17+
{
18+
var jValue = JsonValue.Parse(ref reader);
19+
var schema = jValue["schema"].GetValue<string>();
20+
var estimators = jValue["estimator"].GetValue<Dictionary<string, SweepableEstimator>>();
21+
22+
return new MultiModelPipeline(estimators, Entity.FromExpression(schema));
23+
}
24+
25+
public override void Write(Utf8JsonWriter writer, MultiModelPipeline value, JsonSerializerOptions options)
26+
{
27+
var jsonObject = JsonNode.Parse("{}");
28+
jsonObject["schema"] = value.Schema.ToString();
29+
jsonObject["estimators"] = JsonValue.Create(value.Estimators);
30+
31+
jsonObject.WriteTo(writer, options);
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)