Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
"content_filter": types.FinishReason.SAFETY,
}

_SUPPORTED_FILE_CONTENT_MIME_TYPES = set(
["application/pdf", "application/json"]
)


class ChatCompletionFileUrlObject(TypedDict, total=False):
file_data: str
Expand Down Expand Up @@ -336,7 +340,7 @@ def _get_content(
"type": "audio_url",
"audio_url": {"url": data_uri, "format": format_type},
})
elif part.inline_data.mime_type == "application/pdf":
elif part.inline_data.mime_type in _SUPPORTED_FILE_CONTENT_MIME_TYPES:
format_type = part.inline_data.mime_type
content_objects.append({
"type": "file",
Expand Down
53 changes: 35 additions & 18 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@
),
)

FILE_URI_TEST_CASES = [
pytest.param("gs://bucket/document.pdf", "application/pdf", id="pdf"),
pytest.param("gs://bucket/data.json", "application/json", id="json"),
]

STREAMING_MODEL_RESPONSE = [
ModelResponse(
Expand Down Expand Up @@ -1088,10 +1092,11 @@ def test_content_to_message_param_user_message():
assert message["content"] == "Test prompt"


def test_content_to_message_param_user_message_with_file_uri():
file_part = types.Part.from_uri(
file_uri="gs://bucket/document.pdf", mime_type="application/pdf"
)
@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES)
def test_content_to_message_param_user_message_with_file_uri(
file_uri, mime_type
):
file_part = types.Part.from_uri(file_uri=file_uri, mime_type=mime_type)
content = types.Content(
role="user",
parts=[
Expand All @@ -1106,14 +1111,15 @@ def test_content_to_message_param_user_message_with_file_uri():
assert message["content"][0]["type"] == "text"
assert message["content"][0]["text"] == "Summarize this file."
assert message["content"][1]["type"] == "file"
assert message["content"][1]["file"]["file_id"] == "gs://bucket/document.pdf"
assert message["content"][1]["file"]["format"] == "application/pdf"
assert message["content"][1]["file"]["file_id"] == file_uri
assert message["content"][1]["file"]["format"] == mime_type


def test_content_to_message_param_user_message_file_uri_only():
file_part = types.Part.from_uri(
file_uri="gs://bucket/only.pdf", mime_type="application/pdf"
)
@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES)
def test_content_to_message_param_user_message_file_uri_only(
file_uri, mime_type
):
file_part = types.Part.from_uri(file_uri=file_uri, mime_type=mime_type)
content = types.Content(
role="user",
parts=[
Expand All @@ -1125,8 +1131,8 @@ def test_content_to_message_param_user_message_file_uri_only():
assert message["role"] == "user"
assert isinstance(message["content"], list)
assert message["content"][0]["type"] == "file"
assert message["content"][0]["file"]["file_id"] == "gs://bucket/only.pdf"
assert message["content"][0]["file"]["format"] == "application/pdf"
assert message["content"][0]["file"]["file_id"] == file_uri
assert message["content"][0]["file"]["format"] == mime_type


def test_content_to_message_param_multi_part_function_response():
Expand Down Expand Up @@ -1307,17 +1313,28 @@ def test_get_content_pdf():
assert content[0]["file"]["format"] == "application/pdf"


def test_get_content_file_uri():
def test_get_content_json():
parts = [
types.Part.from_uri(
file_uri="gs://bucket/document.pdf",
mime_type="application/pdf",
types.Part.from_bytes(
data=b'{"hello":"world"}', mime_type="application/json"
)
]
content = _get_content(parts)
assert content[0]["type"] == "file"
assert content[0]["file"]["file_id"] == "gs://bucket/document.pdf"
assert content[0]["file"]["format"] == "application/pdf"
assert (
content[0]["file"]["file_data"]
== "data:application/json;base64,eyJoZWxsbyI6IndvcmxkIn0="
)
assert content[0]["file"]["format"] == "application/json"


@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES)
def test_get_content_file_uri(file_uri, mime_type):
parts = [types.Part.from_uri(file_uri=file_uri, mime_type=mime_type)]
content = _get_content(parts)
assert content[0]["type"] == "file"
assert content[0]["file"]["file_id"] == file_uri
assert content[0]["file"]["format"] == mime_type


def test_get_content_audio():
Expand Down