Skip to content

Commit

Permalink
[Inference] Add SentenceTransformers support to pipeline for `fea…
Browse files Browse the repository at this point in the history
…ture-extration` (#583)

* v1

* apply feedback

* add feedback

* fix style

---------

Co-authored-by: Jingya HUANG <[email protected]>
  • Loading branch information
philschmid and JingyaHuang authored May 6, 2024
1 parent f24e60d commit 9361b55
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 1 deletion.
14 changes: 13 additions & 1 deletion optimum/neuron/pipelines/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from transformers import (
AutoConfig,
FeatureExtractionPipeline,
FillMaskPipeline,
Pipeline,
PreTrainedModel,
Expand All @@ -37,12 +36,17 @@

from optimum.modeling_base import OptimizedModel
from optimum.neuron.modeling_base import NeuronBaseModel
from optimum.neuron.pipelines.transformers.sentence_transformers import (
FeatureExtractionPipeline,
is_sentence_transformer_model,
)

from ...modeling import (
NeuronModelForCausalLM,
NeuronModelForFeatureExtraction,
NeuronModelForMaskedLM,
NeuronModelForQuestionAnswering,
NeuronModelForSentenceTransformers,
NeuronModelForSequenceClassification,
NeuronModelForTokenClassification,
)
Expand Down Expand Up @@ -119,6 +123,13 @@ def load_pipeline(
elif isinstance(model, str):
model_id = model
neuronx_model_class = supported_tasks[targeted_task]["class"][0]
# Try to determine the correct feature extraction class to use.
if targeted_task == "feature-extraction" and is_sentence_transformer_model(
model, token=token, revision=revision
):
logger.info("Using Sentence Transformers compatible Feature extraction pipeline")
neuronx_model_class = NeuronModelForSentenceTransformers

model = neuronx_model_class.from_pretrained(
model, export=export, **compiler_args, **input_shapes, **hub_kwargs, **kwargs
)
Expand Down Expand Up @@ -267,5 +278,6 @@ def pipeline(
feature_extractor=feature_extractor,
use_fast=use_fast,
batch_size=batch_size,
pipeline_class=NEURONX_SUPPORTED_TASKS[task]["impl"],
**kwargs,
)
90 changes: 90 additions & 0 deletions optimum/neuron/pipelines/transformers/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Dict

from transformers.pipelines.base import GenericTensor, Pipeline

from optimum.utils import is_sentence_transformers_available


if is_sentence_transformers_available():
from optimum.exporters.tasks import TasksManager


def is_sentence_transformer_model(model: str, token: str = None, revision: str = None):
"""Checks if the model is a sentence transformer model based on provided model id"""
try:
_library_name = TasksManager.infer_library_from_model(model, use_auth_token=token, revision=revision)
return _library_name == "sentence_transformers"
except ValueError:
return False


class FeatureExtractionPipeline(Pipeline):
"""
Sentence Transformers compatible Feature extraction pipeline uses no model head.
This pipeline extracts the sentence embeddings from the sentence transformers, which can be used
in embedding-based tasks like clustering and search. The pipeline is based on the `transformers` library.
And automatically used instead of the `transformers` library's pipeline when the model is a sentence transformer model.
Example:
```python
>>> from optimum.neuron import pipeline
>>> extractor = pipeline(model="sentence-transformers/all-MiniLM-L6-v2", task="feature-extraction", export=True, batch_size=2, sequence_length=128)
>>> result = extractor("This is a simple test.", return_tensors=True)
>>> result.shape # This is a tensor of shape [1, dimension] representing the input string.
torch.Size([1, 384])
```
"""

def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
if tokenize_kwargs is None:
tokenize_kwargs = {}

if truncation is not None:
if "truncation" in tokenize_kwargs:
raise ValueError(
"truncation parameter defined twice (given as keyword argument as well as in tokenize_kwargs)"
)
tokenize_kwargs["truncation"] = truncation

preprocess_params = tokenize_kwargs

postprocess_params = {}
if return_tensors is not None:
postprocess_params["return_tensors"] = return_tensors

return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
model_inputs = self.tokenizer(inputs, return_tensors=self.framework, **tokenize_kwargs)
return model_inputs

def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs

def postprocess(self, _model_outputs, return_tensors=False):
# Needed change for sentence transformers.
# Check if the model outputs sentence embeddings or not.
if hasattr(_model_outputs, "sentence_embedding"):
model_outputs = _model_outputs.sentence_embedding
else:
model_outputs = _model_outputs
# [0] is the first available tensor, logits or last_hidden_state.
if return_tensors:
return model_outputs[0]
if self.framework == "pt":
return model_outputs[0].tolist()

def __call__(self, *args, **kwargs):
"""
Extract the features of the input(s).
Args:
args (`str` or `List[str]`): One or several texts (or one list of texts) to get the features of.
Return:
A nested list of `float`: The features computed by the model.
"""
return super().__call__(*args, **kwargs)
18 changes: 18 additions & 0 deletions tests/inference/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,24 @@ def test_sentence_transformers_clip(self, model_arch):

gc.collect()

@parameterized.expand(["transformer"], skip_on_empty=True)
@requires_neuronx
def test_pipeline_model(self, model_arch):
input_shapes = {
"batch_size": 1,
"sequence_length": 16,
}
model_id = SENTENCE_TRANSFORMERS_MODEL_NAMES[model_arch]
neuron_model = self.NEURON_MODEL_CLASS.from_pretrained(model_id, export=True, **input_shapes)
tokenizer = get_preprocessor(model_id)
pipe = pipeline(self.TASK, model=neuron_model, tokenizer=tokenizer)
text = "My Name is Philipp."
outputs = pipe(text)

self.assertTrue(all(isinstance(item, float) for item in outputs))

gc.collect()


@is_inferentia_test
class NeuronModelForMaskedLMIntegrationTest(NeuronModelTestMixin):
Expand Down

0 comments on commit 9361b55

Please sign in to comment.