Skip to content

Commit e549cb5

Browse files
committed
fix: Remove background_tasks GQL types
1 parent 1690409 commit e549cb5

File tree

2 files changed

+120
-304
lines changed

2 files changed

+120
-304
lines changed

src/ai/backend/manager/api/admin.py

Lines changed: 120 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import json
44
import logging
55
import traceback
6+
from enum import Enum
67
from http import HTTPStatus
7-
from typing import TYPE_CHECKING, Any, Iterable, Optional, Self, Tuple, cast
8+
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterable, Optional, Self, Tuple, cast
89

910
import aiohttp_cors
1011
import attrs
@@ -16,7 +17,7 @@
1617
from graphql import ValidationRule, parse, validate
1718
from graphql.error import GraphQLError # pants: no-infer-dep
1819
from graphql.execution import ExecutionResult # pants: no-infer-dep
19-
from pydantic import ConfigDict, Field
20+
from pydantic import BaseModel, ConfigDict, Field
2021

2122
# Import Strawberry aiohttp views
2223
from ai.backend.common import validators as tx
@@ -52,6 +53,54 @@
5253
log = BraceStyleAdapter(logging.getLogger(__spec__.name))
5354

5455

56+
# WebSocket message type enum
57+
class GraphQLWSMessageType(str, Enum):
58+
CONNECTION_INIT = "connection_init"
59+
SUBSCRIBE = "subscribe"
60+
COMPLETE = "complete"
61+
62+
63+
# Payload types for WebSocket messages
64+
class GraphQLWSSubscribePayload(BaseModel):
65+
query: str
66+
variables: dict[str, Any] | None = None
67+
operationName: str | None = None
68+
69+
70+
# Union type for all WebSocket messages
71+
class GraphQLWSMessage(BaseModel):
72+
type: GraphQLWSMessageType
73+
id: str | None = None
74+
payload: dict[str, Any] | None = None # Will be validated in specific message types
75+
76+
77+
# Type for schema.subscribe return value - it returns either an AsyncIterator of ExecutionResult
78+
# or a single ExecutionResult for non-subscription operations
79+
SubscriptionResult = AsyncIterator[ExecutionResult]
80+
81+
82+
# WebSocket response message types
83+
class GraphQLWSConnectionAck(BaseModel):
84+
type: str = "connection_ack"
85+
86+
87+
class GraphQLWSNext(BaseModel):
88+
type: str = "next"
89+
id: str
90+
payload: dict[str, Any]
91+
92+
93+
class GraphQLWSError(BaseModel):
94+
type: str = "error"
95+
id: str | None = None
96+
payload: list[dict[str, str]]
97+
98+
99+
class GraphQLWSCompleteResponse(BaseModel):
100+
type: str = "complete"
101+
id: str
102+
103+
55104
class GQLLoggingMiddleware:
56105
def resolve(self, next, root, info: graphene.ResolveInfo, **args) -> Any:
57106
if info.path.prev is None: # indicates the root query
@@ -323,11 +372,10 @@ async def shutdown(app: web.Application) -> None:
323372

324373

325374
@auth_required
326-
async def handle_gql_ws(request: web.Request) -> web.WebSocketResponse:
375+
async def handle_gql_strawberry_ws(request: web.Request) -> web.WebSocketResponse:
327376
ws = web.WebSocketResponse(protocols=["graphql-transport-ws", "graphql-ws"])
328377
await ws.prepare(request)
329378

