Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lora support to stable diffusion #483

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,40 @@ def parse_args_neuronx(parser: "ArgumentParser"):
action="store_true",
help=("Whether or not for the traced model to return the hidden states of all layers."),
)
optional_group.add_argument(
"--lora_model_ids",
default=None,
nargs="*",
type=str,
help=(
"List of model ids (eg. `ostris/super-cereal-sdxl-lora`) of pretrained lora models hosted on the Hub or paths to local directories containing the lora weights."
),
)
optional_group.add_argument(
"--lora_weight_names",
default=None,
nargs="*",
type=str,
help="List of lora weights file names.",
)
optional_group.add_argument(
"--lora_adapter_names",
default=None,
nargs="*",
type=str,
help="List of the adapter names to be used for referencing the loaded adapter models.",
)
optional_group.add_argument(
"--lora_scales",
default=None,
nargs="*",
type=float,
help="List of scaling factors for the lora adapters.",
)
optional_group.add_argument(
"--output_attentions",
action="store_true",
help=("Whether or not for the traced model to return the attentions tensors of all attention layers."),
help="Whether or not for the traced model to return the attentions tensors of all attention layers.",
)

input_group = parser.add_argument_group("Input shapes")
Expand Down
65 changes: 58 additions & 7 deletions optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
from argparse import ArgumentParser
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from requests.exceptions import ConnectionError as RequestsConnectionError
from transformers import AutoConfig, PretrainedConfig
Expand Down Expand Up @@ -245,6 +245,10 @@ def _get_submodels_and_neuron_configs(
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
):
is_stable_diffusion = "stable-diffusion" in task
is_encoder_decoder = (
Expand All @@ -258,12 +262,16 @@ def _get_submodels_and_neuron_configs(
f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet."
)
models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion(
model,
input_shapes,
task,
output,
dynamic_batch_size,
submodels,
model=model,
input_shapes=input_shapes,
task=task,
output=output,
dynamic_batch_size=dynamic_batch_size,
submodels=submodels,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
)
elif is_encoder_decoder:
optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states}
Expand Down Expand Up @@ -291,13 +299,37 @@ def _get_submodels_and_neuron_configs(
return models_and_neuron_configs, output_model_names


def _normalize_lora_params(lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales):
if isinstance(lora_model_ids, str):
lora_model_ids = [
lora_model_ids,
]
if isinstance(lora_weight_names, str):
lora_weight_names = [
lora_weight_names,
]
if isinstance(lora_adapter_names, str):
lora_adapter_names = [
lora_adapter_names,
]
if isinstance(lora_scales, float):
lora_scales = [
lora_scales,
]
return lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales


def _get_submodels_and_neuron_configs_for_stable_diffusion(
model: Union["PreTrainedModel", "DiffusionPipeline"],
input_shapes: Dict[str, int],
task: str,
output: Path,
dynamic_batch_size: bool = False,
submodels: Optional[Dict[str, Union[Path, str]]] = None,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
):
check_compiler_compatibility_for_stable_diffusion()
model = replace_stable_diffusion_submodels(model, submodels)
Expand All @@ -317,10 +349,17 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion(
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
model.save_config(output)

lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales = _normalize_lora_params(
lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales
)
models_and_neuron_configs = get_stable_diffusion_models_for_export(
pipeline=model,
task=task,
dynamic_batch_size=dynamic_batch_size,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
**input_shapes,
)
output_model_names = {
Expand Down Expand Up @@ -395,6 +434,10 @@ def main_export(
output_attentions: bool = False,
output_hidden_states: bool = False,
library_name: Optional[str] = None,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
**input_shapes,
):
output = Path(output)
Expand Down Expand Up @@ -434,6 +477,10 @@ def main_export(
submodels=submodels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
)

_, neuron_outputs = export_models(
Expand Down Expand Up @@ -556,6 +603,10 @@ def main():
do_validation=not args.disable_validation,
submodels=submodels,
library_name=args.library_name,
lora_model_ids=getattr(args, "lora_model_ids", None),
lora_weight_names=getattr(args, "lora_weight_names", None),
lora_adapter_names=getattr(args, "lora_adapter_names", None),
lora_scales=getattr(args, "lora_scales", None),
**optional_outputs,
**input_shapes,
)
Expand Down
60 changes: 58 additions & 2 deletions optimum/exporters/neuron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import copy
import os
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import torch
from transformers import PretrainedConfig
Expand Down Expand Up @@ -135,6 +135,10 @@ def get_stable_diffusion_models_for_export(
vae_encoder_input_shapes: Dict[str, int],
vae_decoder_input_shapes: Dict[str, int],
dynamic_batch_size: Optional[bool] = False,
lora_model_ids: Optional[List[str]] = None,
lora_weight_names: Optional[List[str]] = None,
lora_adapter_names: Optional[List[str]] = None,
lora_scales: Optional[List[float]] = None,
) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "NeuronDefaultConfig"]]:
"""
Returns the components of a Stable Diffusion model and their subsequent neuron configs.
Expand All @@ -157,12 +161,27 @@ def get_stable_diffusion_models_for_export(
Static shapes used for compiling vae decoder.
dynamic_batch_size (`bool`, defaults to `False`):
Whether the Neuron compiled model supports dynamic batch size.
lora_model_ids (`Optional[List[str]]`, defaults to `None`):
List of model ids (eg. `ostris/super-cereal-sdxl-lora`) of pretrained lora models hosted on the Hub or paths to local directories containing the lora weights.
lora_weight_names (`Optional[List[str]]`, defaults to `None`):
List of lora weights file names.
lora_adapter_names (`Optional[List[str]]`, defaults to `None`):
List of adapter names to be used for referencing the loaded adapter models.
lora_scales (`Optional[List[float]]`, defaults to `None`):
List of scaling factors for lora adapters.

Returns:
`Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronDefaultConfig`]`: A Dict containing the model and
Neuron configs for the different components of the model.
"""
models_for_export = _get_submodels_for_export_stable_diffusion(pipeline=pipeline, task=task)
models_for_export = _get_submodels_for_export_stable_diffusion(
pipeline=pipeline,
task=task,
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
)
library_name = "diffusers"

# Text encoders
Expand Down Expand Up @@ -255,15 +274,52 @@ def get_stable_diffusion_models_for_export(
return models_for_export


def _load_lora_weights_to_pipeline(
pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"],
lora_model_ids: Optional[List[str]] = None,
weight_names: Optional[List[str]] = None,
adapter_names: Optional[List[str]] = None,
lora_scales: Optional[List[float]] = None,
):
if lora_model_ids and weight_names:
if len(lora_model_ids) == 1:
pipeline.load_lora_weights(lora_model_ids[0], weight_name=weight_names[0])
# For tracing the lora weights, we need to use PEFT to fuse adapters directly into the model weights. It won't work by passing the lora scale to the Neuron pipeline during the inference.
pipeline.fuse_lora(lora_scale=lora_scales[0])
elif len(lora_model_ids) > 1:
if not len(lora_model_ids) == len(weight_names) == len(adapter_names):
raise ValueError(
f"weight_name and lora_scale are required to fuse more than one lora. You have {len(lora_model_ids)} lora models to fuse, but you have {len(weight_names)} lora weight names and {len(adapter_names)} adapter names."
)
for model_id, weight_name, adapter_name in zip(lora_model_ids, weight_names, adapter_names):
pipeline.load_lora_weights(model_id, weight_name=weight_name, adapter_name=adapter_name)

if lora_scales:
pipeline.set_adapters(adapter_names, adapter_weights=lora_scales)
pipeline.fuse_lora()


def _get_submodels_for_export_stable_diffusion(
pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"],
task: str,
lora_model_ids: Optional[List[str]] = None,
lora_weight_names: Optional[List[str]] = None,
lora_adapter_names: Optional[List[str]] = None,
lora_scales: Optional[List[float]] = None,
) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]:
"""
Returns the components of a Stable Diffusion model.
"""
is_sdxl = "xl" in task

_load_lora_weights_to_pipeline(
pipeline=pipeline,
lora_model_ids=lora_model_ids,
weight_names=lora_weight_names,
adapter_names=lora_adapter_names,
lora_scales=lora_scales,
)

models_for_export = []
if hasattr(pipeline, "text_encoder_2"):
projection_dim = pipeline.text_encoder_2.config.projection_dim
Expand Down
18 changes: 17 additions & 1 deletion optimum/neuron/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import abstractmethod
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -553,6 +553,10 @@ def _export(
disable_fallback: bool = False,
dynamic_batch_size: bool = False,
data_parallel_mode: Optional[str] = None,
lora_model_ids: Optional[Union[str, List[str]]] = None,
lora_weight_names: Optional[Union[str, List[str]]] = None,
lora_adapter_names: Optional[Union[str, List[str]]] = None,
lora_scales: Optional[Union[float, List[float]]] = None,
**kwargs_shapes,
) -> "NeuronStableDiffusionPipelineBase":
"""
Expand Down Expand Up @@ -615,6 +619,14 @@ def _export(
data_parallel_mode (`Optional[str]`, defaults to `None`):
Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only
load unet into both cores of each device), "all"(load the whole pipeline into both cores).
lora_model_ids (`Optional[Union[str, List[str]]]`, defaults to `None`):
Lora model local paths or repo ids (eg. `ostris/super-cereal-sdxl-lora`) on the Hugginface Hub.
lora_weight_names (`Optional[Union[str, List[str]]]`, defaults to `None`):
Lora weights file names.
lora_adapter_names (`Optional[List[str]]`, defaults to `None`):
Adapter names to be used for referencing the loaded adapter models.
lora_scales (`Optional[List[float]]`, defaults to `None`):
Lora adapters scaling factors.
kwargs_shapes (`Dict[str, int]`):
Shapes to use during inference. This argument allows to override the default shapes used during the export.
"""
Expand Down Expand Up @@ -654,6 +666,10 @@ def _export(
use_auth_token=use_auth_token,
do_validation=False,
submodels={"unet": unet_id},
lora_model_ids=lora_model_ids,
lora_weight_names=lora_weight_names,
lora_adapter_names=lora_adapter_names,
lora_scales=lora_scales,
library_name=cls.library_name,
**input_shapes,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,11 @@ def __call__(
self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt)

# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored."
)
lora_scale = None

# NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided
# distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ def __call__(
)

# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored."
)
text_encoder_lora_scale = None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
num_images_per_prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,11 @@ def __call__(
)

# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored."
)
text_encoder_lora_scale = None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
num_images_per_prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,11 @@ def __call__(
)

# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored."
)
text_encoder_lora_scale = None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
num_images_per_prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,11 @@ def __call__(
)

# 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored."
)
lora_scale = None
(
prompt_embeds,
negative_prompt_embeds,
Expand Down
Loading
Loading