Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
4e06e26
Precompute data to massively accelerate training in GPU
jobs-git May 27, 2025
913984b
Enable using both slow and fast path on `__getitem__ ()`
jobs-git May 28, 2025
e41ce8b
Fix formatting for lint
jobs-git May 28, 2025
ab34f02
Fix check error due to Dict
jobs-git May 28, 2025
394bc79
Fix check error in code quality
jobs-git May 28, 2025
cd12c80
Fix check error due to spacing
jobs-git May 28, 2025
087799f
Fix check error due to spacing
jobs-git May 28, 2025
7831ea7
Fix check error due to spacing
jobs-git May 28, 2025
a2007f9
Fix check error due to spacing
jobs-git May 28, 2025
9b9085e
Add `precompute` as settings in `TimeSeriesDataset` class
jobs-git Jun 1, 2025
9fe2eb0
Merge branch 'sktime:main' into patch-1
jobs-git Jun 5, 2025
3f18cdb
Added test
jobs-git Jun 5, 2025
4332f7a
Merge branch 'sktime:main' into patch-1
jobs-git Jun 5, 2025
ce8deb0
Merge branch 'sktime:main' into patch-1
jobs-git Jun 8, 2025
1d13042
Organize precompute logic in `__getitem__`
jobs-git Jun 8, 2025
2bd3e15
Organize precompute logic in __getitem__
jobs-git Jun 8, 2025
4e169fb
Merge branch 'sktime:main' into patch-1
jobs-git Jun 8, 2025
5f44326
Further improve performance to ~2000% by pre-collating batches
jobs-git Jun 12, 2025
eff2931
Fix check error
jobs-git Jun 12, 2025
6ad47cf
Fix check error
jobs-git Jun 12, 2025
ef48f01
Merge branch 'sktime:main' into patch-1
jobs-git Jun 12, 2025
1a75134
Add samples demonstrating GPU performance
jobs-git Jun 12, 2025
07d0ba9
Merge branch 'sktime:main' into patch-1
jobs-git Jun 13, 2025
8c25a39
added benchmark and usage examples
jobs-git Jun 14, 2025
06613a1
fix check error due to missing sktime dataset module
jobs-git Jun 14, 2025
c0c519b
fix check error due to missing cuda device from github test env
jobs-git Jun 14, 2025
7f97963
defaults to torch dataloader when no batch_sampler is set
jobs-git Jun 14, 2025
23868e3
prevent from overriding batch_sampler when both precompute and batch_…
jobs-git Jun 14, 2025
3ef3417
Merge branch 'sktime:main' into patch-1
jobs-git Jun 14, 2025
590a3b4
update changelog
jobs-git Jun 14, 2025
841f8cf
added test to check batch shape mismatch between precompute=True and …
jobs-git Jun 14, 2025
d8de451
updated changelog
jobs-git Jun 14, 2025
7df1500
Merge branch 'main' into pr/1850
fkiraly Jul 11, 2025
a2d4b8a
Merge branch 'sktime:main' into patch-1
jobs-git Sep 11, 2025
13b59e6
remove tutorial and sample usage
jobs-git Sep 11, 2025
8e09831
updated changelog
jobs-git Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
# Release Notes

## v1.5.0

Feature and maintenance update.

### Highlights

* added option to pre calculate tensors in `TimeSeriesDataSet` by setting `precompute=True`.

### Enhancements

* [ENH] Precompute data to massively accelerate training by ~2000% in GPU by @jobs-git in https://github.com/sktime/pytorch-forecasting/pull/1850
* [ENH] Added test for `TimeSeriesDataSet` when `precompute=True` by @jobs-git in https://github.com/sktime/pytorch-forecasting/pull/1850
* [ENH] Added test to check batch shape mismatch between precompute=True and False by @jobs-git in https://github.com/sktime/pytorch-forecasting/pull/1850
* [ENH] Added benchmark test to compare with and without `precompute=True` in GPU and CPU by @jobs-git in https://github.com/sktime/pytorch-forecasting/pull/1850

### All Contributors

@jobs-git


## v1.4.0

Feature and maintenance update.
Expand Down
84 changes: 81 additions & 3 deletions pytorch_forecasting/data/timeseries/_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from torch.distributions import Beta
from torch.nn.utils import rnn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler
from torch.utils.data.sampler import Sampler, SequentialSampler

from pytorch_forecasting.data.encoders import (
Expand Down Expand Up @@ -476,10 +476,15 @@ def __init__(
] = None,
randomize_length: Union[None, tuple[float, float], bool] = False,
predict_mode: bool = False,
precompute: bool = False,
):
"""Timeseries dataset holding data for models."""
super().__init__()

self.precollate_cache = []
self.precollate_idx = 0
self.precompute = precompute

# write variables to self and handle defaults
# -------------------------------------------
self.max_encoder_length = max_encoder_length
Expand Down Expand Up @@ -2095,7 +2100,7 @@ def calculate_decoder_length(
).clip(max=self.max_prediction_length)
return decoder_length

def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
def __item_tensor__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""
Get sample for model

Expand Down Expand Up @@ -2356,6 +2361,75 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
(target, weight),
)

def __precompute__(self, kwargs):
"""
Precompute sample for model

Args:
**kwargs: additional arguments passed to ``DataLoader`` constructor
"""
batch_sampler = kwargs["batch_sampler"]
if batch_sampler is None:
sampler = (
RandomSampler(self) if kwargs["shuffle"] else SequentialSampler(self)
)
batch_sampler = BatchSampler(
sampler=sampler,
batch_size=kwargs["batch_size"],
drop_last=kwargs["drop_last"],
)
else:
if isinstance(batch_sampler, str):
sampler = kwargs["batch_sampler"]
if sampler == "synchronized":
batch_sampler = TimeSynchronizedBatchSampler(
SequentialSampler(self),
batch_size=kwargs["batch_size"],
shuffle=kwargs["shuffle"],
drop_last=kwargs["drop_last"],
)
else:
raise ValueError(
f"batch_sampler '{batch_sampler}' is not recognized."
)
else:
raise ValueError(f"batch_sampler '{batch_sampler}' is not recognized.")

for batch in batch_sampler:
batch_samples = []

for idx in batch:
batch_result = self.__item_tensor__(idx)
batch_samples.append(batch_result)

batch = self._collate_fn(batch_samples)
self.precollate_cache.append(batch)

def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""
Get sample for model

Args:
idx (int): index of prediction (between ``0`` and ``len(dataset) - 1``)

Returns:
tuple[dict[str, torch.Tensor], torch.Tensor]: x and y for model
"""
if self.precompute:
return None

return self.__item_tensor__(idx)

def __fast_collate_fn__(self):
def _collate_fn_(batches):
if self.precollate_idx >= len(self.precollate_cache):
self.precollate_idx = 0
batch = self.precollate_cache[self.precollate_idx]
self.precollate_idx += 1
return batch

return _collate_fn_

@staticmethod
def _collate_fn(
batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]],
Expand Down Expand Up @@ -2627,7 +2701,11 @@ def to_dataloader(
)
default_kwargs.update(kwargs)
kwargs = default_kwargs
if kwargs["batch_sampler"] is not None:

if self.precompute:
kwargs["collate_fn"] = self.__fast_collate_fn__()
self.__precompute__(kwargs)
elif kwargs["batch_sampler"] is not None:
sampler = kwargs["batch_sampler"]
if isinstance(sampler, str):
if sampler == "synchronized":
Expand Down
Loading
Loading