Skip to content

Commit

Permalink
Merge pull request #35 from ourownstory/restructure_helper_functions
Browse files Browse the repository at this point in the history
Fully restructure model-related and df-related code base
  • Loading branch information
LeonieFreisinger authored Jan 30, 2023
2 parents 7de7bf7 + 6794029 commit 7197fe1
Show file tree
Hide file tree
Showing 12 changed files with 1,713 additions and 1,095 deletions.
33 changes: 17 additions & 16 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from tot.dataset import Dataset
from tot.experiment import CrossValidationExperiment, SimpleExperiment
from tot.metrics import ERROR_FUNCTIONS
from tot.models import NeuralProphetModel, ProphetModel
from tot.models_neuralprophet import NeuralProphetModel
from tot.models_simple import ProphetModel

log = logging.getLogger("tot.test")
log.setLevel("WARNING")
Expand Down Expand Up @@ -78,7 +79,7 @@ def test_2_benchmark_simple():
model_classes_and_params=model_classes_and_params, # iterate over this list of tuples
datasets=dataset_list, # iterate over this list
metrics=["MAE", "MSE", "MASE", "RMSE"],
test_percentage=25,
test_percentage=0.25,
)
results_train, results_test = benchmark.run()

Expand Down Expand Up @@ -110,7 +111,7 @@ def test_2_benchmark_CV():
model_classes_and_params=model_classes_and_params, # iterate over this list of tuples
datasets=dataset_list, # iterate over this list
metrics=["MASE", "RMSE"],
test_percentage=10,
test_percentage=0.1,
num_folds=3,
fold_overlap_pct=0,
)
Expand Down Expand Up @@ -141,7 +142,7 @@ def test_2_benchmark_manual():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=25,
test_percentage=0.25,
),
SimpleExperiment(
model_class=NeuralProphetModel,
Expand All @@ -152,7 +153,7 @@ def test_2_benchmark_manual():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=25,
test_percentage=0.25,
),
]
if _prophet_installed:
Expand All @@ -164,7 +165,7 @@ def test_2_benchmark_manual():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=25,
test_percentage=0.25,
)
)
benchmark = ManualBenchmark(
Expand All @@ -190,7 +191,7 @@ def test_2_benchmark_manualCV():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=10,
test_percentage=0.1,
num_folds=3,
fold_overlap_pct=0,
),
Expand All @@ -203,7 +204,7 @@ def test_2_benchmark_manualCV():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=10,
test_percentage=0.1,
num_folds=3,
fold_overlap_pct=0,
),
Expand Down Expand Up @@ -242,15 +243,15 @@ def test_manual_benchmark():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=25,
test_percentage=0.25,
save_dir=SAVE_DIR,
),
SimpleExperiment(
model_class=NeuralProphetModel,
params={"learning_rate": 0.1, "epochs": EPOCHS},
data=Dataset(df=peyton_manning_df, name="peyton_manning", freq="D"),
metrics=metrics,
test_percentage=15,
test_percentage=0.15,
save_dir=SAVE_DIR,
),
]
Expand All @@ -260,15 +261,15 @@ def test_manual_benchmark():
params={"seasonality_mode": "multiplicative"},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=25,
test_percentage=0.25,
save_dir=SAVE_DIR,
),
SimpleExperiment(
model_class=ProphetModel,
params={},
data=Dataset(df=peyton_manning_df, name="peyton_manning", freq="D"),
metrics=metrics,
test_percentage=15,
test_percentage=0.15,
save_dir=SAVE_DIR,
),
]
Expand All @@ -294,7 +295,7 @@ def test_manual_cv_benchmark():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=10,
test_percentage=0.10,
num_folds=2,
fold_overlap_pct=0,
save_dir=SAVE_DIR,
Expand All @@ -308,7 +309,7 @@ def test_manual_cv_benchmark():
},
data=Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
metrics=metrics,
test_percentage=10,
test_percentage=0.10,
num_folds=1,
fold_overlap_pct=0,
save_dir=SAVE_DIR,
Expand Down Expand Up @@ -350,7 +351,7 @@ def test_simple_benchmark():
model_classes_and_params=model_classes_and_params, # iterate over this list of tuples
datasets=dataset_list, # iterate over this list
metrics=list(ERROR_FUNCTIONS.keys()),
test_percentage=25,
test_percentage=0.25,
save_dir=SAVE_DIR,
num_processes=1,
)
Expand Down Expand Up @@ -388,7 +389,7 @@ def test_cv_benchmark():
model_classes_and_params=model_classes_and_params, # iterate over this list of tuples
datasets=dataset_list, # iterate over this list
metrics=list(ERROR_FUNCTIONS.keys()),
test_percentage=10,
test_percentage=0.10,
num_folds=3,
fold_overlap_pct=0,
save_dir=SAVE_DIR,
Expand Down
Loading

0 comments on commit 7197fe1

Please sign in to comment.