Skip to content

Commit

Permalink
Fix multiprocessing in experiments and local models (#64)
Browse files Browse the repository at this point in the history
* set multiprocessing start method to spawn 
* parallelize local models
* remove import
  • Loading branch information
ankke authored Nov 25, 2023
1 parent 35ac96a commit ad2a4ea
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 13 deletions.
6 changes: 5 additions & 1 deletion tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,15 @@ def test_2_benchmark_CV():
test_percentage=0.1,
num_folds=3,
fold_overlap_pct=0.5,
num_processes=1,
num_processes=10,
)
results_summary_ol, results_train_ol, results_test_ol = benchmark_cv_overlap.run()
log.debug("{}".format(results_summary_ol))
log.info("#### test_2_benchmark_CV")
if PLOT:
air_passengers = results_summary[results_summary["split"] == "test"]
# air_passengers.plot(x="data", y="MASE", kind="barh")
plt.show()


def test_2_benchmark_manual():
Expand Down
40 changes: 37 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_darts_model():
print(results_test)


def test_darts_local_model():
def test_darts_StatsForecastAutoETS_model():
air_passengers_df = pd.read_csv(AIR_FILE, nrows=NROWS)
ercot_df_aux = pd.read_csv(ERCOT_FILE)
ercot_df = pd.DataFrame()
Expand All @@ -350,11 +350,45 @@ def test_darts_local_model():
model_classes_and_params = [
(
DartsLocalForecastingModel,
{"model": StatsForecastAutoARIMA, "lags": 12, "n_forecasts": 4},
{"model": StatsForecastAutoETS, "lags": 12, "n_forecasts": 1, "ETS_model": "ZZZ"},
),
]
log.debug("{}".format(model_classes_and_params))

benchmark = SimpleBenchmark(
model_classes_and_params=model_classes_and_params,
datasets=dataset_list,
metrics=list(ERROR_FUNCTIONS.keys()),
test_percentage=0.25,
save_dir=SAVE_DIR,
num_processes=1,
)
results_train, results_test = benchmark.run()
log.info("#### test_darts_local_model")
print(results_test)


def test_darts_StatsForecastAutoARIMA_model():
air_passengers_df = pd.read_csv(AIR_FILE, nrows=NROWS)
ercot_df_aux = pd.read_csv(ERCOT_FILE)
ercot_df = pd.DataFrame()
for region in ERCOT_REGIONS:
ercot_df = pd.concat(
(
ercot_df,
ercot_df_aux[ercot_df_aux["ID"] == region].iloc[:NROWS].copy(deep=True),
),
ignore_index=True,
)

dataset_list = [
Dataset(df=air_passengers_df, name="air_passengers", freq="MS"),
Dataset(df=ercot_df, name="ercot_df", freq="H"),
]
model_classes_and_params = [
(
DartsLocalForecastingModel,
{"model": StatsForecastAutoETS, "lags": 12, "n_forecasts": 4, "ETS_model": "ZZZ"},
{"model": StatsForecastAutoARIMA, "lags": 12, "n_forecasts": 1},
),
]
log.debug("{}".format(model_classes_and_params))
Expand Down
4 changes: 2 additions & 2 deletions tot/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from abc import ABC
from dataclasses import dataclass
from multiprocessing.pool import Pool
from multiprocessing import get_context
from typing import List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -103,7 +103,7 @@ def run(self, verbose=True):
not all([exp.num_processes == 1 for exp in self.experiments]),
"Cannot set multiprocessing in " "Experiments and Benchmark.",
)
with Pool(self.num_processes) as pool:
with get_context("spawn").Pool(self.num_processes) as pool:
args_list = [(exp, verbose, i + 1) for i, exp in enumerate(self.experiments)]
pool.map_async(
self._run_exp,
Expand Down
4 changes: 2 additions & 2 deletions tot/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from multiprocessing.pool import Pool
from multiprocessing import get_context
from typing import List, Optional

import pandas as pd
Expand Down Expand Up @@ -442,7 +442,7 @@ def run(self):
self.fcst_test = []
time_start = time.time()
if self.num_processes > 1 and self.num_folds > 1:
with Pool(self.num_processes) as pool:
with get_context("spawn").Pool(self.num_processes) as pool:
args = [
(df_train, df_test, current_fold, received_ID_column, received_single_time_series)
for current_fold, (df_train, df_test) in enumerate(folds)
Expand Down
26 changes: 21 additions & 5 deletions tot/models/models_darts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gc
import logging
from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import get_context
from typing import Type

import pandas as pd
Expand Down Expand Up @@ -163,6 +165,11 @@ def maybe_drop_added_values_from_df(self, predicted: pd.DataFrame, df: pd.DataFr
return predicted


def _fit(args):
model, series = args
model.fit(series)


@dataclass
class DartsLocalForecastingModel(DartsForecastingModel):
"""
Expand Down Expand Up @@ -240,11 +247,20 @@ def fit(self, df: pd.DataFrame, freq: str, ids_weights: dict) -> None:
self.freq = freq
self.model_params["season_length"] = FREQ_TO_SEASON_LENGTH[freq]

for df_name, df in df.groupby("ID"):
model = self.model(**self.model_params)
series = convert_df_to_TimeSeries(df, freq=self.freq)
model.fit(series)
self.models_list.append(model)
self.models_list = [self.model(**self.model_params) for _ in range(len(df.groupby("ID")))]

with get_context("spawn").Pool() as pool:
args = [
(model, convert_df_to_TimeSeries(df_i, freq=self.freq))
for model, (_, df_i) in zip(self.models_list, df.groupby("ID"))
]
pool.map(
_fit,
args,
)
pool.close()
pool.join()
gc.collect()

def predict(
self, df: pd.DataFrame, received_single_time_series: bool, df_historic: pd.DataFrame = None
Expand Down

0 comments on commit ad2a4ea

Please sign in to comment.