Skip to content

Commit

Permalink
[refactoring] Data validation unification (#46)
Browse files Browse the repository at this point in the history
* replace assert with ValueError

* format code and docstrings

* rename variables and functions

* extract raising Value Error to a separate module

* throw DataValidationError for non user input validation

* correct docstrings in metrics.py

* Update error_utils.py

---------

Co-authored-by: LeonieFreisinger <[email protected]>
  • Loading branch information
ankke and LeonieFreisinger authored Apr 26, 2023
1 parent b0a0dd4 commit 502869e
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 193 deletions.
4 changes: 2 additions & 2 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def test_evaluation_by_ID_for_forecast_step_invalid_input():
benchmark.fcst_train[0]
) # ensure ID column in dataframe with single time series
# calculate metrics by ID for selected forecast step
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
calculate_metrics_by_ID_for_forecast_step(
fcst_df=fcst_test_peyton, df_historic=fcst_train_peyton, forecast_step_in_focus=1, freq="D"
)
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
calculate_metrics_by_ID_for_forecast_step(
fcst_df=fcst_test_peyton,
df_historic=fcst_train_peyton,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_seasonal_naive_model(dataset_input, model_classes_and_params_input):

@pytest.mark.parametrize(*decorator_input)
def test_seasonal_naive_model_invalid_input(dataset_input, model_classes_and_params_input):
log.info("Test invalid model input - Raise Assertion")
log.info("Test invalid model input - Raise ValueError")
peyton_manning_df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
dataset_list = [
Dataset(
Expand All @@ -209,7 +209,7 @@ def test_seasonal_naive_model_invalid_input(dataset_input, model_classes_and_par
num_processes=1,
)

with pytest.raises(AssertionError):
with pytest.raises(ValueError):
_, _ = benchmark.run()

log.info("#### Done with test_seasonal_naive_model_invalid_input")
Expand Down Expand Up @@ -345,6 +345,6 @@ def test_check_min_input_len():
save_dir=SAVE_DIR,
num_processes=1,
)
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
results_train, results_test = benchmark.run()
log.info("#### test_check_min_input_len")
7 changes: 5 additions & 2 deletions tot/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd

from tot.datasets.dataset import Dataset
from tot.error_utils import raise_if
from tot.experiment import CrossValidationExperiment, Experiment, SimpleExperiment
from tot.models.models import Model

Expand Down Expand Up @@ -98,8 +99,10 @@ def run(self, verbose=True):
log.info("exp {}/{}: {}".format(i + 1, len(self.experiments), exp.experiment_name))
log.info("---- Staring Series of {} Experiments ----".format(len(self.experiments)))
if self.num_processes > 1 and len(self.experiments) > 1:
if not all([exp.num_processes == 1 for exp in self.experiments]):
raise ValueError("can not set multiprocessing in experiments and Benchmark.")
raise_if(
not all([exp.num_processes == 1 for exp in self.experiments]),
"Cannot set multiprocessing in " "Experiments and Benchmark.",
)
with Pool(self.num_processes) as pool:
args_list = [(exp, verbose, i + 1) for i, exp in enumerate(self.experiments)]
pool.map_async(
Expand Down
Loading

0 comments on commit 502869e

Please sign in to comment.