|
48 | 48 | from huggingface_hub.inference._providers.openai import OpenAIConversationalTask |
49 | 49 | from huggingface_hub.inference._providers.publicai import PublicAIConversationalTask |
50 | 50 | from huggingface_hub.inference._providers.replicate import ( |
| 51 | + ReplicateAutomaticSpeechRecognitionTask, |
51 | 52 | ReplicateImageToImageTask, |
52 | 53 | ReplicateTask, |
53 | 54 | ReplicateTextToSpeechTask, |
@@ -396,7 +397,7 @@ def test_automatic_speech_recognition_payload(self): |
396 | 397 | def test_automatic_speech_recognition_response(self): |
397 | 398 | helper = FalAIAutomaticSpeechRecognitionTask() |
398 | 399 | response = helper.get_response({"text": "Hello world"}) |
399 | | - assert response == "Hello world" |
| 400 | + assert response == {"text": "Hello world"} |
400 | 401 |
|
401 | 402 | with pytest.raises(ValueError): |
402 | 403 | helper.get_response({"text": 123}) |
@@ -1423,6 +1424,74 @@ def test_prepare_url(self): |
1423 | 1424 |
|
1424 | 1425 |
|
1425 | 1426 | class TestReplicateProvider: |
| 1427 | + def test_automatic_speech_recognition_payload(self): |
| 1428 | + helper = ReplicateAutomaticSpeechRecognitionTask() |
| 1429 | + |
| 1430 | + mapping_info = InferenceProviderMapping( |
| 1431 | + provider="replicate", |
| 1432 | + hf_model_id="openai/whisper-large-v3", |
| 1433 | + providerId="openai/whisper-large-v3", |
| 1434 | + task="automatic-speech-recognition", |
| 1435 | + status="live", |
| 1436 | + ) |
| 1437 | + |
| 1438 | + payload = helper._prepare_payload_as_dict( |
| 1439 | + "https://example.com/audio.mp3", |
| 1440 | + {"language": "en"}, |
| 1441 | + mapping_info, |
| 1442 | + ) |
| 1443 | + |
| 1444 | + assert payload == {"input": {"audio": "https://example.com/audio.mp3", "language": "en"}} |
| 1445 | + |
| 1446 | + mapping_with_version = InferenceProviderMapping( |
| 1447 | + provider="replicate", |
| 1448 | + hf_model_id="openai/whisper-large-v3", |
| 1449 | + providerId="openai/whisper-large-v3:123", |
| 1450 | + task="automatic-speech-recognition", |
| 1451 | + status="live", |
| 1452 | + ) |
| 1453 | + |
| 1454 | + audio_bytes = b"dummy-audio" |
| 1455 | + encoded_audio = base64.b64encode(audio_bytes).decode() |
| 1456 | + |
| 1457 | + payload = helper._prepare_payload_as_dict( |
| 1458 | + audio_bytes, |
| 1459 | + {}, |
| 1460 | + mapping_with_version, |
| 1461 | + ) |
| 1462 | + |
| 1463 | + assert payload == { |
| 1464 | + "input": {"audio": f"data:audio/wav;base64,{encoded_audio}"}, |
| 1465 | + "version": "123", |
| 1466 | + } |
| 1467 | + |
| 1468 | + def test_automatic_speech_recognition_get_response_variants(self, mocker): |
| 1469 | + helper = ReplicateAutomaticSpeechRecognitionTask() |
| 1470 | + |
| 1471 | + result = helper.get_response({"output": "hello"}) |
| 1472 | + assert result == {"text": "hello"} |
| 1473 | + |
| 1474 | + result = helper.get_response({"output": ["hello-world"]}) |
| 1475 | + assert result == {"text": "hello-world"} |
| 1476 | + |
| 1477 | + result = helper.get_response({"output": {"transcription": "bonjour"}}) |
| 1478 | + assert result == {"text": "bonjour"} |
| 1479 | + |
| 1480 | + result = helper.get_response({"output": {"translation": "hola"}}) |
| 1481 | + assert result == {"text": "hola"} |
| 1482 | + |
| 1483 | + mock_session = mocker.patch("huggingface_hub.inference._providers.replicate.get_session") |
| 1484 | + mock_response = mocker.Mock(text="file text") |
| 1485 | + mock_response.raise_for_status = lambda: None |
| 1486 | + mock_session.return_value.get.return_value = mock_response |
| 1487 | + |
| 1488 | + result = helper.get_response({"output": {"txt_file": "https://example.com/output.txt"}}) |
| 1489 | + mock_session.return_value.get.assert_called_once_with("https://example.com/output.txt") |
| 1490 | + assert result == {"text": "file text"} |
| 1491 | + |
| 1492 | + with pytest.raises(ValueError): |
| 1493 | + helper.get_response({"output": 123}) |
| 1494 | + |
1426 | 1495 | def test_prepare_headers(self): |
1427 | 1496 | helper = ReplicateTask("text-to-image") |
1428 | 1497 | headers = helper._prepare_headers({}, "my_replicate_key") |
|
0 commit comments