Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
654 changes: 154 additions & 500 deletions docs/source/tutorials/ptf_V2_example.ipynb

Large diffs are not rendered by default.

413 changes: 133 additions & 280 deletions docs/source/tutorials/tslib_v2_example.ipynb

Large diffs are not rendered by default.

15 changes: 14 additions & 1 deletion pytorch_forecasting/data/_tslib_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
x["target_scale"] = processed_data["target_scale"]

y = processed_data["target"][future_indices]
if self.data_module.n_targets > 1:
y = list(torch.split(y, 1, dim=1))
else:
y = y.squeeze(-1)

return x, y

Expand Down Expand Up @@ -294,6 +298,7 @@ def __init__(
self.window_stride = window_stride

self.time_series_metadata = time_series_dataset.get_metadata()
self.n_targets = len(self.time_series_metadata["cols"]["y"])

for idx, col in enumerate(self.time_series_metadata["cols"]["x"]):
if self.time_series_metadata["col_type"].get(col) == "C":
Expand Down Expand Up @@ -816,5 +821,13 @@ def collate_fn(batch):
[x["static_continuous_features"] for x, _ in batch]
)

y_batch = torch.stack([y for _, y in batch])
if isinstance(batch[0][1], (list, tuple)):
num_targets = len(batch[0][1])
y_batch = []
for i in range(num_targets):
target_tensors = [sample_y[i] for _, sample_y in batch]
stacked_target = torch.stack(target_tensors)
y_batch.append(stacked_target)
else:
y_batch = torch.stack([y for _, y in batch])
return x_batch, y_batch
18 changes: 15 additions & 3 deletions pytorch_forecasting/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
self._min_encoder_length = min_encoder_length or max_encoder_length
self._categorical_encoders = _coerce_to_dict(categorical_encoders)
self._scalers = _coerce_to_dict(scalers)
self.n_targets = len(self.time_series_metadata["cols"]["y"])

self.categorical_indices = []
self.continuous_indices = []
Expand Down Expand Up @@ -547,8 +548,11 @@ def __getitem__(self, idx):
)

y = data["target"][decoder_indices]
if y.ndim == 1:
y = y.unsqueeze(-1)

if self.data_module.n_targets > 1:
y = list(torch.split(y, 1, dim=1))
else:
y = y.squeeze(-1)

return x, y

Expand Down Expand Up @@ -730,5 +734,13 @@ def collate_fn(batch):
[x["static_continuous_features"] for x, _ in batch]
)

y_batch = torch.stack([y for _, y in batch])
if isinstance(batch[0][1], (list, tuple)):
num_targets = len(batch[0][1])
y_batch = []
for i in range(num_targets):
target_tensors = [sample_y[i] for _, sample_y in batch]
stacked_target = torch.stack(target_tensors)
y_batch.append(stacked_target)
else:
y_batch = torch.stack([y for _, y in batch])
return x_batch, y_batch
5 changes: 3 additions & 2 deletions pytorch_forecasting/layers/_output/_flatten_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
x = x.permute(0, 2, 1)

if self.n_quantiles is not None:
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
batch_size = x.shape[0]
x = x.reshape(batch_size, -1, self.n_quantiles)
return x
5 changes: 3 additions & 2 deletions pytorch_forecasting/models/base/_base_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import torch.nn as nn
from torch.optim import Optimizer

from pytorch_forecasting.metrics import Metric
from pytorch_forecasting.utils._classproperty import classproperty


