diff --git a/benchmark/text-generation-inference/performance/generate_csv.py b/benchmark/text-generation-inference/performance/generate_csv.py index 1e7770f63..366370e19 100644 --- a/benchmark/text-generation-inference/performance/generate_csv.py +++ b/benchmark/text-generation-inference/performance/generate_csv.py @@ -3,7 +3,6 @@ import os import pandas as pd - from guidellm.core import GuidanceReport, TextGenerationBenchmark @@ -16,11 +15,7 @@ def _benchmark_rate_id(benchmark: TextGenerationBenchmark) -> str: :return: A string representing the benchmark rate ID. :rtype: str """ - rate_id = ( - f"{benchmark.mode}@{benchmark.rate:.2f} req/sec" - if benchmark.rate - else f"{benchmark.mode}" - ) + rate_id = f"{benchmark.mode}@{benchmark.rate:.2f} req/sec" if benchmark.rate else f"{benchmark.mode}" return rate_id @@ -38,20 +33,20 @@ def main(): for path in paths: filename = os.path.basename(path) # Extract model_id - model_id, date = filename.replace(suffix, '').split('#') + model_id, date = filename.replace(suffix, "").split("#") with open(path) as f: report = GuidanceReport.from_json(f.read()) for benchmark in report.benchmarks: for b in benchmark.benchmarks_sorted: d = { - "model_id": model_id, - "Date": date, - "Input type": _benchmark_rate_id(b), - "Requests per Second": b.completed_request_rate, - "Request Latency (s)": b.request_latency, - "Time-to-first-token (ms)": b.time_to_first_token, - "Inter Token Latency (ms)": b.inter_token_latency, - "Output Token Throughput (t/s)": b.output_token_throughput, + "model_id": model_id, + "Date": date, + "Input type": _benchmark_rate_id(b), + "Requests per Second": b.completed_request_rate, + "Request Latency (s)": b.request_latency, + "Time-to-first-token (ms)": b.time_to_first_token, + "Inter Token Latency (ms)": b.inter_token_latency, + "Output Token Throughput (t/s)": b.output_token_throughput, } results.append(pd.DataFrame.from_dict(d, orient="index").transpose()) diff --git a/docs/source/package_reference/modeling.mdx b/docs/source/package_reference/modeling.mdx index 121275fb2..76d8e2fa9 100644 --- a/docs/source/package_reference/modeling.mdx +++ b/docs/source/package_reference/modeling.mdx @@ -68,6 +68,11 @@ The following Neuron model classes are available for natural language processing [[autodoc]] modeling.NeuronModelForCausalLM - forward +### NeuronModelForSeq2SeqLM + +[[autodoc]] modeling_seq2seq.NeuronModelForSeq2SeqLM + - forward + ## Computer Vision The following Neuron model classes are available for computer vision tasks. diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index eb7749677..e14dd99df 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -112,6 +112,12 @@ def parse_args_neuronx(parser: "ArgumentParser"): choices=["bf16", "fp16", "tf32"], help='The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.', ) + optional_group.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Tensor parallelism size, the number of neuron cores on which to shard the model.", + ) optional_group.add_argument( "--dynamic-batch-size", action="store_true", diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index ce38b5a63..d458ded0d 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -264,6 +264,7 @@ def get_submodels_and_neuron_configs( task: str, output: Path, library_name: str, + tensor_parallel_size: int = 1, subfolder: str = "", dynamic_batch_size: bool = False, model_name_or_path: Optional[Union[str, Path]] = None, @@ -300,7 +301,14 @@ def get_submodels_and_neuron_configs( elif is_encoder_decoder: optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states} models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_encoder_decoder( - model, input_shapes, task, output, dynamic_batch_size, model_name_or_path, **optional_outputs + model=model, + input_shapes=input_shapes, + tensor_parallel_size=tensor_parallel_size, + task=task, + output=output, + dynamic_batch_size=dynamic_batch_size, + model_name_or_path=model_name_or_path, + **optional_outputs, ) else: # TODO: Enable optional outputs for encoders @@ -427,6 +435,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( def _get_submodels_and_neuron_configs_for_encoder_decoder( model: "PreTrainedModel", input_shapes: Dict[str, int], + tensor_parallel_size: int, task: str, output: Path, dynamic_batch_size: bool = False, @@ -442,15 +451,19 @@ def _get_submodels_and_neuron_configs_for_encoder_decoder( models_and_neuron_configs = get_encoder_decoder_models_for_export( model=model, task=task, + tensor_parallel_size=tensor_parallel_size, dynamic_batch_size=dynamic_batch_size, input_shapes=input_shapes, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + model_name_or_path=model_name_or_path, ) output_model_names = { ENCODER_NAME: os.path.join(ENCODER_NAME, NEURON_FILE_NAME), DECODER_NAME: os.path.join(DECODER_NAME, NEURON_FILE_NAME), } + model.config.save_pretrained(output) + model.generation_config.save_pretrained(output) maybe_save_preprocessors(model_name_or_path, output) return models_and_neuron_configs, output_model_names @@ -475,6 +488,7 @@ def load_models_and_neuron_configs( lora_weight_names: Optional[Union[str, List[str]]], lora_adapter_names: Optional[Union[str, List[str]]], lora_scales: Optional[Union[float, List[float]]], + tensor_parallel_size: int = 1, controlnet_ids: Optional[Union[str, List[str]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, @@ -499,6 +513,7 @@ def load_models_and_neuron_configs( models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs( model=model, input_shapes=input_shapes, + tensor_parallel_size=tensor_parallel_size, task=task, library_name=library_name, output=output, @@ -522,6 +537,7 @@ def main_export( model_name_or_path: str, output: Union[str, Path], compiler_kwargs: Dict[str, Any], + tensor_parallel_size: int = 1, model: Optional[Union["PreTrainedModel", "ModelMixin"]] = None, task: str = "auto", dynamic_batch_size: bool = False, @@ -563,6 +579,7 @@ def main_export( model_name_or_path=model_name_or_path, output=output, model=model, + tensor_parallel_size=tensor_parallel_size, task=task, dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, @@ -597,6 +614,12 @@ def main_export( ) # Validate compiled model + if do_validation and tensor_parallel_size > 1: + # TODO: support the validation of tp models. + logger.warning( + "The validation is not yet supported for tensor parallel model, the validation will be turned off." + ) + do_validation = False if do_validation is True: try: validate_models_outputs( @@ -698,6 +721,7 @@ def main(): model_name_or_path=args.model, output=args.output, compiler_kwargs=compiler_kwargs, + tensor_parallel_size=args.tensor_parallel_size, task=task, dynamic_batch_size=args.dynamic_batch_size, atol=args.atol, diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index a5cd59a9d..9520b2483 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -146,6 +146,7 @@ def __init__( task: str, compiler_type: Optional[str] = None, compiler_version: Optional[str] = None, + tensor_parallel_size: int = 1, batch_size: Optional[int] = None, text_batch_size: Optional[int] = None, image_batch_size: Optional[int] = None, @@ -174,6 +175,7 @@ def __init__( self._config = config self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) self.mandatory_axes = () + self.tensor_parallel_size = tensor_parallel_size self.task = task self._axes: Dict[str, int] = {} self.dynamic_batch_size = dynamic_batch_size @@ -227,6 +229,14 @@ def task(self, value: str): self._task = value self.mandatory_axes = self.get_mandatory_axes_for_task(self.task) + @property + def tensor_parallel_size(self) -> int: + return self._tensor_parallel_size + + @tensor_parallel_size.setter + def tensor_parallel_size(self, value: int): + self._tensor_parallel_size = value + def __getattr__(self, attr_name) -> Any: if attr_name != "_axes" and attr_name in self._axes: return self._axes[attr_name] diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 1f93590fc..1739f3782 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -21,6 +21,7 @@ import numpy as np import torch +from transformers import PreTrainedModel from ...exporters.error_utils import OutputMatchError, ShapeError from ...neuron.utils import ( @@ -28,6 +29,7 @@ convert_neuronx_compiler_args_to_neuron, is_neuron_available, is_neuronx_available, + is_neuronx_distributed_available, store_compilation_config, ) from ...neuron.utils.cache_utils import get_model_name_or_path @@ -42,11 +44,10 @@ is_sentence_transformers_available, logging, ) +from .config import TextSeq2SeqNeuronConfig if TYPE_CHECKING: - from transformers import PreTrainedModel - from .base import NeuronDefaultConfig if is_neuron_available(): @@ -68,6 +69,9 @@ if is_sentence_transformers_available(): from sentence_transformers import SentenceTransformer +if is_neuronx_distributed_available(): + import neuronx_distributed + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -182,6 +186,7 @@ def validate_model_outputs( inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes) ref_inputs = config.unflatten_inputs(inputs) if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False): + reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes) if "SentenceTransformer" in reference_model.__class__.__name__: reference_model = config.patch_model_for_export(reference_model, ref_inputs) @@ -297,7 +302,6 @@ def export_models( optlevel: str = "2", output_file_names: Optional[Dict[str, str]] = None, compiler_kwargs: Optional[Dict[str, Any]] = {}, - configs: Optional[Dict[str, Any]] = {}, model_name_or_path: Optional[str] = None, ) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: """ @@ -324,8 +328,6 @@ def export_models( If None, will use the keys from `models_and_neuron_configs` as names. compiler_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Arguments to pass to the Neuron(x) compiler for exporting Neuron models. - configs (`Optional[Dict[str, Any]]`, defaults to `None`): - A list of pretrained model configs. model_name_or_path (`Optional[str]`, defaults to `None`): Path to pretrained model or model identifier from the Hugging Face Hub. Returns: @@ -364,7 +366,7 @@ def export_models( start_time = time.time() neuron_inputs, neuron_outputs = export( - model=submodel, + model_or_path=submodel, config=sub_neuron_config, output=output_path, compiler_workdir=compiler_workdir, @@ -377,14 +379,9 @@ def export_models( logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.") all_inputs[model_name] = neuron_inputs all_outputs[model_name] = neuron_outputs - # Add neuron specific configs to model components' original config - if hasattr(submodel, "config"): - model_config = submodel.config - elif configs and (model_name in configs.keys()): - model_config = configs[model_name] - else: - raise AttributeError("Cannot find model's configuration, please pass it with `configs`.") + # Add neuron specific configs to model components' original config + model_config = sub_neuron_config._config if is_diffusers_available() and isinstance(model_config, FrozenDict): model_config = OrderedDict(model_config) model_config = DiffusersPretrainedConfig.from_dict(model_config) @@ -396,6 +393,7 @@ def export_models( input_names=neuron_inputs, output_names=neuron_outputs, dynamic_batch_size=sub_neuron_config.dynamic_batch_size, + tensor_parallel_size=sub_neuron_config.tensor_parallel_size, compiler_type=NEURON_COMPILER_TYPE, compiler_version=NEURON_COMPILER_VERSION, inline_weights_to_neff=inline_weights_to_neff, @@ -426,7 +424,7 @@ def export_models( def export( - model: "PreTrainedModel", + model_or_path: Union["PreTrainedModel", str, Path], config: "NeuronDefaultConfig", output: Path, compiler_workdir: Optional[Path] = None, @@ -439,7 +437,7 @@ def export( ) -> Tuple[List[str], List[str]]: if is_neuron_available(): return export_neuron( - model=model, + model=model_or_path, config=config, output=output, compiler_workdir=compiler_workdir, @@ -451,7 +449,7 @@ def export( ) elif is_neuronx_available(): return export_neuronx( - model=model, + model_or_path=model_or_path, config=config, output=output, compiler_workdir=compiler_workdir, @@ -467,7 +465,7 @@ def export( def export_neuronx( - model: "PreTrainedModel", + model_or_path: Union["PreTrainedModel", str, Path], config: "NeuronDefaultConfig", output: Path, compiler_workdir: Optional[Path] = None, @@ -480,8 +478,8 @@ def export_neuronx( Exports a PyTorch model to a serialized TorchScript module compiled by neuronx-cc compiler. Args: - model ([`PreTrainedModel`]): - The model to export. + model_or_path (Union["PreTrainedModel", str, Path]): + The model to export or its location(case when applying the parallelism as the model needs to be loaded with the tracing). config ([`~exporter.NeuronDefaultConfig`]): The Neuron configuration associated with the exported model. output (`Path`): @@ -508,18 +506,21 @@ def export_neuronx( if isinstance(compiler_workdir, Path): compiler_workdir = compiler_workdir.as_posix() - if hasattr(model, "config"): - model.config.return_dict = True - model.config.torchscript = True - model.eval() + if hasattr(model_or_path, "config"): + model_or_path.config.return_dict = True + model_or_path.config.torchscript = True + if isinstance(model_or_path, PreTrainedModel): + model_or_path.eval() # Check if we need to override certain configuration item if config.values_override is not None: logger.info(f"Overriding {len(config.values_override)} configuration item(s)") for override_config_key, override_config_value in config.values_override.items(): logger.info(f"\t- {override_config_key} -> {override_config_value}") - setattr(model.config, override_config_key, override_config_value) + if isinstance(model_or_path, PreTrainedModel): + setattr(model_or_path.config, override_config_key, override_config_value) + # Prepare dummy inputs for tracing input_shapes = {} for axis in config.mandatory_axes: input_shapes[axis] = getattr(config, axis) @@ -528,14 +529,17 @@ def export_neuronx( dummy_inputs = config.flatten_inputs(dummy_inputs) dummy_inputs_tuple = tuple(dummy_inputs.values()) + # Prepare the model / function(tp) to trace aliases = {} - if hasattr(model, "config") and getattr(model.config, "is_encoder_decoder", False): - checked_model = config.patch_model_for_export(model, **input_shapes) - if getattr(config, "is_decoder", False): + tensor_parallel_size = config.tensor_parallel_size + if isinstance(config, TextSeq2SeqNeuronConfig): + checked_model = config.patch_model_for_export(model_or_path, **input_shapes) + if tensor_parallel_size == 1: aliases = config.generate_io_aliases(checked_model) else: - checked_model = config.patch_model_for_export(model, dummy_inputs) + checked_model = config.patch_model_for_export(model_or_path, dummy_inputs) + # Construct compiler configurations if auto_cast is not None: logger.info(f"Using Neuron: --auto-cast {auto_cast}") @@ -549,8 +553,7 @@ def export_neuronx( compiler_args.extend(["--optlevel", optlevel]) - # diffusers specific - compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) + compiler_args = add_stable_diffusion_compiler_args(config, compiler_args) # diffusers specific if config.dynamic_batch_size and not inline_weights_to_neff: logger.warning( @@ -558,23 +561,35 @@ def export_neuronx( ) inline_weights_to_neff = True - neuron_model = neuronx.trace( - checked_model, - dummy_inputs_tuple, - compiler_args=compiler_args, - input_output_aliases=aliases, - inline_weights_to_neff=inline_weights_to_neff, - compiler_workdir=compiler_workdir, - ) - - if config.dynamic_batch_size is True: - neuron_model = neuronx.dynamic_batch(neuron_model) - - # diffusers specific - improve_stable_diffusion_loading(config, neuron_model) + # Start trace + if tensor_parallel_size > 1: + # 1. use NxD to trace for parallel + neuron_model = neuronx_distributed.trace.parallel_model_trace( + checked_model, + dummy_inputs_tuple, + compiler_args=compiler_args, + inline_weights_to_neff=inline_weights_to_neff, + compiler_workdir=compiler_workdir, + tp_degree=tensor_parallel_size, + ) + neuronx_distributed.trace.parallel_model_save(neuron_model, output) + else: + # 2. use `torch_neuronx.trace` + neuron_model = neuronx.trace( + checked_model, + dummy_inputs_tuple, + compiler_args=compiler_args, + input_output_aliases=aliases, + inline_weights_to_neff=inline_weights_to_neff, + compiler_workdir=compiler_workdir, + ) + if config.dynamic_batch_size is True: + neuron_model = neuronx.dynamic_batch(neuron_model) + # diffusers specific + improve_stable_diffusion_loading(config, neuron_model) + torch.jit.save(neuron_model, output) - torch.jit.save(neuron_model, output) - del model + del model_or_path del checked_model del dummy_inputs del neuron_model diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 63a5996c9..7a8ffee4e 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -15,15 +15,18 @@ """Model specific Neuron configurations.""" import copy +from functools import partial from typing import TYPE_CHECKING, Dict, List import torch +from ...neuron.distributed import ParallelizersManager from ...neuron.utils import ( ASTDummyAudioInputGenerator, DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyMaskedPosGenerator, + is_neuronx_distributed_available, ) from ...utils import ( DummyInputGenerator, @@ -59,6 +62,9 @@ ) +if is_neuronx_distributed_available(): + import neuronx_distributed + if TYPE_CHECKING: if is_diffusers_available(): from diffusers.models.vae import Decoder as VaeDecoder @@ -793,19 +799,69 @@ class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig): def is_decoder(self) -> bool: return False - def patch_model_for_export(self, model, device="xla", **kwargs): + def patch_model_for_export(self, model_or_path, device="xla", **kwargs): num_beams = kwargs.pop("num_beams", 1) - return self.CUSTOM_MODEL_WRAPPER(model, num_beams=num_beams, device=device) - - -@register_in_tasks_manager("opt", "text-generation") -class OPTNeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "opt.model.OPTForSampling" - + sequence_length = kwargs.pop("sequence_length", None) + batch_size = kwargs.pop("batch_size", None) + + if self.tensor_parallel_size > 1: + # `torch.nn.modules` objects not eligible for pickling, the model needs to be loaded within the func. + return partial( + self.get_parallel_callable, + model_or_path, + sequence_length, + batch_size, + num_beams, + device, + self.tensor_parallel_size, + ) + else: + return self.CUSTOM_MODEL_WRAPPER( + model_or_path, + sequence_length=sequence_length, + batch_size=batch_size, + num_beams=num_beams, + device=device, + tensor_parallel_size=self.tensor_parallel_size, + ) + + def get_parallel_callable( + self, model_name_or_path, sequence_length, batch_size, num_beams, device, tensor_parallel_size + ): + """Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states.""" + model = TasksManager.get_model_from_task( + model_name_or_path=model_name_or_path, + task=self.task, + framework="pt", + library_name="transformers", + ) # TODO: add extra args, eg. revision, trust_remote_code, etc. + model.config.use_cache = True + parallelizer = ParallelizersManager.parallelizer_for_model(model) + with parallelizer.saved_model_in_temporary_directory(model) as ckpt_path: + # Replace parallel layers + parallel_model = parallelizer._parallelize(model, parallelize_embeddings=False) + # Load the weights into the parallel layers + neuronx_distributed.parallel_layers.load(ckpt_path, parallel_model, sharded=False) + encoder = self.CUSTOM_MODEL_WRAPPER( + parallel_model, + sequence_length=sequence_length, + batch_size=batch_size, + num_beams=num_beams, + device=device, + tensor_parallel_size=tensor_parallel_size, + ) + encoder.eval() + aliases = self.generate_io_aliases(encoder) + return encoder, aliases -@register_in_tasks_manager("bloom", "text-generation") -class BloomNeuronConfig(TextNeuronDecoderConfig): - NEURONX_CLASS = "bloom.model.BloomForSampling" + def generate_io_aliases(self, encoder=None): + aliases = {} + if self.tensor_parallel_size > 1: + for i in range(len(encoder.past_key_values_sa)): + aliases[encoder.past_key_values_sa[i]] = i + for i in range(len(encoder.past_key_values_ca)): + aliases[encoder.past_key_values_ca[i]] = len(encoder.past_key_values_sa) + i + return aliases @register_in_tasks_manager("t5-decoder", "text2text-generation") @@ -850,27 +906,92 @@ def patch_model_for_export(self, model, device="xla", **kwargs): sequence_length = kwargs.pop("sequence_length", 1) num_beams = kwargs.pop("num_beams", 1) - return self.CUSTOM_MODEL_WRAPPER( - model, + trace_args = { + "model": model, + "batch_size": batch_size, + "sequence_length": sequence_length, + "num_beams": num_beams, + "output_hidden_states": self.output_hidden_states, + "output_attentions": self.output_attentions, + "device": device, + "tensor_parallel_size": self.tensor_parallel_size, + } + if self.tensor_parallel_size > 1: + return partial( + self.get_parallel_callable, + model, + batch_size, + sequence_length, + num_beams, + self.output_hidden_states, + self.output_attentions, + device, + self.tensor_parallel_size, + ) + else: + return self.CUSTOM_MODEL_WRAPPER(**trace_args) + + def get_parallel_callable( + self, + model_name_or_path, + batch_size, + sequence_length, + num_beams, + output_hidden_states, + output_attentions, + device, + tensor_parallel_size, + ): + """Unlike `torch_neuronx.trace`, `parallel_model_trace` requires a function returning a model object and a dictionary of states.""" + model = TasksManager.get_model_from_task( + model_name_or_path=model_name_or_path, + task=self.task, + framework="pt", + library_name="transformers", + ) # TODO: add extra args, eg. revision, trust_remote_code, etc. + model.config.use_cache = True + parallelizer = ParallelizersManager.parallelizer_for_model(model) + with parallelizer.saved_model_in_temporary_directory(model) as ckpt_path: + # Replace parallel layers + parallel_model = parallelizer._parallelize(model, parallelize_embeddings=False) + # Load the weights into the parallel layers + neuronx_distributed.parallel_layers.load(ckpt_path, parallel_model, sharded=False) + + decoder = self.CUSTOM_MODEL_WRAPPER( + parallel_model, batch_size=batch_size, sequence_length=sequence_length, num_beams=num_beams, - output_hidden_states=self.output_hidden_states, - output_attentions=self.output_attentions, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, device=device, + tensor_parallel_size=tensor_parallel_size, ) + decoder.eval() + aliases = self.generate_io_aliases(decoder) + return decoder, aliases - def generate_io_aliases(self, model): - num_outputs_from_trace = 3 if model.num_beams > 1 else 1 + def generate_io_aliases(self, decoder): + num_outputs_from_trace = 3 if decoder.num_beams > 1 else 1 aliases = {} - for i in range(len(model.past_key_values_sa)): - aliases[model.past_key_values_sa[i]] = i + num_outputs_from_trace - for i in range(len(model.past_key_values_ca)): - aliases[model.past_key_values_ca[i]] = len(model.past_key_values_sa) + i + num_outputs_from_trace + for i in range(len(decoder.past_key_values_sa)): + aliases[decoder.past_key_values_sa[i]] = i + num_outputs_from_trace + for i in range(len(decoder.past_key_values_ca)): + aliases[decoder.past_key_values_ca[i]] = len(decoder.past_key_values_sa) + i + num_outputs_from_trace return aliases +@register_in_tasks_manager("opt", "text-generation") +class OPTNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "opt.model.OPTForSampling" + + +@register_in_tasks_manager("bloom", "text-generation") +class BloomNeuronConfig(TextNeuronDecoderConfig): + NEURONX_CLASS = "bloom.model.BloomForSampling" + + @register_in_tasks_manager("mistral", "text-generation") class MistralNeuronConfig(TextNeuronDecoderConfig): NEURONX_CLASS = "mistral.model.MistralForSampling" diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py index 9c83168ce..d91d7ed8e 100644 --- a/optimum/exporters/neuron/model_wrappers.py +++ b/optimum/exporters/neuron/model_wrappers.py @@ -19,6 +19,12 @@ import torch from transformers.models.t5.modeling_t5 import T5LayerCrossAttention +from ...neuron.utils import is_neuronx_distributed_available + + +if is_neuronx_distributed_available(): + import neuronx_distributed + if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel @@ -122,21 +128,74 @@ class T5EncoderWrapper(torch.nn.Module): def __init__( self, model: "PreTrainedModel", + sequence_length: Optional[int] = None, + batch_size: Optional[int] = None, num_beams: int = 1, device: str = "xla", - tp_degree: Optional[int] = None, + tensor_parallel_size: int = 1, ): super().__init__() self.model = model self.config = model.config self.num_beams = num_beams + self.sequence_length = sequence_length + self.batch_size = batch_size self.device = device - self.tp_degree = tp_degree + self.tensor_parallel_size = tensor_parallel_size + self.num_attention_heads_per_partition = self.config.num_heads # when tensor_parallel_size=1 + + if self.tensor_parallel_size > 1: + self.num_attention_heads_per_partition = ( + self.num_attention_heads_per_partition + // neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size() + ) + self.past_key_values_sa = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.ones( + ( + self.num_beams * batch_size, + self.num_attention_heads_per_partition, + self.sequence_length - 1, + self.config.d_kv, + ), + dtype=torch.float32, + ), + requires_grad=False, + ) + for _ in range(self.config.num_decoder_layers * 2) + ] + ) + self.past_key_values_ca = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.ones( + ( + self.num_beams * batch_size, + self.num_attention_heads_per_partition, + self.sequence_length, + self.config.d_kv, + ), + dtype=torch.float32, + ), + requires_grad=False, + ) + for _ in range(self.config.num_decoder_layers * 2) + ] + ) def forward(self, input_ids, attention_mask): - # Infer shapes + # Infer shapes of dummy inputs used for tracing batch_size = input_ids.shape[0] sequence_length = input_ids.shape[1] + if self.sequence_length is not None: + assert ( + self.sequence_length + ), f"Different sequence length for the parallel partition({self.sequence_length}) and for dummy inputs({sequence_length}). Make sure that they have the same value." + if self.batch_size is not None: + assert ( + self.batch_size + ), f"Different batch size for the parallel partition({self.batch_size}) and for dummy inputs({batch_size}). Make sure that they have the same value." encoder_output = self.model.encoder( input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False @@ -151,7 +210,7 @@ def forward(self, input_ids, attention_mask): present_key_value_states_sa = [] present_key_value_states_ca = [] - for block in decoder_blocks: + for i, block in enumerate(decoder_blocks): # Cross attention has to be initialized with the encoder hidden state cross_attention: T5LayerCrossAttention = block.layer[1] attention = cross_attention.EncDecAttention @@ -159,34 +218,67 @@ def forward(self, input_ids, attention_mask): def shape(states): """projection""" return states.view( - self.num_beams * batch_size, -1, self.config.num_heads, attention.key_value_proj_dim + self.num_beams * batch_size, + -1, + self.num_attention_heads_per_partition, + attention.key_value_proj_dim, ).transpose(1, 2) key_states = shape(attention.k(encoder_hidden_states)) value_states = shape(attention.v(encoder_hidden_states)) - # cross_attn_kv_state - present_key_value_states_ca.append(key_states) - present_key_value_states_ca.append(value_states) - - # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant. - # The kv cache is padded here to keep a fixed shape. - # [key states] - present_key_value_states_sa.append( - torch.zeros( - (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), - dtype=torch.float32, - device=self.device, + if not self.tensor_parallel_size > 1: + # cross_attn_kv_state + present_key_value_states_ca.append(key_states) + present_key_value_states_ca.append(value_states) + + # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant. + # The kv cache is padded here to keep a fixed shape. + # [key states] + present_key_value_states_sa.append( + torch.zeros( + (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), + dtype=torch.float32, + device=self.device, + ) ) - ) - # [value states] - present_key_value_states_sa.append( - torch.zeros( - (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), - dtype=torch.float32, - device=self.device, + # [value states] + present_key_value_states_sa.append( + torch.zeros( + (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), + dtype=torch.float32, + device=self.device, + ) + ) + else: + present_key_value_states_ca.append((self.past_key_values_ca[i * 2] * 0) + key_states) + present_key_value_states_ca.append((self.past_key_values_ca[i * 2 + 1] * 0) + value_states) + present_key_value_states_sa.append( + self.past_key_values_sa[i * 2] + * torch.zeros( + ( + self.num_beams * self.batch_size, + self.num_attention_heads_per_partition, + self.sequence_length - 1, + self.config.d_kv, + ), + dtype=torch.float32, + device=self.device, + ) + ) + present_key_value_states_sa.append( + self.past_key_values_sa[i * 2 + 1] + * torch.zeros( + ( + self.num_beams * self.batch_size, + self.num_attention_heads_per_partition, + self.sequence_length - 1, + self.config.d_kv, + ), + dtype=torch.float32, + device=self.device, + ) ) - ) return present_key_value_states_sa + present_key_value_states_ca @@ -204,7 +296,7 @@ def __init__( output_hidden_states: bool = False, output_attentions: bool = False, device: str = "xla", - tp_degree: Optional[int] = None, + tensor_parallel_size: int = 1, ): super().__init__() self.model = model @@ -215,7 +307,14 @@ def __init__( self.output_hidden_states = output_hidden_states self.output_attentions = output_attentions self.device = device - self.tp_degree = tp_degree + self.tensor_parallel_size = tensor_parallel_size + + self.num_attention_heads_per_partition = self.config.num_heads + if tensor_parallel_size > 1: + self.num_attention_heads_per_partition = ( + self.num_attention_heads_per_partition + // neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size() + ) # Initialize KV cache (num_beams, n_heads, seq_length, dim_per_head) if device == "cpu": @@ -238,7 +337,7 @@ def __init__( torch.ones( ( self.batch_size * self.num_beams, - self.config.num_heads, + self.num_attention_heads_per_partition, sequence_length - 1, self.config.d_kv, ), @@ -255,7 +354,7 @@ def __init__( torch.ones( ( self.batch_size * self.num_beams, - self.config.num_heads, + self.num_attention_heads_per_partition, sequence_length, self.config.d_kv, ), @@ -284,8 +383,7 @@ def update_past(self, past_key_values): def reorder_cache(self, past_key_values, beam_idx): for i in range(len(past_key_values)): - gather_index = beam_idx.view([beam_idx.shape[0], 1, 1, 1]).expand_as(past_key_values[i]) - past_key_values[i] = torch.gather(past_key_values[i], dim=0, index=gather_index) + past_key_values[i] = torch.index_select(past_key_values[i], 0, beam_idx) return past_key_values def forward( diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index a5673c6ae..690ea2bdf 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -17,6 +17,7 @@ import copy import os from collections import OrderedDict +from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -465,10 +466,12 @@ def replace_stable_diffusion_submodels(pipeline, submodels): def get_encoder_decoder_models_for_export( model: "PreTrainedModel", task: str, + tensor_parallel_size: int, input_shapes: Dict[str, int], dynamic_batch_size: Optional[bool] = False, output_attentions: bool = False, output_hidden_states: bool = False, + model_name_or_path: Optional[Union[str, Path]] = None, ) -> Dict[str, Tuple["PreTrainedModel", "NeuronDefaultConfig"]]: """ Returns the components of an encoder-decoder model and their subsequent neuron configs. @@ -479,6 +482,10 @@ def get_encoder_decoder_models_for_export( Args: model ("PreTrainedModel"): The model to export. + task (`str`): + The task to export the model for. If not specified, the task will be auto-inferred based on the model. + tensor_parallel_size (`int`): + Tensor parallelism size, the number of Neuron cores on which to shard the model. input_shapes (`Dict[str, int]`): Static shapes used for compiling the encoder and the decoder. dynamic_batch_size (`bool`, defaults to `False`): @@ -487,6 +494,8 @@ def get_encoder_decoder_models_for_export( Whether or not for the traced model to return the attentions tensors of all attention layers. output_hidden_states (`bool`, defaults to `False`): Whether or not for the traced model to return the hidden states of all layers. + model_name_or_path (`Optional[Union[str, Path]]`, defaults to `None`): + The location from where the model is loaded, this is needed in the case of tensor parallelism, since we need to load the model within the tracing API. Returns: `Dict[str, Tuple["PreTrainedModel", "NeuronDefaultConfig"]]`: A Dict containing the model and @@ -507,9 +516,18 @@ def get_encoder_decoder_models_for_export( config=model.config, task=task, dynamic_batch_size=dynamic_batch_size, + tensor_parallel_size=tensor_parallel_size, **input_shapes, ) - models_for_export[ENCODER_NAME] = (model, encoder_neuron_config) + if not tensor_parallel_size > 1: + models_for_export[ENCODER_NAME] = (model, encoder_neuron_config) + else: + if model_name_or_path: + models_for_export[ENCODER_NAME] = (model_name_or_path, encoder_neuron_config) + else: + raise ValueError( + f"you need to precise `model_name_or_path` when the parallelism is on, but now it's {model_name_or_path}." + ) # Decoder model_type = getattr(model.config, "model_type") + "-decoder" @@ -524,10 +542,19 @@ def get_encoder_decoder_models_for_export( config=model.config, task=task, dynamic_batch_size=dynamic_batch_size, + tensor_parallel_size=tensor_parallel_size, output_attentions=output_attentions, output_hidden_states=output_hidden_states, **input_shapes, ) - models_for_export[DECODER_NAME] = (model, decoder_neuron_config) + if not tensor_parallel_size > 1: + models_for_export[DECODER_NAME] = (model, decoder_neuron_config) + else: + if model_name_or_path: + models_for_export[DECODER_NAME] = (model_name_or_path, decoder_neuron_config) + else: + raise ValueError( + f"you need to precise `model_name_or_path` when the parallelism is on, but now it's {model_name_or_path}." + ) return models_for_export diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py index 2eca15285..6cb53d1c0 100644 --- a/optimum/neuron/modeling_seq2seq.py +++ b/optimum/neuron/modeling_seq2seq.py @@ -26,6 +26,7 @@ import torch from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModelForSeq2SeqLM, GenerationConfig +from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.utils import ModelOutput @@ -43,6 +44,7 @@ ENCODER_NAME, NEURON_FILE_NAME, is_neuronx_available, + is_neuronx_distributed_available, ) @@ -52,8 +54,40 @@ if is_neuronx_available(): import torch_neuronx +if is_neuronx_distributed_available(): + import neuronx_distributed + logger = logging.getLogger(__name__) +_TOKENIZER_FOR_DOC = "AutoTokenizer" + +NEURON_SEQ2SEQ_MODEL_START_DOCSTRING = r""" + This model inherits from [`~neuron.modeling.NeuronTracedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving) + + Args: + encoder (`torch.jit._script.ScriptModule`): [torch.jit._script.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html) is the TorchScript module of the encoder with embedded NEFF(Neuron Executable File Format) compiled by neuron(x) compiler. + decoder (`torch.jit._script.ScriptModule`): [torch.jit._script.ScriptModule](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html) is the TorchScript module of the decoder with embedded NEFF(Neuron Executable File Format) compiled by neuron(x) compiler. + config (`transformers.PretrainedConfig`): [PretrainedConfig](https://huggingface.co/docs/transformers/main_classes/configuration#transformers.PretrainedConfig) is the Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`optimum.neuron.modeling.NeuronTracedModel.from_pretrained`] method to load the model weights. +""" + +NEURON_SEQ2SEQ_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using [`AutoTokenizer`](https://huggingface.co/docs/transformers/autoclass_tutorial#autotokenizer). + See [`PreTrainedTokenizer.encode`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.encode) and + [`PreTrainedTokenizer.__call__`](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizerBase.__call__) for details. + [What are input IDs?](https://huggingface.co/docs/transformers/glossary#input-ids) + attention_mask (`Union[torch.Tensor, None]` of shape `({0})`, defaults to `None`): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](https://huggingface.co/docs/transformers/glossary#attention-mask) +""" + class NeuronModelForConditionalGeneration(NeuronTracedModel, ABC): base_model_prefix = "neuron_model" @@ -71,7 +105,6 @@ def __init__( neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None, configs: Optional[Dict[str, "PretrainedConfig"]] = None, generation_config: Optional[GenerationConfig] = None, - model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, **kwargs, ): self.config = config @@ -81,7 +114,6 @@ def __init__( self.neuron_configs[ENCODER_NAME] ) # only for the encoder self._attributes_init(model_save_dir, preprocessors, **kwargs) - self.model_and_config_save_paths = model_and_config_save_paths if model_and_config_save_paths else None self.encoder = NeuronEncoder( encoder, self, @@ -103,55 +135,39 @@ def __init__( if generation_config is None: generation_config = GenerationConfig.from_model_config(self.configs[DECODER_NAME]) self.generation_config = generation_config + self.tensor_parallel_size = self.neuron_configs[DECODER_NAME].tensor_parallel_size - def _save_pretrained( - self, - save_directory: Union[str, Path], - encoder_file_name: str = NEURON_FILE_NAME, - decoder_file_name: str = NEURON_FILE_NAME, - ): + def _save_pretrained(self, save_directory: Union[str, Path]): """ - Saves the model encoder and decoder as well as their configuration files to a - directory, so that it can be re-loaded using the - [`~optimum.neuron.modeling_seq2seq.NeuronModelForSeq2SeqLM.from_pretrained`] class method. + Saves a model and its configuration file to a directory, so that it can be re-loaded using the + [`~optimum.neuron.modeling_traced.NeuronTracedModel.from_pretrained`] class method. Args: - save_directory (`Union[str, Path`]): - The directory where to save the model files. - encoder_file_name (`str`, defaults to `NEURON_FILE_NAME`]): - The file name to save the encoder. - decoder_file_name (`str`, defaults to `NEURON_FILE_NAME`]): - The file name to save the decoder. + save_directory (`Union[str, Path]`): + Directory where to save the model file. """ - if self.model_and_config_save_paths is None: - logger.warning( - "`model_save_paths` is None which means that no path of Neuron model is defined. Nothing will be saved." - ) - return - - save_directory = Path(save_directory) - if not self.model_and_config_save_paths.get(ENCODER_NAME)[0].is_file(): - self.model_and_config_save_paths.pop(ENCODER_NAME) - - if not self.model_and_config_save_paths.get(DECODER_NAME)[0].is_file(): - self.model_and_config_save_paths.pop(DECODER_NAME) - - dst_paths = [ - save_directory / ENCODER_NAME / encoder_file_name, - save_directory / DECODER_NAME / decoder_file_name, - ] - src_paths = [ - Path(self.model_and_config_save_paths[ENCODER_NAME][0]), - Path(self.model_and_config_save_paths[DECODER_NAME][0]), - ] - - for src_path, dst_path in zip(src_paths, dst_paths): - dst_path.parent.mkdir(parents=True, exist_ok=True) - if src_path.is_file(): - shutil.copyfile(src_path, dst_path) - + shutil.copytree(self.model_save_dir, save_directory, dirs_exist_ok=True) self.generation_config.save_pretrained(save_directory) + @staticmethod + def load_model( + encoder_path: Union[str, Path], + decoder_path: Union[str, Path], + tensor_parallel_size: int, + ): + if tensor_parallel_size == 1: + # Initialize Neuron Runtime before loading models + runtime = torch.classes.neuron.Runtime() + runtime.initialize() + runtime.set_default_neuron_cores(0, 1) + encoder = NeuronTracedModel.load_model(encoder_path) + decoder = NeuronTracedModel.load_model(decoder_path) + torch_neuronx.move_trace_to_device(decoder, 0) + else: + encoder = neuronx_distributed.trace.parallel_model_load(encoder_path) + decoder = neuronx_distributed.trace.parallel_model_load(decoder_path) + return encoder, decoder + @classmethod def _from_pretrained( cls, @@ -205,14 +221,11 @@ def _from_pretrained( configs[name] = model_config neuron_configs[name] = cls._neuron_config_init(model_config) - # Initialize Neuron Runtime before loading models - runtime = torch.classes.neuron.Runtime() - runtime.initialize() - runtime.set_default_neuron_cores(0, 1) - - encoder = cls.load_model(model_and_config_save_paths[ENCODER_NAME][0]) - decoder = cls.load_model(model_and_config_save_paths[DECODER_NAME][0]) - torch_neuronx.move_trace_to_device(decoder, 0) + encoder, decoder = cls.load_model( + encoder_path=model_and_config_save_paths[ENCODER_NAME][0], + decoder_path=model_and_config_save_paths[DECODER_NAME][0], + tensor_parallel_size=configs["decoder"].neuron.get("tensor_parallel_size", 1), + ) if model_save_dir is None: model_save_dir = new_model_save_dir @@ -226,7 +239,6 @@ def _from_pretrained( local_files_only=local_files_only, token=token, revision=revision, - subfolder=os.path.join(subfolder, DECODER_NAME), ) except OSError: logger.info("Generation config file not found, using a generation config created from the model config.") @@ -242,7 +254,6 @@ def _from_pretrained( neuron_configs=neuron_configs, configs=configs, generation_config=generation_config, - model_and_config_save_paths=model_and_config_save_paths, ) @classmethod @@ -260,6 +271,7 @@ def _export( force_download: bool = True, cache_dir: Optional[str] = None, compiler_workdir: Optional[str] = None, + tensor_parallel_size: Optional[int] = 1, inline_weights_to_neff: bool = True, optlevel: str = "2", subfolder: str = "", @@ -299,6 +311,7 @@ def _export( model_name_or_path=model_id, output=save_dir_path, compiler_kwargs=compiler_kwargs, + tensor_parallel_size=tensor_parallel_size, task=task, dynamic_batch_size=dynamic_batch_size, cache_dir=cache_dir, @@ -350,10 +363,76 @@ def _combine_encoder_decoder_config(self, encoder_config: "PretrainedConfig", de return combined_config +TRANSLATION_EXAMPLE = r""" + *(Following models are compiled with neuronx compiler and can only be run on INF2.)* + Example of text-to-text generation with small T5 model: + + ```python + from transformers import {processor_class} + from optimum.neuron import {model_class} + + neuron_model = {model_class}.from_pretrained({checkpoint_regular}, export=True, dynamic_batch_size=False, batch_size=1, sequence_length=64, num_beams=4) + neuron_model.save_pretrained("t5_small_neuronx") + del neuron_model + + neuron_model = {model_class}.from_pretrained("t5_small_neuronx") + tokenizer = {processor_class}.from_pretrained("t5_small_neuronx") + inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt") + + output = neuron_model.generate( + **inputs, + num_return_sequences=1, + ) + results = [tokenizer.decode(t, skip_special_tokens=True) for t in output] + ``` + + *(For large models, in order to fit into Neuron cores, we need to apply tensor parallelism. Here below is an example ran on `inf2.24xlarge`.)* + Example of text-to-text generation with tensor parallelism: + + ```python + from transformers import {processor_class} + from optimum.neuron import {model_class} + # 1. compile + if __name__ == "__main__": # compulsory for parallel tracing since the API will spawn multiple processes. + neuron_model = {model_class}.from_pretrained( + {checkpoint_tp}, export=True, tensor_parallel_size=8, dynamic_batch_size=False, batch_size=1, sequence_length=128, num_beams=4, + ) + neuron_model.save_pretrained("flan_t5_xl_neuronx_tp8/") + del neuron_model + + # 2. inference + neuron_model = {model_class}.from_pretrained("flan_t5_xl_neuronx_tp8") + tokenizer = {processor_class}.from_pretrained("flan_t5_xl_neuronx_tp8") + inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt") + + output = neuron_model.generate( + **inputs, + num_return_sequences=1, + ) + results = [tokenizer.decode(t, skip_special_tokens=True) for t in output] + ``` +""" # noqa: W293 + + +@add_start_docstrings( + """ + Neuron Sequence-to-sequence model with a language modeling head for text2text-generation tasks. + """, + NEURON_SEQ2SEQ_MODEL_START_DOCSTRING, +) class NeuronModelForSeq2SeqLM(NeuronModelForConditionalGeneration, NeuronGenerationMixin): auto_model_class = AutoModelForSeq2SeqLM main_input_name = "input_ids" + @add_start_docstrings_to_model_forward( + NEURON_SEQ2SEQ_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + TRANSLATION_EXAMPLE.format( + processor_class=_TOKENIZER_FOR_DOC, + model_class="NeuronModelForSeq2SeqLM", + checkpoint_regular="google-t5/t5-small", + checkpoint_tp="google/flan-t5-xl", + ) + ) def forward( self, attention_mask: Optional[torch.FloatTensor] = None, @@ -438,9 +517,16 @@ def generate( axis=1, ) - # copy the new cache state to the decoder - for state, tensor in zip(self.decoder.model.parameters(), past_key_values): - state.copy_(tensor) + if self.tensor_parallel_size == 1: + # copy the new cache state to the decoder + for state, tensor in zip(self.decoder.model.parameters(), past_key_values): + state.copy_(tensor) + else: + # Here we iterate sharded encoders and decoders since the encoder on each rank will return cache as device tensors, + # we want to assign them to the cache of the sharded decoder on the same rank to avoid the copy. The KV cache always + # use pre-allocated memory, no host-device communication overhead. + for decoder_tp, encoder_tp in zip(self.decoder.model.models, self.encoder.model.models): + decoder_tp.load_state_dict(encoder_tp.state_dict(), strict=False) output = super().generate( **inputs, diff --git a/optimum/neuron/modeling_traced.py b/optimum/neuron/modeling_traced.py index a8235c029..708e2c38e 100644 --- a/optimum/neuron/modeling_traced.py +++ b/optimum/neuron/modeling_traced.py @@ -436,6 +436,7 @@ def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronDefaultConfig # Fetch compiler information compiler_type = neuron_config.get("compiler_type") compiler_version = neuron_config.get("compiler_version") + tensor_parallel_size = neuron_config.get("tensor_parallel_size", 1) # Fetch mandatory shapes from config compile_shapes = { @@ -461,6 +462,7 @@ def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronDefaultConfig dynamic_batch_size=neuron_config.get("dynamic_batch_size", False), compiler_type=compiler_type, compiler_version=compiler_version, + tensor_parallel_size=tensor_parallel_size, **compile_shapes, ) diff --git a/optimum/neuron/utils/argument_utils.py b/optimum/neuron/utils/argument_utils.py index ecf57e621..8546802bd 100644 --- a/optimum/neuron/utils/argument_utils.py +++ b/optimum/neuron/utils/argument_utils.py @@ -144,6 +144,7 @@ def store_compilation_config( compiler_version: str, inline_weights_to_neff: bool, optlevel: str, + tensor_parallel_size: int = 1, model_type: Optional[str] = None, task: Optional[str] = None, input_names: Optional[List[str]] = None, @@ -170,6 +171,7 @@ def store_compilation_config( config_args[axis] = shape config_args["dynamic_batch_size"] = dynamic_batch_size + config_args["tensor_parallel_size"] = tensor_parallel_size # Add compilation args to the config config_args["optlevel"] = optlevel diff --git a/tests/cli/test_export_cli.py b/tests/cli/test_export_cli.py index cdce97e30..863a9f41a 100644 --- a/tests/cli/test_export_cli.py +++ b/tests/cli/test_export_cli.py @@ -303,6 +303,9 @@ def test_replace_unet(self): check=True, ) + @unittest.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." + ) @requires_neuronx def test_encoder_decoder(self): model_id = "hf-internal-testing/tiny-random-t5" @@ -332,6 +335,9 @@ def test_encoder_decoder(self): check=True, ) + @unittest.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." + ) @requires_neuronx def test_encoder_decoder_optional_outputs(self): model_id = "hf-internal-testing/tiny-random-t5" @@ -362,3 +368,37 @@ def test_encoder_decoder_optional_outputs(self): shell=False, check=True, ) + + @unittest.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." + ) + @requires_neuronx + def test_encoder_decoder_tp2(self): + model_id = "michaelbenayoun/t5-tiny-random" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + [ + "optimum-cli", + "export", + "neuron", + "--model", + model_id, + "--task", + "text2text-generation", + "--tensor_parallel_size", + "2", + "--batch_size", + "1", + "--sequence_length", + "18", + "--num_beams", + "4", + "--auto_cast", + "matmul", + "--auto_cast_type", + "bf16", + tempdir, + ], + shell=False, + check=True, + ) diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index 28aeb219f..dcf2b09dd 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -152,7 +152,7 @@ def _neuronx_export( with NamedTemporaryFile("w") as output: try: _, neuron_outputs = export( - model=model, + model_or_path=model, config=neuron_config, output=Path(output.name), inline_weights_to_neff=inline_weights_to_neff, @@ -310,6 +310,9 @@ def test_export_sd_with_fused_lora_weights(self): ) +@unittest.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) @is_inferentia_test @requires_neuronx class NeuronEncoderDecoderExportTestCase(unittest.TestCase): diff --git a/tests/generation/conftest.py b/tests/generation/conftest.py index 0d19e865d..7f2b1b2f9 100644 --- a/tests/generation/conftest.py +++ b/tests/generation/conftest.py @@ -129,6 +129,28 @@ def neuron_seq2seq_greedy_path_with_optional_outputs(export_seq2seq_id): yield model_path +@pytest.fixture(scope="module") +@requires_neuronx +def neuron_seq2seq_tp2_path(): + model = NeuronModelForSeq2SeqLM.from_pretrained( + "michaelbenayoun/t5-tiny-random", + export=True, + tensor_parallel_size=2, + dynamic_batch_size=False, + batch_size=1, + sequence_length=64, + num_beams=4, + ) + model_dir = TemporaryDirectory() + model_path = model_dir.name + model.save_pretrained(model_path) + del model + # Yield instead of returning to keep a reference to the temporary directory. + # It will go out of scope and be released only once all tests needing the fixture + # have been completed. + yield model_path + + @pytest.fixture(scope="module") def neuron_push_seq2seq_id(export_seq2seq_id): model_name = export_seq2seq_id.split("/")[-1] diff --git a/tests/generation/test_export.py b/tests/generation/test_export.py index 7737274ef..cc6065f2d 100644 --- a/tests/generation/test_export.py +++ b/tests/generation/test_export.py @@ -20,6 +20,9 @@ from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx +@pytest.mark.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) @pytest.mark.parametrize( "batch_size, sequence_length, num_beams", [ @@ -40,6 +43,9 @@ def test_seq2seq_export(export_seq2seq_id, batch_size, sequence_length, num_beam return model +@pytest.mark.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) @is_inferentia_test @requires_neuronx def test_seq2seq_model_from_path(neuron_seq2seq_greedy_path): diff --git a/tests/generation/test_generate.py b/tests/generation/test_generate.py index 6bb4ceca1..d5de5e018 100644 --- a/tests/generation/test_generate.py +++ b/tests/generation/test_generate.py @@ -35,6 +35,9 @@ def _test_model_generation_trn(model, tokenizer, batch_size, input_length, **gen assert sample_output.shape[0] == batch_size +@pytest.mark.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_beam(neuron_seq2seq_beam_path): @@ -55,6 +58,9 @@ def test_seq2seq_generation_beam(neuron_seq2seq_beam_path): assert len(output[0].unique()) <= 5 + 1 # +1 for `decoder_start_token_id` +@pytest.mark.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_beam_with_optional_outputs(neuron_seq2seq_beam_path_with_optional_outputs): @@ -77,6 +83,9 @@ def test_seq2seq_generation_beam_with_optional_outputs(neuron_seq2seq_beam_path_ assert "decoder_hidden_states" in output +@pytest.mark.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_greedy(neuron_seq2seq_greedy_path): @@ -97,6 +106,9 @@ def test_seq2seq_generation_greedy(neuron_seq2seq_greedy_path): assert len(output[0]) <= 5 + 1 # +1 for `decoder_start_token_id` +@pytest.mark.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) @is_inferentia_test @requires_neuronx def test_seq2seq_generation_greedy_with_optional_outputs(neuron_seq2seq_greedy_path_with_optional_outputs): @@ -117,6 +129,29 @@ def test_seq2seq_generation_greedy_with_optional_outputs(neuron_seq2seq_greedy_p assert "decoder_hidden_states" in output +@pytest.mark.skip( + "T5 compilation broken since neuron sdk 2.20, wait for the fix: https://github.com/aws-neuron/aws-neuron-sdk/issues/1013." +) +@is_inferentia_test +@requires_neuronx +def test_seq2seq_generation_tp2(neuron_seq2seq_tp2_path): + model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_tp2_path) + tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_tp2_path) + inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt") + + output = model.generate( + **inputs, + num_return_sequences=1, + max_length=20, + output_attentions=True, + output_hidden_states=True, + return_dict_in_generate=True, + ) + assert "decoder_attentions" in output + assert "cross_attentions" in output + assert "decoder_hidden_states" in output + + @pytest.mark.skip("Makes pytest fail, to fix.") @pytest.mark.parametrize( "gen_kwargs", @@ -160,3 +195,10 @@ def test_general_seq2seq_generation(export_seq2seq_id, export_seq2seq_model_clas model = export_seq2seq_model_class.from_pretrained(export_seq2seq_id) tokenizer = AutoTokenizer.from_pretrained(export_seq2seq_id) _test_model_generation_trn(model, tokenizer, 1, 10, **gen_kwargs) + + +# Compulsory for multiprocessing tests, since we want children processes to be spawned only in the main program. +# eg. tensor parallel tracing, `neuronx_distributed.parallel_model_trace` will spawn multiple processes to trace +# and compile the model. +if __name__ == "__main__": + pytest.main([__file__])