Skip to content

Commit 5c3a252

Browse files
[Inference Providers] add support for ASR with replicate (#3538)
1 parent 5573ce7 commit 5c3a252

File tree

7 files changed

+143
-4
lines changed

7 files changed

+143
-4
lines changed

docs/source/en/guides/inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ For more details, refer to the [Inference Providers pricing documentation](https
196196
| --------------------------------------------------- | ----------------- | -------- | -------- | ------ | ------ | -------------- | ------------ | ---- | ------------ | ---------- | ---------------- | --------- | ------ | ---------- | --------- | --------- | --------- | -------- | --------- | ---- |
197197
| [`~InferenceClient.audio_classification`] |||||||||||||||||||||
198198
| [`~InferenceClient.audio_to_audio`] |||||||||||||||||||||
199-
| [`~InferenceClient.automatic_speech_recognition`] ||||||||||||||| ||||||
199+
| [`~InferenceClient.automatic_speech_recognition`] ||||||||||||||| ||||||
200200
| [`~InferenceClient.chat_completion`] |||||||||||||||||||||
201201
| [`~InferenceClient.document_question_answering`] |||||||||||||||||||||
202202
| [`~InferenceClient.feature_extraction`] |||||||||||||||||||||

src/huggingface_hub/inference/_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def automatic_speech_recognition(
452452
api_key=self.token,
453453
)
454454
response = self._inner_post(request_parameters)
455+
response = provider_helper.get_response(response, request_params=request_parameters)
455456
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
456457

457458
@overload

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ async def automatic_speech_recognition(
472472
api_key=self.token,
473473
)
474474
response = await self._inner_post(request_parameters)
475+
response = provider_helper.get_response(response, request_params=request_parameters)
475476
return AutomaticSpeechRecognitionOutput.parse_obj_as_instance(response)
476477

477478
@overload

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@
3939
from .nscale import NscaleConversationalTask, NscaleTextToImageTask
4040
from .openai import OpenAIConversationalTask
4141
from .publicai import PublicAIConversationalTask
42-
from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask
42+
from .replicate import (
43+
ReplicateAutomaticSpeechRecognitionTask,
44+
ReplicateImageToImageTask,
45+
ReplicateTask,
46+
ReplicateTextToImageTask,
47+
ReplicateTextToSpeechTask,
48+
)
4349
from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask
4450
from .scaleway import ScalewayConversationalTask, ScalewayFeatureExtractionTask
4551
from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask
@@ -170,6 +176,7 @@
170176
"conversational": PublicAIConversationalTask(),
171177
},
172178
"replicate": {
179+
"automatic-speech-recognition": ReplicateAutomaticSpeechRecognitionTask(),
173180
"image-to-image": ReplicateImageToImageTask(),
174181
"text-to-image": ReplicateTextToImageTask(),
175182
"text-to-speech": ReplicateTextToSpeechTask(),

src/huggingface_hub/inference/_providers/fal_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def get_response(self, response: Union[bytes, dict], request_params: Optional[Re
112112
text = _as_dict(response)["text"]
113113
if not isinstance(text, str):
114114
raise ValueError(f"Unexpected output format from FalAI API. Expected string, got {type(text)}.")
115-
return text
115+
return {"text": text}
116116

117117

118118
class FalAITextToImageTask(FalAITask):

src/huggingface_hub/inference/_providers/replicate.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,67 @@ def _prepare_payload_as_dict(
7272
return payload
7373

7474

75+
class ReplicateAutomaticSpeechRecognitionTask(ReplicateTask):
76+
def __init__(self) -> None:
77+
super().__init__("automatic-speech-recognition")
78+
79+
def _prepare_payload_as_dict(
80+
self,
81+
inputs: Any,
82+
parameters: dict,
83+
provider_mapping_info: InferenceProviderMapping,
84+
) -> Optional[dict]:
85+
mapped_model = provider_mapping_info.provider_id
86+
audio_url = _as_url(inputs, default_mime_type="audio/wav")
87+
88+
payload: dict[str, Any] = {
89+
"input": {
90+
**{"audio": audio_url},
91+
**filter_none(parameters),
92+
}
93+
}
94+
95+
if ":" in mapped_model:
96+
payload["version"] = mapped_model.split(":", 1)[1]
97+
98+
return payload
99+
100+
def get_response(self, response: Union[bytes, dict], request_params: Optional[RequestParameters] = None) -> Any:
101+
response_dict = _as_dict(response)
102+
output = response_dict.get("output")
103+
104+
if isinstance(output, str):
105+
return {"text": output}
106+
107+
if isinstance(output, list) and output:
108+
first_item = output[0]
109+
if isinstance(first_item, str):
110+
return {"text": first_item}
111+
if isinstance(first_item, dict):
112+
output = first_item
113+
114+
text: Optional[str] = None
115+
if isinstance(output, dict):
116+
transcription = output.get("transcription")
117+
if isinstance(transcription, str):
118+
text = transcription
119+
120+
translation = output.get("translation")
121+
if isinstance(translation, str):
122+
text = translation
123+
124+
txt_file = output.get("txt_file")
125+
if isinstance(txt_file, str):
126+
text_response = get_session().get(txt_file)
127+
text_response.raise_for_status()
128+
text = text_response.text
129+
130+
if text is not None:
131+
return {"text": text}
132+
133+
raise ValueError("Received malformed response from Replicate automatic-speech-recognition API")
134+
135+
75136
class ReplicateImageToImageTask(ReplicateTask):
76137
def __init__(self):
77138
super().__init__("image-to-image")

tests/test_inference_providers.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from huggingface_hub.inference._providers.openai import OpenAIConversationalTask
4949
from huggingface_hub.inference._providers.publicai import PublicAIConversationalTask
5050
from huggingface_hub.inference._providers.replicate import (
51+
ReplicateAutomaticSpeechRecognitionTask,
5152
ReplicateImageToImageTask,
5253
ReplicateTask,
5354
ReplicateTextToSpeechTask,
@@ -396,7 +397,7 @@ def test_automatic_speech_recognition_payload(self):
396397
def test_automatic_speech_recognition_response(self):
397398
helper = FalAIAutomaticSpeechRecognitionTask()
398399
response = helper.get_response({"text": "Hello world"})
399-
assert response == "Hello world"
400+
assert response == {"text": "Hello world"}
400401

401402
with pytest.raises(ValueError):
402403
helper.get_response({"text": 123})
@@ -1423,6 +1424,74 @@ def test_prepare_url(self):
14231424

14241425

14251426
class TestReplicateProvider:
1427+
def test_automatic_speech_recognition_payload(self):
1428+
helper = ReplicateAutomaticSpeechRecognitionTask()
1429+
1430+
mapping_info = InferenceProviderMapping(
1431+
provider="replicate",
1432+
hf_model_id="openai/whisper-large-v3",
1433+
providerId="openai/whisper-large-v3",
1434+
task="automatic-speech-recognition",
1435+
status="live",
1436+
)
1437+
1438+
payload = helper._prepare_payload_as_dict(
1439+
"https://example.com/audio.mp3",
1440+
{"language": "en"},
1441+
mapping_info,
1442+
)
1443+
1444+
assert payload == {"input": {"audio": "https://example.com/audio.mp3", "language": "en"}}
1445+
1446+
mapping_with_version = InferenceProviderMapping(
1447+
provider="replicate",
1448+
hf_model_id="openai/whisper-large-v3",
1449+
providerId="openai/whisper-large-v3:123",
1450+
task="automatic-speech-recognition",
1451+
status="live",
1452+
)
1453+
1454+
audio_bytes = b"dummy-audio"
1455+
encoded_audio = base64.b64encode(audio_bytes).decode()
1456+
1457+
payload = helper._prepare_payload_as_dict(
1458+
audio_bytes,
1459+
{},
1460+
mapping_with_version,
1461+
)
1462+
1463+
assert payload == {
1464+
"input": {"audio": f"data:audio/wav;base64,{encoded_audio}"},
1465+
"version": "123",
1466+
}
1467+
1468+
def test_automatic_speech_recognition_get_response_variants(self, mocker):
1469+
helper = ReplicateAutomaticSpeechRecognitionTask()
1470+
1471+
result = helper.get_response({"output": "hello"})
1472+
assert result == {"text": "hello"}
1473+
1474+
result = helper.get_response({"output": ["hello-world"]})
1475+
assert result == {"text": "hello-world"}
1476+
1477+
result = helper.get_response({"output": {"transcription": "bonjour"}})
1478+
assert result == {"text": "bonjour"}
1479+
1480+
result = helper.get_response({"output": {"translation": "hola"}})
1481+
assert result == {"text": "hola"}
1482+
1483+
mock_session = mocker.patch("huggingface_hub.inference._providers.replicate.get_session")
1484+
mock_response = mocker.Mock(text="file text")
1485+
mock_response.raise_for_status = lambda: None
1486+
mock_session.return_value.get.return_value = mock_response
1487+
1488+
result = helper.get_response({"output": {"txt_file": "https://example.com/output.txt"}})
1489+
mock_session.return_value.get.assert_called_once_with("https://example.com/output.txt")
1490+
assert result == {"text": "file text"}
1491+
1492+
with pytest.raises(ValueError):
1493+
helper.get_response({"output": 123})
1494+
14261495
def test_prepare_headers(self):
14271496
helper = ReplicateTask("text-to-image")
14281497
headers = helper._prepare_headers({}, "my_replicate_key")

0 commit comments

Comments
 (0)