Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IP-adapter support for stable diffusion #766

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
30 changes: 30 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,36 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=str,
help="List of model ids (eg. `thibaud/controlnet-openpose-sdxl-1.0`) of ControlNet models.",
)
optional_group.add_argument(
"--ip_adapter_ids",
default=None,
nargs="*",
type=str,
help=(
"Model ids (eg. `h94/IP-Adapter`) of IP-Adapter models hosted on the Hub or paths to local directories containing the IP-Adapter weights."
),
)
optional_group.add_argument(
"--ip_adapter_subfolders",
default=None,
nargs="*",
type=str,
help="The subfolder location of a model file within a larger model repository on the Hub or locally. If a list is passed, it should have the same length as `ip_adapter_weight_names`.",
)
optional_group.add_argument(
"--ip_adapter_weight_names",
default=None,
nargs="*",
type=str,
help="The name of the weight file to load. If a list is passed, it should have the same length as `ip_adapter_subfolders`.",
)
optional_group.add_argument(
"--ip_adapter_scales",
default=None,
nargs="*",
type=float,
help="Scaling factors for the IP-Adapters.",
)
optional_group.add_argument(
"--output_attentions",
action="store_true",
Expand Down
41 changes: 41 additions & 0 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,24 @@ def infer_stable_diffusion_shapes_from_diffusers(
"encoder_hidden_size": encoder_hidden_size,
}

# Image encoder
if getattr(model, "image_encoder", None):
input_shapes["image_encoder"] = {
"batch_size": input_shapes[unet_or_transformer_name]["batch_size"],
"num_channels": model.image_encoder.config.num_channels,
"width": model.image_encoder.config.image_size,
"height": model.image_encoder.config.image_size,
}
# IP-Adapter: add image_embeds as input for unet/transformer
# unet has `ip_adapter_image_embeds` with shape [batch_size, 1, (self.image_encoder.config.image_size//patch_size)**2+1, self.image_encoder.config.hidden_size] as input
if getattr(model.unet.config, "encoder_hid_dim_type", None) == "ip_image_proj":
input_shapes[unet_or_transformer_name]["image_encoder_sequence_length"] = (
model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[0]
)
input_shapes[unet_or_transformer_name]["image_encoder_hidden_size"] = (
model.image_encoder.vision_model.embeddings.position_embedding.weight.shape[1]
)

return input_shapes


Expand Down Expand Up @@ -430,6 +448,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
lora_scales=lora_scales,
controlnet_ids=controlnet_ids,
controlnet_input_shapes=input_shapes.get("controlnet", None),
image_encoder_input_shapes=input_shapes.get("image_encoder", None),
)
output_model_names = {
DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
Expand All @@ -449,6 +468,8 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
output_model_names[DIFFUSION_MODEL_TRANSFORMER_NAME] = os.path.join(
DIFFUSION_MODEL_TRANSFORMER_NAME, NEURON_FILE_NAME
)
if getattr(model, "image_encoder", None) is not None:
output_model_names["image_encoder"] = os.path.join("image_encoder", NEURON_FILE_NAME)

# ControlNet models
if controlnet_ids:
Expand Down Expand Up @@ -522,6 +543,10 @@ def load_models_and_neuron_configs(
torch_dtype: Optional[Union[str, torch.dtype]] = None,
tensor_parallel_size: int = 1,
controlnet_ids: Optional[Union[str, List[str]]] = None,
ip_adapter_ids: Optional[Union[str, List[str]]] = None,
ip_adapter_subfolders: Optional[Union[str, List[str]]] = None,
ip_adapter_weight_names: Optional[Union[str, List[str]]] = None,
ip_adapter_scales: Optional[Union[float, List[float]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
**input_shapes,
Expand All @@ -542,6 +567,10 @@ def load_models_and_neuron_configs(
}
if model is None:
model = TasksManager.get_model_from_task(**model_kwargs)
# Load IP-Adapter if it exists
if ip_adapter_ids:
model.load_ip_adapter(ip_adapter_ids, subfolder=ip_adapter_subfolders, weight_name=ip_adapter_weight_names)
model.set_ip_adapter_scale(ip_adapter_scales)

models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs(
model=model,
Expand Down Expand Up @@ -597,6 +626,10 @@ def main_export(
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
controlnet_ids: Optional[Union[str, List[str]]] = None,
ip_adapter_ids: Optional[Union[str, List[str]]] = None,
ip_adapter_subfolders: Optional[Union[str, List[str]]] = None,
ip_adapter_weight_names: Optional[Union[str, List[str]]] = None,
ip_adapter_scales: Optional[Union[float, List[float]]] = None,
**input_shapes,
):
output = Path(output)
Expand Down Expand Up @@ -634,6 +667,10 @@ def main_export(
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
controlnet_ids=controlnet_ids,
ip_adapter_ids=ip_adapter_ids,
ip_adapter_subfolders=ip_adapter_subfolders,
ip_adapter_weight_names=ip_adapter_weight_names,
ip_adapter_scales=ip_adapter_scales,
**input_shapes,
)

Expand Down Expand Up @@ -777,6 +814,10 @@ def main():
lora_adapter_names=getattr(args, "lora_adapter_names", None),
lora_scales=getattr(args, "lora_scales", None),
controlnet_ids=getattr(args, "controlnet_ids", None),
ip_adapter_ids=getattr(args, "ip_adapter_ids", None),
ip_adapter_subfolders=getattr(args, "ip_adapter_subfolders", None),
ip_adapter_weight_names=getattr(args, "ip_adapter_weight_names", None),
ip_adapter_scales=getattr(args, "ip_adapter_scales", None),
**optional_outputs,
**input_shapes,
)
Expand Down
10 changes: 9 additions & 1 deletion optimum/exporters/neuron/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class NeuronDefaultConfig(NeuronExportConfig, ABC):
DUMMY_INPUT_GENERATOR_CLASSES = ()
ATOL_FOR_VALIDATION: Union[float, Dict[str, float]] = 1e-5
MODEL_TYPE = None
CUSTOM_MODEL_WRAPPER = None

_TASK_TO_COMMON_OUTPUTS = {
"depth-estimation": ["predicted_depth"],
Expand Down Expand Up @@ -166,6 +167,8 @@ def __init__(
num_beams: Optional[int] = None,
vae_scale_factor: Optional[int] = None,
encoder_hidden_size: Optional[int] = None,
image_encoder_sequence_length: Optional[int] = None,
image_encoder_hidden_size: Optional[int] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
int_dtype: Union[str, torch.dtype] = "int64",
Expand Down Expand Up @@ -205,6 +208,8 @@ def __init__(
"patch_size": patch_size or getattr(self._config, "patch_size", None),
"vae_scale_factor": vae_scale_factor,
"encoder_hidden_size": encoder_hidden_size,
"image_encoder_sequence_length": image_encoder_sequence_length,
"image_encoder_hidden_size": image_encoder_hidden_size,
}
input_shapes = {}
for name, value in axes_values.items():
Expand Down Expand Up @@ -425,4 +430,7 @@ def forward(self, *input):

return outputs

return ModelWrapper(model, list(dummy_inputs.keys()))
if self.CUSTOM_MODEL_WRAPPER is None:
return ModelWrapper(model, list(dummy_inputs.keys()))
else:
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))
9 changes: 4 additions & 5 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,22 +186,21 @@ def validate_model_outputs(
reference_model.eval()
inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes)
ref_inputs = config.unflatten_inputs(inputs)
if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False):
if hasattr(reference_model, "config") and getattr(config._config, "is_encoder_decoder", False):
reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes)
if "SentenceTransformer" in reference_model.__class__.__name__:
reference_model = config.patch_model_for_export(reference_model, ref_inputs)
ref_outputs = reference_model(**ref_inputs)
neuron_inputs = tuple(config.flatten_inputs(inputs).values())
elif "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr(
reference_model.config, "is_encoder_decoder", False
config._config, "is_encoder_decoder", False
):
# VAE components for stable diffusion or Encoder-Decoder models
ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
neuron_inputs = tuple(inputs.values())
elif any(
pattern in getattr(config._config, "_class_name", "").lower() for pattern in ["controlnet", "transformer"]
):
elif config.CUSTOM_MODEL_WRAPPER is not None:
ref_inputs = config.flatten_inputs(inputs)
reference_model = config.patch_model_for_export(reference_model, ref_inputs)
neuron_inputs = ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
Expand Down
51 changes: 33 additions & 18 deletions optimum/exporters/neuron/model_configs/traced_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ASTDummyAudioInputGenerator,
DummyBeamValuesGenerator,
DummyControNetInputGenerator,
DummyIPAdapterInputGenerator,
DummyMaskedPosGenerator,
is_neuronx_distributed_available,
)
Expand All @@ -52,6 +53,7 @@
VisionNeuronConfig,
)
from ..model_wrappers import (
CLIPVisionModelNeuronWrapper,
ControlNetNeuronWrapper,
NoCacheModelWrapper,
PixartTransformerNeuronWrapper,
Expand Down Expand Up @@ -146,9 +148,6 @@ class PhiNeuronConfig(ElectraNeuronConfig):
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


@register_in_tasks_manager("roformer", *COMMON_TEXT_TASKS)
class RoFormerNeuronConfig(ElectraNeuronConfig):
Expand Down Expand Up @@ -235,15 +234,26 @@ def inputs(self) -> List[str]:
def outputs(self) -> List[str]:
return ["token_embeddings", "sentence_embedding"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
TEXT_CONFIG = "text_config"
VISION_CONFIG = "vision_config"


@register_in_tasks_manager("clip-vision-model", *["feature-extraction"], library_name="diffusers")
class CLIPVisionModelNeuronConfig(VisionNeuronConfig):
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
CUSTOM_MODEL_WRAPPER = CLIPVisionModelNeuronWrapper

@property
def inputs(self) -> List[str]:
return ["pixel_values"]

@property
def outputs(self) -> List[str]:
return ["image_embeds", "last_hidden_state", "hidden_states"]


@register_in_tasks_manager("clip", *["feature-extraction", "zero-shot-image-classification"])
class CLIPNeuronConfig(TextAndVisionNeuronConfig):
NORMALIZED_CONFIG_CLASS = CLIPNormalizedConfig
Expand Down Expand Up @@ -311,9 +321,6 @@ class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig):
def outputs(self) -> List[str]:
return ["text_embeds", "image_embeds"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))

