diff --git a/examples/inference/pippy/bert.py b/examples/inference/pippy/bert.py index bed3562337b..474409f5d0f 100644 --- a/examples/inference/pippy/bert.py +++ b/examples/inference/pippy/bert.py @@ -32,7 +32,7 @@ input = torch.randint( low=0, high=model.config.vocab_size, - size=(2, 512), # bs x seq_len + size=(1, 512), # bs x seq_len device="cpu", dtype=torch.int64, requires_grad=False, @@ -49,6 +49,16 @@ # available on all GPUs # model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True) +# Create new inputs of the expected size (n_processes) +input = torch.randint( + low=0, + high=model.config.vocab_size, + size=(2, 512), # bs x seq_len + device="cpu", + dtype=torch.int64, + requires_grad=False, +) + # Move the inputs to the first device input = input.to("cuda:0") diff --git a/examples/inference/pippy/gpt2.py b/examples/inference/pippy/gpt2.py index 994327f3c0d..d1f232b51de 100644 --- a/examples/inference/pippy/gpt2.py +++ b/examples/inference/pippy/gpt2.py @@ -32,7 +32,7 @@ input = torch.randint( low=0, high=model.config.vocab_size, - size=(2, 1024), # bs x seq_len + size=(1, 1024), # bs x seq_len device="cpu", dtype=torch.int64, requires_grad=False, @@ -48,6 +48,16 @@ # available on all GPUs # model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True) +# Create new inputs of the expected size (n_processes) +input = torch.randint( + low=0, + high=model.config.vocab_size, + size=(2, 1024), # bs x seq_len + device="cpu", + dtype=torch.int64, + requires_grad=False, +) + # Move the inputs to the first device input = input.to("cuda:0") diff --git a/examples/inference/pippy/llama.py b/examples/inference/pippy/llama.py index a1b2e12bb8f..631da07bfcf 100644 --- a/examples/inference/pippy/llama.py +++ b/examples/inference/pippy/llama.py @@ -27,7 +27,7 @@ # Input configs # Create example inputs for the model tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") -prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3 +prompts = ("I would like to", "I really like to") # bs = 2, sending 2 per process tokenizer.pad_token = tokenizer.eos_token inputs = tokenizer(prompts, return_tensors="pt", padding=True) @@ -43,6 +43,8 @@ # currently we don't support `model.generate` # output = model.generate(**inputs, max_new_tokens=1) +prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3 +inputs = tokenizer(prompts, return_tensors="pt", padding=True) inputs = inputs.to(0) with torch.no_grad(): output = model(**inputs) diff --git a/examples/inference/pippy/t5.py b/examples/inference/pippy/t5.py index 2f9218aef14..b134eb5372c 100644 --- a/examples/inference/pippy/t5.py +++ b/examples/inference/pippy/t5.py @@ -14,12 +14,21 @@ import time import torch +from packaging import version from transformers import AutoModelForSeq2SeqLM from accelerate import PartialState, prepare_pippy +from accelerate import __version__ as accelerate_version from accelerate.utils import set_seed +if version.parse(accelerate_version) > version.parse("0.33.0"): + raise RuntimeError( + "Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. " + "Please use a lower accelerate version and `torchpippy`, which this example uses." + ) + + # Set the random seed to have reproducable outputs set_seed(42) diff --git a/src/accelerate/inference.py b/src/accelerate/inference.py index 4f9b07081b2..7ee0bdd6f63 100644 --- a/src/accelerate/inference.py +++ b/src/accelerate/inference.py @@ -79,22 +79,21 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks): `AcceleratorState.num_processes` """ # Note: We import here to reduce import time from general modules, and isolate outside dependencies - from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points - from pippy.PipelineStage import PipelineStage + from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline # We need to annotate the split points in the model for PiPPy state = PartialState() - annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points}) - found_batch_size = find_pippy_batch_size(args, kwargs) - if found_batch_size != num_chunks: - if args is not None: - args = pad_input_tensors(args, found_batch_size, num_chunks) - if kwargs is not None: - kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) - pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs) - stage = PipelineStage(pipe, state.local_process_index, device=state.device) + split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points} + pipe = pipeline( + model, + mb_args=args, + mb_kwargs=kwargs, + split_spec=split_spec, + ) + stage = pipe.build_stage(state.local_process_index, device=state.device) + schedule = ScheduleGPipe(stage, num_chunks) - return stage + return schedule def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): @@ -143,11 +142,12 @@ def prepare_pippy( no_split_module_classes (`List[str]`): A list of class names for layers we don't want to be split. example_args (tuple of model inputs): - The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible. + The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use + this method if possible. example_kwargs (dict of model inputs) - The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure - that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition - is true for all cases. + The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a + *highly* limiting structure that requires the same keys be present at *all* inference calls. Not + recommended unless the prior condition is true for all cases. num_chunks (`int`, defaults to the number of available GPUs): The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but this can be tuned and played with. In general one should have num_chunks >= num_gpus. @@ -155,10 +155,7 @@ def prepare_pippy( If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs. """ if not is_pippy_available(): - raise ImportError( - "`pippy` was not found to be installed on your system. Please " - "install using `pip install torchpippy` or ensure you have at least version 0.2.0" - ) + raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.") state = PartialState() example_args = send_to_device(example_args, "cpu") example_kwargs = send_to_device(example_kwargs, "cpu") @@ -177,7 +174,7 @@ def prepare_pippy( model.hf_split_points = split_points def forward(*args, **kwargs): - return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs) + return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs) # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` # Note: creates an infinite recursion loop with `generate` diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 3badfefd684..15f802e5926 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -178,11 +178,7 @@ def is_deepspeed_available(): def is_pippy_available(): - package_exists = _is_package_available("pippy", "torchpippy") - if package_exists: - pippy_version = version.parse(importlib.metadata.version("torchpippy")) - return compare_versions(pippy_version, ">", "0.1.1") - return False + return is_torch_version(">=", "2.4.0") def is_bf16_available(ignore_tpu=False):