Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
"""
`Introduction to ONNX <intro_onnx.html>`_ ||
**Exporting a PyTorch model to ONNX** ||
`Extending the ONNX Registry <onnx_registry_tutorial.html>`_
`Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
`Export a model with control flow to ONNX <export_control_flow_model_to_onnx_tutorial.html>`_

Export a PyTorch model to ONNX
==============================

**Author**: `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_, `Justin Chu <[email protected]>`_, `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_.

.. note::
As of PyTorch 2.1, there are two versions of ONNX Exporter.

* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0
Starting with PyTorch 2.5, there are two ONNX Exporter options available.
* ``torch.onnx.export(..., dynamo=True)`` is the recommended exporter that leverages ``torch.export`` and Torch FX for graph capture.
* ``torch.onnx.export`` is the legacy approach that relies on the deprecated TorchScript and is no longer recommended for use.

"""

###############################################################################
# In the `60 Minute Blitz <https://tutorials.pytorch.kr/beginner/deep_learning_60min_blitz.html>`_,
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
# ONNX format using the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter.
#
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
Expand All @@ -47,8 +47,7 @@
#
# .. code-block:: bash
#
# pip install onnx
# pip install onnxscript
# pip install --upgrade onnx onnxscript
#
# 2. Author a simple image classifier model
# -----------------------------------------
Expand All @@ -62,17 +61,16 @@
import torch.nn.functional as F


class MyModel(nn.Module):

class ImageClassifierModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
def forward(self, x: torch.Tensor):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
Expand All @@ -81,16 +79,18 @@ def forward(self, x):
x = self.fc3(x)
return x


######################################################################
# 3. Export the model to ONNX format
# ----------------------------------
#
# Now that we have our model defined, we need to instantiate it and create a random 32x32 input.
# Next, we can export the model to ONNX format.

torch_model = MyModel()
torch_input = torch.randn(1, 1, 32, 32)
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)
torch_model = ImageClassifierModel()
# Create example inputs for exporting the model. The inputs should be a tuple of tensors.
example_inputs = (torch.randn(1, 1, 32, 32),)
onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True)

######################################################################
# As we can see, we didn't need any code change to the model.
Expand All @@ -102,13 +102,14 @@ def forward(self, x):
# Although having the exported model loaded in memory is useful in many applications,
# we can save it to disk with the following code:

onnx_program.save("my_image_classifier.onnx")
onnx_program.save("image_classifier_model.onnx")

######################################################################
# You can load the ONNX file back into memory and check if it is well formed with the following code:

import onnx
onnx_model = onnx.load("my_image_classifier.onnx")

onnx_model = onnx.load("image_classifier_model.onnx")
onnx.checker.check_model(onnx_model)

######################################################################
Expand All @@ -124,10 +125,10 @@ def forward(self, x):
# :align: center
#
#
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
# Once Netron is open, we can drag and drop our ``image_classifier_model.onnx`` file into the browser or select it after
# clicking the **Open model** button.
#
# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
# .. image:: ../../_static/img/onnx/image_classifier_onnx_model_on_netron_web_ui.png
# :width: 50%
#
#
Expand Down Expand Up @@ -155,18 +156,18 @@ def forward(self, x):

import onnxruntime

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(torch_input)
print(f"Input length: {len(onnx_input)}")
print(f"Sample input: {onnx_input}")

ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider'])
onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs]
print(f"Input length: {len(onnx_inputs)}")
print(f"Sample input: {onnx_inputs}")

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
ort_session = onnxruntime.InferenceSession(
"./image_classifier_model.onnx", providers=["CPUExecutionProvider"]
)

onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}
onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
# ONNX Runtime returns a list of outputs
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0]

####################################################################
# 7. Compare the PyTorch results with the ones from the ONNX Runtime
Expand All @@ -178,8 +179,7 @@ def to_numpy(tensor):
# For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's.
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.

torch_outputs = torch_model(torch_input)
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)
torch_outputs = torch_model(*example_inputs)

assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
Expand Down