diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 04f8b748..0ad02af6 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -32,6 +32,8 @@ DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqPastKeyValuesGenerator, DummyTextInputGenerator, + DummyTimestepInputGenerator, + DummyVideoInputGenerator, DummyVisionInputGenerator, logging, ) @@ -410,3 +412,9 @@ def post_process_exported_models( models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.is_merged = True return models_and_onnx_configs, onnx_files_subpaths + + +class VideoOnnxConfig(OnnxConfig): + """Handles video architectures.""" + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoInputGenerator, DummyTimestepInputGenerator) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index c4ac1c39..c52e79d4 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -31,6 +31,7 @@ TextDecoderWithPositionIdsOnnxConfig, TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig, + VideoOnnxConfig, VisionOnnxConfig, ) from optimum.exporters.onnx.input_generators import ( @@ -83,9 +84,11 @@ DummyTransformerTextInputGenerator, DummyTransformerTimestepInputGenerator, DummyTransformerVisionInputGenerator, + DummyVideoInputGenerator, DummyVisionEmbeddingsGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator, DummyVisionInputGenerator, + DummyWanTimestepInputGenerator, DummyXPathSeqInputGenerator, FalconDummyPastKeyValuesGenerator, GemmaDummyPastKeyValuesGenerator, @@ -1385,6 +1388,40 @@ class SiglipVisionModelOnnxConfig(CLIPVisionModelOnnxConfig): pass +@register_tasks_manager_onnx("unet-3d-condition", *["semantic-segmentation"], library_name="diffusers") +class UNet3DOnnxConfig(VideoOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + in_channels="in_channels", + hidden_size="text_encoder_projection_dim", + vocab_size="vocab_size", + allow_new=True, + ) + DUMMY_INPUT_GENERATOR_CLASSES = ( + *VideoOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES, + DummyTransformerTextInputGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return { + "sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"}, + "timestep": {}, # a scalar with no dimension + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return { + "out_sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"}, + } + + @property + def torch_to_onnx_output_map(self) -> dict[str, str]: + return { + "sample": "out_sample", + } + + @register_tasks_manager_onnx("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers") class UNetOnnxConfig(VisionOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -2852,3 +2889,130 @@ def outputs(self) -> dict[str, dict[int, str]]: 3: f"latent_width * {up_sampling_factor}", } } + + +@register_tasks_manager_onnx("umt5-encoder", *["feature-extraction"], library_name="diffusers") +class UMT5EncoderOnnxConfig(TextEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + MIN_TRANSFORMERS_VERSION = version.parse("4.46.0") + + @property + def inputs(self): + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self): + return {"last_hidden_state": {0: "batch_size", 1: "sequence_length"}} + + +@register_tasks_manager_onnx("wan-transformer-3d", *["semantic-segmentation"], library_name="diffusers") +class WanTransformer3DOnnxConfig(VideoOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + in_channels="in_channels", + out_channels="out_channels", + hidden_size="text_dim", + z_dim="z_dim", + expand_timesteps="expand_timesteps", + scale_factor_temporal="vae_scale_factor_temporal", + scale_factor_spatial="vae_scale_factor_spatial", + vocab_size="vocab_size", + allow_new=True, + ) + MIN_TRANSFORMERS_VERSION = version.parse("4.46.0") + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyTransformerTextInputGenerator, + DummyWanTimestepInputGenerator, + DummyVideoInputGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + if self._normalized_config.expand_timesteps is True: + return { + "latent_sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "timestep": {0: "batch_size", 1: "seq_len"}, + } + return { + "latent_sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "timestep": {0: "batch_size"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return { + "sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"}, + } + + def rename_ambiguous_inputs(self, inputs): + # The input name in the model signature is `x, hence the export input name is updated. + model_inputs = inputs + model_inputs["hidden_states"] = inputs["latent_sample"] + model_inputs.pop("latent_sample") + + return model_inputs + + +@register_tasks_manager_onnx("vae-encoder-video", *["semantic-segmentation"], library_name="diffusers") +class VaeEncoderVideoOnnxConfig(VaeEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + in_channels="in_channels", + z_dim="z_dim", + scale_factor_temporal="scale_factor_temporal", + scale_factor_spatial="scale_factor_spatial", + allow_new=True, + ) + MIN_TRANSFORMERS_VERSION = version.parse("4.46.0") + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoInputGenerator,) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return { + "sample": {0: "batch_size", 2: "num_frames", 3: "height", 4: "width"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return { + "latent_parameters": { + 0: "batch_size", + 2: f"1 + ( num_frames - 1 ) // {self._normalized_config.scale_factor_temporal}", + 3: f"height / {self._normalized_config.scale_factor_spatial}", + 4: f"width / {self._normalized_config.scale_factor_spatial}", + } + } + + +@register_tasks_manager_onnx("vae-decoder-video", *["semantic-segmentation"], library_name="diffusers") +class VaeDecoderVideoOnnxConfig(VaeEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + in_channels="in_channels", + out_channels="out_channels", + z_dim="z_dim", + scale_factor_temporal="scale_factor_temporal", + scale_factor_spatial="scale_factor_spatial", + allow_new=True, + ) + MIN_TRANSFORMERS_VERSION = version.parse("4.46.0") + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVideoInputGenerator,) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return { + "latent_sample": {0: "batch_size", 2: "latent_num_frames", 3: "latent_height", 4: "latent_width"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return { + "sample": { + 0: "batch_size", + 2: f"1 + ( latent_num_frames - 1 ) * {self._normalized_config.scale_factor_temporal}", + 3: f"latent_height * {self._normalized_config.scale_factor_spatial}", + 4: f"latent_width * {self._normalized_config.scale_factor_spatial}", + } + } diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 3438b9ef..486cd72d 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -76,6 +76,30 @@ def __ior_(g, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value: torch.onnx.register_custom_op_symbolic("aten::__ior__", __ior_, 14) + +@symbolic_helper.parse_args("v", "v", "v") +def upsample_nearest_exact_symbolic(g, input, output_size, scale_h=None) -> torch._C.Value: + # Compute scales from scale_h + scales = g.op("Concat", g.op("Constant", value_t=torch.tensor([1.0, 1.0], dtype=torch.float32)), scale_h, axis_i=0) + empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32)) + + return g.op( + "Resize", + input, + empty_roi, # roi (unused for nearest) + scales, + mode_s="nearest", + coordinate_transformation_mode_s="half_pixel", + nearest_mode_s="round_prefer_floor", + ) + + +torch.onnx.register_custom_op_symbolic( + "aten::_upsample_nearest_exact2d", # PyTorch op name + upsample_nearest_exact_symbolic, # Your symbolic function + 18, # Target ONNX opset +) + if is_torch_version("<", "2.9"): # this was fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973 from torch.onnx import JitScalarType diff --git a/optimum/onnxruntime/constants.py b/optimum/onnxruntime/constants.py index 023f64a1..89350447 100644 --- a/optimum/onnxruntime/constants.py +++ b/optimum/onnxruntime/constants.py @@ -23,3 +23,8 @@ DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" DECODER_MERGED_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?merged(.*)?\.onnx" ONNX_FILE_PATTERN = r".*\.onnx$" + +# Some newer text-to-video pipelines such as Wan handles the encoder-decoder scaling at a model levels instead of pipeline level. +ENCODER_DECODER_HANDLES_SCALING_FACTOR = [ + "AutoencoderKLWan", +] diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 713f3bfa..6f4e277c 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -30,6 +30,7 @@ AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image, + AutoPipelineForText2Video, LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, StableDiffusionImg2ImgPipeline, @@ -38,6 +39,8 @@ StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, + TextToVideoSDPipeline, + WanPipeline, ) from diffusers.pipelines.auto_pipeline import ( AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, @@ -59,6 +62,7 @@ from onnxruntime import InferenceSession, SessionOptions from optimum.exporters.onnx import main_export from optimum.onnxruntime.base import ORTParentMixin, ORTSessionMixin +from optimum.onnxruntime.constants import ENCODER_DECODER_HANDLES_SCALING_FACTOR from optimum.onnxruntime.utils import get_device_for_provider, prepare_providers_and_provider_options from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, @@ -785,9 +789,11 @@ def forward( class ORTVaeEncoder(ORTModelMixin): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # can be missing from models exported long ago - if not hasattr(self.config, "scaling_factor"): + if ( + not hasattr(self.config, "scaling_factor") + and self.config._class_name not in ENCODER_DECODER_HANDLES_SCALING_FACTOR + ): logger.warning( "The `scaling_factor` attribute is missing from the VAE encoder configuration. " "Please re-export the model with newer version of optimum and diffusers to avoid this warning." @@ -841,7 +847,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # can be missing from models exported long ago - if not hasattr(self.config, "scaling_factor"): + if ( + not hasattr(self.config, "scaling_factor") + and self.config._class_name not in ENCODER_DECODER_HANDLES_SCALING_FACTOR + ): logger.warning( "The `scaling_factor` attribute is missing from the VAE decoder configuration. " "Please re-export the model with newer version of optimum and diffusers to avoid this warning." @@ -886,7 +895,7 @@ class ORTVae(ORTParentMixin): def __init__(self, encoder: ORTVaeEncoder | None = None, decoder: ORTVaeDecoder | None = None): self.encoder = encoder self.decoder = decoder - + self.temperal_downsample = getattr(self.encoder.config, "temperal_downsample", None) self.initialize_ort_attributes(parts=list(filter(None, {self.encoder, self.decoder}))) def decode(self, *args, **kwargs): @@ -1060,6 +1069,30 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsi auto_model_class = LatentConsistencyModelImg2ImgPipeline +@add_end_docstrings(ORT_PIPELINE_DOCSTRING) +class ORTWanPipeline(ORTDiffusionPipeline, WanPipeline): + """ONNX Runtime-powered Pipeline for text-guided text-to-video generation using transformer Model and corresponding to [WanPipeline] + (https://github.com/huggingface/diffusers/blob/6290fdfda40610ce7b99920146853614ba529c6e/src/diffusers/pipelines/wan/pipeline_wan.py#L95). + """ + + task = "text-to-video" + main_input_name = "prompt" + auto_model_class = WanPipeline + + +@add_end_docstrings(ORT_PIPELINE_DOCSTRING) +class ORTTextToVideoSDPipeline(ORTDiffusionPipeline, TextToVideoSDPipeline): + """ONNX Runtime-powered Pipeline for text-to-video using Unet Model. + + Corresponds to + [TextToVideoSDPipeline](https://github.com/huggingface/diffusers/blob/8b4722de57a9a2646466b8bb7095c4fd465193fa/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py#L70C7-L70C28). + """ + + task = "text-to-video" + main_input_name = "prompt" + auto_model_class = TextToVideoSDPipeline + + ORT_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( [ ("latent-consistency", ORTLatentConsistencyModelPipeline), @@ -1083,10 +1116,18 @@ class ORTLatentConsistencyModelImg2ImgPipeline(ORTDiffusionPipeline, LatentConsi ] ) +ORT_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", ORTWanPipeline), + ("text-to-video-sd", ORTTextToVideoSDPipeline), + ] +) + SUPPORTED_ORT_PIPELINES_MAPPINGS = [ ORT_TEXT2IMAGE_PIPELINES_MAPPING, ORT_IMAGE2IMAGE_PIPELINES_MAPPING, ORT_INPAINT_PIPELINES_MAPPING, + ORT_TEXT2VIDEO_PIPELINES_MAPPING, ] @@ -1190,6 +1231,7 @@ class ORTSanaPipeline(ORTUnavailablePipeline): *ORT_TEXT2IMAGE_PIPELINES_MAPPING.values(), *ORT_IMAGE2IMAGE_PIPELINES_MAPPING.values(), *ORT_INPAINT_PIPELINES_MAPPING.values(), + *ORT_TEXT2VIDEO_PIPELINES_MAPPING.values(), ] @@ -1313,11 +1355,34 @@ class ORTPipelineForInpainting(ORTPipelineForTask): ort_pipelines_mapping = ORT_INPAINT_PIPELINES_MAPPING +class ORTPipelineForText2Video(ORTPipelineForTask): + """[`ORTPipelineForText2Video`] is a generic pipeline class that instantiates an text2video pipeline class. The + specific underlying pipeline class is automatically selected from either the + [`~ORTPipelineForText2Video.from_pretrained`] or [`~ORTPipelineForText2Video.from_pipe`] methods. + + This class cannot be instantiated using `__init__()` (throws an error). + + Class attributes: + + - **config_name** (`str`) -- The configuration filename that stores the class and module names of all the + diffusion pipeline's components. + - **auto_model_class** (`Type[DiffusionPipeline]`) -- The corresponding/equivalent Diffusers pipeline class. + - **ort_pipelines_mapping** (`OrderedDict`) -- The mapping between the model names/architectures and the + corresponding ORT pipeline class. + + """ + + config_name = "model_index.json" + auto_model_class = AutoPipelineForText2Video + ort_pipelines_mapping = ORT_TEXT2VIDEO_PIPELINES_MAPPING + + GENERIC_ORT_PIPELINES = [ ORTDiffusionPipeline, ORTPipelineForText2Image, ORTPipelineForImage2Image, ORTPipelineForInpainting, + ORTPipelineForText2Video, ] # Documentation updates