From 7d0dbb59487e52e7c477cf52593749824ea5ae9e Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 25 Jan 2024 15:39:08 +0100 Subject: [PATCH] Allow exporting decoder models using optimum-cli (#422) * refactor(export): add root config class * feat(decoder): accept batch_size = None * feat(decoder): accept num_cores = None * feat(cli): support exporting decoder models * tests: use NeuronDefaultConfig * doc: update export cli section * test: add export decoder cli test * Apply suggestions from code review Co-authored-by: Michael Benayoun * feat(decoder): check that the host has neuron devices * ci(inf2): move generation tests up * test(decoder): extend fixture scope to session --------- Co-authored-by: Michael Benayoun --- .github/workflows/test_inf2.yml | 8 +-- docs/source/guides/export_model.mdx | 20 +++--- optimum/commands/export/neuronx.py | 6 ++ optimum/exporters/neuron/__init__.py | 4 +- optimum/exporters/neuron/__main__.py | 43 +++++++++++-- optimum/exporters/neuron/base.py | 77 ++++++++++++++++------- optimum/exporters/neuron/config.py | 14 ++--- optimum/exporters/neuron/convert.py | 24 +++---- optimum/exporters/neuron/model_configs.py | 8 +-- optimum/exporters/neuron/utils.py | 10 +-- optimum/neuron/modeling_base.py | 12 ++-- optimum/neuron/modeling_decoder.py | 12 +++- optimum/neuron/modeling_diffusion.py | 10 +-- optimum/neuron/modeling_seq2seq.py | 6 +- tests/cli/test_export_decoder_cli.py | 47 ++++++++++++++ tests/exporters/test_export.py | 4 +- tests/generation/conftest.py | 8 ++- 17 files changed, 219 insertions(+), 94 deletions(-) create mode 100644 tests/cli/test_export_decoder_cli.py diff --git a/.github/workflows/test_inf2.yml b/.github/workflows/test_inf2.yml index a296128ce..e1e8b3015 100644 --- a/.github/workflows/test_inf2.yml +++ b/.github/workflows/test_inf2.yml @@ -43,6 +43,10 @@ jobs: run: | source aws_neuron_venv_pytorch/bin/activate HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/cli + - name: Run generation tests + run: | + source aws_neuron_venv_pytorch/bin/activate + HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation - name: Run exporters tests run: | source aws_neuron_venv_pytorch/bin/activate @@ -51,10 +55,6 @@ jobs: run: | source aws_neuron_venv_pytorch/bin/activate HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/inference - - name: Run generation tests - run: | - source aws_neuron_venv_pytorch/bin/activate - HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m is_inferentia_test tests/generation - name: Run pipelines tests run: | source aws_neuron_venv_pytorch/bin/activate diff --git a/docs/source/guides/export_model.mdx b/docs/source/guides/export_model.mdx index 968195a7e..0e629ceea 100644 --- a/docs/source/guides/export_model.mdx +++ b/docs/source/guides/export_model.mdx @@ -25,7 +25,7 @@ optimum-cli export neuron \ --model bert-base-uncased \ --sequence_length 128 \ --batch_size 1 \ - bert_neuron/ + bert_neuron/ ``` Check out the help for more options: @@ -36,7 +36,7 @@ optimum-cli export neuron --help ## Why compile to Neuron model? -AWS provides two generations of the Inferentia accelerator built for machine learning inference with higher throughput, lower latency but lower cost: [inf2 (NeuronCore-v2)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html) and [inf1 (NeuronCore-v1)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf1-arch.html#aws-inf1-arch). +AWS provides two generations of the Inferentia accelerator built for machine learning inference with higher throughput, lower latency but lower cost: [inf2 (NeuronCore-v2)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf2-arch.html) and [inf1 (NeuronCore-v1)](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/inf1-arch.html#aws-inf1-arch). In production environments, to deploy 🤗 [Transformers](https://huggingface.co/docs/transformers/index) models on Neuron devices, you need to compile your models and export them to a serialized format before inference. Through Ahead-Of-Time (AOT) compilation with Neuron Compiler( [neuronx-cc](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/compiler/neuronx-cc/index.html) or [neuron-cc](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/compiler/neuron-cc/neuron-cc.html) ), your models will be converted to serialized and optimized [TorchScript modules](https://pytorch.org/docs/stable/generated/torch.jit.ScriptModule.html). @@ -49,10 +49,10 @@ To understand a little bit more about the compilation, here are general steps ex Although pre-compilation avoids overhead during the inference, traced Neuron module has some limitations: -* Traced Neuron module will be static, which requires fixed input shapes and data types used passed during the compilation. As the model won't be dynamically recompiled, the inference will fail if any of the above conditions change. +* Traced Neuron module will be static, which requires fixed input shapes and data types used during the compilation. As the model won't be dynamically recompiled, the inference will fail if any of the above conditions change. (*But these limitations could be bypass with [dynamic batching](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/api-reference-guide/inference/api-torch-neuronx-trace.html#dynamic-batching) and [bucketing](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/appnotes/torch-neuron/bucketing-app-note.html#bucketing-app-note)*). * Neuron models are hardware-specialized, which means: - * Models traced with Neuron can no longer be executed in non-Neuron environment. + * Models traced with Neuron can no longer be executed in non-Neuron environment. * Models compiled for inf1 (NeuronCore-v1) are not compatible with inf2 (NeuronCore-v2), and vice versa. In this guide, we'll show you how to export your models to serialized models optimized for Neuron devices. @@ -60,7 +60,7 @@ In this guide, we'll show you how to export your models to serialized models opt 🤗 Optimum provides support for the Neuron export by leveraging configuration objects. -These configuration objects come ready made for a number of model architectures, and are designed to be easily extendable to other architectures. +These configuration objects come ready made for a number of model architectures, and are designed to be easily extendable to other architectures. **To check the supported architectures, go to the [configuration reference page](../package_reference/configuration).** @@ -89,7 +89,7 @@ optimum-cli export neuron --help usage: optimum-cli export neuron [-h] -m MODEL [--task TASK] [--atol ATOL] [--cache_dir CACHE_DIR] [--trust-remote-code] [--compiler_workdir COMPILER_WORKDIR] [--disable-validation] [--auto_cast {none,matmul,all}] - [--auto_cast_type {bf16,fp16,tf32}] [--dynamic-batch-size] [--unet UNET] + [--auto_cast_type {bf16,fp16,tf32}] [--dynamic-batch-size] [--num_cores NUM_CORES] [--unet UNET] [--output_hidden_states] [--output_attentions] [--batch_size BATCH_SIZE] [--sequence_length SEQUENCE_LENGTH] [--num_beams NUM_BEAMS] [--num_choices NUM_CHOICES] [--num_channels NUM_CHANNELS] [--width WIDTH] [--height HEIGHT] @@ -137,6 +137,8 @@ Optional arguments: --dynamic-batch-size Enable dynamic batch size for neuron compiled model. If this option is enabled, the input batch size can be a multiple of the batch size during the compilation, but it comes with a potential tradeoff in terms of latency. + --num_cores NUM_CORES + The number of cores on which the model should be deployed (text-generation only). --unet UNET UNet model ID on huggingface.co or path on disk to load model from. This will replace the unet in the original Stable Diffusion pipeline. --output_hidden_states @@ -173,7 +175,7 @@ Exporting a checkpoint can be done as follows: optimum-cli export neuron --model distilbert-base-uncased-distilled-squad --batch_size 1 --sequence_length 16 distilbert_base_uncased_squad_neuron/ ``` -You should see the following logs which validate the model on Neuron deivces by comparing with PyTorch model on CPU: +You should see the following logs which validate the model on Neuron devices by comparing with PyTorch model on CPU: ```bash Validating Neuron model... @@ -192,7 +194,7 @@ As you can see, the task was automatically detected. This was possible because t optimum-cli export neuron --model local_path --task question-answering --batch_size 1 --sequence_length 16 --dynamic-batch-size distilbert_base_uncased_squad_neuron/ ``` -Note that providing the `--task` argument for a model on the Hub will disable the automatic task detection. The resulting `model.neuron` file, can then be loaded and run on Neuron devices. +Note that providing the `--task` argument for a model on the Hub will disable the automatic task detection. The resulting `model.neuron` file, can then be loaded and run on Neuron devices. ## Exporting a model to Neuron via NeuronModel @@ -204,7 +206,7 @@ You will also be able to export your models to Neuron format with `optimum.neuro >>> input_shapes = {"batch_size": 1, "sequence_length": 64} # mandatory shapes >>> model = NeuronModelForSequenceClassification.from_pretrained( ... "distilbert-base-uncased-finetuned-sst-2-english", export=True, **input_shapes -... ) +... ) # Save the model >>> model.save_pretrained("./distilbert-base-uncased-finetuned-sst-2-english_neuron/") diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index c711a6c2e..58dba4923 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -109,6 +109,12 @@ def parse_args_neuronx(parser: "ArgumentParser"): action="store_true", help="Enable dynamic batch size for neuron compiled model. If this option is enabled, the input batch size can be a multiple of the batch size during the compilation, but it comes with a potential tradeoff in terms of latency.", ) + optional_group.add_argument( + "--num_cores", + type=int, + default=None, + help="The number of cores on which the model should be deployed (text-generation only).", + ) optional_group.add_argument( "--unet", default=None, diff --git a/optimum/exporters/neuron/__init__.py b/optimum/exporters/neuron/__init__.py index 313507c14..c7dd3ec1a 100644 --- a/optimum/exporters/neuron/__init__.py +++ b/optimum/exporters/neuron/__init__.py @@ -24,7 +24,7 @@ "normalize_input_shapes", "normalize_stable_diffusion_input_shapes", ], - "base": ["NeuronConfig"], + "base": ["NeuronDefaultConfig"], "convert": ["export", "export_models", "validate_model_outputs", "validate_models_outputs"], "utils": [ "DiffusersPretrainedConfig", @@ -40,7 +40,7 @@ normalize_input_shapes, normalize_stable_diffusion_input_shapes, ) - from .base import NeuronConfig + from .base import NeuronDefaultConfig from .convert import export, export_models, validate_model_outputs, validate_models_outputs from .utils import ( DiffusersPretrainedConfig, diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 00bd8522d..a4c2eb28c 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -24,6 +24,7 @@ from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoConfig, PretrainedConfig +from ...neuron import NeuronModelForCausalLM from ...neuron.utils import ( DECODER_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, @@ -41,6 +42,7 @@ from ...utils.save_utils import maybe_save_preprocessors from ..error_utils import AtolError, OutputMatchError, ShapeError from ..tasks import TasksManager +from .base import NeuronDecoderConfig from .convert import export_models, validate_models_outputs from .model_configs import * # noqa: F403 from .utils import ( @@ -106,7 +108,7 @@ def infer_task(task: str, model_name_or_path: str) -> str: return task -def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int]: +def get_input_shapes_and_config_class(task: str, args: argparse.Namespace) -> Dict[str, int]: config = AutoConfig.from_pretrained(args.model) model_type = config.model_type.replace("_", "-") @@ -116,9 +118,9 @@ def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int neuron_config_constructor = TasksManager.get_exporter_config_constructor( model_type=model_type, exporter="neuron", task=task ) - mandatory_axes = neuron_config_constructor.func.get_mandatory_axes_for_task(task) - input_shapes = {name: getattr(args, name) for name in mandatory_axes} - return input_shapes + input_args = neuron_config_constructor.func.get_input_args_for_task(task) + input_shapes = {name: getattr(args, name) for name in input_args} + return input_shapes, neuron_config_constructor.func def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]: @@ -457,6 +459,19 @@ def main_export( ) +def decoder_export( + model_name_or_path: str, + output: Union[str, Path], + **kwargs, +): + output = Path(output) + if not output.parent.exists(): + output.parent.mkdir(parents=True) + + model = NeuronModelForCausalLM.from_pretrained(model_name_or_path, export=True, **kwargs) + model.save_pretrained(output) + + def main(): parser = ArgumentParser(f"Hugging Face Optimum {NEURON_COMPILER} exporter") @@ -468,7 +483,6 @@ def main(): task = infer_task(args.task, args.model) is_stable_diffusion = "stable-diffusion" in task is_sentence_transformers = args.library_name == "sentence_transformers" - compiler_kwargs = infer_compiler_kwargs(args) if is_stable_diffusion: input_shapes = normalize_stable_diffusion_input_shapes(args) @@ -477,9 +491,26 @@ def main(): input_shapes = normalize_sentence_transformers_input_shapes(args) submodels = None else: - input_shapes = normalize_input_shapes(task, args) + input_shapes, neuron_config_class = get_input_shapes_and_config_class(task, args) + if NeuronDecoderConfig in inspect.getmro(neuron_config_class): + # TODO: warn about ignored args: + # dynamic_batch_size, compiler_workdir, optlevel, + # atol, disable_validation, library_name + decoder_export( + model_name_or_path=args.model, + output=args.output, + task=task, + cache_dir=args.cache_dir, + trust_remote_code=args.trust_remote_code, + subfolder=args.subfolder, + auto_cast_type=args.auto_cast_type, + num_cores=args.num_cores, + **input_shapes, + ) + return submodels = None + compiler_kwargs = infer_compiler_kwargs(args) optional_outputs = customize_optional_outputs(args) optlevel = parse_optlevel(args) diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index 5f7277b53..240859e69 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -38,9 +38,46 @@ class MissingMandatoryAxisDimension(ValueError): pass -class NeuronConfig(ExportConfig, ABC): +class NeuronConfig(ExportConfig): + """Base class for Neuron exportable models + + Class attributes: + + - INPUT_ARGS (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either: + - An argument name, for instance "batch_size" or "sequence_length", that indicates that the argument can + be passed to export the model, + - Or a tuple containing two elements: + - The first one is either a string or a tuple of strings and specifies for which task(s) the argument is relevant + - The second one is the argument name. + + Input arguments can be mandatory for some export types, as specified in child classes. + + Args: + task (`str`): + The task the model should be exported for. + """ + + INPUT_ARGS = () + + @classmethod + def get_input_args_for_task(cls, task: str) -> Tuple[str]: + axes = [] + for axis in cls.INPUT_ARGS: + if isinstance(axis, tuple): + tasks, name = axis + if not isinstance(tasks, tuple): + tasks = (tasks,) + if task not in tasks: + continue + else: + name = axis + axes.append(name) + return tuple(axes) + + +class NeuronDefaultConfig(NeuronConfig, ABC): """ - Base class for Neuron exportable model describing metadata on how to export the model through the TorchScript format. + Base class for configuring the export of Neuron TorchScript models. Class attributes: @@ -50,14 +87,14 @@ class NeuronConfig(ExportConfig, ABC): [`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs. - ATOL_FOR_VALIDATION (`Union[float, Dict[str, float]]`) -- A float or a dictionary mapping task names to float, where the float values represent the absolute tolerance value to use during model conversion validation. - - MANDATORY_AXES (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either: - - An axis name, for instance "batch_size" or "sequence_length", that indicates that the axis dimension is - needed to export the model, + - INPUT_ARGS (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either: + - An argument name, for instance "batch_size" or "sequence_length", that indicates that the argument MUST + be passed to export the model, - Or a tuple containing two elements: - - The first one is either a string or a tuple of strings and specifies for which task(s) the axis is needed - - The second one is the axis name. + - The first one is either a string or a tuple of strings and specifies for which task(s) the argument must be passed + - The second one is the argument name. - For example: `MANDATORY_AXES = ("batch_size", "sequence_length", ("multiple-choice", "num_choices"))` means that + For example: `INPUT_ARGS = ("batch_size", "sequence_length", ("multiple-choice", "num_choices"))` means that to export the model, the batch size and sequence length values always need to be specified, and that a value for the number of possible choices is needed when the task is multiple-choice. @@ -74,13 +111,12 @@ class NeuronConfig(ExportConfig, ABC): The data type of float tensors, could be ["fp32", "fp16", "bf16"], default to "fp32". The rest of the arguments are used to specify the shape of the inputs the model can take. - They are required or not depending on the model the `NeuronConfig` is designed for. + They are required or not depending on the model the `NeuronDefaultConfig` is designed for. """ NORMALIZED_CONFIG_CLASS = None DUMMY_INPUT_GENERATOR_CLASSES = () ATOL_FOR_VALIDATION: Union[float, Dict[str, float]] = 1e-5 - MANDATORY_AXES = () MODEL_TYPE = None _TASK_TO_COMMON_OUTPUTS = { @@ -165,18 +201,7 @@ def __init__( @classmethod def get_mandatory_axes_for_task(cls, task: str) -> Tuple[str]: - axes = [] - for axis in cls.MANDATORY_AXES: - if isinstance(axis, tuple): - tasks, name = axis - if not isinstance(tasks, tuple): - tasks = (tasks,) - if task not in tasks: - continue - else: - name = axis - axes.append(name) - return tuple(axes) + return cls.get_input_args_for_task(task) @property def task(self) -> str: @@ -343,12 +368,15 @@ def forward(self, *input): return ModelWrapper(model, list(dummy_inputs.keys())) -class NeuronDecoderConfig(ExportConfig): +class NeuronDecoderConfig(NeuronConfig): """ Base class for configuring the export of Neuron Decoder models Class attributes: + - INPUT_ARGS (`Tuple[Union[str, Tuple[Union[str, Tuple[str]]]]]`) -- A tuple where each element is either: + - An argument name, for instance "batch_size" or "sequence_length", that indicates that the argument can + be passed to export the model, - NEURONX_CLASS (`str`) -- the name of the transformers-neuronx class to instantiate for the model. It is a full class name defined relatively to the transformers-neuronx module, e.g. `gpt2.model.GPT2ForSampling` [`~optimum.utils.DummyInputGenerator`] specifying how to create dummy inputs. @@ -359,9 +387,10 @@ class NeuronDecoderConfig(ExportConfig): task (`str`): The task the model should be exported for. """ + INPUT_ARGS = ("batch_size", "sequence_length") NEURONX_CLASS = None - def __init__(self, task): + def __init__(self, task: str): if not is_transformers_neuronx_available(): raise ModuleNotFoundError( "The mandatory transformers-neuronx package is missing. Please install optimum[neuronx]." diff --git a/optimum/exporters/neuron/config.py b/optimum/exporters/neuron/config.py index 01a3ae86a..33e680ef3 100644 --- a/optimum/exporters/neuron/config.py +++ b/optimum/exporters/neuron/config.py @@ -27,31 +27,31 @@ DummyVisionInputGenerator, logging, ) -from .base import NeuronConfig, NeuronDecoderConfig +from .base import NeuronDecoderConfig, NeuronDefaultConfig logger = logging.get_logger(__name__) -class TextEncoderNeuronConfig(NeuronConfig): +class TextEncoderNeuronConfig(NeuronDefaultConfig): """ Handles encoder-based text architectures. """ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) - MANDATORY_AXES = ("batch_size", "sequence_length", ("multiple-choice", "num_choices")) + INPUT_ARGS = ("batch_size", "sequence_length", ("multiple-choice", "num_choices")) -class VisionNeuronConfig(NeuronConfig): +class VisionNeuronConfig(NeuronDefaultConfig): """ Handles vision architectures. """ DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,) - MANDATORY_AXES = ("batch_size", "num_channels", "width", "height") + INPUT_ARGS = ("batch_size", "num_channels", "width", "height") -class TextAndVisionNeuronConfig(NeuronConfig): +class TextAndVisionNeuronConfig(NeuronDefaultConfig): """ Handles multi-modal text and vision architectures. """ @@ -67,7 +67,7 @@ class TextNeuronDecoderConfig(NeuronDecoderConfig): pass -class TextSeq2SeqNeuronConfig(NeuronConfig): +class TextSeq2SeqNeuronConfig(NeuronDefaultConfig): """ Handles encoder-decoder-based text architectures. """ diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 7e0dc6533..08fa1b21b 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -41,7 +41,7 @@ if TYPE_CHECKING: from transformers import PreTrainedModel - from .base import NeuronConfig + from .base import NeuronDefaultConfig if is_neuron_available(): import torch.neuron as neuron # noqa: F811 @@ -67,7 +67,7 @@ def validate_models_outputs( models_and_neuron_configs: Dict[ - str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronConfig"] + str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"] ], neuron_named_outputs: List[List[str]], output_dir: Path, @@ -79,7 +79,7 @@ def validate_models_outputs( The following method validates the Neuron models exported using the `export_models` method. Args: - models_and_neuron_configs (`Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`, `torch.nn.Module`], `NeuronConfig`]]): + models_and_neuron_configs (`Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`, `torch.nn.Module`], `NeuronDefaultConfig`]]): A dictionnary containing the models to export and their corresponding neuron configs. neuron_named_outputs (`List[List[str]]`): The names of the outputs to check. @@ -134,7 +134,7 @@ def validate_models_outputs( def validate_model_outputs( - config: "NeuronConfig", + config: "NeuronDefaultConfig", reference_model: Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"], neuron_model_path: Path, neuron_named_outputs: List[str], @@ -144,7 +144,7 @@ def validate_model_outputs( Validates the export by checking that the outputs from both the reference and the exported model match. Args: - config ([`~optimum.neuron.exporter.NeuronConfig`]: + config ([`~optimum.neuron.exporter.NeuronDefaultConfig`]: The configuration used to export the model. reference_model ([`Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"]`]): The model used for the export. @@ -269,7 +269,7 @@ def validate_model_outputs( def export_models( models_and_neuron_configs: Dict[ - str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronConfig"] + str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"] ], output_dir: Path, compiler_workdir: Optional[Path] = None, @@ -282,7 +282,7 @@ def export_models( Exports a Pytorch model with multiple component models to separate files. Args: - models_and_neuron_configs (`Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], `NeuronConfig`]]): + models_and_neuron_configs (`Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], `NeuronDefaultConfig`]]): A dictionnary containing the models to export and their corresponding neuron configs. output_dir (`Path`): Output directory to store the exported Neuron models. @@ -389,7 +389,7 @@ def export_models( def export( model: "PreTrainedModel", - config: "NeuronConfig", + config: "NeuronDefaultConfig", output: Path, compiler_workdir: Optional[Path] = None, optlevel: str = "2", @@ -418,7 +418,7 @@ def export( def export_neuronx( model: "PreTrainedModel", - config: "NeuronConfig", + config: "NeuronDefaultConfig", output: Path, compiler_workdir: Optional[Path] = None, optlevel: str = "2", @@ -431,7 +431,7 @@ def export_neuronx( Args: model ([`PreTrainedModel`]): The model to export. - config ([`~exporter.NeuronConfig`]): + config ([`~exporter.NeuronDefaultConfig`]): The Neuron configuration associated with the exported model. output (`Path`): Directory to store the exported Neuron model. @@ -553,7 +553,7 @@ def improve_stable_diffusion_loading(config, neuron_model): def export_neuron( model: "PreTrainedModel", - config: "NeuronConfig", + config: "NeuronDefaultConfig", output: Path, auto_cast: Optional[str] = None, auto_cast_type: str = "bf16", @@ -566,7 +566,7 @@ def export_neuron( Args: model ([`PreTrainedModel`]): The model to export. - config ([`~exporter.NeuronConfig`]): + config ([`~exporter.NeuronDefaultConfig`]): The Neuron configuration associated with the exported model. output (`Path`): Directory to store the exported Neuron model. diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 38ee03a62..5b0f20786 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -256,7 +256,7 @@ def outputs(self) -> List[str]: class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig): CUSTOM_MODEL_WRAPPER = SentenceTransformersCLIPNeuronWrapper ATOL_FOR_VALIDATION = 1e-3 - MANDATORY_AXES = ("batch_size", "sequence_length", "num_channels", "width", "height") + INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height") @property def outputs(self) -> List[str]: @@ -269,7 +269,7 @@ def patch_model_for_export(self, model, dummy_inputs): @register_in_tasks_manager("unet", *["semantic-segmentation"]) class UNetNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 - MANDATORY_AXES = ("batch_size", "sequence_length", "num_channels", "width", "height") + INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height") MODEL_TYPE = "unet" CUSTOM_MODEL_WRAPPER = UnetNeuronWrapper NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -416,7 +416,7 @@ class LLamaNeuronConfig(TextNeuronDecoderConfig): @register_in_tasks_manager("t5-encoder", "text2text-generation") class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 - MANDATORY_AXES = ("batch_size", "sequence_length", "num_beams") + INPUT_ARGS = ("batch_size", "sequence_length", "num_beams") MODEL_TYPE = "t5-encoder" CUSTOM_MODEL_WRAPPER = T5EncoderWrapper NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args( @@ -451,7 +451,7 @@ class BloomNeuronConfig(TextNeuronDecoderConfig): class T5DecoderNeuronConfig(TextSeq2SeqNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqNeuronConfig.DUMMY_INPUT_GENERATOR_CLASSES + (DummyBeamValuesGenerator,) - MANDATORY_AXES = ("batch_size", "sequence_length", "num_beams") + INPUT_ARGS = ("batch_size", "sequence_length", "num_beams") MODEL_TYPE = "t5-decoder" CUSTOM_MODEL_WRAPPER = T5DecoderWrapper NORMALIZED_CONFIG_CLASS = T5LikeNormalizedTextConfig diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index b49817f40..6c9678675 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -67,7 +67,7 @@ if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel - from .base import NeuronConfig + from .base import NeuronDefaultConfig if is_diffusers_available(): from diffusers import ModelMixin, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline @@ -135,7 +135,7 @@ def get_stable_diffusion_models_for_export( vae_encoder_input_shapes: Dict[str, int], vae_decoder_input_shapes: Dict[str, int], dynamic_batch_size: Optional[bool] = False, -) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "NeuronConfig"]]: +) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "NeuronDefaultConfig"]]: """ Returns the components of a Stable Diffusion model and their subsequent neuron configs. These components are chosen because they represent the bulk of the compute in the pipeline, @@ -159,7 +159,7 @@ def get_stable_diffusion_models_for_export( Whether the Neuron compiled model supports dynamic batch size. Returns: - `Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronConfig`]`: A Dict containing the model and + `Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronDefaultConfig`]`: A Dict containing the model and Neuron configs for the different components of the model. """ models_for_export = _get_submodels_for_export_stable_diffusion(pipeline=pipeline, task=task) @@ -354,7 +354,7 @@ def get_encoder_decoder_models_for_export( dynamic_batch_size: Optional[bool] = False, output_attentions: bool = False, output_hidden_states: bool = False, -) -> Dict[str, Tuple["PreTrainedModel", "NeuronConfig"]]: +) -> Dict[str, Tuple["PreTrainedModel", "NeuronDefaultConfig"]]: """ Returns the components of an encoder-decoder model and their subsequent neuron configs. The encoder includes the compute of encoder hidden states and the initialization of KV @@ -374,7 +374,7 @@ def get_encoder_decoder_models_for_export( Whether or not for the traced model to return the hidden states of all layers. Returns: - `Dict[str, Tuple["PreTrainedModel", "NeuronConfig"]]`: A Dict containing the model and + `Dict[str, Tuple["PreTrainedModel", "NeuronDefaultConfig"]]`: A Dict containing the model and Neuron configs for the different components of the model. """ models_for_export = {} diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py index 985f5c265..1f3f0b108 100644 --- a/optimum/neuron/modeling_base.py +++ b/optimum/neuron/modeling_base.py @@ -40,7 +40,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig - from ..exporters.neuron import NeuronConfig + from ..exporters.neuron import NeuronDefaultConfig logger = logging.getLogger(__name__) @@ -77,7 +77,7 @@ def __init__( model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, model_file_name: Optional[str] = None, preprocessors: Optional[List] = None, - neuron_config: Optional["NeuronConfig"] = None, + neuron_config: Optional["NeuronDefaultConfig"] = None, **kwargs, ): super().__init__(model, config) @@ -132,7 +132,7 @@ def _from_pretrained( subfolder: str = "", local_files_only: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - neuron_config: Optional["NeuronConfig"] = None, + neuron_config: Optional["NeuronDefaultConfig"] = None, **kwargs, ) -> "NeuronBaseModel": model_path = Path(model_id) @@ -395,9 +395,9 @@ def _attributes_init( self.auto_model_class.register(AutoConfig, self.__class__) @classmethod - def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronConfig": + def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronDefaultConfig": """ - Builds a `NeuronConfig` with an instance of the `PretrainedConfig` and the task. + Builds a `NeuronDefaultConfig` with an instance of the `PretrainedConfig` and the task. """ if not hasattr(config, "neuron"): logger.warning( @@ -434,7 +434,7 @@ def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronConfig": ) @classmethod - def get_input_static_shapes(cls, neuron_config: "NeuronConfig") -> Dict[str, int]: + def get_input_static_shapes(cls, neuron_config: "NeuronDefaultConfig") -> Dict[str, int]: """ Gets a dictionary of inputs with their valid static shapes. """ diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index c38f950ea..f62aecf90 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -181,12 +181,15 @@ def _export( use_auth_token: Optional[str] = None, revision: Optional[str] = None, task: Optional[str] = None, - batch_size: Optional[int] = 1, + batch_size: Optional[int] = None, sequence_length: Optional[int] = None, - num_cores: Optional[int] = 2, + num_cores: Optional[int] = None, auto_cast_type: Optional[str] = "fp32", **kwargs, ) -> "NeuronDecoderModel": + if not os.path.isdir("/sys/class/neuron_device/"): + raise SystemError("Decoder models can only be exported on a neuron platform.") + if task is None: task = TasksManager.infer_task_from_model(cls.auto_model_class) @@ -208,10 +211,15 @@ def _export( model_info = api.repo_info(model_id, revision=revision) checkpoint_revision = model_info.sha + if batch_size is None: + batch_size = 1 # If the sequence_length was not specified, deduce it from the model configuration if sequence_length is None: # Note: for older models, max_position_embeddings is an alias for n_positions sequence_length = config.max_position_embeddings + if num_cores is None: + # Use all available cores + num_cores = len(os.listdir("/sys/class/neuron_device/")) * 2 # Update the config config.neuron = { diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 8761a1f51..34747a191 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -72,7 +72,7 @@ if TYPE_CHECKING: - from ..exporters.neuron import NeuronConfig + from ..exporters.neuron import NeuronDefaultConfig logger = logging.getLogger(__name__) @@ -98,7 +98,7 @@ def __init__( tokenizer_2: Optional[CLIPTokenizer] = None, feature_extractor: Optional[CLIPFeatureExtractor] = None, configs: Optional[Dict[str, "PretrainedConfig"]] = None, - neuron_configs: Optional[Dict[str, "NeuronConfig"]] = None, + neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, ): @@ -132,7 +132,7 @@ def __init__( A model extracting features from generated images to be used as inputs for the `safety_checker` configs (Optional[Dict[str, "PretrainedConfig"]], defaults to `None`): A dictionary configurations for components of the pipeline. - neuron_configs (Optional["NeuronConfig"], defaults to `None`): + neuron_configs (Optional["NeuronDefaultConfig"], defaults to `None`): A list of Neuron configurations. model_save_dir (`Optional[Union[str, Path, TemporaryDirectory]]`, defaults to `None`): The directory under which the exported Neuron models were saved. @@ -677,7 +677,7 @@ def __init__( model: torch.jit._script.ScriptModule, parent_model: NeuronBaseModel, config: Optional[Union[DiffusersPretrainedConfig, PretrainedConfig]] = None, - neuron_config: Optional["NeuronConfig"] = None, + neuron_config: Optional["NeuronDefaultConfig"] = None, model_type: str = "unet", device: Optional[int] = None, ): @@ -829,7 +829,7 @@ def __init__( tokenizer_2: Optional[CLIPTokenizer] = None, feature_extractor: Optional[CLIPFeatureExtractor] = None, configs: Optional[Dict[str, "PretrainedConfig"]] = None, - neuron_configs: Optional[Dict[str, "NeuronConfig"]] = None, + neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, add_watermarker: Optional[bool] = None, diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py index 51efc9d62..048f189b3 100644 --- a/optimum/neuron/modeling_seq2seq.py +++ b/optimum/neuron/modeling_seq2seq.py @@ -31,7 +31,7 @@ from transformers.utils import ModelOutput from ..exporters.neuron import ( - NeuronConfig, + NeuronDefaultConfig, main_export, ) from ..exporters.tasks import TasksManager @@ -68,7 +68,7 @@ def __init__( encoder_file_name: Optional[str] = NEURON_FILE_NAME, decoder_file_name: Optional[str] = NEURON_FILE_NAME, preprocessors: Optional[List] = None, - neuron_configs: Optional[Dict[str, "NeuronConfig"]] = None, + neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None, configs: Optional[Dict[str, "PretrainedConfig"]] = None, generation_config: Optional[GenerationConfig] = None, model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, @@ -531,7 +531,7 @@ def __init__( model: torch.jit._script.ScriptModule, parent_model: NeuronBaseModel, config: Optional["PretrainedConfig"] = None, - neuron_config: Optional["NeuronConfig"] = None, + neuron_config: Optional["NeuronDefaultConfig"] = None, model_type: str = "encoder", device: Optional[int] = None, ): diff --git a/tests/cli/test_export_decoder_cli.py b/tests/cli/test_export_decoder_cli.py new file mode 100644 index 000000000..c2f81d5ad --- /dev/null +++ b/tests/cli/test_export_decoder_cli.py @@ -0,0 +1,47 @@ +import subprocess +from tempfile import TemporaryDirectory + +import pytest +from transformers import AutoConfig + +from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx + + +@is_inferentia_test +@requires_neuronx +@pytest.mark.parametrize("batch_size, sequence_length, auto_cast_type", [[1, 512, "fp16"], [2, 128, "bf16"]]) +@pytest.mark.parametrize("num_cores", [1, 2]) +def test_export_decoder_cli(batch_size, sequence_length, auto_cast_type, num_cores): + model_id = "hf-internal-testing/tiny-random-gpt2" + with TemporaryDirectory() as tempdir: + subprocess.run( + [ + "optimum-cli", + "export", + "neuron", + "--model", + model_id, + "--sequence_length", + f"{sequence_length}", + "--batch_size", + f"{batch_size}", + "--auto_cast_type", + auto_cast_type, + "--num_cores", + f"{num_cores}", + "--task", + "text-generation", + tempdir, + ], + shell=False, + check=True, + ) + # Check exported config + config = AutoConfig.from_pretrained(tempdir) + neuron_config = getattr(config, "neuron", None) + assert neuron_config is not None + assert neuron_config["batch_size"] == batch_size + assert neuron_config["sequence_length"] == sequence_length + assert neuron_config["auto_cast_type"] == auto_cast_type + assert neuron_config["num_cores"] == num_cores + assert neuron_config["checkpoint_id"] == model_id diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index 9ce117176..d7f24bd22 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -26,7 +26,7 @@ from transformers.testing_utils import require_vision from optimum.exporters.neuron import ( - NeuronConfig, + NeuronDefaultConfig, build_stable_diffusion_components_mandatory_shapes, export, export_models, @@ -104,7 +104,7 @@ def _neuronx_export( model_type: str, model_name: str, task: str, - neuron_config_constructor: "NeuronConfig", + neuron_config_constructor: "NeuronDefaultConfig", dynamic_batch_size: bool = False, ): if "sentence-transformers" in model_type: diff --git a/tests/generation/conftest.py b/tests/generation/conftest.py index fba494bf5..6e08cae1c 100644 --- a/tests/generation/conftest.py +++ b/tests/generation/conftest.py @@ -44,7 +44,9 @@ } -@pytest.fixture(scope="module", params=[DECODER_MODEL_NAMES[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES]) +@pytest.fixture( + scope="session", params=[DECODER_MODEL_NAMES[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES] +) def export_decoder_id(request): return request.param @@ -66,7 +68,7 @@ def export_seq2seq_model_class(request): return request.param -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") @requires_neuronx def neuron_decoder_path(export_decoder_id): model = NeuronModelForCausalLM.from_pretrained( @@ -161,7 +163,7 @@ def neuron_seq2seq_greedy_path_with_optional_outputs(export_seq2seq_id): yield model_path -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def neuron_push_decoder_id(export_decoder_id): model_name = export_decoder_id.split("/")[-1] repo_id = f"{USER}/{model_name}-neuronx"