diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py new file mode 100644 index 000000000..8170aab77 --- /dev/null +++ b/src/google/adk/cli/adk_web_server.py @@ -0,0 +1,966 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from dataclasses import dataclass +import logging +import time +import traceback +import typing +from typing import Any +from typing import Callable +from typing import List +from typing import Literal +from typing import Optional + +from fastapi import FastAPI +from fastapi import HTTPException +from fastapi import Query +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.responses import StreamingResponse +from fastapi.staticfiles import StaticFiles +from fastapi.websockets import WebSocket +from fastapi.websockets import WebSocketDisconnect +from google.adk.evaluation.eval_set_results_manager import EvalSetResultsManager +from google.genai import types +import graphviz +from opentelemetry import trace +from opentelemetry.sdk.trace import export +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.sdk.trace import TracerProvider +from pydantic import Field +from pydantic import ValidationError +from starlette.types import Lifespan +from typing_extensions import override +from watchdog.observers import Observer + +from . import agent_graph +from ..agents.live_request_queue import LiveRequest +from ..agents.live_request_queue import LiveRequestQueue +from ..agents.run_config import RunConfig +from ..agents.run_config import StreamingMode +from ..artifacts.base_artifact_service import BaseArtifactService +from ..auth.credential_service.base_credential_service import BaseCredentialService +from ..errors.not_found_error import NotFoundError +from ..evaluation.eval_case import EvalCase +from ..evaluation.eval_case import SessionInput +from ..evaluation.eval_metrics import EvalMetric +from ..evaluation.eval_metrics import EvalMetricResult +from ..evaluation.eval_metrics import EvalMetricResultPerInvocation +from ..evaluation.eval_result import EvalSetResult +from ..evaluation.eval_sets_manager import EvalSetsManager +from ..events.event import Event +from ..memory.base_memory_service import BaseMemoryService +from ..runners import Runner +from ..sessions.base_session_service import BaseSessionService +from ..sessions.session import Session +from .cli_eval import EVAL_SESSION_ID_PREFIX +from .cli_eval import EvalStatus +from .utils import cleanup +from .utils import common +from .utils import evals +from .utils.base_agent_loader import BaseAgentLoader +from .utils.shared_value import SharedValue +from .utils.state import create_empty_state + +logger = logging.getLogger("google_adk." + __name__) + +_EVAL_SET_FILE_EXTENSION = ".evalset.json" + + +class ApiServerSpanExporter(export.SpanExporter): + + def __init__(self, trace_dict): + self.trace_dict = trace_dict + + def export( + self, spans: typing.Sequence[ReadableSpan] + ) -> export.SpanExportResult: + for span in spans: + if ( + span.name == "call_llm" + or span.name == "send_data" + or span.name.startswith("execute_tool") + ): + attributes = dict(span.attributes) + attributes["trace_id"] = span.get_span_context().trace_id + attributes["span_id"] = span.get_span_context().span_id + if attributes.get("gcp.vertex.agent.event_id", None): + self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes + return export.SpanExportResult.SUCCESS + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + +class InMemoryExporter(export.SpanExporter): + + def __init__(self, trace_dict): + super().__init__() + self._spans = [] + self.trace_dict = trace_dict + + @override + def export( + self, spans: typing.Sequence[ReadableSpan] + ) -> export.SpanExportResult: + for span in spans: + trace_id = span.context.trace_id + if span.name == "call_llm": + attributes = dict(span.attributes) + session_id = attributes.get("gcp.vertex.agent.session_id", None) + if session_id: + if session_id not in self.trace_dict: + self.trace_dict[session_id] = [trace_id] + else: + self.trace_dict[session_id] += [trace_id] + self._spans.extend(spans) + return export.SpanExportResult.SUCCESS + + @override + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True + + def get_finished_spans(self, session_id: str): + trace_ids = self.trace_dict.get(session_id, None) + if trace_ids is None or not trace_ids: + return [] + return [x for x in self._spans if x.context.trace_id in trace_ids] + + def clear(self): + self._spans.clear() + + +class AgentRunRequest(common.BaseModel): + app_name: str + user_id: str + session_id: str + new_message: types.Content + streaming: bool = False + state_delta: Optional[dict[str, Any]] = None + + +class AddSessionToEvalSetRequest(common.BaseModel): + eval_id: str + session_id: str + user_id: str + + +class RunEvalRequest(common.BaseModel): + eval_ids: list[str] # if empty, then all evals in the eval set are run. + eval_metrics: list[EvalMetric] + + +class RunEvalResult(common.BaseModel): + eval_set_file: str + eval_set_id: str + eval_id: str + final_eval_status: EvalStatus + eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( + deprecated=True, + description=( + "This field is deprecated, use overall_eval_metric_results instead." + ), + ) + overall_eval_metric_results: list[EvalMetricResult] + eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] + user_id: str + session_id: str + + +class GetEventGraphResult(common.BaseModel): + dot_src: str + + +@dataclass(kw_only=True) +class AdkWebServer: + """Helper class for setting up and running the ADK web server on FastAPI. + + You construct this class with all the Services required to run ADK agents and + can then call the get_fast_api_app method to get a FastAPI app instance that + can will use your provided service instances, static assets, and agent loader. + If you pass in a web_assets_dir, the static assets will be served under + /dev-ui in addition to the API endpoints created by default. + + You can add add additional API endpoints by modifying the FastAPI app + instance returned by get_fast_api_app as this class exposes the agent runners + and most other bits of state retained during the lifetime of the server. + + Attributes: + agent_loader: An instance of BaseAgentLoader for loading agents. + session_service: An instance of BaseSessionService for managing sessions. + memory_service: An instance of BaseMemoryService for managing memory. + artifact_service: An instance of BaseArtifactService for managing + artifacts. + credential_service: An instance of BaseCredentialService for managing + credentials. + eval_sets_manager: An instance of EvalSetsManager for managing evaluation + sets. + eval_set_results_manager: An instance of EvalSetResultsManager for + managing evaluation set results. + runners_to_clean: Set of runner names marked for cleanup. + current_app_name_ref: A shared reference to the latest ran app name. + runner_dict: A dict of instantiated runners for each app. + """ + + agent_loader: BaseAgentLoader + session_service: BaseSessionService + memory_service: BaseMemoryService + artifact_service: BaseArtifactService + credential_service: BaseCredentialService + eval_sets_manager: EvalSetsManager + eval_set_results_manager: EvalSetResultsManager + + def __post_init__(self): + """For propeties we want to allow being modified from callbacks.""" + self.runners_to_clean: set[str] = set() + self.current_app_name_ref: SharedValue[str] = SharedValue(value="") + self.runner_dict = {} + + async def get_runner_async(self, app_name: str) -> Runner: + """Returns the runner for the given app.""" + if app_name in self.runners_to_clean: + self.runners_to_clean.remove(app_name) + runner = self.runner_dict.pop(app_name, None) + await cleanup.close_runners(list([runner])) + + if app_name in self.runner_dict: + return self.runner_dict[app_name] + root_agent = self.agent_loader.load_agent(app_name) + runner = Runner( + app_name=app_name, + agent=root_agent, + artifact_service=self.artifact_service, + session_service=self.session_service, + memory_service=self.memory_service, + credential_service=self.credential_service, + ) + self.runner_dict[app_name] = runner + return runner + + def get_fast_api_app( + self, + lifespan: Optional[Lifespan[FastAPI]] = None, + allow_origins: Optional[list[str]] = None, + web_assets_dir: Optional[str] = None, + setup_observer: Callable[ + [Observer, "AdkWebServer"], None + ] = lambda o, s: None, + tear_down_observer: Callable[ + [Observer, "AdkWebServer"], None + ] = lambda o, s: None, + register_processors: Callable[[TracerProvider], None] = lambda o: None, + ): + """Creates a FastAPI app for the ADK web server. + + By default it'll just return a FastAPI instance with the API server + endpoints, + but if you specify a web_assets_dir, it'll also serve the static web assets + from that directory. + + Args: + lifespan: The lifespan of the FastAPI app. + allow_origins: The origins that are allowed to make cross-origin requests. + web_assets_dir: The directory containing the web assets to serve. + setup_observer: Callback for setting up the file system observer. + tear_down_observer: Callback for cleaning up the file system observer. + register_processors: Callback for additional Span processors to be added + to the TracerProvider. + + Returns: + A FastAPI app instance. + """ + # Properties we don't need to modify from callbacks + trace_dict = {} + session_trace_dict = {} + # Set up a file system watcher to detect changes in the agents directory. + observer = Observer() + setup_observer(observer, self) + + @asynccontextmanager + async def internal_lifespan(app: FastAPI): + try: + if lifespan: + async with lifespan(app) as lifespan_context: + yield lifespan_context + else: + yield + finally: + tear_down_observer(observer, self) + # Create tasks for all runner closures to run concurrently + await cleanup.close_runners(list(self.runner_dict.values())) + + # Set up tracing in the FastAPI server. + provider = TracerProvider() + provider.add_span_processor( + export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)) + ) + memory_exporter = InMemoryExporter(session_trace_dict) + provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter)) + + register_processors(provider) + + trace.set_tracer_provider(provider) + + # Run the FastAPI server. + app = FastAPI(lifespan=internal_lifespan) + + if allow_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=allow_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/list-apps") + def list_apps() -> list[str]: + return self.agent_loader.list_agents() + + @app.get("/debug/trace/{event_id}") + def get_trace_dict(event_id: str) -> Any: + event_dict = trace_dict.get(event_id, None) + if event_dict is None: + raise HTTPException(status_code=404, detail="Trace not found") + return event_dict + + @app.get("/debug/trace/session/{session_id}") + def get_session_trace(session_id: str) -> Any: + spans = memory_exporter.get_finished_spans(session_id) + if not spans: + return [] + return [ + { + "name": s.name, + "span_id": s.context.span_id, + "trace_id": s.context.trace_id, + "start_time": s.start_time, + "end_time": s.end_time, + "attributes": dict(s.attributes), + "parent_span_id": s.parent.span_id if s.parent else None, + } + for s in spans + ] + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def get_session( + app_name: str, user_id: str, session_id: str + ) -> Session: + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + self.current_app_name_ref.value = app_name + return session + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions", + response_model_exclude_none=True, + ) + async def list_sessions(app_name: str, user_id: str) -> list[Session]: + list_sessions_response = await self.session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + return [ + session + for session in list_sessions_response.sessions + # Remove sessions that were generated as a part of Eval. + if not session.id.startswith(EVAL_SESSION_ID_PREFIX) + ] + + @app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def create_session_with_id( + app_name: str, + user_id: str, + session_id: str, + state: Optional[dict[str, Any]] = None, + ) -> Session: + if ( + await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + is not None + ): + logger.warning("Session already exists: %s", session_id) + raise HTTPException( + status_code=400, detail=f"Session already exists: {session_id}" + ) + logger.info("New session created: %s", session_id) + return await self.session_service.create_session( + app_name=app_name, user_id=user_id, state=state, session_id=session_id + ) + + @app.post( + "/apps/{app_name}/users/{user_id}/sessions", + response_model_exclude_none=True, + ) + async def create_session( + app_name: str, + user_id: str, + state: Optional[dict[str, Any]] = None, + events: Optional[list[Event]] = None, + ) -> Session: + logger.info("New session created") + session = await self.session_service.create_session( + app_name=app_name, user_id=user_id, state=state + ) + + if events: + for event in events: + await self.session_service.append_event(session=session, event=event) + + return session + + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}", + response_model_exclude_none=True, + ) + def create_eval_set( + app_name: str, + eval_set_id: str, + ): + """Creates an eval set, given the id.""" + try: + self.eval_sets_manager.create_eval_set(app_name, eval_set_id) + except ValueError as ve: + raise HTTPException( + status_code=400, + detail=str(ve), + ) from ve + + @app.get( + "/apps/{app_name}/eval_sets", + response_model_exclude_none=True, + ) + def list_eval_sets(app_name: str) -> list[str]: + """Lists all eval sets for the given app.""" + try: + return self.eval_sets_manager.list_eval_sets(app_name) + except NotFoundError as e: + logger.warning(e) + return [] + + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}/add_session", + response_model_exclude_none=True, + ) + async def add_session_to_eval_set( + app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest + ): + # Get the session + session = await self.session_service.get_session( + app_name=app_name, user_id=req.user_id, session_id=req.session_id + ) + assert session, "Session not found." + + # Convert the session data to eval invocations + invocations = evals.convert_session_to_eval_invocations(session) + + # Populate the session with initial session state. + initial_session_state = create_empty_state( + self.agent_loader.load_agent(app_name) + ) + + new_eval_case = EvalCase( + eval_id=req.eval_id, + conversation=invocations, + session_input=SessionInput( + app_name=app_name, + user_id=req.user_id, + state=initial_session_state, + ), + creation_timestamp=time.time(), + ) + + try: + self.eval_sets_manager.add_eval_case( + app_name, eval_set_id, new_eval_case + ) + except ValueError as ve: + raise HTTPException(status_code=400, detail=str(ve)) from ve + + @app.get( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals", + response_model_exclude_none=True, + ) + def list_evals_in_eval_set( + app_name: str, + eval_set_id: str, + ) -> list[str]: + """Lists all evals in an eval set.""" + eval_set_data = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) + + if not eval_set_data: + raise HTTPException( + status_code=400, detail=f"Eval set `{eval_set_id}` not found." + ) + + return sorted([x.eval_id for x in eval_set_data.eval_cases]) + + @app.get( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", + response_model_exclude_none=True, + ) + def get_eval( + app_name: str, eval_set_id: str, eval_case_id: str + ) -> EvalCase: + """Gets an eval case in an eval set.""" + eval_case_to_find = self.eval_sets_manager.get_eval_case( + app_name, eval_set_id, eval_case_id + ) + + if eval_case_to_find: + return eval_case_to_find + + raise HTTPException( + status_code=404, + detail=( + f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found." + ), + ) + + @app.put( + "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", + response_model_exclude_none=True, + ) + def update_eval( + app_name: str, + eval_set_id: str, + eval_case_id: str, + updated_eval_case: EvalCase, + ): + if ( + updated_eval_case.eval_id + and updated_eval_case.eval_id != eval_case_id + ): + raise HTTPException( + status_code=400, + detail=( + "Eval id in EvalCase should match the eval id in the API route." + ), + ) + + # Overwrite the value. We are either overwriting the same value or an empty + # field. + updated_eval_case.eval_id = eval_case_id + try: + self.eval_sets_manager.update_eval_case( + app_name, eval_set_id, updated_eval_case + ) + except NotFoundError as nfe: + raise HTTPException(status_code=404, detail=str(nfe)) from nfe + + @app.delete("/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}") + def delete_eval(app_name: str, eval_set_id: str, eval_case_id: str): + try: + self.eval_sets_manager.delete_eval_case( + app_name, eval_set_id, eval_case_id + ) + except NotFoundError as nfe: + raise HTTPException(status_code=404, detail=str(nfe)) from nfe + + @app.post( + "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval", + response_model_exclude_none=True, + ) + async def run_eval( + app_name: str, eval_set_id: str, req: RunEvalRequest + ) -> list[RunEvalResult]: + """Runs an eval given the details in the eval request.""" + from .cli_eval import run_evals + + # Create a mapping from eval set file to all the evals that needed to be + # run. + eval_set = self.eval_sets_manager.get_eval_set(app_name, eval_set_id) + + if not eval_set: + raise HTTPException( + status_code=400, detail=f"Eval set `{eval_set_id}` not found." + ) + + if req.eval_ids: + eval_cases = [ + e for e in eval_set.eval_cases if e.eval_id in req.eval_ids + ] + eval_set_to_evals = {eval_set_id: eval_cases} + else: + logger.info( + "Eval ids to run list is empty. We will run all eval cases." + ) + eval_set_to_evals = {eval_set_id: eval_set.eval_cases} + + root_agent = self.agent_loader.load_agent(app_name) + run_eval_results = [] + eval_case_results = [] + try: + async for eval_case_result in run_evals( + eval_set_to_evals, + root_agent, + getattr(root_agent, "reset_data", None), + req.eval_metrics, + session_service=self.session_service, + artifact_service=self.artifact_service, + ): + run_eval_results.append( + RunEvalResult( + app_name=app_name, + eval_set_file=eval_case_result.eval_set_file, + eval_set_id=eval_set_id, + eval_id=eval_case_result.eval_id, + final_eval_status=eval_case_result.final_eval_status, + eval_metric_results=eval_case_result.eval_metric_results, + overall_eval_metric_results=eval_case_result.overall_eval_metric_results, + eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, + ) + ) + eval_case_result.session_details = ( + await self.session_service.get_session( + app_name=app_name, + user_id=eval_case_result.user_id, + session_id=eval_case_result.session_id, + ) + ) + eval_case_results.append(eval_case_result) + except ModuleNotFoundError as e: + logger.exception("%s", e) + raise HTTPException(status_code=400, detail=str(e)) from e + + self.eval_set_results_manager.save_eval_set_result( + app_name, eval_set_id, eval_case_results + ) + + return run_eval_results + + @app.get( + "/apps/{app_name}/eval_results/{eval_result_id}", + response_model_exclude_none=True, + ) + def get_eval_result( + app_name: str, + eval_result_id: str, + ) -> EvalSetResult: + """Gets the eval result for the given eval id.""" + try: + return self.eval_set_results_manager.get_eval_set_result( + app_name, eval_result_id + ) + except ValueError as ve: + raise HTTPException(status_code=404, detail=str(ve)) from ve + except ValidationError as ve: + raise HTTPException(status_code=500, detail=str(ve)) from ve + + @app.get( + "/apps/{app_name}/eval_results", + response_model_exclude_none=True, + ) + def list_eval_results(app_name: str) -> list[str]: + """Lists all eval results for the given app.""" + return self.eval_set_results_manager.list_eval_set_results(app_name) + + @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") + async def delete_session(app_name: str, user_id: str, session_id: str): + await self.session_service.delete_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + response_model_exclude_none=True, + ) + async def load_artifact( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + version: Optional[int] = Query(None), + ) -> Optional[types.Part]: + artifact = await self.artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + version=version, + ) + if not artifact: + raise HTTPException(status_code=404, detail="Artifact not found") + return artifact + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", + response_model_exclude_none=True, + ) + async def load_artifact_version( + app_name: str, + user_id: str, + session_id: str, + artifact_name: str, + version_id: int, + ) -> Optional[types.Part]: + artifact = await self.artifact_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + version=version_id, + ) + if not artifact: + raise HTTPException(status_code=404, detail="Artifact not found") + return artifact + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", + response_model_exclude_none=True, + ) + async def list_artifact_names( + app_name: str, user_id: str, session_id: str + ) -> list[str]: + return await self.artifact_service.list_artifact_keys( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", + response_model_exclude_none=True, + ) + async def list_artifact_versions( + app_name: str, user_id: str, session_id: str, artifact_name: str + ) -> list[int]: + return await self.artifact_service.list_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + ) + + @app.delete( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + ) + async def delete_artifact( + app_name: str, user_id: str, session_id: str, artifact_name: str + ): + await self.artifact_service.delete_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=artifact_name, + ) + + @app.post("/run", response_model_exclude_none=True) + async def agent_run(req: AgentRunRequest) -> list[Event]: + session = await self.session_service.get_session( + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + runner = await self.get_runner_async(req.app_name) + events = [ + event + async for event in runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + ) + ] + logger.info("Generated %s events in agent run: %s", len(events), events) + return events + + @app.post("/run_sse") + async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: + # SSE endpoint + session = await self.session_service.get_session( + app_name=req.app_name, user_id=req.user_id, session_id=req.session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + # Convert the events to properly formatted SSE + async def event_generator(): + try: + stream_mode = ( + StreamingMode.SSE if req.streaming else StreamingMode.NONE + ) + runner = await self.get_runner_async(req.app_name) + async for event in runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + ): + # Format as SSE data + sse_event = event.model_dump_json(exclude_none=True, by_alias=True) + logger.info("Generated event in agent run streaming: %s", sse_event) + yield f"data: {sse_event}\n\n" + except Exception as e: + logger.exception("Error in event_generator: %s", e) + # You might want to yield an error event here + yield f'data: {{"error": "{str(e)}"}}\n\n' + + # Returns a streaming response with the proper media type for SSE + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + ) + + @app.get( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph", + response_model_exclude_none=True, + ) + async def get_event_graph( + app_name: str, user_id: str, session_id: str, event_id: str + ): + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + session_events = session.events if session else [] + event = next((x for x in session_events if x.id == event_id), None) + if not event: + return {} + + function_calls = event.get_function_calls() + function_responses = event.get_function_responses() + root_agent = self.agent_loader.load_agent(app_name) + dot_graph = None + if function_calls: + function_call_highlights = [] + for function_call in function_calls: + from_name = event.author + to_name = function_call.name + function_call_highlights.append((from_name, to_name)) + dot_graph = await agent_graph.get_agent_graph( + root_agent, function_call_highlights + ) + elif function_responses: + function_responses_highlights = [] + for function_response in function_responses: + from_name = function_response.name + to_name = event.author + function_responses_highlights.append((from_name, to_name)) + dot_graph = await agent_graph.get_agent_graph( + root_agent, function_responses_highlights + ) + else: + from_name = event.author + to_name = "" + dot_graph = await agent_graph.get_agent_graph( + root_agent, [(from_name, to_name)] + ) + if dot_graph and isinstance(dot_graph, graphviz.Digraph): + return GetEventGraphResult(dot_src=dot_graph.source) + else: + return {} + + @app.websocket("/run_live") + async def agent_live_run( + websocket: WebSocket, + app_name: str, + user_id: str, + session_id: str, + modalities: List[Literal["TEXT", "AUDIO"]] = Query( + default=["TEXT", "AUDIO"] + ), # Only allows "TEXT" or "AUDIO" + ) -> None: + await websocket.accept() + + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + # Accept first so that the client is aware of connection establishment, + # then close with a specific code. + await websocket.close(code=1002, reason="Session not found") + return + + live_request_queue = LiveRequestQueue() + + async def forward_events(): + runner = await self.get_runner_async(app_name) + async for event in runner.run_live( + session=session, live_request_queue=live_request_queue + ): + await websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) + + async def process_messages(): + try: + while True: + data = await websocket.receive_text() + # Validate and send the received message to the live queue. + live_request_queue.send(LiveRequest.model_validate_json(data)) + except ValidationError as ve: + logger.error("Validation error in process_messages: %s", ve) + + # Run both tasks concurrently and cancel all if one fails. + tasks = [ + asyncio.create_task(forward_events()), + asyncio.create_task(process_messages()), + ] + done, pending = await asyncio.wait( + tasks, return_when=asyncio.FIRST_EXCEPTION + ) + try: + # This will re-raise any exception from the completed tasks. + for task in done: + task.result() + except WebSocketDisconnect: + logger.info("Client disconnected during process_messages.") + except Exception as e: + logger.exception("Error during live websocket communication: %s", e) + traceback.print_exc() + WEBSOCKET_INTERNAL_ERROR_CODE = 1011 + WEBSOCKET_MAX_BYTES_FOR_REASON = 123 + await websocket.close( + code=WEBSOCKET_INTERNAL_ERROR_CODE, + reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON], + ) + finally: + for task in pending: + task.cancel() + + if web_assets_dir: + import mimetypes + + mimetypes.add_type("application/javascript", ".js", True) + mimetypes.add_type("text/javascript", ".js", True) + + @app.get("/") + async def redirect_root_to_dev_ui(): + return RedirectResponse("/dev-ui/") + + @app.get("/dev-ui") + async def redirect_dev_ui_add_slash(): + return RedirectResponse("/dev-ui/") + + app.mount( + "/dev-ui/", + StaticFiles(directory=web_assets_dir, html=True, follow_symlink=True), + name="static", + ) + + return app diff --git a/src/google/adk/cli/agent_graph.py b/src/google/adk/cli/agent_graph.py index 2df968f81..e919010cc 100644 --- a/src/google/adk/cli/agent_graph.py +++ b/src/google/adk/cli/agent_graph.py @@ -19,11 +19,11 @@ import graphviz -from ..agents import BaseAgent -from ..agents import LoopAgent -from ..agents import ParallelAgent -from ..agents import SequentialAgent +from ..agents.base_agent import BaseAgent from ..agents.llm_agent import LlmAgent +from ..agents.loop_agent import LoopAgent +from ..agents.parallel_agent import ParallelAgent +from ..agents.sequential_agent import SequentialAgent from ..tools.agent_tool import AgentTool from ..tools.base_tool import BaseTool from ..tools.function_tool import FunctionTool diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 79d0bfe65..9314a63db 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -31,7 +31,7 @@ from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session from .utils import envs -from .utils.agent_loader import AgentLoader +from .utils.file_system_agent_loader import FileSystemAgentLoader class InputFile(BaseModel): @@ -137,7 +137,7 @@ async def run_cli( session = await session_service.create_session( app_name=agent_folder_name, user_id=user_id ) - root_agent = AgentLoader(agents_dir=agent_parent_dir).load_agent( + root_agent = FileSystemAgentLoader(agents_dir=agent_parent_dir).load_agent( agent_folder_name ) envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index b9630103b..db79734b6 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -25,7 +25,7 @@ from typing import Optional import uuid -from ..agents import Agent +from ..agents.llm_agent import Agent from ..artifacts.base_artifact_service import BaseArtifactService from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from ..evaluation.eval_case import EvalCase diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 937dc6e1f..776ec4d12 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -14,200 +14,47 @@ from __future__ import annotations -import asyncio -from contextlib import asynccontextmanager import json import logging import os from pathlib import Path -import shutil -import time -import traceback -import typing -from typing import Any -from typing import List -from typing import Literal from typing import Optional import click from fastapi import FastAPI -from fastapi import HTTPException -from fastapi import Query -from fastapi import UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse -from fastapi.responses import StreamingResponse -from fastapi.staticfiles import StaticFiles -from fastapi.websockets import WebSocket -from fastapi.websockets import WebSocketDisconnect -from google.genai import types -import graphviz -from opentelemetry import trace from opentelemetry.sdk.trace import export -from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace import TracerProvider -from pydantic import Field -from pydantic import ValidationError from starlette.types import Lifespan -from typing_extensions import override -from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer +import yaml -from ..agents import RunConfig -from ..agents.live_request_queue import LiveRequest -from ..agents.live_request_queue import LiveRequestQueue -from ..agents.run_config import StreamingMode from ..artifacts.gcs_artifact_service import GcsArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService -from ..errors.not_found_error import NotFoundError -from ..evaluation.eval_case import EvalCase -from ..evaluation.eval_case import SessionInput -from ..evaluation.eval_metrics import EvalMetric -from ..evaluation.eval_metrics import EvalMetricResult -from ..evaluation.eval_metrics import EvalMetricResultPerInvocation -from ..evaluation.eval_result import EvalSetResult from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager -from ..events.event import Event from ..memory.in_memory_memory_service import InMemoryMemoryService from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService from ..runners import Runner from ..sessions.in_memory_session_service import InMemorySessionService -from ..sessions.session import Session from ..sessions.vertex_ai_session_service import VertexAiSessionService from ..utils.feature_decorator import working_in_progress -from .cli_eval import EVAL_SESSION_ID_PREFIX -from .cli_eval import EvalStatus -from .utils import cleanup +from .adk_web_server import AdkWebServer from .utils import common -from .utils import create_empty_state from .utils import envs from .utils import evals -from .utils.agent_loader import AgentLoader +from .utils.agent_change_handler import AgentChangeEventHandler +from .utils.file_system_agent_loader import FileSystemAgentLoader logger = logging.getLogger("google_adk." + __name__) -_EVAL_SET_FILE_EXTENSION = ".evalset.json" -_app_name = "" -_runners_to_clean = set() - -class AgentChangeEventHandler(FileSystemEventHandler): - - def __init__(self, agent_loader: AgentLoader): - self.agent_loader = agent_loader - - def on_modified(self, event): - if not (event.src_path.endswith(".py") or event.src_path.endswith(".yaml")): - return - logger.info("Change detected in agents directory: %s", event.src_path) - self.agent_loader.remove_agent_from_cache(_app_name) - _runners_to_clean.add(_app_name) - - -class ApiServerSpanExporter(export.SpanExporter): - - def __init__(self, trace_dict): - self.trace_dict = trace_dict - - def export( - self, spans: typing.Sequence[ReadableSpan] - ) -> export.SpanExportResult: - for span in spans: - if ( - span.name == "call_llm" - or span.name == "send_data" - or span.name.startswith("execute_tool") - ): - attributes = dict(span.attributes) - attributes["trace_id"] = span.get_span_context().trace_id - attributes["span_id"] = span.get_span_context().span_id - if attributes.get("gcp.vertex.agent.event_id", None): - self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes - return export.SpanExportResult.SUCCESS - - def force_flush(self, timeout_millis: int = 30000) -> bool: - return True - - -class InMemoryExporter(export.SpanExporter): - - def __init__(self, trace_dict): - super().__init__() - self._spans = [] - self.trace_dict = trace_dict - - @override - def export( - self, spans: typing.Sequence[ReadableSpan] - ) -> export.SpanExportResult: - for span in spans: - trace_id = span.context.trace_id - if span.name == "call_llm": - attributes = dict(span.attributes) - session_id = attributes.get("gcp.vertex.agent.session_id", None) - if session_id: - if session_id not in self.trace_dict: - self.trace_dict[session_id] = [trace_id] - else: - self.trace_dict[session_id] += [trace_id] - self._spans.extend(spans) - return export.SpanExportResult.SUCCESS - - @override - def force_flush(self, timeout_millis: int = 30000) -> bool: - return True - - def get_finished_spans(self, session_id: str): - trace_ids = self.trace_dict.get(session_id, None) - if trace_ids is None or not trace_ids: - return [] - return [x for x in self._spans if x.context.trace_id in trace_ids] - - def clear(self): - self._spans.clear() - - -class AgentRunRequest(common.BaseModel): - app_name: str - user_id: str - session_id: str - new_message: types.Content - streaming: bool = False - state_delta: Optional[dict[str, Any]] = None - - -class AddSessionToEvalSetRequest(common.BaseModel): - eval_id: str - session_id: str - user_id: str - - -class RunEvalRequest(common.BaseModel): - eval_ids: list[str] # if empty, then all evals in the eval set are run. - eval_metrics: list[EvalMetric] - - -class RunEvalResult(common.BaseModel): - eval_set_file: str - eval_set_id: str - eval_id: str - final_eval_status: EvalStatus - eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field( - deprecated=True, - description=( - "This field is deprecated, use overall_eval_metric_results instead." - ), - ) - overall_eval_metric_results: list[EvalMetricResult] - eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation] - user_id: str - session_id: str - - -class GetEventGraphResult(common.BaseModel): - dot_src: str +class AgentBuildRequest(common.BaseModel): + agent_name: str + agent_type: str + model: str + description: str + instruction: str def get_fast_api_app( @@ -226,66 +73,7 @@ def get_fast_api_app( reload_agents: bool = False, lifespan: Optional[Lifespan[FastAPI]] = None, ) -> FastAPI: - # InMemory tracing dict. - trace_dict: dict[str, Any] = {} - session_trace_dict: dict[str, Any] = {} - - # Set up tracing in the FastAPI server. - provider = TracerProvider() - provider.add_span_processor( - export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict)) - ) - memory_exporter = InMemoryExporter(session_trace_dict) - provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter)) - if trace_to_cloud: - from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter - - envs.load_dotenv_for_agent("", agents_dir) - if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): - processor = export.BatchSpanProcessor( - CloudTraceSpanExporter(project_id=project_id) - ) - provider.add_span_processor(processor) - else: - logger.warning( - "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will" - " not be enabled." - ) - - trace.set_tracer_provider(provider) - - @asynccontextmanager - async def internal_lifespan(app: FastAPI): - try: - if lifespan: - async with lifespan(app) as lifespan_context: - yield lifespan_context - else: - yield - finally: - if reload_agents: - observer.stop() - observer.join() - # Create tasks for all runner closures to run concurrently - await cleanup.close_runners(list(runner_dict.values())) - - # Run the FastAPI server. - app = FastAPI(lifespan=internal_lifespan) - - if allow_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=allow_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - runner_dict = {} - # Set up eval managers. - eval_sets_manager = None - eval_set_results_manager = None if eval_storage_uri: gcs_eval_managers = evals.create_gcs_eval_managers_from_uri( eval_storage_uri @@ -385,680 +173,96 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name): credential_service = InMemoryCredentialService() # initialize Agent Loader - agent_loader = AgentLoader(agents_dir) - - # Set up a file system watcher to detect changes in the agents directory. - observer = Observer() - if reload_agents: - event_handler = AgentChangeEventHandler(agent_loader) - observer.schedule(event_handler, agents_dir, recursive=True) - observer.start() - - @app.get("/list-apps") - def list_apps() -> list[str]: - base_path = Path.cwd() / agents_dir - if not base_path.exists(): - raise HTTPException(status_code=404, detail="Path not found") - if not base_path.is_dir(): - raise HTTPException(status_code=400, detail="Not a directory") - agent_names = [ - x - for x in os.listdir(base_path) - if os.path.isdir(os.path.join(base_path, x)) - and not x.startswith(".") - and x != "__pycache__" - ] - agent_names.sort() - return agent_names - - @app.get("/debug/trace/{event_id}") - def get_trace_dict(event_id: str) -> Any: - event_dict = trace_dict.get(event_id, None) - if event_dict is None: - raise HTTPException(status_code=404, detail="Trace not found") - return event_dict - - @app.get("/debug/trace/session/{session_id}") - def get_session_trace(session_id: str) -> Any: - spans = memory_exporter.get_finished_spans(session_id) - if not spans: - return [] - return [ - { - "name": s.name, - "span_id": s.context.span_id, - "trace_id": s.context.trace_id, - "start_time": s.start_time, - "end_time": s.end_time, - "attributes": dict(s.attributes), - "parent_span_id": s.parent.span_id if s.parent else None, - } - for s in spans - ] - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}", - response_model_exclude_none=True, - ) - async def get_session( - app_name: str, user_id: str, session_id: str - ) -> Session: - session = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - global _app_name - _app_name = app_name - return session - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions", - response_model_exclude_none=True, - ) - async def list_sessions(app_name: str, user_id: str) -> list[Session]: - list_sessions_response = await session_service.list_sessions( - app_name=app_name, user_id=user_id - ) - return [ - session - for session in list_sessions_response.sessions - # Remove sessions that were generated as a part of Eval. - if not session.id.startswith(EVAL_SESSION_ID_PREFIX) - ] - - @app.post( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}", - response_model_exclude_none=True, - ) - async def create_session_with_id( - app_name: str, - user_id: str, - session_id: str, - state: Optional[dict[str, Any]] = None, - ) -> Session: - if ( - await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - is not None - ): - logger.warning("Session already exists: %s", session_id) - raise HTTPException( - status_code=400, detail=f"Session already exists: {session_id}" - ) - logger.info("New session created: %s", session_id) - return await session_service.create_session( - app_name=app_name, user_id=user_id, state=state, session_id=session_id - ) - - @app.post( - "/apps/{app_name}/users/{user_id}/sessions", - response_model_exclude_none=True, - ) - async def create_session( - app_name: str, - user_id: str, - state: Optional[dict[str, Any]] = None, - events: Optional[list[Event]] = None, - ) -> Session: - logger.info("New session created") - session = await session_service.create_session( - app_name=app_name, user_id=user_id, state=state - ) - - if events: - for event in events: - await session_service.append_event(session=session, event=event) - - return session - - def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str: - return os.path.join( - agents_dir, - app_name, - eval_set_id + _EVAL_SET_FILE_EXTENSION, - ) - - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}", - response_model_exclude_none=True, - ) - def create_eval_set( - app_name: str, - eval_set_id: str, - ): - """Creates an eval set, given the id.""" - try: - eval_sets_manager.create_eval_set(app_name, eval_set_id) - except ValueError as ve: - raise HTTPException( - status_code=400, - detail=str(ve), - ) from ve - - @app.get( - "/apps/{app_name}/eval_sets", - response_model_exclude_none=True, - ) - def list_eval_sets(app_name: str) -> list[str]: - """Lists all eval sets for the given app.""" - try: - return eval_sets_manager.list_eval_sets(app_name) - except NotFoundError as e: - logger.warning(e) - return [] - - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}/add_session", - response_model_exclude_none=True, - ) - async def add_session_to_eval_set( - app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest - ): - # Get the session - session = await session_service.get_session( - app_name=app_name, user_id=req.user_id, session_id=req.session_id - ) - assert session, "Session not found." - - # Convert the session data to eval invocations - invocations = evals.convert_session_to_eval_invocations(session) - - # Populate the session with initial session state. - initial_session_state = create_empty_state( - agent_loader.load_agent(app_name) - ) - - new_eval_case = EvalCase( - eval_id=req.eval_id, - conversation=invocations, - session_input=SessionInput( - app_name=app_name, user_id=req.user_id, state=initial_session_state - ), - creation_timestamp=time.time(), - ) - - try: - eval_sets_manager.add_eval_case(app_name, eval_set_id, new_eval_case) - except ValueError as ve: - raise HTTPException(status_code=400, detail=str(ve)) from ve - - @app.get( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals", - response_model_exclude_none=True, - ) - def list_evals_in_eval_set( - app_name: str, - eval_set_id: str, - ) -> list[str]: - """Lists all evals in an eval set.""" - eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id) - - if not eval_set_data: - raise HTTPException( - status_code=400, detail=f"Eval set `{eval_set_id}` not found." - ) - - return sorted([x.eval_id for x in eval_set_data.eval_cases]) - - @app.get( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", - response_model_exclude_none=True, - ) - def get_eval(app_name: str, eval_set_id: str, eval_case_id: str) -> EvalCase: - """Gets an eval case in an eval set.""" - eval_case_to_find = eval_sets_manager.get_eval_case( - app_name, eval_set_id, eval_case_id - ) - - if eval_case_to_find: - return eval_case_to_find - - raise HTTPException( - status_code=404, - detail=f"Eval set `{eval_set_id}` or Eval `{eval_case_id}` not found.", - ) - - @app.put( - "/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}", - response_model_exclude_none=True, - ) - def update_eval( - app_name: str, - eval_set_id: str, - eval_case_id: str, - updated_eval_case: EvalCase, - ): - if updated_eval_case.eval_id and updated_eval_case.eval_id != eval_case_id: - raise HTTPException( - status_code=400, - detail=( - "Eval id in EvalCase should match the eval id in the API route." - ), - ) - - # Overwrite the value. We are either overwriting the same value or an empty - # field. - updated_eval_case.eval_id = eval_case_id - try: - eval_sets_manager.update_eval_case( - app_name, eval_set_id, updated_eval_case - ) - except NotFoundError as nfe: - raise HTTPException(status_code=404, detail=str(nfe)) from nfe - - @app.delete("/apps/{app_name}/eval_sets/{eval_set_id}/evals/{eval_case_id}") - def delete_eval(app_name: str, eval_set_id: str, eval_case_id: str): - try: - eval_sets_manager.delete_eval_case(app_name, eval_set_id, eval_case_id) - except NotFoundError as nfe: - raise HTTPException(status_code=404, detail=str(nfe)) from nfe - - @app.post( - "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval", - response_model_exclude_none=True, + agent_loader = FileSystemAgentLoader(agents_dir) + + adk_web_server = AdkWebServer( + agent_loader=agent_loader, + session_service=session_service, + artifact_service=artifact_service, + memory_service=memory_service, + credential_service=credential_service, + eval_sets_manager=eval_sets_manager, + eval_set_results_manager=eval_set_results_manager, ) - async def run_eval( - app_name: str, eval_set_id: str, req: RunEvalRequest - ) -> list[RunEvalResult]: - """Runs an eval given the details in the eval request.""" - from .cli_eval import run_evals - # Create a mapping from eval set file to all the evals that needed to be - # run. - eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id) + # Callbacks & other optional args for when constructing the FastAPI instance + extra_fast_api_args = {} - if not eval_set: - raise HTTPException( - status_code=400, detail=f"Eval set `{eval_set_id}` not found." - ) - - if req.eval_ids: - eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids] - eval_set_to_evals = {eval_set_id: eval_cases} - else: - logger.info("Eval ids to run list is empty. We will run all eval cases.") - eval_set_to_evals = {eval_set_id: eval_set.eval_cases} + if trace_to_cloud: + from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter - root_agent = agent_loader.load_agent(app_name) - run_eval_results = [] - eval_case_results = [] - try: - async for eval_case_result in run_evals( - eval_set_to_evals, - root_agent, - getattr(root_agent, "reset_data", None), - req.eval_metrics, - session_service=session_service, - artifact_service=artifact_service, - ): - run_eval_results.append( - RunEvalResult( - app_name=app_name, - eval_set_file=eval_case_result.eval_set_file, - eval_set_id=eval_set_id, - eval_id=eval_case_result.eval_id, - final_eval_status=eval_case_result.final_eval_status, - eval_metric_results=eval_case_result.eval_metric_results, - overall_eval_metric_results=eval_case_result.overall_eval_metric_results, - eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, - ) + def register_processors(provider: TracerProvider) -> None: + envs.load_dotenv_for_agent("", agents_dir) + if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None): + processor = export.BatchSpanProcessor( + CloudTraceSpanExporter(project_id=project_id) ) - eval_case_result.session_details = await session_service.get_session( - app_name=app_name, - user_id=eval_case_result.user_id, - session_id=eval_case_result.session_id, + provider.add_span_processor(processor) + else: + logger.warning( + "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will" + " not be enabled." ) - eval_case_results.append(eval_case_result) - except ModuleNotFoundError as e: - logger.exception("%s", e) - raise HTTPException(status_code=400, detail=str(e)) from e - eval_set_results_manager.save_eval_set_result( - app_name, eval_set_id, eval_case_results + extra_fast_api_args.update( + register_processors=register_processors, ) - return run_eval_results + if reload_agents: - @app.get( - "/apps/{app_name}/eval_results/{eval_result_id}", - response_model_exclude_none=True, - ) - def get_eval_result( - app_name: str, - eval_result_id: str, - ) -> EvalSetResult: - """Gets the eval result for the given eval id.""" - try: - return eval_set_results_manager.get_eval_set_result( - app_name, eval_result_id + def setup_observer(observer: Observer, adk_web_server: AdkWebServer): + agent_change_handler = AgentChangeEventHandler( + agent_loader=agent_loader, + runners_to_clean=adk_web_server.runners_to_clean, + current_app_name_ref=adk_web_server.current_app_name_ref, ) - except ValueError as ve: - raise HTTPException(status_code=404, detail=str(ve)) from ve - except ValidationError as ve: - raise HTTPException(status_code=500, detail=str(ve)) from ve - - @app.get( - "/apps/{app_name}/eval_results", - response_model_exclude_none=True, - ) - def list_eval_results(app_name: str) -> list[str]: - """Lists all eval results for the given app.""" - return eval_set_results_manager.list_eval_set_results(app_name) - - @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}") - async def delete_session(app_name: str, user_id: str, session_id: str): - await session_service.delete_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", - response_model_exclude_none=True, - ) - async def load_artifact( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - version: Optional[int] = Query(None), - ) -> Optional[types.Part]: - artifact = await artifact_service.load_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - version=version, - ) - if not artifact: - raise HTTPException(status_code=404, detail="Artifact not found") - return artifact + observer.schedule(agent_change_handler, agents_dir, recursive=True) + observer.start() - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}", - response_model_exclude_none=True, - ) - async def load_artifact_version( - app_name: str, - user_id: str, - session_id: str, - artifact_name: str, - version_id: int, - ) -> Optional[types.Part]: - artifact = await artifact_service.load_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - version=version_id, - ) - if not artifact: - raise HTTPException(status_code=404, detail="Artifact not found") - return artifact + def tear_down_observer(observer: Observer, _: AdkWebServer): + observer.stop() + observer.join() - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", - response_model_exclude_none=True, - ) - async def list_artifact_names( - app_name: str, user_id: str, session_id: str - ) -> list[str]: - return await artifact_service.list_artifact_keys( - app_name=app_name, user_id=user_id, session_id=session_id + extra_fast_api_args.update( + setup_observer=setup_observer, + tear_down_observer=tear_down_observer, ) - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions", - response_model_exclude_none=True, - ) - async def list_artifact_versions( - app_name: str, user_id: str, session_id: str, artifact_name: str - ) -> list[int]: - return await artifact_service.list_versions( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, + if web: + BASE_DIR = Path(__file__).parent.resolve() + ANGULAR_DIST_PATH = BASE_DIR / "browser" + extra_fast_api_args.update( + web_assets_dir=ANGULAR_DIST_PATH, ) - @app.delete( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", + app = adk_web_server.get_fast_api_app( + lifespan=lifespan, + allow_origins=allow_origins, + **extra_fast_api_args, ) - async def delete_artifact( - app_name: str, user_id: str, session_id: str, artifact_name: str - ): - await artifact_service.delete_artifact( - app_name=app_name, - user_id=user_id, - session_id=session_id, - filename=artifact_name, - ) @working_in_progress("builder_save is not ready for use.") @app.post("/builder/save", response_model_exclude_none=True) - async def builder_build(files: list[UploadFile]) -> bool: + async def builder_build(req: AgentBuildRequest): base_path = Path.cwd() / agents_dir - - for file in files: - try: - # File name format: {app_name}/{agent_name}.yaml - if not file.filename: - logger.exception("Agent name is missing in the input files") - return False - - agent_name, filename = file.filename.split("/") - - agent_dir = os.path.join(base_path, agent_name) - os.makedirs(agent_dir, exist_ok=True) - file_path = os.path.join(agent_dir, filename) - - with open(file_path, "w") as buffer: - shutil.copyfileobj(file.file, buffer) - - except Exception as e: - logger.exception("Error in builder_build: %s", e) - return False - - return True - - @app.post("/run", response_model_exclude_none=True) - async def agent_run(req: AgentRunRequest) -> list[Event]: - session = await session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - runner = await _get_runner_async(req.app_name) - events = [ - event - async for event in runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - ) - ] - logger.info("Generated %s events in agent run: %s", len(events), events) - return events - - @app.post("/run_sse") - async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse: - # SSE endpoint - session = await session_service.get_session( - app_name=req.app_name, user_id=req.user_id, session_id=req.session_id - ) - if not session: - raise HTTPException(status_code=404, detail="Session not found") - - # Convert the events to properly formatted SSE - async def event_generator(): - try: - stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE - runner = await _get_runner_async(req.app_name) - async for event in runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), - ): - # Format as SSE data - sse_event = event.model_dump_json(exclude_none=True, by_alias=True) - logger.info("Generated event in agent run streaming: %s", sse_event) - yield f"data: {sse_event}\n\n" - except Exception as e: - logger.exception("Error in event_generator: %s", e) - # You might want to yield an error event here - yield f'data: {{"error": "{str(e)}"}}\n\n' - - # Returns a streaming response with the proper media type for SSE - return StreamingResponse( - event_generator(), - media_type="text/event-stream", - ) - - @app.get( - "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph", - response_model_exclude_none=True, - ) - async def get_event_graph( - app_name: str, user_id: str, session_id: str, event_id: str - ): - session = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - session_events = session.events if session else [] - event = next((x for x in session_events if x.id == event_id), None) - if not event: - return {} - - from . import agent_graph - - function_calls = event.get_function_calls() - function_responses = event.get_function_responses() - root_agent = agent_loader.load_agent(app_name) - dot_graph = None - if function_calls: - function_call_highlights = [] - for function_call in function_calls: - from_name = event.author - to_name = function_call.name - function_call_highlights.append((from_name, to_name)) - dot_graph = await agent_graph.get_agent_graph( - root_agent, function_call_highlights - ) - elif function_responses: - function_responses_highlights = [] - for function_response in function_responses: - from_name = function_response.name - to_name = event.author - function_responses_highlights.append((from_name, to_name)) - dot_graph = await agent_graph.get_agent_graph( - root_agent, function_responses_highlights - ) - else: - from_name = event.author - to_name = "" - dot_graph = await agent_graph.get_agent_graph( - root_agent, [(from_name, to_name)] - ) - if dot_graph and isinstance(dot_graph, graphviz.Digraph): - return GetEventGraphResult(dot_src=dot_graph.source) - else: - return {} - - @app.websocket("/run_live") - async def agent_live_run( - websocket: WebSocket, - app_name: str, - user_id: str, - session_id: str, - modalities: List[Literal["TEXT", "AUDIO"]] = Query( - default=["TEXT", "AUDIO"] - ), # Only allows "TEXT" or "AUDIO" - ) -> None: - await websocket.accept() - - session = await session_service.get_session( - app_name=app_name, user_id=user_id, session_id=session_id - ) - if not session: - # Accept first so that the client is aware of connection establishment, - # then close with a specific code. - await websocket.close(code=1002, reason="Session not found") - return - - live_request_queue = LiveRequestQueue() - - async def forward_events(): - runner = await _get_runner_async(app_name) - async for event in runner.run_live( - session=session, live_request_queue=live_request_queue - ): - await websocket.send_text( - event.model_dump_json(exclude_none=True, by_alias=True) - ) - - async def process_messages(): - try: - while True: - data = await websocket.receive_text() - # Validate and send the received message to the live queue. - live_request_queue.send(LiveRequest.model_validate_json(data)) - except ValidationError as ve: - logger.error("Validation error in process_messages: %s", ve) - - # Run both tasks concurrently and cancel all if one fails. - tasks = [ - asyncio.create_task(forward_events()), - asyncio.create_task(process_messages()), - ] - done, pending = await asyncio.wait( - tasks, return_when=asyncio.FIRST_EXCEPTION - ) + agent = { + "agent_class": req.agent_type, + "name": req.agent_name, + "model": req.model, + "description": req.description, + "instruction": f"""{req.instruction}""", + } try: - # This will re-raise any exception from the completed tasks. - for task in done: - task.result() - except WebSocketDisconnect: - logger.info("Client disconnected during process_messages.") + agent_dir = os.path.join(base_path, req.agent_name) + os.makedirs(agent_dir, exist_ok=True) + file_path = os.path.join(agent_dir, "root_agent.yaml") + with open(file_path, "w") as file: + yaml.dump(agent, file, default_flow_style=False) + adk_web_server.agent_loader.load_agent(agent_name=req.agent_name) + return True except Exception as e: - logger.exception("Error during live websocket communication: %s", e) - traceback.print_exc() - WEBSOCKET_INTERNAL_ERROR_CODE = 1011 - WEBSOCKET_MAX_BYTES_FOR_REASON = 123 - await websocket.close( - code=WEBSOCKET_INTERNAL_ERROR_CODE, - reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON], - ) - finally: - for task in pending: - task.cancel() - - async def _get_runner_async(app_name: str) -> Runner: - """Returns the runner for the given app.""" - if app_name in _runners_to_clean: - _runners_to_clean.remove(app_name) - runner = runner_dict.pop(app_name, None) - await cleanup.close_runners(list([runner])) - - envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir) - if app_name in runner_dict: - return runner_dict[app_name] - root_agent = agent_loader.load_agent(app_name) - runner = Runner( - app_name=app_name, - agent=root_agent, - artifact_service=artifact_service, - session_service=session_service, - memory_service=memory_service, - credential_service=credential_service, - ) - runner_dict[app_name] = runner - return runner + logger.exception("Error in builder_build: %s", e) + return False if a2a: try: @@ -1090,7 +294,7 @@ def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" async def _get_a2a_runner_async() -> Runner: - return await _get_runner_async(captured_app_name) + return await adk_web_server.get_runner_async(captured_app_name) return _get_a2a_runner_async @@ -1141,28 +345,5 @@ async def _get_a2a_runner_async() -> Runner: except Exception as e: logger.error("Failed to setup A2A agent %s: %s", app_name, e) # Continue with other agents even if one fails - if web: - import mimetypes - - mimetypes.add_type("application/javascript", ".js", True) - mimetypes.add_type("text/javascript", ".js", True) - BASE_DIR = Path(__file__).parent.resolve() - ANGULAR_DIST_PATH = BASE_DIR / "browser" - - @app.get("/") - async def redirect_root_to_dev_ui(): - return RedirectResponse("/dev-ui/") - - @app.get("/dev-ui") - async def redirect_dev_ui_add_slash(): - return RedirectResponse("/dev-ui/") - - app.mount( - "/dev-ui/", - StaticFiles( - directory=ANGULAR_DIST_PATH, html=True, follow_symlink=True - ), - name="static", - ) return app diff --git a/src/google/adk/cli/utils/agent_change_handler.py b/src/google/adk/cli/utils/agent_change_handler.py new file mode 100644 index 000000000..88e46604d --- /dev/null +++ b/src/google/adk/cli/utils/agent_change_handler.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File system event handler for agent changes to trigger hot reload for agents.""" + +from __future__ import annotations + +import logging + +from watchdog.events import FileSystemEventHandler + +from .file_system_agent_loader import FileSystemAgentLoader +from .shared_value import SharedValue + +logger = logging.getLogger("google_adk." + __name__) + + +class AgentChangeEventHandler(FileSystemEventHandler): + + def __init__( + self, + agent_loader: FileSystemAgentLoader, + runners_to_clean: set[str], + current_app_name_ref: SharedValue[str], + ): + self.agent_loader = agent_loader + self.runners_to_clean = runners_to_clean + self.current_app_name_ref = current_app_name_ref + + def on_modified(self, event): + if not (event.src_path.endswith(".py") or event.src_path.endswith(".yaml")): + return + logger.info("Change detected in agents directory: %s", event.src_path) + self.agent_loader.remove_agent_from_cache(self.current_app_name_ref.value) + self.runners_to_clean.add(self.current_app_name_ref.value) diff --git a/src/google/adk/cli/utils/base_agent_loader.py b/src/google/adk/cli/utils/base_agent_loader.py new file mode 100644 index 000000000..d41a50197 --- /dev/null +++ b/src/google/adk/cli/utils/base_agent_loader.py @@ -0,0 +1,36 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base class for agent loaders.""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +from ...agents.base_agent import BaseAgent + + +class BaseAgentLoader(ABC): + """Abstract base class for agent loaders.""" + + @abstractmethod + def load_agent(self, agent_name: str) -> BaseAgent: + """Loads an instance of an agent with the given name.""" + raise NotImplementedError + + @abstractmethod + def list_agents(self) -> list[str]: + """Lists all agents available in the agent loader in alphabetical order.""" + raise NotImplementedError diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/file_system_agent_loader.py similarity index 91% rename from src/google/adk/cli/utils/agent_loader.py rename to src/google/adk/cli/utils/file_system_agent_loader.py index 1e2068463..69783f1f8 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/file_system_agent_loader.py @@ -17,22 +17,26 @@ import importlib import logging import os +from pathlib import Path import sys from typing import Optional +from google.adk.cli.utils import envs from pydantic import ValidationError +from typing_extensions import override -from . import envs from ...agents import config_agent_utils from ...agents.base_agent import BaseAgent from ...utils.feature_decorator import working_in_progress +from .base_agent_loader import BaseAgentLoader logger = logging.getLogger("google_adk." + __name__) -class AgentLoader: +class FileSystemAgentLoader(BaseAgentLoader): """Centralized agent loading with proper isolation, caching, and .env loading. - Support loading agents from below folder/file structures: + + Supports loading agents from below folder/file structures: a) {agent_name}.agent as a module name: agents_dir/{agent_name}/agent.py (with root_agent defined in the module) b) {agent_name} as a module name @@ -41,7 +45,6 @@ class AgentLoader: agents_dir/{agent_name}/__init__.py (with root_agent in the package) d) {agent_name} as a YAML config folder: agents_dir/{agent_name}/root_agent.yaml defines the root agent - """ def __init__(self, agents_dir: str): @@ -188,6 +191,7 @@ def _perform_load(self, agent_name: str) -> BaseAgent: " exposed." ) + @override def load_agent(self, agent_name: str) -> BaseAgent: """Load an agent module (with caching & .env) and return its root_agent.""" if agent_name in self._agent_cache: @@ -199,6 +203,20 @@ def load_agent(self, agent_name: str) -> BaseAgent: self._agent_cache[agent_name] = agent return agent + @override + def list_agents(self) -> list[str]: + """Lists all agents available in the agent loader (sorted alphabetically).""" + base_path = Path.cwd() / self.agents_dir + agent_names = [ + x + for x in os.listdir(base_path) + if os.path.isdir(os.path.join(base_path, x)) + and not x.startswith(".") + and x != "__pycache__" + ] + agent_names.sort() + return agent_names + def remove_agent_from_cache(self, agent_name: str): # Clear module cache for the agent and its submodules keys_to_delete = [ diff --git a/src/google/adk/cli/utils/shared_value.py b/src/google/adk/cli/utils/shared_value.py new file mode 100644 index 000000000..e9202df92 --- /dev/null +++ b/src/google/adk/cli/utils/shared_value.py @@ -0,0 +1,30 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Generic +from typing import TypeVar + +import pydantic + +T = TypeVar("T") + + +class SharedValue(pydantic.BaseModel, Generic[T]): + """Simple wrapper around a value to allow modifying it from callbacks.""" + + model_config = pydantic.ConfigDict( + arbitrary_types_allowed=True, + ) + value: T diff --git a/src/google/adk/cli/utils/__init__.py b/src/google/adk/cli/utils/state.py similarity index 97% rename from src/google/adk/cli/utils/__init__.py rename to src/google/adk/cli/utils/state.py index 846c15635..29d0b1f24 100644 --- a/src/google/adk/cli/utils/__init__.py +++ b/src/google/adk/cli/utils/state.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import re from typing import Any from typing import Optional @@ -19,10 +21,6 @@ from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent -__all__ = [ - 'create_empty_state', -] - def _create_empty_state(agent: BaseAgent, all_state: dict[str, Any]): for sub_agent in agent.sub_agents: diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d4f9382e3..e144d0952 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -188,6 +188,9 @@ def __init__(self, agents_dir: str): def load_agent(self, app_name): return root_agent + def list_agents(self): + return ["test_app"] + return MockAgentLoader(".") @@ -445,7 +448,7 @@ def test_app( return_value=mock_memory_service, ), patch( - "google.adk.cli.fast_api.AgentLoader", + "google.adk.cli.fast_api.FileSystemAgentLoader", return_value=mock_agent_loader, ), patch( @@ -596,7 +599,7 @@ def test_app_with_a2a( return_value=mock_memory_service, ), patch( - "google.adk.cli.fast_api.AgentLoader", + "google.adk.cli.fast_api.FileSystemAgentLoader", return_value=mock_agent_loader, ), patch( diff --git a/tests/unittests/cli/utils/test_agent_loader.py b/tests/unittests/cli/utils/test_file_system_agent_loader.py similarity index 94% rename from tests/unittests/cli/utils/test_agent_loader.py rename to tests/unittests/cli/utils/test_file_system_agent_loader.py index 2b68f3cc3..079d3c6ab 100644 --- a/tests/unittests/cli/utils/test_agent_loader.py +++ b/tests/unittests/cli/utils/test_file_system_agent_loader.py @@ -18,13 +18,13 @@ import tempfile from textwrap import dedent -from google.adk.cli.utils.agent_loader import AgentLoader +from google.adk.cli.utils.file_system_agent_loader import FileSystemAgentLoader from pydantic import ValidationError import pytest -class TestAgentLoader: - """Unit tests for AgentLoader focusing on interface behavior.""" +class TestFileSystemAgentLoader: + """Unit tests for FileSystemAgentLoader focusing on interface behavior.""" @pytest.fixture(autouse=True) def cleanup_sys_path(self): @@ -47,7 +47,8 @@ def create_agent_structure( Args: temp_dir: The temporary directory to create the agent in agent_name: Name of the agent - structure_type: One of 'module', 'package_with_root', 'package_with_agent_module' + structure_type: One of 'module', 'package_with_root', + 'package_with_agent_module' """ if structure_type == "module": # Structure: agents_dir/agent_name.py @@ -140,7 +141,7 @@ def test_load_agent_as_module(self): self.create_agent_structure(temp_path, "module_agent", "module") # Load the agent - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent = loader.load_agent("module_agent") # Assert agent was loaded correctly @@ -159,7 +160,7 @@ def test_load_agent_as_package_with_root_agent(self): ) # Load the agent - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent = loader.load_agent("package_agent") # Assert agent was loaded correctly @@ -177,7 +178,7 @@ def test_load_agent_as_package_with_agent_module(self): ) # Load the agent - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent = loader.load_agent("modular_agent") # Assert agent was loaded correctly @@ -193,7 +194,7 @@ def test_agent_caching_returns_same_instance(self): self.create_agent_structure(temp_path, "cached_agent", "module") # Load the agent twice - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent1 = loader.load_agent("cached_agent") agent2 = loader.load_agent("cached_agent") @@ -215,7 +216,7 @@ def test_env_loading_for_agent(self): ) # Load the agent - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent = loader.load_agent("env_agent") # Assert environment variables were loaded @@ -250,7 +251,7 @@ def __init__(self): root_agent = SubmoduleAgent() """)) - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent = loader.load_agent(agent_name) # Assert that the module version was loaded due to the new loading order @@ -269,7 +270,7 @@ def test_load_multiple_different_agents(self): ) # Load all agents - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent1 = loader.load_agent("agent_one") agent2 = loader.load_agent("agent_two") agent3 = loader.load_agent("agent_three") @@ -285,7 +286,7 @@ def test_load_multiple_different_agents(self): def test_agent_not_found_error(self): """Test that appropriate error is raised when agent is not found.""" with tempfile.TemporaryDirectory() as temp_dir: - loader = AgentLoader(temp_dir) + loader = FileSystemAgentLoader(temp_dir) agents_dir = temp_dir # For use in the expected message string # Try to load non-existent agent @@ -321,7 +322,7 @@ def __init__(self): # Note: No root_agent defined """)) - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) # Try to load agent without root_agent with pytest.raises(ValueError) as exc_info: @@ -348,7 +349,7 @@ def __init__(self): root_agent = {agent_name.title()}Agent() """)) - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) with pytest.raises(ModuleNotFoundError) as exc_info: loader.load_agent(agent_name) @@ -376,7 +377,7 @@ def __init__(self): root_agent = {agent_name.title()}Agent() """)) - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) # SyntaxError is a subclass of Exception, and importlib might wrap it # The loader is expected to prepend its message and re-raise. with pytest.raises( @@ -411,7 +412,7 @@ def __init__(self): root_agent = {agent_name.title()}Agent() """)) - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) # SyntaxError is a subclass of Exception, and importlib might wrap it # The loader is expected to prepend its message and re-raise. with pytest.raises( @@ -436,7 +437,7 @@ def test_sys_path_modification(self): # Check sys.path before assert str(temp_path) not in sys.path - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) # Path should not be added yet - only added during load assert str(temp_path) not in sys.path @@ -483,7 +484,7 @@ def test_load_agent_from_yaml_config(self): self.create_yaml_agent_structure(temp_path, agent_name, yaml_content) # Load the agent - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent = loader.load_agent(agent_name) # Assert agent was loaded correctly @@ -514,7 +515,7 @@ def test_yaml_agent_caching_returns_same_instance(self): self.create_yaml_agent_structure(temp_path, agent_name, yaml_content) # Load the agent twice - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) agent1 = loader.load_agent(agent_name) agent2 = loader.load_agent(agent_name) @@ -525,7 +526,7 @@ def test_yaml_agent_caching_returns_same_instance(self): def test_yaml_agent_not_found_error(self): """Test that appropriate error is raised when YAML agent is not found.""" with tempfile.TemporaryDirectory() as temp_dir: - loader = AgentLoader(temp_dir) + loader = FileSystemAgentLoader(temp_dir) agents_dir = temp_dir # For use in the expected message string # Try to load non-existent YAML agent @@ -565,7 +566,7 @@ def test_yaml_agent_invalid_yaml_error(self): temp_path, agent_name, invalid_yaml_content ) - loader = AgentLoader(str(temp_path)) + loader = FileSystemAgentLoader(str(temp_path)) # Try to load agent with invalid YAML with pytest.raises(ValidationError) as exc_info: