From 800bcc3d24a2aa3d2729e8e31bc0c597d7ff90c5 Mon Sep 17 00:00:00 2001 From: Leonie Freisinger Date: Sun, 22 Jan 2023 07:41:10 -0800 Subject: [PATCH 1/2] add TorchProphet --- tot/__init__.py | 1 + tot/models.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tot/__init__.py b/tot/__init__.py index 0aff26a..cda7571 100644 --- a/tot/__init__.py +++ b/tot/__init__.py @@ -12,6 +12,7 @@ from .models import NeuralProphetModel # noqa: F401 from .models import ProphetModel # noqa: F401 from .models import SeasonalNaiveModel # noqa: F401 +from .models import TorchProphetModel # noqa: F401 # logger handling log = logging.getLogger("dv") diff --git a/tot/models.py b/tot/models.py index b8e6e63..e740e3e 100644 --- a/tot/models.py +++ b/tot/models.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from neuralprophet import NeuralProphet, df_utils +from neuralprophet import NeuralProphet, TorchProphet, df_utils from tot.df_utils import reshape_raw_predictions_to_forecast_df from tot.utils import _convert_seasonality_to_season_length, _get_seasons, convert_df_to_TimeSeries, convert_to_datetime @@ -290,6 +290,12 @@ def maybe_drop_added_dates(self, predicted, df): return predicted, df +@dataclass +class TorchProphetModel(NeuralProphetModel): + model_name: str = "TorchProphet" + model_class: Type = TorchProphet + + @dataclass class SeasonalNaiveModel(Model): """ From acfc707bb192f28a4fbda94832fdc41132a87cc7 Mon Sep 17 00:00:00 2001 From: Leonie Freisinger Date: Sun, 22 Jan 2023 07:57:04 -0800 Subject: [PATCH 2/2] add pytest --- tests/test_models.py | 30 +++++++++++++++++++++++++++++- tot/models.py | 24 ++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 267c0e7..baf78b9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from tot.benchmark import SimpleBenchmark from tot.dataset import Dataset from tot.metrics import ERROR_FUNCTIONS -from tot.models import LinearRegressionModel, NaiveModel, ProphetModel, SeasonalNaiveModel +from tot.models import LinearRegressionModel, NaiveModel, ProphetModel, SeasonalNaiveModel, TorchProphetModel log = logging.getLogger("tot.test") log.setLevel("WARNING") @@ -265,3 +265,31 @@ def test_linear_regression_model(): results_train, results_test = benchmark.run() log.info("#### test_linear_regression_model") print(results_test) + + +def test_torch_prophet_model(): + air_passengers_df = pd.read_csv(AIR_FILE, nrows=NROWS) + peyton_manning_df = pd.read_csv(PEYTON_FILE, nrows=NROWS) + dataset_list = [ + Dataset(df=air_passengers_df, name="air_passengers", freq="MS"), + Dataset(df=peyton_manning_df, name="peyton_manning", freq="D"), + ] + model_classes_and_params = [ + ( + TorchProphetModel, + {"seasonality_mode": "multiplicative", "interval_width": 0}, + ), + ] + log.debug("{}".format(model_classes_and_params)) + + benchmark = SimpleBenchmark( + model_classes_and_params=model_classes_and_params, + datasets=dataset_list, + metrics=list(ERROR_FUNCTIONS.keys()), + test_percentage=25, + save_dir=SAVE_DIR, + num_processes=1, + ) + results_train, results_test = benchmark.run() + log.info("#### test_torch_prophet_model") + print(results_test) diff --git a/tot/models.py b/tot/models.py index e740e3e..e2b7177 100644 --- a/tot/models.py +++ b/tot/models.py @@ -295,6 +295,30 @@ class TorchProphetModel(NeuralProphetModel): model_name: str = "TorchProphet" model_class: Type = TorchProphet + def __post_init__(self): + data_params = self.params["_data_params"] + custom_seasonalities = None + if "seasonalities" in data_params and len(data_params["seasonalities"]) > 0: + daily, weekly, yearly, custom_seasonalities = _get_seasons(data_params["seasonalities"]) + self.params.update({"daily_seasonality": daily}) + self.params.update({"weekly_seasonality": weekly}) + self.params.update({"yearly_seasonality": yearly}) + if "seasonality_mode" in data_params and data_params["seasonality_mode"] is not None: + self.params.update({"seasonality_mode": data_params["seasonality_mode"]}) + model_params = deepcopy(self.params) + model_params.pop("_data_params") + model_params.update({"interval_width": 0}) + self.model = self.model_class(**model_params) + if custom_seasonalities is not None: + for seasonality in custom_seasonalities: + self.model.add_seasonality( + name="{}_daily".format(str(seasonality)), + period=seasonality, + ) + self.n_forecasts = self.model.n_forecasts + self.n_lags = self.model.n_lags + self.season_length = None + @dataclass class SeasonalNaiveModel(Model):