Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_forecasting/layers/_output/_flatten_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def forward(self, x):
x = self.flatten(x)
x = self.linear(x)
x = self.dropout(x)
x = x.permute(0, 2, 1)

if self.n_quantiles is not None:
batch_size, n_vars = x.shape[0], x.shape[1]
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
batch_size = x.shape[0]
x = x.reshape(batch_size, -1, self.n_quantiles)
return x
11 changes: 3 additions & 8 deletions pytorch_forecasting/models/dlinear/_dlinear_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,22 +239,17 @@ def _reshape_output(self, output: torch.Tensor) -> torch.Tensor:
Returns
-------
output: torch.Tensor
Reshaped tensor (batch_size, prediction_length, n_features, n_quantiles)
Reshaped tensor (batch_size, prediction_length, n_quantiles)
or (batch_size, prediction_length, n_features) if n_quantiles is None.
"""
if self.n_quantiles is not None:
batch_size, n_features = output.shape[0], output.shape[1]
batch_size = output.shape[0]
output = output.reshape(
batch_size, n_features, self.prediction_length, self.n_quantiles
batch_size, self.prediction_length, self.n_quantiles
)
output = output.permute(0, 2, 1, 3) # (batch, time, features, quantiles)
else:
output = output.permute(0, 2, 1) # (batch, time, features)

# univariate forecasting
if self.target_dim == 1 and output.shape[-1] == 1:
output = output.squeeze(-1)

return output

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Expand Down
9 changes: 0 additions & 9 deletions pytorch_forecasting/models/timexer/_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,6 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

dec_out = self.head(enc_out)

if self.n_quantiles is not None:
dec_out = dec_out.permute(0, 2, 1, 3)
else:
dec_out = dec_out.permute(0, 2, 1)

return dec_out

def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
Expand All @@ -330,10 +325,6 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
out = self._forecast(x)
prediction = out[:, : self.prediction_length, :]

# check to see if the output shape is equal to number of targets
if prediction.size(2) != self.target_dim:
prediction = prediction[:, :, : self.target_dim]

if "target_scale" in x:
prediction = self.transform_output(prediction, x["target_scale"])

Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_dlinear_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_quantile_loss_output(sample_dataset):

assert "prediction" in output
pred = output["prediction"]
assert pred.ndim == 4
assert pred.ndim == 3
assert pred.shape[-1] == len(quantiles)
assert pred.shape[1] == metadata["prediction_length"]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_models/test_timexer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_quantile_predictions(basic_metadata):
output = model(sample_input_data)

predictions = output["prediction"]
assert predictions.shape == (batch_size, 8, 1, 3)
assert predictions.shape == (batch_size, 8, 3)


def test_missing_history_target_handling(basic_metadata):
Expand Down
Loading