Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions docs/manager/rest-reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -5005,6 +5005,44 @@
"description": "\n**Preconditions:**\n* User privilege required.\n"
}
},
"//gql/strawberry/ws": {
"head": {
"operationId": "root.handle_gql_strawberry_ws",
"tags": [
"root"
],
"responses": {
"200": {
"description": "Successful response"
}
},
"security": [
{
"TokenAuth": []
}
],
"parameters": [],
"description": "\n**Preconditions:**\n* User privilege required.\n"
},
"get": {
"operationId": "root.handle_gql_strawberry_ws.2",
"tags": [
"root"
],
"responses": {
"200": {
"description": "Successful response"
}
},
"security": [
{
"TokenAuth": []
}
],
"parameters": [],
"description": "\n**Preconditions:**\n* User privilege required.\n"
}
},
"/spec/graphiql": {
"get": {
"operationId": "spec.render_graphiql_graphene_html",
Expand Down
171 changes: 169 additions & 2 deletions src/ai/backend/manager/api/admin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import json
import logging
import traceback
from enum import Enum
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Iterable, Optional, Self, Tuple, cast
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Optional, Self, Tuple, cast

import aiohttp_cors
import attrs
Expand All @@ -15,7 +17,7 @@
from graphql import ValidationRule, parse, validate
from graphql.error import GraphQLError # pants: no-infer-dep
from graphql.execution import ExecutionResult # pants: no-infer-dep
from pydantic import ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field

from ai.backend.common import validators as tx
from ai.backend.common.api_handlers import APIResponse, BodyParam, MiddlewareParam, api_handler
Expand Down Expand Up @@ -50,6 +52,54 @@
log = BraceStyleAdapter(logging.getLogger(__spec__.name))


# WebSocket message type enum
class GraphQLWSMessageType(str, Enum):
CONNECTION_INIT = "connection_init"
SUBSCRIBE = "subscribe"
COMPLETE = "complete"


# Payload types for WebSocket messages
class GraphQLWSSubscribePayload(BaseModel):
query: str
variables: dict[str, Any] | None = None
operationName: str | None = None


# Union type for all WebSocket messages
class GraphQLWSMessage(BaseModel):
type: GraphQLWSMessageType
id: str | None = None
payload: dict[str, Any] | None = None # Will be validated in specific message types


# Type for schema.subscribe return value - it returns either an AsyncIterator of ExecutionResult
# or a single ExecutionResult for non-subscription operations
SubscriptionResult = AsyncIterator[ExecutionResult]


# WebSocket response message types
class GraphQLWSConnectionAck(BaseModel):
type: str = "connection_ack"


class GraphQLWSNext(BaseModel):
type: str = "next"
id: str
payload: dict[str, Any]


class GraphQLWSError(BaseModel):
type: str = "error"
id: str | None = None
payload: list[dict[str, str]]


class GraphQLWSCompleteResponse(BaseModel):
type: str = "complete"
id: str


class GQLLoggingMiddleware:
def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any:
if info.path.prev is None: # indicates the root query
Expand Down Expand Up @@ -232,6 +282,9 @@
strawberry_ctx = StrawberryGQLContext(
processors=processors_ctx.processors,
config_provider=config_provider_ctx.config_provider,
event_hub=processors_ctx.event_hub,
event_fetcher=processors_ctx.event_fetcher,
valkey_bgtask=processors_ctx.valkey_bgtask,
)

query, variables, operation_name = (
Expand Down Expand Up @@ -302,6 +355,7 @@
auto_camelcase=False,
)
app_ctx.gql_v2_schema = strawberry_schema

root_ctx: RootContext = app["_root.context"]
if root_ctx.config_provider.config.api.allow_graphql_schema_introspection:
log.warning(
Expand All @@ -314,6 +368,117 @@
pass


@auth_required
async def handle_gql_strawberry_ws(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse(protocols=["graphql-transport-ws", "graphql-ws"])
await ws.prepare(request)

root_ctx: RootContext = request.app["_root.context"]
processors_ctx = await ProcessorsCtx.from_request(request)
context = StrawberryGQLContext(
processors=processors_ctx.processors,
config_provider=root_ctx.config_provider,
event_hub=processors_ctx.event_hub,
event_fetcher=processors_ctx.event_fetcher,
valkey_bgtask=processors_ctx.valkey_bgtask,
)

schema = request.app["admin.context"].gql_v2_schema

async for msg in ws:
if msg.type == web.WSMsgType.TEXT:
try:
# Parse and validate WebSocket message using Pydantic
raw_data = msg.json()
ws_message = GraphQLWSMessage.model_validate(raw_data)

if ws_message.type == GraphQLWSMessageType.CONNECTION_INIT:
response = GraphQLWSConnectionAck()
await ws.send_str(json.dumps(response.model_dump()))

elif ws_message.type == GraphQLWSMessageType.SUBSCRIBE:
if not ws_message.id or not ws_message.payload:
raise ValueError("Subscribe message requires id and payload")

# Validate and parse subscription payload
subscribe_payload = GraphQLWSSubscribePayload.model_validate(ws_message.payload)
query = subscribe_payload.query
variables = subscribe_payload.variables or {}

log.info(
"Processing subscription: {}, query: {}, variables: {}",
ws_message.id,
query[:30],
variables,
)

try:
# Execute subscription using Strawberry's subscribe method
async_result: SubscriptionResult = await schema.subscribe(
query,
variable_values=variables,
context_value=context,
)

log.info("Subscription subscribe result: {}", type(async_result))

if not hasattr(async_result, "__aiter__"):
# TODO: Add exception

Check notice

Code scanning / devskim

A "TODO" or similar was left in source code, possibly indicating incomplete functionality Note

Suspicious comment
raise ValueError("Expected an async iterator for subscription")

log.info("Processing subscription async generator")

async for result in async_result:
log.info(
"Subscription result: errors={}, data={}",
result.errors,
result.data,
)

if result.errors:
log.error("Subscription errors: {}", result.errors)
error_response = GraphQLWSError(
id=ws_message.id,
payload=[{"message": str(e)} for e in result.errors],
)
await ws.send_str(json.dumps(error_response.model_dump()))
break
elif result.data:
log.info("Sending subscription data: {}", result.data)
next_response = GraphQLWSNext(
id=ws_message.id, payload={"data": result.data}
)
await ws.send_str(json.dumps(next_response.model_dump()))

# Send completion after async iterator is exhausted
log.info("Subscription completed, sending complete message")
complete_response = GraphQLWSCompleteResponse(id=ws_message.id)
await ws.send_str(json.dumps(complete_response.model_dump()))

except Exception as e:
log.error("Subscription execution error: {}", e)
log.exception("Full traceback:")
error_response = GraphQLWSError(
id=ws_message.id, payload=[{"message": str(e)}]
)
await ws.send_str(json.dumps(error_response.model_dump()))

elif ws_message.type == GraphQLWSMessageType.COMPLETE:
if not ws_message.id:
raise ValueError("Complete message requires id")
log.info("Received complete message for subscription: {}", ws_message.id)

except Exception as e:
# Handle message parsing and validation errors
log.error("WebSocket message validation error: {}", e)
error_response = GraphQLWSError(
payload=[{"message": f"Invalid message format: {str(e)}"}]
)
await ws.send_str(json.dumps(error_response.model_dump()))

return ws


def create_app(
default_cors_options: CORSOptions,
) -> Tuple[web.Application, Iterable[WebMiddleware]]:
Expand All @@ -329,4 +494,6 @@
cors.add(
app.router.add_route("POST", r"/gql/strawberry", gql_api_handler.handle_gql_strawberry)
)
cors.add(app.router.add_get(r"/gql/strawberry/ws", handle_gql_strawberry_ws))

return app, []
6 changes: 6 additions & 0 deletions src/ai/backend/manager/api/gql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import attrs

from ai.backend.common.clients.valkey_client.valkey_bgtask.client import ValkeyBgtaskClient
from ai.backend.common.events.fetcher import EventFetcher
from ai.backend.common.events.hub.hub import EventHub
from ai.backend.manager.config.provider import ManagerConfigProvider
from ai.backend.manager.services.processors import Processors

Expand All @@ -11,3 +14,6 @@
class StrawberryGQLContext:
processors: Processors
config_provider: ManagerConfigProvider
event_hub: EventHub
event_fetcher: EventFetcher
valkey_bgtask: ValkeyBgtaskClient
10 changes: 9 additions & 1 deletion src/ai/backend/manager/api/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,15 @@
></script>

<script>
const fetcher = GraphiQL.createFetcher({ url: '../../admin/gql/strawberry' });
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
// Use Web Server proxy for GraphQL WebSocket subscriptions
const wsUrl = `${wsProtocol}//${window.location.host}/func/admin/gql/strawberry/ws`;
// const wsUrl = `${wsProtocol}//${window.location.host}/func/stream/gql`;

const fetcher = GraphiQL.createFetcher({
url: '../../admin/gql/strawberry',
subscriptionUrl: wsUrl,
});

ReactDOM.render(
React.createElement(GraphiQL, { fetcher: fetcher }),
Expand Down
13 changes: 12 additions & 1 deletion src/ai/backend/manager/dto/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,28 @@
from pydantic import ConfigDict

from ai.backend.common.api_handlers import MiddlewareParam
from ai.backend.common.clients.valkey_client.valkey_bgtask.client import ValkeyBgtaskClient
from ai.backend.common.events.fetcher import EventFetcher
from ai.backend.common.events.hub.hub import EventHub
from ai.backend.manager.api.context import RootContext
from ai.backend.manager.services.processors import Processors


class ProcessorsCtx(MiddlewareParam):
processors: Processors
event_hub: EventHub
event_fetcher: EventFetcher
valkey_bgtask: ValkeyBgtaskClient

model_config = ConfigDict(arbitrary_types_allowed=True)

@override
@classmethod
async def from_request(cls, request: web.Request) -> Self:
root_ctx: RootContext = request.app["_root.context"]
return cls(processors=root_ctx.processors)
return cls(
processors=root_ctx.processors,
event_hub=root_ctx.event_hub,
event_fetcher=root_ctx.event_fetcher,
valkey_bgtask=root_ctx.valkey_bgtask,
)
14 changes: 13 additions & 1 deletion src/ai/backend/web/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,19 @@ async def websocket_handler(
override_api_version=request_api_version,
)
async with api_request.connect_websocket() as up_conn:
down_conn = web.WebSocketResponse()
# Support GraphQL WebSocket protocols
protocols = []
if "sec-websocket-protocol" in request.headers:
client_protocols = [
p.strip() for p in request.headers["sec-websocket-protocol"].split(",")
]
supported_protocols = ["graphql-transport-ws", "graphql-ws"]
protocols = [p for p in client_protocols if p in supported_protocols]
print(
f"=== GraphQL WebSocket protocols: client={client_protocols}, supported={protocols} ==="
)

down_conn = web.WebSocketResponse(protocols=protocols or [])
await down_conn.prepare(request)
web_socket_proxy = WebSocketProxy(up_conn.raw_websocket, down_conn)
await web_socket_proxy.proxy()
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,16 @@
cors.add(app.router.add_route("POST", "/func/{path:auth/signout}", web_handler))
cors.add(app.router.add_route("GET", "/func/{path:stream/kernel/_/events}", web_handler))
cors.add(app.router.add_route("GET", "/func/{path:stream/session/[^/]+/apps$}", web_handler))
# GraphQL WebSocket subscription endpoint (must be before generic stream handler)
cors.add(
app.router.add_route("GET", "/func/{path:admin/gql/strawberry/ws$}", websocket_handler)
)

# TODO: Use this endpoint

Check notice

Code scanning / devskim

A "TODO" or similar was left in source code, possibly indicating incomplete functionality Note

Suspicious comment
# cors.add(
# app.router.add_route("GET", "/func/{path:stream/gql$}", websocket_handler)
# )

cors.add(app.router.add_route("GET", "/func/{path:stream/.*$}", websocket_handler))
cors.add(app.router.add_route("GET", "/func/", anon_web_handler))

Expand Down
Loading