Skip to content

Commit cdc2701

Browse files
NXP Backend: Update documentation to the new scheme (#15219)
Updating per template in #14873 Signed-off-by: Robert Kalmar <[email protected]>
1 parent 3a262ef commit cdc2701

File tree

9 files changed

+296
-120
lines changed

9 files changed

+296
-120
lines changed

docs/source/backends-nxp.md

Lines changed: 0 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +0,0 @@
1-
# NXP eIQ Neutron Backend
2-
3-
This manual page is dedicated to introduction of using the ExecuTorch with NXP eIQ Neutron Backend.
4-
NXP offers accelerated machine learning models inference on edge devices.
5-
To learn more about NXP's machine learning acceleration platform, please refer to [the official NXP website](https://www.nxp.com/applications/technologies/ai-and-machine-learning:MACHINE-LEARNING).
6-
7-
<div class="admonition tip">
8-
For up-to-date status about running ExecuTorch on Neutron Backend please visit the <a href="https://github.com/pytorch/executorch/blob/main/backends/nxp/README.md">manual page</a>.
9-
</div>
10-
11-
## Features
12-
13-
ExecuTorch v1.0 supports running machine learning models on selected NXP chips (for now only i.MXRT700).
14-
Among currently supported machine learning models are:
15-
- Convolution-based neutral networks
16-
- Full support for MobileNetV2 and CifarNet
17-
18-
## Prerequisites (Hardware and Software)
19-
20-
In order to successfully build ExecuTorch project and convert models for NXP eIQ Neutron Backend you will need a computer running Linux.
21-
22-
If you want to test the runtime, you'll also need:
23-
- Hardware with NXP's [i.MXRT700](https://www.nxp.com/products/i.MX-RT700) chip or a testing board like MIMXRT700-AVK
24-
- [MCUXpresso IDE](https://www.nxp.com/design/design-center/software/development-software/mcuxpresso-software-and-tools-/mcuxpresso-integrated-development-environment-ide:MCUXpresso-IDE) or [MCUXpresso Visual Studio Code extension](https://www.nxp.com/design/design-center/software/development-software/mcuxpresso-software-and-tools-/mcuxpresso-for-visual-studio-code:MCUXPRESSO-VSC)
25-
26-
## Using NXP backend
27-
28-
To test converting a neural network model for inference on NXP eIQ Neutron Backend, you can use our example script:
29-
30-
```shell
31-
# cd to the root of executorch repository
32-
./examples/nxp/aot_neutron_compile.sh [model (cifar10 or mobilenetv2)]
33-
```
34-
35-
For a quick overview how to convert a custom PyTorch model, take a look at our [example python script](https://github.com/pytorch/executorch/tree/release/1.0/examples/nxp/aot_neutron_compile.py).
36-
37-
### Partitioner API
38-
39-
The partitioner is defined in `NeutronPartitioner` in `backends/nxp/neutron_partitioner.py`. It has the following
40-
arguments:
41-
* `compile_spec` - list of key-value pairs defining compilation. E.g. for specifying platform (i.MXRT700) and Neutron Converter flavor.
42-
* `custom_delegation_options` - custom options for specifying node delegation.
43-
44-
### Quantization
45-
46-
The quantization for Neutron Backend is defined in `NeutronQuantizer` in `backends/nxp/quantizer/neutron_quantizer.py`.
47-
The quantization follows PT2E workflow, INT8 quantization is supported. Operators are quantized statically, activations
48-
follow affine and weights symmetric per-tensor quantization scheme.
49-
50-
#### Supported operators
51-
52-
List of Aten operators supported by Neutron quantizer:
53-
54-
`abs`, `adaptive_avg_pool2d`, `addmm`, `add.Tensor`, `avg_pool2d`, `cat`, `conv1d`, `conv2d`, `dropout`,
55-
`flatten.using_ints`, `hardtanh`, `hardtanh_`, `linear`, `max_pool2d`, `mean.dim`, `pad`, `permute`, `relu`, `relu_`,
56-
`reshape`, `view`, `softmax.int`, `sigmoid`, `tanh`, `tanh_`
57-
58-
#### Example
59-
60-
To quantize your model, you can either use the PT2E workflow:
61-
```python
62-
import torch
63-
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
64-
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
65-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
66-
67-
# Prepare your model in Aten dialect
68-
aten_model = get_model_in_aten_dialect()
69-
# Prepare calibration inputs, each tuple is one example, example tuple has items for each model input
70-
calibration_inputs: list[tuple[torch.Tensor, ...]] = get_calibration_inputs()
71-
target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09")
72-
quantizer = NeutronQuantizer(neutron_target_spec)
73-
74-
m = prepare_pt2e(aten_model, quantizer)
75-
for data in calibration_inputs:
76-
m(*data)
77-
m = convert_pt2e(m)
78-
```
79-
80-
Or you can use the predefined function for post training quantization from NXP backend implementation:
81-
```python
82-
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
83-
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
84-
from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize
85-
86-
...
87-
88-
target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09")
89-
quantized_graph_module = calibrate_and_quantize(
90-
aten_model,
91-
calibration_inputs,
92-
NeutronQuantizer(neutron_target_spec=target_spec),
93-
)
94-
```
95-
96-
## Runtime Integration
97-
98-
To learn how to run the converted model on the NXP hardware, use one of our example projects on using ExecuTorch runtime from MCUXpresso IDE example projects list.
99-
For more finegrained tutorial, visit [this manual page](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/topics/example_applications.html).

docs/source/backends-overview.md

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,20 @@ Backends are the bridge between your exported model and the hardware it runs on.
1818

1919
## Choosing a Backend
2020

21-
| Backend | Platform(s) | Hardware Type | Typical Use Case |
22-
|-----------------------------------------------------------------|---------------------|---------------|---------------------------------|
23-
| [XNNPACK](backends/xnnpack/xnnpack-overview.md) | All | CPU | General-purpose, fallback |
24-
| [Core ML](/backends/coreml/coreml-overview.md) | iOS, macOS | NPU/GPU/CPU | Apple devices, high performance |
25-
| [Metal Performance Shaders](/backends/mps/mps-overview.md) | iOS, macOS | GPU | Apple GPU acceleration |
26-
| [Vulkan ](/backends/vulkan/vulkan-overview.md) | Android | GPU | Android GPU acceleration |
27-
| [Qualcomm](backends-qualcomm) | Android | NPU | Qualcomm SoCs |
28-
| [MediaTek](backends-mediatek) | Android | NPU | MediaTek SoCs |
29-
| [Arm Ethos-U](/backends/arm-ethos-u/arm-ethos-u-overview.md) | Embedded | NPU | Arm MCUs |
30-
| [Arm VGF](/backends/arm-vgf/arm-vgf-overview.md) | Android | GPU | Arm platforms |
31-
| [OpenVINO](build-run-openvino) | Embedded | CPU/GPU/NPU | Intel SoCs |
32-
| [NXP](backends-nxp) | Embedded | NPU | NXP SoCs |
33-
| [Cadence](backends-cadence) | Embedded | DSP | DSP-optimized workloads |
34-
| [Samsung Exynos](/backends/samsung/samsung-overview.md) | Android | NPU | Samsung SoCs |
21+
| Backend | Platform(s) | Hardware Type | Typical Use Case |
22+
|--------------------------------------------------------------|-------------|---------------|---------------------------------|
23+
| [XNNPACK](backends/xnnpack/xnnpack-overview.md) | All | CPU | General-purpose, fallback |
24+
| [Core ML](/backends/coreml/coreml-overview.md) | iOS, macOS | NPU/GPU/CPU | Apple devices, high performance |
25+
| [Metal Performance Shaders](/backends/mps/mps-overview.md) | iOS, macOS | GPU | Apple GPU acceleration |
26+
| [Vulkan ](/backends/vulkan/vulkan-overview.md) | Android | GPU | Android GPU acceleration |
27+
| [Qualcomm](backends-qualcomm) | Android | NPU | Qualcomm SoCs |
28+
| [MediaTek](backends-mediatek) | Android | NPU | MediaTek SoCs |
29+
| [Arm Ethos-U](/backends/arm-ethos-u/arm-ethos-u-overview.md) | Embedded | NPU | Arm MCUs |
30+
| [Arm VGF](/backends/arm-vgf/arm-vgf-overview.md) | Android | GPU | Arm platforms |
31+
| [OpenVINO](build-run-openvino) | Embedded | CPU/GPU/NPU | Intel SoCs |
32+
| [NXP](backends/nxp/nxp-overview.md) | Embedded | NPU | NXP SoCs |
33+
| [Cadence](backends-cadence) | Embedded | DSP | DSP-optimized workloads |
34+
| [Samsung Exynos](/backends/samsung/samsung-overview.md) | Android | NPU | Samsung SoCs |
3535

3636
**Tip:** For best performance, export a `.pte` file for each backend you plan to support.
3737

@@ -50,15 +50,15 @@ Backends are the bridge between your exported model and the hardware it runs on.
5050
:hidden:
5151
:caption: Backend Overview
5252
53-
backends/xnnpack/xnnpack-overview
53+
backends-xnnpack
5454
backends/coreml/coreml-overview
55-
backends/mps/mps-overview
56-
backends/vulkan/vulkan-overview
55+
backends-mps
56+
backends-vulkan
5757
backends-qualcomm
5858
backends-mediatek
59-
backends/arm-ethos-u/arm-ethos-u-overview
60-
backends/arm-vgf/arm-vgf-overview
59+
backends-arm-ethos-u
60+
backends-arm-vgf
6161
build-run-openvino
6262
backends-nxp
6363
backends-cadence
64-
backends/samsung/samsung-overview
64+
backends-samsung-exynos
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# NXP eIQ Neutron Backend
2+
3+
This manual page is dedicated to introduction NXP eIQ Neutron backend.
4+
NXP offers accelerated machine learning models inference on edge devices.
5+
To learn more about NXP's machine learning acceleration platform, please refer to [the official NXP website](https://www.nxp.com/applications/technologies/ai-and-machine-learning:MACHINE-LEARNING).
6+
7+
<div class="admonition tip">
8+
For up-to-date status about running ExecuTorch on Neutron backend please visit the <a href="https://github.com/pytorch/executorch/blob/main/backends/nxp/README.md">manual page</a>.
9+
</div>
10+
11+
## Features
12+
13+
14+
ExecuTorch v1.0 supports running machine learning models on selected NXP chips (for now only i.MXRT700).
15+
Among currently supported machine learning models are:
16+
- Convolution-based neutral networks
17+
- Full support for MobileNetV2 and CifarNet
18+
19+
## Target Requirements
20+
21+
- Hardware with NXP's [i.MXRT700](https://www.nxp.com/products/i.MX-RT700) chip or a evaluation board like MIMXRT700-EVK.
22+
23+
## Development Requirements
24+
25+
- [MCUXpresso IDE](https://www.nxp.com/design/design-center/software/development-software/mcuxpresso-software-and-tools-/mcuxpresso-integrated-development-environment-ide:MCUXpresso-IDE) or [MCUXpresso Visual Studio Code extension](https://www.nxp.com/design/design-center/software/development-software/mcuxpresso-software-and-tools-/mcuxpresso-for-visual-studio-code:MCUXPRESSO-VSC)
26+
- [MCUXpresso SDK 25.06](https://mcuxpresso.nxp.com/mcuxsdk/25.06.00/html/index.html)
27+
- eIQ Neutron Converter for MCUXPresso SDK 25.06, what you can download from eIQ PyPI:
28+
29+
```commandline
30+
$ pip install --index-url https://eiq.nxp.com/repository neutron_converter_SDK_25_06
31+
```
32+
33+
Instead of manually installing requirements, except MCUXpresso IDE and SDK, you can use the setup script:
34+
```commandline
35+
$ ./examples/nxp/setup.sh
36+
```
37+
38+
## Using NXP eIQ Backend
39+
40+
To test converting a neural network model for inference on NXP eIQ Neutron backend, you can use our example script:
41+
42+
```shell
43+
# cd to the root of executorch repository
44+
./examples/nxp/aot_neutron_compile.sh [model (cifar10 or mobilenetv2)]
45+
```
46+
47+
For a quick overview how to convert a custom PyTorch model, take a look at our [example python script](https://github.com/pytorch/executorch/tree/release/1.0/examples/nxp/aot_neutron_compile.py).
48+
49+
50+
## Runtime Integration
51+
52+
To learn how to run the converted model on the NXP hardware, use one of our example projects on using ExecuTorch runtime from MCUXpresso IDE example projects list.
53+
For more finegrained tutorial, visit [this manual page](https://mcuxpresso.nxp.com/mcuxsdk/latest/html/middleware/eiq/executorch/docs/nxp/topics/example_applications.html).
54+
55+
## Reference
56+
57+
**→{doc}`nxp-partitioner` — Partitioner options.**
58+
59+
**→{doc}`nxp-quantization` — Supported quantization schemes.**
60+
61+
**→{doc}`tutorials/nxp-tutorials` — Tutorials.**
62+
63+
```{toctree}
64+
:maxdepth: 2
65+
:hidden:
66+
:caption: NXP Backend
67+
68+
nxp-partitioner
69+
nxp-quantization
70+
tutorials/nxp-tutorials
71+
```
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
===============
2+
Partitioner API
3+
===============
4+
5+
The Neutron partitioner API allows for configuration of the model delegation to Neutron. Passing an ``NeutronPartitioner`` instance with no additional parameters will run as much of the model as possible on the Neutron backend. This is the most common use-case.
6+
7+
It has the following arguments:
8+
9+
* `compile_spec` - list of key-value pairs defining compilation:
10+
* `custom_delegation_options` - custom options for specifying node delegation.
11+
12+
--------------------
13+
Compile Spec Options
14+
--------------------
15+
To generate the Compile Spec for Neutron backend, you can use the `generate_neutron_compile_spec` function or directly the `NeutronCompileSpecBuilder().neutron_compile_spec()`
16+
Following fields can be set:
17+
18+
* `config` - NXP platform defining the Neutron NPU configuration, e.g. "imxrt700".
19+
* `neutron_converter_flavor` - Flavor of the neutron-converter module to use. Neutron-converter module named neutron_converter_SDK_25_06' has flavor 'SDK_25_06'. You shall set the flavour to the MCUXpresso SDK version you will use.
20+
* `extra_flags` - Extra flags for the Neutron compiler.
21+
* `operators_not_to_delegate` - List of operators that will not be delegated.
22+
23+
-------------------------
24+
Custom Delegation Options
25+
-------------------------
26+
By default the Neutron backend is defensive, what means it does not delegate operators which cannot be decided statically during partitioning. But as the model author you typically have insight into the model and so you can allow opportunistic delegation for some cases. For list of options, see
27+
`CustomDelegationOptions <https://github.com/pytorch/executorch/blob/release/1.0/backends/nxp/backend/custom_delegation_options.py#L11>`_
28+
29+
================
30+
Operator Support
31+
================
32+
33+
Operators are the building blocks of the ML model. See `IRs <https://docs.pytorch.org/docs/stable/torch.compiler_ir.html>`_ for more information on the PyTorch operator set.
34+
35+
This section lists the Edge operators supported by the Neutron backend.
36+
For detailed constraints of the operators see the conditions in the ``is_supported_*`` functions in the `Node converters <https://github.com/pytorch/executorch/blob/release/1.0/backends/nxp/neutron_partitioner.py#L192>`_
37+
38+
39+
.. csv-table:: Operator Support
40+
:file: op-support.csv
41+
:header-rows: 1
42+
:widths: 20 15 30 30
43+
:align: center
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# NXP eIQ Neutron Quantization
2+
3+
The eIQ Neutron NPU requires the operators delegated to be quantized. To quantize the PyTorch model for the Neutron backend, use the `NeutronQuantizer` from `backends/nxp/quantizer/neutron_quantizer.py`.
4+
The `NeutronQuantizer` is configured to quantize the model with quantization scheme supported by the eIQ Neutron NPU.
5+
6+
### Supported Quantization Schemes
7+
8+
The Neutron delegate supports the following quantization schemes:
9+
10+
- Static quantization with 8-bit symmetric weights and 8-bit asymmetric activations (via the PT2E quantization flow), per-tensor granularity.
11+
- Following operators are supported at this moment:
12+
- `aten.abs.default`
13+
- `aten.adaptive_avg_pool2d.default`
14+
- `aten.addmm.default`
15+
- `aten.add.Tensor`
16+
- `aten.avg_pool2d.default`
17+
- `aten.cat.default`
18+
- `aten.conv1d.default`
19+
- `aten.conv2d.default`
20+
- `aten.dropout.default`
21+
- `aten.flatten.using_ints`
22+
- `aten.hardtanh.default`
23+
- `aten.hardtanh_.default`
24+
- `aten.linear.default`
25+
- `aten.max_pool2d.default`
26+
- `aten.mean.dim`
27+
- `aten.mul.Tensor`
28+
- `aten.pad.default`
29+
- `aten.permute.default`
30+
- `aten.relu.default` and `aten.relu_.default`
31+
- `aten.reshape.default`
32+
- `aten.view.default`
33+
- `aten.softmax.int`
34+
- `aten.tanh.default`, `aten.tanh_.default`
35+
- `aten.sigmoid.default`
36+
- `aten.slice_copy.Tensor`
37+
38+
### Static 8-bit Quantization Using the PT2E Flow
39+
40+
To perform 8-bit quantization with the PT2E flow, perform the following steps prior to exporting the model to edge:
41+
42+
1) Create an instance of the `NeutronQuantizer` class.
43+
2) Use `torch.export.export` to export the model to ATen Dialect.
44+
3) Call `prepare_pt2e` with the instance of the `NeutronQuantizer` to annotate the model with observers for quantization.
45+
4) As static quantization is required, run the prepared model with representative samples to calibrate the quantized tensor activation ranges.
46+
5) Call `convert_pt2e` to quantize the model.
47+
6) Export and lower the model using the standard flow.
48+
49+
The output of `convert_pt2e` is a PyTorch model which can be exported and lowered using the normal flow. As it is a regular PyTorch model, it can also be used to evaluate the accuracy of the quantized model using standard PyTorch techniques.
50+
51+
To quantize the model, you can use the PT2E workflow:
52+
53+
```python
54+
import torch
55+
import torchvision.models as models
56+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
57+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
58+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
59+
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
60+
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
61+
from executorch.exir import to_edge_transform_and_lower
62+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
63+
64+
model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval()
65+
sample_inputs = (torch.randn(1, 3, 224, 224), )
66+
67+
target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09")
68+
quantizer = NeutronQuantizer(neutron_target_spec) # (1)
69+
70+
training_ep = torch.export.export(model, sample_inputs).module() # (2)
71+
prepared_model = prepare_pt2e(training_ep, quantizer) # (3)
72+
73+
for cal_sample in [torch.randn(1, 3, 224, 224)]: # Replace with representative model inputs
74+
prepared_model(cal_sample) # (4) Calibrate
75+
76+
quantized_model = convert_pt2e(prepared_model) # (5)
77+
78+
compile_spec = generate_neutron_compile_spec(
79+
"imxrt700",
80+
operators_not_to_delegate=None,
81+
neutron_converter_flavor="SDK_25_06",
82+
)
83+
84+
et_program = to_edge_transform_and_lower( # (6)
85+
torch.export.export(quantized_model, sample_inputs),
86+
partitioner=[NeutronPartitioner(compile_spec=compile_spec)],
87+
).to_executorch()
88+
```
89+
90+
Or you can use the predefined function for post training quantization from NXP Backend implementation:
91+
```python
92+
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
93+
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
94+
from executorch.backends.nxp.quantizer.utils import calibrate_and_quantize
95+
96+
...
97+
98+
target_spec = NeutronTargetSpec(target="imxrt700", converter_flavor="SDK_25_09")
99+
quantized_graph_module = calibrate_and_quantize(
100+
aten_model,
101+
calibration_inputs,
102+
NeutronQuantizer(neutron_target_spec=target_spec),
103+
)
104+
```
105+
106+
See [PyTorch 2 Export Post Training Quantization](https://docs.pytorch.org/ao/main/tutorials_source/pt2e_quant_ptq.html) for more information.

0 commit comments

Comments
 (0)