Skip to content

Commit

Permalink
Update torchpippy (#2938)
Browse files Browse the repository at this point in the history
* rm warning

* Take 3

* Take 4

* Annotate

* Take 6

* Updated

* Spec

* Last fix

* Don't padd input

* Finished

* Continue refactor

* Rm comment

* Adjust the err

* Start adjustment

* GPT2 works, T5 does not

* llama too now I think

* Flag the t5 example
  • Loading branch information
muellerzr authored Aug 26, 2024
1 parent c212092 commit 939ce40
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 29 deletions.
12 changes: 11 additions & 1 deletion examples/inference/pippy/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 512), # bs x seq_len
size=(1, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
Expand All @@ -49,6 +49,16 @@
# available on all GPUs
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)

# Create new inputs of the expected size (n_processes)
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 512), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)

# Move the inputs to the first device
input = input.to("cuda:0")

Expand Down
12 changes: 11 additions & 1 deletion examples/inference/pippy/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
size=(1, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
Expand All @@ -48,6 +48,16 @@
# available on all GPUs
# model = prepare_pippy(model, split_points="auto", example_args=(input,), gather_output=True)

# Create new inputs of the expected size (n_processes)
input = torch.randint(
low=0,
high=model.config.vocab_size,
size=(2, 1024), # bs x seq_len
device="cpu",
dtype=torch.int64,
requires_grad=False,
)

# Move the inputs to the first device
input = input.to("cuda:0")

Expand Down
4 changes: 3 additions & 1 deletion examples/inference/pippy/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# Input configs
# Create example inputs for the model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
prompts = ("I would like to", "I really like to") # bs = 2, sending 2 per process
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompts, return_tensors="pt", padding=True)

Expand All @@ -43,6 +43,8 @@

# currently we don't support `model.generate`
# output = model.generate(**inputs, max_new_tokens=1)
prompts = ("I would like to", "I really like to", "The weather is pretty") # bs = 3
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
inputs = inputs.to(0)
with torch.no_grad():
output = model(**inputs)
Expand Down
9 changes: 9 additions & 0 deletions examples/inference/pippy/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@
import time

import torch
from packaging import version
from transformers import AutoModelForSeq2SeqLM

from accelerate import PartialState, prepare_pippy
from accelerate import __version__ as accelerate_version
from accelerate.utils import set_seed


if version.parse(accelerate_version) > version.parse("0.33.0"):
raise RuntimeError(
"Using encoder/decoder models is not supported with the `torch.pipelining` integration or accelerate>=0.34.0. "
"Please use a lower accelerate version and `torchpippy`, which this example uses."
)


# Set the random seed to have reproducable outputs
set_seed(42)

Expand Down
39 changes: 18 additions & 21 deletions src/accelerate/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,21 @@ def build_pipeline(model, split_points, args, kwargs, num_chunks):
`AcceleratorState.num_processes`
"""
# Note: We import here to reduce import time from general modules, and isolate outside dependencies
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
from pippy.PipelineStage import PipelineStage
from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline

# We need to annotate the split points in the model for PiPPy
state = PartialState()
annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points})
found_batch_size = find_pippy_batch_size(args, kwargs)
if found_batch_size != num_chunks:
if args is not None:
args = pad_input_tensors(args, found_batch_size, num_chunks)
if kwargs is not None:
kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks)
pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs)
stage = PipelineStage(pipe, state.local_process_index, device=state.device)
split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points}
pipe = pipeline(
model,
mb_args=args,
mb_kwargs=kwargs,
split_spec=split_spec,
)
stage = pipe.build_stage(state.local_process_index, device=state.device)
schedule = ScheduleGPipe(stage, num_chunks)

return stage
return schedule


def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs):
Expand Down Expand Up @@ -143,22 +142,20 @@ def prepare_pippy(
no_split_module_classes (`List[str]`):
A list of class names for layers we don't want to be split.
example_args (tuple of model inputs):
The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible.
The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use
this method if possible.
example_kwargs (dict of model inputs)
The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure
that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition
is true for all cases.
The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a
*highly* limiting structure that requires the same keys be present at *all* inference calls. Not
recommended unless the prior condition is true for all cases.
num_chunks (`int`, defaults to the number of available GPUs):
The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but
this can be tuned and played with. In general one should have num_chunks >= num_gpus.
gather_output (`bool`, defaults to `False`):
If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs.
"""
if not is_pippy_available():
raise ImportError(
"`pippy` was not found to be installed on your system. Please "
"install using `pip install torchpippy` or ensure you have at least version 0.2.0"
)
raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.")
state = PartialState()
example_args = send_to_device(example_args, "cpu")
example_kwargs = send_to_device(example_kwargs, "cpu")
Expand All @@ -177,7 +174,7 @@ def prepare_pippy(
model.hf_split_points = split_points

def forward(*args, **kwargs):
return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs)
return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs)

# To act like a decorator so that it can be popped when doing `extract_model_from_parallel`
# Note: creates an infinite recursion loop with `generate`
Expand Down
6 changes: 1 addition & 5 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,7 @@ def is_deepspeed_available():


def is_pippy_available():
package_exists = _is_package_available("pippy", "torchpippy")
if package_exists:
pippy_version = version.parse(importlib.metadata.version("torchpippy"))
return compare_versions(pippy_version, ">", "0.1.1")
return False
return is_torch_version(">=", "2.4.0")


def is_bf16_available(ignore_tpu=False):
Expand Down

0 comments on commit 939ce40

Please sign in to comment.