diff --git a/docs/source/basic_tutorials/migration.md b/docs/source/basic_tutorials/migration.md
index 6220702e977..8fb2c32f981 100644
--- a/docs/source/basic_tutorials/migration.md
+++ b/docs/source/basic_tutorials/migration.md
@@ -219,3 +219,9 @@ During training, you may want to save the current state of the model, optimizer,
To further customize where and how states are saved through [`~Accelerator.save_state`], use the [`~utils.ProjectConfiguration`] class. For example, if `automatic_checkpoint_naming` is enabled, each saved checkpoint is stored at `Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}`.
Any other stateful items to be stored should be registered with the [`~Accelerator.register_for_checkpointing`] method so they can be saved and loaded. Every object passed to this method to be stored must have a `load_state_dict` and `state_dict` function.
+
+
+
+If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, you can additionally pass `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`]. This extends Accelerate's DataLoader classes with a `load_state_dict` and `state_dict` function, and makes it so `Accelerator.save_state` and `Accelerator.load_state` also track how far into the training dataset it has read when persisting the model.
+
+
diff --git a/docs/source/concept_guides/internal_mechanism.md b/docs/source/concept_guides/internal_mechanism.md
index e0b715dfa63..2410d882bb5 100644
--- a/docs/source/concept_guides/internal_mechanism.md
+++ b/docs/source/concept_guides/internal_mechanism.md
@@ -69,4 +69,10 @@ setting the same seed in the main random number generator in all processes.
+
+
+If you have [`torchdata>=0.8.0`](https://github.com/pytorch/data/tree/main) installed, and you have passed `use_stateful_dataloader=True` into your [`~utils.DataLoaderConfiguration`], these classes will directly inherit from `StatefulDataLoader` instead, and maintain a `state_dict`.
+
+
+
For more details about the internals, see the [Internals page](package_reference/torch_wrappers).
diff --git a/setup.py b/setup.py
index 27d609cfa11..85a90bf82d8 100644
--- a/setup.py
+++ b/setup.py
@@ -27,6 +27,7 @@
"datasets",
"diffusers",
"evaluate",
+ "torchdata>=0.8.0",
"torchpippy>=0.2.0",
"transformers",
"scipy",
diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py
index 3f5f1279132..4ed80537144 100755
--- a/src/accelerate/accelerator.py
+++ b/src/accelerate/accelerator.py
@@ -583,6 +583,12 @@ def use_seedable_sampler(self):
def non_blocking(self):
return self.dataloader_config.non_blocking
+ @property
+ def use_stateful_dataloader(self):
+ if hasattr(self.dataloader_config, "use_stateful_dataloader"):
+ return self.dataloader_config.use_stateful_dataloader
+ return False
+
@property
def project_dir(self):
return self.project_configuration.project_dir
@@ -2068,6 +2074,7 @@ def prepare_data_loader(
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
non_blocking=self.non_blocking,
+ use_stateful_dataloader=self.use_stateful_dataloader,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py
index f0e88c645ec..e5b6364b4ab 100644
--- a/src/accelerate/data_loader.py
+++ b/src/accelerate/data_loader.py
@@ -30,6 +30,7 @@
get_data_structure,
initialize_tensors,
is_torch_version,
+ is_torchdata_stateful_dataloader_available,
send_to_device,
slice_tensors,
synchronize_rng_states,
@@ -388,9 +389,75 @@ def end(self):
self.gradient_state._remove_dataloader(self)
-class DataLoaderShard(DataLoader, DataLoaderStateMixin):
+class DataLoaderAdapter:
"""
- Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup.
+ A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
+ compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
+ """
+
+ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
+ self.use_stateful_dataloader = use_stateful_dataloader
+ if is_torchdata_stateful_dataloader_available():
+ from torchdata.stateful_dataloader import StatefulDataLoader
+
+ if use_stateful_dataloader and not is_torchdata_stateful_dataloader_available():
+ raise ImportError(
+ "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
+ )
+ if use_stateful_dataloader:
+ self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
+ else:
+ self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
+
+ # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
+ # In C++ terms, this is analogous to creating `DataLoaderAdapter : T`, where T is a DataLoader or
+ # StatefulDataLoader
+ #
+ # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
+ # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
+ # dispatching scattered throughout various functions and files.
+ #
+ # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
+ # transparently.
+ #
+ # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
+ # but this would not be backwards compatible with existing code which assumes
+ # DataLoaderShard/DataLoaderDispatcher are DataLoaders.
+ base_cls = self.__class__
+ base_cls_name = self.__class__.__name__
+ parent_cls_name = self.base_dataloader.__class__
+ self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})
+
+ if hasattr(self.base_dataloader, "state_dict"):
+ self.dl_state_dict = self.base_dataloader.state_dict()
+
+ def __getattr__(self, name):
+ # Avoid infinite recursion if we try to access a nonexistent base_dataloader attribute.
+ if name == "base_dataloader":
+ raise AttributeError()
+ # Delegate attribute access to the internal dataloader
+ return getattr(self.base_dataloader, name)
+
+ def state_dict(self):
+ return self.dl_state_dict
+
+ def load_state_dict(self, state_dict):
+ self.base_dataloader.load_state_dict(state_dict)
+ self.dl_state_dict = self.state_dict
+
+ def _update_state_dict(self):
+ # The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
+ # E.g. the implementation of DataLoaderShard involves having an underlying iterator 1 element ahead of
+ # what it wants to yield.
+ #
+ # _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
+ if hasattr(self.base_dataloader, "state_dict"):
+ self.dl_state_dict = self.base_dataloader.state_dict()
+
+
+class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
+ """
+ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
Args:
dataset (`torch.utils.data.dataset.Dataset`):
@@ -409,6 +476,8 @@ class DataLoaderShard(DataLoader, DataLoaderStateMixin):
A random number generator to keep synchronized across processes.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**kwargs (additional keyword arguments, *optional*):
All other keyword arguments to pass to the regular `DataLoader` initialization.
@@ -428,11 +497,12 @@ def __init__(
rng_types=None,
synchronized_generator=None,
skip_batches=0,
+ use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
**kwargs,
):
- super().__init__(dataset, **kwargs)
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.device = device
self.rng_types = rng_types
self.synchronized_generator = synchronized_generator
@@ -448,7 +518,7 @@ def __iter__(self):
self.begin()
self.set_epoch(self.iteration)
- dataloader_iter = super().__iter__()
+ dataloader_iter = self.base_dataloader.__iter__()
# We iterate one batch ahead to check when we are at the end
try:
current_batch = next(dataloader_iter)
@@ -461,6 +531,7 @@ def __iter__(self):
# But we still move it to the device so it is done before `StopIteration` is reached
if self.device is not None:
current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
+ self._update_state_dict()
next_batch = next(dataloader_iter)
if batch_index >= self.skip_batches:
yield current_batch
@@ -564,10 +635,10 @@ def dataloader(self):
return self._loader
-class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
+class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
"""
- Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each
- process their part of the batch.
+ Subclass of `DataLoaderAdapter` that will iterate and preprocess on process 0 only, then dispatch on each process
+ their part of the batch.
Args:
split_batches (`bool`, *optional*, defaults to `False`):
@@ -579,6 +650,8 @@ class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin):
size of the `dataloader` is a round multiple of `batch_size`.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning of an iteration.
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
**Available attributes:**
@@ -594,6 +667,7 @@ def __init__(
dataset,
split_batches: bool = False,
skip_batches=0,
+ use_stateful_dataloader=False,
_drop_last: bool = False,
_non_blocking: bool = False,
slice_fn=None,
@@ -606,7 +680,7 @@ def __init__(
# We need to save the shuffling state of the DataPipe
if isinstance(dataset, ShufflerIterDataPipe):
shuffle = dataset._shuffle_enabled
- super().__init__(dataset, **kwargs)
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.split_batches = split_batches
if shuffle:
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
@@ -627,12 +701,14 @@ def _fetch_batches(self, iterator):
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
+ self._update_state_dict()
batch = next(iterator)
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
batches = []
for _ in range(self.state.num_processes):
+ self._update_state_dict()
batches.append(next(iterator))
try:
batch = concatenate(batches, dim=0)
@@ -673,9 +749,9 @@ def __iter__(self):
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
- main_iterator = super().__iter__()
+ main_iterator = self.base_dataloader.__iter__()
elif self.state.process_index == 0:
- main_iterator = super().__iter__()
+ main_iterator = self.base_dataloader.__iter__()
stop_iteration = False
self._stop_iteration = False
first_batch = None
@@ -812,6 +888,7 @@ def prepare_data_loader(
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
non_blocking: bool = False,
+ use_stateful_dataloader: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@@ -873,6 +950,10 @@ def prepare_data_loader(
non_blocking (`bool`, *optional*, defaults to `False`):
If set to `True`, dataloader will utilize non-blocking host-to-device transfers. If the dataloader has
`pin_memory` set to `True`, this will help to increase overlap between data transfer and computations.
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
+ "If set to true, the dataloader prepared by the Accelerator will be backed by "
+ "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
+ This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
Returns:
@@ -1006,6 +1087,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
slice_fn=slice_fn_for_dispatch,
+ use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
elif sampler_is_batch_sampler:
@@ -1018,6 +1100,7 @@ def prepare_data_loader(
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
synchronized_generator=synchronized_generator,
+ use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
else:
@@ -1029,6 +1112,7 @@ def prepare_data_loader(
synchronized_generator=synchronized_generator,
_drop_last=dataloader.drop_last,
_non_blocking=non_blocking,
+ use_stateful_dataloader=use_stateful_dataloader,
**kwargs,
)
@@ -1046,6 +1130,7 @@ class SkipBatchSampler(BatchSampler):
def __init__(self, batch_sampler, skip_batches=0):
self.batch_sampler = batch_sampler
+ self.sampler = batch_sampler.sampler
self.skip_batches = skip_batches
def __iter__(self):
@@ -1061,7 +1146,7 @@ def __len__(self):
return len(self.batch_sampler) - self.skip_batches
-class SkipDataLoader(DataLoader):
+class SkipDataLoader(DataLoaderAdapter):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
@@ -1070,17 +1155,20 @@ class SkipDataLoader(DataLoader):
The dataset to use to build this datalaoder.
skip_batches (`int`, *optional*, defaults to 0):
The number of batches to skip at the beginning.
+ use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
+ Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`.
kwargs:
All other keyword arguments to pass to the regular `DataLoader` initialization.
"""
- def __init__(self, dataset, skip_batches=0, **kwargs):
- super().__init__(dataset, **kwargs)
+ def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
+ super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches
def __iter__(self):
- for index, batch in enumerate(super().__iter__()):
+ for index, batch in enumerate(self.base_dataloader.__iter__()):
if index >= self.skip_batches:
+ self._update_state_dict()
yield batch
@@ -1088,6 +1176,9 @@ def skip_first_batches(dataloader, num_batches=0):
"""
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
"""
+ if is_torchdata_stateful_dataloader_available():
+ from torchdata.stateful_dataloader import StatefulDataLoader
+
state = PartialState()
if state.distributed_type == DistributedType.XLA:
device = dataloader.device
@@ -1131,6 +1222,7 @@ def skip_first_batches(dataloader, num_batches=0):
split_batches=dataloader.split_batches,
batch_sampler=new_batch_sampler,
_drop_last=dataloader._drop_last,
+ use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
elif isinstance(dataloader, DataLoaderShard):
@@ -1147,12 +1239,17 @@ def skip_first_batches(dataloader, num_batches=0):
device=dataloader.device,
rng_types=dataloader.rng_types,
synchronized_generator=dataloader.synchronized_generator,
+ use_stateful_dataloader=dataloader.use_stateful_dataloader,
**kwargs,
)
else:
if new_batch_sampler is None:
# Need to manually skip batches in the dataloader
- dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs)
+ dataloader = SkipDataLoader(
+ dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs
+ )
+ elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader):
+ dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
else:
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs)
diff --git a/src/accelerate/test_utils/scripts/test_sync.py b/src/accelerate/test_utils/scripts/test_sync.py
index 1029f475f6c..20b5b2c752e 100644
--- a/src/accelerate/test_utils/scripts/test_sync.py
+++ b/src/accelerate/test_utils/scripts/test_sync.py
@@ -305,12 +305,12 @@ def test_gradient_accumulation_with_opt_and_scheduler(
def test_dataloader_break():
accelerator = Accelerator()
-
first_dset = RegressionDataset(length=80)
first_dataloader = DataLoader(first_dset, batch_size=16)
second_dset = RegressionDataset(length=96)
second_dataloader = DataLoader(second_dset, batch_size=16)
first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader)
+
assert accelerator.gradient_state.active_dataloader is None
for iteration, _ in enumerate(first_dataloader):
assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader)
diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py
index 43a55e2339f..f9ade458eac 100644
--- a/src/accelerate/test_utils/testing.py
+++ b/src/accelerate/test_utils/testing.py
@@ -52,6 +52,7 @@
is_timm_available,
is_torch_version,
is_torch_xla_available,
+ is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
@@ -429,6 +430,18 @@ def require_trackers(test_case):
)(test_case)
+def require_torchdata_stateful_dataloader(test_case):
+ """
+ Decorator marking a test that requires torchdata.stateful_dataloader.
+
+ These tests are skipped when torchdata with stateful_dataloader module isn't installed.
+
+ """
+ return unittest.skipUnless(
+ is_torchdata_stateful_dataloader_available(), "test requires torchdata.stateful_dataloader"
+ )(test_case)
+
+
class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py
index 07ba2fdcf8f..ed6c77d8de6 100644
--- a/src/accelerate/utils/__init__.py
+++ b/src/accelerate/utils/__init__.py
@@ -107,6 +107,8 @@
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
+ is_torchdata_available,
+ is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_transformer_engine_available,
is_transformers_available,
diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py
index 0f35a294736..7fa7810878e 100644
--- a/src/accelerate/utils/dataclasses.py
+++ b/src/accelerate/utils/dataclasses.py
@@ -749,7 +749,7 @@ class DataLoaderConfiguration:
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
- " underlying dataset is an `IterableDataslet`, `False` otherwise."
+ " underlying dataset is an `IterableDataset`, `False` otherwise."
},
)
even_batches: bool = field(
@@ -777,6 +777,13 @@ class DataLoaderConfiguration:
" prepared dataloader has `pin_memory` set to `True` to work properly."
},
)
+ use_stateful_dataloader: bool = field(
+ default=False,
+ metadata={
+ "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by "
+ "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
+ },
+ )
@dataclass
diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py
index 592e62e6172..3badfefd684 100644
--- a/src/accelerate/utils/imports.py
+++ b/src/accelerate/utils/imports.py
@@ -431,3 +431,16 @@ def is_xpu_available(check_device=False):
def is_dvclive_available():
return _is_package_available("dvclive")
+
+
+def is_torchdata_available():
+ return _is_package_available("torchdata")
+
+
+# TODO: Remove this function once stateful_dataloader is a stable feature in torchdata.
+def is_torchdata_stateful_dataloader_available():
+ package_exists = _is_package_available("torchdata")
+ if package_exists:
+ torchdata_version = version.parse(importlib.metadata.version("torchdata"))
+ return compare_versions(torchdata_version, ">=", "0.8.0")
+ return False
diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py
index 4ef7a94b281..42df216123a 100644
--- a/tests/test_accelerator.py
+++ b/tests/test_accelerator.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import json
import os
import pickle
@@ -26,6 +27,7 @@
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
+from accelerate.data_loader import skip_first_batches
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
@@ -35,9 +37,20 @@
slow,
torch_device,
)
-from accelerate.test_utils.testing import AccelerateTestCase, require_cuda, require_non_torch_xla
-from accelerate.utils import FP8RecipeKwargs, patch_environment
+from accelerate.test_utils.testing import (
+ AccelerateTestCase,
+ require_cuda,
+ require_non_torch_xla,
+ require_torchdata_stateful_dataloader,
+)
+from accelerate.utils import FP8RecipeKwargs, is_torchdata_stateful_dataloader_available, patch_environment
+from accelerate.utils.dataclasses import DataLoaderConfiguration
from accelerate.utils.modeling import get_state_dict_from_offload, load_checkpoint_in_model
+from accelerate.utils.random import set_seed
+
+
+if is_torchdata_stateful_dataloader_available():
+ from torchdata.stateful_dataloader import StatefulDataLoader
class ModelWithTiedWeights(torch.nn.Module):
@@ -58,7 +71,6 @@ def create_components(tied_weights=False):
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=2, epochs=1)
train_dl = DataLoader(TensorDataset(torch.tensor([1, 2, 3])))
valid_dl = DataLoader(TensorDataset(torch.tensor([4, 5, 6])))
-
return model, optimizer, scheduler, train_dl, valid_dl
@@ -73,6 +85,21 @@ def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))
+def create_dataloaders_for_test(batch_size=3, n_train_batches: int = 12, n_valid_batches: int = 2, num_workers=0):
+ "Generates a tuple of dummy DataLoaders to test with"
+
+ def get_dataset(n_batches):
+ x = torch.randn(batch_size * n_batches, 3)
+ y = torch.randn(batch_size * n_batches, 5)
+ return TensorDataset(x, y)
+
+ train_dataset = get_dataset(n_train_batches)
+ valid_dataset = get_dataset(n_valid_batches)
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
+ valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers)
+ return (train_dataloader, valid_dataloader)
+
+
def get_signature(model):
return sum(param.abs().sum().item() for param in model.parameters())
@@ -89,7 +116,12 @@ def parameterized_custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = "use_safetensors" if param.args[0] is True else "use_pytorch"
- param_based_name += "_tied_weights" if (len(param.args) == 2 and param.args[1] is True) else ""
+ if len(param.args) > 1:
+ param_based_name += "_tied_weights" if param.args[1] is True else ""
+ if len(param.args) > 2:
+ param_based_name += f"_num_workers_{param.args[2]}"
+ if len(param.args) > 3:
+ param_based_name += "_dispatch_batches" if param.args[3] is True else "_no_dispatch_batches"
return f"{func.__name__}_{param_based_name}"
@@ -615,3 +647,133 @@ def test_can_unwrap_model(self):
# check that pickle roundtrip works
model_loaded = pickle.loads(pickle.dumps(model))
model_loaded(inputs)
+
+ # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
+ @require_torchdata_stateful_dataloader
+ def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
+ """Test that setting `use_stateful_dataloader=True` in `DataLoaderConfiguration` prepares a `StatefulDataLoader` object instead of a `DataLoader` object."""
+ dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
+ accelerator = Accelerator(dataloader_config=dataloader_config)
+ model, optimizer, scheduler, train_dl, valid_dl = create_components()
+
+ (
+ prepared_model,
+ prepared_optimizer,
+ prepared_scheduler,
+ prepared_train_dl,
+ prepared_valid_dl,
+ ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)
+
+ assert prepared_model in accelerator._models
+ assert prepared_optimizer in accelerator._optimizers
+ assert prepared_scheduler in accelerator._schedulers
+ assert prepared_train_dl in accelerator._dataloaders
+ assert prepared_valid_dl in accelerator._dataloaders
+ assert isinstance(prepared_train_dl, StatefulDataLoader)
+ assert isinstance(prepared_valid_dl, StatefulDataLoader)
+
+ @parameterized.expand(
+ itertools.product([True, False], [True, False], [0, 2], [True, False]),
+ name_func=parameterized_custom_name_func,
+ )
+ @require_torchdata_stateful_dataloader
+ def test_save_model_with_stateful_dataloader(self, use_safetensors, tied_weights, num_workers, dispatch_batches):
+ """
+ Test that saving and loading a model with a stateful dataloader returns the same model,
+ and that the dataloader's iterator is restored properly."""
+ set_seed(42)
+ dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True)
+ accelerator = Accelerator(dataloader_config=dataloader_config)
+
+ model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)
+ train_dl, valid_dl = create_dataloaders_for_test(num_workers=num_workers)
+ model = ModelForTest()
+
+ (
+ prepared_model,
+ prepared_optimizer,
+ prepared_scheduler,
+ prepared_train_dl,
+ prepared_valid_dl,
+ ) = accelerator.prepare(model, optimizer, scheduler, train_dl, valid_dl)
+
+ assert isinstance(prepared_train_dl, StatefulDataLoader)
+ assert isinstance(prepared_valid_dl, StatefulDataLoader)
+
+ # Perform 3 training iterations to ensure the dataloader's iterator is advanced
+ num_batches_to_skip = 3
+ model.train()
+ for step, batch in enumerate(prepared_train_dl):
+ x, y = batch
+ x.to(accelerator.device)
+ y.to(accelerator.device)
+ with accelerator.accumulate(prepared_model):
+ outputs = prepared_model(x)
+ loss = torch.nn.functional.mse_loss(outputs, y)
+ accelerator.backward(loss)
+ prepared_optimizer.step()
+ prepared_scheduler.step()
+ prepared_optimizer.zero_grad()
+ if step == num_batches_to_skip - 1:
+ state_dict = prepared_train_dl.state_dict()
+ # When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state.
+ # TODO: Maybe this could be done automatically?
+ prepared_train_dl.end()
+ break
+
+ assert accelerator.gradient_state.active_dataloader is None
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ # Save model for later use
+ accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors)
+
+ # Starting from where we left off, train this model to the end of the DataLoader
+ prepared_train_dl = skip_first_batches(prepared_train_dl, num_batches_to_skip)
+ batches_seen_with_original_dl = 0
+ for batch in prepared_train_dl:
+ x, y = batch
+ x.to(accelerator.device)
+ y.to(accelerator.device)
+ with accelerator.accumulate(prepared_model):
+ outputs = prepared_model(x)
+ loss = torch.nn.functional.mse_loss(outputs, y)
+ accelerator.backward(loss)
+ prepared_optimizer.step()
+ prepared_scheduler.step()
+ prepared_optimizer.zero_grad()
+ batches_seen_with_original_dl += 1
+
+ original_linear1 = prepared_model.linear1.weight.clone()
+ original_batchnorm = prepared_model.batchnorm.weight.clone()
+ original_linear2 = prepared_model.linear2.weight.clone()
+
+ # Load the model and state dict
+ load_checkpoint_in_model(model, tmpdirname)
+ stateful_train_dl, _ = create_dataloaders_for_test(num_workers=num_workers)
+ prepared_stateful_train_dl = accelerator.prepare_data_loader(stateful_train_dl)
+ prepared_stateful_train_dl.load_state_dict(state_dict)
+
+ # Train this to the end of the DataLoader
+ batches_seen_with_loaded_dl = 0
+ for batch in prepared_stateful_train_dl:
+ x, y = batch
+ x.to(accelerator.device)
+ y.to(accelerator.device)
+ with accelerator.accumulate(prepared_model):
+ outputs = prepared_model(x)
+ loss = torch.nn.functional.mse_loss(outputs, y)
+ accelerator.backward(loss)
+ prepared_optimizer.step()
+ prepared_scheduler.step()
+ prepared_optimizer.zero_grad()
+ batches_seen_with_loaded_dl += 1
+
+ new_linear1 = prepared_model.linear1.weight
+ new_batchnorm = prepared_model.batchnorm.weight
+ new_linear2 = prepared_model.linear2.weight
+
+ # Assert equalities
+ assert batches_seen_with_original_dl == batches_seen_with_loaded_dl
+ assert torch.allclose(original_linear1, new_linear1)
+ assert torch.allclose(original_batchnorm, new_batchnorm)
+ assert torch.allclose(original_linear2, new_linear2)
diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py
index 2f360d71bcb..d60e599722d 100644
--- a/tests/test_data_loader.py
+++ b/tests/test_data_loader.py
@@ -15,6 +15,9 @@
import random
import unittest
+import pytest
+import torch
+from parameterized import parameterized
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
from accelerate import Accelerator
@@ -22,11 +25,28 @@
BatchSamplerShard,
DataLoaderDispatcher,
DataLoaderShard,
+ DataLoaderStateMixin,
IterableDatasetShard,
SkipBatchSampler,
SkipDataLoader,
skip_first_batches,
)
+from accelerate.test_utils.testing import require_torchdata_stateful_dataloader
+from accelerate.utils import is_torchdata_stateful_dataloader_available
+from accelerate.utils.dataclasses import DataLoaderConfiguration
+
+
+if is_torchdata_stateful_dataloader_available():
+ from torchdata.stateful_dataloader import (
+ StatefulDataLoader,
+ )
+
+
+def parameterized_custom_name_func(func, param_num, param):
+ # customize the test name generator function as we want both params to appear in the sub-test
+ # name, as by default it shows only the first param
+ param_based_name = f"num_workers_{param.args[0]}"
+ return f"{func.__name__}_{param_based_name}"
class RandomIterableDataset(IterableDataset):
@@ -369,6 +389,29 @@ def test_skip_batch_sampler(self):
new_batch_sampler = SkipBatchSampler(batch_sampler, 2)
assert list(new_batch_sampler) == [[8, 9, 10, 11], [12, 13, 14, 15]]
+ def test_dataloader_inheritance(self):
+ """
+ `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter
+ are instances of DataLoader and DataLoaderStateMixin.
+ """
+ Accelerator()
+ skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
+ dl_shard = DataLoaderShard(range(16), batch_size=4)
+ dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)
+ assert isinstance(skip_dl, DataLoader)
+ assert isinstance(dl_shard, DataLoader)
+ assert isinstance(dl_dispatcher, DataLoader)
+
+ assert isinstance(dl_shard, DataLoaderStateMixin)
+ assert isinstance(dl_dispatcher, DataLoaderStateMixin)
+
+ assert isinstance(skip_dl.base_dataloader, DataLoader)
+ assert isinstance(dl_shard.base_dataloader, DataLoader)
+ assert isinstance(dl_dispatcher.base_dataloader, DataLoader)
+
+ with pytest.raises(AttributeError):
+ _ = DataLoaderShard.base_dataloader
+
def test_skip_data_loader(self):
dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2)
assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]
@@ -396,3 +439,230 @@ def test_end_of_dataloader_dispatcher(self):
# Test it also works on the second iteration
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
+
+
+class StatefulDataLoaderTester(unittest.TestCase):
+ @require_torchdata_stateful_dataloader
+ def test_skip_data_loader(self):
+ dataloader = SkipDataLoader(list(range(16)), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
+ assert isinstance(dataloader, StatefulDataLoader)
+ assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]
+
+ @require_torchdata_stateful_dataloader
+ def test_skip_first_batches(self):
+ dataloader = StatefulDataLoader(list(range(16)), batch_size=4)
+ new_dataloader = skip_first_batches(dataloader, num_batches=2)
+ assert isinstance(new_dataloader, StatefulDataLoader)
+ assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]]
+
+ @require_torchdata_stateful_dataloader
+ def test_end_of_dataloader(self):
+ dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True)
+ assert dataloader.use_stateful_dataloader
+ assert isinstance(dataloader, StatefulDataLoader)
+ for idx, _ in enumerate(dataloader):
+ assert dataloader.end_of_dataloader == (idx == 3)
+
+ # Test it also works on the second iteration
+ for idx, _ in enumerate(dataloader):
+ assert dataloader.end_of_dataloader == (idx == 3)
+
+ @require_torchdata_stateful_dataloader
+ def test_end_of_dataloader_dispatcher(self):
+ Accelerator()
+ dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
+ assert isinstance(dataloader, StatefulDataLoader)
+ for idx, _ in enumerate(dataloader):
+ assert dataloader.end_of_dataloader == (idx == 3)
+
+ # Test it also works on the second iteration
+ for idx, _ in enumerate(dataloader):
+ assert dataloader.end_of_dataloader == (idx == 3)
+
+ @parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
+ @require_torchdata_stateful_dataloader
+ def test_dataloader_state_dict(self, num_workers):
+ """
+ Test that saving a stateful dataloader's state, then loading it back, gives the same results.
+ """
+ dataset = list(range(16))
+ dataloader = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
+
+ assert dataloader.use_stateful_dataloader
+ assert isinstance(dataloader, StatefulDataLoader)
+ vals = []
+ for idx, val in enumerate(dataloader):
+ vals.append(val)
+ if idx == 1:
+ sd = dataloader.state_dict()
+ assert len(vals) == 4
+
+ dataloader2 = DataLoaderShard(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
+ dataloader2.load_state_dict(sd)
+
+ data1 = vals[2:]
+ data2 = list(dataloader2)
+ assert len(data1) == len(data2)
+ for d1, d2 in zip(data1, data2):
+ assert torch.allclose(d1, d2)
+
+ @parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
+ @require_torchdata_stateful_dataloader
+ def test_dataloader_dispatcher_state_dict(self, num_workers):
+ """
+ Test that saving a stateful dataloader's state, then loading it back, gives the same results.
+ """
+ dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
+ Accelerator(dataloader_config=dataloader_config)
+ dataset = list(range(16))
+ dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
+
+ assert dataloader.use_stateful_dataloader
+ assert isinstance(dataloader, StatefulDataLoader)
+ vals = []
+ for idx, val in enumerate(dataloader):
+ vals.append(val)
+ if idx == 1:
+ sd = dataloader.state_dict()
+ assert len(vals) == 4
+ dataloader2 = DataLoaderDispatcher(
+ dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers
+ )
+ dataloader2.load_state_dict(sd)
+
+ data1 = vals[2:]
+ data2 = list(dataloader2)
+ assert len(data1) == len(data2)
+ for d1, d2 in zip(data1, data2):
+ assert torch.allclose(d1, d2)
+
+ @require_torchdata_stateful_dataloader
+ def test_dataloader_inheritance(self):
+ """
+ `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True,
+ subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.
+ """
+ Accelerator()
+ skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
+ dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
+ dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
+ assert isinstance(skip_dl, StatefulDataLoader)
+ assert isinstance(dl_shard, StatefulDataLoader)
+ assert isinstance(dl_dispatcher, StatefulDataLoader)
+
+ assert isinstance(dl_shard, DataLoaderStateMixin)
+ assert isinstance(dl_dispatcher, DataLoaderStateMixin)
+
+ assert isinstance(skip_dl.base_dataloader, StatefulDataLoader)
+ assert isinstance(dl_shard.base_dataloader, StatefulDataLoader)
+ assert isinstance(dl_dispatcher.base_dataloader, StatefulDataLoader)
+
+ @parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
+ @require_torchdata_stateful_dataloader
+ def test_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
+ """
+ Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
+ the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader`.
+ """
+ dataset = list(range(64))
+
+ # Set the seed for reproducibility
+ def g():
+ return torch.Generator().manual_seed(42)
+
+ accelerator = Accelerator()
+ stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
+ skip_dl = SkipDataLoader(
+ dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
+ )
+ dl_shard = DataLoaderShard(
+ dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
+ )
+ dl_dispatcher = DataLoaderDispatcher(
+ dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
+ )
+
+ dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]
+
+ num_batches_to_skip = 8
+
+ def get_first_n_batches(dl, n, device):
+ """
+ Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
+ """
+ batches = []
+ for idx, batch in enumerate(dl):
+ if idx == n - 1:
+ if hasattr(dl, "end"):
+ dl.end()
+ break
+ batches.append(batch.to(device))
+ return batches
+
+ # Iterate over all of the dataloaders identically, expect the same values
+ expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, accelerator.device)
+ batches_from_dataloaders = [
+ get_first_n_batches(dl, num_batches_to_skip, accelerator.device) for dl in dataloaders_under_test
+ ]
+
+ for dl_batches in batches_from_dataloaders:
+ for expected, actual in zip(expected_batches, dl_batches):
+ assert torch.allclose(expected, actual)
+
+ # The adapters should all produce the same state_dict as the reference stateful dataloader
+ expected_state_dict = stateful_dl.state_dict()
+ skip_dl_state_dict = skip_dl.state_dict()
+ dl_shard_state_dict = dl_shard.state_dict()
+ dl_dispatcher_state_dict = dl_dispatcher.state_dict()
+
+ assert expected_state_dict == skip_dl_state_dict
+ assert expected_state_dict == dl_shard_state_dict
+ assert expected_state_dict == dl_dispatcher_state_dict
+
+ # Load the state dict into new dataloaders
+ manual_skip_dl = SkipDataLoader(
+ dataset,
+ batch_size=4,
+ num_workers=num_workers,
+ generator=g(),
+ skip_batches=num_batches_to_skip,
+ use_stateful_dataloader=True,
+ )
+ loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
+ loaded_stateful_dl.load_state_dict(expected_state_dict)
+ loaded_skip_dl = SkipDataLoader(
+ dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
+ )
+ loaded_skip_dl.load_state_dict(expected_state_dict)
+ loaded_dl_shard = DataLoaderShard(
+ dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
+ )
+ loaded_dl_shard.load_state_dict(expected_state_dict)
+ loaded_dl_dispatcher = DataLoaderDispatcher(
+ dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
+ )
+ loaded_dl_dispatcher.load_state_dict(expected_state_dict)
+
+ # Continue the iteration, expecting identical behavior across the board
+ def get_all_batches(dl, device):
+ """
+ Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
+ """
+ batches = []
+ num_batches_yielded = 0
+ for batch in dl:
+ batches.append(batch.to(device))
+ num_batches_yielded += 1
+ return (batches, num_batches_yielded)
+
+ expected_batch_results = get_all_batches(loaded_stateful_dl, accelerator.device)
+ dataloader_batch_results = [
+ get_all_batches(dl, accelerator.device)
+ for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
+ ]
+ for dl_results in dataloader_batch_results:
+ for expected, actual in zip(expected_batches, dl_batches):
+ assert torch.allclose(expected[0], actual[0])
+ assert expected_batch_results[1] == dl_results[1]
+
+ assert accelerator.gradient_state.active_dataloader is None