Skip to content

Commit

Permalink
add variant for model loading in from_transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 17, 2025
1 parent 6b98d62 commit 5ae26d0
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 64 deletions.
48 changes: 1 addition & 47 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,10 @@ def _get_submodels_and_export_configs(
def get_diffusion_models_for_export_ext(
pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino"
):
<<<<<<< HEAD
is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
is_flux = pipeline.__class__.__name__.startswith("Flux")
is_sana = pipeline.__class__.__name__.startswith("Sana")
is_sana = pipeline.__class__.__name__.startswith("Sana")
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")

Expand All @@ -1036,51 +1035,6 @@ def get_diffusion_models_for_export_ext(
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
elif is_flux:
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
=======
if is_diffusers_version(">=", "0.29.0"):
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline

sd3_pipes = [StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline]
if is_diffusers_version(">=", "0.30.0"):
from diffusers import StableDiffusion3InpaintPipeline

sd3_pipes.append(StableDiffusion3InpaintPipeline)

is_sd3 = isinstance(pipeline, tuple(sd3_pipes))
else:
is_sd3 = False

if is_diffusers_version(">=", "0.30.0"):
from diffusers import FluxPipeline

flux_pipes = [FluxPipeline]

if is_diffusers_version(">=", "0.31.0"):
from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline

flux_pipes.extend([FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline])

if is_diffusers_version(">=", "0.32.0"):
from diffusers import FluxFillPipeline

flux_pipes.append(FluxFillPipeline)

is_flux = isinstance(pipeline, tuple(flux_pipes))
else:
is_flux = False

if is_diffusers_version(">=", "0.32.0"):
from diffusers import SanaPipeline

is_sana = isinstance(pipeline, SanaPipeline)
else:
is_sana = False

if not any([is_sana, is_flux, is_sd3]):
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
if is_sd3:
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
>>>>>>> add pipeline
elif is_sana:
models_for_export = get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype)
else:
Expand Down
6 changes: 6 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
Qwen2VLVisionEmbMergerPatcher,
QwenModelPatcher,
RotaryEmbPatcher,
SanaTextEncoderModelPatcher,
StatefulSeq2SeqDecoderPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
Expand Down Expand Up @@ -1903,6 +1904,11 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return SanaTextEncoderModelPatcher(self, model, model_kwargs)


class DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
Expand Down
38 changes: 34 additions & 4 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.utils import is_tf_available

from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.model_patcher import (
DecoderModelPatcher,
ModelPatcher,
Expand Down Expand Up @@ -114,18 +116,20 @@ def patch_model_with_bettertransformer(model):
return model


def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None):
def patch_update_causal_mask(
model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False
):
if is_transformers_version(">=", transformers_version):
inner_model = getattr(model, inner_model_name, None)
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
if inner_model is not None:
if hasattr(inner_model, "_update_causal_mask"):
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
patch_fn = patch_fn or _llama_gemma_update_causal_mask
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)


def unpatch_update_causal_mask(model, inner_model_name="model"):
inner_model = getattr(model, inner_model_name, None)
def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False):
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"):
inner_model._update_causal_mask = inner_model._orig_update_causal_mask

Expand Down Expand Up @@ -3791,3 +3795,29 @@ def patched_forward(*args, **kwargs):
model.forward = patched_forward

super().__init__(config, model, model_kwargs)


class SanaTextEncoderModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True)

if self._model.config._attn_implementation != "sdpa":
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
if is_transformers_version("<", "4.47.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES

sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
for layer in self._model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, None, True)
if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
12 changes: 9 additions & 3 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,15 @@ def deduce_diffusers_dtype(model_name_or_path, **loading_kwargs):
model_part_name = "unet"
if model_part_name:
directory = path / model_part_name
safetensors_files = [
filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1
]

pattern = "*.safetensors"
if "variant" in loading_kwargs:
variant = loading_kwargs["variant"]
pattern = f"*.{variant}.safetensors"
safetensors_files = list(directory.glob(pattern))
else:
# filter out variant files
safetensors_files = [filename for filename in directory.glob(pattern) if len(filename.suffixes) == 1]
safetensors_file = None
if len(safetensors_files) > 0:
safetensors_file = safetensors_files.pop(0)
Expand Down
3 changes: 3 additions & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ def _from_transformers(
else:
ov_config = OVConfig(dtype="fp32")

variant = kwargs.pop("variant", None)

main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -607,6 +609,7 @@ def _from_transformers(
trust_remote_code=trust_remote_code,
ov_config=ov_config,
library_name=cls._library_name,
model_variant=variant,
)

return cls._from_pretrained(
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def _from_transformers(
else:
ov_config = OVConfig(dtype="fp32")
stateful = kwargs.get("stateful", True)
variant = kwargs.pop("variant", None)

main_export(
model_name_or_path=model_id,
Expand All @@ -422,6 +423,7 @@ def _from_transformers(
trust_remote_code=trust_remote_code,
ov_config=ov_config,
stateful=stateful,
model_variant=variant,
)

return cls._from_pretrained(
Expand Down
3 changes: 3 additions & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def _from_transformers(
if torch_dtype is not None:
model_loading_kwargs["torch_dtype"] = torch_dtype

variant = kwargs.pop("variant", None)

main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -325,6 +327,7 @@ def _from_transformers(
stateful=stateful,
model_loading_kwargs=model_loading_kwargs,
library_name=cls._library_name,
model_variant=variant,
)

if config.model_type == "phi3" and config.max_position_embeddings != getattr(
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def _from_transformers(

model_save_dir = TemporaryDirectory()
model_save_path = Path(model_save_dir.name)
variant = kwargs.pop("variant", None)

main_export(
model_name_or_path=model_id,
Expand All @@ -589,6 +590,7 @@ def _from_transformers(
force_download=force_download,
ov_config=ov_config,
library_name=cls._library_name,
model_variant=variant,
)

return cls._from_pretrained(
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ def _from_transformers(
ov_config = OVConfig(dtype="fp32" if load_in_8bit is False else "auto")

stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
variant = kwargs.pop("variant", None)

main_export(
model_name_or_path=model_id,
Expand All @@ -629,6 +630,7 @@ def _from_transformers(
trust_remote_code=trust_remote_code,
ov_config=ov_config,
stateful=stateful,
model_variant=variant,
)
config = AutoConfig.from_pretrained(save_dir_path, trust_remote_code=trust_remote_code)
return cls._from_pretrained(
Expand Down
14 changes: 4 additions & 10 deletions tests/openvino/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
for output_type in ["latent", "np", "pt"]:
inputs["output_type"] = output_type
if model_arch == "sana":
if output_type == "latent":
continue
# resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
inputs["use_resolution_binning"] = False
atol = 4e-2
else:
atol = 6e-3
atol = 1e-4

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
Expand All @@ -166,12 +163,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
for output_type in ["latent", "np", "pt"]:
inputs["output_type"] = output_type
if model_arch == "sana":
if output_type == "latent":
continue
# resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations
inputs["use_resolution_binning"] = False
atol = 4e-2
else:
atol = 6e-3
atol = 6e-3

ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
Expand Down

0 comments on commit 5ae26d0

Please sign in to comment.