|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
5 | 5 | using System;
|
| 6 | +using System.Collections; |
| 7 | +using System.Collections.Generic; |
| 8 | +using System.Diagnostics.Contracts; |
| 9 | +using Microsoft.ML.AutoML.CodeGen; |
6 | 10 | using Microsoft.ML.Data;
|
7 | 11 | using Microsoft.ML.SearchSpace;
|
8 | 12 |
|
@@ -289,5 +293,170 @@ internal SweepableEstimator CreateSweepableEstimator<T>(Func<MLContext, T, IEsti
|
289 | 293 | {
|
290 | 294 | return new SweepableEstimator((MLContext context, Parameter param) => factory(context, param.AsType<T>()), ss);
|
291 | 295 | }
|
| 296 | + |
| 297 | + internal SweepableEstimator[] BinaryClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true, |
| 298 | + FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null, |
| 299 | + SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null) |
| 300 | + { |
| 301 | + var res = new List<SweepableEstimator>(); |
| 302 | + |
| 303 | + if (useFastTree) |
| 304 | + { |
| 305 | + fastTreeOption = fastTreeOption ?? new FastTreeOption(); |
| 306 | + fastTreeOption.LabelColumnName = labelColumnName; |
| 307 | + fastTreeOption.FeatureColumnName = featureColumnName; |
| 308 | + fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 309 | + res.Add(SweepableEstimatorFactory.CreateFastTreeBinary(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>())); |
| 310 | + } |
| 311 | + |
| 312 | + if (useFastForest) |
| 313 | + { |
| 314 | + fastForestOption = fastForestOption ?? new FastForestOption(); |
| 315 | + fastForestOption.LabelColumnName = labelColumnName; |
| 316 | + fastForestOption.FeatureColumnName = featureColumnName; |
| 317 | + fastForestOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 318 | + res.Add(SweepableEstimatorFactory.CreateFastForestBinary(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>())); |
| 319 | + } |
| 320 | + |
| 321 | + if (useLgbm) |
| 322 | + { |
| 323 | + lgbmOption = lgbmOption ?? new LgbmOption(); |
| 324 | + lgbmOption.LabelColumnName = labelColumnName; |
| 325 | + lgbmOption.FeatureColumnName = featureColumnName; |
| 326 | + lgbmOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 327 | + res.Add(SweepableEstimatorFactory.CreateLightGbmBinary(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>())); |
| 328 | + } |
| 329 | + |
| 330 | + if (useLbfgs) |
| 331 | + { |
| 332 | + lbfgsOption = lbfgsOption ?? new LbfgsOption(); |
| 333 | + lbfgsOption.LabelColumnName = labelColumnName; |
| 334 | + lbfgsOption.FeatureColumnName = featureColumnName; |
| 335 | + lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 336 | + res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionBinary(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>())); |
| 337 | + } |
| 338 | + |
| 339 | + if (useSdca) |
| 340 | + { |
| 341 | + sdcaOption = sdcaOption ?? new SdcaOption(); |
| 342 | + sdcaOption.LabelColumnName = labelColumnName; |
| 343 | + sdcaOption.FeatureColumnName = featureColumnName; |
| 344 | + sdcaOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 345 | + res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionBinary(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>())); |
| 346 | + } |
| 347 | + |
| 348 | + return res.ToArray(); |
| 349 | + } |
| 350 | + |
| 351 | + internal SweepableEstimator[] MultiClassification(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true, |
| 352 | + FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null, |
| 353 | + SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null) |
| 354 | + { |
| 355 | + var res = new List<SweepableEstimator>(); |
| 356 | + |
| 357 | + if (useFastTree) |
| 358 | + { |
| 359 | + fastTreeOption = fastTreeOption ?? new FastTreeOption(); |
| 360 | + fastTreeOption.LabelColumnName = labelColumnName; |
| 361 | + fastTreeOption.FeatureColumnName = featureColumnName; |
| 362 | + fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 363 | + res.Add(SweepableEstimatorFactory.CreateFastTreeOva(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>())); |
| 364 | + } |
| 365 | + |
| 366 | + if (useFastForest) |
| 367 | + { |
| 368 | + fastForestOption = fastForestOption ?? new FastForestOption(); |
| 369 | + fastForestOption.LabelColumnName = labelColumnName; |
| 370 | + fastForestOption.FeatureColumnName = featureColumnName; |
| 371 | + fastForestOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 372 | + res.Add(SweepableEstimatorFactory.CreateFastForestOva(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>())); |
| 373 | + } |
| 374 | + |
| 375 | + if (useLgbm) |
| 376 | + { |
| 377 | + lgbmOption = lgbmOption ?? new LgbmOption(); |
| 378 | + lgbmOption.LabelColumnName = labelColumnName; |
| 379 | + lgbmOption.FeatureColumnName = featureColumnName; |
| 380 | + lgbmOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 381 | + res.Add(SweepableEstimatorFactory.CreateLightGbmMulti(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>())); |
| 382 | + } |
| 383 | + |
| 384 | + if (useLbfgs) |
| 385 | + { |
| 386 | + lbfgsOption = lbfgsOption ?? new LbfgsOption(); |
| 387 | + lbfgsOption.LabelColumnName = labelColumnName; |
| 388 | + lbfgsOption.FeatureColumnName = featureColumnName; |
| 389 | + lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 390 | + res.Add(SweepableEstimatorFactory.CreateLbfgsLogisticRegressionOva(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>())); |
| 391 | + res.Add(SweepableEstimatorFactory.CreateLbfgsMaximumEntropyMulti(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>())); |
| 392 | + } |
| 393 | + |
| 394 | + if (useSdca) |
| 395 | + { |
| 396 | + sdcaOption = sdcaOption ?? new SdcaOption(); |
| 397 | + sdcaOption.LabelColumnName = labelColumnName; |
| 398 | + sdcaOption.FeatureColumnName = featureColumnName; |
| 399 | + sdcaOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 400 | + res.Add(SweepableEstimatorFactory.CreateSdcaMaximumEntropyMulti(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>())); |
| 401 | + res.Add(SweepableEstimatorFactory.CreateSdcaLogisticRegressionOva(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>())); |
| 402 | + } |
| 403 | + |
| 404 | + return res.ToArray(); |
| 405 | + } |
| 406 | + |
| 407 | + internal SweepableEstimator[] Regression(string labelColumnName = DefaultColumnNames.Label, string featureColumnName = DefaultColumnNames.Features, string exampleWeightColumnName = null, bool useFastForest = true, bool useLgbm = true, bool useFastTree = true, bool useLbfgs = true, bool useSdca = true, |
| 408 | + FastTreeOption fastTreeOption = null, LgbmOption lgbmOption = null, FastForestOption fastForestOption = null, LbfgsOption lbfgsOption = null, SdcaOption sdcaOption = null, |
| 409 | + SearchSpace<FastTreeOption> fastTreeSearchSpace = null, SearchSpace<LgbmOption> lgbmSearchSpace = null, SearchSpace<FastForestOption> fastForestSearchSpace = null, SearchSpace<LbfgsOption> lbfgsSearchSpace = null, SearchSpace<SdcaOption> sdcaSearchSpace = null) |
| 410 | + { |
| 411 | + var res = new List<SweepableEstimator>(); |
| 412 | + |
| 413 | + if (useFastTree) |
| 414 | + { |
| 415 | + fastTreeOption = fastTreeOption ?? new FastTreeOption(); |
| 416 | + fastTreeOption.LabelColumnName = labelColumnName; |
| 417 | + fastTreeOption.FeatureColumnName = featureColumnName; |
| 418 | + fastTreeOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 419 | + res.Add(SweepableEstimatorFactory.CreateFastTreeRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>())); |
| 420 | + res.Add(SweepableEstimatorFactory.CreateFastTreeTweedieRegression(fastTreeOption, fastTreeSearchSpace ?? new SearchSpace<FastTreeOption>())); |
| 421 | + } |
| 422 | + |
| 423 | + if (useFastForest) |
| 424 | + { |
| 425 | + fastForestOption = fastForestOption ?? new FastForestOption(); |
| 426 | + fastForestOption.LabelColumnName = labelColumnName; |
| 427 | + fastForestOption.FeatureColumnName = featureColumnName; |
| 428 | + fastForestOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 429 | + res.Add(SweepableEstimatorFactory.CreateFastForestRegression(fastForestOption, fastForestSearchSpace ?? new SearchSpace<FastForestOption>())); |
| 430 | + } |
| 431 | + |
| 432 | + if (useLgbm) |
| 433 | + { |
| 434 | + lgbmOption = lgbmOption ?? new LgbmOption(); |
| 435 | + lgbmOption.LabelColumnName = labelColumnName; |
| 436 | + lgbmOption.FeatureColumnName = featureColumnName; |
| 437 | + lgbmOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 438 | + res.Add(SweepableEstimatorFactory.CreateLightGbmRegression(lgbmOption, lgbmSearchSpace ?? new SearchSpace<LgbmOption>())); |
| 439 | + } |
| 440 | + |
| 441 | + if (useLbfgs) |
| 442 | + { |
| 443 | + lbfgsOption = lbfgsOption ?? new LbfgsOption(); |
| 444 | + lbfgsOption.LabelColumnName = labelColumnName; |
| 445 | + lbfgsOption.FeatureColumnName = featureColumnName; |
| 446 | + lbfgsOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 447 | + res.Add(SweepableEstimatorFactory.CreateLbfgsPoissonRegressionRegression(lbfgsOption, lbfgsSearchSpace ?? new SearchSpace<LbfgsOption>())); |
| 448 | + } |
| 449 | + |
| 450 | + if (useSdca) |
| 451 | + { |
| 452 | + sdcaOption = sdcaOption ?? new SdcaOption(); |
| 453 | + sdcaOption.LabelColumnName = labelColumnName; |
| 454 | + sdcaOption.FeatureColumnName = featureColumnName; |
| 455 | + sdcaOption.ExampleWeightColumnName = exampleWeightColumnName; |
| 456 | + res.Add(SweepableEstimatorFactory.CreateSdcaRegression(sdcaOption, sdcaSearchSpace ?? new SearchSpace<SdcaOption>())); |
| 457 | + } |
| 458 | + |
| 459 | + return res.ToArray(); |
| 460 | + } |
292 | 461 | }
|
293 | 462 | }
|
0 commit comments