diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 3cfc1423440..d9b2e0e8980 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1164,11 +1164,11 @@ def prepare_data_loader( class SkipBatchSampler(BatchSampler): """ A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. + Should not be used if the original dataloader is a `StatefulDataLoader`. """ def __init__(self, batch_sampler, skip_batches=0): self.batch_sampler = batch_sampler - self.sampler = batch_sampler.sampler self.skip_batches = skip_batches def __iter__(self): @@ -1186,15 +1186,14 @@ def __len__(self): class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin): """ - Subclass of a PyTorch `DataLoader` that will skip the first batches. + Subclass of a PyTorch `DataLoader` that will skip the first batches. Generally it's preferable to use + `skip_first_batches`/`torchdata.StatefulDataLoader` instead of this class. Args: dataset (`torch.utils.data.dataset.Dataset`): The dataset to use to build this datalaoder. skip_batches (`int`, *optional*, defaults to 0): The number of batches to skip at the beginning. - use_stateful_dataloader (`bool`, *optional*, defaults to `False`): - Whether to have this class adapt `StatefulDataLoader` from `torchdata` instead of the regular `DataLoader`. kwargs: All other keyword arguments to pass to the regular `DataLoader` initialization. """ @@ -1215,11 +1214,9 @@ def __iter__(self): def skip_first_batches(dataloader, num_batches=0): """ - Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. + Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if + the original dataloader is a `StatefulDataLoader`. """ - if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader - state = PartialState() if state.distributed_type == DistributedType.XLA: device = dataloader.device @@ -1263,7 +1260,6 @@ def skip_first_batches(dataloader, num_batches=0): split_batches=dataloader.split_batches, batch_sampler=new_batch_sampler, _drop_last=dataloader._drop_last, - use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) elif isinstance(dataloader, DataLoaderShard): @@ -1280,17 +1276,12 @@ def skip_first_batches(dataloader, num_batches=0): device=dataloader.device, rng_types=dataloader.rng_types, synchronized_generator=dataloader.synchronized_generator, - use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs, ) else: if new_batch_sampler is None: # Need to manually skip batches in the dataloader - dataloader = SkipDataLoader( - dataset, skip_batches=num_batches, use_stateful_dataloader=dataloader.use_stateful_dataloader, **kwargs - ) - elif is_torchdata_stateful_dataloader_available() and isinstance(dataloader, StatefulDataLoader): - dataloader = StatefulDataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) + dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) else: dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) diff --git a/src/accelerate/test_utils/scripts/external_deps/test_pippy.py b/src/accelerate/test_utils/scripts/external_deps/test_pippy.py index 86dd4139551..389b963e0cb 100644 --- a/src/accelerate/test_utils/scripts/external_deps/test_pippy.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_pippy.py @@ -12,23 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from torchvision.models import resnet34 from transformers import ( BertConfig, BertForMaskedLM, GPT2Config, GPT2ForSequenceClassification, - T5Config, - T5ForConditionalGeneration, ) from accelerate import PartialState from accelerate.inference import prepare_pippy -from accelerate.utils import DistributedType, send_to_device, set_seed +from accelerate.utils import DistributedType, set_seed model_to_config = { - "t5": (T5ForConditionalGeneration, T5Config, 1024), "bert": (BertForMaskedLM, BertConfig, 512), "gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024), } @@ -42,23 +38,19 @@ def get_model_and_data_for_text(model_name, device, num_processes: int = 2): # config_args["pad_token_id"] = 0 model_config = config(**config_args) model = initializer(model_config) - return model, torch.randint( - low=0, - high=model_config.vocab_size, - size=(num_processes, seq_len), - device=device, - dtype=torch.int64, - requires_grad=False, - ) + kwargs = dict(low=0, high=model_config.vocab_size, device=device, dtype=torch.int64, requires_grad=False) + trace_input = torch.randint(size=(1, seq_len), **kwargs) + inference_inputs = torch.randint(size=(num_processes, seq_len), **kwargs) + return model, trace_input, inference_inputs -def test_gpt2(batch_size: int = 2): +def test_bert(batch_size: int = 2): set_seed(42) state = PartialState() - model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size) - model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules) + model, trace_input, inference_inputs = get_model_and_data_for_text("bert", "cpu", batch_size) + model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules) # For inference args need to be a tuple - inputs = inputs.to("cuda") + inputs = inference_inputs.to("cuda") with torch.no_grad(): output = model(inputs) # Zach: Check that we just grab the real outputs we need at the end @@ -68,20 +60,15 @@ def test_gpt2(batch_size: int = 2): assert output is not None, "Output was not generated in the last process!" -def test_t5(batch_size: int = 2): +def test_gpt2(batch_size: int = 2): set_seed(42) state = PartialState() - model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size) - example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs} - model = prepare_pippy( - model, - no_split_module_classes=model._no_split_modules, - example_kwargs=example_inputs, - ) + model, trace_input, inference_inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size) + model = prepare_pippy(model, example_args=(trace_input,), no_split_module_classes=model._no_split_modules) # For inference args need to be a tuple - inputs = send_to_device(example_inputs, "cuda:0") + inputs = inference_inputs.to("cuda") with torch.no_grad(): - output = model(*inputs.values()) + output = model(inputs) # Zach: Check that we just grab the real outputs we need at the end if not state.is_last_process: assert output is None, "Output was not generated on just the last process!" @@ -89,42 +76,41 @@ def test_t5(batch_size: int = 2): assert output is not None, "Output was not generated in the last process!" -def test_resnet(batch_size: int = 2): - set_seed(42) - state = PartialState() - model = resnet34() - input_tensor = torch.rand(batch_size, 3, 224, 224) - model = prepare_pippy( - model, - example_args=(input_tensor,), - ) - inputs = send_to_device(input_tensor, "cuda:0") - with torch.no_grad(): - output = model(inputs) - # Zach: Check that we just grab the real outputs we need at the end - if not state.is_last_process: - assert output is None, "Output was not generated on just the last process!" - else: - assert output is not None, "Output was not generated in the last process!" +# Currently disabled, enable again once PyTorch pippy interface can trace a resnet34 +# def test_resnet(batch_size: int = 2): +# set_seed(42) +# state = PartialState() +# model = resnet34() +# input_tensor = torch.rand(1, 3, 224, 224) +# model = prepare_pippy( +# model, +# example_args=(input_tensor,), +# ) +# inference_inputs = torch.rand(batch_size, 3, 224, 224) +# inputs = send_to_device(inference_inputs, "cuda:0") +# with torch.no_grad(): +# output = model(inputs) +# # Zach: Check that we just grab the real outputs we need at the end +# if not state.is_last_process: +# assert output is None, "Output was not generated on just the last process!" +# else: +# assert output is not None, "Output was not generated in the last process!" if __name__ == "__main__": state = PartialState() state.print("Testing pippy integration...") - if state.distributed_type == DistributedType.MULTI_GPU: - state.print("Testing GPT2...") - test_gpt2() - # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue - # due to references - # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope - # test_gpt2(3) - state.print("Testing T5...") - test_t5() - test_t5(1) - test_t5(3) - state.print("Testing CV model...") - test_resnet() - test_resnet(3) + try: + if state.distributed_type == DistributedType.MULTI_GPU: + state.print("Testing GPT2...") + test_gpt2() + # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue + # due to references + # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope + # test_gpt2(3) + state.print("Testing BERT...") + test_bert() + else: + print("Less than two GPUs found, not running tests!") + finally: state.destroy_process_group() - else: - print("Less than two GPUs found, not running tests!") diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index a0ec03418a7..1fe34bdc6ed 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -469,13 +469,6 @@ def test_skip_data_loader(self): assert isinstance(dataloader, StatefulDataLoader) assert [t.tolist() for t in dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] - @require_torchdata_stateful_dataloader - def test_skip_first_batches(self): - dataloader = StatefulDataLoader(list(range(16)), batch_size=4) - new_dataloader = skip_first_batches(dataloader, num_batches=2) - assert isinstance(new_dataloader, StatefulDataLoader) - assert [t.tolist() for t in new_dataloader] == [[8, 9, 10, 11], [12, 13, 14, 15]] - @require_torchdata_stateful_dataloader def test_end_of_dataloader(self): dataloader = DataLoaderShard(list(range(16)), batch_size=4, use_stateful_dataloader=True) diff --git a/tests/test_examples.py b/tests/test_examples.py index 553a65008fc..337edd8f8a1 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -19,7 +19,7 @@ import tempfile import unittest from pathlib import Path -from unittest import mock +from unittest import mock, skip import torch @@ -261,6 +261,9 @@ def test_ddp_comm_hook(self): testargs = ["examples/by_feature/ddp_comm_hook.py", "--ddp_comm_hook", "fp16"] run_command(self.launch_args + testargs) + @skip( + reason="stable-diffusion-v1-5 is no longer available. Potentially `Comfy-Org/stable-diffusion-v1-5-archive` once diffusers support is added." + ) @require_multi_device def test_distributed_inference_examples_stable_diffusion(self): testargs = ["examples/inference/distributed/stable_diffusion.py"]