330-
# Create context once
331379
root_ctx: RootContext = request.app["_root.context"]
332380
processors_ctx = await ProcessorsCtx.from_request(request)
333381
context = StrawberryGQLContext(
@@ -342,35 +390,45 @@ async def handle_gql_ws(request: web.Request) -> web.WebSocketResponse:
342390

343391
async for msg in ws:
344392
if msg.type == web.WSMsgType.TEXT:
345-
data = msg.json()
346-
347-
if data.get("type") == "connection_init":
348-
await ws.send_str('{"type":"connection_ack"}')
349-
350-
elif data.get("type") == "subscribe":
351-
subscription_id = data.get("id")
352-
payload = data.get("payload", {})
353-
query = payload.get("query", "")
354-
variables = payload.get("variables", {})
355-
356-
log.info(
357-
"Processing subscription: {}, query: {}, variables: {}",
358-
subscription_id,
359-
query[:30],
360-
variables,
361-
)
362-
363-
try:
364-
# Execute subscription using Strawberry's subscribe method for proper AsyncGenerator handling
365-
async_result = await schema.subscribe(
366-
query,
367-
variable_values=variables,
368-
context_value=context,
393+
try:
394+
# Parse and validate WebSocket message using Pydantic
395+
raw_data = msg.json()
396+
ws_message = GraphQLWSMessage.model_validate(raw_data)
397+
398+
if ws_message.type == GraphQLWSMessageType.CONNECTION_INIT:
399+
response = GraphQLWSConnectionAck()
400+
await ws.send_str(json.dumps(response.model_dump()))
401+
402+
elif ws_message.type == GraphQLWSMessageType.SUBSCRIBE:
403+
if not ws_message.id or not ws_message.payload:
404+
raise ValueError("Subscribe message requires id and payload")
405+
406+
# Validate and parse subscription payload
407+
subscribe_payload = GraphQLWSSubscribePayload.model_validate(ws_message.payload)
408+
query = subscribe_payload.query
409+
variables = subscribe_payload.variables or {}
410+
411+
log.info(
412+
"Processing subscription: {}, query: {}, variables: {}",
413+
ws_message.id,
414+
query[:30],
415+
variables,
369416
)
370417

371-
log.info("Subscription subscribe result: {}", type(async_result))
418+
try:
419+
# Execute subscription using Strawberry's subscribe method
420+
async_result: SubscriptionResult = await schema.subscribe(
421+
query,
422+
variable_values=variables,
423+
context_value=context,
424+
)
425+
426+
log.info("Subscription subscribe result: {}", type(async_result))
427+
428+
if not hasattr(async_result, "__aiter__"):
429+
# TODO: Add exception
430+
raise ValueError("Expected an async iterator for subscription")
372431

373-
if hasattr(async_result, "__aiter__"):
374432
log.info("Processing subscription async generator")
375433

376434
async for result in async_result:
@@ -382,59 +440,44 @@ async def handle_gql_ws(request: web.Request) -> web.WebSocketResponse:
382440

383441
if result.errors:
384442
log.error("Subscription errors: {}", result.errors)
385-
await ws.send_str(
386-
json.dumps({
387-
"id": subscription_id,
388-
"type": "error",
389-
"payload": [{"message": str(e)} for e in result.errors],
390-
})
443+
error_response = GraphQLWSError(
444+
id=ws_message.id,
445+
payload=[{"message": str(e)} for e in result.errors],
391446
)
447+
await ws.send_str(json.dumps(error_response.model_dump()))
392448
break
393449
elif result.data:
394450
log.info("Sending subscription data: {}", result.data)
395-
await ws.send_str(
396-
json.dumps({
397-
"id": subscription_id,
398-
"type": "next",
399-
"payload": {"data": result.data},
400-
})
451+
next_response = GraphQLWSNext(
452+
id=ws_message.id, payload={"data": result.data}
401453
)
454+
await ws.send_str(json.dumps(next_response.model_dump()))
402455

403-
# Send completion
456+
# Send completion after async iterator is exhausted
404457
log.info("Subscription completed, sending complete message")
405-
await ws.send_str(json.dumps({"id": subscription_id, "type": "complete"}))
406-
else:
407-
# Fallback to regular execute for queries
408-
log.info("Not a subscription, using regular execute")
409-
result = async_result
410-
411-
if result.errors:
412-
await ws.send_str(
413-
json.dumps({
414-
"id": subscription_id,
415-
"type": "error",
416-
"payload": [{"message": str(e)} for e in result.errors],
417-
})
418-
)
419-
elif result.data:
420-
await ws.send_str(
421-
json.dumps({
422-
"id": subscription_id,
423-
"type": "next",
424-
"payload": {"data": result.data},
425-
})
426-
)
427-
428-
except Exception as e:
429-
log.error("Subscription execution error: {}", e)
430-
log.exception("Full traceback:")
431-
await ws.send_str(
432-
json.dumps({
433-
"id": subscription_id,
434-
"type": "error",
435-
"payload": [{"message": str(e)}],
436-
})
437-
)
458+
complete_response = GraphQLWSCompleteResponse(id=ws_message.id)
459+
await ws.send_str(json.dumps(complete_response.model_dump()))
460+
461+
except Exception as e:
462+
log.error("Subscription execution error: {}", e)
463+
log.exception("Full traceback:")
464+
error_response = GraphQLWSError(
465+
id=ws_message.id, payload=[{"message": str(e)}]
466+
)
467+
await ws.send_str(json.dumps(error_response.model_dump()))
468+
469+
elif ws_message.type == GraphQLWSMessageType.COMPLETE:
470+
if not ws_message.id:
471+
raise ValueError("Complete message requires id")
472+
log.info("Received complete message for subscription: {}", ws_message.id)
473+
474+
except Exception as e:
475+
# Handle message parsing and validation errors
476+
log.error("WebSocket message validation error: {}", e)
477+
error_response = GraphQLWSError(
478+
payload=[{"message": f"Invalid message format: {str(e)}"}]
479+
)
480+
await ws.send_str(json.dumps(error_response.model_dump()))
438481

439482
return ws
440483

@@ -454,7 +497,6 @@ def create_app(
454497
cors.add(
455498
app.router.add_route("POST", r"/gql/strawberry", gql_api_handler.handle_gql_strawberry)
456499
)
457-
458-
cors.add(app.router.add_get(r"/gql/strawberry/ws", handle_gql_ws))
500+
cors.add(app.router.add_get(r"/gql/strawberry/ws", handle_gql_strawberry_ws))
459501

460502
return app, []

0 commit comments

Comments
 (0)