File tree Expand file tree Collapse file tree 1 file changed +16
-1
lines changed
pytorch_forecasting/models/nbeats Expand file tree Collapse file tree 1 file changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -60,9 +60,24 @@ def get_base_test_params(cls):
6060
6161 @classmethod
6262 def _get_test_dataloaders_from (cls , params ):
63- """Get dataloaders from parameters."""
63+ loss = params .get ("loss" , None )
64+ data_loader_kwargs = params .get ("data_loader_kwargs" , {})
65+ from pytorch_forecasting .metrics import TweedieLoss
6466 from pytorch_forecasting .tests ._data_scenarios import (
67+ data_with_covariates ,
6568 dataloaders_fixed_window_without_covariates ,
69+ make_dataloaders ,
6670 )
6771
72+ if isinstance (loss , TweedieLoss ):
73+ dwc = data_with_covariates ()
74+ dl_default_kwargs = dict (
75+ target = "target" ,
76+ time_varying_unknown_reals = ["target" ],
77+ add_relative_time_idx = False ,
78+ )
79+ dl_default_kwargs .update (data_loader_kwargs )
80+ dataloaders_with_covariates = make_dataloaders (dwc , ** dl_default_kwargs )
81+ return dataloaders_with_covariates
82+
6883 return dataloaders_fixed_window_without_covariates ()
You can’t perform that action at this time.
0 commit comments