From 27899339384e306a57e82f51f175076ee55bf897 Mon Sep 17 00:00:00 2001 From: Sidd Karamcheti Date: Mon, 26 Aug 2024 07:19:59 -0700 Subject: [PATCH] Decouple `prepare_data_loader()` from Accelerator (#3047) --- src/accelerate/data_loader.py | 13 ++-- tests/test_data_loader.py | 130 +++++++++++++++++++++++++++++++--- 2 files changed, 126 insertions(+), 17 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 0f1f97c0b49..df84be901b0 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -20,7 +20,7 @@ from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler from .logging import get_logger -from .state import AcceleratorState, DistributedType, GradientState, PartialState, is_torch_xla_available +from .state import DistributedType, GradientState, PartialState, is_torch_xla_available from .utils import ( RNGType, broadcast, @@ -720,7 +720,7 @@ def __init__( torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) self.gradient_state = GradientState() - self.state = AcceleratorState() + self.state = PartialState() self._drop_last = _drop_last self._non_blocking = _non_blocking self.skip_batches = skip_batches @@ -937,10 +937,9 @@ def prepare_data_loader( device (`torch.device`): The target device for the returned `DataLoader`. num_processes (`int`, *optional*): - The number of processes running concurrently. Will default to the value given by - [`~state.AcceleratorState`]. + The number of processes running concurrently. Will default to the value given by [`~state.PartialState`]. process_index (`int`, *optional*): - The index of the current process. Will default to the value given by [`~state.AcceleratorState`]. + The index of the current process. Will default to the value given by [`~state.PartialState`]. split_batches (`bool`, *optional*, defaults to `False`): Whether the resulting `DataLoader` should split the batches of the original data loader across devices or yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of @@ -1009,8 +1008,8 @@ def prepare_data_loader( if dispatch_batches and not put_on_device: raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.") - # Grab defaults from AcceleratorState - state = AcceleratorState() + # Grab defaults from PartialState + state = PartialState() if num_processes is None: num_processes = state.num_processes if process_index is None: diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index d91e60db4ed..a0ec03418a7 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -20,7 +20,7 @@ from parameterized import parameterized from torch.utils.data import BatchSampler, DataLoader, IterableDataset -from accelerate import Accelerator +from accelerate import Accelerator, PartialState from accelerate.data_loader import ( BatchSamplerShard, DataLoaderDispatcher, @@ -29,11 +29,12 @@ IterableDatasetShard, SkipBatchSampler, SkipDataLoader, + prepare_data_loader, skip_first_batches, ) +from accelerate.state import GradientState from accelerate.test_utils.testing import require_torchdata_stateful_dataloader from accelerate.utils import is_torchdata_stateful_dataloader_available -from accelerate.utils.dataclasses import DataLoaderConfiguration if is_torchdata_stateful_dataloader_available(): @@ -401,9 +402,8 @@ def test_iterable_dataset_shard(self): def test_iterable_dataset_using_none_batch_size(self): dataset = SimpleIterableDataset(100) - accelerator = Accelerator() dataloader = DataLoader(dataset, batch_size=None) - dataloader = accelerator.prepare(dataloader) + dataloader = prepare_data_loader(dataloader) for d in dataloader: assert isinstance(d, torch.Tensor) @@ -417,7 +417,6 @@ def test_dataloader_inheritance(self): `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter are instances of DataLoader and DataLoaderStateMixin. """ - Accelerator() skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) dl_shard = DataLoaderShard(range(16), batch_size=4) dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4) @@ -454,7 +453,6 @@ def test_end_of_dataloader(self): assert dataloader.end_of_dataloader == (idx == 3) def test_end_of_dataloader_dispatcher(self): - Accelerator() dataloader = DataLoaderDispatcher(range(16), batch_size=4) for idx, _ in enumerate(dataloader): assert dataloader.end_of_dataloader == (idx == 3) @@ -492,7 +490,6 @@ def test_end_of_dataloader(self): @require_torchdata_stateful_dataloader def test_end_of_dataloader_dispatcher(self): - Accelerator() dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) assert isinstance(dataloader, StatefulDataLoader) for idx, _ in enumerate(dataloader): @@ -535,8 +532,6 @@ def test_dataloader_dispatcher_state_dict(self, num_workers): """ Test that saving a stateful dataloader's state, then loading it back, gives the same results. """ - dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True) - Accelerator(dataloader_config=dataloader_config) dataset = list(range(16)) dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers) @@ -565,7 +560,6 @@ def test_dataloader_inheritance(self): `DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True, subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin. """ - Accelerator() skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) @@ -689,3 +683,119 @@ def get_all_batches(dl, device): assert expected_batch_results[1] == dl_results[1] assert accelerator.gradient_state.active_dataloader is None + + @parameterized.expand([0, 2], name_func=parameterized_custom_name_func) + @require_torchdata_stateful_dataloader + def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers): + """ + Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce + the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader` when *not* using + Accelerator (and instead using the decoupled `PartialState` workflow). + """ + dataset = list(range(64)) + + # Set the seed for reproducibility + def g(): + return torch.Generator().manual_seed(42) + + state = PartialState() + stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + + dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher] + + num_batches_to_skip = 8 + + def get_first_n_batches(dl, n, device): + """ + Iterate over the first `n` batches of a dataloader then break, returning the batches in a list. + """ + batches = [] + for idx, batch in enumerate(dl): + if idx == n - 1: + if hasattr(dl, "end"): + dl.end() + break + batches.append(batch.to(device)) + return batches + + # Iterate over all of the dataloaders identically, expect the same values + expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, state.device) + batches_from_dataloaders = [ + get_first_n_batches(dl, num_batches_to_skip, state.device) for dl in dataloaders_under_test + ] + + for dl_batches in batches_from_dataloaders: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected, actual) + + # The adapters should all produce the same state_dict as the reference stateful dataloader + expected_state_dict = stateful_dl.state_dict() + skip_dl_state_dict = skip_dl.state_dict() + dl_shard_state_dict = dl_shard.state_dict() + dl_dispatcher_state_dict = dl_dispatcher.state_dict() + + assert expected_state_dict == skip_dl_state_dict + assert expected_state_dict == dl_shard_state_dict + assert expected_state_dict == dl_dispatcher_state_dict + + # Load the state dict into new dataloaders + manual_skip_dl = SkipDataLoader( + dataset, + batch_size=4, + num_workers=num_workers, + generator=g(), + skip_batches=num_batches_to_skip, + use_stateful_dataloader=True, + ) + loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g()) + loaded_stateful_dl.load_state_dict(expected_state_dict) + loaded_skip_dl = SkipDataLoader( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_skip_dl.load_state_dict(expected_state_dict) + loaded_dl_shard = DataLoaderShard( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_shard.load_state_dict(expected_state_dict) + loaded_dl_dispatcher = DataLoaderDispatcher( + dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True + ) + loaded_dl_dispatcher.load_state_dict(expected_state_dict) + + # Continue the iteration, expecting identical behavior across the board + def get_all_batches(dl, device): + """ + Iterate over all batches of a dataloader, returning (batches, num_batches_yielded) + """ + batches = [] + num_batches_yielded = 0 + for batch in dl: + batches.append(batch.to(device)) + num_batches_yielded += 1 + return (batches, num_batches_yielded) + + expected_batch_results = get_all_batches(loaded_stateful_dl, state.device) + dataloader_batch_results = [ + get_all_batches(dl, state.device) + for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher] + ] + for dl_results in dataloader_batch_results: + for expected, actual in zip(expected_batches, dl_batches): + assert torch.allclose(expected[0], actual[0]) + assert expected_batch_results[1] == dl_results[1] + + # Using the decoupled (`PartialState`) workflow, GradientState should be automatically initialized (with + # default parameters) by `DataLoaderDispatcher` + assert GradientState._shared_state != {}, "GradientState should already be initialized!" + + gradient_state = GradientState() + assert gradient_state.active_dataloader is None