diff --git a/docs/manager/rest-reference/openapi.json b/docs/manager/rest-reference/openapi.json index 4d8f5cba095..5e50de43e3f 100644 --- a/docs/manager/rest-reference/openapi.json +++ b/docs/manager/rest-reference/openapi.json @@ -5005,6 +5005,44 @@ "description": "\n**Preconditions:**\n* User privilege required.\n" } }, + "//gql/strawberry/ws": { + "head": { + "operationId": "root.handle_gql_strawberry_ws", + "tags": [ + "root" + ], + "responses": { + "200": { + "description": "Successful response" + } + }, + "security": [ + { + "TokenAuth": [] + } + ], + "parameters": [], + "description": "\n**Preconditions:**\n* User privilege required.\n" + }, + "get": { + "operationId": "root.handle_gql_strawberry_ws.2", + "tags": [ + "root" + ], + "responses": { + "200": { + "description": "Successful response" + } + }, + "security": [ + { + "TokenAuth": [] + } + ], + "parameters": [], + "description": "\n**Preconditions:**\n* User privilege required.\n" + } + }, "/spec/graphiql": { "get": { "operationId": "spec.render_graphiql_graphene_html", diff --git a/src/ai/backend/manager/api/admin.py b/src/ai/backend/manager/api/admin.py index 40b563e95c4..89d6271f9e5 100644 --- a/src/ai/backend/manager/api/admin.py +++ b/src/ai/backend/manager/api/admin.py @@ -1,9 +1,11 @@ from __future__ import annotations +import json import logging import traceback +from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Iterable, Optional, Self, Tuple, cast +from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Optional, Self, Tuple, cast import aiohttp_cors import attrs @@ -15,7 +17,7 @@ from graphql import ValidationRule, parse, validate from graphql.error import GraphQLError # pants: no-infer-dep from graphql.execution import ExecutionResult # pants: no-infer-dep -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field from ai.backend.common import validators as tx from ai.backend.common.api_handlers import APIResponse, BodyParam, MiddlewareParam, api_handler @@ -50,6 +52,54 @@ log = BraceStyleAdapter(logging.getLogger(__spec__.name)) +# WebSocket message type enum +class GraphQLWSMessageType(str, Enum): + CONNECTION_INIT = "connection_init" + SUBSCRIBE = "subscribe" + COMPLETE = "complete" + + +# Payload types for WebSocket messages +class GraphQLWSSubscribePayload(BaseModel): + query: str + variables: dict[str, Any] | None = None + operationName: str | None = None + + +# Union type for all WebSocket messages +class GraphQLWSMessage(BaseModel): + type: GraphQLWSMessageType + id: str | None = None + payload: dict[str, Any] | None = None # Will be validated in specific message types + + +# Type for schema.subscribe return value - it returns either an AsyncIterator of ExecutionResult +# or a single ExecutionResult for non-subscription operations +SubscriptionResult = AsyncIterator[ExecutionResult] + + +# WebSocket response message types +class GraphQLWSConnectionAck(BaseModel): + type: str = "connection_ack" + + +class GraphQLWSNext(BaseModel): + type: str = "next" + id: str + payload: dict[str, Any] + + +class GraphQLWSError(BaseModel): + type: str = "error" + id: str | None = None + payload: list[dict[str, str]] + + +class GraphQLWSCompleteResponse(BaseModel): + type: str = "complete" + id: str + + class GQLLoggingMiddleware: def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any: if info.path.prev is None: # indicates the root query @@ -232,6 +282,9 @@ async def handle_gql_strawberry( strawberry_ctx = StrawberryGQLContext( processors=processors_ctx.processors, config_provider=config_provider_ctx.config_provider, + event_hub=processors_ctx.event_hub, + event_fetcher=processors_ctx.event_fetcher, + valkey_bgtask=processors_ctx.valkey_bgtask, ) query, variables, operation_name = ( @@ -302,6 +355,7 @@ async def init(app: web.Application) -> None: auto_camelcase=False, ) app_ctx.gql_v2_schema = strawberry_schema + root_ctx: RootContext = app["_root.context"] if root_ctx.config_provider.config.api.allow_graphql_schema_introspection: log.warning( @@ -314,6 +368,117 @@ async def shutdown(app: web.Application) -> None: pass +@auth_required +async def handle_gql_strawberry_ws(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse(protocols=["graphql-transport-ws", "graphql-ws"]) + await ws.prepare(request) + + root_ctx: RootContext = request.app["_root.context"] + processors_ctx = await ProcessorsCtx.from_request(request) + context = StrawberryGQLContext( + processors=processors_ctx.processors, + config_provider=root_ctx.config_provider, + event_hub=processors_ctx.event_hub, + event_fetcher=processors_ctx.event_fetcher, + valkey_bgtask=processors_ctx.valkey_bgtask, + ) + + schema = request.app["admin.context"].gql_v2_schema + + async for msg in ws: + if msg.type == web.WSMsgType.TEXT: + try: + # Parse and validate WebSocket message using Pydantic + raw_data = msg.json() + ws_message = GraphQLWSMessage.model_validate(raw_data) + + if ws_message.type == GraphQLWSMessageType.CONNECTION_INIT: + response = GraphQLWSConnectionAck() + await ws.send_str(json.dumps(response.model_dump())) + + elif ws_message.type == GraphQLWSMessageType.SUBSCRIBE: + if not ws_message.id or not ws_message.payload: + raise ValueError("Subscribe message requires id and payload") + + # Validate and parse subscription payload + subscribe_payload = GraphQLWSSubscribePayload.model_validate(ws_message.payload) + query = subscribe_payload.query + variables = subscribe_payload.variables or {} + + log.info( + "Processing subscription: {}, query: {}, variables: {}", + ws_message.id, + query[:30], + variables, + ) + + try: + # Execute subscription using Strawberry's subscribe method + async_result: SubscriptionResult = await schema.subscribe( + query, + variable_values=variables, + context_value=context, + ) + + log.info("Subscription subscribe result: {}", type(async_result)) + + if not hasattr(async_result, "__aiter__"): + # TODO: Add exception + raise ValueError("Expected an async iterator for subscription") + + log.info("Processing subscription async generator") + + async for result in async_result: + log.info( + "Subscription result: errors={}, data={}", + result.errors, + result.data, + ) + + if result.errors: + log.error("Subscription errors: {}", result.errors) + error_response = GraphQLWSError( + id=ws_message.id, + payload=[{"message": str(e)} for e in result.errors], + ) + await ws.send_str(json.dumps(error_response.model_dump())) + break + elif result.data: + log.info("Sending subscription data: {}", result.data) + next_response = GraphQLWSNext( + id=ws_message.id, payload={"data": result.data} + ) + await ws.send_str(json.dumps(next_response.model_dump())) + + # Send completion after async iterator is exhausted + log.info("Subscription completed, sending complete message") + complete_response = GraphQLWSCompleteResponse(id=ws_message.id) + await ws.send_str(json.dumps(complete_response.model_dump())) + + except Exception as e: + log.error("Subscription execution error: {}", e) + log.exception("Full traceback:") + error_response = GraphQLWSError( + id=ws_message.id, payload=[{"message": str(e)}] + ) + await ws.send_str(json.dumps(error_response.model_dump())) + + elif ws_message.type == GraphQLWSMessageType.COMPLETE: + if not ws_message.id: + raise ValueError("Complete message requires id") + log.info("Received complete message for subscription: {}", ws_message.id) + + except Exception as e: + # Handle message parsing and validation errors + log.error("WebSocket message validation error: {}", e) + error_response = GraphQLWSError( + payload=[{"message": f"Invalid message format: {str(e)}"}] + ) + await ws.send_str(json.dumps(error_response.model_dump())) + + return ws + + def create_app( default_cors_options: CORSOptions, ) -> Tuple[web.Application, Iterable[WebMiddleware]]: @@ -329,4 +494,6 @@ def create_app( cors.add( app.router.add_route("POST", r"/gql/strawberry", gql_api_handler.handle_gql_strawberry) ) + cors.add(app.router.add_get(r"/gql/strawberry/ws", handle_gql_strawberry_ws)) + return app, [] diff --git a/src/ai/backend/manager/api/gql/types.py b/src/ai/backend/manager/api/gql/types.py index b01e3cfd024..2c4598ab149 100644 --- a/src/ai/backend/manager/api/gql/types.py +++ b/src/ai/backend/manager/api/gql/types.py @@ -3,6 +3,9 @@ import attrs +from ai.backend.common.clients.valkey_client.valkey_bgtask.client import ValkeyBgtaskClient +from ai.backend.common.events.fetcher import EventFetcher +from ai.backend.common.events.hub.hub import EventHub from ai.backend.manager.config.provider import ManagerConfigProvider from ai.backend.manager.services.processors import Processors @@ -11,3 +14,6 @@ class StrawberryGQLContext: processors: Processors config_provider: ManagerConfigProvider + event_hub: EventHub + event_fetcher: EventFetcher + valkey_bgtask: ValkeyBgtaskClient diff --git a/src/ai/backend/manager/api/spec.py b/src/ai/backend/manager/api/spec.py index da7a682b9c1..6ebcc10af66 100644 --- a/src/ai/backend/manager/api/spec.py +++ b/src/ai/backend/manager/api/spec.py @@ -89,7 +89,15 @@ >