Skip to content

Commit e0abd4b

Browse files
Align I/O with Inference API (#99)
* Fix `task` type-hint and remove extra space in `logging` * Align `transformers` and `diffusers` inputs with Inference API * Remove duplicated `sentencepiece` extra requirement * Remove `pipeline.task` check for `sentence-transformers` * Add `warning` and `pop` unsupported parameters * Fix `sentence-transformers` pipeline type-hints * Update `sentence-ranking` type-hints * Add missing type-hints and clear code a bit * Fix failing `sentence-transformers` tests due to input parsing * Fix "table-question-answering" payload check * Fix "zero-shot-classification" payload check * Check that payload is `dict` in advance * Fix `HuggingFaceHandler` errors and checks * Fix `sentence-transformers` pipelines as those don't have parameters * Fix `INPUT` to `input_data` fixture * Fix quality in `tests/unit/test_handler.py` * Make `parameters` default to empty dict instead of None * Add note on `token-classification` / `ner` task Apparently the parameters are indeed supported via the `__call__` method of the `TokenClassificationPipeline` even if the docs say otherwise, since those are internally provided to the `_sanitize_parameters` function and then used within the `__call__` method instead of via the `__init__` * Update `version` in `setup.py` * Fix `generate_kwargs` payload handling for text2text-based tasks * Fix `generate_kwargs` handling to move to flatten first-level dict Co-authored-by: Célina <[email protected]> * Update `generate_kwargs` handling as sometimes required * Remove `generate` from supported generation kwargs key names * Update `SentenceRankingPipeline` to handle `query`-`texts` pipelines Also adds some extra validation steps * Update typing and fix `sentence-transformers` tests * Upgrade `transformers`, `sentence-transformers` and `peft` dependencies --------- Co-authored-by: Célina <[email protected]>
1 parent 6b17e6c commit e0abd4b

File tree

7 files changed

+277
-149
lines changed

7 files changed

+277
-149
lines changed

setup.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
# We don't declare our dependency on transformers here because we build with
66
# different packages for different variants
77

8-
VERSION = "0.5.2"
8+
VERSION = "0.5.3"
99

1010
# Ubuntu packages
1111
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
1212
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
1313
# libavcodec-extra : libavcodec-extra includes additional codecs for ffmpeg
1414

1515
install_requires = [
16-
"transformers[sklearn,sentencepiece,audio,vision,sentencepiece]==4.46.1",
16+
"transformers[sklearn,sentencepiece,audio,vision]==4.47.0",
1717
"huggingface_hub[hf_transfer]==0.26.2",
1818
# vision
1919
"Pillow",
@@ -31,11 +31,11 @@
3131

3232
extras = {}
3333

34-
extras["st"] = ["sentence_transformers==3.2.1"]
35-
extras["diffusers"] = ["diffusers==0.31.0", "accelerate==1.0.1"]
34+
extras["st"] = ["sentence_transformers==3.3.1"]
35+
extras["diffusers"] = ["diffusers==0.31.0", "accelerate==1.1.0"]
3636
# Includes `peft` as PEFT requires `torch` so having `peft` as a core dependency
3737
# means that `torch` will be installed even if the `torch` extra is not specified.
38-
extras["torch"] = ["torch==2.3.1", "torchvision", "torchaudio", "peft==0.13.2"]
38+
extras["torch"] = ["torch==2.3.1", "torchvision", "torchaudio", "peft==0.14.0"]
3939
extras["test"] = [
4040
"pytest==7.2.1",
4141
"pytest-xdist",

src/huggingface_inference_toolkit/diffusers_utils.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ def is_diffusers_available():
2222

2323

2424
class IEAutoPipelineForText2Image:
25-
def __init__(
26-
self, model_dir: str, device: Union[str, None] = None, **kwargs
27-
): # needs "cuda" for GPU
25+
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs): # needs "cuda" for GPU
2826
dtype = torch.float32
2927
if device == "cuda":
3028
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
@@ -36,9 +34,7 @@ def __init__(
3634
# try to use DPMSolverMultistepScheduler
3735
if isinstance(self.pipeline, StableDiffusionPipeline):
3836
try:
39-
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
40-
self.pipeline.scheduler.config
41-
)
37+
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
4238
except Exception:
4339
pass
4440

@@ -47,6 +43,13 @@ def __call__(
4743
prompt,
4844
**kwargs,
4945
):
46+
if "prompt" in kwargs:
47+
logger.warning(
48+
"prompt has been provided twice, both via arg and kwargs, so the `prompt` arg will be used "
49+
"instead, and the `prompt` in kwargs will be discarded."
50+
)
51+
kwargs.pop("prompt")
52+
5053
# diffusers doesn't support seed but rather the generator kwarg
5154
# see: https://github.com/huggingface/api-inference-community/blob/8e577e2d60957959ba02f474b2913d84a9086b82/docker_images/diffusers/app/pipelines/text_to_image.py#L172-L176
5255
if "seed" in kwargs:
@@ -58,9 +61,16 @@ def __call__(
5861
# TODO: add support for more images (Reason is correct output)
5962
if "num_images_per_prompt" in kwargs:
6063
kwargs.pop("num_images_per_prompt")
61-
logger.warning(
62-
"Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1."
63-
)
64+
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")
65+
66+
if "target_size" in kwargs:
67+
kwargs["height"] = kwargs["target_size"].pop("height", None)
68+
kwargs["width"] = kwargs["target_size"].pop("width", None)
69+
kwargs.pop("target_size")
70+
71+
if "output_type" in kwargs and kwargs["output_type"] != "pil":
72+
kwargs.pop("output_type")
73+
logger.warning("The `output_type` cannot be modified, and PIL will be used by default instead.")
6474

6575
# Call pipeline with parameters
6676
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
+82-29
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
22
from pathlib import Path
3-
from typing import Optional, Union
3+
from typing import Any, Dict, Literal, Optional, Union
44

55
from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
6+
from huggingface_inference_toolkit.sentence_transformers_utils import SENTENCE_TRANSFORMERS_TASKS
67
from huggingface_inference_toolkit.utils import (
78
check_and_register_custom_pipeline_from_directory,
89
get_pipeline,
@@ -12,34 +13,87 @@
1213
class HuggingFaceHandler:
1314
"""
1415
A Default Hugging Face Inference Handler which works with all
15-
transformers pipelines, Sentence Transformers and Optimum.
16+
Transformers, Diffusers, Sentence Transformers and Optimum pipelines.
1617
"""
1718

18-
def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"):
19+
def __init__(
20+
self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt"
21+
) -> None:
1922
self.pipeline = get_pipeline(
20-
model_dir=model_dir,
21-
task=task,
23+
model_dir=model_dir, # type: ignore
24+
task=task, # type: ignore
2225
framework=framework,
2326
trust_remote_code=HF_TRUST_REMOTE_CODE,
2427
)
2528

26-
def __call__(self, data):
29+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
2730
"""
2831
Handles an inference request with input data and makes a prediction.
2932
Args:
3033
:data: (obj): the raw request body data.
3134
:return: prediction output
3235
"""
3336
inputs = data.pop("inputs", data)
34-
parameters = data.pop("parameters", None)
35-
36-
# pass inputs with all kwargs in data
37-
if parameters is not None:
38-
prediction = self.pipeline(inputs, **parameters)
39-
else:
40-
prediction = self.pipeline(inputs)
41-
# postprocess the prediction
42-
return prediction
37+
parameters = data.pop("parameters", {})
38+
39+
# sentence transformers pipelines do not have the `task` arg
40+
if any(isinstance(self.pipeline, v) for v in SENTENCE_TRANSFORMERS_TASKS.values()):
41+
return self.pipeline(**inputs) if isinstance(inputs, dict) else self.pipeline(inputs) # type: ignore
42+
43+
if self.pipeline.task == "question-answering":
44+
if not isinstance(inputs, dict):
45+
raise ValueError(f"inputs must be a dict, but a `{type(inputs)}` was provided instead.")
46+
if not all(k in inputs for k in {"question", "context"}):
47+
raise ValueError(
48+
f"{self.pipeline.task} expects `inputs` to be a dict containing both `question` and "
49+
"`context` as the keys, both of them being either a `str` or a `List[str]`."
50+
)
51+
52+
if self.pipeline.task == "table-question-answering":
53+
if not isinstance(inputs, dict):
54+
raise ValueError(f"inputs must be a dict, but a `{type(inputs)}` was provided instead.")
55+
if "question" in inputs:
56+
inputs["query"] = inputs.pop("question")
57+
if not all(k in inputs for k in {"table", "query"}):
58+
raise ValueError(
59+
f"{self.pipeline.task} expects `inputs` to be a dict containing the keys `table` and "
60+
"either `question` or `query`."
61+
)
62+
63+
if self.pipeline.task.__contains__("translation") or self.pipeline.task in {
64+
"text-generation",
65+
"image-to-text",
66+
"automatic-speech-recognition",
67+
"text-to-audio",
68+
"text-to-speech",
69+
}:
70+
# `generate_kwargs` needs to be a dict, `generation_parameters` is here for forward compatibility
71+
if "generation_parameters" in parameters:
72+
parameters["generate_kwargs"] = parameters.pop("generation_parameters")
73+
74+
if self.pipeline.task.__contains__("translation") or self.pipeline.task in {"text-generation"}:
75+
# flatten the values of `generate_kwargs` as it's not supported as is, but via top-level parameters
76+
generate_kwargs = parameters.pop("generate_kwargs", {})
77+
for key, value in generate_kwargs.items():
78+
parameters[key] = value
79+
80+
if self.pipeline.task.__contains__("zero-shot-classification"):
81+
if "candidateLabels" in parameters:
82+
parameters["candidate_labels"] = parameters.pop("candidateLabels")
83+
if not isinstance(inputs, dict):
84+
inputs = {"sequences": inputs}
85+
if "text" in inputs:
86+
inputs["sequences"] = inputs.pop("text")
87+
if not all(k in inputs for k in {"sequences"}) or not all(k in parameters for k in {"candidate_labels"}):
88+
raise ValueError(
89+
f"{self.pipeline.task} expects `inputs` to be either a string or a dict containing the "
90+
"key `text` or `sequences`, and `parameters` to be a dict containing either `candidate_labels` "
91+
"or `candidateLabels`."
92+
)
93+
94+
return (
95+
self.pipeline(**inputs, **parameters) if isinstance(inputs, dict) else self.pipeline(inputs, **parameters) # type: ignore
96+
)
4397

4498

4599
class VertexAIHandler(HuggingFaceHandler):
@@ -48,21 +102,21 @@ class VertexAIHandler(HuggingFaceHandler):
48102
Vertex AI specific logic for inference.
49103
"""
50104

51-
def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"):
52-
super().__init__(model_dir, task, framework)
105+
def __init__(
106+
self, model_dir: Union[str, Path], task: Union[str, None] = None, framework: Literal["pt"] = "pt"
107+
) -> None:
108+
super().__init__(model_dir=model_dir, task=task, framework=framework)
53109

54-
def __call__(self, data):
110+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
55111
"""
56112
Handles an inference request with input data and makes a prediction.
57113
Args:
58114
:data: (obj): the raw request body data.
59115
:return: prediction output
60116
"""
61117
if "instances" not in data:
62-
raise ValueError(
63-
"The request body must contain a key 'instances' with a list of instances."
64-
)
65-
parameters = data.pop("parameters", None)
118+
raise ValueError("The request body must contain a key 'instances' with a list of instances.")
119+
parameters = data.pop("parameters", {})
66120

67121
predictions = []
68122
# iterate over all instances and make predictions
@@ -74,9 +128,7 @@ def __call__(self, data):
74128
return {"predictions": predictions}
75129

76130

77-
def get_inference_handler_either_custom_or_default_handler(
78-
model_dir: Path, task: Optional[str] = None
79-
):
131+
def get_inference_handler_either_custom_or_default_handler(model_dir: Path, task: Optional[str] = None) -> Any:
80132
"""
81133
Returns the appropriate inference handler based on the given model directory and task.
82134
@@ -88,9 +140,10 @@ def get_inference_handler_either_custom_or_default_handler(
88140
InferenceHandler: The appropriate inference handler based on the given model directory and task.
89141
"""
90142
custom_pipeline = check_and_register_custom_pipeline_from_directory(model_dir)
91-
if custom_pipeline:
143+
if custom_pipeline is not None:
92144
return custom_pipeline
93-
elif os.environ.get("AIP_MODE", None) == "PREDICTION":
145+
146+
if os.environ.get("AIP_MODE", None) == "PREDICTION":
94147
return VertexAIHandler(model_dir=model_dir, task=task)
95-
else:
96-
return HuggingFaceHandler(model_dir=model_dir, task=task)
148+
149+
return HuggingFaceHandler(model_dir=model_dir, task=task)
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
import importlib.util
2+
from typing import Any, Dict, List, Tuple, Union
3+
4+
try:
5+
from typing import Literal
6+
except ImportError:
7+
from typing_extensions import Literal
28

39
_sentence_transformers = importlib.util.find_spec("sentence_transformers") is not None
410

@@ -12,40 +18,73 @@ def is_sentence_transformers_available():
1218

1319

1420
class SentenceSimilarityPipeline:
15-
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
21+
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None:
22+
# `device` needs to be set to "cuda" for GPU
1623
self.model = SentenceTransformer(model_dir, device=device, **kwargs)
1724

18-
def __call__(self, inputs=None):
19-
embeddings1 = self.model.encode(
20-
inputs["source_sentence"], convert_to_tensor=True
21-
)
22-
embeddings2 = self.model.encode(inputs["sentences"], convert_to_tensor=True)
25+
def __call__(self, source_sentence: str, sentences: List[str]) -> Dict[str, float]:
26+
embeddings1 = self.model.encode(source_sentence, convert_to_tensor=True)
27+
embeddings2 = self.model.encode(sentences, convert_to_tensor=True)
2328
similarities = util.pytorch_cos_sim(embeddings1, embeddings2).tolist()[0]
2429
return {"similarities": similarities}
2530

2631

2732
class SentenceEmbeddingPipeline:
28-
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
33+
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None:
34+
# `device` needs to be set to "cuda" for GPU
2935
self.model = SentenceTransformer(model_dir, device=device, **kwargs)
3036

31-
def __call__(self, inputs):
32-
embeddings = self.model.encode(inputs).tolist()
37+
def __call__(self, sentences: Union[str, List[str]]) -> Dict[str, List[float]]:
38+
embeddings = self.model.encode(sentences).tolist()
3339
return {"embeddings": embeddings}
3440

3541

36-
class RankingPipeline:
37-
def __init__(self, model_dir: str, device: str = None, **kwargs): # needs "cuda" for GPU
42+
class SentenceRankingPipeline:
43+
def __init__(self, model_dir: str, device: Union[str, None] = None, **kwargs: Any) -> None:
44+
# `device` needs to be set to "cuda" for GPU
3845
self.model = CrossEncoder(model_dir, device=device, **kwargs)
3946

40-
def __call__(self, inputs):
41-
scores = self.model.predict(inputs).tolist()
42-
return {"scores": scores}
47+
def __call__(
48+
self,
49+
sentences: Union[Tuple[str, str], List[str], List[List[str]], List[Tuple[str, str]], None] = None,
50+
query: Union[str, None] = None,
51+
texts: Union[List[str], None] = None,
52+
return_documents: bool = False,
53+
) -> Union[Dict[str, List[float]], List[Dict[Literal["index", "score", "text"], Any]]]:
54+
if all(x is not None for x in [sentences, query, texts]):
55+
raise ValueError(
56+
f"The provided payload contains {sentences=} (i.e. 'inputs'), {query=}, and {texts=}"
57+
" but all of those cannot be provided, you should provide either only 'sentences' i.e. 'inputs'"
58+
" of both 'query' and 'texts' to run the ranking task."
59+
)
60+
61+
if all(x is None for x in [sentences, query, texts]):
62+
raise ValueError(
63+
"No inputs have been provided within the input payload, make sure that the input payload"
64+
" contains either 'sentences' i.e. 'inputs', or both 'query' and 'texts' to run the ranking task."
65+
)
66+
67+
if sentences is not None:
68+
scores = self.model.predict(sentences).tolist()
69+
return {"scores": scores}
70+
71+
if query is None or not isinstance(query, str):
72+
raise ValueError(f"Provided {query=} but a non-empty string should be provided instead.")
73+
74+
if texts is None or not isinstance(texts, list) or not all(isinstance(text, str) for text in texts):
75+
raise ValueError(f"Provided {texts=}, but a list of non-empty strings should be provided instead.")
76+
77+
scores = self.model.rank(query, texts, return_documents=return_documents)
78+
# rename "corpus_id" key to "index" for all scores to match TEI
79+
for score in scores:
80+
score["index"] = score.pop("corpus_id") # type: ignore
81+
return scores # type: ignore
4382

4483

4584
SENTENCE_TRANSFORMERS_TASKS = {
4685
"sentence-similarity": SentenceSimilarityPipeline,
4786
"sentence-embeddings": SentenceEmbeddingPipeline,
48-
"sentence-ranking": RankingPipeline,
87+
"sentence-ranking": SentenceRankingPipeline,
4988
}
5089

5190

@@ -56,9 +95,5 @@ def get_sentence_transformers_pipeline(task=None, model_dir=None, device=-1, **k
5695
kwargs.pop("framework", None)
5796

5897
if task not in SENTENCE_TRANSFORMERS_TASKS:
59-
raise ValueError(
60-
f"Unknown task {task}. Available tasks are: {', '.join(SENTENCE_TRANSFORMERS_TASKS.keys())}"
61-
)
62-
return SENTENCE_TRANSFORMERS_TASKS[task](
63-
model_dir=model_dir, device=device, **kwargs
64-
)
98+
raise ValueError(f"Unknown task {task}. Available tasks are: {', '.join(SENTENCE_TRANSFORMERS_TASKS.keys())}")
99+
return SENTENCE_TRANSFORMERS_TASKS[task](model_dir=model_dir, device=device, **kwargs)

0 commit comments

Comments
 (0)