Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ dependencies = [
"pydantic-settings>=2.10.1", # Config management
"a2a-sdk>=0.3.0", # For Google Agent2Agent protocol
"deprecated>=1.2.18",
"google-adk>=1.10.0", # For basic agent architecture
"litellm>=1.74.3", # For model inference
"google-adk>=1.15.0", # For basic agent architecture
"litellm>=1.79.3", # For model inference
"loguru>=0.7.3", # For better logging
"opentelemetry-exporter-otlp>=1.35.0",
"opentelemetry-instrumentation-logging>=0.56b0",
Expand Down
48 changes: 40 additions & 8 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from __future__ import annotations

import os
from typing import Optional, Union
from typing import Optional, Union, AsyncGenerator

from google.adk.agents import LlmAgent, RunConfig
from google.adk.agents import LlmAgent, RunConfig, InvocationContext
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.context_cache_config import ContextCacheConfig
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
from google.adk.agents.run_config import StreamingMode
from google.adk.events import Event
from google.adk.models.lite_llm import LiteLlm
from google.adk.runners import Runner
from google.genai import types
Expand Down Expand Up @@ -151,6 +153,10 @@ class Agent(LlmAgent):

tracers: list[BaseTracer] = []

enable_responses: bool = False

context_cache_config: Optional[ContextCacheConfig] = None

run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True)
"""Optional run processor for intercepting and processing agent execution flows.

Expand Down Expand Up @@ -197,12 +203,28 @@ def model_post_init(self, __context: Any) -> None:
logger.info(f"Model extra config: {self.model_extra_config}")

if not self.model:
self.model = LiteLlm(
model=f"{self.model_provider}/{self.model_name}",
api_key=self.model_api_key,
api_base=self.model_api_base,
**self.model_extra_config,
)
if self.enable_responses:
from veadk.models.ark_llm import ArkLlm

self.model = ArkLlm(
model=f"{self.model_provider}/{self.model_name}",
api_key=self.model_api_key,
api_base=self.model_api_base,
**self.model_extra_config,
)
if not self.context_cache_config:
self.context_cache_config = ContextCacheConfig(
cache_intervals=100, # maximum number
ttl_seconds=315360000,
min_tokens=0,
)
else:
self.model = LiteLlm(
model=f"{self.model_provider}/{self.model_name}",
api_key=self.model_api_key,
api_base=self.model_api_base,
**self.model_extra_config,
)
logger.debug(
f"LiteLLM client created with config: {self.model_extra_config}"
)
Expand Down Expand Up @@ -238,6 +260,16 @@ def model_post_init(self, __context: Any) -> None:
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
)

async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
if self.enable_responses:
if not ctx.context_cache_config:
ctx.context_cache_config = self.context_cache_config

async for event in super()._run_async_impl(ctx):
yield event

async def _run(
self,
runner,
Expand Down
252 changes: 252 additions & 0 deletions veadk/models/ark_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# adapted from Google ADK models adk-python/blob/main/src/google/adk/models/lite_llm.py at f1f44675e4a86b75e72cfd838efd8a0399f23e24 · google/adk-python

import json
from typing import Any, Dict, Union, AsyncGenerator

import litellm
import openai
from openai.types.responses import Response as OpenAITypeResponse, ResponseStreamEvent
from google.adk.models import LlmRequest, LlmResponse
from google.adk.models.lite_llm import (
LiteLlm,
_get_completion_inputs,
FunctionChunk,
TextChunk,
_message_to_generate_content_response,
UsageMetadataChunk,
)
from google.genai import types
from litellm import ChatCompletionAssistantMessage
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Function,
)
from pydantic import Field

from veadk.models.ark_transform import (
CompletionToResponsesAPIHandler,
)
from veadk.utils.logger import get_logger

# This will add functions to prompts if functions are provided.
litellm.add_function_to_prompt = True

logger = get_logger(__name__)


class ArkLlmClient:
async def aresponse(
self, **kwargs
) -> Union[OpenAITypeResponse, openai.AsyncStream[ResponseStreamEvent]]:
# 1. Get request params
api_base = kwargs.pop("api_base", None)
api_key = kwargs.pop("api_key", None)

# 2. Call openai responses
client = openai.AsyncOpenAI(
base_url=api_base,
api_key=api_key,
)

raw_response = await client.responses.create(**kwargs)
return raw_response


class ArkLlm(LiteLlm):
llm_client: ArkLlmClient = Field(default_factory=ArkLlmClient)
_additional_args: Dict[str, Any] = None
transform_handler: CompletionToResponsesAPIHandler = Field(
default_factory=CompletionToResponsesAPIHandler
)

def __init__(self, **kwargs):
super().__init__(**kwargs)

async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
"""Generates content asynchronously.

