From d5ee450af577b4a15b1f81a6995982efcc436327 Mon Sep 17 00:00:00 2001 From: Nick Bobrowski <39348559+bonk1t@users.noreply.github.com> Date: Sun, 14 Dec 2025 03:07:28 +0000 Subject: [PATCH] feat: add realtime voice agents - add realtime session bridge with per-agent voice support - mount /{agency}/realtime websocket routes via run_fastapi - ship packaged browser + Twilio realtime demos - document voice agents and FastAPI integration options - add regression tests for voice config and realtime endpoints --- .../fastapi-integration.mdx | 9 + .../voice-agents/deployment.mdx | 91 +++ .../voice-agents/overview.mdx | 84 +++ docs/docs.json | 8 + docs/references/api.mdx | 11 + examples/README.md | 1 + examples/interactive/realtime/__init__.py | 1 + examples/interactive/realtime/demo.py | 51 ++ src/agency_swarm/__init__.py | 2 + src/agency_swarm/agency/core.py | 68 +- src/agency_swarm/agency/helpers.py | 4 + src/agency_swarm/agent/constants.py | 28 + src/agency_swarm/agent/core.py | 14 +- src/agency_swarm/integrations/__init__.py | 7 + src/agency_swarm/integrations/fastapi.py | 125 ++++ src/agency_swarm/integrations/realtime.py | 532 ++++++++++++++ src/agency_swarm/realtime/__init__.py | 4 + src/agency_swarm/realtime/agency.py | 155 ++++ src/agency_swarm/realtime/agent.py | 52 ++ .../ui/demos/realtime/__init__.py | 85 +++ .../ui/demos/realtime/app/__init__.py | 1 + .../ui/demos/realtime/app/server.py | 370 ++++++++++ .../ui/demos/realtime/app/static/app.js | 682 ++++++++++++++++++ .../app/static/audio-playback.worklet.js | 120 +++ .../app/static/audio-recorder.worklet.js | 56 ++ .../ui/demos/realtime/app/static/favicon.ico | 0 .../ui/demos/realtime/app/static/index.html | 299 ++++++++ .../ui/demos/realtime/twilio/README.md | 86 +++ .../ui/demos/realtime/twilio/__init__.py | 0 .../ui/demos/realtime/twilio/requirements.txt | 5 + .../ui/demos/realtime/twilio/server.py | 80 ++ .../demos/realtime/twilio/twilio_handler.py | 257 +++++++ .../fastapi/test_fastapi_metadata.py | 18 + tests/integration/realtime/__init__.py | 3 + tests/integration/realtime/test_realtime.py | 232 ++++++ .../test_agency_initialization.py | 44 ++ .../test_agent_initialization.py | 11 + tests/test_integrations_modules/__init__.py | 0 38 files changed, 3593 insertions(+), 3 deletions(-) create mode 100644 docs/additional-features/voice-agents/deployment.mdx create mode 100644 docs/additional-features/voice-agents/overview.mdx create mode 100644 examples/interactive/realtime/__init__.py create mode 100644 examples/interactive/realtime/demo.py create mode 100644 src/agency_swarm/integrations/__init__.py create mode 100644 src/agency_swarm/integrations/realtime.py create mode 100644 src/agency_swarm/realtime/__init__.py create mode 100644 src/agency_swarm/realtime/agency.py create mode 100644 src/agency_swarm/realtime/agent.py create mode 100644 src/agency_swarm/ui/demos/realtime/__init__.py create mode 100644 src/agency_swarm/ui/demos/realtime/app/__init__.py create mode 100644 src/agency_swarm/ui/demos/realtime/app/server.py create mode 100644 src/agency_swarm/ui/demos/realtime/app/static/app.js create mode 100644 src/agency_swarm/ui/demos/realtime/app/static/audio-playback.worklet.js create mode 100644 src/agency_swarm/ui/demos/realtime/app/static/audio-recorder.worklet.js create mode 100644 src/agency_swarm/ui/demos/realtime/app/static/favicon.ico create mode 100644 src/agency_swarm/ui/demos/realtime/app/static/index.html create mode 100644 src/agency_swarm/ui/demos/realtime/twilio/README.md create mode 100644 src/agency_swarm/ui/demos/realtime/twilio/__init__.py create mode 100644 src/agency_swarm/ui/demos/realtime/twilio/requirements.txt create mode 100644 src/agency_swarm/ui/demos/realtime/twilio/server.py create mode 100644 src/agency_swarm/ui/demos/realtime/twilio/twilio_handler.py create mode 100644 tests/integration/realtime/__init__.py create mode 100644 tests/integration/realtime/test_realtime.py create mode 100644 tests/test_integrations_modules/__init__.py diff --git a/docs/additional-features/fastapi-integration.mdx b/docs/additional-features/fastapi-integration.mdx index 8fae8487..48133dba 100644 --- a/docs/additional-features/fastapi-integration.mdx +++ b/docs/additional-features/fastapi-integration.mdx @@ -42,6 +42,8 @@ Optionally, you can specify following parameters: - enable_agui (default: `False`) - Enable AG-UI protocol compatibility for streaming endpoints - enable_logging (default: `False`) - Enable request tracking and expose `/get_logs` endpoint - logs_dir (default: `"activity-logs"`) - Directory for log files when logging is enabled +- enable_realtime (default: `False`) - Mount `/your_agency/realtime` websocket routes in addition to the REST endpoints +- realtime_options (default: `{}`) - Optional overrides for the realtime bridge (`model`, `turn_detection`, `input_audio_format`, etc.) This will create 4 endpoints for the agency: - `/test_agency/get_response` @@ -112,6 +114,8 @@ run_fastapi( "test_agency_2": create_agency_2, }, tools=[example_tool, test_tool], + enable_realtime=True, + realtime_options={"model": "gpt-realtime-mini"}, ) ``` @@ -126,6 +130,7 @@ This will create the following endpoints: - `/test_agency_2/get_metadata` - `/tool/ExampleTool` (for BaseTools) or `/tool/example_tool` (for function tools) - `/tool/TestTool` (for BaseTools) or `/tool/test_tool` (for function tools) +- `/test_agency_1/realtime` and `/test_agency_2/realtime` websocket routes when `enable_realtime=True` If `enable_logging=True`, a `/get_logs` endpoint is also added. @@ -185,6 +190,10 @@ print("Response:", tool_response.json()) When `enable_logging=True`: - `/get_logs` (POST) +- **Realtime Websocket Endpoint:** + When `enable_realtime=True`: + - `/your_agency_name/realtime` (WebSocket) + - **AG-UI Protocol:** When `enable_agui=True`, only the streaming endpoint is exposed and follows the AG-UI protocol for enhanced frontend integration. The cancel endpoint is not registered in AG-UI mode. diff --git a/docs/additional-features/voice-agents/deployment.mdx b/docs/additional-features/voice-agents/deployment.mdx new file mode 100644 index 00000000..6c6489a4 --- /dev/null +++ b/docs/additional-features/voice-agents/deployment.mdx @@ -0,0 +1,91 @@ +--- +title: "Deployment" +description: "Run the realtime FastAPI bridge, serve the bundled web client, and connect Twilio phone calls." +icon: "server" +--- + +Use this guide once your agent is ready and you want to host a realtime bridge or connect phone infrastructure. It builds on the [Overview](./overview) and assumes your agents are ready for deployment. + +## Host the FastAPI bridge + +`run_realtime` starts a FastAPI app that proxies between your Agency Swarm agents and the OpenAI Realtime API. The helper already converts your agency to the realtime runtime, exposes a `/realtime` websocket, and streams events back to callers. + +```python +from agency_swarm import Agency +from agency_swarm.integrations import run_realtime +from voice_agent import voice_agent + +agency = Agency(voice_agent) + +run_realtime( + agency=agency, + model="gpt-realtime", + host="0.0.0.0", + port=8000, + turn_detection={"type": "server_vad"}, +) +``` + +```bash +python app.py +``` + +The server prints every incoming websocket connection. When your agents declare `voice=`, the bridge carries that choice automatically; omit the parameter for the entry agent to inherit its own voice. Supply `cors_origins` when you deploy behind a browser client that runs on a different domain. + + +`run_realtime(..., return_app=True)` returns the FastAPI `app` object if you want to mount it inside an existing application rather than start a dedicated Uvicorn process. + + +## Serve voice endpoints from `run_fastapi` + +Keep your existing REST endpoints and add realtime voice routes with one flag: + +```python +run_fastapi( + agencies={"support": create_agency}, + enable_realtime=True, + realtime_options={ + "model": "gpt-realtime", + "turn_detection": {"type": "server_vad", "interrupt_response": True}, + }, + enable_logging=True, +) +``` + +`enable_realtime=True` mounts `/support/realtime` alongside the normal JSON endpoints. Pass `realtime_options` when you want to override model settings (they map directly to `run_realtime` keyword arguments). Authentication and logging apply to the new websocket route automatically. + +## Serve the packaged browser client + +The static site in `src/agency_swarm/ui/demos/realtime/app` is bundled with the library. Point it at your server by editing `examples/interactive/realtime/demo.py` or by hosting the static files yourself: + +```bash +python -m agency_swarm.ui.demos.realtime.app.server +``` + +This mounts the frontend and websocket bridge under the same process—ideal for internal demos or QA. + +## Twilio phone calls + +Pass a Twilio number to `run_realtime` to expose a media-stream bridge. The helper exposes `/incoming-call` (returns TwiML) and `/twilio/media-stream` for bidirectional audio. + +```python +run_realtime( + agency=agency, + model="gpt-realtime", + twilio_number="+15551234567", + twilio_audio_format="g711_ulaw", + twilio_greeting="Connecting you to the assistant.", +) +``` + +Deployment checklist: + +1. Start the server (with extras installed) and expose it publicly, e.g. `ngrok http 8000`. +2. In the Twilio Console, set your phone number’s voice webhook to `https:///incoming-call`. +3. Call the number—the helper streams audio in both directions and reuses your existing tools and handoffs. + +For a lower-level implementation (custom playback tracking, fine-grained buffering), see `src/agency_swarm/ui/demos/realtime/twilio/README.md`. + + +Store your Twilio account SID and auth token in a local `.env`, export them before launching the demo, and keep `OPENAI_API_KEY` alongside them. The packaged server reads standard environment variables; no credentials live in source control. + diff --git a/docs/additional-features/voice-agents/overview.mdx b/docs/additional-features/voice-agents/overview.mdx new file mode 100644 index 00000000..1ed80195 --- /dev/null +++ b/docs/additional-features/voice-agents/overview.mdx @@ -0,0 +1,84 @@ +--- +title: "Overview" +description: "Design voice-first assistants with the same Agency Swarm agents you already use." +icon: "microphone" +--- + +Agency Swarm reuses your existing `Agent` definitions for voice. This page shows how to adapt agents for spoken conversations; deployment lives on the dedicated [Deployment](./deployment) guide. + +## What you can build + +- **Phone receptionist** — answers calls, routes to specialists, captures caller details. +- **Live support triage** — gathers context, lets callers interrupt, and escalates to a human or another agent. +- **Language coach** — listens, corrects pronunciation, and keeps the dialogue short and encouraging. + +## Prerequisites + +- Access to OpenAI Realtime models (`gpt-realtime` or `gpt-realtime-mini` are the recommended latest options). +- `agency-swarm` with FastAPI extras: + +```bash +pip install "agency-swarm[fastapi]" +``` + +## Define your agent (same API) + +You keep using the standard `Agent` class—voice agents are regular agents with the same tools, handoffs, and instructions. + +```python +from agency_swarm import Agent, function_tool + +@function_tool +def lookup_order(order_id: str) -> str: + """Return a short order status by ID.""" + return f"Order {order_id} has shipped and will arrive soon." + +voice_agent = Agent( + name="Voice Concierge", + instructions=( + "You are a friendly concierge. Answer in one or two sentences and offer to look up order " + "details when the caller mentions a number." + ), + tools=[lookup_order], + voice="nova", +) +``` + + +Keep using the same agent definitions—add `voice=` only when you care about the spoken persona. + + +Set `voice` to any of the OpenAI realtime voices: `alloy`, `ash`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage`, or `shimmer`. Each agent can declare its own voice, and the realtime bridge keeps it consistent across handoffs. If you prefer variety without manual assignments, construct your agency with `randomize_agent_voices=True`; any agent missing an explicit voice receives a deterministic random pick at initialization. + +## Add handoffs (optional) + +Handoffs work exactly as they do in text mode. Register your flows once and they will carry over to voice sessions. + +```python +from agency_swarm import Agency, Agent +from agency_swarm.tools import SendMessageHandoff + +billing = Agent(name="Billing", instructions="Handle billing questions briefly.") +faq = Agent(name="FAQ", instructions="Answer frequently asked questions.") + +concierge = Agent( + name="Concierge", + instructions="Greet the caller, collect intent, then hand off when a specialist is needed.", +) + +agency = Agency( + concierge, + communication_flows=[ + (concierge > billing, SendMessageHandoff), + (concierge > faq, SendMessageHandoff), + ], +) +``` + +When the concierge invokes `SendMessageHandoff`, the realtime session routes audio and tool access to the designated specialist agent. + +## Next steps + +- Try the [realtime browser demo](https://github.com/VRSEN/agency-swarm/tree/main/examples/interactive/realtime) +- [Deploy your agents](./deployment) for phone calls using Twilio. +- Review OpenAI’s realtime [Quickstart](https://openai.github.io/openai-agents-python/realtime/quickstart/) and [Guide](https://openai.github.io/openai-agents-python/realtime/guide/) for protocol details—Agency Swarm builds on those primitives. diff --git a/docs/docs.json b/docs/docs.json index 9a276f5b..ef55bd67 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -97,6 +97,14 @@ "additional-features/few-shot-examples", "additional-features/guardrails", "additional-features/streaming", + { + "group": "Voice Agents", + "icon": "microphone", + "pages": [ + "additional-features/voice-agents/overview", + "additional-features/voice-agents/deployment" + ] + }, "additional-features/fastapi-integration", "additional-features/mcp-tools-server", { diff --git a/docs/references/api.mdx b/docs/references/api.mdx index a1bf54bb..8ccdece1 100644 --- a/docs/references/api.mdx +++ b/docs/references/api.mdx @@ -23,6 +23,8 @@ class Agency: load_threads_callback: ThreadLoadCallback | None = None, save_threads_callback: ThreadSaveCallback | None = None, user_context: dict[str, Any] | None = None, + randomize_agent_voices: bool = False, + voice_random_seed: int | None = None, **kwargs: Any, ): """ @@ -38,6 +40,8 @@ class Agency: load_threads_callback: Callable used to load persisted conversation threads save_threads_callback: Callable used to save conversation threads user_context: Initial shared context accessible to all agents + randomize_agent_voices: Assign random voices (from the realtime voice list) to agents that do not set `voice` + voice_random_seed: Optional seed used to make voice randomization deterministic **kwargs: Captures deprecated parameters, emitting warnings when used """ ``` @@ -166,6 +170,8 @@ def run_fastapi( app_token_env: str = "APP_TOKEN", cors_origins: list[str] | None = None, enable_agui: bool = False, + enable_realtime: bool = False, + realtime_options: dict[str, Any] | None = None, ): """ Serve this agency via the FastAPI integration. @@ -176,6 +182,8 @@ def run_fastapi( app_token_env: Environment variable name for authentication token cors_origins: List of allowed CORS origins enable_agui: Enable Agency UI interface + enable_realtime: Mount `/agency_name/realtime` websocket routes in addition to REST endpoints + realtime_options: Optional overrides for realtime sessions (mirrors :func:`integrations.run_realtime`) """ ``` @@ -314,6 +322,8 @@ class Agent(BaseAgent[MasterContext]): tool_use_behavior ("run_llm_again" | "stop_on_first_tool" | StopAtTools | dict[str, Any] | Callable): Tool execution policy passed through to the agents SDK reset_tool_choice (bool | None): Whether to reset tool choice after tool calls + voice (Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] | None): + Optional realtime voice name for audio output """ ``` @@ -330,6 +340,7 @@ class Agent(BaseAgent[MasterContext]): - **`throw_input_guardrail_error`** (bool): Controls input guardrail mode—False for friendly (guidance as assistant), True for strict (raises exceptions) - **`handoff_reminder`** (str | None): Custom reminder appended to handoff prompts - **`tool_concurrency_manager`** (ToolConcurrencyManager): Coordinates concurrent tool execution +- **`voice`** (str | None): Preferred realtime voice for the agent ### Core Execution Methods diff --git a/examples/README.md b/examples/README.md index 77e5fd7a..96a2433a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -27,6 +27,7 @@ This directory contains runnable examples demonstrating key features of Agency S - **`fastapi_integration/`** – FastAPI server and client examples - `server.py` – FastAPI server with streaming support - `client.py` – Client examples for testing endpoints +- **`interactive/realtime/demo.py`** – Launch the packaged realtime voice/web demo (edit to customize agents) - **`mcp_servers.py`** – Using tools from MCP servers (local and hosted) - **`connectors.py`** – Google Calendar integration using OpenAI hosted tools diff --git a/examples/interactive/realtime/__init__.py b/examples/interactive/realtime/__init__.py new file mode 100644 index 00000000..9a7910b1 --- /dev/null +++ b/examples/interactive/realtime/__init__.py @@ -0,0 +1 @@ +"""Package marker for interactive realtime demo examples.""" diff --git a/examples/interactive/realtime/demo.py b/examples/interactive/realtime/demo.py new file mode 100644 index 00000000..f876fcbd --- /dev/null +++ b/examples/interactive/realtime/demo.py @@ -0,0 +1,51 @@ +""" +Interactive realtime voice demo. + +Launches the packaged browser frontend + FastAPI backend. +Edit this file to customize the agent behavior. +""" + +import sys +from pathlib import Path + +# Ensure local src/ is importable when running directly from the repo checkout. +sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "src")) + +from agency_swarm import Agency, Agent, function_tool +from agency_swarm.ui.demos.realtime import RealtimeDemoLauncher + + +@function_tool +def lookup_order(order_id: str) -> str: + """Return a short order status by ID.""" + return f"Order {order_id} has shipped and will arrive soon." + + +VOICE_AGENT = Agent( + name="Voice Concierge", + instructions=( + "You are a helpful voice concierge. Answer succinctly and offer to look up order details " + "with the provided tool when asked about an order number." + ), + tools=[lookup_order], +) + +VOICE_AGENCY = Agency(VOICE_AGENT) + + +def main() -> None: + print("Agency Swarm Realtime Browser Demo") + print("=" * 50) + print("Open http://localhost:8000 after launch.") + print("Press Ctrl+C to stop.\n") + + RealtimeDemoLauncher.start( + VOICE_AGENCY, + model="gpt-realtime", + voice="alloy", + turn_detection={"type": "server_vad"}, + ) + + +if __name__ == "__main__": + main() diff --git a/src/agency_swarm/__init__.py b/src/agency_swarm/__init__.py index 36a39a4b..d72a1ab1 100644 --- a/src/agency_swarm/__init__.py +++ b/src/agency_swarm/__init__.py @@ -51,6 +51,7 @@ from .hooks import PersistenceHooks # noqa: E402 from .integrations.fastapi import run_fastapi # noqa: E402 from .integrations.mcp_server import run_mcp # noqa: E402 +from .integrations.realtime import run_realtime # noqa: E402 from .tools import ( # noqa: E402 BaseTool, CodeInterpreter, @@ -96,6 +97,7 @@ "PersistenceHooks", "SendMessage", "run_fastapi", + "run_realtime", "run_mcp", # Re-exports from Agents SDK "ModelSettings", diff --git a/src/agency_swarm/agency/core.py b/src/agency_swarm/agency/core.py index 2508e997..f3b53b29 100644 --- a/src/agency_swarm/agency/core.py +++ b/src/agency_swarm/agency/core.py @@ -2,12 +2,14 @@ import atexit import logging import os +import random import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from agents import RunConfig, RunHooks, RunResult, TResponseInputItem from agency_swarm.agent.agent_flow import AgentFlow +from agency_swarm.agent.constants import AGENT_REALTIME_VOICES, AgentVoice from agency_swarm.agent.core import AgencyContext, Agent from agency_swarm.agent.execution_streaming import StreamingRunResponse from agency_swarm.hooks import PersistenceHooks @@ -29,6 +31,7 @@ if TYPE_CHECKING: from agency_swarm.agent.context_types import AgentRuntimeState + from agency_swarm.realtime.agency import RealtimeAgency logger = logging.getLogger(__name__) @@ -90,6 +93,8 @@ def __init__( load_threads_callback: ThreadLoadCallback | None = None, save_threads_callback: ThreadSaveCallback | None = None, user_context: dict[str, Any] | None = None, + randomize_agent_voices: bool = False, + voice_random_seed: int | None = None, **kwargs: Any, ): """ @@ -121,6 +126,10 @@ def __init__( load_threads_callback (ThreadLoadCallback | None, optional): Callable to load conversation threads. save_threads_callback (ThreadSaveCallback | None, optional): Callable to save conversation threads. user_context (dict[str, Any] | None, optional): Initial shared context accessible to all agents. + randomize_agent_voices (bool, optional): When True, assigns a random supported realtime voice to every + agent that does not already declare one. Assignments persist for the lifetime of this agency instance. + voice_random_seed (int | None, optional): Optional seed used when randomizing voices to obtain a + deterministic shuffle (useful for testing). **kwargs: Catches other deprecated parameters, issuing warnings if used. Raises: @@ -195,6 +204,8 @@ def __init__( self.user_context = user_context or {} self.send_message_tool_class = send_message_tool_class + self._voice_random_seed = voice_random_seed if randomize_agent_voices else None + self._randomize_agent_voices = randomize_agent_voices # --- Initialize Core Components --- self.thread_manager = ThreadManager( @@ -217,6 +228,9 @@ def __init__( self._save_threads_callback = final_save_threads_callback initialize_agent_runtime_state(self) + if randomize_agent_voices: + self._assign_random_agent_voices() + if not self.agents: raise ValueError("Agency must contain at least one agent.") logger.info(f"Registered agents: {list(self.agents.keys())}") @@ -257,6 +271,45 @@ def get_agent_runtime_state(self, agent_name: str) -> "AgentRuntimeState": raise ValueError(f"No runtime state found for agent: {agent_name}") return self._agent_runtime_state[agent_name] + def _assign_random_agent_voices(self) -> None: + """Assign deterministic random voices to agents lacking an explicit voice.""" + unassigned_agents = [agent for agent in self.agents.values() if agent.voice is None] + if not unassigned_agents: + return + + rng = random.Random(self._voice_random_seed) + used_voices: set[str] = { + voice for voice in (agent.voice for agent in self.agents.values()) if voice is not None + } + available: list[str] = [voice for voice in AGENT_REALTIME_VOICES if voice not in used_voices] + if not available: + available = list(AGENT_REALTIME_VOICES) + rng.shuffle(available) + + for agent in unassigned_agents: + if not available: + available = [voice for voice in AGENT_REALTIME_VOICES if voice not in used_voices] + if not available: + available = list(AGENT_REALTIME_VOICES) + rng.shuffle(available) + voice_choice = available.pop() + used_voices.add(voice_choice) + agent.voice = cast(AgentVoice, voice_choice) + + def to_realtime(self, agent: "Agent | str | None" = None) -> "RealtimeAgency": + """Create a `RealtimeAgency` wrapper around this agency.""" + from agency_swarm.realtime.agency import RealtimeAgency + + resolved_agent: Agent | None + if agent is None or isinstance(agent, Agent): + resolved_agent = agent + else: + resolved_agent = self.agents.get(agent) + if resolved_agent is None: + raise ValueError(f"Agent '{agent}' is not registered in this agency.") + + return RealtimeAgency(self, agent=resolved_agent) + # Import and bind methods from split modules with proper type hints async def get_response( self, @@ -435,6 +488,8 @@ def run_fastapi( app_token_env: str = "APP_TOKEN", cors_origins: list[str] | None = None, enable_agui: bool = False, + enable_realtime: bool = False, + realtime_options: dict[str, Any] | None = None, ): """Serve this agency via the FastAPI integration. @@ -448,7 +503,16 @@ def run_fastapi( """ from .helpers import run_fastapi - return run_fastapi(self, host, port, app_token_env, cors_origins, enable_agui) + return run_fastapi( + self, + host, + port, + app_token_env, + cors_origins, + enable_agui, + enable_realtime, + realtime_options, + ) def get_agency_structure(self, include_tools: bool = True) -> dict[str, Any]: """Return a ReactFlow-compatible JSON structure describing the agency.""" diff --git a/src/agency_swarm/agency/helpers.py b/src/agency_swarm/agency/helpers.py index 1b0cfcc9..3e1e518b 100644 --- a/src/agency_swarm/agency/helpers.py +++ b/src/agency_swarm/agency/helpers.py @@ -132,6 +132,8 @@ def run_fastapi( app_token_env: str = "APP_TOKEN", cors_origins: list[str] | None = None, enable_agui: bool = False, + enable_realtime: bool = False, + realtime_options: dict[str, Any] | None = None, ) -> None: """Serve this agency via the FastAPI integration. @@ -170,6 +172,8 @@ def agency_factory(*, load_threads_callback=None, save_threads_callback=None, ** app_token_env=app_token_env, cors_origins=cors_origins, enable_agui=enable_agui, + enable_realtime=enable_realtime, + realtime_options=realtime_options, ) diff --git a/src/agency_swarm/agent/constants.py b/src/agency_swarm/agent/constants.py index ee945cc2..d1875440 100644 --- a/src/agency_swarm/agent/constants.py +++ b/src/agency_swarm/agent/constants.py @@ -1,5 +1,7 @@ """Agent module constants extracted to keep files under size limits (no behavior change).""" +from typing import Literal + # Combine old and new params for easier checking later AGENT_PARAMS = { # New/Current @@ -12,6 +14,7 @@ "include_search_results", "validation_attempts", "throw_input_guardrail_error", + "voice", # Old/Deprecated (to check in kwargs) "id", "tool_resources", @@ -24,3 +27,28 @@ # Constants for dynamic tool creation MESSAGE_PARAM = "message" + +# Canonical realtime voice options mirrored from openai-agents SDK v0.4.1 +AGENT_REALTIME_VOICES = ( + "alloy", + "ash", + "coral", + "echo", + "fable", + "onyx", + "nova", + "sage", + "shimmer", +) + +AgentVoice = Literal[ + "alloy", + "ash", + "coral", + "echo", + "fable", + "onyx", + "nova", + "sage", + "shimmer", +] diff --git a/src/agency_swarm/agent/core.py b/src/agency_swarm/agent/core.py index 36dca028..31e48619 100644 --- a/src/agency_swarm/agent/core.py +++ b/src/agency_swarm/agent/core.py @@ -2,7 +2,7 @@ import os import warnings from pathlib import Path -from typing import Any, TypeVar +from typing import Any, TypeVar, cast from agents import ( Agent as BaseAgent, @@ -29,6 +29,7 @@ ) from agency_swarm.agent.agent_flow import AgentFlow from agency_swarm.agent.attachment_manager import AttachmentManager +from agency_swarm.agent.constants import AGENT_REALTIME_VOICES, AgentVoice from agency_swarm.agent.execution_streaming import StreamingRunResponse from agency_swarm.agent.file_manager import AgentFileManager from agency_swarm.agent.tools import _attach_one_call_guard @@ -72,6 +73,7 @@ class Agent(BaseAgent[MasterContext]): include_search_results: bool = False validation_attempts: int = 1 throw_input_guardrail_error: bool = False + voice: AgentVoice | None # --- Internal State --- _associated_vector_store_id: str | None = None @@ -192,6 +194,16 @@ def __init__(self, **kwargs: Any): self.validation_attempts = int(current_agent_params.get("validation_attempts", 1)) self.throw_input_guardrail_error = bool(current_agent_params.get("throw_input_guardrail_error", False)) self.handoff_reminder = current_agent_params.get("handoff_reminder") + voice_value = current_agent_params.get("voice") + if voice_value is None: + self.voice = None + else: + if voice_value not in AGENT_REALTIME_VOICES: + raise ValueError( + f"Invalid voice '{voice_value}' for agent '{self.name}'. " + f"Valid options: {', '.join(AGENT_REALTIME_VOICES)}." + ) + self.voice = cast(AgentVoice, voice_value) # Internal state self._openai_client = None diff --git a/src/agency_swarm/integrations/__init__.py b/src/agency_swarm/integrations/__init__.py new file mode 100644 index 00000000..10752ad2 --- /dev/null +++ b/src/agency_swarm/integrations/__init__.py @@ -0,0 +1,7 @@ +"""Integration helpers exposed at the package level.""" + +from .fastapi import run_fastapi +from .mcp_server import run_mcp +from .realtime import run_realtime + +__all__ = ["run_fastapi", "run_mcp", "run_realtime"] diff --git a/src/agency_swarm/integrations/fastapi.py b/src/agency_swarm/integrations/fastapi.py index 4bd05ff7..0489ee9d 100644 --- a/src/agency_swarm/integrations/fastapi.py +++ b/src/agency_swarm/integrations/fastapi.py @@ -1,11 +1,20 @@ +import asyncio import logging import os from collections.abc import Callable, Mapping +from contextlib import suppress +from typing import Any from agents.tool import FunctionTool from agency_swarm.agency import Agency from agency_swarm.agent.core import Agent +from agency_swarm.integrations.realtime import ( + RealtimeSessionFactory, + _forward_session_events as _rt_forward_session_events, + _handle_client_payload as _rt_handle_client_payload, + build_model_settings, +) logger = logging.getLogger(__name__) @@ -22,6 +31,8 @@ def run_fastapi( enable_agui: bool = False, enable_logging: bool = False, logs_dir: str = "activity-logs", + enable_realtime: bool = False, + realtime_options: dict[str, Any] | None = None, ): """Launch a FastAPI server exposing endpoints for multiple agencies and tools. @@ -46,6 +57,12 @@ def run_fastapi( logs_dir : str Directory to store log files when logging is enabled. Defaults to 'activity-logs'. + enable_realtime : bool + When True, registers a websocket endpoint for each agency that mirrors the realtime + helper. Requires FastAPI extras. + realtime_options : dict[str, Any] | None + Optional configuration applied to realtime endpoints (e.g. ``model``, ``voice``, + ``turn_detection``). Matches the keyword arguments of :func:`run_realtime`. """ if (agencies is None or len(agencies) == 0) and (tools is None or len(tools) == 0): logger.warning("No endpoints to deploy. Please provide at least one agency or tool.") @@ -55,6 +72,7 @@ def run_fastapi( import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware + from starlette.websockets import WebSocket as StarletteWebSocket, WebSocketDisconnect from .fastapi_utils.endpoint_handlers import ( ActiveRunRegistry, @@ -174,6 +192,113 @@ def run_fastapi( endpoints.append(f"/{agency_name}/get_response_stream") endpoints.append(f"/{agency_name}/cancel_response_stream") + if enable_realtime: + realtime_defaults = dict(realtime_options or {}) + route_path = f"/{agency_name}/realtime" + + @app.websocket(route_path) + async def realtime_websocket( + websocket: StarletteWebSocket, + _agency_factory: Callable[..., Agency] = agency_factory, + _agency_name: str = agency_name, + _realtime_defaults: dict[str, Any] = realtime_defaults, + _app_token: str | None = app_token, + ) -> None: + auth_header = websocket.headers.get("authorization") + if _app_token: + if not auth_header or not auth_header.lower().startswith("bearer "): + await websocket.close(code=1008, reason="Unauthorized") + return + provided_token = auth_header.split(" ", 1)[1].strip() + if provided_token != _app_token: + await websocket.close(code=1008, reason="Unauthorized") + return + + await websocket.accept() + logger.info("Realtime websocket accepted for %s from %s", _agency_name, websocket.client) + session = None + try: + try: + agency_instance = _agency_factory(load_threads_callback=lambda: []) + except Exception: + logger.exception("Failed to instantiate agency for realtime endpoint", exc_info=True) + await websocket.close(code=1011, reason="Failed to initialize agency.") + return + + realtime_agency = agency_instance.to_realtime() + entry_voice = getattr(realtime_agency.entry_agent, "voice", None) + config = dict(_realtime_defaults) + base_settings = build_model_settings( + model=config.get("model", "gpt-realtime"), + voice=config.get("voice", entry_voice), + input_audio_format=config.get("input_audio_format"), + output_audio_format=config.get("output_audio_format"), + turn_detection=config.get("turn_detection"), + input_audio_noise_reduction=config.get("input_audio_noise_reduction"), + ) + session_factory = RealtimeSessionFactory(realtime_agency, base_settings) + except Exception: + logger.exception("Failed to prepare realtime session factory", exc_info=True) + await websocket.close(code=1011, reason="Failed to initialize realtime session.") + return + + try: + session = await session_factory.create_session() + except Exception: + logger.exception("Failed to initialize realtime session", exc_info=True) + await websocket.close(code=1011, reason="Failed to initialize realtime session.") + return + + try: + async with session as realtime_session: + events_task = asyncio.create_task( + _rt_forward_session_events( + realtime_session, + websocket.send_text, + initial_voice=session_factory.default_voice, + ) + ) + try: + while True: + message = await websocket.receive() + message_type = message.get("type") + if message_type == "websocket.disconnect": + break + if message_type != "websocket.receive": + continue + + text_data = message.get("text") + if text_data is not None: + await _rt_handle_client_payload(realtime_session, text_data) + continue + + bytes_data = message.get("bytes") + if bytes_data is not None: + await realtime_session.send_audio(bytes_data) + except WebSocketDisconnect: + logger.info( + "Realtime websocket disconnected by client %s for %s", + websocket.client, + _agency_name, + ) + except Exception: + logger.exception( + "Error while handling realtime websocket traffic for %s", + _agency_name, + exc_info=True, + ) + await websocket.close(code=1011, reason="Realtime session error.") + finally: + events_task.cancel() + with suppress(asyncio.CancelledError): + await events_task + finally: + if session is not None: + with suppress(Exception): + await session.close() + + endpoints.append(route_path) + app.add_api_route( f"/{agency_name}/get_metadata", make_metadata_endpoint(agency_metadata, verify_token), diff --git a/src/agency_swarm/integrations/realtime.py b/src/agency_swarm/integrations/realtime.py new file mode 100644 index 00000000..dec7748f --- /dev/null +++ b/src/agency_swarm/integrations/realtime.py @@ -0,0 +1,532 @@ +from __future__ import annotations + +import asyncio +import base64 +import binascii +import json +import logging +from collections.abc import Awaitable, Callable, Mapping +from contextlib import suppress +from typing import TYPE_CHECKING, Any, assert_never, cast + +from agents.realtime import RealtimeRunner, RealtimeSession +from agents.realtime.config import ( + RealtimeInputAudioNoiseReductionConfig, + RealtimeSessionModelSettings, + RealtimeTurnDetectionConfig, +) +from agents.realtime.events import ( + RealtimeAgentEndEvent, + RealtimeAgentStartEvent, + RealtimeAudio, + RealtimeAudioEnd, + RealtimeAudioInterrupted, + RealtimeError, + RealtimeGuardrailTripped, + RealtimeHandoffEvent, + RealtimeHistoryAdded, + RealtimeHistoryUpdated, + RealtimeInputAudioTimeoutTriggered, + RealtimeRawModelEvent, + RealtimeSessionEvent, + RealtimeToolEnd, + RealtimeToolStart, +) +from agents.realtime.model_inputs import ( + RealtimeModelRawClientMessage, + RealtimeModelSendRawMessage, + RealtimeModelSendSessionUpdate, +) +from starlette.websockets import WebSocket as StarletteWebSocket, WebSocketDisconnect + +from agency_swarm.agency.core import Agency +from agency_swarm.agent.core import Agent +from agency_swarm.context import MasterContext +from agency_swarm.realtime.agency import RealtimeAgency +from agency_swarm.utils.thread import ThreadManager + +if TYPE_CHECKING: + from fastapi import FastAPI + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +__all__ = ["run_realtime", "RealtimeSessionFactory", "build_model_settings"] + + +def _model_dump(item: Any) -> Any: + """Convert SDK events to JSON-serializable data.""" + dump = getattr(item, "model_dump", None) + if callable(dump): + try: + return dump(mode="json") + except TypeError: + return dump() + if isinstance(item, str | int | float | bool) or item is None: + return item + if isinstance(item, dict): + return item + return str(item) + + +def _serialize_event(event: RealtimeSessionEvent) -> dict[str, Any] | None: + """Translate realtime session events to JSON payloads for websocket clients.""" + if isinstance(event, RealtimeAudio): + audio_data = base64.b64encode(event.audio.data).decode("utf-8") + return { + "type": "audio", + "audio": audio_data, + "item_id": event.item_id, + "content_index": event.content_index, + "response_id": event.audio.response_id, + } + if isinstance(event, RealtimeAudioEnd): + return { + "type": "audio_end", + "item_id": event.item_id, + "content_index": event.content_index, + } + if isinstance(event, RealtimeAudioInterrupted): + return { + "type": "audio_interrupted", + "item_id": event.item_id, + "content_index": event.content_index, + } + if isinstance(event, RealtimeAgentStartEvent): + return {"type": "agent_start", "agent": event.agent.name} + if isinstance(event, RealtimeAgentEndEvent): + return {"type": "agent_end", "agent": event.agent.name} + if isinstance(event, RealtimeHandoffEvent): + return {"type": "handoff", "from": event.from_agent.name, "to": event.to_agent.name} + if isinstance(event, RealtimeToolStart): + return { + "type": "tool_start", + "agent": event.agent.name, + "tool": getattr(event.tool, "name", str(event.tool)), + } + if isinstance(event, RealtimeToolEnd): + return { + "type": "tool_end", + "agent": event.agent.name, + "tool": getattr(event.tool, "name", str(event.tool)), + "output": str(event.output), + } + if isinstance(event, RealtimeHistoryUpdated): + return { + "type": "history_updated", + "history": [_model_dump(item) for item in event.history], + } + if isinstance(event, RealtimeHistoryAdded): + return { + "type": "history_added", + "item": _model_dump(event.item), + } + if isinstance(event, RealtimeGuardrailTripped): + return { + "type": "guardrail_tripped", + "guardrails": [result.guardrail.get_name() for result in event.guardrail_results], + "message": event.message, + } + if isinstance(event, RealtimeError): + return {"type": "error", "error": str(event.error)} + if isinstance(event, RealtimeRawModelEvent): + raw_type = getattr(event.data, "type", "unknown") + payload = getattr(event.data, "model_dump", None) + data = payload(mode="json") if callable(payload) else str(event.data) + return {"type": "raw_model_event", "raw_type": raw_type, "data": data} + if isinstance(event, RealtimeInputAudioTimeoutTriggered): + return {"type": "input_audio_timeout_triggered"} + assert_never(event) + + +async def _forward_session_events( + session: RealtimeSession, + send: Callable[[str], Awaitable[Any]], + *, + initial_voice: str | None = None, +) -> None: + current_voice = initial_voice + async for event in session: + if isinstance(event, RealtimeAgentStartEvent): + desired_voice = getattr(event.agent, "voice", None) + if desired_voice and desired_voice != current_voice: + await session.model.send_event( + RealtimeModelSendSessionUpdate(session_settings={"voice": desired_voice}) + ) + logger.info("Updated realtime voice to %s for agent %s", desired_voice, event.agent.name) + current_voice = desired_voice + payload = _serialize_event(event) + if payload is not None: + await send(json.dumps(payload)) + + +async def _handle_client_payload(session: RealtimeSession, payload: str) -> None: + try: + message = json.loads(payload) + except json.JSONDecodeError: + logger.warning("Ignoring non-JSON realtime payload: %s", payload[:80]) + return + + msg_type = message.get("type") + if not isinstance(msg_type, str): + logger.warning("Realtime payload missing 'type': %s", message) + return + + if msg_type == "input_audio_buffer": + audio = message.get("audio") + if isinstance(audio, str): + try: + audio_bytes = base64.b64decode(audio) + except (binascii.Error, ValueError): + logger.warning("Failed to decode realtime audio payload.") + return + await session.send_audio(audio_bytes, commit=bool(message.get("commit", False))) + else: + logger.debug("Realtime audio payload missing 'audio' data.") + return + + if msg_type == "interrupt": + await session.interrupt() + return + + other = {k: v for k, v in message.items() if k != "type"} + raw_message: dict[str, Any] = {"type": msg_type} + if other: + raw_message["other_data"] = other + client_message = cast(RealtimeModelRawClientMessage, raw_message) + await session.model.send_event(RealtimeModelSendRawMessage(message=client_message)) + + +def build_model_settings( + *, + model: str, + voice: str | None, + input_audio_format: str | None, + output_audio_format: str | None, + turn_detection: dict[str, Any] | None, + input_audio_noise_reduction: dict[str, Any] | None, +) -> RealtimeSessionModelSettings: + settings: RealtimeSessionModelSettings = {"model_name": model} + if voice: + settings["voice"] = voice + if input_audio_format: + settings["input_audio_format"] = input_audio_format + if output_audio_format: + settings["output_audio_format"] = output_audio_format + if turn_detection: + settings["turn_detection"] = cast(RealtimeTurnDetectionConfig, turn_detection) + if input_audio_noise_reduction: + settings["input_audio_noise_reduction"] = cast( + RealtimeInputAudioNoiseReductionConfig, input_audio_noise_reduction + ) + return settings + + +async def _forward_events_to_twilio( + session: RealtimeSession, + websocket: StarletteWebSocket, + get_stream_sid: Callable[[], str | None], + *, + initial_voice: str | None = None, +) -> None: + current_voice = initial_voice + async for event in session: + stream_sid = get_stream_sid() + if stream_sid is None: + continue + + if isinstance(event, RealtimeAgentStartEvent): + desired_voice = getattr(event.agent, "voice", None) + if desired_voice and desired_voice != current_voice: + await session.model.send_event( + RealtimeModelSendSessionUpdate(session_settings={"voice": desired_voice}) + ) + logger.info("Updated realtime voice to %s for Twilio stream", desired_voice) + current_voice = desired_voice + continue + + if isinstance(event, RealtimeAudio): + payload = base64.b64encode(event.audio.data).decode("utf-8") + await websocket.send_text( + json.dumps({"event": "media", "streamSid": stream_sid, "media": {"payload": payload}}) + ) + elif isinstance(event, RealtimeAudioInterrupted): + await websocket.send_text(json.dumps({"event": "clear", "streamSid": stream_sid})) + + +class RealtimeSessionFactory: + def __init__(self, realtime_agency: RealtimeAgency, base_model_settings: Mapping[str, Any]): + self._agency = realtime_agency + self._base_model_settings = dict(base_model_settings) + + @property + def default_voice(self) -> str | None: + voice_value = self._base_model_settings.get("voice") + return cast(str | None, voice_value) + + async def create_session(self, overrides: dict[str, Any] | None = None) -> RealtimeSession: + runner = RealtimeRunner(self._agency.entry_agent) + merged_settings: dict[str, Any] = dict(self._base_model_settings) + if overrides: + for key, value in overrides.items(): + if value is not None: + merged_settings[key] = value + + model_settings = cast(RealtimeSessionModelSettings, merged_settings) + + session = await runner.run( + context=MasterContext( + thread_manager=ThreadManager(), + agents=self._agency.source_agents, + shared_instructions=self._agency.shared_instructions, + user_context=dict(self._agency.user_context), + agent_runtime_state=self._agency.runtime_state_map, + ), + model_config={"initial_model_settings": model_settings}, + ) + return session + + +def run_realtime( + *, + agency: Agency | RealtimeAgency, + entry_agent: Agent | str | None = None, + model: str = "gpt-realtime", + voice: str | None = None, + host: str = "0.0.0.0", + port: int = 8000, + turn_detection: dict[str, Any] | None = None, + input_audio_format: str | None = None, + output_audio_format: str | None = None, + input_audio_noise_reduction: dict[str, Any] | None = None, + cors_origins: list[str] | None = None, + twilio_number: str | None = None, + twilio_audio_format: str | None = None, + twilio_greeting: str = "Connecting you now.", + return_app: bool = False, +) -> FastAPI | None: + """Launch a realtime FastAPI server backed by OpenAI's Realtime API.""" + + try: + from fastapi import FastAPI as FastAPIApp, Request as FastAPIRequest + from fastapi.middleware.cors import CORSMiddleware as FastAPICORSMiddleware + from fastapi.responses import PlainTextResponse as FastAPIPlainTextResponse + except ImportError as exc: + logger.error( + "Realtime dependencies are missing: %s. Install agency-swarm[fastapi] to use run_realtime.", + exc, + ) + return None + + app = FastAPIApp() + origins = cors_origins or ["*"] + app.add_middleware( + FastAPICORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + realtime_agency = _ensure_realtime_agency(agency, entry_agent) + entry_voice = getattr(realtime_agency.entry_agent, "voice", None) + effective_voice = voice if voice is not None else entry_voice + + base_settings = build_model_settings( + model=model, + voice=effective_voice, + input_audio_format=input_audio_format, + output_audio_format=output_audio_format, + turn_detection=turn_detection, + input_audio_noise_reduction=input_audio_noise_reduction, + ) + session_factory = RealtimeSessionFactory(realtime_agency, base_settings) + + @app.websocket("/realtime") + async def realtime_endpoint(websocket: StarletteWebSocket) -> None: + await websocket.accept() + print(f"[run_realtime] Accepted websocket from {websocket.client}", flush=True) + logger.info("Realtime websocket accepted from %s", websocket.client) + session: RealtimeSession | None = None + try: + try: + session = await session_factory.create_session() + except Exception: + logger.exception("Failed to initialize realtime session", exc_info=True) + await websocket.close(code=1011, reason="Failed to initialize realtime session.") + return + + try: + async with session as realtime_session: + events_task = asyncio.create_task( + _forward_session_events( + realtime_session, + websocket.send_text, + initial_voice=session_factory.default_voice, + ) + ) + try: + while True: + message = await websocket.receive() + message_type = message.get("type") + if message_type == "websocket.disconnect": + break + if message_type != "websocket.receive": + continue + + text_data = message.get("text") + if text_data is not None: + await _handle_client_payload(realtime_session, text_data) + continue + + bytes_data = message.get("bytes") + if bytes_data is not None: + await realtime_session.send_audio(bytes_data) + except WebSocketDisconnect: + logger.info("Realtime websocket disconnected by client %s", websocket.client) + except Exception: + logger.exception("Error while handling realtime websocket traffic", exc_info=True) + await websocket.close(code=1011, reason="Realtime session error.") + finally: + events_task.cancel() + with suppress(asyncio.CancelledError): + await events_task + finally: + if session is not None: + with suppress(Exception): + await session.close() + except Exception: + logger.exception("Realtime endpoint crashed", exc_info=True) + await websocket.close(code=1011, reason="Realtime endpoint failure.") + + listen_host = host + listen_port = port + + if twilio_number: + incoming_path = "/incoming-call" + media_path = "/twilio/media-stream" + logger.info("Twilio voice bridge enabled for %s", twilio_number) + + overrides: dict[str, Any] = {} + if twilio_audio_format: + overrides["input_audio_format"] = twilio_audio_format + overrides.setdefault("output_audio_format", twilio_audio_format) + if overrides and not overrides.get("output_audio_format") and output_audio_format: + overrides["output_audio_format"] = output_audio_format + + @app.post(incoming_path) + @app.get(incoming_path) + async def incoming_call(request: FastAPIRequest) -> FastAPIPlainTextResponse: + forwarded_proto = request.headers.get("x-forwarded-proto", request.url.scheme) + scheme = "https" if forwarded_proto in {"https", "wss"} else "http" + ws_scheme = "wss" if scheme == "https" else "ws" + host_header = request.headers.get("host", f"{listen_host}:{listen_port}") + ws_url = f"{ws_scheme}://{host_header}{media_path}" + + twiml = ( + '\n' + "\n" + f" {twilio_greeting}\n" + " \n" + f' \n' + " \n" + "" + ) + return FastAPIPlainTextResponse(content=twiml, media_type="text/xml") + + @app.websocket(media_path) + async def twilio_media_stream(websocket: StarletteWebSocket) -> None: + await websocket.accept() + try: + session = await session_factory.create_session(overrides=overrides or None) + except Exception: + logger.exception("Failed to initialize realtime session for Twilio bridge", exc_info=True) + await websocket.close(code=1011, reason="Realtime session initialization failed.") + return + stream_sid: str | None = None + + def _get_stream_sid() -> str | None: + return stream_sid + + try: + initial_voice = overrides.get("voice") if overrides else session_factory.default_voice + async with session as realtime_session: + events_task = asyncio.create_task( + _forward_events_to_twilio( + realtime_session, + websocket, + _get_stream_sid, + initial_voice=initial_voice, + ) + ) + try: + while True: + message_text = await websocket.receive_text() + try: + payload = json.loads(message_text) + except json.JSONDecodeError: + logger.warning("Invalid Twilio payload: %s", message_text[:80]) + continue + + event_type = payload.get("event") + if event_type == "start": + stream_sid = payload.get("start", {}).get("streamSid", stream_sid) + elif event_type == "media": + media_payload = payload.get("media", {}).get("payload") + if isinstance(media_payload, str): + try: + audio_bytes = base64.b64decode(media_payload) + except (binascii.Error, ValueError): + logger.warning("Failed to decode Twilio audio payload.") + continue + await realtime_session.send_audio(audio_bytes) + elif event_type == "mark": + continue + elif event_type == "stop": + break + else: + logger.debug("Unhandled Twilio event: %s", event_type) + except WebSocketDisconnect: + pass + finally: + events_task.cancel() + with suppress(asyncio.CancelledError): + await events_task + finally: + with suppress(Exception): + await session.close() + + if return_app: + return app + + try: + import uvicorn + except ImportError as exc: + logger.error("uvicorn is required to run the realtime server: %s", exc) + return None + + logger.info("Starting realtime server at http://%s:%s", host, port) + uvicorn.run(app, host=host, port=port) + return None + + +def _ensure_realtime_agency(agency: Agency | RealtimeAgency, entry_agent: Agent | str | None) -> RealtimeAgency: + if isinstance(agency, RealtimeAgency): + if entry_agent is not None: + raise ValueError("entry_agent must not be provided when a RealtimeAgency instance is supplied.") + return agency + + if isinstance(agency, Agency): + resolved_agent: Agent | None + if entry_agent is None: + resolved_agent = None + elif isinstance(entry_agent, Agent): + resolved_agent = entry_agent + else: + resolved_agent = agency.agents.get(entry_agent) + if resolved_agent is None: + raise ValueError(f"Agent '{entry_agent}' is not registered in the Agency.") + + return agency.to_realtime(resolved_agent) + + raise TypeError(f"Unsupported agency type: {type(agency)!r}") diff --git a/src/agency_swarm/realtime/__init__.py b/src/agency_swarm/realtime/__init__.py new file mode 100644 index 00000000..26ccb90a --- /dev/null +++ b/src/agency_swarm/realtime/__init__.py @@ -0,0 +1,4 @@ +from .agency import RealtimeAgency +from .agent import RealtimeAgent + +__all__ = ["RealtimeAgency", "RealtimeAgent"] diff --git a/src/agency_swarm/realtime/agency.py b/src/agency_swarm/realtime/agency.py new file mode 100644 index 00000000..64fe208d --- /dev/null +++ b/src/agency_swarm/realtime/agency.py @@ -0,0 +1,155 @@ +import inspect +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, cast + +from agents import RunContextWrapper +from agents.handoffs import Handoff +from agents.realtime import RealtimeAgent as SDKRealtimeAgent, realtime_handoff + +from agency_swarm.agent.context_types import AgentRuntimeState +from agency_swarm.agent.core import Agent +from agency_swarm.context import MasterContext +from agency_swarm.realtime.agent import RealtimeAgent + +if TYPE_CHECKING: + from agency_swarm.agency.core import Agency + + +class RealtimeAgency: + """Realtime-aware facade built from an existing `Agency`. + + It converts each registered Agent into a distinct `RealtimeAgent`, enforces that + every communication path is modeled as a handoff, and exposes the data needed to + drive the realtime runner. + """ + + def __init__(self, source: "Agency", agent: Agent | None = None) -> None: + self._source = source + self._realtime_agents: dict[str, RealtimeAgent] = { + name: RealtimeAgent(agent_instance) for name, agent_instance in source.agents.items() + } + self._entry_agent = self._resolve_entry(agent) + self._populate_handoffs() + + @property + def source(self) -> "Agency": + return self._source + + @property + def entry_agent(self) -> RealtimeAgent: + return self._entry_agent + + @property + def agents(self) -> dict[str, RealtimeAgent]: + return self._realtime_agents + + @property + def source_agents(self) -> dict[str, Agent]: + return self._source.agents + + @property + def shared_instructions(self) -> str | None: + return self._source.shared_instructions or None + + @property + def user_context(self) -> dict[str, Any]: + return self._source.user_context + + @property + def runtime_state_map(self) -> dict[str, AgentRuntimeState]: + return self._source._agent_runtime_state + + def _resolve_entry(self, agent: Agent | None) -> RealtimeAgent: + if agent is None: + if not self._source.entry_points: + raise ValueError("RealtimeAgency requires the source Agency to declare an entry point.") + candidate = self._source.entry_points[0] + else: + if agent.name not in self._source.agents: + raise ValueError(f"Agent '{agent.name}' is not registered in the source Agency.") + candidate = agent + + realtime_agent = self._realtime_agents.get(candidate.name) + if realtime_agent is None: + raise ValueError(f"Realtime agent for '{candidate.name}' could not be constructed.") + return realtime_agent + + def _populate_handoffs(self) -> None: + runtime_state_map = self.runtime_state_map + for agent_name, realtime_agent in self._realtime_agents.items(): + runtime_state = runtime_state_map.get(agent_name) + if runtime_state is None: + raise ValueError(f"No runtime state available for agent '{agent_name}'.") + + converted = self._convert_handoffs( + source_agent=self._source.agents[agent_name], + runtime_state=runtime_state, + realtime_agent=realtime_agent, + ) + realtime_agent.handoffs = cast( + list[SDKRealtimeAgent[MasterContext] | Handoff[Any, SDKRealtimeAgent[Any]]], + converted, + ) + + def _convert_handoffs( + self, + *, + source_agent: Agent, + runtime_state: AgentRuntimeState, + realtime_agent: RealtimeAgent, + ) -> list[Handoff[Any, SDKRealtimeAgent[MasterContext]]]: + original_handoffs = getattr(runtime_state, "handoffs", []) + if not original_handoffs: + if getattr(runtime_state, "send_message_tools", {}): + raise ValueError( + f"RealtimeAgency requires communication flows to be modeled with SendMessageHandoff. " + f"Agent '{source_agent.name}' has send_message tools but no handoffs." + ) + return [] + + converted: list[Handoff[Any, SDKRealtimeAgent[MasterContext]]] = [] + for original in original_handoffs: + target = self._realtime_agents.get(original.agent_name) + if target is None: + raise ValueError( + f"Handoff '{original.tool_name}' on agent '{source_agent.name}' " + f"targets unknown agent '{original.agent_name}'." + ) + + realtime_handoff_obj = realtime_handoff( + target, + tool_name_override=original.tool_name, + tool_description_override=original.tool_description, + is_enabled=_wrap_is_enabled(original.is_enabled, source_agent), + ) + + original_on_invoke = original.on_invoke_handoff + + async def _on_invoke( + ctx: RunContextWrapper[Any], input_json: str | None = None, *, _orig=original_on_invoke, _target=target + ) -> SDKRealtimeAgent[MasterContext]: + await _orig(ctx, input_json or "") + return _target + + realtime_handoff_obj.on_invoke_handoff = _on_invoke + realtime_handoff_obj.input_filter = original.input_filter + realtime_handoff_obj.input_json_schema = original.input_json_schema + realtime_handoff_obj.strict_json_schema = original.strict_json_schema + converted.append(realtime_handoff_obj) + return converted + + +def _wrap_is_enabled( + is_enabled: bool | Callable[[RunContextWrapper[Any], Agent], Any], + source_agent: Agent, +) -> bool | Callable[[RunContextWrapper[Any], SDKRealtimeAgent[MasterContext]], Awaitable[bool]]: + if not callable(is_enabled): + return bool(is_enabled) + + async def _wrapped(ctx: RunContextWrapper[Any], _: SDKRealtimeAgent[MasterContext]) -> bool: + result = is_enabled(ctx, source_agent) + if inspect.isawaitable(result): + return bool(await cast(Awaitable[Any], result)) + return bool(result) + + return _wrapped diff --git a/src/agency_swarm/realtime/agent.py b/src/agency_swarm/realtime/agent.py new file mode 100644 index 00000000..afffa088 --- /dev/null +++ b/src/agency_swarm/realtime/agent.py @@ -0,0 +1,52 @@ +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, cast + +from agents import RunContextWrapper +from agents.agent import MCPConfig +from agents.realtime import RealtimeAgent as SDKRealtimeAgent + +from agency_swarm.agent.core import Agent +from agency_swarm.context import MasterContext + + +class RealtimeAgent(SDKRealtimeAgent[MasterContext]): + """RealtimeAgent wrapper that preserves the interface of agency_swarm.agent.core.Agent.""" + + def __init__(self, source: Agent) -> None: + self._source = source + instructions = _wrap_instructions(source) + super().__init__( + name=source.name, + instructions=instructions, + handoff_description=source.handoff_description, + tools=list(source.tools), + mcp_servers=list(source.mcp_servers), + mcp_config=cast(MCPConfig, dict(source.mcp_config)), + prompt=source.prompt if (source.prompt is None or not callable(source.prompt)) else None, + output_guardrails=list(source.output_guardrails), + ) + self.voice = source.voice + + @property + def source(self) -> Agent: + """Return the originating agency-swarm Agent.""" + return self._source + + +def _wrap_instructions( + agent: Agent, +) -> str | Callable[[RunContextWrapper[Any], SDKRealtimeAgent[MasterContext]], Awaitable[str]] | None: + instructions = agent.instructions + if instructions is None or isinstance(instructions, str): + return cast(str | None, instructions) + + typed = cast(Callable[[RunContextWrapper[Any], Agent], Awaitable[str] | str], instructions) + + async def _wrapped(ctx: RunContextWrapper[Any], _: SDKRealtimeAgent[MasterContext]) -> str: + result = typed(ctx, agent) + if inspect.isawaitable(result): + return cast(str, await result) + return cast(str, result) + + return _wrapped diff --git a/src/agency_swarm/ui/demos/realtime/__init__.py b/src/agency_swarm/ui/demos/realtime/__init__.py new file mode 100644 index 00000000..2bea7a9c --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/__init__.py @@ -0,0 +1,85 @@ +from pathlib import Path +from typing import Any + +from agency_swarm.agency.core import Agency + + +class RealtimeDemoLauncher: + """Launch the realtime browser demo using a provided Agency Swarm agent graph.""" + + @staticmethod + def start( + agency: Agency, + *, + host: str = "0.0.0.0", + port: int = 8000, + model: str = "gpt-realtime", + voice: str | None = "alloy", + turn_detection: dict[str, Any] | None = None, + input_audio_format: str | None = None, + output_audio_format: str | None = None, + input_audio_noise_reduction: dict[str, Any] | None = None, + cors_origins: list[str] | None = None, + ) -> None: + """Start the realtime demo server and keep it running until interrupted.""" + if not isinstance(agency, Agency): + raise TypeError("RealtimeDemoLauncher.start expects an Agency instance.") + if not agency.entry_points: + raise ValueError("RealtimeDemoLauncher.start requires the Agency to define at least one entry point.") + entry_agent = agency.entry_points[0] + + try: + import uvicorn + except ImportError as exc: # pragma: no cover - dependency guard + raise RuntimeError( + 'Realtime demo requires uvicorn. Install extras: pip install "agency-swarm[fastapi]"' + ) from exc + + try: + from agency_swarm.integrations.realtime import ( + RealtimeSessionFactory, + build_model_settings, + ) + from agency_swarm.ui.demos.realtime.app.server import create_realtime_demo_app + except ImportError as exc: # pragma: no cover - import guard + missing = getattr(exc, "name", "") or str(exc) + if "fastapi" in missing or "starlette" in missing: + raise RuntimeError( + 'Realtime demo requires FastAPI. Install extras: pip install "agency-swarm[fastapi]"' + ) from exc + raise RuntimeError("Realtime demo assets are missing from the installation.") from exc + + demo_static_dir = Path(__file__).parent / "app" / "static" + if not demo_static_dir.exists(): + raise RuntimeError(f"Realtime demo static assets not found at {demo_static_dir}.") + + base_settings = build_model_settings( + model=model, + voice=voice, + input_audio_format=input_audio_format, + output_audio_format=output_audio_format, + turn_detection=turn_detection, + input_audio_noise_reduction=input_audio_noise_reduction, + ) + realtime_agency = agency.to_realtime(entry_agent) + session_factory = RealtimeSessionFactory(realtime_agency, base_settings) + app = create_realtime_demo_app( + session_factory, + static_dir=demo_static_dir, + cors_origins=cors_origins, + ) + + display_host = host if host != "0.0.0.0" else "localhost" + print( + f"\n\033[92;1mRealtime demo running\033[0m\nFrontend: http://{display_host}:{port}\nPress Ctrl+C to stop.\n" + ) + + uvicorn.run( + app, + host=host, + port=port, + ws_max_size=16 * 1024 * 1024, + ) + + +__all__ = ["RealtimeDemoLauncher"] diff --git a/src/agency_swarm/ui/demos/realtime/app/__init__.py b/src/agency_swarm/ui/demos/realtime/app/__init__.py new file mode 100644 index 00000000..047617d1 --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/app/__init__.py @@ -0,0 +1 @@ +"""FastAPI app and static assets for the realtime demo.""" diff --git a/src/agency_swarm/ui/demos/realtime/app/server.py b/src/agency_swarm/ui/demos/realtime/app/server.py new file mode 100644 index 00000000..05484853 --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/app/server.py @@ -0,0 +1,370 @@ +import asyncio +import base64 +import json +import logging +import struct +from contextlib import asynccontextmanager, suppress +from pathlib import Path +from types import TracebackType +from typing import Any, Protocol, assert_never, cast + +from agents.realtime import RealtimeSession, RealtimeSessionEvent +from agents.realtime.config import RealtimeUserInputMessage +from agents.realtime.items import RealtimeItem +from agents.realtime.model_inputs import ( + RealtimeModelRawClientMessage, + RealtimeModelSendRawMessage, + RealtimeModelSendSessionUpdate, +) +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles + +logger = logging.getLogger(__name__) + + +class RealtimeSessionContext(Protocol): + async def __aenter__(self) -> RealtimeSession: ... + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> bool | None: ... + + +class RealtimeSessionFactory(Protocol): + @property + def default_voice(self) -> str | None: ... + + async def create_session(self, overrides: dict[str, Any] | None = None) -> RealtimeSessionContext: ... + + +class RealtimeWebSocketManager: + def __init__(self, session_factory: RealtimeSessionFactory): + self._session_factory = session_factory + self._session_voices: dict[str, str | None] = {} + self._event_tasks: dict[str, asyncio.Task[None]] = {} + + self.active_sessions: dict[str, RealtimeSession] = {} + self.session_contexts: dict[str, RealtimeSessionContext] = {} + self.websockets: dict[str, WebSocket] = {} + + async def connect(self, websocket: WebSocket, session_id: str) -> bool: + await websocket.accept() + self.websockets[session_id] = websocket + self._session_voices[session_id] = self._session_factory.default_voice + + try: + session_context = await self._session_factory.create_session() + session = await session_context.__aenter__() + except Exception: + logger.exception("Failed to initialize realtime session") + with suppress(Exception): + await websocket.close(code=1011, reason="Failed to initialize realtime session.") + self.websockets.pop(session_id, None) + self._session_voices.pop(session_id, None) + return False + + self.active_sessions[session_id] = session + self.session_contexts[session_id] = session_context + + self._event_tasks[session_id] = asyncio.create_task(self._process_events(session_id)) + return True + + async def disconnect(self, session_id: str) -> None: + event_task = self._event_tasks.pop(session_id, None) + if event_task is not None: + event_task.cancel() + with suppress(asyncio.CancelledError): + await event_task + + session_context = self.session_contexts.pop(session_id, None) + if session_context is not None: + with suppress(Exception): + await session_context.__aexit__(None, None, None) + + self.active_sessions.pop(session_id, None) + self.websockets.pop(session_id, None) + self._session_voices.pop(session_id, None) + + async def send_audio(self, session_id: str, audio_bytes: bytes) -> None: + session = self.active_sessions.get(session_id) + if session is None: + return + await session.send_audio(audio_bytes) + + async def send_client_event(self, session_id: str, event: dict[str, Any]) -> None: + """Send a raw client event to the underlying realtime model.""" + session = self.active_sessions.get(session_id) + if session is None: + return + + await session.model.send_event( + RealtimeModelSendRawMessage( + message=cast(RealtimeModelRawClientMessage, event), + ) + ) + + async def send_user_message(self, session_id: str, message: RealtimeUserInputMessage) -> None: + """Send a structured user message via the higher-level API (supports input_image).""" + session = self.active_sessions.get(session_id) + if session is None: + return + await session.send_message(message) + + async def interrupt(self, session_id: str) -> None: + session = self.active_sessions.get(session_id) + if session is None: + return + await session.interrupt() + + async def _process_events(self, session_id: str) -> None: + session = self.active_sessions.get(session_id) + websocket = self.websockets.get(session_id) + if session is None or websocket is None: + return + + try: + async for event in session: + if event.type == "agent_start": + await self._apply_voice_update(session_id, session, event) + + event_data = await self._serialize_event(event) + await websocket.send_text(json.dumps(event_data)) + except asyncio.CancelledError: + raise + except Exception: + logger.exception("Error processing events for session %s", session_id) + + async def _apply_voice_update( + self, + session_id: str, + session: RealtimeSession, + event: RealtimeSessionEvent, + ) -> None: + desired_voice = getattr(getattr(event, "agent", None), "voice", None) + if not desired_voice: + return + + current_voice = self._session_voices.get(session_id) + if desired_voice == current_voice: + return + + try: + await session.model.send_event(RealtimeModelSendSessionUpdate(session_settings={"voice": desired_voice})) + except Exception: + logger.exception("Failed to update realtime voice to %s", desired_voice) + return + + logger.info("Updated realtime voice to %s", desired_voice) + self._session_voices[session_id] = desired_voice + + def _sanitize_history_item(self, item: RealtimeItem) -> dict[str, Any]: + """Remove large binary payloads from history items while keeping transcripts.""" + item_dict = item.model_dump() + content = item_dict.get("content") + if isinstance(content, list): + sanitized_content: list[Any] = [] + for part in content: + if isinstance(part, dict): + sanitized_part = part.copy() + if sanitized_part.get("type") in {"audio", "input_audio"}: + sanitized_part.pop("audio", None) + sanitized_content.append(sanitized_part) + else: + sanitized_content.append(part) + item_dict["content"] = sanitized_content + return item_dict + + async def _serialize_event(self, event: RealtimeSessionEvent) -> dict[str, Any]: + base_event: dict[str, Any] = {"type": event.type} + + if event.type == "agent_start": + base_event["agent"] = event.agent.name + elif event.type == "agent_end": + base_event["agent"] = event.agent.name + elif event.type == "handoff": + base_event["from"] = event.from_agent.name + base_event["to"] = event.to_agent.name + elif event.type == "tool_start": + base_event["tool"] = event.tool.name + elif event.type == "tool_end": + base_event["tool"] = event.tool.name + base_event["output"] = str(event.output) + elif event.type == "audio": + base_event["audio"] = base64.b64encode(event.audio.data).decode("utf-8") + elif event.type == "audio_interrupted": + pass + elif event.type == "audio_end": + pass + elif event.type == "history_updated": + base_event["history"] = [self._sanitize_history_item(item) for item in event.history] + elif event.type == "history_added": + try: + base_event["item"] = self._sanitize_history_item(event.item) + except Exception: + base_event["item"] = None + elif event.type == "guardrail_tripped": + base_event["guardrail_results"] = [{"name": result.guardrail.name} for result in event.guardrail_results] + elif event.type == "raw_model_event": + base_event["raw_model_event"] = {"type": event.data.type} + elif event.type == "error": + base_event["error"] = str(event.error) if hasattr(event, "error") else "Unknown error" + elif event.type == "input_audio_timeout_triggered": + pass + else: + assert_never(event) + + return base_event + + +def create_realtime_demo_app( + session_factory: RealtimeSessionFactory, + *, + static_dir: Path | None = None, + cors_origins: list[str] | None = None, +) -> FastAPI: + """Create a FastAPI app serving the realtime demo frontend + websocket bridge.""" + + manager = RealtimeWebSocketManager(session_factory) + static_path = static_dir or (Path(__file__).resolve().parent / "static") + + @asynccontextmanager + async def lifespan(_: FastAPI): + yield + + app = FastAPI(lifespan=lifespan) + + if cors_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, + ) + + @app.websocket("/ws/{session_id}") + async def websocket_endpoint(websocket: WebSocket, session_id: str) -> None: + connected = await manager.connect(websocket, session_id) + if not connected: + return + + image_buffers: dict[str, dict[str, Any]] = {} + try: + while True: + data = await websocket.receive_text() + message = json.loads(data) + message_type = message.get("type") + + if message_type == "audio": + int16_data = message.get("data") or [] + audio_bytes = struct.pack(f"{len(int16_data)}h", *int16_data) + await manager.send_audio(session_id, audio_bytes) + elif message_type == "image": + logger.info("Received image message from client (session %s).", session_id) + data_url = message.get("data_url") + prompt_text = message.get("text") or "Please describe this image." + if not data_url: + await websocket.send_text( + json.dumps({"type": "error", "error": "No data_url for image message."}) + ) + continue + + user_msg: RealtimeUserInputMessage = { + "type": "message", + "role": "user", + "content": ( + [ + {"type": "input_image", "image_url": data_url, "detail": "high"}, + {"type": "input_text", "text": prompt_text}, + ] + if prompt_text + else [{"type": "input_image", "image_url": data_url, "detail": "high"}] + ), + } + await manager.send_user_message(session_id, user_msg) + await websocket.send_text( + json.dumps({"type": "client_info", "info": "image_enqueued", "size": len(data_url)}) + ) + elif message_type == "commit_audio": + await manager.send_client_event(session_id, {"type": "input_audio_buffer.commit"}) + elif message_type == "image_start": + img_id = str(message.get("id")) + image_buffers[img_id] = { + "text": message.get("text") or "Please describe this image.", + "chunks": [], + } + await websocket.send_text( + json.dumps({"type": "client_info", "info": "image_start_ack", "id": img_id}) + ) + elif message_type == "image_chunk": + img_id = str(message.get("id")) + chunk = message.get("chunk", "") + if img_id in image_buffers: + image_buffers[img_id]["chunks"].append(chunk) + if len(image_buffers[img_id]["chunks"]) % 10 == 0: + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_chunk_ack", + "id": img_id, + "count": len(image_buffers[img_id]["chunks"]), + } + ) + ) + elif message_type == "image_end": + img_id = str(message.get("id")) + buf = image_buffers.pop(img_id, None) + if buf is None: + await websocket.send_text( + json.dumps({"type": "error", "error": "Unknown image id for image_end."}) + ) + continue + + data_url = "".join(buf["chunks"]) if buf["chunks"] else None + prompt_text = buf["text"] + if not data_url: + await websocket.send_text(json.dumps({"type": "error", "error": "Empty image."})) + continue + + user_msg = { + "type": "message", + "role": "user", + "content": ( + [ + {"type": "input_image", "image_url": data_url, "detail": "high"}, + {"type": "input_text", "text": prompt_text}, + ] + if prompt_text + else [{"type": "input_image", "image_url": data_url, "detail": "high"}] + ), + } + await manager.send_user_message(session_id, cast(RealtimeUserInputMessage, user_msg)) + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_enqueued", + "id": img_id, + "size": len(data_url), + } + ) + ) + elif message_type == "interrupt": + await manager.interrupt(session_id) + except WebSocketDisconnect: + pass + finally: + await manager.disconnect(session_id) + + app.mount("/", StaticFiles(directory=str(static_path), html=True), name="static") + + return app + + +if __name__ == "__main__": + raise SystemExit("Use agency_swarm.ui.demos.realtime.RealtimeDemoLauncher.start() to run the realtime demo.") diff --git a/src/agency_swarm/ui/demos/realtime/app/static/app.js b/src/agency_swarm/ui/demos/realtime/app/static/app.js new file mode 100644 index 00000000..0724cf4b --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/app/static/app.js @@ -0,0 +1,682 @@ +class RealtimeDemo { + constructor() { + this.ws = null; + this.isConnected = false; + this.isMuted = false; + this.isCapturing = false; + this.audioContext = null; + this.captureSource = null; + this.captureNode = null; + this.stream = null; + this.sessionId = this.generateSessionId(); + + this.isPlayingAudio = false; + this.playbackAudioContext = null; + this.playbackNode = null; + this.playbackInitPromise = null; + this.pendingPlaybackChunks = []; + this.playbackFadeSec = 0.02; // ~20ms fade to reduce clicks + this.messageNodes = new Map(); // item_id -> DOM node + this.seenItemIds = new Set(); // item_id set for append-only syncing + + this.initializeElements(); + this.setupEventListeners(); + } + + initializeElements() { + this.connectBtn = document.getElementById('connectBtn'); + this.muteBtn = document.getElementById('muteBtn'); + this.imageBtn = document.getElementById('imageBtn'); + this.imageInput = document.getElementById('imageInput'); + this.imagePrompt = document.getElementById('imagePrompt'); + this.status = document.getElementById('status'); + this.messagesContent = document.getElementById('messagesContent'); + this.eventsContent = document.getElementById('eventsContent'); + this.toolsContent = document.getElementById('toolsContent'); + } + + setupEventListeners() { + this.connectBtn.addEventListener('click', () => { + if (this.isConnected) { + this.disconnect(); + } else { + this.connect(); + } + }); + + this.muteBtn.addEventListener('click', () => { + this.toggleMute(); + }); + + // Image upload + this.imageBtn.addEventListener('click', (e) => { + e.preventDefault(); + e.stopPropagation(); + console.log('Send Image clicked'); + // Programmatically open the hidden file input + this.imageInput.click(); + }); + + this.imageInput.addEventListener('change', async (e) => { + console.log('Image input change fired'); + const file = e.target.files && e.target.files[0]; + if (!file) return; + await this._handlePickedFile(file); + this.imageInput.value = ''; + }); + + this._handlePickedFile = async (file) => { + try { + const dataUrl = await this.prepareDataURL(file); + const promptText = (this.imagePrompt && this.imagePrompt.value) || ''; + // Send to server; server forwards to Realtime API. + // Use chunked frames to avoid WS frame limits. + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + console.log('Interrupting and sending image (chunked) to server WebSocket'); + // Stop any current audio locally and tell model to interrupt + this.stopAudioPlayback(); + this.ws.send(JSON.stringify({ type: 'interrupt' })); + const id = 'img_' + Math.random().toString(36).slice(2); + const CHUNK = 60_000; // ~60KB per frame + this.ws.send(JSON.stringify({ type: 'image_start', id, text: promptText })); + for (let i = 0; i < dataUrl.length; i += CHUNK) { + const chunk = dataUrl.slice(i, i + CHUNK); + this.ws.send(JSON.stringify({ type: 'image_chunk', id, chunk })); + } + this.ws.send(JSON.stringify({ type: 'image_end', id })); + } else { + console.warn('Not connected; image will not be sent. Click Connect first.'); + } + // Add to UI immediately for better feedback + console.log('Adding local user image bubble'); + this.addUserImageMessage(dataUrl, promptText); + } catch (err) { + console.error('Failed to process image:', err); + } + }; + } + + generateSessionId() { + return 'session_' + Math.random().toString(36).substr(2, 9); + } + + async connect() { + try { + this.ws = new WebSocket(`ws://localhost:8000/ws/${this.sessionId}`); + + this.ws.onopen = () => { + this.isConnected = true; + this.updateConnectionUI(); + this.startContinuousCapture(); + }; + + this.ws.onmessage = (event) => { + const data = JSON.parse(event.data); + this.handleRealtimeEvent(data); + }; + + this.ws.onclose = () => { + this.isConnected = false; + this.updateConnectionUI(); + }; + + this.ws.onerror = (error) => { + console.error('WebSocket error:', error); + }; + + } catch (error) { + console.error('Failed to connect:', error); + } + } + + disconnect() { + if (this.ws) { + this.ws.close(); + } + this.stopContinuousCapture(); + } + + updateConnectionUI() { + if (this.isConnected) { + this.connectBtn.textContent = 'Disconnect'; + this.connectBtn.className = 'connect-btn connected'; + this.status.textContent = 'Connected'; + this.status.className = 'status connected'; + this.muteBtn.disabled = false; + } else { + this.connectBtn.textContent = 'Connect'; + this.connectBtn.className = 'connect-btn disconnected'; + this.status.textContent = 'Disconnected'; + this.status.className = 'status disconnected'; + this.muteBtn.disabled = true; + } + } + + toggleMute() { + this.isMuted = !this.isMuted; + this.updateMuteUI(); + } + + updateMuteUI() { + if (this.isMuted) { + this.muteBtn.textContent = '🔇 Mic Off'; + this.muteBtn.className = 'mute-btn muted'; + } else { + this.muteBtn.textContent = '🎤 Mic On'; + this.muteBtn.className = 'mute-btn unmuted'; + if (this.isCapturing) { + this.muteBtn.classList.add('active'); + } + } + } + + readFileAsDataURL(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result); + reader.onerror = reject; + reader.readAsDataURL(file); + }); + } + + async prepareDataURL(file) { + const original = await this.readFileAsDataURL(file); + try { + const img = new Image(); + img.decoding = 'async'; + const loaded = new Promise((res, rej) => { + img.onload = () => res(); + img.onerror = rej; + }); + img.src = original; + await loaded; + + const maxDim = 1024; + const maxSide = Math.max(img.width, img.height); + const scale = maxSide > maxDim ? (maxDim / maxSide) : 1; + const w = Math.max(1, Math.round(img.width * scale)); + const h = Math.max(1, Math.round(img.height * scale)); + + const canvas = document.createElement('canvas'); + canvas.width = w; canvas.height = h; + const ctx = canvas.getContext('2d'); + ctx.drawImage(img, 0, 0, w, h); + return canvas.toDataURL('image/jpeg', 0.85); + } catch (e) { + console.warn('Image resize failed; sending original', e); + return original; + } + } + + async startContinuousCapture() { + if (!this.isConnected || this.isCapturing) return; + + // Check if getUserMedia is available + if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) { + throw new Error('getUserMedia not available. Please use HTTPS or localhost.'); + } + + try { + this.stream = await navigator.mediaDevices.getUserMedia({ + audio: { + sampleRate: 24000, + channelCount: 1, + echoCancellation: true, + noiseSuppression: true + } + }); + + this.audioContext = new AudioContext({ sampleRate: 24000, latencyHint: 'interactive' }); + if (this.audioContext.state === 'suspended') { + try { await this.audioContext.resume(); } catch {} + } + + if (!this.audioContext.audioWorklet) { + throw new Error('AudioWorklet API not supported in this browser.'); + } + + await this.audioContext.audioWorklet.addModule('audio-recorder.worklet.js'); + + this.captureSource = this.audioContext.createMediaStreamSource(this.stream); + this.captureNode = new AudioWorkletNode(this.audioContext, 'pcm-recorder'); + + this.captureNode.port.onmessage = (event) => { + if (this.isMuted) return; + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) return; + + const chunk = event.data instanceof ArrayBuffer ? new Int16Array(event.data) : event.data; + if (!chunk || !(chunk instanceof Int16Array) || chunk.length === 0) return; + + this.ws.send(JSON.stringify({ + type: 'audio', + data: Array.from(chunk) + })); + }; + + this.captureSource.connect(this.captureNode); + this.captureNode.connect(this.audioContext.destination); + + this.isCapturing = true; + this.updateMuteUI(); + + } catch (error) { + console.error('Failed to start audio capture:', error); + } + } + + stopContinuousCapture() { + if (!this.isCapturing) return; + + this.isCapturing = false; + + if (this.captureSource) { + try { this.captureSource.disconnect(); } catch {} + this.captureSource = null; + } + + if (this.captureNode) { + this.captureNode.port.onmessage = null; + try { this.captureNode.disconnect(); } catch {} + this.captureNode = null; + } + + if (this.audioContext) { + this.audioContext.close(); + this.audioContext = null; + } + + if (this.stream) { + this.stream.getTracks().forEach(track => track.stop()); + this.stream = null; + } + + this.updateMuteUI(); + } + + handleRealtimeEvent(event) { + // Add to raw events pane + this.addRawEvent(event); + + // Add to tools panel if it's a tool or handoff event + if (event.type === 'tool_start' || event.type === 'tool_end' || event.type === 'handoff') { + this.addToolEvent(event); + } + + // Handle specific event types + switch (event.type) { + case 'audio': + this.playAudio(event.audio); + break; + case 'audio_interrupted': + this.stopAudioPlayback(); + break; + case 'input_audio_timeout_triggered': + // Ask server to commit the input buffer to expedite model response + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify({ type: 'commit_audio' })); + } + break; + case 'history_updated': + this.syncMissingFromHistory(event.history); + this.updateLastMessageFromHistory(event.history); + break; + case 'history_added': + // Append just the new item without clearing the thread. + if (event.item) { + this.addMessageFromItem(event.item); + } + break; + } + } + updateLastMessageFromHistory(history) { + if (!history || !Array.isArray(history) || history.length === 0) return; + // Find the last message item in history + let last = null; + for (let i = history.length - 1; i >= 0; i--) { + const it = history[i]; + if (it && it.type === 'message') { last = it; break; } + } + if (!last) return; + const itemId = last.item_id; + + // Extract a text representation (for assistant transcript updates) + let text = ''; + if (Array.isArray(last.content)) { + for (const part of last.content) { + if (!part || typeof part !== 'object') continue; + if (part.type === 'text' && part.text) text += part.text; + else if (part.type === 'input_text' && part.text) text += part.text; + else if ((part.type === 'input_audio' || part.type === 'audio') && part.transcript) text += part.transcript; + } + } + + const node = this.messageNodes.get(itemId); + if (!node) { + // If we haven't rendered this item yet, append it now. + this.addMessageFromItem(last); + return; + } + + // Update only the text content of the bubble, preserving any images already present. + const bubble = node.querySelector('.message-bubble'); + if (bubble && text && text.trim()) { + // If there's an , keep it and only update the trailing caption/text node. + const hasImg = !!bubble.querySelector('img'); + if (hasImg) { + // Ensure there is a caption div after the image + let cap = bubble.querySelector('.image-caption'); + if (!cap) { + cap = document.createElement('div'); + cap.className = 'image-caption'; + cap.style.marginTop = '0.5rem'; + bubble.appendChild(cap); + } + cap.textContent = text.trim(); + } else { + bubble.textContent = text.trim(); + } + this.scrollToBottom(); + } + } + + syncMissingFromHistory(history) { + if (!history || !Array.isArray(history)) return; + for (const item of history) { + if (!item || item.type !== 'message') continue; + const id = item.item_id; + if (!id) continue; + if (!this.seenItemIds.has(id)) { + this.addMessageFromItem(item); + } + } + } + + addMessageFromItem(item) { + try { + if (!item || item.type !== 'message') return; + const role = item.role; + let content = ''; + let imageUrls = []; + + if (Array.isArray(item.content)) { + for (const contentPart of item.content) { + if (!contentPart || typeof contentPart !== 'object') continue; + if (contentPart.type === 'text' && contentPart.text) { + content += contentPart.text; + } else if (contentPart.type === 'input_text' && contentPart.text) { + content += contentPart.text; + } else if (contentPart.type === 'input_audio' && contentPart.transcript) { + content += contentPart.transcript; + } else if (contentPart.type === 'audio' && contentPart.transcript) { + content += contentPart.transcript; + } else if (contentPart.type === 'input_image') { + const url = contentPart.image_url || contentPart.url; + if (typeof url === 'string' && url) imageUrls.push(url); + } + } + } + + let node = null; + if (imageUrls.length > 0) { + for (const url of imageUrls) { + node = this.addImageMessage(role, url, content.trim()); + } + } else if (content && content.trim()) { + node = this.addMessage(role, content.trim()); + } + if (node && item.item_id) { + this.messageNodes.set(item.item_id, node); + this.seenItemIds.add(item.item_id); + } + } catch (e) { + console.error('Failed to add message from item:', e, item); + } + } + + addMessage(type, content) { + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${type}`; + + const bubbleDiv = document.createElement('div'); + bubbleDiv.className = 'message-bubble'; + bubbleDiv.textContent = content; + + messageDiv.appendChild(bubbleDiv); + this.messagesContent.appendChild(messageDiv); + this.scrollToBottom(); + + return messageDiv; + } + + addImageMessage(role, imageUrl, caption = '') { + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${role}`; + + const bubbleDiv = document.createElement('div'); + bubbleDiv.className = 'message-bubble'; + + const img = document.createElement('img'); + img.src = imageUrl; + img.alt = 'Uploaded image'; + img.style.maxWidth = '220px'; + img.style.borderRadius = '8px'; + img.style.display = 'block'; + + bubbleDiv.appendChild(img); + if (caption) { + const cap = document.createElement('div'); + cap.textContent = caption; + cap.style.marginTop = '0.5rem'; + bubbleDiv.appendChild(cap); + } + + messageDiv.appendChild(bubbleDiv); + this.messagesContent.appendChild(messageDiv); + this.scrollToBottom(); + + return messageDiv; + } + + addUserImageMessage(imageUrl, caption = '') { + return this.addImageMessage('user', imageUrl, caption); + } + + addRawEvent(event) { + const eventDiv = document.createElement('div'); + eventDiv.className = 'event'; + + const headerDiv = document.createElement('div'); + headerDiv.className = 'event-header'; + headerDiv.innerHTML = ` + ${event.type} + + `; + + const contentDiv = document.createElement('div'); + contentDiv.className = 'event-content collapsed'; + contentDiv.textContent = JSON.stringify(event, null, 2); + + headerDiv.addEventListener('click', () => { + const isCollapsed = contentDiv.classList.contains('collapsed'); + contentDiv.classList.toggle('collapsed'); + headerDiv.querySelector('span:last-child').textContent = isCollapsed ? '▲' : '▼'; + }); + + eventDiv.appendChild(headerDiv); + eventDiv.appendChild(contentDiv); + this.eventsContent.appendChild(eventDiv); + + // Auto-scroll events pane + this.eventsContent.scrollTop = this.eventsContent.scrollHeight; + } + + addToolEvent(event) { + const eventDiv = document.createElement('div'); + eventDiv.className = 'event'; + + let title = ''; + let description = ''; + let eventClass = ''; + + if (event.type === 'handoff') { + title = `🔄 Handoff`; + description = `From ${event.from} to ${event.to}`; + eventClass = 'handoff'; + } else if (event.type === 'tool_start') { + title = `🔧 Tool Started`; + description = `Running ${event.tool}`; + eventClass = 'tool'; + } else if (event.type === 'tool_end') { + title = `✅ Tool Completed`; + description = `${event.tool}: ${event.output || 'No output'}`; + eventClass = 'tool'; + } + + eventDiv.innerHTML = ` +
+
+
${title}
+
${description}
+
+ ${new Date().toLocaleTimeString()} +
+ `; + + this.toolsContent.appendChild(eventDiv); + + // Auto-scroll tools pane + this.toolsContent.scrollTop = this.toolsContent.scrollHeight; + } + + async playAudio(audioBase64) { + try { + if (!audioBase64 || audioBase64.length === 0) { + console.warn('Received empty audio data, skipping playback'); + return; + } + + const int16Array = this.decodeBase64ToInt16(audioBase64); + if (!int16Array || int16Array.length === 0) { + console.warn('Audio chunk has no samples, skipping'); + return; + } + + this.pendingPlaybackChunks.push(int16Array); + await this.ensurePlaybackNode(); + this.flushPendingPlaybackChunks(); + + } catch (error) { + console.error('Failed to play audio:', error); + this.pendingPlaybackChunks = []; + } + } + + async ensurePlaybackNode() { + if (this.playbackNode) { + return; + } + + if (!this.playbackInitPromise) { + this.playbackInitPromise = (async () => { + if (!this.playbackAudioContext) { + this.playbackAudioContext = new AudioContext({ sampleRate: 24000, latencyHint: 'interactive' }); + } + + if (this.playbackAudioContext.state === 'suspended') { + try { await this.playbackAudioContext.resume(); } catch {} + } + + if (!this.playbackAudioContext.audioWorklet) { + throw new Error('AudioWorklet API not supported in this browser.'); + } + + await this.playbackAudioContext.audioWorklet.addModule('audio-playback.worklet.js'); + + this.playbackNode = new AudioWorkletNode(this.playbackAudioContext, 'pcm-playback', { outputChannelCount: [1] }); + this.playbackNode.port.onmessage = (event) => { + const message = event.data; + if (!message || typeof message !== 'object') return; + if (message.type === 'drained') { + this.isPlayingAudio = false; + } + }; + + // Provide initial configuration for fades. + const fadeSamples = Math.floor(this.playbackAudioContext.sampleRate * this.playbackFadeSec); + this.playbackNode.port.postMessage({ type: 'config', fadeSamples }); + + this.playbackNode.connect(this.playbackAudioContext.destination); + })().catch((error) => { + this.playbackInitPromise = null; + throw error; + }); + } + + await this.playbackInitPromise; + } + + flushPendingPlaybackChunks() { + if (!this.playbackNode) { + return; + } + + while (this.pendingPlaybackChunks.length > 0) { + const chunk = this.pendingPlaybackChunks.shift(); + if (!chunk || !(chunk instanceof Int16Array) || chunk.length === 0) { + continue; + } + + try { + this.playbackNode.port.postMessage( + { type: 'chunk', payload: chunk.buffer }, + [chunk.buffer] + ); + this.isPlayingAudio = true; + } catch (error) { + console.error('Failed to enqueue audio chunk to worklet:', error); + } + } + } + + decodeBase64ToInt16(audioBase64) { + try { + const binaryString = atob(audioBase64); + const length = binaryString.length; + const bytes = new Uint8Array(length); + for (let i = 0; i < length; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + return new Int16Array(bytes.buffer); + } catch (error) { + console.error('Failed to decode audio chunk:', error); + return null; + } + } + + stopAudioPlayback() { + console.log('Stopping audio playback due to interruption'); + + this.pendingPlaybackChunks = []; + + if (this.playbackNode) { + try { + this.playbackNode.port.postMessage({ type: 'stop' }); + } catch (error) { + console.error('Failed to notify playback worklet to stop:', error); + } + } + + this.isPlayingAudio = false; + + console.log('Audio playback stopped and queue cleared'); + } + + scrollToBottom() { + this.messagesContent.scrollTop = this.messagesContent.scrollHeight; + } +} + +// Initialize the demo when the page loads +document.addEventListener('DOMContentLoaded', () => { + new RealtimeDemo(); +}); diff --git a/src/agency_swarm/ui/demos/realtime/app/static/audio-playback.worklet.js b/src/agency_swarm/ui/demos/realtime/app/static/audio-playback.worklet.js new file mode 100644 index 00000000..63735f82 --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/app/static/audio-playback.worklet.js @@ -0,0 +1,120 @@ +class PCMPlaybackProcessor extends AudioWorkletProcessor { + constructor() { + super(); + + this.buffers = []; + this.currentBuffer = null; + this.currentIndex = 0; + this.isCurrentlyPlaying = false; + this.fadeSamples = Math.round(sampleRate * 0.02); + + this.port.onmessage = (event) => { + const message = event.data; + if (!message || typeof message !== 'object') return; + + if (message.type === 'chunk') { + const payload = message.payload; + if (!(payload instanceof ArrayBuffer)) { + return; + } + + const int16Data = new Int16Array(payload); + if (int16Data.length === 0) { + return; + } + + const scale = 1 / 32768; + const floatData = new Float32Array(int16Data.length); + for (let i = 0; i < int16Data.length; i++) { + floatData[i] = Math.max(-1, Math.min(1, int16Data[i] * scale)); + } + + if (!this.hasPendingAudio()) { + const fadeSamples = Math.min(this.fadeSamples, floatData.length); + for (let i = 0; i < fadeSamples; i++) { + const gain = fadeSamples <= 1 ? 1 : (i / fadeSamples); + floatData[i] *= gain; + } + } + + this.buffers.push(floatData); + + } else if (message.type === 'stop') { + this.reset(); + this.port.postMessage({ type: 'drained' }); + + } else if (message.type === 'config') { + const fadeSamples = message.fadeSamples; + if (Number.isFinite(fadeSamples) && fadeSamples >= 0) { + this.fadeSamples = fadeSamples >>> 0; + } + } + }; + } + + reset() { + this.buffers = []; + this.currentBuffer = null; + this.currentIndex = 0; + this.isCurrentlyPlaying = false; + } + + hasPendingAudio() { + if (this.currentBuffer && this.currentIndex < this.currentBuffer.length) { + return true; + } + return this.buffers.length > 0; + } + + pullSample() { + if (this.currentBuffer && this.currentIndex < this.currentBuffer.length) { + return this.currentBuffer[this.currentIndex++]; + } + + if (this.currentBuffer && this.currentIndex >= this.currentBuffer.length) { + this.currentBuffer = null; + this.currentIndex = 0; + } + + while (this.buffers.length > 0) { + this.currentBuffer = this.buffers.shift(); + this.currentIndex = 0; + if (this.currentBuffer && this.currentBuffer.length > 0) { + return this.currentBuffer[this.currentIndex++]; + } + } + + this.currentBuffer = null; + this.currentIndex = 0; + return 0; + } + + process(inputs, outputs) { + const output = outputs[0]; + if (!output || output.length === 0) { + return true; + } + + const channel = output[0]; + let wroteSamples = false; + + for (let i = 0; i < channel.length; i++) { + const sample = this.pullSample(); + channel[i] = sample; + if (sample !== 0) { + wroteSamples = true; + } + } + + if (this.hasPendingAudio()) { + this.isCurrentlyPlaying = true; + } else if (!wroteSamples && this.isCurrentlyPlaying) { + this.isCurrentlyPlaying = false; + this.port.postMessage({ type: 'drained' }); + } + + return true; + } +} + +registerProcessor('pcm-playback', PCMPlaybackProcessor); diff --git a/src/agency_swarm/ui/demos/realtime/app/static/audio-recorder.worklet.js b/src/agency_swarm/ui/demos/realtime/app/static/audio-recorder.worklet.js new file mode 100644 index 00000000..ccd6e6b1 --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/app/static/audio-recorder.worklet.js @@ -0,0 +1,56 @@ +class PCMRecorderProcessor extends AudioWorkletProcessor { + constructor() { + super(); + this.chunkSize = 4096; + this.buffer = new Int16Array(this.chunkSize); + this.offset = 0; + this.pendingFrames = 0; + this.maxPendingFrames = 10; + } + + flushBuffer() { + if (this.offset === 0) { + return; + } + + const chunk = new Int16Array(this.offset); + chunk.set(this.buffer.subarray(0, this.offset)); + this.port.postMessage(chunk, [chunk.buffer]); + + this.offset = 0; + this.pendingFrames = 0; + } + + process(inputs) { + const input = inputs[0]; + if (!input || input.length === 0) { + return true; + } + + const channel = input[0]; + if (!channel || channel.length === 0) { + return true; + } + + for (let i = 0; i < channel.length; i++) { + let sample = channel[i]; + sample = Math.max(-1, Math.min(1, sample)); + this.buffer[this.offset++] = sample < 0 ? sample * 0x8000 : sample * 0x7fff; + + if (this.offset === this.chunkSize) { + this.flushBuffer(); + } + } + + if (this.offset > 0) { + this.pendingFrames += 1; + if (this.pendingFrames >= this.maxPendingFrames) { + this.flushBuffer(); + } + } + + return true; + } +} + +registerProcessor('pcm-recorder', PCMRecorderProcessor); diff --git a/src/agency_swarm/ui/demos/realtime/app/static/favicon.ico b/src/agency_swarm/ui/demos/realtime/app/static/favicon.ico new file mode 100644 index 00000000..e69de29b diff --git a/src/agency_swarm/ui/demos/realtime/app/static/index.html b/src/agency_swarm/ui/demos/realtime/app/static/index.html new file mode 100644 index 00000000..a62b27dd --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/app/static/index.html @@ -0,0 +1,299 @@ + + + + + + Realtime Demo + + + +
+

Realtime Demo

+ +
+ +
+
+
+ Conversation +
+
+ +
+
+ + + + + Disconnected +
+
+ +
+
+
+ Event stream +
+
+ +
+
+ +
+
+ Tools & Handoffs +
+
+ +
+
+
+
+ + + + diff --git a/src/agency_swarm/ui/demos/realtime/twilio/README.md b/src/agency_swarm/ui/demos/realtime/twilio/README.md new file mode 100644 index 00000000..4eaf4e9e --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/twilio/README.md @@ -0,0 +1,86 @@ +# Realtime Twilio Integration + +This example demonstrates how to connect the OpenAI Realtime API to a phone call using Twilio's Media Streams. The server handles incoming phone calls and streams audio between Twilio and the OpenAI Realtime API, enabling real-time voice conversations with an AI agent over the phone. + +## Prerequisites + +- Python 3.12+ +- OpenAI API key with [Realtime API](https://platform.openai.com/docs/guides/realtime) access +- [Twilio](https://www.twilio.com/docs/voice) account with a phone number +- A tunneling service like [ngrok](https://ngrok.com/) to expose your local server + +## Setup + +1. **Start the server:** + + ```bash + uv run python server.py + ``` + + The server will start on port 8000 by default. + +2. **Expose the server publicly, e.g. via ngrok:** + + ```bash + ngrok http 8000 + ``` + + Note the public URL (e.g., `https://abc123.ngrok.io`) + +3. **Configure your Twilio phone number:** + - Log into your Twilio Console + - Select your phone number + - Set the webhook URL for incoming calls to: `https://your-ngrok-url.ngrok.io/incoming-call` + - Set the HTTP method to POST + +## Usage + +1. Call your Twilio phone number +2. You'll hear: "Hello! You're now connected to an AI assistant. You can start talking!" +3. Start speaking - the AI will respond in real-time +4. The assistant has access to tools like weather information and current time + +## How It Works + +1. **Incoming Call**: When someone calls your Twilio number, Twilio makes a request to `/incoming-call` +2. **TwiML Response**: The server returns TwiML that: + - Plays a greeting message + - Connects the call to a WebSocket stream at `/media-stream` +3. **WebSocket Connection**: Twilio establishes a WebSocket connection for bidirectional audio streaming +4. **Transport Layer**: The `TwilioRealtimeTransportLayer` class owns the WebSocket message handling: + - Takes ownership of the Twilio WebSocket after initial handshake + - Runs its own message loop to process all Twilio messages + - Handles protocol differences between Twilio and OpenAI + - Automatically sets G.711 μ-law audio format for Twilio compatibility + - Manages audio chunk tracking for interruption support + - Wraps the OpenAI realtime model instead of subclassing it +5. **Audio Processing**: + - Audio from the caller is base64 decoded and sent to OpenAI Realtime API + - Audio responses from OpenAI are base64 encoded and sent back to Twilio + - Twilio plays the audio to the caller + +## Configuration + +- **Port**: Set `PORT` environment variable (default: 8000) +- **OpenAI API Key**: Set `OPENAI_API_KEY` environment variable +- **Agent Instructions**: Modify the `RealtimeAgent` configuration in `server.py` +- **Tools**: Add or modify function tools in `server.py` + +## Troubleshooting + +- **WebSocket connection issues**: Ensure your ngrok URL is correct and publicly accessible +- **Audio quality**: Twilio streams audio in mulaw format at 8kHz, which may affect quality +- **Latency**: Network latency between Twilio, your server, and OpenAI affects response time +- **Logs**: Check the console output for detailed connection and error logs + +## Architecture + +``` +Phone Call → Twilio → WebSocket → TwilioRealtimeTransportLayer → OpenAI Realtime API + ↓ + RealtimeAgent with Tools + ↓ + Audio Response → Twilio → Phone Call +``` + +The `TwilioRealtimeTransportLayer` acts as a bridge between Twilio's Media Streams and OpenAI's Realtime API, handling the protocol differences and audio format conversions. It wraps the OpenAI realtime model to provide a clean interface for Twilio integration. diff --git a/src/agency_swarm/ui/demos/realtime/twilio/__init__.py b/src/agency_swarm/ui/demos/realtime/twilio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/agency_swarm/ui/demos/realtime/twilio/requirements.txt b/src/agency_swarm/ui/demos/realtime/twilio/requirements.txt new file mode 100644 index 00000000..7c7bc4b3 --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/twilio/requirements.txt @@ -0,0 +1,5 @@ +openai-agents +fastapi +uvicorn[standard] +websockets +python-dotenv diff --git a/src/agency_swarm/ui/demos/realtime/twilio/server.py b/src/agency_swarm/ui/demos/realtime/twilio/server.py new file mode 100644 index 00000000..8a753f78 --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/twilio/server.py @@ -0,0 +1,80 @@ +import os +from typing import TYPE_CHECKING + +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi.responses import PlainTextResponse + +# Import TwilioHandler class - handle both module and package use cases +if TYPE_CHECKING: + # For type checking, use the relative import + from .twilio_handler import TwilioHandler +else: + # At runtime, try both import styles + try: + # Try relative import first (when used as a package) + from .twilio_handler import TwilioHandler + except ImportError: + # Fall back to direct import (when run as a script) + from twilio_handler import TwilioHandler + + +class TwilioWebSocketManager: + def __init__(self): + self.active_handlers: dict[str, TwilioHandler] = {} + + async def new_session(self, websocket: WebSocket) -> TwilioHandler: + """Create and configure a new session.""" + print("Creating twilio handler") + + handler = TwilioHandler(websocket) + return handler + + # In a real app, you'd also want to clean up/close the handler when the call ends + + +manager = TwilioWebSocketManager() +app = FastAPI() + + +@app.get("/") +async def root(): + return {"message": "Twilio Media Stream Server is running!"} + + +@app.post("/incoming-call") +@app.get("/incoming-call") +async def incoming_call(request: Request): + """Handle incoming Twilio phone calls""" + host = request.headers.get("Host") + + twiml_response = f""" + + Hello! You're now connected to an AI assistant. You can start talking! + + + +""" + return PlainTextResponse(content=twiml_response, media_type="text/xml") + + +@app.websocket("/media-stream") +async def media_stream_endpoint(websocket: WebSocket): + """WebSocket endpoint for Twilio Media Streams""" + + try: + handler = await manager.new_session(websocket) + await handler.start() + + await handler.wait_until_done() + + except WebSocketDisconnect: + print("WebSocket disconnected") + except Exception as e: + print(f"WebSocket error: {e}") + + +if __name__ == "__main__": + import uvicorn + + port = int(os.getenv("PORT", 8000)) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/src/agency_swarm/ui/demos/realtime/twilio/twilio_handler.py b/src/agency_swarm/ui/demos/realtime/twilio/twilio_handler.py new file mode 100644 index 00000000..1998ccc1 --- /dev/null +++ b/src/agency_swarm/ui/demos/realtime/twilio/twilio_handler.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import time +from datetime import datetime +from typing import Any + +from agents import function_tool +from agents.realtime import ( + RealtimeAgent, + RealtimePlaybackTracker, + RealtimeRunner, + RealtimeSession, + RealtimeSessionEvent, +) +from fastapi import WebSocket + + +@function_tool +def get_weather(city: str) -> str: + """Get the weather in a city.""" + return f"The weather in {city} is sunny." + + +@function_tool +def get_current_time() -> str: + """Get the current time.""" + return f"The current time is {datetime.now().strftime('%H:%M:%S')}" + + +agent = RealtimeAgent( + name="Twilio Assistant", + instructions=( + "You are a helpful assistant that starts every conversation with a creative greeting. " + "Keep responses concise and friendly since this is a phone conversation." + ), + tools=[get_weather, get_current_time], +) + + +class TwilioHandler: + def __init__(self, twilio_websocket: WebSocket): + self.twilio_websocket = twilio_websocket + self._message_loop_task: asyncio.Task[None] | None = None + self.session: RealtimeSession | None = None + self.playback_tracker = RealtimePlaybackTracker() + + # Audio buffering configuration (matching CLI demo) + self.CHUNK_LENGTH_S = 0.05 # 50ms chunks like CLI demo + self.SAMPLE_RATE = 8000 # Twilio uses 8kHz for g711_ulaw + self.BUFFER_SIZE_BYTES = int(self.SAMPLE_RATE * self.CHUNK_LENGTH_S) # 50ms worth of audio + + self._stream_sid: str | None = None + self._audio_buffer: bytearray = bytearray() + self._last_buffer_send_time = time.time() + + # Mark event tracking for playback + self._mark_counter = 0 + self._mark_data: dict[str, tuple[str, int, int]] = {} # mark_id -> (item_id, content_index, byte_count) + + async def start(self) -> None: + """Start the session.""" + runner = RealtimeRunner(agent) + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable is required") + + self.session = await runner.run( + model_config={ + "api_key": api_key, + "initial_model_settings": { + "input_audio_format": "g711_ulaw", + "output_audio_format": "g711_ulaw", + "turn_detection": { + "type": "semantic_vad", + "interrupt_response": True, + "create_response": True, + }, + }, + "playback_tracker": self.playback_tracker, + } + ) + + await self.session.enter() + + await self.twilio_websocket.accept() + print("Twilio WebSocket connection accepted") + + self._realtime_session_task = asyncio.create_task(self._realtime_session_loop()) + self._message_loop_task = asyncio.create_task(self._twilio_message_loop()) + self._buffer_flush_task = asyncio.create_task(self._buffer_flush_loop()) + + async def wait_until_done(self) -> None: + """Wait until the session is done.""" + assert self._message_loop_task is not None + await self._message_loop_task + + async def _realtime_session_loop(self) -> None: + """Listen for events from the realtime session.""" + assert self.session is not None + try: + async for event in self.session: + await self._handle_realtime_event(event) + except Exception as e: + print(f"Error in realtime session loop: {e}") + + async def _twilio_message_loop(self) -> None: + """Listen for messages from Twilio WebSocket and handle them.""" + try: + while True: + message_text = await self.twilio_websocket.receive_text() + message = json.loads(message_text) + await self._handle_twilio_message(message) + except json.JSONDecodeError as e: + print(f"Failed to parse Twilio message as JSON: {e}") + except Exception as e: + print(f"Error in Twilio message loop: {e}") + + async def _handle_realtime_event(self, event: RealtimeSessionEvent) -> None: + """Handle events from the realtime session.""" + if event.type == "audio": + base64_audio = base64.b64encode(event.audio.data).decode("utf-8") + await self.twilio_websocket.send_text( + json.dumps( + { + "event": "media", + "streamSid": self._stream_sid, + "media": {"payload": base64_audio}, + } + ) + ) + + # Send mark event for playback tracking + self._mark_counter += 1 + mark_id = str(self._mark_counter) + self._mark_data[mark_id] = ( + event.audio.item_id, + event.audio.content_index, + len(event.audio.data), + ) + + await self.twilio_websocket.send_text( + json.dumps( + { + "event": "mark", + "streamSid": self._stream_sid, + "mark": {"name": mark_id}, + } + ) + ) + + elif event.type == "audio_interrupted": + print("Sending audio interrupted to Twilio") + await self.twilio_websocket.send_text(json.dumps({"event": "clear", "streamSid": self._stream_sid})) + elif event.type == "audio_end": + print("Audio end") + elif event.type == "raw_model_event": + pass + else: + pass + + async def _handle_twilio_message(self, message: dict[str, Any]) -> None: + """Handle incoming messages from Twilio Media Stream.""" + try: + event = message.get("event") + + if event == "connected": + print("Twilio media stream connected") + elif event == "start": + start_data = message.get("start", {}) + self._stream_sid = start_data.get("streamSid") + print(f"Media stream started with SID: {self._stream_sid}") + elif event == "media": + await self._handle_media_event(message) + elif event == "mark": + await self._handle_mark_event(message) + elif event == "stop": + print("Media stream stopped") + except Exception as e: + print(f"Error handling Twilio message: {e}") + + async def _handle_media_event(self, message: dict[str, Any]) -> None: + """Handle audio data from Twilio - buffer it before sending to OpenAI.""" + media = message.get("media", {}) + payload = media.get("payload", "") + + if payload: + try: + # Decode base64 audio from Twilio (µ-law format) + ulaw_bytes = base64.b64decode(payload) + + # Add original µ-law to buffer for OpenAI (they expect µ-law) + self._audio_buffer.extend(ulaw_bytes) + + # Send buffered audio if we have enough data + if len(self._audio_buffer) >= self.BUFFER_SIZE_BYTES: + await self._flush_audio_buffer() + + except Exception as e: + print(f"Error processing audio from Twilio: {e}") + + async def _handle_mark_event(self, message: dict[str, Any]) -> None: + """Handle mark events from Twilio to update playback tracker.""" + try: + mark_data = message.get("mark", {}) + mark_id = mark_data.get("name", "") + + # Look up stored data for this mark ID + if mark_id in self._mark_data: + item_id, item_content_index, byte_count = self._mark_data[mark_id] + + # Convert byte count back to bytes for playback tracker + audio_bytes = b"\x00" * byte_count # Placeholder bytes + + # Update playback tracker + self.playback_tracker.on_play_bytes(item_id, item_content_index, audio_bytes) + print(f"Playback tracker updated: {item_id}, index {item_content_index}, {byte_count} bytes") + + # Clean up the stored data + del self._mark_data[mark_id] + + except Exception as e: + print(f"Error handling mark event: {e}") + + async def _flush_audio_buffer(self) -> None: + """Send buffered audio to OpenAI.""" + if not self._audio_buffer or not self.session: + return + + try: + # Send the buffered audio + buffer_data = bytes(self._audio_buffer) + await self.session.send_audio(buffer_data) + + # Clear the buffer + self._audio_buffer.clear() + self._last_buffer_send_time = time.time() + + except Exception as e: + print(f"Error sending buffered audio to OpenAI: {e}") + + async def _buffer_flush_loop(self) -> None: + """Periodically flush audio buffer to prevent stale data.""" + try: + while True: + await asyncio.sleep(self.CHUNK_LENGTH_S) # Check every 50ms + + # If buffer has data and it's been too long since last send, flush it + current_time = time.time() + if self._audio_buffer and current_time - self._last_buffer_send_time > self.CHUNK_LENGTH_S * 2: + await self._flush_audio_buffer() + + except Exception as e: + print(f"Error in buffer flush loop: {e}") diff --git a/tests/integration/fastapi/test_fastapi_metadata.py b/tests/integration/fastapi/test_fastapi_metadata.py index d8c42af1..6f484c03 100644 --- a/tests/integration/fastapi/test_fastapi_metadata.py +++ b/tests/integration/fastapi/test_fastapi_metadata.py @@ -454,3 +454,21 @@ def run(self) -> str: assert schema["servers"] == [{"url": "https://api.example.com/base"}] assert "/tool/EchoTool" in schema["paths"] + + +def test_run_fastapi_registers_realtime_endpoint(): + """Realtime endpoints are mounted when enable_realtime is True.""" + + def create_agency(load_threads_callback=None, save_threads_callback=None): + agent = Agent(name="VoiceAgent", instructions="Assist callers.", voice="ash") + return Agency(agent, load_threads_callback=load_threads_callback, save_threads_callback=save_threads_callback) + + app = run_fastapi( + agencies={"test_agency": create_agency}, + return_app=True, + app_token_env="", + enable_realtime=True, + ) + + route_paths = {route.path for route in app.routes} + assert "/test_agency/realtime" in route_paths diff --git a/tests/integration/realtime/__init__.py b/tests/integration/realtime/__init__.py new file mode 100644 index 00000000..6e91be9a --- /dev/null +++ b/tests/integration/realtime/__init__.py @@ -0,0 +1,3 @@ +""" +Integration tests for realtime orchestration. +""" diff --git a/tests/integration/realtime/test_realtime.py b/tests/integration/realtime/test_realtime.py new file mode 100644 index 00000000..509992df --- /dev/null +++ b/tests/integration/realtime/test_realtime.py @@ -0,0 +1,232 @@ +import asyncio + +import pytest +from agents import RunContextWrapper + +from agency_swarm import Agency, Agent +from agency_swarm.context import MasterContext +from agency_swarm.integrations import run_realtime +from agency_swarm.tools import SendMessageHandoff +from agency_swarm.utils.thread import ThreadManager + + +def _build_simple_realtime_agency() -> Agency: + """Create a tiny agency with a concierge handing off to billing.""" + billing = Agent(name="Billing", instructions="Handle billing questions.") + concierge = Agent( + name="Concierge", + instructions="Route requests.", + send_message_tool_class=SendMessageHandoff, + ) + return Agency( + concierge, + communication_flows=[ + (concierge > billing, SendMessageHandoff), + ], + ) + + +def _make_context(agency: Agency) -> RunContextWrapper[MasterContext]: + """Build a MasterContext so we can call realtime helpers directly.""" + master = MasterContext( + thread_manager=ThreadManager(), + agents=agency.agents, + user_context=dict(agency.user_context), + agent_runtime_state=agency._agent_runtime_state, + ) + return RunContextWrapper(master) + + +def test_realtime_agency_wraps_handoffs_and_agents() -> None: + agency = _build_simple_realtime_agency() + realtime_agency = agency.to_realtime() + + concierge_rt = realtime_agency.entry_agent + handoffs = concierge_rt.handoffs + + assert realtime_agency.source is agency + assert realtime_agency.source_agents["Concierge"] is agency.agents["Concierge"] + assert realtime_agency.agents["Concierge"] is concierge_rt + assert realtime_agency.shared_instructions is None + assert realtime_agency.user_context == {} + assert realtime_agency.runtime_state_map["Concierge"] is agency._agent_runtime_state["Concierge"] + assert len(handoffs) == 1 + assert handoffs[0].agent_name == "Billing" + + schema = handoffs[0].input_json_schema + assert schema.get("type") == "object" + assert schema["properties"]["recipient_agent"]["const"] == "Billing" + + target = asyncio.run(handoffs[0].on_invoke_handoff(_make_context(agency), "{}")) + assert target.name == "Billing" + + +def test_realtime_agency_requires_handoffs() -> None: + primary = Agent(name="Primary", instructions="Help the user.") + helper = Agent(name="Helper", instructions="Assist.") + agency = Agency(primary, communication_flows=[(primary, helper)]) + + with pytest.raises(ValueError): + agency.to_realtime() + + +def test_realtime_agency_allows_entry_by_name() -> None: + agency = _build_simple_realtime_agency() + realtime_agency = agency.to_realtime("Concierge") + + assert realtime_agency.entry_agent.name == "Concierge" + + with pytest.raises(ValueError): + agency.to_realtime("Unknown") + + +def test_realtime_agency_allows_entry_by_agent_object() -> None: + agency = _build_simple_realtime_agency() + concierge = agency.entry_points[0] + realtime_agency = agency.to_realtime(concierge) + + assert realtime_agency.entry_agent.name == "Concierge" + + ghost = Agent(name="Ghost", instructions="Help") + with pytest.raises(ValueError): + agency.to_realtime(ghost) + + +def test_realtime_agent_handles_callable_instructions() -> None: + captured: list[str] = [] + + def dynamic_instructions(ctx: RunContextWrapper[MasterContext], agent: Agent) -> str: + captured.append(agent.name) + return f"Hello from {agent.name}" + + greeter = Agent(name="Greeter", instructions="placeholder") + greeter.instructions = dynamic_instructions + + agency = Agency(greeter) + realtime_agent = agency.to_realtime().entry_agent + + prompt = asyncio.run(realtime_agent.get_system_prompt(_make_context(agency))) + + assert prompt == "Hello from Greeter" + assert realtime_agent.source is greeter + assert captured == ["Greeter"] + + async def async_dynamic(ctx: RunContextWrapper[MasterContext], agent: Agent) -> str: + captured.append(f"async:{agent.name}") + return f"Async hello from {agent.name}" + + greeter.instructions = async_dynamic + async_agency = Agency(greeter) + async_realtime_agent = async_agency.to_realtime().entry_agent + async_prompt = asyncio.run(async_realtime_agent.get_system_prompt(_make_context(async_agency))) + assert async_prompt == "Async hello from Greeter" + assert async_realtime_agent.source is greeter + assert captured == ["Greeter", "async:Greeter"] + + +def test_realtime_handoff_respects_is_enabled_callable() -> None: + from agency_swarm.realtime.agency import _wrap_is_enabled # type: ignore[attr-defined] + + agency = _build_simple_realtime_agency() + + async def enabled(_: RunContextWrapper[MasterContext], agent: Agent) -> bool: + return agent.name == "Concierge" + + concierge = agency.agents["Concierge"] + wrapped = _wrap_is_enabled(enabled, concierge) # type: ignore[arg-type] + + result = asyncio.run(wrapped(_make_context(agency), agency.to_realtime().entry_agent)) + assert result is True + + def disabled(_: RunContextWrapper[MasterContext], agent: Agent) -> bool: + return agent.name != "Concierge" + + sync_wrapped = _wrap_is_enabled(disabled, concierge) # type: ignore[arg-type] + assert asyncio.run(sync_wrapped(_make_context(agency), agency.to_realtime().entry_agent)) is False + + +def test_realtime_agency_requires_entry_point() -> None: + agency = _build_simple_realtime_agency() + agency.entry_points.clear() + + with pytest.raises(ValueError): + agency.to_realtime() + + +def test_realtime_agency_missing_runtime_state_raises() -> None: + agency = _build_simple_realtime_agency() + agency._agent_runtime_state.pop("Concierge") + + with pytest.raises(ValueError): + agency.to_realtime() + + +def test_realtime_agency_unknown_handoff_target_raises() -> None: + agency = _build_simple_realtime_agency() + runtime_state = agency.get_agent_runtime_state("Concierge") + runtime_state.handoffs[0].agent_name = "Missing" + + with pytest.raises(ValueError): + agency.to_realtime() + + +def test_realtime_agency_missing_realtime_agent_raises() -> None: + agency = _build_simple_realtime_agency() + realtime_agency = agency.to_realtime() + realtime_agency._realtime_agents.pop("Concierge") + + with pytest.raises(ValueError): + realtime_agency._resolve_entry(agency.entry_points[0]) # type: ignore[arg-type] + + +def test_run_realtime_accepts_realtime_agency() -> None: + agency = _build_simple_realtime_agency() + realtime_agency = agency.to_realtime() + + app = run_realtime(agency=realtime_agency, voice="alloy", return_app=True) + if app is None: + pytest.skip("FastAPI extras not installed") + + assert {route.path for route in app.routes} >= {"/realtime"} + + with pytest.raises(ValueError): + run_realtime(agency=realtime_agency, entry_agent=agency.entry_points[0], return_app=True) + + +def test_run_realtime_return_app_registers_realtime_endpoint() -> None: + agent = Agent(name="Voice Agent", instructions="Be concise.") + agency = Agency(agent) + + app = run_realtime(agency=agency, voice="alloy", return_app=True) + if app is None: + pytest.skip("FastAPI extras not installed") + + assert {route.path for route in app.routes} >= {"/realtime"} + + +def test_realtime_demo_entrypoint() -> None: + import examples.interactive.realtime.demo as realtime_demo + + assert hasattr(realtime_demo, "main") + + +def test_run_realtime_defaults_voice_to_entry_agent(monkeypatch: pytest.MonkeyPatch) -> None: + agent = Agent(name="Voiceful", instructions="Respond aloud.", voice="nova") + agency = Agency(agent) + + from agency_swarm.integrations import realtime as realtime_module + + captured = {} + original_init = realtime_module.RealtimeSessionFactory.__init__ + + def capture_init(self, realtime_agency, base_model_settings): + captured["voice"] = base_model_settings.get("voice") + original_init(self, realtime_agency, base_model_settings) + + monkeypatch.setattr(realtime_module.RealtimeSessionFactory, "__init__", capture_init) + + app = run_realtime(agency=agency, return_app=True) + if app is None: + pytest.skip("FastAPI extras not installed") + + assert captured.get("voice") == "nova" diff --git a/tests/test_agency_modules/test_agency_initialization.py b/tests/test_agency_modules/test_agency_initialization.py index 7a8738ad..9d598a66 100644 --- a/tests/test_agency_modules/test_agency_initialization.py +++ b/tests/test_agency_modules/test_agency_initialization.py @@ -3,6 +3,7 @@ import pytest from agency_swarm import Agency, Agent +from agency_swarm.agent.constants import AGENT_REALTIME_VOICES from agency_swarm.tools.send_message import SendMessage # --- Fixtures --- @@ -149,3 +150,46 @@ def test_agency_send_message_tool_class_does_not_mutate_agent(mock_agent): assert mock_agent.send_message_tool_class is None Agency(mock_agent, send_message_tool_class=_CustomSendMessage) assert mock_agent.send_message_tool_class is None + + +def test_agency_randomizes_agent_voices_with_seed(): + agent_a = Agent(name="AgentA", instructions="Respond with short answers.") + agent_b = Agent(name="AgentB", instructions="Provide detailed explanations.") + + Agency( + agent_a, + agent_b, + randomize_agent_voices=True, + voice_random_seed=21, + ) + + assert agent_a.voice in AGENT_REALTIME_VOICES + assert agent_b.voice in AGENT_REALTIME_VOICES + assert agent_a.voice != agent_b.voice + + clone_a = Agent(name="AgentA", instructions="Respond with short answers.") + clone_b = Agent(name="AgentB", instructions="Provide detailed explanations.") + Agency( + clone_a, + clone_b, + randomize_agent_voices=True, + voice_random_seed=21, + ) + + assert (agent_a.voice, agent_b.voice) == (clone_a.voice, clone_b.voice) + + +def test_agency_randomization_respects_explicit_voice(): + anchored = Agent(name="Anchored", instructions="Maintain voice.", voice="echo") + floating = Agent(name="Floating", instructions="Experiment with voices.") + + Agency( + anchored, + floating, + randomize_agent_voices=True, + voice_random_seed=5, + ) + + assert anchored.voice == "echo" + assert floating.voice in AGENT_REALTIME_VOICES + assert floating.voice != "echo" diff --git a/tests/test_agent_modules/test_agent_initialization.py b/tests/test_agent_modules/test_agent_initialization.py index a54afd13..85355de1 100644 --- a/tests/test_agent_modules/test_agent_initialization.py +++ b/tests/test_agent_modules/test_agent_initialization.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock +import pytest from agents import FunctionTool, ModelSettings, StopAtTools from pydantic import BaseModel, Field @@ -20,6 +21,16 @@ class SimpleOutput(BaseModel): # --- Initialization Tests --- +def test_agent_initialization_with_voice(): + agent = Agent(name="Voicey", instructions="Talk", voice="ash") + assert agent.voice == "ash" + + +def test_agent_initialization_invalid_voice(): + with pytest.raises(ValueError, match="Invalid voice 'invalid'"): + Agent(name="Voicey", instructions="Talk", voice="invalid") + + def test_agent_initialization_minimal(): """Test basic Agent initialization with minimal parameters.""" agent = Agent(name="Agent1", instructions="Be helpful") diff --git a/tests/test_integrations_modules/__init__.py b/tests/test_integrations_modules/__init__.py new file mode 100644 index 00000000..e69de29b