diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2416dd141..f71a67b51 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -41,6 +41,10 @@ - local: guides/pipelines title: Inference pipelines with AWS Neuron title: How-To Guides + - sections: + - local: community/contributing + title: Add support for a new model architecture + title: Contribute - sections: - local: package_reference/trainer title: Neuron Trainer diff --git a/docs/source/community/contributing.mdx b/docs/source/community/contributing.mdx new file mode 100644 index 000000000..a53e54297 --- /dev/null +++ b/docs/source/community/contributing.mdx @@ -0,0 +1,115 @@ + + +# Adding support for new architectures + + + +> **_NOTE:_** ❗This section does not apply to the decoder model’s inference with autoregressive sampling integrated through `transformers-neuronx`. If you want to add support for these models, please open an issue on the Optimum Neuron GitHub repo, and ping maintainers for help. + +You want to export and run a new model on AWS Inferentia or Trainium? Check the guideline, and submit a pull request to [🤗 Optimum Neuron's GitHub repo](https://github.com/huggingface/optimum-neuron/)! + +To support a new model architecture in the Optimum Neuron library here are some steps to follow: + +1. Implement a custom Neuron configuration. +2. Export and validate the model. +3. Contribute to the GitHub repo. + +## Implement a custom Neuron configuration + +To support the export of a new model to a Neuron compatible format, the first thing to do is to define a Neuron configuration, describing how to export the PyTorch model by specifying: + +1. The input names. +2. The output names. +3. The dummy inputs used to trace the model: the Neuron Compiler records the computational graph via tracing and works on the resulting `TorchScript` module. +4. The compilation arguments used to control the trade-off between hardware efficiency (latency, throughput) and accuracy. + +Depending on the choice of model and task, we represent the data above with configuration classes. Each configuration class is associated with +a specific model architecture, and follows the naming convention `ArchitectureNameNeuronConfig`. For instance, the configuration that specifies the Neuron +export of BERT models is `BertNeuronConfig`. + +Since many architectures share similar properties for their Neuron configuration, 🤗 Optimum adopts a 3-level class hierarchy: + +1. Abstract and generic base classes. These handle all the fundamental features, while being agnostic to the modality (text, image, audio, etc). +2. Middle-end classes. These are aware of the modality. Multiple config classes could exist for the same modality, depending on the inputs they support. They specify which input generators should be used for generating the dummy inputs, but remain model-agnostic. +3. Model-specific classes like the `BertNeuronConfig` mentioned above. These are the ones actually used to export models. + +### Example: Adding support for ESM models + +Here we take the support of [ESM models](https://huggingface.co/docs/transformers/model_doc/esm#esm) as an example. Let's create an `EsmNeuronConfig` class in the `optimum/exporters/neuron/model_configs.py`. + +When an Esm model interprets as a text encoder, we are able to inherit from the middle-end class [`TextEncoderNeuronConfig`](https://github.com/huggingface/optimum-neuron/blob/v0.0.18/optimum/exporters/neuron/config.py#L36). +Since the modeling and configuration of Esm is almost the same as BERT when it is interpreted as an encoder, we can use the `NormalizedConfigManager` with `model_type=bert` to normalize the configuration to generate dummy inputs for tracing the model. + +And one last step, since `optimum-neuron` is an extension of `optimum`, we need to register the Neuron config that we create to the [TasksManager](https://huggingface.co/docs/optimum/main/en/exporters/task_manager#optimum.exporters.TasksManager) with the `register_in_tasks_manager` decorator by specifying the model type and supported tasks. + +```python + +@register_in_tasks_manager("esm", *["feature-extraction", "fill-mask", "text-classification", "token-classification"]) +class EsmNeuronConfig(TextEncoderNeuronConfig): + NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert") + ATOL_FOR_VALIDATION = 1e-3 # absolute tolerance to compare for comparing model on CPUs + + @property + def inputs(self) -> List[str]: + return ["input_ids", "attention_mask"] + +``` + +## Export and validate the model + +With the Neuron configuration class that you implemented, now do a quick test if it works as expected: + +* Export + +```bash +optimum-cli export neuron --model facebook/esm2_t33_650M_UR50D --task text-classification --batch_size 1 --sequence_length 16 esm_neuron/ +``` + +During the export [`validate_model_outputs`](https://github.com/huggingface/optimum-neuron/blob/7b18de9ddfa5c664c94051304c651eaf855c3e0b/optimum/exporters/neuron/convert.py#L136) will be called to validate the outputs of your exported Neuron model by comparing them to the results of PyTorch on the CPU. You could also validate the model manually with: + +```python +from optimum.exporters.neuron import validate_model_outputs + +validate_model_outputs( + neuron_config, base_model, neuron_model_path, neuron_named_outputs, neuron_config.ATOL_FOR_VALIDATION +) +``` + +* Inference (optional) + +```python +from transformers import AutoTokenizer +from optimum.neuron import NeuronModelForSequenceClassification + +model = NeuronModelForSequenceClassification.from_pretrained("esm_neuron/") +tokenizer = AutoTokenizer.from_pretrained("esm_neuron/") +inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") +logits = model(**inputs).logits +``` + +## Contribute to the GitHub repo + +We are almost all set. Now submit a pull request to make your work accessible to all community members! + +* Open an issue in the [Optimum Neuron GitHub repo](https://github.com/huggingface/optimum-neuron/issues) to describe the new feature and make it visible to Optimum Neuron's maintainers. +* Add the model to the exporter test in [`optimum-neuron/tests/exporters/exporters_utils.py`](https://github.com/huggingface/optimum-neuron/blob/v0.0.18/tests/exporters/exporters_utils.py) and the inference test in [`optimum-neuron/tests/inference/inference_utils.py`](https://github.com/huggingface/optimum-neuron/blob/v0.0.18/tests/inference/inference_utils.py). +* Open a pull request! (Don't forget to link it to the issue you opened, so that the maintainers could better track it and provide help when needed.) + + + + +We usually test smaller checkpoints to accelerate the CIs, you could find tiny models for testing under the [`Hugging Face Internal Testing Organization`](https://huggingface.co/hf-internal-testing). + + + +You have made a new model accessible on Neuron for the community! Thanks for joining us in the endeavor of democratizing good machine learning 🤗. \ No newline at end of file diff --git a/docs/source/guides/overview.mdx b/docs/source/guides/overview.mdx index 64ebfaa0c..0255ccc75 100644 --- a/docs/source/guides/overview.mdx +++ b/docs/source/guides/overview.mdx @@ -21,9 +21,9 @@ Welcome to the 🤗 Optimum Neuron how-to guides! These guides tackle more advanced topics and will show you how to easily get the best from AWS Trainium / Inferentia: - [How to setup AWS Trainium instance](./setup_aws_instance) -- [How to fine-tune a Transformers model with AWS Trainium](./fine_tune) - [Training and Deployment using Amazon Sagemaker](./sagemaker) - [Neuron model cache](./cache_system) +- [How to fine-tune a Transformers model with AWS Trainium](./fine_tune) - [Distributed training with AWS Neuron](./distributed_training.mdx) - [Export a model to Inferentia](./export_model) - [Neuron Model Inference](./models) diff --git a/docs/source/package_reference/export.mdx b/docs/source/package_reference/export.mdx index 912ae5d81..7f0102ecf 100644 --- a/docs/source/package_reference/export.mdx +++ b/docs/source/package_reference/export.mdx @@ -28,27 +28,6 @@ exporting function according to the environment. Besides, you can check if the exported model is valid via [`~optimum.exporters.neuron.convert.validate_model_outputs`], which compares the compiled model's output on Neuron devices to the PyTorch model's output on CPU. -## Configuration classes for Neuron exports - -Exporting a PyTorch model to a Neuron compiled model involves specifying: - -1. The input names. -2. The output names. -3. The dummy inputs used to trace the model. This is needed by the Neuron Compiler to record the computational graph and convert it to a TorchScript module. -4. The compilation arguments used to control the trade-off between hardware efficiency (latency, throughput) and accuracy. - -Depending on the choice of model and task, we represent the data above with _configuration classes_. Each configuration class is associated with -a specific model architecture, and follows the naming convention `ArchitectureNameNeuronConfig`. For instance, the configuration which specifies the Neuron -export of BERT models is `BertNeuronConfig`. - -Since many architectures share similar properties for their Neuron configuration, 🤗 Optimum adopts a 3-level class hierarchy: - -1. Abstract and generic base classes. These handle all the fundamental features, while being agnostic to the modality (text, image, audio, etc). -2. Middle-end classes. These are aware of the modality, but multiple can exist for the same modality depending on the inputs they support. - They specify which input generators should be used for the dummy inputs, but remain model-agnostic. -3. Model-specific classes like the `BertNeuronConfig` mentioned above. These are the ones actually used to export models. - - ## Supported architectures diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index a08da0826..294928bbb 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -101,6 +101,16 @@ def outputs(self) -> List[str]: return self._TASK_TO_COMMON_OUTPUTS[self.task] +@register_in_tasks_manager("esm", *["feature-extraction", "fill-mask", "text-classification", "token-classification"]) +class EsmNeuronConfig(TextEncoderNeuronConfig): + NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("bert") + ATOL_FOR_VALIDATION = 1e-3 + + @property + def inputs(self) -> List[str]: + return ["input_ids", "attention_mask"] + + @register_in_tasks_manager("flaubert", *COMMON_TEXT_TASKS) class FlaubertNeuronConfig(ElectraNeuronConfig): pass diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c373e5588..b4b8e32b9 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -23,6 +23,7 @@ "deberta-v2": "hf-internal-testing/tiny-random-DebertaV2Model", # Failed for INF1: 'XSoftmax' "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", + "esm": "hf-internal-testing/tiny-random-EsmModel", "flaubert": "flaubert/flaubert_small_cased", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mpnet": "hf-internal-testing/tiny-random-MPNetModel", diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index f59656252..80316db9a 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -178,7 +178,6 @@ def test_export_separated_weights(self, test_name, name, model_name, task, neuro @parameterized.expand(_get_models_to_test(SENTENCE_TRANSFORMERS_MODELS)) @is_inferentia_test - @require_vision @require_sentence_transformers @requires_neuronx def test_export_sentence_transformers(self, test_name, name, model_name, task, neuron_config_constructor):