Skip to content

Commit

Permalink
[Inference] Improve the support of sentence transformers (#408)
Browse files Browse the repository at this point in the history
* support sentence transformers

* support clip

* fix clip and add tests

* add dependencies

* fix style

* adapt for inf1

* abandon  inf1

* fix name typo and remove useless outputs

* add test for inference

* support inference

* update doc

* fix typo

---------

Co-authored-by: JingyaHuang <[email protected]>
  • Loading branch information
JingyaHuang and JingyaHuang authored Jan 16, 2024
1 parent 43d2f90 commit 9837efa
Show file tree
Hide file tree
Showing 15 changed files with 356 additions and 38 deletions.
4 changes: 4 additions & 0 deletions docs/source/package_reference/modeling.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ The following Neuron model classes are available for natural language processing

[[autodoc]] modeling.NeuronModelForFeatureExtraction

### NeuronModelForSenetenceTransformers

[[autodoc]] modeling.NeuronModelForSenetenceTransformers

### NeuronModelForMaskedLM

[[autodoc]] modeling.NeuronModelForMaskedLM
Expand Down
15 changes: 15 additions & 0 deletions optimum/commands/export/neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def parse_args_neuronx(parser: "ArgumentParser"):
f" {str(list(TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS.keys()) + list(TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS.keys()))}."
),
)
optional_group.add_argument(
"--library-name",
type=str,
choices=["transformers", "diffusers", "sentence_transformers"],
default=None,
help=("The library on the model." " If not provided, will attempt to infer the local checkpoint's library."),
)
optional_group.add_argument(
"--subfolder",
type=str,
default="",
help=(
"In case the relevant files are located inside a subfolder of the model repo either locally or on huggingface.co, specify the folder name here."
),
)
optional_group.add_argument(
"--atol",
type=float,
Expand Down
56 changes: 43 additions & 13 deletions optimum/exporters/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,47 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from .__main__ import (
infer_stable_diffusion_shapes_from_diffusers,
main_export,
normalize_input_shapes,
normalize_stable_diffusion_input_shapes,
)
from .base import NeuronConfig
from .convert import export, export_models, validate_model_outputs, validate_models_outputs
from .utils import (
DiffusersPretrainedConfig,
build_stable_diffusion_components_mandatory_shapes,
get_stable_diffusion_models_for_export,
)
from transformers.utils import _LazyModule


_import_structure = {
"__main__": [
"infer_stable_diffusion_shapes_from_diffusers",
"main_export",
"normalize_input_shapes",
"normalize_stable_diffusion_input_shapes",
],
"base": ["NeuronConfig"],
"convert": ["export", "export_models", "validate_model_outputs", "validate_models_outputs"],
"utils": [
"DiffusersPretrainedConfig",
"build_stable_diffusion_components_mandatory_shapes",
"get_stable_diffusion_models_for_export",
],
}

if TYPE_CHECKING:
from .__main__ import (
infer_stable_diffusion_shapes_from_diffusers,
main_export,
normalize_input_shapes,
normalize_stable_diffusion_input_shapes,
)
from .base import NeuronConfig
from .convert import export, export_models, validate_model_outputs, validate_models_outputs
from .utils import (
DiffusersPretrainedConfig,
build_stable_diffusion_components_mandatory_shapes,
get_stable_diffusion_models_for_export,
)
else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
24 changes: 23 additions & 1 deletion optimum/exporters/neuron/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,19 @@ def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int
return input_shapes


def normalize_sentence_transformers_input_shapes(args: argparse.Namespace) -> Dict[str, int]:
args = vars(args) if isinstance(args, argparse.Namespace) else args
mandatory_axes = {"batch_size", "sequence_length"}
if "clip" in args.get("model", "").lower():
mandatory_axes.update(["num_channels", "width", "height"])
if not mandatory_axes.issubset(set(args.keys())):
raise AttributeError(
f"Shape of {mandatory_axes} are mandatory for neuron compilation, while {mandatory_axes.difference(args.keys())} are not given."
)
mandatory_shapes = {name: args[name] for name in mandatory_axes}
return mandatory_shapes


def customize_optional_outputs(args: argparse.Namespace) -> Dict[str, bool]:
"""
Customize optional outputs of the traced model, eg. if `output_attentions=True`, the attentions tensors will be traced.
Expand Down Expand Up @@ -249,7 +262,8 @@ def _get_submodels_and_neuron_configs(
model=model, exporter="neuron", task=task
)
neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
model_name = model.name_or_path.split("/")[-1]
model_name = getattr(model, "name_or_path", None) or model_name_or_path
model_name = model_name.split("/")[-1] if model_name else model.config.model_type
output_model_names = {model_name: "model.neuron"}
models_and_neuron_configs = {model_name: (model, neuron_config)}
maybe_save_preprocessors(model_name_or_path, output)
Expand Down Expand Up @@ -358,6 +372,7 @@ def main_export(
submodels: Optional[Dict[str, Union[Path, str]]] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
library_name: Optional[str] = None,
**input_shapes,
):
output = Path(output)
Expand All @@ -378,6 +393,7 @@ def main_export(
"force_download": force_download,
"trust_remote_code": trust_remote_code,
"framework": "pt",
"library_name": library_name,
}
model = TasksManager.get_model_from_task(**model_kwargs)

Expand Down Expand Up @@ -451,11 +467,15 @@ def main():

task = infer_task(args.task, args.model)
is_stable_diffusion = "stable-diffusion" in task
is_sentence_transformers = args.library_name == "sentence_transformers"
compiler_kwargs = infer_compiler_kwargs(args)

if is_stable_diffusion:
input_shapes = normalize_stable_diffusion_input_shapes(args)
submodels = {"unet": args.unet}
elif is_sentence_transformers:
input_shapes = normalize_sentence_transformers_input_shapes(args)
submodels = None
else:
input_shapes = normalize_input_shapes(task, args)
submodels = None
Expand All @@ -474,8 +494,10 @@ def main():
compiler_workdir=args.compiler_workdir,
optlevel=optlevel,
trust_remote_code=args.trust_remote_code,
subfolder=args.subfolder,
do_validation=not args.disable_validation,
submodels=submodels,
library_name=args.library_name,
**optional_outputs,
**input_shapes,
)
Expand Down
25 changes: 16 additions & 9 deletions optimum/exporters/neuron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import numpy as np
import torch
from transformers import PretrainedConfig

from ...exporters.error_utils import OutputMatchError, ShapeError
from ...neuron.utils import (
Expand All @@ -31,7 +30,11 @@
store_compilation_config,
)
from ...neuron.utils.version_utils import get_neuroncc_version, get_neuronxcc_version
from ...utils import is_diffusers_available, logging
from ...utils import (
is_diffusers_available,
is_sentence_transformers_available,
logging,
)
from .utils import DiffusersPretrainedConfig


Expand All @@ -56,6 +59,8 @@
from diffusers import ModelMixin
from diffusers.configuration_utils import FrozenDict

if is_sentence_transformers_available():
from sentence_transformers import SentenceTransformer

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -130,7 +135,7 @@ def validate_models_outputs(

def validate_model_outputs(
config: "NeuronConfig",
reference_model: Union["PreTrainedModel", "ModelMixin"],
reference_model: Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"],
neuron_model_path: Path,
neuron_named_outputs: List[str],
atol: Optional[float] = None,
Expand All @@ -141,7 +146,7 @@ def validate_model_outputs(
Args:
config ([`~optimum.neuron.exporter.NeuronConfig`]:
The configuration used to export the model.
reference_model ([`Union["PreTrainedModel", "ModelMixin"]`]):
reference_model ([`Union["PreTrainedModel", "SentenceTransformer", "ModelMixin"]`]):
The model used for the export.
neuron_model_path (`Path`):
The path to the exported model.
Expand Down Expand Up @@ -169,9 +174,13 @@ def validate_model_outputs(
with torch.no_grad():
reference_model.eval()
ref_inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes)
if getattr(reference_model.config, "is_encoder_decoder", False):
if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False):
reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes)
if "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr(
if "SentenceTransformer" in reference_model.__class__.__name__:
reference_model = config.patch_model_for_export(reference_model, ref_inputs)
ref_outputs = reference_model(**ref_inputs)
neuron_inputs = tuple(config.flatten_inputs(ref_inputs).values())
elif "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr(
reference_model.config, "is_encoder_decoder", False
):
# VAE components for stable diffusion or Encoder-Decoder models
Expand Down Expand Up @@ -359,8 +368,6 @@ def export_models(
output_attentions=getattr(sub_neuron_config, "output_attentions", False),
output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False),
)
if isinstance(model_config, PretrainedConfig):
model_config = DiffusersPretrainedConfig.from_dict(model_config.__dict__)
model_config.save_pretrained(output_path.parent)
except Exception as e:
failed_models.append((i, model_name))
Expand Down Expand Up @@ -469,7 +476,7 @@ def export_neuronx(
dummy_inputs_tuple = tuple(dummy_inputs.values())

aliases = {}
if getattr(model.config, "is_encoder_decoder", False):
if hasattr(model, "config") and getattr(model.config, "is_encoder_decoder", False):
checked_model = config.patch_model_for_export(model, **input_shapes)
if getattr(config, "is_decoder", False):
aliases = config.generate_io_aliases(checked_model)
Expand Down
37 changes: 37 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
NormalizedConfigManager,
NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
is_diffusers_available,
)
from ...utils.normalized_config import T5LikeNormalizedTextConfig
Expand All @@ -41,6 +42,8 @@
VisionNeuronConfig,
)
from .model_wrappers import (
SentenceTransformersCLIPNeuronWrapper,
SentenceTransformersTransformerNeuronWrapper,
T5DecoderWrapper,
T5EncoderWrapper,
UnetNeuronWrapper,
Expand Down Expand Up @@ -171,6 +174,24 @@ class DebertaV2NeuronConfig(DebertaNeuronConfig):
pass


@register_in_tasks_manager("sentence-transformers-transformer", *["feature-extraction", "sentence-similarity"])
class SentenceTransformersTransformerNeuronConfig(TextEncoderNeuronConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
CUSTOM_MODEL_WRAPPER = SentenceTransformersTransformerNeuronWrapper
ATOL_FOR_VALIDATION = 1e-3

@property
def inputs(self) -> List[str]:
return ["input_ids", "attention_mask"]

@property
def outputs(self) -> List[str]:
return ["token_embeddings", "sentence_embedding"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


class CLIPNormalizedConfig(NormalizedTextAndVisionConfig):
TEXT_CONFIG = "text_config"
VISION_CONFIG = "vision_config"
Expand Down Expand Up @@ -229,6 +250,22 @@ def outputs(self) -> List[str]:
return common_outputs


# TODO: We should decouple clip text and vision, this would need fix on Optimum main. For the current workaround
# users can pass dummy text inputs when encoding image, vice versa.
@register_in_tasks_manager("sentence-transformers-clip", *["feature-extraction", "sentence-similarity"])
class SentenceTransformersCLIPNeuronConfig(CLIPNeuronConfig):
CUSTOM_MODEL_WRAPPER = SentenceTransformersCLIPNeuronWrapper
ATOL_FOR_VALIDATION = 1e-3
MANDATORY_AXES = ("batch_size", "sequence_length", "num_channels", "width", "height")

@property
def outputs(self) -> List[str]:
return ["text_embeds", "image_embeds"]

def patch_model_for_export(self, model, dummy_inputs):
return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))


@register_in_tasks_manager("unet", *["semantic-segmentation"])
class UNetNeuronConfig(VisionNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
Expand Down
35 changes: 35 additions & 0 deletions optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,38 @@ def forward(
neuron_outputs += cross_attentions

return neuron_outputs


class SentenceTransformersTransformerNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str]):
super().__init__()
self.model = model
self.input_names = input_names

def forward(self, input_ids, attention_mask):
out_tuple = self.model({"input_ids": input_ids, "attention_mask": attention_mask})

return out_tuple["token_embeddings"], out_tuple["sentence_embedding"]


class SentenceTransformersCLIPNeuronWrapper(torch.nn.Module):
def __init__(self, model, input_names: List[str]):
super().__init__()
self.model = model
self.input_names = input_names

def forward(self, input_ids, pixel_values, attention_mask):
vision_outputs = self.model[0].model.vision_model(pixel_values=pixel_values)
image_embeds = self.model[0].model.visual_projection(vision_outputs[1])

text_outputs = self.model[0].model.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
)
text_embeds = self.model[0].model.text_projection(text_outputs[1])

if len(self.model) > 1:
image_embeds = self.model[1:](image_embeds)
text_embeds = self.model[1:](text_embeds)

return (text_embeds, image_embeds)
2 changes: 2 additions & 0 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"modeling_base": ["NeuronBaseModel"],
"modeling": [
"NeuronModelForFeatureExtraction",
"NeuronModelForSenetenceTransformers",
"NeuronModelForMaskedLM",
"NeuronModelForQuestionAnswering",
"NeuronModelForSequenceClassification",
Expand Down Expand Up @@ -60,6 +61,7 @@
NeuronModelForMaskedLM,
NeuronModelForMultipleChoice,
NeuronModelForQuestionAnswering,
NeuronModelForSenetenceTransformers,
NeuronModelForSequenceClassification,
NeuronModelForTokenClassification,
)
Expand Down
Loading

0 comments on commit 9837efa

Please sign in to comment.