Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def main_export(
framework: str | None = "pt",
atol: float | None = None,
pad_token_id: int | None = None,
# inference kwargs
inf_kwargs: dict[str, Any] | None = None,
# module_arch_configs
module_arch_fields: dict[str, list[str]] | None = None,
# flag for export_by_inference
export_by_inference: bool = False,
# hub options
subfolder: str = "",
revision: str = "main",
Expand Down Expand Up @@ -416,6 +422,9 @@ def main_export(
use_subprocess=use_subprocess,
do_constant_folding=do_constant_folding,
slim=slim,
inf_kwargs=inf_kwargs,
module_arch_fields=module_arch_fields,
export_by_inference=export_by_inference,
**kwargs_shapes,
)

Expand Down
46 changes: 46 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from optimum.exporters.onnx.utils import (
PickableInferenceSession,
_get_submodels_and_onnx_configs,
_get_submodels_and_tensors_,
recursive_to_device,
)
from optimum.exporters.tasks import TasksManager
Expand Down Expand Up @@ -470,6 +471,7 @@ def export_pytorch(
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: dict[str, Any] | None = None,
export_by_inference: bool = False,
) -> tuple[list[str], list[str]]:
"""Exports a PyTorch model to an ONNX Intermediate Representation.

Expand Down Expand Up @@ -528,6 +530,9 @@ def export_pytorch(
if input_shapes is None:
input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES

if export_by_inference is True:
input_shapes = {}

# Check that inputs match, and order them properly
dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes)

Expand Down Expand Up @@ -628,6 +633,7 @@ def export_models(
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: dict[str, Any] | None = None,
export_by_inference: bool = False,
) -> tuple[list[list[str]], list[list[str]]]:
"""Exports a Pytorch encoder decoder model to an ONNX Intermediate Representation.
The following method exports the encoder and decoder components of the model as separate
Expand Down Expand Up @@ -696,6 +702,7 @@ def export_models(
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
export_by_inference=export_by_inference,
)
)

Expand All @@ -715,6 +722,7 @@ def export(
no_dynamic_axes: bool = False,
do_constant_folding: bool = True,
model_kwargs: dict[str, Any] | None = None,
export_by_inference: bool = False,
) -> tuple[list[str], list[str]]:
"""Exports a Pytorch model to an ONNX Intermediate Representation.

Expand Down Expand Up @@ -794,6 +802,7 @@ def export(
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
export_by_inference=export_by_inference,
)

else:
Expand All @@ -802,6 +811,8 @@ def export(
)

if not disable_dynamic_axes_fix:
if export_by_inference is True:
input_shapes = {}
config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype)
return export_output

Expand All @@ -826,6 +837,9 @@ def onnx_export_from_model(
use_subprocess: bool = False,
do_constant_folding: bool = True,
slim: bool = False,
inf_kwargs: dict[str,Any] | None = None,
module_arch_fields: dict[str, list[str]] | None = None,
export_by_inference: bool = False,
**kwargs_shapes,
):
"""Full-suite ONNX export function, exporting **from a pre-loaded PyTorch model**. This function is especially useful in case one needs to do modifications on the model, as overriding a forward call, before exporting to ONNX.
Expand Down Expand Up @@ -981,6 +995,16 @@ def onnx_export_from_model(
f"Exporting with a sequence length of 1 a {model_type} model is not supported and can yield unexpected results."
)

if export_by_inference is True:
# inference model to trace input and output tensor shape
models_and_inputs, models_and_outputs = _get_submodels_and_tensors_(
model=model,
inf_kwargs=inf_kwargs,
)
else:
models_and_inputs = None
models_and_outputs = None

onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs(
model=model,
task=task,
Expand All @@ -993,6 +1017,9 @@ def onnx_export_from_model(
_variant=_variant,
library_name=library_name,
model_kwargs=model_kwargs,
models_and_inputs=models_and_inputs,
models_and_outputs=models_and_outputs,
module_arch_fields=module_arch_fields,
)

if library_name != "diffusers":
Expand Down Expand Up @@ -1088,8 +1115,27 @@ def onnx_export_from_model(
no_dynamic_axes=no_dynamic_axes,
do_constant_folding=do_constant_folding,
model_kwargs=model_kwargs,
export_by_inference=export_by_inference,
)

if models_and_outputs is not None:
import json
output_dir = os.path.join(output, "io_binding")
os.makedirs(output_dir, exist_ok=True)

for module_name, dummy_outputs in models_and_outputs.items():
# convert tuple -> list for json
serializable = {
name: list(shape)
for name, shape in dummy_outputs.items()
}

file_path = os.path.join(output_dir, f"{module_name}_outputs.json")

with open(file_path, "w") as f:
json.dump(serializable, f, indent=4)
print(f"Saved: {file_path}")

if optimize is not None:
from optimum.onnxruntime import AutoOptimizationConfig, ORTOptimizer

Expand Down
21 changes: 21 additions & 0 deletions optimum/exporters/onnx/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DummyAudioInputGenerator,
DummyPastKeyValuesGenerator,
DummyTransformerTextInputGenerator,
DummyInputGenerator,
NormalizedTextConfig,
is_transformers_version,
)
Expand Down Expand Up @@ -108,3 +109,23 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return super().generate(
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
)

class DummyTupleInputGenerator(DummyInputGenerator):