def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
for name, axis_dim in self._axes.items():
self._axes[name] = kwargs.pop(name, axis_dim)
Expand Down Expand Up @@ -598,6 +605,7 @@ class UNetNeuronConfig(VisionNeuronConfig):
DummyTimestepInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummyControNetInputGenerator,
DummyIPAdapterInputGenerator,
)

@property
Expand All @@ -616,6 +624,10 @@ def inputs(self) -> List[str]:
# outputs of controlnet
common_inputs += ["down_block_additional_residuals", "mid_block_additional_residual"]

if self.with_ip_adapter:
# add output of image encoder
common_inputs += ["image_embeds"]

return common_inputs

@property
Expand Down Expand Up @@ -648,9 +660,6 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
else:
return dummy_inputs

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))

@property
def is_sdxl(self) -> bool:
return self._is_sdxl
Expand All @@ -667,6 +676,18 @@ def with_controlnet(self) -> bool:
def with_controlnet(self, with_controlnet: bool):
self._with_controlnet = with_controlnet

@property
def with_ip_adapter(self) -> bool:
return self._with_ip_adapter

@with_ip_adapter.setter
def with_ip_adapter(self, with_ip_adapter: bool):
self._with_ip_adapter = with_ip_adapter
if with_ip_adapter:
self.mandatory_axes += ("image_encoder_sequence_length", "image_encoder_hidden_size")
setattr(self, "image_encoder_sequence_length", self.input_shapes["image_encoder_sequence_length"])
setattr(self, "image_encoder_hidden_size", self.input_shapes["image_encoder_hidden_size"])