Args:
llm_request: LlmRequest, the request to send to the LiteLlm model.
stream: bool = False, whether to do streaming call.

Yields:
LlmResponse: The model response.
"""
self._maybe_append_user_content(llm_request)
# logger.debug(_build_request_log(llm_request))

messages, tools, response_format, generation_params = _get_completion_inputs(
llm_request
)

if "functions" in self._additional_args:
# LiteLLM does not support both tools and functions together.
tools = None
# ------------------------------------------------------ #
# get previous_response_id
previous_response_id = None
if llm_request.cache_metadata and llm_request.cache_metadata.cache_name:
previous_response_id = llm_request.cache_metadata.cache_name
completion_args = {
"model": self.model,
"messages": messages,
"tools": tools,
"response_format": response_format,
"previous_response_id": previous_response_id, # supply previous_response_id
}
# ------------------------------------------------------ #
completion_args.update(self._additional_args)

if generation_params:
completion_args.update(generation_params)
response_args = self.transform_handler.transform_request(**completion_args)

if stream:
text = ""
# Track function calls by index
function_calls = {} # index -> {name, args, id}
response_args["stream"] = True
aggregated_llm_response = None
aggregated_llm_response_with_tool_call = None
usage_metadata = None
fallback_index = 0
raw_response = await self.llm_client.aresponse(**response_args)
async for part in raw_response:
for (
model_response,
chunk,
finish_reason,
) in self.transform_handler.stream_event_to_chunk(
part, model=self.model
):
if isinstance(chunk, FunctionChunk):
index = chunk.index or fallback_index
if index not in function_calls:
function_calls[index] = {"name": "", "args": "", "id": None}

if chunk.name:
function_calls[index]["name"] += chunk.name
if chunk.args:
function_calls[index]["args"] += chunk.args

# check if args is completed (workaround for improper chunk
# indexing)
try:
json.loads(function_calls[index]["args"])
fallback_index += 1
except json.JSONDecodeError:
pass

function_calls[index]["id"] = (
chunk.id or function_calls[index]["id"] or str(index)
)
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=chunk.text,
),
is_partial=True,
)
elif isinstance(chunk, UsageMetadataChunk):
usage_metadata = types.GenerateContentResponseUsageMetadata(
prompt_token_count=chunk.prompt_tokens,
candidates_token_count=chunk.completion_tokens,
total_token_count=chunk.total_tokens,
)
# ------------------------------------------------------ #
if model_response.get("usage", {}).get("prompt_tokens_details"):
usage_metadata.cached_content_token_count = (
model_response.get("usage", {})
.get("prompt_tokens_details")
.cached_tokens
)
# ------------------------------------------------------ #

if (
finish_reason == "tool_calls" or finish_reason == "stop"
) and function_calls:
tool_calls = []
for index, func_data in function_calls.items():
if func_data["id"]:
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=func_data["id"],
function=Function(
name=func_data["name"],
arguments=func_data["args"],
index=index,
),
)
)
aggregated_llm_response_with_tool_call = (
_message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=text,
tool_calls=tool_calls,
)
)
)
self.transform_handler.adapt_responses_api(
model_response,
aggregated_llm_response_with_tool_call,
stream=True,
)
text = ""
function_calls.clear()
elif finish_reason == "stop" and text:
aggregated_llm_response = _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant", content=text
)
)
self.transform_handler.adapt_responses_api(
model_response,
aggregated_llm_response,
stream=True,
)
text = ""

# waiting until streaming ends to yield the llm_response as litellm tends
# to send chunk that contains usage_metadata after the chunk with
# finish_reason set to tool_calls or stop.
if aggregated_llm_response:
if usage_metadata:
aggregated_llm_response.usage_metadata = usage_metadata
usage_metadata = None
yield aggregated_llm_response

if aggregated_llm_response_with_tool_call:
if usage_metadata:
aggregated_llm_response_with_tool_call.usage_metadata = (
usage_metadata
)
yield aggregated_llm_response_with_tool_call

else:
raw_response = await self.llm_client.aresponse(**response_args)
for (
llm_response
) in self.transform_handler.openai_response_to_generate_content_response(
llm_request, raw_response
):
yield llm_response
Loading