Skip to content

Commit

Permalink
[InferenceClient] Better handling of task parameters (#2812)
Browse files Browse the repository at this point in the history
* fix discrepancies for text-to-image parameters and add extra_parameters argument

* revamp inference providers tests

* nit

* fix test

* add examples with extra parameters

* remove nested dict image size

* filter out @deprecated params

* fix

* fix test

* fixing bugs introduced by the LLM
  • Loading branch information
hanouticelina authored Jan 31, 2025
1 parent 37e79dc commit 07e1adb
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 67 deletions.
54 changes: 44 additions & 10 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
TextGenerationInputGrammarType,
TextGenerationOutput,
TextGenerationStreamOutput,
TextToImageTargetSize,
TextToSpeechEarlyStoppingEnum,
TokenClassificationAggregationStrategy,
TokenClassificationOutputElement,
Expand Down Expand Up @@ -474,8 +473,6 @@ def automatic_speech_recognition(
model (`str`, *optional*):
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
parameters (Dict[str, Any], *optional*):
Additional parameters to pass to the model.
Returns:
[`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
Expand Down Expand Up @@ -2392,9 +2389,8 @@ def text_to_image(
guidance_scale: Optional[float] = None,
model: Optional[str] = None,
scheduler: Optional[str] = None,
target_size: Optional[TextToImageTargetSize] = None,
seed: Optional[int] = None,
**kwargs,
extra_parameters: Optional[Dict[str, Any]] = None,
) -> "Image":
"""
Generate an image based on a given text using a specified model.
Expand Down Expand Up @@ -2426,10 +2422,11 @@ def text_to_image(
Defaults to None.
scheduler (`str`, *optional*):
Override the scheduler with a compatible one.
target_size (`TextToImageTargetSize`, *optional*):
The size in pixel of the output image
seed (`int`, *optional*):
Seed for the random number generator.
extra_parameters (`Dict[str, Any]`, *optional*):
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
for supported parameters.
Returns:
`Image`: The generated image.
Expand Down Expand Up @@ -2482,6 +2479,21 @@ def text_to_image(
... )
>>> image.save("astronaut.png")
```
Example using Replicate provider with extra parameters
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient(
... provider="replicate", # Use replicate provider
... api_key="hf_...", # Pass your HF token
... )
>>> image = client.text_to_image(
... "An astronaut riding a horse on the moon.",
... model="black-forest-labs/FLUX.1-schnell",
... extra_parameters={"output_quality": 100},
... )
>>> image.save("astronaut.png")
```
"""
provider_helper = get_provider_helper(self.provider, task="text-to-image")
request_parameters = provider_helper.prepare_request(
Expand All @@ -2493,9 +2505,8 @@ def text_to_image(
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"scheduler": scheduler,
"target_size": target_size,
"seed": seed,
**kwargs,
**(extra_parameters or {}),
},
headers=self.headers,
model=model or self.model,
Expand All @@ -2515,6 +2526,7 @@ def text_to_video(
num_frames: Optional[float] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
extra_parameters: Optional[Dict[str, Any]] = None,
) -> bytes:
"""
Generate a video based on a given text.
Expand All @@ -2538,6 +2550,9 @@ def text_to_video(
expense of slower inference.
seed (`int`, *optional*):
Seed for the random number generator.
extra_parameters (`Dict[str, Any]`, *optional*):
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
for supported parameters.
Returns:
`bytes`: The generated video.
Expand Down Expand Up @@ -2583,6 +2598,7 @@ def text_to_video(
"num_frames": num_frames,
"num_inference_steps": num_inference_steps,
"seed": seed,
**(extra_parameters or {}),
},
headers=self.headers,
model=model or self.model,
Expand Down Expand Up @@ -2613,6 +2629,7 @@ def text_to_speech(
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
use_cache: Optional[bool] = None,
extra_parameters: Optional[Dict[str, Any]] = None,
) -> bytes:
"""
Synthesize an audio of a voice pronouncing a given text.
Expand Down Expand Up @@ -2670,7 +2687,9 @@ def text_to_speech(
paper](https://hf.co/papers/2202.00666) for more details.
use_cache (`bool`, *optional*):
Whether the model should use the past last key/values attentions to speed up decoding
extra_parameters (`Dict[str, Any]`, *optional*):
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
for supported parameters.
Returns:
`bytes`: The generated audio.
Expand Down Expand Up @@ -2717,6 +2736,20 @@ def text_to_speech(
... )
>>> Path("hello_world.flac").write_bytes(audio)
```
Example using Replicate provider with extra parameters
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient(
... provider="replicate", # Use replicate provider
... api_key="hf_...", # Pass your HF token
... )
>>> audio = client.text_to_speech(
... "Hello, my name is Kororo, an awesome text-to-speech model.",
... model="hexgrad/Kokoro-82M",
... extra_parameters={"voice": "af_nicole"},
... )
>>> Path("hello.flac").write_bytes(audio)
```
"""
provider_helper = get_provider_helper(self.provider, task="text-to-speech")
request_parameters = provider_helper.prepare_request(
Expand All @@ -2738,6 +2771,7 @@ def text_to_speech(
"top_p": top_p,
"typical_p": typical_p,
"use_cache": use_cache,
**(extra_parameters or {}),
},
headers=self.headers,
model=model or self.model,
Expand Down
54 changes: 44 additions & 10 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
TextGenerationInputGrammarType,
TextGenerationOutput,
TextGenerationStreamOutput,
TextToImageTargetSize,
TextToSpeechEarlyStoppingEnum,
TokenClassificationAggregationStrategy,
TokenClassificationOutputElement,
Expand Down Expand Up @@ -507,8 +506,6 @@ async def automatic_speech_recognition(
model (`str`, *optional*):
The model to use for ASR. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
Inference Endpoint. If not provided, the default recommended model for ASR will be used.
parameters (Dict[str, Any], *optional*):
Additional parameters to pass to the model.
Returns:
[`AutomaticSpeechRecognitionOutput`]: An item containing the transcribed text and optionally the timestamp chunks.
Expand Down Expand Up @@ -2448,9 +2445,8 @@ async def text_to_image(
guidance_scale: Optional[float] = None,
model: Optional[str] = None,
scheduler: Optional[str] = None,
target_size: Optional[TextToImageTargetSize] = None,
seed: Optional[int] = None,
**kwargs,
extra_parameters: Optional[Dict[str, Any]] = None,
) -> "Image":
"""
Generate an image based on a given text using a specified model.
Expand Down Expand Up @@ -2482,10 +2478,11 @@ async def text_to_image(
Defaults to None.
scheduler (`str`, *optional*):
Override the scheduler with a compatible one.
target_size (`TextToImageTargetSize`, *optional*):
The size in pixel of the output image
seed (`int`, *optional*):
Seed for the random number generator.
extra_parameters (`Dict[str, Any]`, *optional*):
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
for supported parameters.
Returns:
`Image`: The generated image.
Expand Down Expand Up @@ -2539,6 +2536,21 @@ async def text_to_image(
... )
>>> image.save("astronaut.png")
```
Example using Replicate provider with extra parameters
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient(
... provider="replicate", # Use replicate provider
... api_key="hf_...", # Pass your HF token
... )
>>> image = client.text_to_image(
... "An astronaut riding a horse on the moon.",
... model="black-forest-labs/FLUX.1-schnell",
... extra_parameters={"output_quality": 100},
... )
>>> image.save("astronaut.png")
```
"""
provider_helper = get_provider_helper(self.provider, task="text-to-image")
request_parameters = provider_helper.prepare_request(
Expand All @@ -2550,9 +2562,8 @@ async def text_to_image(
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"scheduler": scheduler,
"target_size": target_size,
"seed": seed,
**kwargs,
**(extra_parameters or {}),
},
headers=self.headers,
model=model or self.model,
Expand All @@ -2572,6 +2583,7 @@ async def text_to_video(
num_frames: Optional[float] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
extra_parameters: Optional[Dict[str, Any]] = None,
) -> bytes:
"""
Generate a video based on a given text.
Expand All @@ -2595,6 +2607,9 @@ async def text_to_video(
expense of slower inference.
seed (`int`, *optional*):
Seed for the random number generator.
extra_parameters (`Dict[str, Any]`, *optional*):
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
for supported parameters.
Returns:
`bytes`: The generated video.
Expand Down Expand Up @@ -2640,6 +2655,7 @@ async def text_to_video(
"num_frames": num_frames,
"num_inference_steps": num_inference_steps,
"seed": seed,
**(extra_parameters or {}),
},
headers=self.headers,
model=model or self.model,
Expand Down Expand Up @@ -2670,6 +2686,7 @@ async def text_to_speech(
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
use_cache: Optional[bool] = None,
extra_parameters: Optional[Dict[str, Any]] = None,
) -> bytes:
"""
Synthesize an audio of a voice pronouncing a given text.
Expand Down Expand Up @@ -2727,7 +2744,9 @@ async def text_to_speech(
paper](https://hf.co/papers/2202.00666) for more details.
use_cache (`bool`, *optional*):
Whether the model should use the past last key/values attentions to speed up decoding
extra_parameters (`Dict[str, Any]`, *optional*):
Additional provider-specific parameters to pass to the model. Refer to the provider's documentation
for supported parameters.
Returns:
`bytes`: The generated audio.
Expand Down Expand Up @@ -2775,6 +2794,20 @@ async def text_to_speech(
... )
>>> Path("hello_world.flac").write_bytes(audio)
```
Example using Replicate provider with extra parameters
```py
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient(
... provider="replicate", # Use replicate provider
... api_key="hf_...", # Pass your HF token
... )
>>> audio = client.text_to_speech(
... "Hello, my name is Kororo, an awesome text-to-speech model.",
... model="hexgrad/Kokoro-82M",
... extra_parameters={"voice": "af_nicole"},
... )
>>> Path("hello.flac").write_bytes(audio)
```
"""
provider_helper = get_provider_helper(self.provider, task="text-to-speech")
request_parameters = provider_helper.prepare_request(
Expand All @@ -2796,6 +2829,7 @@ async def text_to_speech(
"top_p": top_p,
"typical_p": typical_p,
"use_cache": use_cache,
**(extra_parameters or {}),
},
headers=self.headers,
model=model or self.model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ class TextToImageParameters(BaseInferenceType):
"""Override the scheduler with a compatible one."""
seed: Optional[int] = None
"""Seed for the random number generator."""
target_size: Optional[TextToImageTargetSize] = None
"""The size in pixel of the output image"""


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/inference/_providers/fal_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self):

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
if "image_size" not in parameters and "width" in parameters and "height" in parameters:
if "width" in parameters and "height" in parameters:
parameters["image_size"] = {
"width": parameters.pop("width"),
"height": parameters.pop("height"),
Expand Down
8 changes: 7 additions & 1 deletion src/huggingface_hub/inference/_providers/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,16 @@ def __init__(self):
super().__init__("text-to-image")

def _prepare_payload(self, inputs: Any, parameters: Dict[str, Any]) -> Dict[str, Any]:
parameters = {k: v for k, v in parameters.items() if v is not None}
if "num_inference_steps" in parameters:
parameters["steps"] = parameters.pop("num_inference_steps")
if "guidance_scale" in parameters:
parameters["guidance"] = parameters.pop("guidance_scale")

payload = {
"prompt": inputs,
"response_format": "base64",
**{k: v for k, v in parameters.items() if v is not None},
**parameters,
}
return payload

Expand Down
Loading

0 comments on commit 07e1adb

Please sign in to comment.