def __init__(self, task: str, config_dim: dict[str, int], **kwargs):
super().__init__()
self.config_dim = config_dim
self.padding_side = "right"

def generate(self, input_name: str,
tensor_shape: tuple[int, ...],
framework: str = "pt",
int_dtype: str = "int64",
float_dtype: str = "fp32"):
if "input_id" in input_name:
min_value = 0
max_value = self.config_dim.get("vocab_size", 1000)
return self.random_int_tensor(list(tensor_shape), max_value, min_value=min_value, framework=framework, dtype=int_dtype)
elif "mask" in input_name:
return self.random_mask_tensor(list(tensor_shape), padding_side=self.padding_side, framework=framework, dtype=int_dtype)
return self.random_float_tensor(list(tensor_shape), framework=framework, dtype=float_dtype)
96 changes: 96 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DummyMoonshineAudioInputGenerator,
DummySanaTransforemerTextInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
DummyTupleInputGenerator,
)
from optimum.exporters.onnx.model_patcher import (
BigBirdPegasusModelPatcher,
Expand Down Expand Up @@ -2852,3 +2853,98 @@ def outputs(self) -> dict[str, dict[int, str]]:
3: f"latent_width * {up_sampling_factor}",
}
}


class DummyOnnxConfig(OnnxConfig):

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyTupleInputGenerator,
)

def __init__(
self,
config: PretrainedConfig,
task: str = "text-encoding",
preprocessors: list[Any] | None = None,
int_dtype: str = "int64",
float_dtype: str = "fp16",
model_inputs: dict[str, Any] | None = None,
model_outputs: dict[str, Any] | None = None,
config_dim: dict[str, int] | None = None,
):
super().__init__(config=config, task=task, preprocessors=preprocessors, int_dtype=int_dtype, float_dtype=float_dtype)
self.task = task
self.model_inputs = model_inputs
self.model_outputs = model_outputs
self.dummy_tuple_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](task=task, config_dim=config_dim)
self.config_dim = config_dim

def infer_dynamic_dims(self, tensor_shape: tuple[int, ...], config_dim: dict[str, int], name: str="input") -> dict[int, str]:
dynamic = {}
for idx, dim in enumerate(tensor_shape):
# Batch is always dynamic
if idx == 0:
dynamic[idx] = "batch"
continue

find_match = False
for key, value in config_dim.items():
if value == dim:
find_match = True
break

if find_match is True:
continue

dynamic[idx] = f"{name}_dim_{idx}"
return dynamic

@property
def inputs(self) -> dict[str,dict[int,str]]:
model_inputs_dynamic_axes = {}
if self.task == "text-encoding":
for key, value in self.model_inputs.items():
model_inputs_dynamic_axes[key] = {0: "batch_size", 1: "sequence_length"}
return model_inputs_dynamic_axes
if self.task == "backbone":
for key, value in self.model_inputs.items():
model_inputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, key)
return model_inputs_dynamic_axes
if self.task == "sample_encode":
for key, value in self.model_inputs.items():
model_inputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "encode")
return model_inputs_dynamic_axes
if self.task == "latent_decode":
for key, value in self.model_inputs.items():
model_inputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "decode")
return model_inputs_dynamic_axes
return model_inputs_dynamic_axes

@property
def outputs(self) -> dict[str, dict[int, str]]:
model_outputs_dynamic_axes = {}
if self.task == "text-encoding":
for key, value in self.model_outputs.items():
model_outputs_dynamic_axes[key] = {0: "batch_size", 1: "sequence_length"}
return model_outputs_dynamic_axes
if self.task == "backbone":
for key, value in self.model_outputs.items():
model_outputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, key)
return model_outputs_dynamic_axes
if self.task == "sample_encode":
for key, value in self.model_outputs.items():
model_outputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "encode")
return model_outputs_dynamic_axes
if self.task == "latent_decode":
for key, value in self.model_outputs.items():
model_outputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "decode")
return model_outputs_dynamic_axes
return model_outputs_dynamic_axes

def generate_dummy_inputs(self, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp16"):
dummy_inputs = {}
for key, value in self.model_inputs.items():
dummy_inputs[key] = self.dummy_tuple_input_generator.generate(key, value, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype)
return dummy_inputs
23 changes: 23 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,29 @@ def __ior_(g, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value:

torch.onnx.register_custom_op_symbolic("aten::__ior__", __ior_, 14)


@symbolic_helper.parse_args("v", "v", "v")
def upsample_nearest_exact_symbolic(g, input, output_size, scale_h=None) -> torch._C.Value:
# Compute scales from scale_h
scales = g.op("Concat", g.op("Constant", value_t=torch.tensor([1.0, 1.0], dtype=torch.float32)), scale_h, axis_i=0)
empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))

return g.op(
"Resize",
input,
empty_roi, # roi (unused for nearest)
scales,
mode_s="nearest",
coordinate_transformation_mode_s="half_pixel",
nearest_mode_s="round_prefer_floor",
)

torch.onnx.register_custom_op_symbolic(
"aten::_upsample_nearest_exact2d", # PyTorch op name
upsample_nearest_exact_symbolic, # Your symbolic function
18, # Target ONNX opset
)

if is_torch_version("<", "2.9"):
# this was fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973
from torch.onnx import JitScalarType
Expand Down
Loading