diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index c7de5970..5c113d9b 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -67,6 +67,12 @@ def main_export( framework: str | None = "pt", atol: float | None = None, pad_token_id: int | None = None, + # inference kwargs + inf_kwargs: dict[str, Any] | None = None, + # module_arch_configs + module_arch_fields: dict[str, list[str]] | None = None, + # flag for export_by_inference + export_by_inference: bool = False, # hub options subfolder: str = "", revision: str = "main", @@ -416,6 +422,9 @@ def main_export( use_subprocess=use_subprocess, do_constant_folding=do_constant_folding, slim=slim, + inf_kwargs=inf_kwargs, + module_arch_fields=module_arch_fields, + export_by_inference=export_by_inference, **kwargs_shapes, ) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 75959418..ea571ab7 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -38,6 +38,7 @@ from optimum.exporters.onnx.utils import ( PickableInferenceSession, _get_submodels_and_onnx_configs, + _get_submodels_and_tensors_, recursive_to_device, ) from optimum.exporters.tasks import TasksManager @@ -470,6 +471,7 @@ def export_pytorch( no_dynamic_axes: bool = False, do_constant_folding: bool = True, model_kwargs: dict[str, Any] | None = None, + export_by_inference: bool = False, ) -> tuple[list[str], list[str]]: """Exports a PyTorch model to an ONNX Intermediate Representation. @@ -528,6 +530,9 @@ def export_pytorch( if input_shapes is None: input_shapes = {} # will use the defaults from DEFAULT_DUMMY_SHAPES + if export_by_inference is True: + input_shapes = {} + # Check that inputs match, and order them properly dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes) @@ -628,6 +633,7 @@ def export_models( no_dynamic_axes: bool = False, do_constant_folding: bool = True, model_kwargs: dict[str, Any] | None = None, + export_by_inference: bool = False, ) -> tuple[list[list[str]], list[list[str]]]: """Exports a Pytorch encoder decoder model to an ONNX Intermediate Representation. The following method exports the encoder and decoder components of the model as separate @@ -696,6 +702,7 @@ def export_models( no_dynamic_axes=no_dynamic_axes, do_constant_folding=do_constant_folding, model_kwargs=model_kwargs, + export_by_inference=export_by_inference, ) ) @@ -715,6 +722,7 @@ def export( no_dynamic_axes: bool = False, do_constant_folding: bool = True, model_kwargs: dict[str, Any] | None = None, + export_by_inference: bool = False, ) -> tuple[list[str], list[str]]: """Exports a Pytorch model to an ONNX Intermediate Representation. @@ -794,6 +802,7 @@ def export( no_dynamic_axes=no_dynamic_axes, do_constant_folding=do_constant_folding, model_kwargs=model_kwargs, + export_by_inference=export_by_inference, ) else: @@ -802,6 +811,8 @@ def export( ) if not disable_dynamic_axes_fix: + if export_by_inference is True: + input_shapes = {} config.fix_dynamic_axes(output, device=device, input_shapes=input_shapes, dtype=dtype) return export_output @@ -826,6 +837,9 @@ def onnx_export_from_model( use_subprocess: bool = False, do_constant_folding: bool = True, slim: bool = False, + inf_kwargs: dict[str,Any] | None = None, + module_arch_fields: dict[str, list[str]] | None = None, + export_by_inference: bool = False, **kwargs_shapes, ): """Full-suite ONNX export function, exporting **from a pre-loaded PyTorch model**. This function is especially useful in case one needs to do modifications on the model, as overriding a forward call, before exporting to ONNX. @@ -981,6 +995,16 @@ def onnx_export_from_model( f"Exporting with a sequence length of 1 a {model_type} model is not supported and can yield unexpected results." ) + if export_by_inference is True: + # inference model to trace input and output tensor shape + models_and_inputs, models_and_outputs = _get_submodels_and_tensors_( + model=model, + inf_kwargs=inf_kwargs, + ) + else: + models_and_inputs = None + models_and_outputs = None + onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs( model=model, task=task, @@ -993,6 +1017,9 @@ def onnx_export_from_model( _variant=_variant, library_name=library_name, model_kwargs=model_kwargs, + models_and_inputs=models_and_inputs, + models_and_outputs=models_and_outputs, + module_arch_fields=module_arch_fields, ) if library_name != "diffusers": @@ -1088,8 +1115,27 @@ def onnx_export_from_model( no_dynamic_axes=no_dynamic_axes, do_constant_folding=do_constant_folding, model_kwargs=model_kwargs, + export_by_inference=export_by_inference, ) + if models_and_outputs is not None: + import json + output_dir = os.path.join(output, "io_binding") + os.makedirs(output_dir, exist_ok=True) + + for module_name, dummy_outputs in models_and_outputs.items(): + # convert tuple -> list for json + serializable = { + name: list(shape) + for name, shape in dummy_outputs.items() + } + + file_path = os.path.join(output_dir, f"{module_name}_outputs.json") + + with open(file_path, "w") as f: + json.dump(serializable, f, indent=4) + print(f"Saved: {file_path}") + if optimize is not None: from optimum.onnxruntime import AutoOptimizationConfig, ORTOptimizer diff --git a/optimum/exporters/onnx/input_generators.py b/optimum/exporters/onnx/input_generators.py index e11ebd16..3824df6f 100644 --- a/optimum/exporters/onnx/input_generators.py +++ b/optimum/exporters/onnx/input_generators.py @@ -17,6 +17,7 @@ DummyAudioInputGenerator, DummyPastKeyValuesGenerator, DummyTransformerTextInputGenerator, + DummyInputGenerator, NormalizedTextConfig, is_transformers_version, ) @@ -108,3 +109,23 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return super().generate( input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype ) + +class DummyTupleInputGenerator(DummyInputGenerator): + + def __init__(self, task: str, config_dim: dict[str, int], **kwargs): + super().__init__() + self.config_dim = config_dim + self.padding_side = "right" + + def generate(self, input_name: str, + tensor_shape: tuple[int, ...], + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32"): + if "input_id" in input_name: + min_value = 0 + max_value = self.config_dim.get("vocab_size", 1000) + return self.random_int_tensor(list(tensor_shape), max_value, min_value=min_value, framework=framework, dtype=int_dtype) + elif "mask" in input_name: + return self.random_mask_tensor(list(tensor_shape), padding_side=self.padding_side, framework=framework, dtype=int_dtype) + return self.random_float_tensor(list(tensor_shape), framework=framework, dtype=float_dtype) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index c4ac1c39..bbc5e69e 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -37,6 +37,7 @@ DummyMoonshineAudioInputGenerator, DummySanaTransforemerTextInputGenerator, GPTBigCodeDummyPastKeyValuesGenerator, + DummyTupleInputGenerator, ) from optimum.exporters.onnx.model_patcher import ( BigBirdPegasusModelPatcher, @@ -2852,3 +2853,98 @@ def outputs(self) -> dict[str, dict[int, str]]: 3: f"latent_width * {up_sampling_factor}", } } + + +class DummyOnnxConfig(OnnxConfig): + + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTupleInputGenerator, + ) + + def __init__( + self, + config: PretrainedConfig, + task: str = "text-encoding", + preprocessors: list[Any] | None = None, + int_dtype: str = "int64", + float_dtype: str = "fp16", + model_inputs: dict[str, Any] | None = None, + model_outputs: dict[str, Any] | None = None, + config_dim: dict[str, int] | None = None, + ): + super().__init__(config=config, task=task, preprocessors=preprocessors, int_dtype=int_dtype, float_dtype=float_dtype) + self.task = task + self.model_inputs = model_inputs + self.model_outputs = model_outputs + self.dummy_tuple_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](task=task, config_dim=config_dim) + self.config_dim = config_dim + + def infer_dynamic_dims(self, tensor_shape: tuple[int, ...], config_dim: dict[str, int], name: str="input") -> dict[int, str]: + dynamic = {} + for idx, dim in enumerate(tensor_shape): + # Batch is always dynamic + if idx == 0: + dynamic[idx] = "batch" + continue + + find_match = False + for key, value in config_dim.items(): + if value == dim: + find_match = True + break + + if find_match is True: + continue + + dynamic[idx] = f"{name}_dim_{idx}" + return dynamic + + @property + def inputs(self) -> dict[str,dict[int,str]]: + model_inputs_dynamic_axes = {} + if self.task == "text-encoding": + for key, value in self.model_inputs.items(): + model_inputs_dynamic_axes[key] = {0: "batch_size", 1: "sequence_length"} + return model_inputs_dynamic_axes + if self.task == "backbone": + for key, value in self.model_inputs.items(): + model_inputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, key) + return model_inputs_dynamic_axes + if self.task == "sample_encode": + for key, value in self.model_inputs.items(): + model_inputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "encode") + return model_inputs_dynamic_axes + if self.task == "latent_decode": + for key, value in self.model_inputs.items(): + model_inputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "decode") + return model_inputs_dynamic_axes + return model_inputs_dynamic_axes + + @property + def outputs(self) -> dict[str, dict[int, str]]: + model_outputs_dynamic_axes = {} + if self.task == "text-encoding": + for key, value in self.model_outputs.items(): + model_outputs_dynamic_axes[key] = {0: "batch_size", 1: "sequence_length"} + return model_outputs_dynamic_axes + if self.task == "backbone": + for key, value in self.model_outputs.items(): + model_outputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, key) + return model_outputs_dynamic_axes + if self.task == "sample_encode": + for key, value in self.model_outputs.items(): + model_outputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "encode") + return model_outputs_dynamic_axes + if self.task == "latent_decode": + for key, value in self.model_outputs.items(): + model_outputs_dynamic_axes[key] = self.infer_dynamic_dims(value, self.config_dim, "decode") + return model_outputs_dynamic_axes + return model_outputs_dynamic_axes + + def generate_dummy_inputs(self, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp16"): + dummy_inputs = {} + for key, value in self.model_inputs.items(): + dummy_inputs[key] = self.dummy_tuple_input_generator.generate(key, value, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) + return dummy_inputs \ No newline at end of file diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 3438b9ef..471a5f74 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -76,6 +76,29 @@ def __ior_(g, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value: torch.onnx.register_custom_op_symbolic("aten::__ior__", __ior_, 14) + +@symbolic_helper.parse_args("v", "v", "v") +def upsample_nearest_exact_symbolic(g, input, output_size, scale_h=None) -> torch._C.Value: + # Compute scales from scale_h + scales = g.op("Concat", g.op("Constant", value_t=torch.tensor([1.0, 1.0], dtype=torch.float32)), scale_h, axis_i=0) + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + return g.op( + "Resize", + input, + empty_roi, # roi (unused for nearest) + scales, + mode_s="nearest", + coordinate_transformation_mode_s="half_pixel", + nearest_mode_s="round_prefer_floor", + ) + +torch.onnx.register_custom_op_symbolic( + "aten::_upsample_nearest_exact2d", # PyTorch op name + upsample_nearest_exact_symbolic, # Your symbolic function + 18, # Target ONNX opset +) + if is_torch_version("<", "2.9"): # this was fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973 from torch.onnx import JitScalarType diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 1da33042..33337ef3 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -24,6 +24,7 @@ from optimum.exporters.tasks import TasksManager from optimum.exporters.utils import _get_submodels_and_export_configs from optimum.utils.import_utils import is_diffusers_available, is_transformers_version +from optimum.exporters.onnx.model_configs import DummyOnnxConfig if TYPE_CHECKING: @@ -233,6 +234,94 @@ def get_sana_models_for_export(pipeline: DiffusionPipeline, int_dtype: str = "in return models_for_export +def generate_config_dim( + model: PreTrainedModel, + dim_name: list[str] | None = None, +): + if dim_name is None: + return {} + tmp = {k: getattr(model.config, k) for k in dim_name if hasattr(model.config, k)} + return {k: getattr(model.config, k) for k in dim_name if hasattr(model.config, k)} + +def get_dynamic_models_for_export( + pipeline: DiffusionPipeline, + models_and_inputs: dict | None = None, + models_and_outputs: dict | None = None, + module_arch_fields: dict[str, list[str]] | None = None, + int_dtype: str = "int64", + float_dtype: str = "fp32" +): + import copy + import types + from functools import partial + + models_for_export = {} + text_encoder = pipeline.text_encoder + text_encoder_config = DummyOnnxConfig(config=text_encoder.config, + task="text-encoding", + preprocessors=None, + int_dtype=int_dtype, + float_dtype=float_dtype, + model_inputs=models_and_inputs["text_encoder"], + model_outputs=models_and_outputs["text_encoder"], + config_dim=generate_config_dim(text_encoder, module_arch_fields["text_encoder"])) + models_for_export["text_encoder"] = (text_encoder, text_encoder_config) + + if hasattr(pipeline, "text_encoder_2") and "text_encoder_2" in models_and_outputs.keys(): + text_encoder_2 = pipeline.text_encoder_2 + text_encoder_2_config = DummyOnnxConfig(config=text_encoder_2.config, + task="text-encoding", + preprocessors=None, + int_dtype=int_dtype, + float_dtype=float_dtype, + model_inputs=models_and_inputs["text_encoder_2"], + model_outputs=models_and_outputs["text_encoder_2"], + config_dim=generate_config_dim(text_encoder, module_arch_fields["text_encoder_2"])) + models_for_export["text_encoder_2"] = (text_encoder_2, text_encoder_2_config) + + transformer = pipeline.transformer + transformer_config = DummyOnnxConfig(config=transformer.config, + task="backbone", + preprocessors=None, + int_dtype=int_dtype, + float_dtype=float_dtype, + model_inputs=models_and_inputs["transformer"], + model_outputs=models_and_outputs["transformer"], + config_dim=generate_config_dim(transformer, module_arch_fields["transformer"])) + models_for_export["transformer"] = (transformer, transformer_config) + + if "vae_encoder" in models_and_inputs.keys(): + vae_encoder = copy.deepcopy(pipeline.vae) + # proper forward wrapper + def encode_forward(self, sample): + return vae_encoder.encode(self, x=sample, return_dict=False) + vae_encoder.forward = types.MethodType(encode_forward, vae_encoder) + vae_encoder_config = DummyOnnxConfig(config=vae_encoder.config, + task="sample_encode", + preprocessors=None, + int_dtype=int_dtype, + float_dtype=float_dtype, + model_inputs=models_and_inputs["vae_encoder"], + model_outputs=models_and_outputs["vae_encoder"], + config_dim=generate_config_dim(vae_decoder, module_arch_fields["vae_encoder"])) + models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_config) + + if "vae_decoder" in models_and_inputs.keys(): + vae_decoder = copy.deepcopy(pipeline.vae) + # proper forward wrapper + def decode_forward(self, latent_sample): + return vae_decoder.decode(self, z=latent_sample, return_dict=False) + vae_decoder.forward = types.MethodType(decode_forward, vae_decoder) + vae_decoder_config = DummyOnnxConfig(config=vae_decoder.config, + task="latent_decode", + preprocessors=None, + int_dtype=int_dtype, + float_dtype=float_dtype, + model_inputs=models_and_inputs["vae_decoder"], + model_outputs=models_and_outputs["vae_decoder"], + config_dim=generate_config_dim(vae_decoder, module_arch_fields["vae_decoder"])) + models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_config) + return models_for_export def _get_submodels_and_onnx_configs( model: PreTrainedModel, @@ -247,6 +336,9 @@ def _get_submodels_and_onnx_configs( fn_get_submodels: Callable | None = None, preprocessors: list[Any] | None = None, model_kwargs: dict | None = None, + models_and_inputs: dict | None = None, + models_and_outputs: dict | None = None, + module_arch_fields: dict[str, list[str]] | None = None, ): if library_name == "transformers" and model.config.model_type == "metaclip_2": export_config_constructor = TasksManager.get_exporter_config_constructor( @@ -264,6 +356,10 @@ def _get_submodels_and_onnx_configs( if library_name == "diffusers" and model.__class__.__name__.startswith("Sana"): return None, get_sana_models_for_export(model, int_dtype, float_dtype) + ## use inference to trace input and output shape + if library_name == "diffusers" and models_and_inputs is not None and models_and_outputs is not None and module_arch_fields is not None: + return None, get_dynamic_models_for_export(model, models_and_inputs, models_and_outputs, module_arch_fields, int_dtype, float_dtype) + return _get_submodels_and_export_configs( model, task, @@ -279,3 +375,160 @@ def _get_submodels_and_onnx_configs( model_kwargs, exporter="onnx", ) + +def make_positional_hook(dummy_inputs, module_name): + import inspect + def hook(module, args, kwargs): + sig = inspect.signature(module.forward) + params = list(sig.parameters.values()) + # remove self if present + if params and params[0].name == "self": + params = params[1:] + named_shapes = {} + for p, v in zip(params, args): + if torch.is_tensor(v): + named_shapes[p.name] = tuple(v.shape) + for k, v in kwargs.items(): + if torch.is_tensor(v): + named_shapes[k] = tuple(v.shape) + dummy_inputs[module_name] = named_shapes + return None # do not modify inputs + return hook + +def get_output_name_and_shape(output, name): + from dataclasses import fields, is_dataclass + + named_shapes = {} + if torch.is_tensor(output): + named_shapes[name] = tuple(output.shape) + elif is_dataclass(output): + for f in fields(output): + val = getattr(output, f.name) + if torch.is_tensor(val): + named_shapes[f.name] = tuple(val.shape) + elif isinstance(output, (tuple, list)): + for i, x in enumerate(output): + if torch.is_tensor(x): + named_shapes[f"{name}_{i}"] = tuple(x.shape) + elif isinstance(output, dict): + for k, v in output.items(): + if torch.is_tensor(v): + named_shapes[k] = tuple(v.shape) + return named_shapes + + +def make_dataclass_output_hook(dummy_outputs, module_name): + def hook(module, args, output): + dummy_outputs[module_name] = get_output_name_and_shape(output, "sample") + return None # don't modify output + return hook + +def _get_submodels_and_tensors_( + model: PreTrainedModel | DiffusionPipeline, + inf_kwargs: dict[str, Any] | None = None, +): + import torch.nn as nn + import inspect + import types + + # key: module_name, value: {input_name: tensor_shape} + dummy_inputs = {} + dummy_outputs = {} + + hooks = [] + transformer_original_forward = None + orig_decode = None + orig_encode = None + + for name, module in model.components.items(): + if isinstance(module, nn.Module): + dummy_inputs[name] = {} + dummy_outputs[name] = {} + + if "text_encoder" in dummy_inputs.keys(): + hooks.append( + model.text_encoder.register_forward_pre_hook(make_positional_hook(dummy_inputs, "text_encoder"), with_kwargs=True)) + hooks.append( + model.text_encoder.register_forward_hook(make_dataclass_output_hook(dummy_outputs, "text_encoder"))) + + if "text_encoder_2" in dummy_inputs.keys(): + hooks.append( + model.text_encoder_2.register_forward_pre_hook(make_positional_hook(dummy_inputs, "text_encoder_2"), with_kwargs=True)) + hooks.append( + model.text_encoder_2.register_forward_hook(make_dataclass_output_hook(dummy_outputs, "text_encoder_2"))) + + if "transformer" in dummy_inputs.keys(): + transformer_original_forward = model.transformer.forward + def wrapped_forward(*args, **kwargs): + for key, value in kwargs.items(): + if torch.is_tensor(value): + dummy_inputs["transformer"][key] = tuple(value.shape) + return transformer_original_forward(*args, **kwargs) + + model.transformer.forward = wrapped_forward + hooks.append( + model.transformer.register_forward_hook(make_dataclass_output_hook(dummy_outputs, "transformer"))) + + if "vae" in dummy_inputs.keys(): + dummy_inputs["vae_encoder"] = {} + dummy_inputs["vae_decoder"] = {} + # hook encoder + wrap_encode = model.vae.encode + for cell in wrap_encode.__closure__: + if inspect.isfunction(cell.cell_contents): + orig_decode = cell.cell_contents + break + if orig_encode is None: + sig = None + else: + sig = inspect.signature(orig_encode) + def hooked_encode(self, *args, **kwargs): + if sig is not None: + bound = sig.bind_partial(self, *args, **kwargs) + for name, value in bound.arguments.items(): + if torch.is_tensor(value): + dummy_inputs["vae_encoder"][name] = tuple(value.shape) + output = wrap_encode(*args, **kwargs) + dummy_output["vae_encoder"] = get_output_name_and_shape(output, "latent_dist") + return output + model.vae.encode = types.MethodType(hooked_encode, model.vae) + + wrap_decode = model.vae.decode + for cell in wrap_decode.__closure__: + if inspect.isfunction(cell.cell_contents): + orig_decode = cell.cell_contents + break + if orig_decode is None: + sig = None + else: + sig = inspect.signature(orig_decode) + def hooked_decode(self, *args, **kwargs): + if sig is not None: + bound = sig.bind_partial(self, *args, **kwargs) + for name, value in bound.arguments.items(): + if torch.is_tensor(value): + dummy_inputs["vae_decoder"]["latent_sample"] = tuple(value.shape) + output = wrap_decode(*args, **kwargs) + dummy_outputs["vae_decoder"] = get_output_name_and_shape(output, "sample") + return output + model.vae.decode = types.MethodType(hooked_decode, model.vae) + + output = model(**inf_kwargs).frames[0] # yes, we can inference + + filtered_inputs = {k: v for k, v in dummy_inputs.items() if v} + filtered_outputs = {k: v for k, v in dummy_outputs.items() if v} + + # remove all the model hooks + for h in hooks: + h.remove() + if transformer_original_forward is not None: + model.transformer.forward = transformer_original_forward + if orig_decode is not None: + model.vae.decode = orig_decode + if orig_encode is not None: + model.vae.encode = orig_encode + + return filtered_inputs, filtered_outputs + + + diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 713f3bfa..ea9d8fd5 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -38,6 +38,8 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, + WanPipeline, + HunyuanVideo15Pipeline, ) from diffusers.pipelines.auto_pipeline import ( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, @@ -59,7 +61,7 @@ from onnxruntime import InferenceSession, SessionOptions from optimum.exporters.onnx import main_export from optimum.onnxruntime.base import ORTParentMixin, ORTSessionMixin -from optimum.onnxruntime.utils import get_device_for_provider, prepare_providers_and_provider_options +from optimum.onnxruntime.utils import get_device_for_provider, prepare_providers_and_provider_options, load_shapes_as_torch_size from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_3_SUBFOLDER, @@ -274,6 +276,12 @@ def from_pretrained( providers: Sequence[str] | None = None, provider_options: Sequence[dict[str, Any]] | dict[str, Any] | None = None, session_options: SessionOptions | None = None, + # inference kwargs + inf_kwargs: dict[str, Any] | None = None, + # module_arch_configs + module_arch_fields: dict[str, list[str]] | None = None, + # flag to use export_by_inference + export_by_inference: bool = False, # inference options use_io_binding: bool | None = None, # hub options and preloaded models @@ -387,11 +395,15 @@ def from_pretrained( no_post_process=True, do_validation=False, task=cls.task, + inf_kwargs = inf_kwargs, + module_arch_fields = module_arch_fields, + export_by_inference = export_by_inference, # export related arguments **export_kwargs, # hub related arguments - **hub_kwargs, + **hub_kwargs ) + # download the model if needed if not model_save_path.is_dir(): @@ -475,7 +487,18 @@ def from_pretrained( ort_pipeline.register_to_config(**config) ort_pipeline.register_to_config(_name_or_path=config.get("_name_or_path", model_name_or_path)) - + for key, comp in ort_pipeline.components.items(): + output_dir = os.path.join(model_save_path, "io_binding") + file_path = os.path.join(output_dir, f"{key}_outputs.json") + if key == "vae": + if comp.encoder is not None: + file_path = os.path.join(output_dir, f"{key}_encoder_outputs.json") + comp.encoder.set_io_binding_file(file_path) + if comp.decoder is not None: + file_path = os.path.join(output_dir, f"{key}_decoder_outputs.json") + comp.decoder.set_io_binding_file(file_path) + else: + comp.set_io_binding_file(file_path) return ort_pipeline def save_pretrained( @@ -572,6 +595,11 @@ def __init__( config_dict = self._dict_from_json_file(config_file_path) self.register_to_config(**config_dict) + self.io_binding_file = None + + def set_io_binding_file(self, filename): + self.io_binding_file = filename + def save_pretrained(self, save_directory: str | Path): """Saves the ONNX model and its configuration file to a directory, so that it can be re-loaded using the [`from_pretrained`] class method. @@ -635,8 +663,10 @@ def forward( } if self.use_io_binding: - known_output_shapes = {"out_sample": sample.shape} + known_output_shapes = load_shapes_as_torch_size(self.io_binding_file) + known_output_shapes["out_sample"] = sample.shape + known_output_buffers = None # in LCM, the scheduler uses both the input sample (latents) and the output sample (model_pred) to compute the next latents # latents, denoised = self.scheduler.step(model_pred, t, latents, **extra_step_kwargs, return_dict=False) @@ -701,7 +731,8 @@ def forward( } if self.use_io_binding: - known_output_shapes = {"out_hidden_states": hidden_states.shape} + known_output_shapes = load_shapes_as_torch_size(self.io_binding_file) + known_output_shapes["out_hidden_states"] = hidden_states.shape known_output_buffers = None # in Flux model, the scheduler uses both the input hidden_states (latents) and the output hidden_states (noise_pred) to compute the next latents @@ -750,7 +781,10 @@ def forward( } if self.use_io_binding: - output_shapes, output_buffers = self._prepare_io_binding(model_inputs) + known_output_shapes = load_shapes_as_torch_size(self.io_binding_file) + output_shapes, output_buffers = self._prepare_io_binding(model_inputs, + known_output_shapes=known_output_shapes, + known_output_buffers=None) if self.device.type == "cpu": self.session.run_with_iobinding(self._io_binding) @@ -807,7 +841,10 @@ def forward( } if self.use_io_binding: - output_shapes, output_buffers = self._prepare_io_binding(model_inputs) + known_output_shapes = load_shapes_as_torch_size(self.io_binding_file) + output_shapes, output_buffers = self._prepare_io_binding(model_inputs, + known_output_shapes=known_output_shapes, + known_output_buffers=None) if self.device.type == "cpu": self.session.run_with_iobinding(self._io_binding) @@ -841,7 +878,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # can be missing from models exported long ago - if not hasattr(self.config, "scaling_factor"): + if not hasattr(self.config, "scaling_factor") and hasattr(self.config, "block_out_channels"): logger.warning( "The `scaling_factor` attribute is missing from the VAE decoder configuration. " "Please re-export the model with newer version of optimum and diffusers to avoid this warning." @@ -861,7 +898,12 @@ def forward( } if self.use_io_binding: - output_shapes, output_buffers = self._prepare_io_binding(model_inputs) + + known_output_shapes = load_shapes_as_torch_size(self.io_binding_file) + + output_shapes, output_buffers = self._prepare_io_binding(model_inputs, + known_output_shapes=known_output_shapes, + known_output_buffers=None) if self.device.type == "cpu": self.session.run_with_iobinding(self._io_binding) @@ -1059,6 +1101,26 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsi main_input_name = "image" auto_model_class = LatentConsistencyModelImg2ImgPipeline +@add_end_docstrings(ORT_PIPELINE_DOCSTRING) +class ORTWanPipeline(ORTDiffusionPipeline, WanPipeline): + """ONNX Runtime-powered Pipeline for text-guided text-to-video generation using transformer Model and corresponding to [WanPipeline] + (https://github.com/huggingface/diffusers/blob/6290fdfda40610ce7b99920146853614ba529c6e/src/diffusers/pipelines/wan/pipeline_wan.py#L95). + """ + + task = "text-to-video" + main_input_name = "prompt" + auto_model_class = WanPipeline + +@add_end_docstrings(ORT_PIPELINE_DOCSTRING) +class ORTHunyuanVideo15Pipeline(ORTDiffusionPipeline, HunyuanVideo15Pipeline): + """ONNX Runtime-powered Pipeline for text-guided text-to-video generation using transformer Model and corresponding to [WanPipeline] + (https://github.com/huggingface/diffusers/blob/6290fdfda40610ce7b99920146853614ba529c6e/src/diffusers/pipelines/wan/pipeline_wan.py#L95). + """ + + task = "text-to-video" + main_input_name = "prompt" + auto_model_class = HunyuanVideo15Pipeline + ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( [ @@ -1083,10 +1145,18 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsi ] ) +ORT_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", ORTWanPipeline), + ("hunyuan", ORTHunyuanVideo15Pipeline) + ] +) + SUPPORTED_ORT_PIPELINES_MAPPINGS = [ ORT_TEXT2IMAGE_PIPELINES_MAPPING, ORT_IMAGE2IMAGE_PIPELINES_MAPPING, ORT_INPAINT_PIPELINES_MAPPING, + ORT_TEXT2VIDEO_PIPELINES_MAPPING, ] @@ -1190,6 +1260,7 @@ class ORTSanaPipeline(ORTUnavailablePipeline): *ORT_TEXT2IMAGE_PIPELINES_MAPPING.values(), *ORT_IMAGE2IMAGE_PIPELINES_MAPPING.values(), *ORT_INPAINT_PIPELINES_MAPPING.values(), + *ORT_TEXT2VIDEO_PIPELINES_MAPPING.values(), ] @@ -1311,6 +1382,28 @@ class ORTPipelineForInpainting(ORTPipelineForTask): config_name = "model_index.json" auto_model_class = AutoPipelineForInpainting ort_pipelines_mapping = ORT_INPAINT_PIPELINES_MAPPING + + +class ORTPipelineForText2Video(ORTPipelineForTask): + """[`ORTPipelineForText2Video`] is a generic pipeline class that instantiates an text2video pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~ORTPipelineForText2Video.from_pretrained`] or [`~ORTPipelineForText2Video.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + - **auto_model_class** (`Type[DiffusionPipeline]`) -- The corresponding/equivalent Diffusers pipeline class. + - **ort_pipelines_mapping** (`OrderedDict`) -- The mapping between the model names/architectures and the + corresponding ORT pipeline class. + + """ + + config_name = "model_index.json" + auto_model_class = DiffusionPipeline + ort_pipelines_mapping = ORT_TEXT2VIDEO_PIPELINES_MAPPING GENERIC_ORT_PIPELINES = [ @@ -1318,6 +1411,7 @@ class ORTPipelineForInpainting(ORTPipelineForTask): ORTPipelineForText2Image, ORTPipelineForImage2Image, ORTPipelineForInpainting, + ORTPipelineForText2Video, ] # Documentation updates diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 6b34aa15..790413c6 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -423,3 +423,19 @@ def get_dtype_from_session(session: ort.InferenceSession) -> torch.dtype: return torch_dtype return torch.float32 + +def load_shapes_as_torch_size(path): + import json + + if not os.path.exists(path): + return {} # or return {} + + with open(path, "r") as f: + data = json.load(f) + + shapes = { + key: torch.Size(shape) # convert list -> torch.Size + for key, shape in data.items() + } + + return shapes