diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 5f374dc2ad..f1856156b2 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -79,6 +79,10 @@ "content_filter": types.FinishReason.SAFETY, } +_SUPPORTED_FILE_CONTENT_MIME_TYPES = set( + ["application/pdf", "application/json", "text/plain"] +) + class ChatCompletionFileUrlObject(TypedDict, total=False): file_data: str @@ -336,7 +340,7 @@ def _get_content( "type": "audio_url", "audio_url": {"url": data_uri, "format": format_type}, }) - elif part.inline_data.mime_type == "application/pdf": + elif part.inline_data.mime_type in _SUPPORTED_FILE_CONTENT_MIME_TYPES: format_type = part.inline_data.mime_type content_objects.append({ "type": "file", diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 1e1d47002d..c60d7fb20e 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -87,6 +87,32 @@ ), ) +FILE_URI_TEST_CASES = [ + pytest.param("gs://bucket/document.pdf", "application/pdf", id="pdf"), + pytest.param("gs://bucket/data.json", "application/json", id="json"), + pytest.param("gs://bucket/data.txt", "text/plain", id="txt"), +] + +FILE_BYTES_TEST_CASES = [ + pytest.param( + b"test_pdf_data", + "application/pdf", + "data:application/pdf;base64,dGVzdF9wZGZfZGF0YQ==", + id="pdf", + ), + pytest.param( + b'{"hello":"world"}', + "application/json", + "data:application/json;base64,eyJoZWxsbyI6IndvcmxkIn0=", + id="json", + ), + pytest.param( + b"hello world", + "text/plain", + "data:text/plain;base64,aGVsbG8gd29ybGQ=", + id="txt", + ), +] STREAMING_MODEL_RESPONSE = [ ModelResponse( @@ -1088,10 +1114,11 @@ def test_content_to_message_param_user_message(): assert message["content"] == "Test prompt" -def test_content_to_message_param_user_message_with_file_uri(): - file_part = types.Part.from_uri( - file_uri="gs://bucket/document.pdf", mime_type="application/pdf" - ) +@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES) +def test_content_to_message_param_user_message_with_file_uri( + file_uri, mime_type +): + file_part = types.Part.from_uri(file_uri=file_uri, mime_type=mime_type) content = types.Content( role="user", parts=[ @@ -1106,14 +1133,15 @@ def test_content_to_message_param_user_message_with_file_uri(): assert message["content"][0]["type"] == "text" assert message["content"][0]["text"] == "Summarize this file." assert message["content"][1]["type"] == "file" - assert message["content"][1]["file"]["file_id"] == "gs://bucket/document.pdf" - assert message["content"][1]["file"]["format"] == "application/pdf" + assert message["content"][1]["file"]["file_id"] == file_uri + assert message["content"][1]["file"]["format"] == mime_type -def test_content_to_message_param_user_message_file_uri_only(): - file_part = types.Part.from_uri( - file_uri="gs://bucket/only.pdf", mime_type="application/pdf" - ) +@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES) +def test_content_to_message_param_user_message_file_uri_only( + file_uri, mime_type +): + file_part = types.Part.from_uri(file_uri=file_uri, mime_type=mime_type) content = types.Content( role="user", parts=[ @@ -1125,8 +1153,8 @@ def test_content_to_message_param_user_message_file_uri_only(): assert message["role"] == "user" assert isinstance(message["content"], list) assert message["content"][0]["type"] == "file" - assert message["content"][0]["file"]["file_id"] == "gs://bucket/only.pdf" - assert message["content"][0]["file"]["format"] == "application/pdf" + assert message["content"][0]["file"]["file_id"] == file_uri + assert message["content"][0]["file"]["format"] == mime_type def test_content_to_message_param_multi_part_function_response(): @@ -1294,30 +1322,24 @@ def test_get_content_video(): assert content[0]["video_url"]["format"] == "video/mp4" -def test_get_content_pdf(): - parts = [ - types.Part.from_bytes(data=b"test_pdf_data", mime_type="application/pdf") - ] +@pytest.mark.parametrize( + "file_data,mime_type,expected_base64", FILE_BYTES_TEST_CASES +) +def test_get_content_file_bytes(file_data, mime_type, expected_base64): + parts = [types.Part.from_bytes(data=file_data, mime_type=mime_type)] content = _get_content(parts) assert content[0]["type"] == "file" - assert ( - content[0]["file"]["file_data"] - == "data:application/pdf;base64,dGVzdF9wZGZfZGF0YQ==" - ) - assert content[0]["file"]["format"] == "application/pdf" + assert content[0]["file"]["file_data"] == expected_base64 + assert content[0]["file"]["format"] == mime_type -def test_get_content_file_uri(): - parts = [ - types.Part.from_uri( - file_uri="gs://bucket/document.pdf", - mime_type="application/pdf", - ) - ] +@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES) +def test_get_content_file_uri(file_uri, mime_type): + parts = [types.Part.from_uri(file_uri=file_uri, mime_type=mime_type)] content = _get_content(parts) assert content[0]["type"] == "file" - assert content[0]["file"]["file_id"] == "gs://bucket/document.pdf" - assert content[0]["file"]["format"] == "application/pdf" + assert content[0]["file"]["file_id"] == file_uri + assert content[0]["file"]["format"] == mime_type def test_get_content_audio():