diff --git a/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py index 895be83c..4cc8cb3c 100644 --- a/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py +++ b/beginner_source/onnx/export_simple_model_to_onnx_tutorial.py @@ -2,18 +2,18 @@ """ `Introduction to ONNX `_ || **Exporting a PyTorch model to ONNX** || -`Extending the ONNX Registry `_ +`Extending the ONNX exporter operator support `_ || +`Export a model with control flow to ONNX `_ Export a PyTorch model to ONNX ============================== -**Author**: `Thiago Crepaldi `_ +**Author**: `Ti-Tai Wang `_, `Justin Chu `_, `Thiago Crepaldi `_. .. note:: - As of PyTorch 2.1, there are two versions of ONNX Exporter. - - * ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0 - * ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0 + Starting with PyTorch 2.5, there are two ONNX Exporter options available. + * ``torch.onnx.export(..., dynamo=True)`` is the recommended exporter that leverages ``torch.export`` and Torch FX for graph capture. + * ``torch.onnx.export`` is the legacy approach that relies on the deprecated TorchScript and is no longer recommended for use. """ @@ -21,7 +21,7 @@ # In the `60 Minute Blitz `_, # we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images. # In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the -# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter. +# ONNX format using the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter. # # While PyTorch is great for iterating on the development of models, the model can be deployed to production # using different formats, including `ONNX `_ (Open Neural Network Exchange)! @@ -47,8 +47,7 @@ # # .. code-block:: bash # -# pip install onnx -# pip install onnxscript +# pip install --upgrade onnx onnxscript # # 2. Author a simple image classifier model # ----------------------------------------- @@ -62,17 +61,16 @@ import torch.nn.functional as F -class MyModel(nn.Module): - +class ImageClassifierModel(nn.Module): def __init__(self): - super(MyModel, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) - def forward(self, x): + def forward(self, x: torch.Tensor): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = torch.flatten(x, 1) @@ -81,6 +79,7 @@ def forward(self, x): x = self.fc3(x) return x + ###################################################################### # 3. Export the model to ONNX format # ---------------------------------- @@ -88,9 +87,10 @@ def forward(self, x): # Now that we have our model defined, we need to instantiate it and create a random 32x32 input. # Next, we can export the model to ONNX format. -torch_model = MyModel() -torch_input = torch.randn(1, 1, 32, 32) -onnx_program = torch.onnx.dynamo_export(torch_model, torch_input) +torch_model = ImageClassifierModel() +# Create example inputs for exporting the model. The inputs should be a tuple of tensors. +example_inputs = (torch.randn(1, 1, 32, 32),) +onnx_program = torch.onnx.export(torch_model, example_inputs, dynamo=True) ###################################################################### # As we can see, we didn't need any code change to the model. @@ -102,13 +102,14 @@ def forward(self, x): # Although having the exported model loaded in memory is useful in many applications, # we can save it to disk with the following code: -onnx_program.save("my_image_classifier.onnx") +onnx_program.save("image_classifier_model.onnx") ###################################################################### # You can load the ONNX file back into memory and check if it is well formed with the following code: import onnx -onnx_model = onnx.load("my_image_classifier.onnx") + +onnx_model = onnx.load("image_classifier_model.onnx") onnx.checker.check_model(onnx_model) ###################################################################### @@ -124,10 +125,10 @@ def forward(self, x): # :align: center # # -# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after +# Once Netron is open, we can drag and drop our ``image_classifier_model.onnx`` file into the browser or select it after # clicking the **Open model** button. # -# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png +# .. image:: ../../_static/img/onnx/image_classifier_onnx_model_on_netron_web_ui.png # :width: 50% # # @@ -155,18 +156,18 @@ def forward(self, x): import onnxruntime -onnx_input = onnx_program.adapt_torch_inputs_to_onnx(torch_input) -print(f"Input length: {len(onnx_input)}") -print(f"Sample input: {onnx_input}") - -ort_session = onnxruntime.InferenceSession("./my_image_classifier.onnx", providers=['CPUExecutionProvider']) +onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs] +print(f"Input length: {len(onnx_inputs)}") +print(f"Sample input: {onnx_inputs}") -def to_numpy(tensor): - return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() +ort_session = onnxruntime.InferenceSession( + "./image_classifier_model.onnx", providers=["CPUExecutionProvider"] +) -onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)} +onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)} -onnxruntime_outputs = ort_session.run(None, onnxruntime_input) +# ONNX Runtime returns a list of outputs +onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0] #################################################################### # 7. Compare the PyTorch results with the ones from the ONNX Runtime @@ -178,8 +179,7 @@ def to_numpy(tensor): # For that, we need to execute the PyTorch model with the same input and compare the results with ONNX Runtime's. # Before comparing the results, we need to convert the PyTorch's output to match ONNX's format. -torch_outputs = torch_model(torch_input) -torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs) +torch_outputs = torch_model(*example_inputs) assert len(torch_outputs) == len(onnxruntime_outputs) for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):