Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 75 additions & 59 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,19 @@ async def acompletion(
Returns:
The model response as a message.
"""

return await acompletion(
model=model,
messages=messages,
tools=tools,
**kwargs,
)
try:
return await acompletion(
model=model,
messages=messages,
tools=tools,
**kwargs,
)
except Exception as e:
if "Internal Server Error" in str(e):
logger.error("Internal Server Error encountered: %s", e)
raise
else:
raise

def completion(
self, model, messages, tools, stream=False, **kwargs
Expand Down Expand Up @@ -478,7 +484,7 @@ def _build_function_declaration_log(
if func_decl.parameters and func_decl.parameters.properties:
param_str = str({
k: v.model_dump(exclude_none=True)
for k, v in func_decl.parameters.properties.items()
for k, v in func_decl.parameters.properties
})
return_str = "None"
if func_decl.response:
Expand Down Expand Up @@ -605,58 +611,68 @@ async def generate_content_async(
}
completion_args.update(self._additional_args)

if stream:
text = ""
function_name = ""
function_args = ""
function_id = None
completion_args["stream"] = True
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
if chunk.name:
function_name += chunk.name
if chunk.args:
function_args += chunk.args
function_id = chunk.id or function_id
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=chunk.text,
),
is_partial=True,
)
if finish_reason == "tool_calls" and function_id:
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
retries = 3
for attempt in range(retries):
try:
if stream:
text = ""
function_name = ""
function_args = ""
function_id = None
completion_args["stream"] = True
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
if chunk.name:
function_name += chunk.name
if chunk.args:
function_args += chunk.args
function_id = chunk.id or function_id
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=chunk.text,
),
is_partial=True,
)
)
function_name = ""
function_args = ""
function_id = None
elif finish_reason == "stop" and text:
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)
)
text = ""

else:
response = await self.llm_client.acompletion(**completion_args)
yield _model_response_to_generate_content_response(response)
if finish_reason == "tool_calls" and function_id:
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
)
)
function_name = ""
function_args = ""
function_id = None
elif finish_reason == "stop" and text:
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)
)
text = ""
break
else:
response = await self.llm_client.acompletion(**completion_args)
yield _model_response_to_generate_content_response(response)
break
except Exception as e:
if attempt < retries - 1:
logger.warning("Retrying due to error: %s", e)
else:
logger.error("Failed after %d attempts: %s", retries, e)
raise

@staticmethod
@override
Expand Down
44 changes: 44 additions & 0 deletions src/google/adk/tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,3 +802,47 @@ async def test_generate_content_async_stream(
]
== "string"
)


@pytest.mark.asyncio
async def test_generate_content_async_retry_logic(mock_acompletion, lite_llm_instance):
mock_acompletion.side_effect = [Exception("Internal Server Error"), mock_response]

async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.content.parts[1].function_call.name == "test_function"
assert response.content.parts[1].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[1].function_call.id == "test_tool_call_id"

assert mock_acompletion.call_count == 2


@pytest.mark.asyncio
async def test_generate_content_async_retry_logic_exceeds_retries(mock_acompletion, lite_llm_instance):
mock_acompletion.side_effect = Exception("Internal Server Error")

with pytest.raises(Exception, match="Internal Server Error"):
async for _ in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
pass

assert mock_acompletion.call_count == 3


@pytest.mark.asyncio
async def test_generate_content_async_handles_other_exceptions(mock_acompletion, lite_llm_instance):
mock_acompletion.side_effect = Exception("Some other error")

with pytest.raises(Exception, match="Some other error"):
async for _ in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
pass

assert mock_acompletion.call_count == 1