Skip to content

Commit bb61d2a

Browse files
authored
[ENH] Add Metrics support to ptf-v2 (#1960)
Fixes #1956, Fixes #1844 This PR tries to make `v2` compatible with `Metrics`. It also changes the contract of tensors to match v1: `list` for multi-target and a `tensor` for single-target Stacks on #1965
1 parent 22a9781 commit bb61d2a

File tree

12 files changed

+434
-828
lines changed

12 files changed

+434
-828
lines changed

docs/source/tutorials/ptf_V2_example.ipynb

Lines changed: 154 additions & 500 deletions
Large diffs are not rendered by default.

docs/source/tutorials/tslib_v2_example.ipynb

Lines changed: 133 additions & 280 deletions
Large diffs are not rendered by default.

pytorch_forecasting/data/_tslib_data_module.py

Lines changed: 77 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,65 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
7070
7171
Returns
7272
-------
73-
x: dict[str, torch.Tensor]
74-
A dictionary containing the processed data.
75-
y: torch.Tensor
76-
The target variable.
73+
x : dict[str, torch.Tensor]
74+
Dict containing processed inputs for the model, with the following keys:
75+
76+
* ``history_cont`` : torch.Tensor of shape
77+
(context_length, n_history_cont_features)
78+
Continuous features for the encoder (historical data).
79+
* ``history_cat`` : torch.Tensor of shape
80+
(context_length, n_history_cat_features)
81+
Categorical features for the encoder (historical data).
82+
* ``future_cont`` : torch.Tensor of shape
83+
(prediction_length, n_future_cont_features)
84+
Known continuous features for the decoder (future data).
85+
* ``future_cat`` : torch.Tensor of shape
86+
(prediction_length, n_future_cat_features)
87+
Known categorical features for the decoder (future data).
88+
* ``history_length`` : torch.Tensor of shape (1,)
89+
Length of the encoder sequence.
90+
* ``future_length`` : torch.Tensor of shape (1,)
91+
Length of the decoder sequence.
92+
* ``history_mask`` : torch.Tensor of shape (context_length,)
93+
Boolean mask indicating valid encoder time points.
94+
* ``future_mask`` : torch.Tensor of shape (prediction_length,)
95+
Boolean mask indicating valid decoder time points.
96+
* ``groups`` : torch.Tensor of shape (1,)
97+
Group identifier for the time series instance.
98+
* ``history_time_idx`` : torch.Tensor of shape (context_length,)
99+
Time indices for the encoder sequence.
100+
* ``future_time_idx`` : torch.Tensor of shape (prediction_length,)
101+
Time indices for the decoder sequence.
102+
* ``history_target`` : torch.Tensor of shape (context_length,)
103+
Historical target values for the encoder sequence.
104+
* ``future_target`` : torch.Tensor of shape (prediction_length,)
105+
Target values for the decoder sequence.
106+
* ``future_target_len`` : torch.Tensor of shape (1,)
107+
Length of the decoder target sequence.
108+
109+
Optional fields, depending on dataset configuration:
110+
111+
* ``history_relative_time_idx`` : torch.Tensor of shape (context_length,),
112+
optional
113+
Relative time indices for the encoder sequence, present if
114+
`add_relative_time_idx` is True.
115+
* ``future_relative_time_idx`` : torch.Tensor of shape (prediction_length,),
116+
optional
117+
Relative time indices for the decoder sequence, present if
118+
`add_relative_time_idx` is True.
119+
* ``static_categorical_features`` : torch.Tensor of shape
120+
(1, n_static_features), optional
121+
Static categorical features if available.
122+
* ``static_continuous_features`` : torch.Tensor of shape
123+
(1, n_static_features), optional
124+
Static continuous features if available.
125+
* ``target_scale`` : torch.Tensor of shape (1,), optional
126+
Scaling factor for the target values if provided by the dataset.
127+
128+
y : torch.Tensor or list of torch.Tensor
129+
Target values for the decoder sequence.
130+
If ``n_targets`` > 1, a list of tensors each of shape (prediction_length,)
131+
is returned. Otherwise, a tensor of shape (prediction_length,) is returned.
77132
"""
78133

79134
series_idx, start_idx, context_length, prediction_length = self.windows[idx]
@@ -170,6 +225,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
170225
x["target_scale"] = processed_data["target_scale"]
171226

172227
y = processed_data["target"][future_indices]
228+
if self.data_module.n_targets > 1:
229+
y = [t.squeeze(-1) for t in torch.split(y, 1, dim=1)]
230+
else:
231+
y = y.squeeze(-1)
173232

174233
return x, y
175234

@@ -294,6 +353,7 @@ def __init__(
294353
self.window_stride = window_stride
295354

296355
self.time_series_metadata = time_series_dataset.get_metadata()
356+
self.n_targets = len(self.time_series_metadata["cols"]["y"])
297357

298358
for idx, col in enumerate(self.time_series_metadata["cols"]["x"]):
299359
if self.time_series_metadata["col_type"].get(col) == "C":
@@ -774,8 +834,11 @@ def collate_fn(batch):
774834
775835
Returns
776836
-------
777-
tuple[dict[str, torch.Tensor], torch.Tensor]
837+
tuple[dict[str, torch.Tensor], torch.Tensor or list of torch.Tensor]
778838
A tuple containing the collated data and the target variable.
839+
If the dataset has multiple targets, a list of tensors each of shape
840+
(batch_size, prediction_length,). Otherwise, a single tensor of shape
841+
(batch_size, prediction_length).
779842
"""
780843

781844
x_batch = {
@@ -816,5 +879,13 @@ def collate_fn(batch):
816879
[x["static_continuous_features"] for x, _ in batch]
817880
)
818881

819-
y_batch = torch.stack([y for _, y in batch])
882+
if isinstance(batch[0][1], (list, tuple)):
883+
num_targets = len(batch[0][1])
884+
y_batch = []
885+
for i in range(num_targets):
886+
target_tensors = [sample_y[i] for _, sample_y in batch]
887+
stacked_target = torch.stack(target_tensors)
888+
y_batch.append(stacked_target)
889+
else:
890+
y_batch = torch.stack([y for _, y in batch])
820891
return x_batch, y_batch

pytorch_forecasting/data/data_module.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def __init__(
127127
self.train_val_test_split = train_val_test_split
128128

129129
warn(
130-
"TimeSeries is part of an experimental rework of the "
130+
"EncoderDecoderTimeSeriesDataModule is part of an experimental "
131+
"rework of the "
131132
"pytorch-forecasting data layer, "
132133
"scheduled for release with v2.0.0. "
133134
"The API is not stable and may change without prior warning. "
@@ -151,6 +152,7 @@ def __init__(
151152
self._min_encoder_length = min_encoder_length or max_encoder_length
152153
self._categorical_encoders = _coerce_to_dict(categorical_encoders)
153154
self._scalers = _coerce_to_dict(scalers)
155+
self.n_targets = len(self.time_series_metadata["cols"]["y"])
154156

155157
self.categorical_indices = []
156158
self.continuous_indices = []
@@ -382,6 +384,13 @@ def __len__(self):
382384
def __getitem__(self, idx):
383385
"""Retrieve a processed time series window for dataloader input.
384386
387+
Parameters
388+
----------
389+
idx : int
390+
Index of the window to retrieve from the dataset.
391+
392+
Returns
393+
-------
385394
x : dict
386395
Dictionary containing model inputs:
387396
@@ -405,6 +414,8 @@ def __getitem__(self, idx):
405414
Time indices for the encoder sequence.
406415
* ``decoder_time_idx`` : tensor of shape (pred_length,)
407416
Time indices for the decoder sequence.
417+
* ``target_past`` : torch.Tensor of shape (enc_length,)
418+
Historical target values for the encoder sequence.
408419
* ``target_scale`` : tensor of shape (1,)
409420
Scaling factor for the target values.
410421
* ``encoder_mask`` : tensor of shape (enc_length,)
@@ -420,8 +431,10 @@ def __getitem__(self, idx):
420431
* ``static_continuous_features`` : tensor of shape (1, 0), optional
421432
Placeholder for static continuous features (currently empty).
422433
423-
y : tensor of shape ``(pred_length, n_targets)``
434+
y : torch.Tensor or list of torch.Tensor
424435
Target values for the decoder sequence.
436+
If ``n_targets`` > 1, a list of tensors each of shape (pred_length,)
437+
is returned. Otherwise, a tensor of shape (pred_length,) is returned.
425438
"""
426439
series_idx, start_idx, enc_length, pred_length = self.windows[idx]
427440
data = self.data_module._preprocess_data(series_idx)
@@ -547,8 +560,11 @@ def __getitem__(self, idx):
547560
)
548561

549562
y = data["target"][decoder_indices]
550-
if y.ndim == 1:
551-
y = y.unsqueeze(-1)
563+
564+
if self.data_module.n_targets > 1:
565+
y = [t.squeeze(-1) for t in torch.split(y, 1, dim=1)]
566+
else:
567+
y = y.squeeze(-1)
552568

553569
return x, y
554570

@@ -730,5 +746,13 @@ def collate_fn(batch):
730746
[x["static_continuous_features"] for x, _ in batch]
731747
)
732748

733-
y_batch = torch.stack([y for _, y in batch])
749+
if isinstance(batch[0][1], (list, tuple)):
750+
num_targets = len(batch[0][1])
751+
y_batch = []
752+
for i in range(num_targets):
753+
target_tensors = [sample_y[i] for _, sample_y in batch]
754+
stacked_target = torch.stack(target_tensors)
755+
y_batch.append(stacked_target)
756+
else:
757+
y_batch = torch.stack([y for _, y in batch])
734758
return x_batch, y_batch

pytorch_forecasting/models/base/_base_model_v2.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,39 +14,40 @@
1414
import torch.nn as nn
1515
from torch.optim import Optimizer
1616

17+
from pytorch_forecasting.metrics import Metric
1718
from pytorch_forecasting.utils._classproperty import classproperty
1819

1920

2021
class BaseModel(LightningModule):
22+
"""Base model for time series forecasting.
23+
24+
Parameters
25+
----------
26+
loss : Descendants of ``pytorch_forecasting.metrics.Metric`` class
27+
Loss function to use for training.
28+
logging_metrics : Optional[List[nn.Module]], optional
29+
List of metrics to log during training, validation, and testing.
30+
optimizer : Optional[Union[Optimizer, str]], optional
31+
Optimizer to use for training.
32+
Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`.
33+
optimizer_params : Optional[Dict], optional
34+
Parameters for the optimizer.
35+
lr_scheduler : Optional[str], optional
36+
Learning rate scheduler to use.
37+
Supported values: "reduce_lr_on_plateau", "step_lr".
38+
lr_scheduler_params : Optional[Dict], optional
39+
Parameters for the learning rate scheduler.
40+
"""
41+
2142
def __init__(
2243
self,
23-
loss: nn.Module,
44+
loss: Metric,
2445
logging_metrics: Optional[list[nn.Module]] = None,
2546
optimizer: Optional[Union[Optimizer, str]] = "adam",
2647
optimizer_params: Optional[dict] = None,
2748
lr_scheduler: Optional[str] = None,
2849
lr_scheduler_params: Optional[dict] = None,
2950
):
30-
"""
31-
Base model for time series forecasting.
32-
33-
Parameters
34-
----------
35-
loss : nn.Module
36-
Loss function to use for training.
37-
logging_metrics : Optional[List[nn.Module]], optional
38-
List of metrics to log during training, validation, and testing.
39-
optimizer : Optional[Union[Optimizer, str]], optional
40-
Optimizer to use for training.
41-
Can be a string ("adam", "sgd") or an instance of `torch.optim.Optimizer`.
42-
optimizer_params : Optional[Dict], optional
43-
Parameters for the optimizer.
44-
lr_scheduler : Optional[str], optional
45-
Learning rate scheduler to use.
46-
Supported values: "reduce_lr_on_plateau", "step_lr".
47-
lr_scheduler_params : Optional[Dict], optional
48-
Parameters for the learning rate scheduler.
49-
"""
5051
super().__init__()
5152
self.loss = loss
5253
self.logging_metrics = logging_metrics if logging_metrics is not None else []

pytorch_forecasting/models/base/_tslib_base_model_v2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch.nn as nn
1010
from torch.optim import Optimizer
1111

12+
from pytorch_forecasting.metrics import Metric
1213
from pytorch_forecasting.models.base._base_model_v2 import BaseModel
1314

1415

@@ -18,7 +19,7 @@ class TslibBaseModel(BaseModel):
1819
1920
Parameters
2021
----------
21-
loss : nn.Module
22+
loss : Descendants of ``pytorch_forecasting.metrics.Metric`` class
2223
Loss function to use for training.
2324
logging_metrics : Optional[list[nn.Module]], optional
2425
list of metrics to log during training, validation, and testing.
@@ -36,7 +37,7 @@ class TslibBaseModel(BaseModel):
3637

3738
def __init__(
3839
self,
39-
loss: nn.Module,
40+
loss: Metric,
4041
logging_metrics: Optional[list[nn.Module]] = None,
4142
optimizer: Optional[Union[Optimizer, str]] = "adam",
4243
optimizer_params: Optional[dict] = None,

pytorch_forecasting/models/samformer/_samformer_v2_pkg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,20 +117,20 @@ def get_test_train_params(cls):
117117

118118
return [
119119
{
120-
"loss": nn.MSELoss(),
120+
# "loss": nn.MSELoss(),
121121
"hidden_size": 32,
122122
"use_revin": False,
123123
},
124124
{
125-
"loss": nn.MSELoss(),
125+
# "loss": nn.MSELoss(),
126126
"hidden_size": 16,
127127
"use_revin": True,
128128
"out_channels": 1,
129129
"persistence_weight": 0.0,
130130
},
131-
# {
132-
# "loss": QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
133-
# "hidden_size": 32,
134-
# "use_revin": False,
135-
# },
131+
{
132+
"loss": QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
133+
"hidden_size": 32,
134+
"use_revin": False,
135+
},
136136
]

pytorch_forecasting/models/tide/_tide_dsipts/_tide_v2_pkg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_test_train_params(cls):
109109
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
110110
`create_test_instance` uses the first (or only) dictionary in `params`
111111
"""
112-
import torch.nn as nn
112+
from pytorch_forecasting.metrics import MAE, MAPE
113113

114114
return [
115115
dict(
@@ -126,7 +126,7 @@ def get_test_train_params(cls):
126126
n_add_dec=2,
127127
dropout_rate=0.2,
128128
data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3),
129-
loss=nn.MSELoss(),
129+
loss=MAE(),
130130
),
131131
dict(
132132
hidden_size=64,
@@ -135,6 +135,6 @@ def get_test_train_params(cls):
135135
n_add_dec=2,
136136
dropout_rate=0.1,
137137
data_loader_kwargs=dict(max_encoder_length=4, max_prediction_length=2),
138-
loss=nn.PoissonNLLLoss(),
138+
loss=MAPE(),
139139
),
140140
]

pytorch_forecasting/models/timexer/_timexer_pkg_v2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def get_test_train_params(cls):
109109
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
110110
`create_test_instance` uses the first (or only) dictionary in `params`
111111
"""
112+
from pytorch_forecasting.metrics import QuantileLoss
113+
112114
return [
113115
{},
114116
dict(
@@ -158,5 +160,6 @@ def get_test_train_params(cls):
158160
context_length=16,
159161
prediction_length=4,
160162
),
163+
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
161164
),
162165
]

pytorch_forecasting/tests/test_all_estimators_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Automated tests based on the skbase test suite template."""
22

3-
from inspect import isclass
43
import shutil
54

65
import lightning.pytorch as pl
@@ -9,6 +8,7 @@
98
import torch
109
import torch.nn as nn
1110

11+
from pytorch_forecasting.metrics import SMAPE
1212
from pytorch_forecasting.tests.test_all_estimators import (
1313
EstimatorFixtureGenerator,
1414
EstimatorPackageConfig,
@@ -62,7 +62,7 @@ def _integration(
6262
loss = kwargs["loss"]
6363
kwargs.pop("loss")
6464
else:
65-
loss = nn.MSELoss()
65+
loss = SMAPE()
6666

6767
net = estimator_cls(
6868
metadata=metadata,

0 commit comments

Comments
 (0)