Skip to content

Commit

Permalink
Merge pull request #32 from ourownstory/torch_prophet
Browse files Browse the repository at this point in the history
add TorchProphet
  • Loading branch information
LeonieFreisinger authored Jan 22, 2023
2 parents 3b1f048 + acfc707 commit 7de7bf7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
30 changes: 29 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tot.benchmark import SimpleBenchmark
from tot.dataset import Dataset
from tot.metrics import ERROR_FUNCTIONS
from tot.models import LinearRegressionModel, NaiveModel, ProphetModel, SeasonalNaiveModel
from tot.models import LinearRegressionModel, NaiveModel, ProphetModel, SeasonalNaiveModel, TorchProphetModel

log = logging.getLogger("tot.test")
log.setLevel("WARNING")
Expand Down Expand Up @@ -265,3 +265,31 @@ def test_linear_regression_model():
results_train, results_test = benchmark.run()
log.info("#### test_linear_regression_model")
print(results_test)


def test_torch_prophet_model():
air_passengers_df = pd.read_csv(AIR_FILE, nrows=NROWS)
peyton_manning_df = pd.read_csv(PEYTON_FILE, nrows=NROWS)
dataset_list = [
Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
Dataset(df=peyton_manning_df, name="peyton_manning", freq="D"),
]
model_classes_and_params = [
(
TorchProphetModel,
{"seasonality_mode": "multiplicative", "interval_width": 0},
),
]
log.debug("{}".format(model_classes_and_params))

benchmark = SimpleBenchmark(
model_classes_and_params=model_classes_and_params,
datasets=dataset_list,
metrics=list(ERROR_FUNCTIONS.keys()),
test_percentage=25,
save_dir=SAVE_DIR,
num_processes=1,
)
results_train, results_test = benchmark.run()
log.info("#### test_torch_prophet_model")
print(results_test)
1 change: 1 addition & 0 deletions tot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .models import NeuralProphetModel # noqa: F401
from .models import ProphetModel # noqa: F401
from .models import SeasonalNaiveModel # noqa: F401
from .models import TorchProphetModel # noqa: F401

# logger handling
log = logging.getLogger("dv")
Expand Down
32 changes: 31 additions & 1 deletion tot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import pandas as pd
from neuralprophet import NeuralProphet, df_utils
from neuralprophet import NeuralProphet, TorchProphet, df_utils

from tot.df_utils import reshape_raw_predictions_to_forecast_df
from tot.utils import _convert_seasonality_to_season_length, _get_seasons, convert_df_to_TimeSeries, convert_to_datetime
Expand Down Expand Up @@ -290,6 +290,36 @@ def maybe_drop_added_dates(self, predicted, df):
return predicted, df


@dataclass
class TorchProphetModel(NeuralProphetModel):
model_name: str = "TorchProphet"
model_class: Type = TorchProphet

def __post_init__(self):
data_params = self.params["_data_params"]
custom_seasonalities = None
if "seasonalities" in data_params and len(data_params["seasonalities"]) > 0:
daily, weekly, yearly, custom_seasonalities = _get_seasons(data_params["seasonalities"])
self.params.update({"daily_seasonality": daily})
self.params.update({"weekly_seasonality": weekly})
self.params.update({"yearly_seasonality": yearly})
if "seasonality_mode" in data_params and data_params["seasonality_mode"] is not None:
self.params.update({"seasonality_mode": data_params["seasonality_mode"]})
model_params = deepcopy(self.params)
model_params.pop("_data_params")
model_params.update({"interval_width": 0})
self.model = self.model_class(**model_params)
if custom_seasonalities is not None:
for seasonality in custom_seasonalities:
self.model.add_seasonality(
name="{}_daily".format(str(seasonality)),
period=seasonality,
)
self.n_forecasts = self.model.n_forecasts
self.n_lags = self.model.n_lags
self.season_length = None


@dataclass
class SeasonalNaiveModel(Model):
"""
Expand Down

0 comments on commit 7de7bf7

Please sign in to comment.