Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/setfit/exporters/onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import warnings
from typing import Callable, Optional, Union
from inspect import signature

import numpy as np
import onnx
Expand Down Expand Up @@ -87,9 +88,12 @@ def export_onnx_setfit_model(setfit_model: OnnxSetFitModel, inputs, output_path,
for output_name in output_names:
dynamic_axes_output[output_name] = {0: "batch_size"}

# Move inputs to the right device
# Move inputs to the right device and put them in the right order
forward_params = tuple(signature(setfit_model.model_body.forward).parameters.keys()) # keys of ordered dict are ordered
ordered_kwargs = sorted(inputs.items(), key=lambda param: forward_params.index(param[0]))
ordered_params = [param_value for (_, param_value) in ordered_kwargs]
target = setfit_model.model_body.device
args = tuple(value.to(target) for value in inputs.values())
args = tuple(value.to(target) for value in ordered_params)

setfit_model.eval()
with torch.no_grad():
Expand Down