Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tensor parallel support to T5 via NxD #697

Merged
merged 29 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def parse_args_neuronx(parser: "ArgumentParser"):
choices=["bf16", "fp16", "tf32"],
help='The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.',
)
optional_group.add_argument(
"--tensor_parallel_size",
type=int,
default=1,
help="Tensor parallelism degree, the number of devices on which to shard the model.",
)
optional_group.add_argument(
"--dynamic-batch-size",
action="store_true",
Expand Down
20 changes: 19 additions & 1 deletion optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def infer_stable_diffusion_shapes_from_diffusers(
def get_submodels_and_neuron_configs(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
tensor_parallel_size: int,
task: str,
output: Path,
library_name: str,
Expand Down Expand Up @@ -300,7 +301,14 @@ def get_submodels_and_neuron_configs(
elif is_encoder_decoder:
optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states}
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_encoder_decoder(
model, input_shapes, task, output, dynamic_batch_size, model_name_or_path, **optional_outputs
model=model,
input_shapes=input_shapes,
tensor_parallel_size=tensor_parallel_size,
task=task,
output=output,
dynamic_batch_size=dynamic_batch_size,
model_name_or_path=model_name_or_path,
**optional_outputs,
)
else:
# TODO: Enable optional outputs for encoders
Expand Down Expand Up @@ -427,6 +435,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
def _get_submodels_and_neuron_configs_for_encoder_decoder(
model: "PreTrainedModel",
input_shapes: Dict[str, int],
tensor_parallel_size: int,
task: str,
output: Path,
dynamic_batch_size: bool = False,
Expand All @@ -442,15 +451,19 @@ def _get_submodels_and_neuron_configs_for_encoder_decoder(
models_and_neuron_configs = get_encoder_decoder_models_for_export(
model=model,
task=task,
tensor_parallel_size=tensor_parallel_size,
dynamic_batch_size=dynamic_batch_size,
input_shapes=input_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
model_name_or_path=model_name_or_path,
)
output_model_names = {
ENCODER_NAME: os.path.join(ENCODER_NAME, NEURON_FILE_NAME),
DECODER_NAME: os.path.join(DECODER_NAME, NEURON_FILE_NAME),
}
model.config.save_pretrained(output)
model.generation_config.save_pretrained(output)
maybe_save_preprocessors(model_name_or_path, output)

return models_and_neuron_configs, output_model_names
Expand All @@ -460,6 +473,7 @@ def load_models_and_neuron_configs(
model_name_or_path: str,
output: Path,
model: Optional[Union["PreTrainedModel", "ModelMixin"]],
tensor_parallel_size: int,
task: str,
dynamic_batch_size: bool,
cache_dir: Optional[str],
Expand Down Expand Up @@ -499,6 +513,7 @@ def load_models_and_neuron_configs(
models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs(
model=model,
input_shapes=input_shapes,
tensor_parallel_size=tensor_parallel_size,
task=task,
library_name=library_name,
output=output,
Expand All @@ -522,6 +537,7 @@ def main_export(
model_name_or_path: str,
output: Union[str, Path],
compiler_kwargs: Dict[str, Any],
tensor_parallel_size: int = 1,
model: Optional[Union["PreTrainedModel", "ModelMixin"]] = None,
task: str = "auto",
dynamic_batch_size: bool = False,
Expand Down Expand Up @@ -563,6 +579,7 @@ def main_export(
model_name_or_path=model_name_or_path,
output=output,
model=model,
tensor_parallel_size=tensor_parallel_size,
task=task,
dynamic_batch_size=dynamic_batch_size,
cache_dir=cache_dir,
Expand Down Expand Up @@ -698,6 +715,7 @@ def main():
model_name_or_path=args.model,
output=args.output,
compiler_kwargs=compiler_kwargs,
tensor_parallel_size=args.tensor_parallel_size,
task=task,
dynamic_batch_size=args.dynamic_batch_size,
atol=args.atol,
Expand Down
10 changes: 10 additions & 0 deletions optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
task: str,
compiler_type: Optional[str] = None,
compiler_version: Optional[str] = None,
tensor_parallel_size: int = 1,
batch_size: Optional[int] = None,
text_batch_size: Optional[int] = None,
image_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -174,6 +175,7 @@ def __init__(
self._config = config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self.mandatory_axes = ()
self.tp_degree = tensor_parallel_size
self.task = task
self._axes: Dict[str, int] = {}
self.dynamic_batch_size = dynamic_batch_size
Expand Down Expand Up @@ -226,6 +228,14 @@ def task(self) -> str:
def task(self, value: str):
self._task = value
self.mandatory_axes = self.get_mandatory_axes_for_task(self.task)

@property
def tp_degree(self) -> int:
return self._tp_degree

@tp_degree.setter
def tp_degree(self, value: int):
self._tp_degree = value

def __getattr__(self, attr_name) -> Any:
if attr_name != "_axes" and attr_name in self._axes:
Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/neuron/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,21 @@ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGen
]

return dummy_inputs_generators

def load_pretrained_with_parallel_attn(self, model, ckpt_path):
# Parallel implementation of Attention modules.
import neuronx_distributed
from .t5_model_layers import ParallelSelfAttention, ParallelFF, ParallelCrossAttention

for index, block in enumerate(model.decoder.block):
if index == 0:
block.layer[0] = ParallelSelfAttention(model.config,
has_relative_attention_bias=True)
else:
block.layer[0] = ParallelSelfAttention(model.config)
block.layer[1] = ParallelCrossAttention(model.config)
block.layer[2] = ParallelFF(model.config)
# Load the weights into the parallel layers
neuronx_distributed.parallel_layers.load(ckpt_path, model, sharded=False)

return model
97 changes: 60 additions & 37 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@

import numpy as np
import torch
from transformers import PreTrainedModel

from ...exporters.error_utils import OutputMatchError, ShapeError
from ...neuron.utils import (
DiffusersPretrainedConfig,
convert_neuronx_compiler_args_to_neuron,
is_neuron_available,
is_neuronx_available,
is_neuronx_distributed_available,
store_compilation_config,
)
from ...neuron.utils.cache_utils import get_model_name_or_path
Expand All @@ -42,11 +44,10 @@
is_sentence_transformers_available,
logging,
)
from .config import TextSeq2SeqNeuronConfig


if TYPE_CHECKING:
from transformers import PreTrainedModel

from .base import NeuronDefaultConfig

if is_neuron_available():
Expand All @@ -68,6 +69,9 @@
if is_sentence_transformers_available():
from sentence_transformers import SentenceTransformer

if is_neuronx_distributed_available():
import neuronx_distributed

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -182,6 +186,7 @@ def validate_model_outputs(
inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes)
ref_inputs = config.unflatten_inputs(inputs)
if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False):

reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes)
if "SentenceTransformer" in reference_model.__class__.__name__:
reference_model = config.patch_model_for_export(reference_model, ref_inputs)
Expand Down Expand Up @@ -364,7 +369,7 @@ def export_models(

start_time = time.time()
neuron_inputs, neuron_outputs = export(
model=submodel,
model_or_path=submodel,
config=sub_neuron_config,
output=output_path,
compiler_workdir=compiler_workdir,
Expand Down Expand Up @@ -396,6 +401,7 @@ def export_models(
input_names=neuron_inputs,
output_names=neuron_outputs,
dynamic_batch_size=sub_neuron_config.dynamic_batch_size,
tensor_parallel_size=sub_neuron_config.tp_degree,
compiler_type=NEURON_COMPILER_TYPE,
compiler_version=NEURON_COMPILER_VERSION,
inline_weights_to_neff=inline_weights_to_neff,
Expand Down Expand Up @@ -426,7 +432,7 @@ def export_models(


def export(
model: "PreTrainedModel",
model_or_path: Union["PreTrainedModel", str, Path],
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
Expand All @@ -439,7 +445,7 @@ def export(
) -> Tuple[List[str], List[str]]:
if is_neuron_available():
return export_neuron(
model=model,
model=model_or_path,
config=config,
output=output,
compiler_workdir=compiler_workdir,
Expand All @@ -451,7 +457,7 @@ def export(
)
elif is_neuronx_available():
return export_neuronx(
model=model,
model_or_path=model_or_path,
config=config,
output=output,
compiler_workdir=compiler_workdir,
Expand All @@ -467,7 +473,7 @@ def export(


def export_neuronx(
model: "PreTrainedModel",
model_or_path: Union["PreTrainedModel", str, Path],
config: "NeuronDefaultConfig",
output: Path,
compiler_workdir: Optional[Path] = None,
Expand All @@ -480,8 +486,8 @@ def export_neuronx(
Exports a PyTorch model to a serialized TorchScript module compiled by neuronx-cc compiler.

Args:
model ([`PreTrainedModel`]):
The model to export.
model_or_path (Union["PreTrainedModel", str, Path]):
The model to export or its location(case when applying the parallelism as the model needs to be loaded with the tracing).
config ([`~exporter.NeuronDefaultConfig`]):
The Neuron configuration associated with the exported model.
output (`Path`):
Expand All @@ -508,18 +514,21 @@ def export_neuronx(
if isinstance(compiler_workdir, Path):
compiler_workdir = compiler_workdir.as_posix()

if hasattr(model, "config"):
model.config.return_dict = True
model.config.torchscript = True
model.eval()
if hasattr(model_or_path, "config"):
model_or_path.config.return_dict = True
model_or_path.config.torchscript = True
if isinstance(model_or_path, PreTrainedModel):
model_or_path.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)
if isinstance(model_or_path, PreTrainedModel):
setattr(model_or_path.config, override_config_key, override_config_value)

# Prepare dummy inputs for tracing
input_shapes = {}
for axis in config.mandatory_axes:
input_shapes[axis] = getattr(config, axis)
Expand All @@ -528,14 +537,17 @@ def export_neuronx(
dummy_inputs = config.flatten_inputs(dummy_inputs)
dummy_inputs_tuple = tuple(dummy_inputs.values())

# Prepare the model / function(tp) to trace
aliases = {}
if hasattr(model, "config") and getattr(model.config, "is_encoder_decoder", False):
checked_model = config.patch_model_for_export(model, **input_shapes)
if getattr(config, "is_decoder", False):
tp_degree = config.tp_degree
if hasattr(model_or_path, "config") and isinstance(config, TextSeq2SeqNeuronConfig):
checked_model = config.patch_model_for_export(model_or_path, **input_shapes)
if tp_degree==1:
aliases = config.generate_io_aliases(checked_model)
else:
checked_model = config.patch_model_for_export(model, dummy_inputs)
checked_model = config.patch_model_for_export(model_or_path, dummy_inputs)

# Construct compiler configurations
if auto_cast is not None:
logger.info(f"Using Neuron: --auto-cast {auto_cast}")

Expand All @@ -549,32 +561,43 @@ def export_neuronx(

compiler_args.extend(["--optlevel", optlevel])

# diffusers specific
compiler_args = add_stable_diffusion_compiler_args(config, compiler_args)
compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) # diffusers specific

if config.dynamic_batch_size and not inline_weights_to_neff:
logger.warning(
"Dynamic batching is not yet compatible with the weights/neff non-inlined model. `inline_weights_to_neff` is set to True. If you still want to separate the neff and weights, please set `dynamic_batch_size=False`."
)
inline_weights_to_neff = True

neuron_model = neuronx.trace(
checked_model,
dummy_inputs_tuple,
compiler_args=compiler_args,
input_output_aliases=aliases,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
)

if config.dynamic_batch_size is True:
neuron_model = neuronx.dynamic_batch(neuron_model)

# diffusers specific
improve_stable_diffusion_loading(config, neuron_model)

torch.jit.save(neuron_model, output)
del model
# Start trace
if tp_degree > 1:
# 1. use NxD to trace for parallel
neuron_model = neuronx_distributed.trace.parallel_model_trace(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is ok in a first step, but for LLama example they are not using this anymore, but instead the ModelBuilder class that wraps the model into NxDModel classes that contains several sub-models with different input shapes (bucketing).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the use of bucketing mature and justified? I think we can start with parallel_model_trace anyway.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It goes a bit beyond that, because prefill / decode already use two different input shapes, not even mentioning bucketing, and using the builder allows to share the same weights between all the alternate graphs.

checked_model,
dummy_inputs_tuple,
compiler_args=compiler_args,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
tp_degree=tp_degree,
)
neuronx_distributed.trace.parallel_model_save(neuron_model, output)
else:
# 2. use `torch_neuronx.trace`
neuron_model = neuronx.trace(
checked_model,
dummy_inputs_tuple,
compiler_args=compiler_args,
input_output_aliases=aliases,
inline_weights_to_neff=inline_weights_to_neff,
compiler_workdir=compiler_workdir,
)
if config.dynamic_batch_size is True:
neuron_model = neuronx.dynamic_batch(neuron_model)
# diffusers specific
improve_stable_diffusion_loading(config, neuron_model)
torch.jit.save(neuron_model, output)

del model_or_path
del checked_model
del dummy_inputs
del neuron_model
Expand Down
Loading
Loading