Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
"content_filter": types.FinishReason.SAFETY,
}

_SUPPORTED_FILE_CONTENT_MIME_TYPES = set(
["application/pdf", "application/json", "text/plain"]
)


class ChatCompletionFileUrlObject(TypedDict, total=False):
file_data: str
Expand Down Expand Up @@ -336,7 +340,7 @@ def _get_content(
"type": "audio_url",
"audio_url": {"url": data_uri, "format": format_type},
})
elif part.inline_data.mime_type == "application/pdf":
elif part.inline_data.mime_type in _SUPPORTED_FILE_CONTENT_MIME_TYPES:
format_type = part.inline_data.mime_type
content_objects.append({
"type": "file",
Expand Down
82 changes: 52 additions & 30 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,32 @@
),
)

FILE_URI_TEST_CASES = [
pytest.param("gs://bucket/document.pdf", "application/pdf", id="pdf"),
pytest.param("gs://bucket/data.json", "application/json", id="json"),
pytest.param("gs://bucket/data.txt", "text/plain", id="txt"),
]

FILE_BYTES_TEST_CASES = [
pytest.param(
b"test_pdf_data",
"application/pdf",
"data:application/pdf;base64,dGVzdF9wZGZfZGF0YQ==",
id="pdf",
),
pytest.param(
b'{"hello":"world"}',
"application/json",
"data:application/json;base64,eyJoZWxsbyI6IndvcmxkIn0=",
id="json",
),
pytest.param(
b"hello world",
"text/plain",
"data:text/plain;base64,aGVsbG8gd29ybGQ=",
id="txt",
),
]

STREAMING_MODEL_RESPONSE = [
ModelResponse(
Expand Down Expand Up @@ -1088,10 +1114,11 @@ def test_content_to_message_param_user_message():
assert message["content"] == "Test prompt"


def test_content_to_message_param_user_message_with_file_uri():
file_part = types.Part.from_uri(
file_uri="gs://bucket/document.pdf", mime_type="application/pdf"
)
@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES)
def test_content_to_message_param_user_message_with_file_uri(
file_uri, mime_type
):
file_part = types.Part.from_uri(file_uri=file_uri, mime_type=mime_type)
content = types.Content(
role="user",
parts=[
Expand All @@ -1106,14 +1133,15 @@ def test_content_to_message_param_user_message_with_file_uri():
assert message["content"][0]["type"] == "text"
assert message["content"][0]["text"] == "Summarize this file."
assert message["content"][1]["type"] == "file"
assert message["content"][1]["file"]["file_id"] == "gs://bucket/document.pdf"
assert message["content"][1]["file"]["format"] == "application/pdf"
assert message["content"][1]["file"]["file_id"] == file_uri
assert message["content"][1]["file"]["format"] == mime_type


def test_content_to_message_param_user_message_file_uri_only():
file_part = types.Part.from_uri(
file_uri="gs://bucket/only.pdf", mime_type="application/pdf"
)
@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES)
def test_content_to_message_param_user_message_file_uri_only(
file_uri, mime_type
):
file_part = types.Part.from_uri(file_uri=file_uri, mime_type=mime_type)
content = types.Content(
role="user",
parts=[
Expand All @@ -1125,8 +1153,8 @@ def test_content_to_message_param_user_message_file_uri_only():
assert message["role"] == "user"
assert isinstance(message["content"], list)
assert message["content"][0]["type"] == "file"
assert message["content"][0]["file"]["file_id"] == "gs://bucket/only.pdf"
assert message["content"][0]["file"]["format"] == "application/pdf"
assert message["content"][0]["file"]["file_id"] == file_uri
assert message["content"][0]["file"]["format"] == mime_type


def test_content_to_message_param_multi_part_function_response():
Expand Down Expand Up @@ -1294,30 +1322,24 @@ def test_get_content_video():
assert content[0]["video_url"]["format"] == "video/mp4"


def test_get_content_pdf():
parts = [
types.Part.from_bytes(data=b"test_pdf_data", mime_type="application/pdf")
]
@pytest.mark.parametrize(
"file_data,mime_type,expected_base64", FILE_BYTES_TEST_CASES
)
def test_get_content_file_bytes(file_data, mime_type, expected_base64):
parts = [types.Part.from_bytes(data=file_data, mime_type=mime_type)]
content = _get_content(parts)
assert content[0]["type"] == "file"
assert (
content[0]["file"]["file_data"]
== "data:application/pdf;base64,dGVzdF9wZGZfZGF0YQ=="
)
assert content[0]["file"]["format"] == "application/pdf"
assert content[0]["file"]["file_data"] == expected_base64
assert content[0]["file"]["format"] == mime_type


def test_get_content_file_uri():
parts = [
types.Part.from_uri(
file_uri="gs://bucket/document.pdf",
mime_type="application/pdf",
)
]
@pytest.mark.parametrize("file_uri,mime_type", FILE_URI_TEST_CASES)
def test_get_content_file_uri(file_uri, mime_type):
parts = [types.Part.from_uri(file_uri=file_uri, mime_type=mime_type)]
content = _get_content(parts)
assert content[0]["type"] == "file"
assert content[0]["file"]["file_id"] == "gs://bucket/document.pdf"
assert content[0]["file"]["format"] == "application/pdf"
assert content[0]["file"]["file_id"] == file_uri
assert content[0]["file"]["format"] == mime_type


def test_get_content_audio():
Expand Down