@register_in_tasks_manager("pixart-transformer-2d", *["semantic-segmentation"], library_name="diffusers")
class PixartTransformerNeuronConfig(VisionNeuronConfig):
Expand Down Expand Up @@ -707,9 +728,6 @@ def inputs(self) -> List[str]:
def outputs(self) -> List[str]:
return ["out_hidden_states"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


@register_in_tasks_manager("controlnet", *["semantic-segmentation"], library_name="diffusers")
class ControlNetNeuronConfig(VisionNeuronConfig):
Expand Down Expand Up @@ -755,9 +773,6 @@ def inputs(self) -> List[str]:
def outputs(self) -> List[str]:
return ["down_block_res_samples", "mid_block_res_sample"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


@register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers")
class VaeEncoderNeuronConfig(VisionNeuronConfig):
Expand Down
17 changes: 16 additions & 1 deletion optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def forward(self, *inputs):
if len(inputs) != len(self.input_names):
raise ValueError(
f"The model needs {len(self.input_names)} inputs: {self.input_names}."
f" But only {len(input)} inputs are passed."
f" But only {len(inputs)} inputs are passed."
)

ordered_inputs = dict(zip(self.input_names, inputs))

added_cond_kwargs = {
"text_embeds": ordered_inputs.pop("text_embeds", None),
"time_ids": ordered_inputs.pop("time_ids", None),
"image_embeds": ordered_inputs.pop("image_embeds", None),
}
sample = ordered_inputs.pop("sample", None)
timestep = ordered_inputs.pop("timestep").float().expand((sample.shape[0],))
Expand Down Expand Up @@ -568,6 +569,20 @@ def forward(self, input_ids, attention_mask):
return out_tuple["token_embeddings"], out_tuple["sentence_embedding"]


class CLIPVisionModelNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str]):
super().__init__()
self.model = model
self.input_names = input_names

def forward(self, pixel_values):
vision_outputs = self.model.vision_model(pixel_values=pixel_values, output_hidden_states=True)
pooled_output = vision_outputs[1]
image_embeds = self.model.visual_projection(pooled_output)

return (image_embeds, vision_outputs.last_hidden_state, vision_outputs.hidden_states)


class SentenceTransformersCLIPNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str]):
super().__init__()
Expand Down
Loading
Loading