class BaseModel(LightningModule):
def __init__(
self,
loss: nn.Module,
loss: Metric,
logging_metrics: Optional[list[nn.Module]] = None,
optimizer: Optional[Union[Optimizer, str]] = "adam",
optimizer_params: Optional[dict] = None,
Expand All @@ -32,7 +33,7 @@ def __init__(

Parameters
----------
loss : nn.Module
loss : Metric
Loss function to use for training.
logging_metrics : Optional[List[nn.Module]], optional
List of metrics to log during training, validation, and testing.
Expand Down
5 changes: 3 additions & 2 deletions pytorch_forecasting/models/base/_tslib_base_model_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
from torch.optim import Optimizer

from pytorch_forecasting.metrics import Metric
from pytorch_forecasting.models.base._base_model_v2 import BaseModel


Expand All @@ -18,7 +19,7 @@ class TslibBaseModel(BaseModel):

Parameters
----------
loss : nn.Module
loss : Metric
Loss function to use for training.
logging_metrics : Optional[list[nn.Module]], optional
list of metrics to log during training, validation, and testing.
Expand All @@ -36,7 +37,7 @@ class TslibBaseModel(BaseModel):

def __init__(
self,
loss: nn.Module,
loss: Metric,
logging_metrics: Optional[list[nn.Module]] = None,
optimizer: Optional[Union[Optimizer, str]] = "adam",
optimizer_params: Optional[dict] = None,
Expand Down
11 changes: 3 additions & 8 deletions pytorch_forecasting/models/dlinear/_dlinear_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,17 @@ def _reshape_output(self, output: torch.Tensor) -> torch.Tensor:
Returns
-------
output: torch.Tensor
Reshaped tensor (batch_size, prediction_length, n_features, n_quantiles)
Reshaped tensor (batch_size, prediction_length, n_quantiles)
or (batch_size, prediction_length, n_features) if n_quantiles is None.
"""
if self.n_quantiles is not None:
batch_size, n_features = output.shape[0], output.shape[1]
batch_size = output.shape[0]
output = output.reshape(
batch_size, n_features, self.prediction_length, self.n_quantiles
batch_size, self.prediction_length, self.n_quantiles
)
output = output.permute(0, 2, 1, 3) # (batch, time, features, quantiles)
else:
output = output.permute(0, 2, 1) # (batch, time, features)

# univariate forecasting
if self.target_dim == 1 and output.shape[-1] == 1:
output = output.squeeze(-1)

return output

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Expand Down
14 changes: 7 additions & 7 deletions pytorch_forecasting/models/samformer/_samformer_v2_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,20 @@ def get_test_train_params(cls):

return [
{
"loss": nn.MSELoss(),
# "loss": nn.MSELoss(),
"hidden_size": 32,
"use_revin": False,
},
{
"loss": nn.MSELoss(),
# "loss": nn.MSELoss(),
"hidden_size": 16,
"use_revin": True,
"out_channels": 1,
"persistence_weight": 0.0,
},
# {
# "loss": QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
# "hidden_size": 32,
# "use_revin": False,
# },
{
"loss": QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
"hidden_size": 32,
"use_revin": False,
},
]
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def get_test_train_params(cls):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
import torch.nn as nn
from pytorch_forecasting.metrics import MAE, MAPE

return [
dict(
Expand All @@ -126,7 +126,7 @@ def get_test_train_params(cls):
n_add_dec=2,
dropout_rate=0.2,
data_loader_kwargs=dict(max_encoder_length=5, max_prediction_length=3),
loss=nn.MSELoss(),
loss=MAE(),
),
dict(
hidden_size=64,
Expand All @@ -135,6 +135,6 @@ def get_test_train_params(cls):
n_add_dec=2,
dropout_rate=0.1,
data_loader_kwargs=dict(max_encoder_length=4, max_prediction_length=2),
loss=nn.PoissonNLLLoss(),
loss=MAPE(),
),
]
3 changes: 3 additions & 0 deletions pytorch_forecasting/models/timexer/_timexer_pkg_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def get_test_train_params(cls):
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from pytorch_forecasting.metrics import QuantileLoss

return [
{},
dict(
Expand Down Expand Up @@ -158,5 +160,6 @@ def get_test_train_params(cls):
context_length=16,
prediction_length=4,
),
loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
),
]
9 changes: 0 additions & 9 deletions pytorch_forecasting/models/timexer/_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,6 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

dec_out = self.head(enc_out)

if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)

return dec_out

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Expand All @@ -330,10 +325,6 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
out = self._forecast(x)
prediction = out[:, : self.prediction_length, :]

# check to see if the output shape is equal to number of targets
if prediction.size(2) != self.target_dim:
prediction = prediction[:, :, : self.target_dim]

if "target_scale" in x:
prediction = self.transform_output(prediction, x["target_scale"])

Expand Down
5 changes: 2 additions & 3 deletions pytorch_forecasting/tests/test_all_estimators_v2.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Automated tests based on the skbase test suite template."""

from inspect import isclass
import shutil

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
import torch.nn as nn

from pytorch_forecasting.metrics import SMAPE
from pytorch_forecasting.tests.test_all_estimators import (
EstimatorFixtureGenerator,
EstimatorPackageConfig,
Expand Down Expand Up @@ -61,7 +60,7 @@ def _integration(
loss = kwargs["loss"]
kwargs.pop("loss")
else:
loss = nn.MSELoss()
loss = SMAPE()

net = estimator_cls(
metadata=metadata,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,4 +469,4 @@ def test_multivariate_target():
dm.setup()

x, y = dm.train_dataset[0]
assert y.shape[-1] == 2
assert len(y) == 2
2 changes: 1 addition & 1 deletion tests/test_models/test_dlinear_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_quantile_loss_output(sample_dataset):

assert "prediction" in output
pred = output["prediction"]
assert pred.ndim == 4
assert pred.ndim == 3
assert pred.shape[-1] == len(quantiles)
assert pred.shape[1] == metadata["prediction_length"]

Expand Down
1 change: 0 additions & 1 deletion tests/test_models/test_tft_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,5 +394,4 @@ def test_model_with_datamodule_integration(
assert batch_y.shape == (
actual_batch_size,
MAX_PREDICTION_LENGTH_TEST,
model_metadata_from_dm["target"],
)
2 changes: 1 addition & 1 deletion tests/test_models/test_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_quantile_predictions(basic_metadata):
output = model(sample_input_data)

predictions = output["prediction"]
assert predictions.shape == (batch_size, 8, 1, 3)
assert predictions.shape == (batch_size, 8, 3)


def test_missing_history_target_handling(basic_metadata):
Expand Down
Loading