diff --git a/docs/additional-features/chatkit-integration.mdx b/docs/additional-features/chatkit-integration.mdx new file mode 100644 index 00000000..c35d2c65 --- /dev/null +++ b/docs/additional-features/chatkit-integration.mdx @@ -0,0 +1,99 @@ +--- +title: "ChatKit Integration" +description: "Add a chat UI to your agency with OpenAI ChatKit." +icon: "comments" +--- + +[OpenAI ChatKit](https://platform.openai.com/docs/guides/chatkit) is a React-based chat UI. Agency Swarm includes a ready-to-use demo. + + +Requires [Node.js](https://nodejs.org/) v18+ and npm. + + +## Quick Start + +```python +from agency_swarm import Agency, Agent + +agent = Agent(name="Assistant", instructions="You are helpful.") +agency = Agency(agent, name="my_agency") + +agency.chatkit_demo() # Opens browser at http://localhost:3000 +``` + + +```python +agency.chatkit_demo( + host="0.0.0.0", # Backend host + port=8000, # Backend port + frontend_port=3000, # Frontend port + cors_origins=None, # CORS origins list + open_browser=True, # Auto-open browser +) +``` + + +--- + +## Backend Only + +Use `enable_chatkit=True` to expose the ChatKit endpoint without the demo frontend: + +```python +from agency_swarm import Agency, Agent, run_fastapi + +def create_agency(**kwargs): + agent = Agent(name="Assistant", instructions="You are helpful.") + return Agency(agent, name="my_agency") + +run_fastapi(agencies={"my_agency": create_agency}, enable_chatkit=True) +``` + +This exposes `/{agency_name}/chatkit` for your own ChatKit frontend. + + + +Point your ChatKit React app to the backend: + +```typescript +const chatkit = useChatKit({ + api: { url: "http://localhost:8000/my_agency/chatkit" }, +}); +``` + +Or use a Vite proxy: + +```typescript +// vite.config.ts +export default defineConfig({ + server: { + proxy: { + "/chatkit": { + target: "http://localhost:8000", + rewrite: (path) => `/my_agency${path}`, + }, + }, + }, +}); +``` + + +--- + +## Persistence + +By default, ChatKit is stateless. For conversation persistence, pass custom `RunHooks` via `hooks_override` parameter in `get_response` or `get_response_stream`: + +```python +from agents import RunHooks + +class MyPersistenceHooks(RunHooks): + async def on_agent_end(self, context, agent, output): + messages = context.context.thread_manager.get_all_messages() + db.save(thread_id, messages) # Save to your database + +result = await agency.get_response( + message=user_message, + hooks_override=MyPersistenceHooks(), +) +``` diff --git a/docs/additional-features/fastapi-integration.mdx b/docs/additional-features/fastapi-integration.mdx index d358e6bd..ceb99665 100644 --- a/docs/additional-features/fastapi-integration.mdx +++ b/docs/additional-features/fastapi-integration.mdx @@ -40,6 +40,7 @@ Optionally, you can specify following parameters: - return_app (default: False) - If True, will return the FastAPI instead of running the server - cors_origins: (default: ["*"]) - enable_agui (default: `False`) - Enable AG-UI protocol compatibility for streaming endpoints +- enable_chatkit (default: `False`) - Enable [ChatKit](https://platform.openai.com/docs/guides/chatkit) protocol endpoint - enable_logging (default: `False`) - Enable request tracking and expose `/get_logs` endpoint - logs_dir (default: `"activity-logs"`) - Directory for log files when logging is enabled @@ -184,6 +185,9 @@ print("Response:", tool_response.json()) - **AG-UI Protocol:** When `enable_agui=True`, only the streaming endpoint is exposed and follows the AG-UI protocol for enhanced frontend integration. +- **ChatKit Protocol:** + When `enable_chatkit=True`, adds `/{agency_name}/chatkit` endpoint. See [ChatKit Integration](/additional-features/chatkit-integration). + --- ## Inspecting Tool Schemas @@ -251,3 +255,4 @@ When using the agency endpoints (`/{your_agency}/get_response` and `/{your_agenc Behavior with `file_urls`: - The server downloads each URL, uploads it to OpenAI, waits until processed, and uses the resulting File IDs. - `file_ids_map` (shape: `{ filename: file_id }`) is returned in the non‑streaming JSON response of `POST /get_response` and in the final `event: messages` SSE payload of `POST /get_response_stream`. + diff --git a/docs/docs.json b/docs/docs.json index 9a276f5b..30383ec9 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -98,6 +98,7 @@ "additional-features/guardrails", "additional-features/streaming", "additional-features/fastapi-integration", + "additional-features/chatkit-integration", "additional-features/mcp-tools-server", { "group": "Custom Communication Flows", diff --git a/examples/interactive/chatkit_demo.py b/examples/interactive/chatkit_demo.py new file mode 100644 index 00000000..da12e865 --- /dev/null +++ b/examples/interactive/chatkit_demo.py @@ -0,0 +1,70 @@ +""" +Agency Swarm ChatKit Demo + +This example demonstrates the ChatKit UI capabilities of Agency Swarm v1.x. +Sets up a frontend and backend server for the OpenAI ChatKit UI chat demo. +""" + +import sys +from pathlib import Path + +# Add the src directory to the path so we can import agency_swarm +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from agency_swarm import Agency, Agent, RunContextWrapper, function_tool + + +@function_tool() +async def example_tool(wrapper: RunContextWrapper) -> str: + """Example tool for chatkit demo""" + return "Example tool executed" + + +def create_demo_agency(): + """Create a demo agency for chatkit demo""" + + # Create agents using v1.x pattern (direct instantiation) + ceo = Agent( + name="CEO", + description="Chief Executive Officer - oversees all operations", + instructions="You are the CEO responsible for high-level decision making and coordination.", + tools=[example_tool], + ) + + worker = Agent( + name="Worker", + description="Worker - performs tasks", + instructions="Follow instructions given by the CEO.", + tools=[example_tool], + ) + + # Create agency with communication flows (v1.x pattern) + agency = Agency( + ceo, # Entry point agent (positional argument) + communication_flows=[ceo > worker], + name="ChatKitDemoAgency", + ) + + return agency + + +def main(): + """Launch interactive ChatKit demo""" + print("Agency Swarm ChatKit Demo") + print("=" * 50) + print() + + try: + agency = create_demo_agency() + # Launch the ChatKit UI demo with backend and frontend servers. + agency.chatkit_demo() + + except Exception as e: + print(f"❌ Demo failed with error: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 5b1d4d0a..ee3b152b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,13 @@ artifacts = [ "ui/demos/copilot/**/*.mjs", "ui/demos/copilot/**/*.ico", "ui/demos/copilot/**/*.svg", + # ChatKit frontend files + "ui/demos/chatkit/**/*.ts", + "ui/demos/chatkit/**/*.tsx", + "ui/demos/chatkit/**/*.css", + "ui/demos/chatkit/**/*.json", + "ui/demos/chatkit/**/*.mjs", + "ui/demos/chatkit/**/*.html", ] [tool.hatch.build.targets.wheel.force-include] @@ -104,8 +111,11 @@ artifacts = [ include = [ "src/agency_swarm/**/*.py", "src/agency_swarm/**/*.ts", + "src/agency_swarm/**/*.tsx", + "src/agency_swarm/**/*.html", "src/agency_swarm/ui/templates/**", "src/agency_swarm/ui/demos/copilot/**", + "src/agency_swarm/ui/demos/chatkit/**", ] exclude = ["**/node_modules/**", "**/.next/**"] diff --git a/src/agency_swarm/agency/core.py b/src/agency_swarm/agency/core.py index 701354c0..115f81e8 100644 --- a/src/agency_swarm/agency/core.py +++ b/src/agency_swarm/agency/core.py @@ -495,3 +495,18 @@ def copilot_demo( from .visualization import copilot_demo return copilot_demo(self, host, port, frontend_port, cors_origins) + + def chatkit_demo( + self, + host: str = "0.0.0.0", + port: int = 8000, + frontend_port: int = 3000, + cors_origins: list[str] | None = None, + open_browser: bool = True, + ) -> None: + """ + Run a ChatKit demo of the agency. + """ + from .visualization import chatkit_demo + + return chatkit_demo(self, host, port, frontend_port, cors_origins, open_browser) diff --git a/src/agency_swarm/agency/visualization.py b/src/agency_swarm/agency/visualization.py index 89a94724..1103e9f3 100644 --- a/src/agency_swarm/agency/visualization.py +++ b/src/agency_swarm/agency/visualization.py @@ -184,3 +184,21 @@ def copilot_demo( Run a copilot demo of the agency. """ CopilotDemoLauncher.start(agency, host=host, port=port, frontend_port=frontend_port, cors_origins=cors_origins) + + +def chatkit_demo( + agency: "Agency", + host: str = "0.0.0.0", + port: int = 8000, + frontend_port: int = 3000, + cors_origins: list[str] | None = None, + open_browser: bool = True, +) -> None: + """ + Run a ChatKit demo of the agency. + """ + from agency_swarm.ui.demos.chatkit import ChatkitDemoLauncher + + ChatkitDemoLauncher.start( + agency, host=host, port=port, frontend_port=frontend_port, cors_origins=cors_origins, open_browser=open_browser + ) diff --git a/src/agency_swarm/integrations/fastapi.py b/src/agency_swarm/integrations/fastapi.py index 7d971f5f..9bc23404 100644 --- a/src/agency_swarm/integrations/fastapi.py +++ b/src/agency_swarm/integrations/fastapi.py @@ -20,6 +20,7 @@ def run_fastapi( return_app: bool = False, cors_origins: list[str] | None = None, enable_agui: bool = False, + enable_chatkit: bool = False, enable_logging: bool = False, logs_dir: str = "activity-logs", ): @@ -39,6 +40,11 @@ def run_fastapi( server_url : str | None Optional base URL to be included in the server OpenAPI schema. Defaults to ``http://{host}:{port}`` + enable_agui : bool + Enable AG-UI protocol compatibility for streaming endpoints. + enable_chatkit : bool + Enable ChatKit protocol endpoints (adds /{agency}/chatkit routes). + ChatKit is OpenAI's pre-built chat UI framework. enable_logging : bool Enable request tracking and file logging. When enabled, adds middleware to track requests and allows conditional @@ -154,6 +160,17 @@ def run_fastapi( endpoints.append(f"/{agency_name}/get_response") endpoints.append(f"/{agency_name}/get_response_stream") + # Add ChatKit endpoint if enabled + if enable_chatkit: + from .fastapi_utils.chatkit_handlers import ChatkitRequest, make_chatkit_endpoint + + app.add_api_route( + f"/{agency_name}/chatkit", + make_chatkit_endpoint(ChatkitRequest, agency_factory, verify_token), + methods=["POST"], + ) + endpoints.append(f"/{agency_name}/chatkit") + app.add_api_route( f"/{agency_name}/get_metadata", make_metadata_endpoint(agency_metadata, verify_token), @@ -189,6 +206,11 @@ def run_fastapi( if return_app: return app - logger.info(f"Starting FastAPI {'AG-UI ' if enable_agui else ''}server at http://{host}:{port}") + mode_str = "" + if enable_agui: + mode_str = "AG-UI " + elif enable_chatkit: + mode_str = "ChatKit " + logger.info(f"Starting FastAPI {mode_str}server at http://{host}:{port}") uvicorn.run(app, host=host, port=port) diff --git a/src/agency_swarm/integrations/fastapi_utils/chatkit_handlers.py b/src/agency_swarm/integrations/fastapi_utils/chatkit_handlers.py new file mode 100644 index 00000000..0d909725 --- /dev/null +++ b/src/agency_swarm/integrations/fastapi_utils/chatkit_handlers.py @@ -0,0 +1,173 @@ +""" +ChatKit endpoint handlers for Agency Swarm. + +Provides FastAPI endpoint handlers that implement the ChatKit protocol, +enabling OpenAI's ChatKit UI to communicate with Agency Swarm agents. +""" + +import json +import logging +import time +import uuid +from collections.abc import AsyncGenerator, Callable +from typing import Any + +from fastapi import Depends, Request +from fastapi.responses import Response, StreamingResponse +from pydantic import BaseModel, Field + +from agency_swarm import Agency +from agency_swarm.tools.mcp_manager import attach_persistent_mcp_servers +from agency_swarm.ui.core.chatkit_adapter import ChatkitAdapter + +logger = logging.getLogger(__name__) + + +class ChatkitUserInput(BaseModel): + """User input for ChatKit messages.""" + + content: list[dict[str, Any]] = Field(default_factory=list) + attachments: list[str] = Field(default_factory=list) + quoted_text: str | None = None + + +class ChatkitParams(BaseModel): + """Parameters for ChatKit requests.""" + + thread_id: str | None = None + input: ChatkitUserInput | None = None + + +class ChatkitRequest(BaseModel): + """Request model for ChatKit protocol. + + ChatKit sends requests with a 'type' field indicating the operation: + - threads.create: Create new thread with initial message + - threads.add_user_message: Add message to existing thread + - threads.get_by_id: Get thread by ID (non-streaming) + - items.list: List items in a thread (non-streaming) + """ + + type: str = "threads.create" + params: ChatkitParams = Field(default_factory=ChatkitParams) + context: dict[str, Any] = Field(default_factory=dict) + + +def _serialize_event(event: dict[str, Any]) -> bytes: + """Serialize a ChatKit event to SSE format.""" + return f"data: {json.dumps(event)}\n\n".encode() + + +def make_chatkit_endpoint( + request_model: type[ChatkitRequest], + agency_factory: Callable[..., Agency], + verify_token: Callable[..., Any], +) -> Callable[..., Any]: + """Create a ChatKit protocol endpoint handler. + + This endpoint handles requests from ChatKit UI and streams responses + using the ChatKit ThreadStreamEvent protocol. + + Args: + request_model: Pydantic model for request validation (ChatkitRequest) + agency_factory: Factory function that returns an Agency instance + verify_token: Token verification dependency + + Returns: + FastAPI endpoint handler function + """ + _ = request_model # Mark as used + + async def handler( + request: Request, + token: str = Depends(verify_token), + ) -> Response: + """Handle ChatKit protocol requests.""" + body = await request.body() + try: + data = json.loads(body) if body else {} + except json.JSONDecodeError: + data = {} + + req_type = data.get("type", "threads.create") + params = data.get("params", {}) + context = data.get("context", {}) + + # Handle non-streaming requests + if req_type in ("threads.get_by_id", "threads.list", "items.list", "items.feedback"): + return Response( + content=json.dumps({"data": [], "has_more": False}), + media_type="application/json", + ) + + thread_id = params.get("thread_id") or str(uuid.uuid4()) + run_id = str(uuid.uuid4()) + is_new_thread = req_type == "threads.create" + + # Extract user message + user_message = "" + user_input = params.get("input", {}) + if user_input: + content_list = user_input.get("content", []) + for part in content_list: + if isinstance(part, dict): + if part.get("type") == "input_text": + user_message += part.get("text", "") + elif "text" in part: + user_message += part.get("text", "") + + if not user_message: + return Response( + content=json.dumps({"thread_id": thread_id, "status": "no_input"}), + media_type="application/json", + ) + + agency = agency_factory() + await attach_persistent_mcp_servers(agency) + + async def event_generator() -> AsyncGenerator[bytes]: + """Generate ChatKit SSE events from Agency Swarm responses.""" + adapter = ChatkitAdapter() + + if is_new_thread: + yield _serialize_event(adapter._create_thread_created_event(thread_id)) + + user_item_id = f"user_{uuid.uuid4().hex[:8]}" + user_item = { + "id": user_item_id, + "type": "user_message", + "thread_id": thread_id, + "created_at": int(time.time()), + "content": [{"type": "input_text", "text": user_message}], + "attachments": [], + } + yield _serialize_event({"type": "thread.item.done", "item": user_item}) + + try: + async for event in agency.get_response_stream( + message=user_message, + context_override=context if context else None, + ): + chatkit_event = adapter.openai_to_chatkit_events(event, run_id=run_id, thread_id=thread_id) + if chatkit_event: + events = chatkit_event if isinstance(chatkit_event, list) else [chatkit_event] + for evt in events: + yield _serialize_event(evt) + + except Exception as exc: + logger.exception("Error during ChatKit streaming") + yield _serialize_event({"type": "thread.error", "error": {"message": str(exc)}}) + + yield _serialize_event({"type": "thread.run.completed", "thread_id": thread_id, "run_id": run_id}) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + return handler diff --git a/src/agency_swarm/ui/core/chatkit_adapter.py b/src/agency_swarm/ui/core/chatkit_adapter.py new file mode 100644 index 00000000..645715ac --- /dev/null +++ b/src/agency_swarm/ui/core/chatkit_adapter.py @@ -0,0 +1,484 @@ +""" +ChatKit Adapter for Agency Swarm. + +Converts OpenAI Agents SDK streaming events to ChatKit ThreadStreamEvent format. +This enables Agency Swarm agents to be served via OpenAI's ChatKit UI. + +ChatKit event types: +- thread.created - New thread created +- thread.item.added - New item (message, tool call) added +- thread.item.updated - Item updated (streaming text deltas) +- thread.item.done - Item finalized + +ThreadItem types: +- assistant_message - Assistant response +- user_message - User input +- client_tool_call - Tool invocation +""" + +import json +import logging +import time +import uuid +from typing import Any + +logger = logging.getLogger(__name__) + + +class ChatkitAdapter: + """ + Converts between OpenAI Agents SDK events and ChatKit ThreadStreamEvent format. + + Each instance maintains its own run state to track message IDs and tool calls. + """ + + _TOOL_TYPES = {"function_call", "file_search_call", "code_interpreter_call"} + _TOOL_ARG_DELTA_TYPES = { + "response.function_call_arguments.delta", + "response.code_interpreter_call_code.delta", + } + + def __init__(self) -> None: + """Initialize a new ChatkitAdapter with clean per-instance state.""" + self._run_state: dict[str, dict[str, Any]] = {} + self._message_counter = 0 + + def clear_run_state(self, run_id: str | None = None) -> None: + """Clear run state for a specific run_id or all runs if run_id is None.""" + if run_id is None: + self._run_state.clear() + else: + self._run_state.pop(run_id, None) + + def _generate_item_id(self) -> str: + """Generate a unique item ID for ChatKit items.""" + self._message_counter += 1 + return f"item_{self._message_counter}_{uuid.uuid4().hex[:8]}" + + def _create_thread_created_event(self, thread_id: str) -> dict[str, Any]: + """Create a thread.created event.""" + return { + "type": "thread.created", + "thread": { + "id": thread_id, + "created_at": int(time.time()), + "metadata": {}, + }, + } + + def _create_item_added_event( + self, + item_id: str, + item_type: str, + content: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Create a thread.item.added event.""" + item: dict[str, Any] = { + "id": item_id, + "type": item_type, + "created_at": int(time.time()), + } + if content: + item.update(content) + return { + "type": "thread.item.added", + "item": item, + } + + def _create_item_updated_event( + self, + item_id: str, + update: dict[str, Any], + ) -> dict[str, Any]: + """Create a thread.item.updated event for streaming content.""" + return { + "type": "thread.item.updated", + "item_id": item_id, + "update": update, + } + + def _create_item_done_event(self, item_id: str) -> dict[str, Any]: + """Create a thread.item.done event to finalize an item.""" + return { + "type": "thread.item.done", + "item_id": item_id, + } + + def _create_assistant_message_item( + self, + item_id: str, + text: str = "", + ) -> dict[str, Any]: + """Create an assistant_message item structure.""" + return { + "content": [ + { + "type": "output_text", + "text": text, + "annotations": [], + } + ], + } + + def _create_tool_call_item( + self, + call_id: str, + name: str, + arguments: str = "{}", + status: str = "in_progress", + output: str | None = None, + ) -> dict[str, Any]: + """Create a client_tool_call item structure.""" + item: dict[str, Any] = { + "call_id": call_id, + "name": name, + "arguments": arguments, + "status": status, + } + if output is not None: + item["output"] = output + return item + + def openai_to_chatkit_events( + self, + event: Any, + *, + run_id: str, + thread_id: str, + ) -> dict[str, Any] | list[dict[str, Any]] | None: + """ + Convert a single OpenAI Agents SDK StreamEvent into one or more ChatKit events. + + Args: + event: The OpenAI Agents SDK event to convert + run_id: Unique identifier for this run + thread_id: The ChatKit thread ID + + Returns: + A ChatKit event dict, list of event dicts, or None if no conversion + """ + state = self._run_state.setdefault( + run_id, + { + "call_id_by_item": {}, + "item_id_by_call": {}, + "current_message_id": None, + "accumulated_text": {}, + }, + ) + call_id_by_item: dict[str, str] = state["call_id_by_item"] + item_id_by_call: dict[str, str] = state["item_id_by_call"] + accumulated_text: dict[str, str] = state["accumulated_text"] + + logger.debug("Received event: %s", event) + try: + converted_event = None + + if getattr(event, "type", None) == "raw_response_event": + converted_event = self._handle_raw_response( + event.data, + call_id_by_item, + item_id_by_call, + accumulated_text, + state, + ) + + if getattr(event, "type", None) == "run_item_stream_event": + converted_event = self._handle_run_item_stream( + event, + state, + accumulated_text, + ) + + return converted_event + + except Exception as exc: + logger.exception("Error converting event to ChatKit format") + return { + "type": "thread.error", + "error": {"message": str(exc)}, + } + + def _handle_raw_response( + self, + oe: Any, + call_id_by_item: dict[str, str], + item_id_by_call: dict[str, str], + accumulated_text: dict[str, str], + state: dict[str, Any], + ) -> dict[str, Any] | list[dict[str, Any]] | None: + """Translate raw_response_event.data into ChatKit events.""" + etype = getattr(oe, "type", "") + + # --- Output item added ------------------------------------------------- + if etype == "response.output_item.added": + raw_item = getattr(oe, "item", None) + if not raw_item: + logger.warning("raw_response_event ignored: missing item for type %s", etype) + return None + + # Assistant message start + if getattr(raw_item, "type", "") == "message" and getattr(raw_item, "role", "") == "assistant": + msg_id = getattr(raw_item, "id", None) or self._generate_item_id() + state["current_message_id"] = msg_id + accumulated_text[msg_id] = "" + return self._create_item_added_event( + msg_id, + "assistant_message", + self._create_assistant_message_item(msg_id, ""), + ) + + # Tool call start + if getattr(raw_item, "type", "") in self._TOOL_TYPES: + call_id, tool_name, _ = self._tool_meta(raw_item) + if not call_id: + logger.warning("raw_response_event ignored: tool call without call_id") + return None + item_id = self._generate_item_id() + call_id_by_item[raw_item.id] = call_id + item_id_by_call[call_id] = item_id + return self._create_item_added_event( + item_id, + "client_tool_call", + self._create_tool_call_item(call_id, tool_name or "tool", "{}", "in_progress"), + ) + + # --- Text delta -------------------------------------------------------- + if etype == "response.output_text.delta": + raw_item_id: str | None = getattr(oe, "item_id", None) + delta_text: str = getattr(oe, "delta", "") + if raw_item_id and delta_text: + # Use the current message ID from state + current_msg_id: str = state.get("current_message_id") or raw_item_id + accumulated_text[current_msg_id] = accumulated_text.get(current_msg_id, "") + delta_text + return self._create_item_updated_event( + current_msg_id, + { + "type": "assistant_message.content_part.text_delta", + "content_index": 0, + "delta": delta_text, + }, + ) + logger.warning("raw_response_event ignored: text delta without item_id") + return None + + # --- Output item done -------------------------------------------------- + if etype == "response.output_item.done": + raw_item = getattr(oe, "item", None) + if not raw_item: + logger.warning("raw_response_event ignored: output_item.done without item") + return None + + if getattr(raw_item, "type", "") == "message": + msg_id = state.get("current_message_id") or getattr(raw_item, "id", None) + if msg_id: + return self._create_item_done_event(msg_id) + logger.warning("raw_response_event ignored: message done without id") + return None + + if getattr(raw_item, "type", "") in self._TOOL_TYPES: + call_id, tool_name, arguments = self._tool_meta(raw_item) + if not call_id: + logger.warning("raw_response_event ignored: tool done without call_id") + return None + chatkit_item_id: str | None = item_id_by_call.get(call_id) + if chatkit_item_id: + return [ + self._create_item_updated_event( + chatkit_item_id, + { + "type": "client_tool_call.arguments_done", + "arguments": arguments or "{}", + }, + ), + self._create_item_done_event(chatkit_item_id), + ] + return None + + # --- Tool-argument deltas --------------------------------------------- + if etype in self._TOOL_ARG_DELTA_TYPES: + tool_delta_item_id: str | None = getattr(oe, "item_id", None) + tool_call_id: str | None = call_id_by_item.get(tool_delta_item_id) if tool_delta_item_id else None + if tool_call_id: + chatkit_tool_item_id: str | None = item_id_by_call.get(tool_call_id) + if chatkit_tool_item_id: + return self._create_item_updated_event( + chatkit_tool_item_id, + { + "type": "client_tool_call.arguments_delta", + "delta": getattr(oe, "delta", ""), + }, + ) + logger.warning("raw_response_event ignored: tool arg delta without mapping") + return None + + return None + + def _handle_run_item_stream( + self, + event: Any, + state: dict[str, Any], + accumulated_text: dict[str, str], + ) -> dict[str, Any] | list[dict[str, Any]] | None: + """Translate run_item_stream_event into ChatKit events.""" + name = getattr(event, "name", "") + item_id_by_call: dict[str, str] = state["item_id_by_call"] + + # --- Assistant message complete ---------------------------------------- + if name == "message_output_created": + output_item = getattr(event, "item", None) + if not output_item: + logger.warning("run_item_stream_event ignored: missing output item for %s", name) + return None + + raw_item = getattr(output_item, "raw_item", None) + msg_id: str | None = state.get("current_message_id") or getattr(raw_item, "id", None) + output_content = (getattr(raw_item, "content", None) or [None])[0] + if not output_content or not msg_id: + return None + + output_text = getattr(output_content, "text", None) + if not output_text: + return None + + # Final message snapshot + return self._create_item_updated_event( + msg_id, + { + "type": "assistant_message.content_part.done", + "content_index": 0, + "content": { + "type": "output_text", + "text": output_text, + "annotations": [], + }, + }, + ) + + # --- Tool output ------------------------------------------------------- + if name == "tool_output": + output_item = getattr(event, "item", None) + if not output_item: + logger.warning("run_item_stream_event ignored: tool_output without item") + return None + + raw_item = getattr(output_item, "raw_item", None) + call_id = raw_item.get("call_id") if isinstance(raw_item, dict) else getattr(output_item, "call_id", None) + if not call_id: + logger.warning("run_item_stream_event ignored: tool_output without call_id") + return None + + output_text = getattr(output_item, "output", None) + item_id = item_id_by_call.get(call_id) + if item_id and output_text: + return self._create_item_updated_event( + item_id, + { + "type": "client_tool_call.output", + "output": output_text, + "status": "completed", + }, + ) + return None + + return None + + def _tool_meta(self, raw_item: Any) -> tuple[str | None, str | None, str | None]: + """Return (call_id, tool_name, arguments) for a tool raw_item.""" + item_type = getattr(raw_item, "type", "") + + if item_type == "function_call": + return ( + getattr(raw_item, "call_id", None), + getattr(raw_item, "name", "tool"), + getattr(raw_item, "arguments", None), + ) + + if item_type == "file_search_call": + return ( + getattr(raw_item, "id", None), + "FileSearchTool", + json.dumps( + { + "queries": getattr(raw_item, "queries", None), + "results": getattr(raw_item, "results", None), + } + ), + ) + + if item_type == "code_interpreter_call": + return ( + getattr(raw_item, "id", None), + "CodeInterpreterTool", + json.dumps( + { + "code": getattr(raw_item, "code", None), + "container_id": getattr(raw_item, "container_id", None), + "outputs": getattr(raw_item, "outputs", None), + } + ), + ) + + return None, None, None + + @staticmethod + def chatkit_messages_to_chat_history(items: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Convert a list of ChatKit thread items to Agency Swarm chat history format. + + Args: + items: List of ChatKit thread items (user_message, assistant_message, etc.) + + Returns: + List of messages in Agency Swarm's flat chat history format + """ + messages: list[dict[str, Any]] = [] + + for item in items: + item_type = item.get("type", "") + + # User message + if item_type == "user_message": + content = item.get("content", []) + text = "" + for part in content: + if part.get("type") == "input_text": + text += part.get("text", "") + messages.append({"role": "user", "content": text}) + + # Assistant message + elif item_type == "assistant_message": + content = item.get("content", []) + text = "" + for part in content: + if part.get("type") == "output_text": + text += part.get("text", "") + messages.append({"role": "assistant", "content": text}) + + # Tool call + elif item_type == "client_tool_call": + call_id = item.get("call_id", item.get("id", "")) + name = item.get("name", "tool") + arguments = item.get("arguments", "{}") + status = item.get("status", "completed") + messages.append( + { + "id": item.get("id", call_id), + "call_id": call_id, + "type": "function_call", + "arguments": arguments, + "name": name, + "status": status, + } + ) + # If tool has output, add it + if item.get("output"): + messages.append( + { + "call_id": call_id, + "output": item["output"], + "type": "function_call_output", + } + ) + + return messages diff --git a/src/agency_swarm/ui/demos/chatkit.py b/src/agency_swarm/ui/demos/chatkit.py new file mode 100644 index 00000000..4926853c --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit.py @@ -0,0 +1,88 @@ +from agency_swarm import Agency + + +class ChatkitDemoLauncher: + @staticmethod + def start( + agency_instance: Agency, + host: str = "0.0.0.0", + port: int = 8000, + frontend_port: int = 3000, + cors_origins: list[str] | None = None, + open_browser: bool = True, + ) -> None: + """Launch the ChatKit UI demo with backend and frontend servers.""" + import atexit + import os + import shutil + import subprocess + import threading + import time + import webbrowser + from pathlib import Path + + from agency_swarm.integrations.fastapi import run_fastapi + + fe_path = Path(__file__).parent / "chatkit" + + npm_exe = shutil.which("npm") or shutil.which("npm.cmd") + if npm_exe is None: + raise RuntimeError( + "npm was not found on your PATH. Install Node.js (https://nodejs.org) " + "and ensure `npm` is accessible before running ChatkitDemoLauncher." + ) + + if not (fe_path / "node_modules").exists(): + print( + "\033[93m[ChatKit Demo] 'node_modules' not found in chatkit app directory. " + "Running 'npm install' to install frontend dependencies...\033[0m" + ) + try: + subprocess.check_call([npm_exe, "install"], cwd=fe_path) + print( + "\033[92m[ChatKit Demo] Frontend dependencies installed successfully. " + "Frontend might take a few seconds to load.\033[0m" + ) + except subprocess.CalledProcessError as e: + raise RuntimeError( + f"Failed to install frontend dependencies in {fe_path}. Please check your npm setup and try again." + ) from e + + agency_name = getattr(agency_instance, "name", None) or "agency" + agency_name = agency_name.replace(" ", "_") + + # Set environment variables for the Vite frontend + os.environ["CHATKIT_BACKEND_URL"] = f"http://{host}:{port}" + os.environ["CHATKIT_FRONTEND_PORT"] = str(frontend_port) + os.environ["CHATKIT_AGENCY_NAME"] = agency_name + + proc = subprocess.Popen( + [npm_exe, "run", "dev"], + cwd=fe_path, + stdout=subprocess.DEVNULL, + stderr=subprocess.STDOUT, + ) + atexit.register(proc.terminate) + + url = f"http://localhost:{frontend_port}" + print( + f"\n\033[92;1m🚀 ChatKit UI running at {url}\n" + " It might take a moment for the page to load the first time you open it.\033[0m\n" + ) + + if open_browser: + + def delayed_open() -> None: + time.sleep(3) + webbrowser.open(url) + + threading.Thread(target=delayed_open, daemon=True).start() + + run_fastapi( + agencies={agency_name: lambda **kwargs: agency_instance}, + host=host, + port=port, + app_token_env="", + cors_origins=cors_origins, + enable_chatkit=True, + ) diff --git a/src/agency_swarm/ui/demos/chatkit/.gitignore b/src/agency_swarm/ui/demos/chatkit/.gitignore new file mode 100644 index 00000000..b431156f --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/.gitignore @@ -0,0 +1,3 @@ +node_modules +dist +.vite diff --git a/src/agency_swarm/ui/demos/chatkit/index.html b/src/agency_swarm/ui/demos/chatkit/index.html new file mode 100644 index 00000000..5c4f940d --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/index.html @@ -0,0 +1,13 @@ + + + + + + Agency Swarm - ChatKit Demo + + + +
+ + + diff --git a/src/agency_swarm/ui/demos/chatkit/package.json b/src/agency_swarm/ui/demos/chatkit/package.json new file mode 100644 index 00000000..9568f024 --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/package.json @@ -0,0 +1,26 @@ +{ + "name": "agency-swarm-chatkit-demo", + "version": "0.1.0", + "private": true, + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "preview": "vite preview" + }, + "dependencies": { + "@openai/chatkit-react": ">=1.1.1 <2.0.0", + "react": "^19.2.0", + "react-dom": "^19.2.0" + }, + "devDependencies": { + "@tailwindcss/postcss": "^4", + "@types/react": "^19.0.8", + "@types/react-dom": "^19.0.3", + "@vitejs/plugin-react-swc": "^3.5.0", + "postcss": "^8.4.47", + "tailwindcss": "^4", + "typescript": "^5.6.3", + "vite": "^7.1.9" + } +} diff --git a/src/agency_swarm/ui/demos/chatkit/postcss.config.mjs b/src/agency_swarm/ui/demos/chatkit/postcss.config.mjs new file mode 100644 index 00000000..c2ddf748 --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/postcss.config.mjs @@ -0,0 +1,5 @@ +export default { + plugins: { + "@tailwindcss/postcss": {}, + }, +}; diff --git a/src/agency_swarm/ui/demos/chatkit/src/App.tsx b/src/agency_swarm/ui/demos/chatkit/src/App.tsx new file mode 100644 index 00000000..840e6c93 --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/src/App.tsx @@ -0,0 +1,9 @@ +import { ChatKitPanel } from "./components/ChatKitPanel"; + +export default function App() { + return ( +
+ +
+ ); +} diff --git a/src/agency_swarm/ui/demos/chatkit/src/components/ChatKitPanel.tsx b/src/agency_swarm/ui/demos/chatkit/src/components/ChatKitPanel.tsx new file mode 100644 index 00000000..32d3a69b --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/src/components/ChatKitPanel.tsx @@ -0,0 +1,48 @@ +import { ChatKit, useChatKit } from "@openai/chatkit-react"; + +// ChatKit connects to our Agency Swarm backend via the Vite proxy +const CHATKIT_API_URL = "/chatkit"; + +export function ChatKitPanel() { + const chatkit = useChatKit({ + api: { + url: CHATKIT_API_URL, + domainKey: "pk_local_dev", // Required for CustomApiConfig + }, + theme: { + colorScheme: "dark", + radius: "round", + density: "normal", + }, + composer: { + attachments: { enabled: false }, + placeholder: "Type a message...", + }, + startScreen: { + greeting: "Welcome to Agency Swarm", + prompts: [ + { + icon: "circle-question", + label: "Say hello", + prompt: "Hello! What can you help me with?", + }, + { + icon: "bolt", + label: "Test calculation", + prompt: "What is 15 * 7?", + }, + { + icon: "user", + label: "Greet me", + prompt: "Please greet me, my name is User", + }, + ], + }, + }); + + return ( +
+ +
+ ); +} diff --git a/src/agency_swarm/ui/demos/chatkit/src/index.css b/src/agency_swarm/ui/demos/chatkit/src/index.css new file mode 100644 index 00000000..446caf6d --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/src/index.css @@ -0,0 +1,34 @@ +@import "tailwindcss"; + +:root { + --background: #ffffff; + --foreground: #171717; + color-scheme: light; +} + +:root[data-color-scheme="dark"] { + --background: #0a0a0a; + --foreground: #ededed; + color-scheme: dark; +} + +@media (prefers-color-scheme: dark) { + :root:not([data-color-scheme]) { + --background: #0a0a0a; + --foreground: #ededed; + color-scheme: dark; + } +} + +@theme inline { + --color-background: var(--background); + --color-foreground: var(--foreground); + --font-sans: Arial, Helvetica, sans-serif; + --font-mono: SFMono-Regular, Consolas, "Liberation Mono", monospace; +} + +body { + background: var(--background); + color: var(--foreground); + font-family: var(--font-sans); +} diff --git a/src/agency_swarm/ui/demos/chatkit/src/main.tsx b/src/agency_swarm/ui/demos/chatkit/src/main.tsx new file mode 100644 index 00000000..65dc4ec6 --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/src/main.tsx @@ -0,0 +1,10 @@ +import { StrictMode } from "react"; +import { createRoot } from "react-dom/client"; +import App from "./App"; +import "./index.css"; + +createRoot(document.getElementById("root")!).render( + + + +); diff --git a/src/agency_swarm/ui/demos/chatkit/tsconfig.json b/src/agency_swarm/ui/demos/chatkit/tsconfig.json new file mode 100644 index 00000000..a4c834a6 --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "target": "ES2020", + "useDefineForClassFields": true, + "lib": ["ES2020", "DOM", "DOM.Iterable"], + "module": "ESNext", + "skipLibCheck": true, + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "resolveJsonModule": true, + "isolatedModules": true, + "noEmit": true, + "jsx": "react-jsx", + "strict": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "noFallthroughCasesInSwitch": true + }, + "include": ["src"] +} diff --git a/src/agency_swarm/ui/demos/chatkit/vite.config.ts b/src/agency_swarm/ui/demos/chatkit/vite.config.ts new file mode 100644 index 00000000..b6d63085 --- /dev/null +++ b/src/agency_swarm/ui/demos/chatkit/vite.config.ts @@ -0,0 +1,25 @@ +import { defineConfig } from "vite"; +import react from "@vitejs/plugin-react-swc"; + +// Backend URL from environment, set by ChatkitDemoLauncher +const backendUrl = process.env.CHATKIT_BACKEND_URL ?? "http://127.0.0.1:8000"; +const agencyName = process.env.CHATKIT_AGENCY_NAME ?? "agency"; + +export default defineConfig({ + plugins: [react()], + define: { + // Pass agency name to the frontend + "import.meta.env.VITE_AGENCY_NAME": JSON.stringify(agencyName), + }, + server: { + port: parseInt(process.env.CHATKIT_FRONTEND_PORT ?? "3000"), + host: "0.0.0.0", + proxy: { + "/chatkit": { + target: backendUrl, + changeOrigin: true, + rewrite: (path) => `/${agencyName}${path}`, + }, + }, + }, +}); diff --git a/tests/integration/fastapi/test_fastapi_chatkit.py b/tests/integration/fastapi/test_fastapi_chatkit.py new file mode 100644 index 00000000..46e50634 --- /dev/null +++ b/tests/integration/fastapi/test_fastapi_chatkit.py @@ -0,0 +1,233 @@ +"""Integration tests for ChatKit FastAPI endpoint.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from agency_swarm import Agency, Agent, run_fastapi + + +@dataclass +class ChatkitContextTracker: + """Tracks context and messages received by the test agent.""" + + last_context: dict[str, Any] | None = None + last_message: str | None = None + + def reset(self) -> None: + self.last_context = None + self.last_message = None + + +class ChatkitTestAgent(Agent): + """Test agent that records inputs for verification.""" + + def __init__(self, tracker: ChatkitContextTracker): + super().__init__(name="ChatkitTestAgent", instructions="Test agent for ChatKit") + self._tracker = tracker + + def get_response_stream( + self, + message, + sender_name=None, + context_override: dict[str, Any] | None = None, + **kwargs: Any, + ): + # Extract last user message from history list + if isinstance(message, list): + for msg in reversed(message): + if isinstance(msg, dict) and msg.get("role") == "user": + self._tracker.last_message = msg.get("content", "") + break + else: + self._tracker.last_message = str(message) + self._tracker.last_context = context_override + + async def _generator(): + # Simulate streaming events + yield SimpleNamespace( + type="raw_response_event", + data=SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace(type="message", role="assistant", id="msg-1"), + ), + ) + yield SimpleNamespace( + type="raw_response_event", + data=SimpleNamespace( + type="response.output_text.delta", + item_id="msg-1", + delta="Hello from ChatKit!", + ), + ) + yield SimpleNamespace( + type="raw_response_event", + data=SimpleNamespace( + type="response.output_item.done", + item=SimpleNamespace(type="message", id="msg-1"), + ), + ) + + return _generator() + + +@dataclass +class ChatkitAgencyFactory: + """Factory for creating agencies with ChatKit test agent.""" + + tracker: ChatkitContextTracker = field(default_factory=ChatkitContextTracker) + + def __call__(self, load_threads_callback=None, save_threads_callback=None): + self.tracker.reset() + agent = ChatkitTestAgent(self.tracker) + return Agency( + agent, + name="chatkit_test", + load_threads_callback=load_threads_callback, + save_threads_callback=save_threads_callback, + ) + + +@pytest.fixture +def chatkit_factory() -> ChatkitAgencyFactory: + """Fixture providing a ChatKit test agency factory.""" + return ChatkitAgencyFactory() + + +def test_chatkit_endpoint_receives_user_message(chatkit_factory: ChatkitAgencyFactory): + """Verify that ChatKit endpoint extracts and passes user message.""" + app = run_fastapi( + agencies={"chatkit_test": chatkit_factory}, + return_app=True, + app_token_env="", + enable_chatkit=True, + ) + client = TestClient(app) + + # Use proper ChatKit protocol format + payload = { + "type": "threads.create", + "params": { + "input": { + "content": [{"type": "input_text", "text": "Hello ChatKit!"}], + }, + }, + "context": {"user_plan": "premium"}, + } + + with client.stream("POST", "/chatkit_test/chatkit", json=payload) as response: + assert response.status_code == 200 + events = list(response.iter_lines()) + assert len(events) > 0 + + assert chatkit_factory.tracker.last_message == "Hello ChatKit!" + # Filter out internal keys for comparison + context = chatkit_factory.tracker.last_context or {} + user_context = {k: v for k, v in context.items() if not k.startswith("_")} + assert user_context == {"user_plan": "premium"} + + +def test_chatkit_endpoint_streams_events(chatkit_factory: ChatkitAgencyFactory): + """Verify that ChatKit endpoint streams events in correct format.""" + app = run_fastapi( + agencies={"chatkit_test": chatkit_factory}, + return_app=True, + app_token_env="", + enable_chatkit=True, + ) + client = TestClient(app) + + payload = { + "type": "threads.create", + "params": { + "input": {"content": [{"type": "input_text", "text": "Test"}]}, + }, + } + + with client.stream("POST", "/chatkit_test/chatkit", json=payload) as response: + assert response.status_code == 200 + events = [] + for line in response.iter_lines(): + if line.startswith("data: "): + events.append(json.loads(line[6:])) + + # Should have thread.created, user message item, assistant events, and completion + event_types = [e.get("type") for e in events] + assert "thread.created" in event_types + assert "thread.item.added" in event_types + assert "thread.run.completed" in event_types + + +def test_chatkit_endpoint_handles_existing_thread(chatkit_factory: ChatkitAgencyFactory): + """Verify that adding message to existing thread works.""" + app = run_fastapi( + agencies={"chatkit_test": chatkit_factory}, + return_app=True, + app_token_env="", + enable_chatkit=True, + ) + client = TestClient(app) + + # Use threads.add_user_message for existing thread + payload = { + "type": "threads.add_user_message", + "params": { + "thread_id": "existing-thread", + "input": {"content": [{"type": "input_text", "text": "New message"}]}, + }, + } + + with client.stream("POST", "/chatkit_test/chatkit", json=payload) as response: + assert response.status_code == 200 + events = list(response.iter_lines()) + assert len(events) > 0 + + # New message should be processed + assert chatkit_factory.tracker.last_message == "New message" + + +def test_chatkit_endpoint_returns_json_for_no_input(chatkit_factory: ChatkitAgencyFactory): + """Verify that empty input returns JSON response.""" + app = run_fastapi( + agencies={"chatkit_test": chatkit_factory}, + return_app=True, + app_token_env="", + enable_chatkit=True, + ) + client = TestClient(app) + + payload = {"type": "threads.create", "params": {}} + + response = client.post("/chatkit_test/chatkit", json=payload) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "no_input" + + +def test_chatkit_endpoint_coexists_with_standard_endpoints(chatkit_factory: ChatkitAgencyFactory): + """Verify ChatKit endpoint works alongside standard endpoints.""" + app = run_fastapi( + agencies={"chatkit_test": chatkit_factory}, + return_app=True, + app_token_env="", + enable_chatkit=True, + ) + client = TestClient(app) + + # Standard endpoints should exist + response = client.get("/chatkit_test/get_metadata") + assert response.status_code == 200 + + # ChatKit endpoint should also exist + payload = { + "type": "threads.create", + "params": {"input": {"content": [{"type": "input_text", "text": "Test"}]}}, + } + with client.stream("POST", "/chatkit_test/chatkit", json=payload) as response: + assert response.status_code == 200 diff --git a/tests/test_agent_modules/test_chatkit_adapter.py b/tests/test_agent_modules/test_chatkit_adapter.py new file mode 100644 index 00000000..5c5bcc15 --- /dev/null +++ b/tests/test_agent_modules/test_chatkit_adapter.py @@ -0,0 +1,566 @@ +"""Tests for the ChatKit adapter module.""" + +import json +from types import SimpleNamespace + +from agency_swarm.ui.core.chatkit_adapter import ChatkitAdapter + + +def make_raw_event(data): + """Create a raw response event wrapper.""" + return SimpleNamespace(type="raw_response_event", data=data) + + +def make_stream_event(name, item): + """Create a run item stream event wrapper.""" + return SimpleNamespace(type="run_item_stream_event", name=name, item=item) + + +class TestChatkitMessagesToChatHistory: + """Tests for chatkit_messages_to_chat_history conversion.""" + + def test_converts_user_message(self): + """User messages are converted to role:user format.""" + items = [ + { + "id": "user-1", + "type": "user_message", + "content": [{"type": "input_text", "text": "Hello world"}], + } + ] + + history = ChatkitAdapter.chatkit_messages_to_chat_history(items) + + assert history == [{"role": "user", "content": "Hello world"}] + + def test_converts_assistant_message(self): + """Assistant messages are converted to role:assistant format.""" + items = [ + { + "id": "asst-1", + "type": "assistant_message", + "content": [{"type": "output_text", "text": "Hi there"}], + } + ] + + history = ChatkitAdapter.chatkit_messages_to_chat_history(items) + + assert history == [{"role": "assistant", "content": "Hi there"}] + + def test_converts_tool_call(self): + """Tool calls are converted to function_call format.""" + items = [ + { + "id": "tool-1", + "type": "client_tool_call", + "call_id": "call-123", + "name": "search", + "arguments": '{"query": "test"}', + "status": "completed", + } + ] + + history = ChatkitAdapter.chatkit_messages_to_chat_history(items) + + assert len(history) == 1 + assert history[0]["type"] == "function_call" + assert history[0]["name"] == "search" + assert history[0]["call_id"] == "call-123" + + def test_converts_tool_call_with_output(self): + """Tool calls with output include function_call_output.""" + items = [ + { + "id": "tool-1", + "type": "client_tool_call", + "call_id": "call-123", + "name": "search", + "arguments": '{"query": "test"}', + "status": "completed", + "output": "Result data", + } + ] + + history = ChatkitAdapter.chatkit_messages_to_chat_history(items) + + assert len(history) == 2 + assert history[0]["type"] == "function_call" + assert history[1]["type"] == "function_call_output" + assert history[1]["output"] == "Result data" + + def test_handles_multiple_content_parts(self): + """Multiple content parts are concatenated.""" + items = [ + { + "id": "user-1", + "type": "user_message", + "content": [ + {"type": "input_text", "text": "Part 1. "}, + {"type": "input_text", "text": "Part 2."}, + ], + } + ] + + history = ChatkitAdapter.chatkit_messages_to_chat_history(items) + + assert history[0]["content"] == "Part 1. Part 2." + + +class TestChatkitAdapterEventCreation: + """Tests for ChatKit event creation methods.""" + + def test_create_thread_created_event(self): + """Creates a valid thread.created event.""" + adapter = ChatkitAdapter() + event = adapter._create_thread_created_event("thread-123") + + assert event["type"] == "thread.created" + assert event["thread"]["id"] == "thread-123" + assert "created_at" in event["thread"] + + def test_create_item_added_event(self): + """Creates a valid thread.item.added event.""" + adapter = ChatkitAdapter() + event = adapter._create_item_added_event( + "item-1", + "assistant_message", + {"content": [{"type": "output_text", "text": "Hello"}]}, + ) + + assert event["type"] == "thread.item.added" + assert event["item"]["id"] == "item-1" + assert event["item"]["type"] == "assistant_message" + assert event["item"]["content"][0]["text"] == "Hello" + + def test_create_item_updated_event(self): + """Creates a valid thread.item.updated event.""" + adapter = ChatkitAdapter() + event = adapter._create_item_updated_event( + "item-1", + {"type": "assistant_message.content_part.text_delta", "delta": "Hi"}, + ) + + assert event["type"] == "thread.item.updated" + assert event["item_id"] == "item-1" + assert event["update"]["delta"] == "Hi" + + def test_create_item_done_event(self): + """Creates a valid thread.item.done event.""" + adapter = ChatkitAdapter() + event = adapter._create_item_done_event("item-1") + + assert event["type"] == "thread.item.done" + assert event["item_id"] == "item-1" + + def test_create_assistant_message_item(self): + """Creates assistant message item structure.""" + adapter = ChatkitAdapter() + item = adapter._create_assistant_message_item("item-1", "Hello world") + + assert item["content"][0]["type"] == "output_text" + assert item["content"][0]["text"] == "Hello world" + assert item["content"][0]["annotations"] == [] + + def test_create_tool_call_item(self): + """Creates tool call item structure.""" + adapter = ChatkitAdapter() + item = adapter._create_tool_call_item("call-1", "search", '{"q": "test"}', "in_progress") + + assert item["call_id"] == "call-1" + assert item["name"] == "search" + assert item["arguments"] == '{"q": "test"}' + assert item["status"] == "in_progress" + + def test_create_tool_call_item_with_output(self): + """Tool call item includes output when provided.""" + adapter = ChatkitAdapter() + item = adapter._create_tool_call_item("call-1", "search", "{}", "completed", output="Result") + + assert item["output"] == "Result" + + +class TestChatkitAdapterEventConversion: + """Tests for OpenAI -> ChatKit event conversion.""" + + def test_assistant_message_start(self): + """Converts message start event.""" + adapter = ChatkitAdapter() + raw_event = make_raw_event( + SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace(type="message", role="assistant", id="msg-1"), + ) + ) + + result = adapter.openai_to_chatkit_events(raw_event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.item.added" + assert result["item"]["type"] == "assistant_message" + + def test_text_delta_event(self): + """Converts text delta event.""" + adapter = ChatkitAdapter() + # First emit a message start to set up state + adapter.openai_to_chatkit_events( + make_raw_event( + SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace(type="message", role="assistant", id="msg-1"), + ) + ), + run_id="run-1", + thread_id="thread-1", + ) + + delta_event = make_raw_event( + SimpleNamespace( + type="response.output_text.delta", + item_id="msg-1", + delta="Hello", + ) + ) + + result = adapter.openai_to_chatkit_events(delta_event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.item.updated" + assert result["update"]["type"] == "assistant_message.content_part.text_delta" + assert result["update"]["delta"] == "Hello" + + def test_text_delta_without_item_id_is_ignored(self): + """Text delta without item_id returns None.""" + adapter = ChatkitAdapter() + event = make_raw_event( + SimpleNamespace( + type="response.output_text.delta", + item_id=None, + delta="Hi", + ) + ) + + result = adapter.openai_to_chatkit_events(event, run_id="run-1", thread_id="thread-1") + + assert result is None + + def test_message_done_event(self): + """Converts message done event.""" + adapter = ChatkitAdapter() + # Set up state first + adapter.openai_to_chatkit_events( + make_raw_event( + SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace(type="message", role="assistant", id="msg-1"), + ) + ), + run_id="run-1", + thread_id="thread-1", + ) + + done_event = make_raw_event( + SimpleNamespace( + type="response.output_item.done", + item=SimpleNamespace(type="message", id="msg-1"), + ) + ) + + result = adapter.openai_to_chatkit_events(done_event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.item.done" + + def test_tool_call_start(self): + """Converts tool call start event.""" + adapter = ChatkitAdapter() + event = make_raw_event( + SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace( + type="function_call", + id="item-1", + call_id="call-1", + name="search", + arguments="{}", + ), + ) + ) + + result = adapter.openai_to_chatkit_events(event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.item.added" + assert result["item"]["type"] == "client_tool_call" + assert result["item"]["name"] == "search" + + def test_tool_call_without_call_id_is_ignored(self): + """Tool call without call_id returns None.""" + adapter = ChatkitAdapter() + event = make_raw_event( + SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace( + type="function_call", + id="item-1", + call_id=None, + name="search", + arguments="{}", + ), + ) + ) + + result = adapter.openai_to_chatkit_events(event, run_id="run-1", thread_id="thread-1") + + assert result is None + + def test_tool_arguments_delta(self): + """Converts tool arguments delta event.""" + adapter = ChatkitAdapter() + # Set up tool call first + adapter.openai_to_chatkit_events( + make_raw_event( + SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace( + type="function_call", + id="item-1", + call_id="call-1", + name="search", + arguments="{}", + ), + ) + ), + run_id="run-1", + thread_id="thread-1", + ) + + delta_event = make_raw_event( + SimpleNamespace( + type="response.function_call_arguments.delta", + item_id="item-1", + delta='{"q": "test', + ) + ) + + result = adapter.openai_to_chatkit_events(delta_event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.item.updated" + assert result["update"]["type"] == "client_tool_call.arguments_delta" + assert result["update"]["delta"] == '{"q": "test' + + def test_tool_done_returns_multiple_events(self): + """Tool done returns list with update and done events.""" + adapter = ChatkitAdapter() + # Set up tool call first + adapter.openai_to_chatkit_events( + make_raw_event( + SimpleNamespace( + type="response.output_item.added", + item=SimpleNamespace( + type="function_call", + id="item-1", + call_id="call-1", + name="search", + arguments="{}", + ), + ) + ), + run_id="run-1", + thread_id="thread-1", + ) + + done_event = make_raw_event( + SimpleNamespace( + type="response.output_item.done", + item=SimpleNamespace( + type="function_call", + id="item-1", + call_id="call-1", + name="search", + arguments='{"q": "test"}', + ), + ) + ) + + result = adapter.openai_to_chatkit_events(done_event, run_id="run-1", thread_id="thread-1") + + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["type"] == "thread.item.updated" + assert result[1]["type"] == "thread.item.done" + + def test_handles_exception_with_error_event(self): + """Exceptions are converted to thread.error events.""" + from unittest.mock import MagicMock, PropertyMock + + adapter = ChatkitAdapter() + event = MagicMock() + event.type = "raw_response_event" + # Create a property that raises an exception when accessed + type(event).data = PropertyMock(side_effect=RuntimeError("boom")) + + result = adapter.openai_to_chatkit_events(event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.error" + assert "boom" in result["error"]["message"] + + +class TestChatkitAdapterRunItemStream: + """Tests for run_item_stream_event handling.""" + + def test_message_output_created(self): + """Converts message_output_created event.""" + adapter = ChatkitAdapter() + # Set up state first + adapter._run_state["run-1"] = { + "call_id_by_item": {}, + "item_id_by_call": {}, + "current_message_id": "msg-1", + "accumulated_text": {}, + } + + item = SimpleNamespace( + raw_item=SimpleNamespace( + id="msg-1", + content=[SimpleNamespace(text="Hello", annotations=[])], + ) + ) + event = make_stream_event("message_output_created", item) + + result = adapter.openai_to_chatkit_events(event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.item.updated" + assert result["update"]["type"] == "assistant_message.content_part.done" + assert result["update"]["content"]["text"] == "Hello" + + def test_tool_output(self): + """Converts tool_output event.""" + adapter = ChatkitAdapter() + # Set up state with tool call + adapter._run_state["run-1"] = { + "call_id_by_item": {}, + "item_id_by_call": {"call-1": "chatkit-item-1"}, + "current_message_id": None, + "accumulated_text": {}, + } + + item = SimpleNamespace( + raw_item={"call_id": "call-1"}, + call_id="call-1", + output="Tool result", + ) + event = make_stream_event("tool_output", item) + + result = adapter.openai_to_chatkit_events(event, run_id="run-1", thread_id="thread-1") + + assert result["type"] == "thread.item.updated" + assert result["update"]["type"] == "client_tool_call.output" + assert result["update"]["output"] == "Tool result" + + def test_tool_output_without_call_id_is_ignored(self): + """Tool output without call_id returns None.""" + adapter = ChatkitAdapter() + adapter._run_state["run-1"] = { + "call_id_by_item": {}, + "item_id_by_call": {}, + "current_message_id": None, + "accumulated_text": {}, + } + + item = SimpleNamespace(raw_item={}, call_id=None, output="Result") + event = make_stream_event("tool_output", item) + + result = adapter.openai_to_chatkit_events(event, run_id="run-1", thread_id="thread-1") + + assert result is None + + +class TestChatkitAdapterRunState: + """Tests for run state management.""" + + def test_clear_run_state_all(self): + """Clears all run state.""" + adapter = ChatkitAdapter() + adapter._run_state["run-1"] = {"test": "data"} + adapter._run_state["run-2"] = {"test": "data2"} + + adapter.clear_run_state() + + assert adapter._run_state == {} + + def test_clear_run_state_specific(self): + """Clears specific run state.""" + adapter = ChatkitAdapter() + adapter._run_state["run-1"] = {"test": "data"} + adapter._run_state["run-2"] = {"test": "data2"} + + adapter.clear_run_state("run-1") + + assert "run-1" not in adapter._run_state + assert "run-2" in adapter._run_state + + def test_generate_item_id_is_unique(self): + """Generated item IDs are unique.""" + adapter = ChatkitAdapter() + ids = {adapter._generate_item_id() for _ in range(100)} + assert len(ids) == 100 + + +class TestChatkitAdapterToolMeta: + """Tests for _tool_meta helper method.""" + + def test_function_call_meta(self): + """Extracts metadata from function_call.""" + adapter = ChatkitAdapter() + raw_item = SimpleNamespace( + type="function_call", + call_id="call-1", + name="search", + arguments='{"q": "test"}', + ) + + call_id, name, args = adapter._tool_meta(raw_item) + + assert call_id == "call-1" + assert name == "search" + assert args == '{"q": "test"}' + + def test_file_search_call_meta(self): + """Extracts metadata from file_search_call.""" + adapter = ChatkitAdapter() + raw_item = SimpleNamespace( + type="file_search_call", + id="file-1", + queries=["foo"], + results=["bar"], + ) + + call_id, name, args = adapter._tool_meta(raw_item) + + assert call_id == "file-1" + assert name == "FileSearchTool" + assert json.loads(args)["queries"] == ["foo"] + + def test_code_interpreter_call_meta(self): + """Extracts metadata from code_interpreter_call.""" + adapter = ChatkitAdapter() + raw_item = SimpleNamespace( + type="code_interpreter_call", + id="ci-1", + code="print(42)", + container_id="cid", + outputs=["42"], + ) + + call_id, name, args = adapter._tool_meta(raw_item) + + assert call_id == "ci-1" + assert name == "CodeInterpreterTool" + assert json.loads(args)["code"] == "print(42)" + + def test_unknown_type_returns_none(self): + """Unknown type returns None tuple.""" + adapter = ChatkitAdapter() + raw_item = SimpleNamespace(type="unknown_type") + + call_id, name, args = adapter._tool_meta(raw_item) + + assert call_id is None + assert name is None + assert args is None