From dfc25c17a98aaad81e1e2f140db83d17cd78f393 Mon Sep 17 00:00:00 2001 From: Alejandro Cruzado-Ruiz Date: Wed, 23 Jul 2025 16:33:15 -0700 Subject: [PATCH] feat: modularize fast_api.py to allow simpler construction of API Server PiperOrigin-RevId: 786467758 --- src/google/adk/cli/adk_web_server.py | 984 ++++++++++++++++++ src/google/adk/cli/agent_graph.py | 8 +- src/google/adk/cli/fast_api.py | 914 +--------------- src/google/adk/cli/utils/__init__.py | 26 +- .../adk/cli/utils/agent_change_handler.py | 45 + src/google/adk/cli/utils/shared_value.py | 30 + src/google/adk/cli/utils/state.py | 47 + 7 files changed, 1164 insertions(+), 890 deletions(-) create mode 100644 src/google/adk/cli/adk_web_server.py create mode 100644 src/google/adk/cli/utils/agent_change_handler.py create mode 100644 src/google/adk/cli/utils/shared_value.py create mode 100644 src/google/adk/cli/utils/state.py diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py new file mode 100644 index 000000000..d2467ec8f --- /dev/null +++ b/src/google/adk/cli/adk_web_server.py @@ -0,0 +1,984 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +import logging +import os +import time +import traceback +import typing +from typing import Any +from typing import Callable +from typing import List +from typing import Literal +from typing import Optional + +from fastapi import FastAPI +from fastapi import HTTPException +from fastapi import Query +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.responses import StreamingResponse +from fastapi.staticfiles import StaticFiles +from fastapi.websockets import WebSocket +from fastapi.websockets import WebSocketDisconnect +from google.adk.evaluation.eval_set_results_manager import EvalSetResultsManager +from google.genai import types +import graphviz +from opentelemetry import trace +from opentelemetry.sdk.trace import export as export_lib +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace import TracerProvider +from pydantic import Field +from pydantic import ValidationError +from starlette.types import Lifespan +from typing_extensions import override +from watchdog.observers import Observer + +from . import agent_graph +from ..agents.live_request_queue import LiveRequest +from ..agents.live_request_queue import LiveRequestQueue +from ..agents.run_config import RunConfig +from ..agents.run_config import StreamingMode +from ..artifacts.base_artifact_service import BaseArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..errors.not_found_error import NotFoundError +from ..evaluation.base_eval_service import InferenceConfig +from ..evaluation.base_eval_service import InferenceRequest +from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..evaluation.eval_case import EvalCase +from ..evaluation.eval_case import SessionInput +from ..evaluation.eval_metrics import EvalMetric +from ..evaluation.eval_metrics import EvalMetricResult +from ..evaluation.eval_metrics import EvalMetricResultPerInvocation +from ..evaluation.eval_result import EvalSetResult +from ..evaluation.eval_sets_manager import EvalSetsManager +from ..events.event import Event +from ..memory.base_memory_service import BaseMemoryService +from ..runners import Runner +from ..sessions.base_session_service import BaseSessionService +from ..sessions.session import Session +from .cli_eval import EVAL_SESSION_ID_PREFIX +from .cli_eval import EvalStatus +from .utils import cleanup +from .utils import common +from .utils import envs +from .utils import evals +from .utils.base_agent_loader import BaseAgentLoader +from .utils.shared_value import SharedValue +from .utils.state import create_empty_state + +logger = logging.getLogger("google_adk." + __name__) + +_EVAL_SET_FILE_EXTENSION = ".evalset.json" + + +class ApiServerSpanExporter(export_lib.SpanExporter): + + def __init__(self, trace_dict): + self.trace_dict = trace_dict + + def export( + self, spans: typing.Sequence[ReadableSpan] + ) -> export_lib.SpanExportResult: + for span in spans: + if ( + span.name == "call_llm" + or span.name == "send_data" + or span.name.startswith("execute_tool") + ): + attributes = dict(span.attributes) + attributes["trace_id"] = span.get_span_context().trace_id + attributes["span_id"] = span.get_span_context().span_id + if attributes.get("gcp.vertex.agent.event_id", None): + self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes + return export_lib.SpanExportResult.SUCCESS + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + +class InMemoryExporter(export_lib.SpanExporter): + + def __init__(self, trace_dict): + super().__init__() + self._spans = [] + self.trace_dict = trace_dict + + @override + def export( + self, spans: typing.Sequence[ReadableSpan] + ) -> export_lib.SpanExportResult: + for span in spans: + trace_id = span.context.trace_id + if span.name == "call_llm": + attributes = dict(span.attributes) + session_id = attributes.get("gcp.vertex.agent.session_id", None) + if session_id: + if session_id not in self.trace_dict: + self.trace_dict[session_id] = [trace_id] + else: + self.trace_dict[session_id] += [trace_id] + self._spans.extend(spans) + return export_lib.SpanExportResult.SUCCESS + + @override + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def get_finished_spans(self, session_id: str): + trace_ids = self.trace_dict.get(session_id, None) + if trace_ids is None or not trace_ids: + return [] + return [x for x in self._spans if x.context.trace_id in trace_ids] + + def clear(self): + self._spans.clear() + + +class AgentRunRequest(common.BaseModel): + app_name: str + user_id: str + session_id: str + new_message: types.Content + streaming: bool = False + state_delta: Optional[dict[str, Any]] = None + + +class AddSessionToEvalSetRequest(common.BaseModel): + eval_id: str + session_id: str + user_id: str + + +class RunEvalRequest(common.BaseModel): + eval_ids: list[str] # if empty, then all evals in the eval set are run. + eval_metrics: list[EvalMetric] + + +class RunEvalResult(common.BaseModel): + eval_set_file: str + eval_set_id: str + eval_id: str + final_eval_status: EvalStatus + eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( + deprecated=True, + default=[], + description=( + "This field is deprecated, use overall_eval_metric_results instead." + ), + ) + overall_eval_metric_results: list[EvalMetricResult] + eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] + user_id: str + session_id: str + + +class GetEventGraphResult(common.BaseModel): + dot_src: str + + +class AdkWebServer: + """Helper class for setting up and running the ADK web server on FastAPI. + + You construct this class with all the Services required to run ADK agents and + can then call the get_fast_api_app method to get a FastAPI app instance that + can will use your provided service instances, static assets, and agent loader. + If you pass in a web_assets_dir, the static assets will be served under + /dev-ui in addition to the API endpoints created by default. + + You can add add additional API endpoints by modifying the FastAPI app + instance returned by get_fast_api_app as this class exposes the agent runners + and most other bits of state retained during the lifetime of the server. + + Attributes: + agent_loader: An instance of BaseAgentLoader for loading agents. + session_service: An instance of BaseSessionService for managing sessions. + memory_service: An instance of BaseMemoryService for managing memory. + artifact_service: An instance of BaseArtifactService for managing + artifacts. + credential_service: An instance of BaseCredentialService for managing + credentials. + eval_sets_manager: An instance of EvalSetsManager for managing evaluation + sets. + eval_set_results_manager: An instance of EvalSetResultsManager for + managing evaluation set results. + agents_dir: Root directory containing subdirs for agents with those + containing resources (e.g. .env files, eval sets, etc.) for the agents. + runners_to_clean: Set of runner names marked for cleanup. + current_app_name_ref: A shared reference to the latest ran app name. + runner_dict: A dict of instantiated runners for each app. + """ + + def __init__( + self, + *, + agent_loader: BaseAgentLoader, + session_service: BaseSessionService, + memory_service: BaseMemoryService, + artifact_service: BaseArtifactService, + credential_service: BaseCredentialService, + eval_sets_manager: EvalSetsManager, + eval_set_results_manager: EvalSetResultsManager, + agents_dir: str, + ): + self.agent_loader = agent_loader + self.session_service = session_service + self.memory_service = memory_service + self.artifact_service = artifact_service + self.credential_service = credential_service + self.eval_sets_manager = eval_sets_manager + self.eval_set_results_manager = eval_set_results_manager + self.agents_dir = agents_dir + # Internal propeties we want to allow being modified from callbacks. + self.runners_to_clean: set[str] = set() + self.current_app_name_ref: SharedValue[str] = SharedValue(value="") + self.runner_dict = {} + + async def get_runner_async(self, app_name: str) -> Runner: + """Returns the runner for the given app.""" + if app_name in self.runners_to_clean: + self.runners_to_clean.remove(app_name) + runner = self.runner_dict.pop(app_name, None) + await cleanup.close_runners(list([runner])) + + envs.load_dotenv_for_agent(os.path.basename(app_name), self.agents_dir) + if app_name in self.runner_dict: + return self.runner_dict[app_name] + root_agent = self.agent_loader.load_agent(app_name) + runner = Runner( + app_name=app_name, + agent=root_agent, + artifact_service=self.artifact_service, + session_service=self.session_service, + memory_service=self.memory_service, + credential_service=self.credential_service, + ) + self.runner_dict[app_name] = runner + return runner + + def get_fast_api_app( + self, + lifespan: Optional[Lifespan[FastAPI]] = None, + allow_origins: Optional[list[str]] = None, + web_assets_dir: Optional[str] = None, + setup_observer: Callable[ + [Observer, "AdkWebServer"], None + ] = lambda o, s: None, + tear_down_observer: Callable[ + [Observer, "AdkWebServer"], None + ] = lambda o, s: None, + register_processors: Callable[[TracerProvider], None] = lambda o: None, + ): + """Creates a FastAPI app for the ADK web server. + + By default it'll just return a FastAPI instance with the API server + endpoints, + but if you specify a web_assets_dir, it'll also serve the static web assets + from that directory. + + Args: + lifespan: The lifespan of the FastAPI app. + allow_origins: The origins that are allowed to make cross-origin requests. + web_assets_dir: The directory containing the web assets to serve. + setup_observer: Callback for setting up the file system observer. + tear_down_observer: Callback for cleaning up the file system observer. + register_processors: Callback for additional Span processors to be added + to the TracerProvider. + + Returns: + A FastAPI app instance. + """ + # Properties we don't need to modify from callbacks + trace_dict = {} + session_trace_dict = {} + # Set up a file system watcher to detect changes in the agents directory. + observer = Observer() + setup_observer(observer, self) + + @asynccontextmanager + async def internal_lifespan(app: FastAPI): + try: + if lifespan: + async with lifespan(app) as lifespan_context: + yield lifespan_context + else: + yield + finally: + tear_down_observer(observer, self) + # Create tasks for all runner closures to run concurrently + await cleanup.close_runners(list(self.runner_dict.values())) + + # Set up tracing in the FastAPI server. + provider = TracerProvider() + provider.add_span_processor( + export_lib.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)) + ) + memory_exporter = InMemoryExporter(session_trace_dict) + provider.add_span_processor(export_lib.SimpleSpanProcessor(memory_exporter)) + + register_processors(provider) + + trace.set_tracer_provider(provider) + + # Run the FastAPI server. + app = FastAPI(lifespan=internal_lifespan) + + if allow_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=allow_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/list-apps") + def list_apps() -> list[str]: + return self.agent_loader.list_agents() + + @app.get("/debug/trace/{event_id}") + def get_trace_dict(event_id: str) -> Any: + event_dict = trace_dict.get(event_id, None) + if event_dict is None: + raise HTTPException(status_code=404, detail="Trace not found") + return event_dict + + @app.get("/debug/trace/session/{session_id}") + def get_session_trace(session_id: str) -> Any: + spans = memory_exporter.get_finished_spans(session_id) + if not spans: + return [] + return [ + { + "name": s.name, + "span_id": s.context.span_id, + "trace_id": s.context.trace_id, + "start_time": s.start_time, + "end_time": s.end_time, + "attributes": dict(s.attributes), + "parent_span_id": s.parent.span_id if s.parent else None, + } + for s in spans + ] + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def get_session( + app_name: str, user_id: str, session_id: str + ) -> Session: + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + self.current_app_name_ref.value = app_name + return session + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions", + response_model_exclude_none=True, + ) + async def list_sessions(app_name: str, user_id: str) -> list[Session]: + list_sessions_response = await self.session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + return [ + session + for session in list_sessions_response.sessions + # Remove sessions that were generated as a part of Eval. + if not session.id.startswith(EVAL_SESSION_ID_PREFIX) + ] + + @app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def create_session_with_id( + app_name: str, + user_id: str, + session_id: str, + state: Optional[dict[str, Any]] = None, + ) -> Session: + if ( + await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + is not None + ): + logger.warning("Session already exists: %s", session_id) + raise HTTPException( + status_code=400, detail=f"Session already exists: {session_id}" + ) + logger.info("New session created: %s", session_id) + return await self.session_service.create_session( + app_name=app_name, user_id=user_id, state=state, session_id=session_id + ) + + @app.post( + "/apps/{app_name}/users/{user_id}/sessions", + response_model_exclude_none=True, + ) + async def create_session( + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + events: Optional[list[Event]] = None, + ) -> Session: + logger.info("New session created") + session = await self.session_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + if events: + for event in events: + await self.session_service.append_event(session=session, event=event) + + return session + + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}", + response_model_exclude_none=True, + ) + def create_eval_set( + app_name: str, + eval_set_id: str, + ): + """Creates an eval set, given the id.""" + try: + self.eval_sets_manager.create_eval_set(app_name, eval_set_id) + except ValueError as ve: + raise HTTPException( + status_code=400, + detail=str(ve), + ) from ve + + @app.get( + "/apps/{app_name}/eval_sets", + response_model_exclude_none=True, + ) + def list_eval_sets(app_name: str) -> list[str]: + """Lists all eval sets for the given app.""" + try: + return self.eval_sets_manager.list_eval_sets(app_name) + except NotFoundError as e: + logger.warning(e) + return [] + + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}/add_session", + response_model_exclude_none=True, + ) + async def add_session_to_eval_set( + app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest + ): + # Get the session + session = await self.session_service.get_session( + app_name=app_name, user_id=req.user_id, session_id=req.session_id + ) + assert session, "Session not found." + + # Convert the session data to eval invocations + invocations = evals.convert_session_to_eval_invocations(session) + + # Populate the session with initial session state. + initial_session_state = create_empty_state( + self.agent_loader.load_agent(app_name) + ) + + new_eval_case = EvalCase( + eval_id=req.eval_id, + conversation=invocations, + session_input=SessionInput( + app_name=app_name, + user_id=req.user_id, + state=initial_session_state, + ), + creation_timestamp=time.time(), + ) + + try: + self.eval_sets_manager.add_eval_case( + app_name, eval_set_id, new_eval_case + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) from ve + + @app.get( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals", + response_model_exclude_none=True, + ) + def list_evals_in_eval_set( + app_name: str, + eval_set_id: str, + ) -> list[str]: + """Lists all evals in an eval set.""" + eval_set_data = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) + + if not eval_set_data: + raise HTTPException( + status_code=400, detail=f"Eval set `{eval_set_id}` not found." + ) + + return sorted([x.eval_id for x in eval_set_data.eval_cases]) + + @app.get( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", + response_model_exclude_none=True, + ) + def get_eval( + app_name: str, eval_set_id: str, eval_case_id: str + ) -> EvalCase: + """Gets an eval case in an eval set.""" + eval_case_to_find = self.eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + if eval_case_to_find: + return eval_case_to_find + + raise HTTPException( + status_code=404, + detail=( + f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found." + ), + ) + + @app.put( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", + response_model_exclude_none=True, + ) + def update_eval( + app_name: str, + eval_set_id: str, + eval_case_id: str, + updated_eval_case: EvalCase, + ): + if ( + updated_eval_case.eval_id + and updated_eval_case.eval_id != eval_case_id + ): + raise HTTPException( + status_code=400, + detail=( + "Eval id in EvalCase should match the eval id in the API route." + ), + ) + + # Overwrite the value. We are either overwriting the same value or an empty + # field. + updated_eval_case.eval_id = eval_case_id + try: + self.eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + except NotFoundError as nfe: + raise HTTPException(status_code=404, detail=str(nfe)) from nfe + + @app.delete("/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}") + def delete_eval(app_name: str, eval_set_id: str, eval_case_id: str): + try: + self.eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + except NotFoundError as nfe: + raise HTTPException(status_code=404, detail=str(nfe)) from nfe + + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval", + response_model_exclude_none=True, + ) + async def run_eval( + app_name: str, eval_set_id: str, req: RunEvalRequest + ) -> list[RunEvalResult]: + """Runs an eval given the details in the eval request.""" + # Create a mapping from eval set file to all the evals that needed to be + # run. + try: + from ..evaluation.local_eval_service import LocalEvalService + from .cli_eval import _collect_eval_results + from .cli_eval import _collect_inferences + + eval_set = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) + + if not eval_set: + raise HTTPException( + status_code=400, detail=f"Eval set `{eval_set_id}` not found." + ) + + root_agent = self.agent_loader.load_agent(app_name) + + eval_case_results = [] + + eval_service = LocalEvalService( + root_agent=root_agent, + eval_sets_manager=self.eval_sets_manager, + eval_set_results_manager=self.eval_set_results_manager, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + inference_request = InferenceRequest( + app_name=app_name, + eval_set_id=eval_set.eval_set_id, + eval_case_ids=req.eval_ids, + inference_config=InferenceConfig(), + ) + inference_results = await _collect_inferences( + inference_requests=[inference_request], eval_service=eval_service + ) + + eval_case_results = await _collect_eval_results( + inference_results=inference_results, + eval_service=eval_service, + eval_metrics=req.eval_metrics, + ) + except ModuleNotFoundError as e: + logger.exception("%s", e) + raise HTTPException( + status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE + ) from e + + run_eval_results = [] + for eval_case_result in eval_case_results: + run_eval_results.append( + RunEvalResult( + eval_set_file=eval_case_result.eval_set_file, + eval_set_id=eval_set_id, + eval_id=eval_case_result.eval_id, + final_eval_status=eval_case_result.final_eval_status, + overall_eval_metric_results=eval_case_result.overall_eval_metric_results, + eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, + ) + ) + + return run_eval_results + + @app.get( + "/apps/{app_name}/eval_results/{eval_result_id}", + response_model_exclude_none=True, + ) + def get_eval_result( + app_name: str, + eval_result_id: str, + ) -> EvalSetResult: + """Gets the eval result for the given eval id.""" + try: + return self.eval_set_results_manager.get_eval_set_result( + app_name, eval_result_id + ) + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) from ve + except ValidationError as ve: + raise HTTPException(status_code=500, detail=str(ve)) from ve + + @app.get( + "/apps/{app_name}/eval_results", + response_model_exclude_none=True, + ) + def list_eval_results(app_name: str) -> list[str]: + """Lists all eval results for the given app.""" + return self.eval_set_results_manager.list_eval_set_results(app_name) + + @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") + async def delete_session(app_name: str, user_id: str, session_id: str): + await self.session_service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + response_model_exclude_none=True, + ) + async def load_artifact( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + version: Optional[int] = Query(None), + ) -> Optional[types.Part]: + artifact = await self.artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + version=version, + ) + if not artifact: + raise HTTPException(status_code=404, detail="Artifact not found") + return artifact + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", + response_model_exclude_none=True, + ) + async def load_artifact_version( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + version_id: int, + ) -> Optional[types.Part]: + artifact = await self.artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + version=version_id, + ) + if not artifact: + raise HTTPException(status_code=404, detail="Artifact not found") + return artifact + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", + response_model_exclude_none=True, + ) + async def list_artifact_names( + app_name: str, user_id: str, session_id: str + ) -> list[str]: + return await self.artifact_service.list_artifact_keys( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", + response_model_exclude_none=True, + ) + async def list_artifact_versions( + app_name: str, user_id: str, session_id: str, artifact_name: str + ) -> list[int]: + return await self.artifact_service.list_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + ) + + @app.delete( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + ) + async def delete_artifact( + app_name: str, user_id: str, session_id: str, artifact_name: str + ): + await self.artifact_service.delete_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + ) + + @app.post("/run", response_model_exclude_none=True) + async def agent_run(req: AgentRunRequest) -> list[Event]: + session = await self.session_service.get_session( + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + runner = await self.get_runner_async(req.app_name) + events = [ + event + async for event in runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + ) + ] + logger.info("Generated %s events in agent run", len(events)) + logger.debug("Events generated: %s", events) + return events + + @app.post("/run_sse") + async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: + # SSE endpoint + session = await self.session_service.get_session( + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + # Convert the events to properly formatted SSE + async def event_generator(): + try: + stream_mode = ( + StreamingMode.SSE if req.streaming else StreamingMode.NONE + ) + runner = await self.get_runner_async(req.app_name) + async for event in runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + ): + # Format as SSE data + sse_event = event.model_dump_json(exclude_none=True, by_alias=True) + logger.debug( + "Generated event in agent run streaming: %s", sse_event + ) + yield f"data: {sse_event}\n\n" + except Exception as e: + logger.exception("Error in event_generator: %s", e) + # You might want to yield an error event here + yield f'data: {{"error": "{str(e)}"}}\n\n' + + # Returns a streaming response with the proper media type for SSE + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph", + response_model_exclude_none=True, + ) + async def get_event_graph( + app_name: str, user_id: str, session_id: str, event_id: str + ): + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + session_events = session.events if session else [] + event = next((x for x in session_events if x.id == event_id), None) + if not event: + return {} + + function_calls = event.get_function_calls() + function_responses = event.get_function_responses() + root_agent = self.agent_loader.load_agent(app_name) + dot_graph = None + if function_calls: + function_call_highlights = [] + for function_call in function_calls: + from_name = event.author + to_name = function_call.name + function_call_highlights.append((from_name, to_name)) + dot_graph = await agent_graph.get_agent_graph( + root_agent, function_call_highlights + ) + elif function_responses: + function_responses_highlights = [] + for function_response in function_responses: + from_name = function_response.name + to_name = event.author + function_responses_highlights.append((from_name, to_name)) + dot_graph = await agent_graph.get_agent_graph( + root_agent, function_responses_highlights + ) + else: + from_name = event.author + to_name = "" + dot_graph = await agent_graph.get_agent_graph( + root_agent, [(from_name, to_name)] + ) + if dot_graph and isinstance(dot_graph, graphviz.Digraph): + return GetEventGraphResult(dot_src=dot_graph.source) + else: + return {} + + @app.websocket("/run_live") + async def agent_live_run( + websocket: WebSocket, + app_name: str, + user_id: str, + session_id: str, + modalities: List[Literal["TEXT", "AUDIO"]] = Query( + default=["TEXT", "AUDIO"] + ), # Only allows "TEXT" or "AUDIO" + ) -> None: + await websocket.accept() + + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + # Accept first so that the client is aware of connection establishment, + # then close with a specific code. + await websocket.close(code=1002, reason="Session not found") + return + + live_request_queue = LiveRequestQueue() + + async def forward_events(): + runner = await self.get_runner_async(app_name) + async for event in runner.run_live( + session=session, live_request_queue=live_request_queue + ): + await websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + + async def process_messages(): + try: + while True: + data = await websocket.receive_text() + # Validate and send the received message to the live queue. + live_request_queue.send(LiveRequest.model_validate_json(data)) + except ValidationError as ve: + logger.error("Validation error in process_messages: %s", ve) + + # Run both tasks concurrently and cancel all if one fails. + tasks = [ + asyncio.create_task(forward_events()), + asyncio.create_task(process_messages()), + ] + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_EXCEPTION + ) + try: + # This will re-raise any exception from the completed tasks. + for task in done: + task.result() + except WebSocketDisconnect: + logger.info("Client disconnected during process_messages.") + except Exception as e: + logger.exception("Error during live websocket communication: %s", e) + traceback.print_exc() + WEBSOCKET_INTERNAL_ERROR_CODE = 1011 + WEBSOCKET_MAX_BYTES_FOR_REASON = 123 + await websocket.close( + code=WEBSOCKET_INTERNAL_ERROR_CODE, + reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON], + ) + finally: + for task in pending: + task.cancel() + + if web_assets_dir: + import mimetypes + + mimetypes.add_type("application/javascript", ".js", True) + mimetypes.add_type("text/javascript", ".js", True) + + @app.get("/") + async def redirect_root_to_dev_ui(): + return RedirectResponse("/dev-ui/") + + @app.get("/dev-ui") + async def redirect_dev_ui_add_slash(): + return RedirectResponse("/dev-ui/") + + app.mount( + "/dev-ui/", + StaticFiles(directory=web_assets_dir, html=True, follow_symlink=True), + name="static", + ) + + return app diff --git a/src/google/adk/cli/agent_graph.py b/src/google/adk/cli/agent_graph.py index 2df968f81..e919010cc 100644 --- a/src/google/adk/cli/agent_graph.py +++ b/src/google/adk/cli/agent_graph.py @@ -19,11 +19,11 @@ import graphviz -from ..agents import BaseAgent -from ..agents import LoopAgent -from ..agents import ParallelAgent -from ..agents import SequentialAgent +from ..agents.base_agent import BaseAgent from ..agents.llm_agent import LlmAgent +from ..agents.loop_agent import LoopAgent +from ..agents.parallel_agent import ParallelAgent +from ..agents.sequential_agent import SequentialAgent from ..tools.agent_tool import AgentTool from ..tools.base_tool import BaseTool from ..tools.function_tool import FunctionTool diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 09cd5d2e6..99608d7be 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -14,205 +14,42 @@ from __future__ import annotations -import asyncio -from contextlib import asynccontextmanager import json import logging import os from pathlib import Path import shutil -import time -import traceback -import typing from typing import Any -from typing import List -from typing import Literal +from typing import Mapping from typing import Optional import click from fastapi import FastAPI -from fastapi import HTTPException -from fastapi import Query from fastapi import UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse -from fastapi.responses import StreamingResponse -from fastapi.staticfiles import StaticFiles -from fastapi.websockets import WebSocket -from fastapi.websockets import WebSocketDisconnect -from google.genai import types -import graphviz -from opentelemetry import trace from opentelemetry.sdk.trace import export -from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace import TracerProvider -from pydantic import Field -from pydantic import ValidationError from starlette.types import Lifespan -from typing_extensions import override -from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer -from ..agents import RunConfig -from ..agents.live_request_queue import LiveRequest -from ..agents.live_request_queue import LiveRequestQueue -from ..agents.run_config import StreamingMode from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService -from ..errors.not_found_error import NotFoundError -from ..evaluation.base_eval_service import InferenceConfig -from ..evaluation.base_eval_service import InferenceRequest -from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE -from ..evaluation.eval_case import EvalCase -from ..evaluation.eval_case import SessionInput -from ..evaluation.eval_metrics import EvalMetric -from ..evaluation.eval_metrics import EvalMetricResult -from ..evaluation.eval_metrics import EvalMetricResultPerInvocation -from ..evaluation.eval_result import EvalSetResult from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager -from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..runners import Runner from ..sessions.in_memory_session_service import InMemorySessionService -from ..sessions.session import Session from ..sessions.vertex_ai_session_service import VertexAiSessionService from ..utils.feature_decorator import working_in_progress -from .cli_eval import EVAL_SESSION_ID_PREFIX -from .cli_eval import EvalStatus -from .utils import cleanup -from .utils import common -from .utils import create_empty_state +from .adk_web_server import AdkWebServer from .utils import envs from .utils import evals +from .utils.agent_change_handler import AgentChangeEventHandler from .utils.agent_loader import AgentLoader logger = logging.getLogger("google_adk." + __name__) -_EVAL_SET_FILE_EXTENSION = ".evalset.json" -_app_name = "" -_runners_to_clean = set() - - -class AgentChangeEventHandler(FileSystemEventHandler): - - def __init__(self, agent_loader: AgentLoader): - self.agent_loader = agent_loader - - def on_modified(self, event): - if not (event.src_path.endswith(".py") or event.src_path.endswith(".yaml")): - return - logger.info("Change detected in agents directory: %s", event.src_path) - self.agent_loader.remove_agent_from_cache(_app_name) - _runners_to_clean.add(_app_name) - - -class ApiServerSpanExporter(export.SpanExporter): - - def __init__(self, trace_dict): - self.trace_dict = trace_dict - - def export( - self, spans: typing.Sequence[ReadableSpan] - ) -> export.SpanExportResult: - for span in spans: - if ( - span.name == "call_llm" - or span.name == "send_data" - or span.name.startswith("execute_tool") - ): - attributes = dict(span.attributes) - attributes["trace_id"] = span.get_span_context().trace_id - attributes["span_id"] = span.get_span_context().span_id - if attributes.get("gcp.vertex.agent.event_id", None): - self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes - return export.SpanExportResult.SUCCESS - - def force_flush(self, timeout_millis: int = 30000) -> bool: - return True - - -class InMemoryExporter(export.SpanExporter): - - def __init__(self, trace_dict): - super().__init__() - self._spans = [] - self.trace_dict = trace_dict - - @override - def export( - self, spans: typing.Sequence[ReadableSpan] - ) -> export.SpanExportResult: - for span in spans: - trace_id = span.context.trace_id - if span.name == "call_llm": - attributes = dict(span.attributes) - session_id = attributes.get("gcp.vertex.agent.session_id", None) - if session_id: - if session_id not in self.trace_dict: - self.trace_dict[session_id] = [trace_id] - else: - self.trace_dict[session_id] += [trace_id] - self._spans.extend(spans) - return export.SpanExportResult.SUCCESS - - @override - def force_flush(self, timeout_millis: int = 30000) -> bool: - return True - - def get_finished_spans(self, session_id: str): - trace_ids = self.trace_dict.get(session_id, None) - if trace_ids is None or not trace_ids: - return [] - return [x for x in self._spans if x.context.trace_id in trace_ids] - - def clear(self): - self._spans.clear() - - -class AgentRunRequest(common.BaseModel): - app_name: str - user_id: str - session_id: str - new_message: types.Content - streaming: bool = False - state_delta: Optional[dict[str, Any]] = None - - -class AddSessionToEvalSetRequest(common.BaseModel): - eval_id: str - session_id: str - user_id: str - - -class RunEvalRequest(common.BaseModel): - eval_ids: list[str] # if empty, then all evals in the eval set are run. - eval_metrics: list[EvalMetric] - - -class RunEvalResult(common.BaseModel): - eval_set_file: str - eval_set_id: str - eval_id: str - final_eval_status: EvalStatus - eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( - deprecated=True, - default=[], - description=( - "This field is deprecated, use overall_eval_metric_results instead." - ), - ) - overall_eval_metric_results: list[EvalMetricResult] - eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] - user_id: str - session_id: str - - -class GetEventGraphResult(common.BaseModel): - dot_src: str - def get_fast_api_app( *, @@ -231,66 +68,7 @@ def get_fast_api_app( reload_agents: bool = False, lifespan: Optional[Lifespan[FastAPI]] = None, ) -> FastAPI: - # InMemory tracing dict. - trace_dict: dict[str, Any] = {} - session_trace_dict: dict[str, Any] = {} - - # Set up tracing in the FastAPI server. - provider = TracerProvider() - provider.add_span_processor( - export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)) - ) - memory_exporter = InMemoryExporter(session_trace_dict) - provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter)) - if trace_to_cloud: - from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter - - envs.load_dotenv_for_agent("", agents_dir) - if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): - processor = export.BatchSpanProcessor( - CloudTraceSpanExporter(project_id=project_id) - ) - provider.add_span_processor(processor) - else: - logger.warning( - "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will" - " not be enabled." - ) - - trace.set_tracer_provider(provider) - - @asynccontextmanager - async def internal_lifespan(app: FastAPI): - try: - if lifespan: - async with lifespan(app) as lifespan_context: - yield lifespan_context - else: - yield - finally: - if reload_agents: - observer.stop() - observer.join() - # Create tasks for all runner closures to run concurrently - await cleanup.close_runners(list(runner_dict.values())) - - # Run the FastAPI server. - app = FastAPI(lifespan=internal_lifespan) - - if allow_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=allow_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - runner_dict = {} - # Set up eval managers. - eval_sets_manager = None - eval_set_results_manager = None if eval_storage_uri: gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri @@ -397,439 +175,72 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name): # initialize Agent Loader agent_loader = AgentLoader(agents_dir) - # Set up a file system watcher to detect changes in the agents directory. - observer = Observer() - if reload_agents: - event_handler = AgentChangeEventHandler(agent_loader) - observer.schedule(event_handler, agents_dir, recursive=True) - observer.start() - - @app.get("/list-apps") - def list_apps() -> list[str]: - return agent_loader.list_agents() - - @app.get("/debug/trace/{event_id}") - def get_trace_dict(event_id: str) -> Any: - event_dict = trace_dict.get(event_id, None) - if event_dict is None: - raise HTTPException(status_code=404, detail="Trace not found") - return event_dict - - @app.get("/debug/trace/session/{session_id}") - def get_session_trace(session_id: str) -> Any: - spans = memory_exporter.get_finished_spans(session_id) - if not spans: - return [] - return [ - { - "name": s.name, - "span_id": s.context.span_id, - "trace_id": s.context.trace_id, - "start_time": s.start_time, - "end_time": s.end_time, - "attributes": dict(s.attributes), - "parent_span_id": s.parent.span_id if s.parent else None, - } - for s in spans - ] - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}", - response_model_exclude_none=True, + adk_web_server = AdkWebServer( + agent_loader=agent_loader, + session_service=session_service, + artifact_service=artifact_service, + memory_service=memory_service, + credential_service=credential_service, + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, + agents_dir=agents_dir, ) - async def get_session( - app_name: str, user_id: str, session_id: str - ) -> Session: - session = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - global _app_name - _app_name = app_name - return session + # Callbacks & other optional args for when constructing the FastAPI instance + extra_fast_api_args = {} - @app.get( - "/apps/{app_name}/users/{user_id}/sessions", - response_model_exclude_none=True, - ) - async def list_sessions(app_name: str, user_id: str) -> list[Session]: - list_sessions_response = await session_service.list_sessions( - app_name=app_name, user_id=user_id - ) - return [ - session - for session in list_sessions_response.sessions - # Remove sessions that were generated as a part of Eval. - if not session.id.startswith(EVAL_SESSION_ID_PREFIX) - ] + if trace_to_cloud: + from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter - @app.post( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}", - response_model_exclude_none=True, - ) - async def create_session_with_id( - app_name: str, - user_id: str, - session_id: str, - state: Optional[dict[str, Any]] = None, - ) -> Session: - if ( - await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id + def register_processors(provider: TracerProvider) -> None: + envs.load_dotenv_for_agent("", agents_dir) + if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): + processor = export.BatchSpanProcessor( + CloudTraceSpanExporter(project_id=project_id) ) - is not None - ): - logger.warning("Session already exists: %s", session_id) - raise HTTPException( - status_code=400, detail=f"Session already exists: {session_id}" - ) - logger.info("New session created: %s", session_id) - return await session_service.create_session( - app_name=app_name, user_id=user_id, state=state, session_id=session_id - ) - - @app.post( - "/apps/{app_name}/users/{user_id}/sessions", - response_model_exclude_none=True, - ) - async def create_session( - app_name: str, - user_id: str, - state: Optional[dict[str, Any]] = None, - events: Optional[list[Event]] = None, - ) -> Session: - logger.info("New session created") - session = await session_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - if events: - for event in events: - await session_service.append_event(session=session, event=event) - - return session - - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}", - response_model_exclude_none=True, - ) - def create_eval_set( - app_name: str, - eval_set_id: str, - ): - """Creates an eval set, given the id.""" - try: - eval_sets_manager.create_eval_set(app_name, eval_set_id) - except ValueError as ve: - raise HTTPException( - status_code=400, - detail=str(ve), - ) from ve - - @app.get( - "/apps/{app_name}/eval_sets", - response_model_exclude_none=True, - ) - def list_eval_sets(app_name: str) -> list[str]: - """Lists all eval sets for the given app.""" - try: - return eval_sets_manager.list_eval_sets(app_name) - except NotFoundError as e: - logger.warning(e) - return [] - - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}/add_session", - response_model_exclude_none=True, - ) - async def add_session_to_eval_set( - app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest - ): - # Get the session - session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id - ) - assert session, "Session not found." - - # Convert the session data to eval invocations - invocations = evals.convert_session_to_eval_invocations(session) - - # Populate the session with initial session state. - initial_session_state = create_empty_state( - agent_loader.load_agent(app_name) - ) - - new_eval_case = EvalCase( - eval_id=req.eval_id, - conversation=invocations, - session_input=SessionInput( - app_name=app_name, user_id=req.user_id, state=initial_session_state - ), - creation_timestamp=time.time(), - ) - - try: - eval_sets_manager.add_eval_case(app_name, eval_set_id, new_eval_case) - except ValueError as ve: - raise HTTPException(status_code=400, detail=str(ve)) from ve - - @app.get( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals", - response_model_exclude_none=True, - ) - def list_evals_in_eval_set( - app_name: str, - eval_set_id: str, - ) -> list[str]: - """Lists all evals in an eval set.""" - eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id) - - if not eval_set_data: - raise HTTPException( - status_code=400, detail=f"Eval set `{eval_set_id}` not found." - ) - - return sorted([x.eval_id for x in eval_set_data.eval_cases]) - - @app.get( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", - response_model_exclude_none=True, - ) - def get_eval(app_name: str, eval_set_id: str, eval_case_id: str) -> EvalCase: - """Gets an eval case in an eval set.""" - eval_case_to_find = eval_sets_manager.get_eval_case( - app_name, eval_set_id, eval_case_id - ) - - if eval_case_to_find: - return eval_case_to_find - - raise HTTPException( - status_code=404, - detail=f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found.", - ) - - @app.put( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", - response_model_exclude_none=True, - ) - def update_eval( - app_name: str, - eval_set_id: str, - eval_case_id: str, - updated_eval_case: EvalCase, - ): - if updated_eval_case.eval_id and updated_eval_case.eval_id != eval_case_id: - raise HTTPException( - status_code=400, - detail=( - "Eval id in EvalCase should match the eval id in the API route." - ), - ) - - # Overwrite the value. We are either overwriting the same value or an empty - # field. - updated_eval_case.eval_id = eval_case_id - try: - eval_sets_manager.update_eval_case( - app_name, eval_set_id, updated_eval_case - ) - except NotFoundError as nfe: - raise HTTPException(status_code=404, detail=str(nfe)) from nfe - - @app.delete("/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}") - def delete_eval(app_name: str, eval_set_id: str, eval_case_id: str): - try: - eval_sets_manager.delete_eval_case(app_name, eval_set_id, eval_case_id) - except NotFoundError as nfe: - raise HTTPException(status_code=404, detail=str(nfe)) from nfe - - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval", - response_model_exclude_none=True, - ) - async def run_eval( - app_name: str, eval_set_id: str, req: RunEvalRequest - ) -> list[RunEvalResult]: - """Runs an eval given the details in the eval request.""" - # Create a mapping from eval set file to all the evals that needed to be - # run. - try: - from ..evaluation.local_eval_service import LocalEvalService - from .cli_eval import _collect_eval_results - from .cli_eval import _collect_inferences - - eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) - - if not eval_set: - raise HTTPException( - status_code=400, detail=f"Eval set `{eval_set_id}` not found." + provider.add_span_processor(processor) + else: + logger.warning( + "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will" + " not be enabled." ) - root_agent = agent_loader.load_agent(app_name) - - eval_case_results = [] - - eval_service = LocalEvalService( - root_agent=root_agent, - eval_sets_manager=eval_sets_manager, - eval_set_results_manager=eval_set_results_manager, - session_service=session_service, - artifact_service=artifact_service, - ) - inference_request = InferenceRequest( - app_name=app_name, - eval_set_id=eval_set.eval_set_id, - eval_case_ids=req.eval_ids, - inference_config=InferenceConfig(), - ) - inference_results = await _collect_inferences( - inference_requests=[inference_request], eval_service=eval_service - ) - - eval_case_results = await _collect_eval_results( - inference_results=inference_results, - eval_service=eval_service, - eval_metrics=req.eval_metrics, - ) - except ModuleNotFoundError as e: - logger.exception("%s", e) - raise HTTPException( - status_code=400, detail=MISSING_EVAL_DEPENDENCIES_MESSAGE - ) from e - - run_eval_results = [] - for eval_case_result in eval_case_results: - run_eval_results.append( - RunEvalResult( - eval_set_file=eval_case_result.eval_set_file, - eval_set_id=eval_set_id, - eval_id=eval_case_result.eval_id, - final_eval_status=eval_case_result.final_eval_status, - overall_eval_metric_results=eval_case_result.overall_eval_metric_results, - eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, - ) - ) + extra_fast_api_args.update( + register_processors=register_processors, + ) - return run_eval_results + if reload_agents: - @app.get( - "/apps/{app_name}/eval_results/{eval_result_id}", - response_model_exclude_none=True, - ) - def get_eval_result( - app_name: str, - eval_result_id: str, - ) -> EvalSetResult: - """Gets the eval result for the given eval id.""" - try: - return eval_set_results_manager.get_eval_set_result( - app_name, eval_result_id + def setup_observer(observer: Observer, adk_web_server: AdkWebServer): + agent_change_handler = AgentChangeEventHandler( + agent_loader=agent_loader, + runners_to_clean=adk_web_server.runners_to_clean, + current_app_name_ref=adk_web_server.current_app_name_ref, ) - except ValueError as ve: - raise HTTPException(status_code=404, detail=str(ve)) from ve - except ValidationError as ve: - raise HTTPException(status_code=500, detail=str(ve)) from ve - - @app.get( - "/apps/{app_name}/eval_results", - response_model_exclude_none=True, - ) - def list_eval_results(app_name: str) -> list[str]: - """Lists all eval results for the given app.""" - return eval_set_results_manager.list_eval_set_results(app_name) - - @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") - async def delete_session(app_name: str, user_id: str, session_id: str): - await session_service.delete_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) + observer.schedule(agent_change_handler, agents_dir, recursive=True) + observer.start() - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", - response_model_exclude_none=True, - ) - async def load_artifact( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - version: Optional[int] = Query(None), - ) -> Optional[types.Part]: - artifact = await artifact_service.load_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - version=version, - ) - if not artifact: - raise HTTPException(status_code=404, detail="Artifact not found") - return artifact + def tear_down_observer(observer: Observer, _: AdkWebServer): + observer.stop() + observer.join() - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", - response_model_exclude_none=True, - ) - async def load_artifact_version( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - version_id: int, - ) -> Optional[types.Part]: - artifact = await artifact_service.load_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - version=version_id, + extra_fast_api_args.update( + setup_observer=setup_observer, + tear_down_observer=tear_down_observer, ) - if not artifact: - raise HTTPException(status_code=404, detail="Artifact not found") - return artifact - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", - response_model_exclude_none=True, - ) - async def list_artifact_names( - app_name: str, user_id: str, session_id: str - ) -> list[str]: - return await artifact_service.list_artifact_keys( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", - response_model_exclude_none=True, - ) - async def list_artifact_versions( - app_name: str, user_id: str, session_id: str, artifact_name: str - ) -> list[int]: - return await artifact_service.list_versions( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, + if web: + BASE_DIR = Path(__file__).parent.resolve() + ANGULAR_DIST_PATH = BASE_DIR / "browser" + extra_fast_api_args.update( + web_assets_dir=ANGULAR_DIST_PATH, ) - @app.delete( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + app = adk_web_server.get_fast_api_app( + lifespan=lifespan, + allow_origins=allow_origins, + **extra_fast_api_args, ) - async def delete_artifact( - app_name: str, user_id: str, session_id: str, artifact_name: str - ): - await artifact_service.delete_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - ) @working_in_progress("builder_save is not ready for use.") @app.post("/builder/save", response_model_exclude_none=True) @@ -858,202 +269,6 @@ async def builder_build(files: list[UploadFile]) -> bool: return True - @app.post("/run", response_model_exclude_none=True) - async def agent_run(req: AgentRunRequest) -> list[Event]: - session = await session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - runner = await _get_runner_async(req.app_name) - events = [ - event - async for event in runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - ) - ] - logger.info("Generated %s events in agent run", len(events)) - logger.debug("Events generated: %s", events) - return events - - @app.post("/run_sse") - async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: - # SSE endpoint - session = await session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - # Convert the events to properly formatted SSE - async def event_generator(): - try: - stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE - runner = await _get_runner_async(req.app_name) - async for event in runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), - ): - # Format as SSE data - sse_event = event.model_dump_json(exclude_none=True, by_alias=True) - logger.debug("Generated event in agent run streaming: %s", sse_event) - yield f"data: {sse_event}\n\n" - except Exception as e: - logger.exception("Error in event_generator: %s", e) - # You might want to yield an error event here - yield f'data: {{"error": "{str(e)}"}}\n\n' - - # Returns a streaming response with the proper media type for SSE - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - ) - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph", - response_model_exclude_none=True, - ) - async def get_event_graph( - app_name: str, user_id: str, session_id: str, event_id: str - ): - session = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - session_events = session.events if session else [] - event = next((x for x in session_events if x.id == event_id), None) - if not event: - return {} - - from . import agent_graph - - function_calls = event.get_function_calls() - function_responses = event.get_function_responses() - root_agent = agent_loader.load_agent(app_name) - dot_graph = None - if function_calls: - function_call_highlights = [] - for function_call in function_calls: - from_name = event.author - to_name = function_call.name - function_call_highlights.append((from_name, to_name)) - dot_graph = await agent_graph.get_agent_graph( - root_agent, function_call_highlights - ) - elif function_responses: - function_responses_highlights = [] - for function_response in function_responses: - from_name = function_response.name - to_name = event.author - function_responses_highlights.append((from_name, to_name)) - dot_graph = await agent_graph.get_agent_graph( - root_agent, function_responses_highlights - ) - else: - from_name = event.author - to_name = "" - dot_graph = await agent_graph.get_agent_graph( - root_agent, [(from_name, to_name)] - ) - if dot_graph and isinstance(dot_graph, graphviz.Digraph): - return GetEventGraphResult(dot_src=dot_graph.source) - else: - return {} - - @app.websocket("/run_live") - async def agent_live_run( - websocket: WebSocket, - app_name: str, - user_id: str, - session_id: str, - modalities: List[Literal["TEXT", "AUDIO"]] = Query( - default=["TEXT", "AUDIO"] - ), # Only allows "TEXT" or "AUDIO" - ) -> None: - await websocket.accept() - - session = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - if not session: - # Accept first so that the client is aware of connection establishment, - # then close with a specific code. - await websocket.close(code=1002, reason="Session not found") - return - - live_request_queue = LiveRequestQueue() - - async def forward_events(): - runner = await _get_runner_async(app_name) - async for event in runner.run_live( - session=session, live_request_queue=live_request_queue - ): - await websocket.send_text( - event.model_dump_json(exclude_none=True, by_alias=True) - ) - - async def process_messages(): - try: - while True: - data = await websocket.receive_text() - # Validate and send the received message to the live queue. - live_request_queue.send(LiveRequest.model_validate_json(data)) - except ValidationError as ve: - logger.error("Validation error in process_messages: %s", ve) - - # Run both tasks concurrently and cancel all if one fails. - tasks = [ - asyncio.create_task(forward_events()), - asyncio.create_task(process_messages()), - ] - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_EXCEPTION - ) - try: - # This will re-raise any exception from the completed tasks. - for task in done: - task.result() - except WebSocketDisconnect: - logger.info("Client disconnected during process_messages.") - except Exception as e: - logger.exception("Error during live websocket communication: %s", e) - traceback.print_exc() - WEBSOCKET_INTERNAL_ERROR_CODE = 1011 - WEBSOCKET_MAX_BYTES_FOR_REASON = 123 - await websocket.close( - code=WEBSOCKET_INTERNAL_ERROR_CODE, - reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON], - ) - finally: - for task in pending: - task.cancel() - - async def _get_runner_async(app_name: str) -> Runner: - """Returns the runner for the given app.""" - if app_name in _runners_to_clean: - _runners_to_clean.remove(app_name) - runner = runner_dict.pop(app_name, None) - await cleanup.close_runners(list([runner])) - - envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir) - if app_name in runner_dict: - return runner_dict[app_name] - root_agent = agent_loader.load_agent(app_name) - runner = Runner( - app_name=app_name, - agent=root_agent, - artifact_service=artifact_service, - session_service=session_service, - memory_service=memory_service, - credential_service=credential_service, - ) - runner_dict[app_name] = runner - return runner - if a2a: try: from a2a.server.apps import A2AStarletteApplication @@ -1084,7 +299,7 @@ def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" async def _get_a2a_runner_async() -> Runner: - return await _get_runner_async(captured_app_name) + return await adk_web_server.get_runner_async(captured_app_name) return _get_a2a_runner_async @@ -1135,28 +350,5 @@ async def _get_a2a_runner_async() -> Runner: except Exception as e: logger.error("Failed to setup A2A agent %s: %s", app_name, e) # Continue with other agents even if one fails - if web: - import mimetypes - - mimetypes.add_type("application/javascript", ".js", True) - mimetypes.add_type("text/javascript", ".js", True) - BASE_DIR = Path(__file__).parent.resolve() - ANGULAR_DIST_PATH = BASE_DIR / "browser" - - @app.get("/") - async def redirect_root_to_dev_ui(): - return RedirectResponse("/dev-ui/") - - @app.get("/dev-ui") - async def redirect_dev_ui_add_slash(): - return RedirectResponse("/dev-ui/") - - app.mount( - "/dev-ui/", - StaticFiles( - directory=ANGULAR_DIST_PATH, html=True, follow_symlink=True - ), - name="static", - ) return app diff --git a/src/google/adk/cli/utils/__init__.py b/src/google/adk/cli/utils/__init__.py index 846c15635..8aa11b252 100644 --- a/src/google/adk/cli/utils/__init__.py +++ b/src/google/adk/cli/utils/__init__.py @@ -18,32 +18,8 @@ from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent +from .state import create_empty_state __all__ = [ 'create_empty_state', ] - - -def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]): - for sub_agent in agent.sub_agents: - _create_empty_state(sub_agent, all_state) - - if ( - isinstance(agent, LlmAgent) - and agent.instruction - and isinstance(agent.instruction, str) - ): - for key in re.findall(r'{([\w]+)}', agent.instruction): - all_state[key] = '' - - -def create_empty_state( - agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None -) -> dict[str, Any]: - """Creates empty str for non-initialized states.""" - non_initialized_states = {} - _create_empty_state(agent, non_initialized_states) - for key in initialized_states or {}: - if key in non_initialized_states: - del non_initialized_states[key] - return non_initialized_states diff --git a/src/google/adk/cli/utils/agent_change_handler.py b/src/google/adk/cli/utils/agent_change_handler.py new file mode 100644 index 000000000..6e9228088 --- /dev/null +++ b/src/google/adk/cli/utils/agent_change_handler.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File system event handler for agent changes to trigger hot reload for agents.""" + +from __future__ import annotations + +import logging + +from watchdog.events import FileSystemEventHandler + +from .agent_loader import AgentLoader +from .shared_value import SharedValue + +logger = logging.getLogger("google_adk." + __name__) + + +class AgentChangeEventHandler(FileSystemEventHandler): + + def __init__( + self, + agent_loader: AgentLoader, + runners_to_clean: set[str], + current_app_name_ref: SharedValue[str], + ): + self.agent_loader = agent_loader + self.runners_to_clean = runners_to_clean + self.current_app_name_ref = current_app_name_ref + + def on_modified(self, event): + if not (event.src_path.endswith(".py") or event.src_path.endswith(".yaml")): + return + logger.info("Change detected in agents directory: %s", event.src_path) + self.agent_loader.remove_agent_from_cache(self.current_app_name_ref.value) + self.runners_to_clean.add(self.current_app_name_ref.value) diff --git a/src/google/adk/cli/utils/shared_value.py b/src/google/adk/cli/utils/shared_value.py new file mode 100644 index 000000000..e9202df92 --- /dev/null +++ b/src/google/adk/cli/utils/shared_value.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Generic +from typing import TypeVar + +import pydantic + +T = TypeVar("T") + + +class SharedValue(pydantic.BaseModel, Generic[T]): + """Simple wrapper around a value to allow modifying it from callbacks.""" + + model_config = pydantic.ConfigDict( + arbitrary_types_allowed=True, + ) + value: T diff --git a/src/google/adk/cli/utils/state.py b/src/google/adk/cli/utils/state.py new file mode 100644 index 000000000..29d0b1f24 --- /dev/null +++ b/src/google/adk/cli/utils/state.py @@ -0,0 +1,47 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +from typing import Any +from typing import Optional + +from ...agents.base_agent import BaseAgent +from ...agents.llm_agent import LlmAgent + + +def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]): + for sub_agent in agent.sub_agents: + _create_empty_state(sub_agent, all_state) + + if ( + isinstance(agent, LlmAgent) + and agent.instruction + and isinstance(agent.instruction, str) + ): + for key in re.findall(r'{([\w]+)}', agent.instruction): + all_state[key] = '' + + +def create_empty_state( + agent: BaseAgent, initialized_states: Optional[dict[str, Any]] = None +) -> dict[str, Any]: + """Creates empty str for non-initialized states.""" + non_initialized_states = {} + _create_empty_state(agent, non_initialized_states) + for key in initialized_states or {}: + if key in non_initialized_states: + del non_initialized_states[key] + return non_initialized_states