Skip to content

Commit 70610f5

Browse files
rolshovenLuca Rolshoven
authored andcommitted
Fixed order of input parameters for onnx export
1 parent 4ebee43 commit 70610f5

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/setfit/exporters/onnx.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import warnings
33
from typing import Callable, Optional, Union
4+
from inspect import signature
45

56
import numpy as np
67
import onnx
@@ -87,9 +88,12 @@ def export_onnx_setfit_model(setfit_model: OnnxSetFitModel, inputs, output_path,
8788
for output_name in output_names:
8889
dynamic_axes_output[output_name] = {0: "batch_size"}
8990

90-
# Move inputs to the right device
91+
# Move inputs to the right device and put them in the right order
92+
forward_params = tuple(signature(setfit_model.model_body.forward).parameters.keys()) # keys of ordered dict are ordered
93+
ordered_kwargs = sorted(inputs.items(), key=lambda param: forward_params.index(param[0]))
94+
ordered_params = [param_value for (_, param_value) in ordered_kwargs]
9195
target = setfit_model.model_body.device
92-
args = tuple(value.to(target) for value in inputs.values())
96+
args = tuple(value.to(target) for value in ordered_params)
9397

9498
setfit_model.eval()
9599
with torch.no_grad():

0 commit comments

Comments
 (0)