Skip to content

Commit c6bc223

Browse files
authored
Skip pipeline.task check for diffusers and sentence-transformers (#101)
* Skip `task` for `diffusers` and `sentence-transformers` in `HuggingFaceHandler` * Remove unused `SENTENCE_TRANSFORMERS_TASKS` import * Fix `HuggingFaceHandler` for `diffusers` and `sentence-transformers`
1 parent 84cabf7 commit c6bc223

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/huggingface_inference_toolkit/handler.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,19 @@ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
3636
inputs = data.pop("inputs", data)
3737
parameters = data.pop("parameters", {})
3838

39-
# sentence transformers pipelines do not have the `task` arg
40-
if any(isinstance(self.pipeline, v) for v in SENTENCE_TRANSFORMERS_TASKS.values()):
41-
return self.pipeline(**inputs) if isinstance(inputs, dict) else self.pipeline(inputs) # type: ignore
39+
# diffusers and sentence transformers pipelines do not have the `task` arg
40+
if not hasattr(self.pipeline, "task"):
41+
# sentence transformers paramters not supported yet
42+
if any(isinstance(self.pipeline, v) for v in SENTENCE_TRANSFORMERS_TASKS.values()):
43+
return ( # type: ignore
44+
self.pipeline(**inputs) if isinstance(inputs, dict) else self.pipeline(inputs)
45+
)
46+
# diffusers does support kwargs
47+
return ( # type: ignore
48+
self.pipeline(**inputs, **parameters)
49+
if isinstance(inputs, dict)
50+
else self.pipeline(inputs, **parameters)
51+
)
4252

4353
if self.pipeline.task == "question-answering":
4454
if not isinstance(inputs, dict):

0 commit comments

Comments
 (0)