33import json
44import logging
55import traceback
6+ from enum import Enum
67from http import HTTPStatus
7- from typing import TYPE_CHECKING , Any , Iterable , Optional , Self , Tuple , cast
8+ from typing import TYPE_CHECKING , Any , AsyncIterator , Iterable , Optional , Self , Tuple , cast
89
910import aiohttp_cors
1011import attrs
1617from graphql import ValidationRule , parse , validate
1718from graphql .error import GraphQLError # pants: no-infer-dep
1819from graphql .execution import ExecutionResult # pants: no-infer-dep
19- from pydantic import ConfigDict , Field
20+ from pydantic import BaseModel , ConfigDict , Field
2021
2122# Import Strawberry aiohttp views
2223from ai .backend .common import validators as tx
5253log = BraceStyleAdapter (logging .getLogger (__spec__ .name ))
5354
5455
56+ # WebSocket message type enum
57+ class GraphQLWSMessageType (str , Enum ):
58+ CONNECTION_INIT = "connection_init"
59+ SUBSCRIBE = "subscribe"
60+ COMPLETE = "complete"
61+
62+
63+ # Payload types for WebSocket messages
64+ class GraphQLWSSubscribePayload (BaseModel ):
65+ query : str
66+ variables : dict [str , Any ] | None = None
67+ operationName : str | None = None
68+
69+
70+ # Union type for all WebSocket messages
71+ class GraphQLWSMessage (BaseModel ):
72+ type : GraphQLWSMessageType
73+ id : str | None = None
74+ payload : dict [str , Any ] | None = None # Will be validated in specific message types
75+
76+
77+ # Type for schema.subscribe return value - it returns either an AsyncIterator of ExecutionResult
78+ # or a single ExecutionResult for non-subscription operations
79+ SubscriptionResult = AsyncIterator [ExecutionResult ]
80+
81+
82+ # WebSocket response message types
83+ class GraphQLWSConnectionAck (BaseModel ):
84+ type : str = "connection_ack"
85+
86+
87+ class GraphQLWSNext (BaseModel ):
88+ type : str = "next"
89+ id : str
90+ payload : dict [str , Any ]
91+
92+
93+ class GraphQLWSError (BaseModel ):
94+ type : str = "error"
95+ id : str | None = None
96+ payload : list [dict [str , str ]]
97+
98+
99+ class GraphQLWSCompleteResponse (BaseModel ):
100+ type : str = "complete"
101+ id : str
102+
103+
55104class GQLLoggingMiddleware :
56105 def resolve (self , next , root , info : graphene .ResolveInfo , ** args ) -> Any :
57106 if info .path .prev is None : # indicates the root query
@@ -323,11 +372,10 @@ async def shutdown(app: web.Application) -> None:
323372
324373
325374@auth_required
326- async def handle_gql_ws (request : web .Request ) -> web .WebSocketResponse :
375+ async def handle_gql_strawberry_ws (request : web .Request ) -> web .WebSocketResponse :
327376 ws = web .WebSocketResponse (protocols = ["graphql-transport-ws" , "graphql-ws" ])
328377 await ws .prepare (request )
329378
330- # Create context once
331379 root_ctx : RootContext = request .app ["_root.context" ]
332380 processors_ctx = await ProcessorsCtx .from_request (request )
333381 context = StrawberryGQLContext (
@@ -342,35 +390,45 @@ async def handle_gql_ws(request: web.Request) -> web.WebSocketResponse:
342390
343391 async for msg in ws :
344392 if msg .type == web .WSMsgType .TEXT :
345- data = msg .json ()
346-
347- if data .get ("type" ) == "connection_init" :
348- await ws .send_str ('{"type":"connection_ack"}' )
349-
350- elif data .get ("type" ) == "subscribe" :
351- subscription_id = data .get ("id" )
352- payload = data .get ("payload" , {})
353- query = payload .get ("query" , "" )
354- variables = payload .get ("variables" , {})
355-
356- log .info (
357- "Processing subscription: {}, query: {}, variables: {}" ,
358- subscription_id ,
359- query [:30 ],
360- variables ,
361- )
362-
363- try :
364- # Execute subscription using Strawberry's subscribe method for proper AsyncGenerator handling
365- async_result = await schema .subscribe (
366- query ,
367- variable_values = variables ,
368- context_value = context ,
393+ try :
394+ # Parse and validate WebSocket message using Pydantic
395+ raw_data = msg .json ()
396+ ws_message = GraphQLWSMessage .model_validate (raw_data )
397+
398+ if ws_message .type == GraphQLWSMessageType .CONNECTION_INIT :
399+ response = GraphQLWSConnectionAck ()
400+ await ws .send_str (json .dumps (response .model_dump ()))
401+
402+ elif ws_message .type == GraphQLWSMessageType .SUBSCRIBE :
403+ if not ws_message .id or not ws_message .payload :
404+ raise ValueError ("Subscribe message requires id and payload" )
405+
406+ # Validate and parse subscription payload
407+ subscribe_payload = GraphQLWSSubscribePayload .model_validate (ws_message .payload )
408+ query = subscribe_payload .query
409+ variables = subscribe_payload .variables or {}
410+
411+ log .info (
412+ "Processing subscription: {}, query: {}, variables: {}" ,
413+ ws_message .id ,
414+ query [:30 ],
415+ variables ,
369416 )
370417
371- log .info ("Subscription subscribe result: {}" , type (async_result ))
418+ try :
419+ # Execute subscription using Strawberry's subscribe method
420+ async_result : SubscriptionResult = await schema .subscribe (
421+ query ,
422+ variable_values = variables ,
423+ context_value = context ,
424+ )
425+
426+ log .info ("Subscription subscribe result: {}" , type (async_result ))
427+
428+ if not hasattr (async_result , "__aiter__" ):
429+ # TODO: Add exception
430+ raise ValueError ("Expected an async iterator for subscription" )
372431
373- if hasattr (async_result , "__aiter__" ):
374432 log .info ("Processing subscription async generator" )
375433
376434 async for result in async_result :
@@ -382,59 +440,44 @@ async def handle_gql_ws(request: web.Request) -> web.WebSocketResponse:
382440
383441 if result .errors :
384442 log .error ("Subscription errors: {}" , result .errors )
385- await ws .send_str (
386- json .dumps ({
387- "id" : subscription_id ,
388- "type" : "error" ,
389- "payload" : [{"message" : str (e )} for e in result .errors ],
390- })
443+ error_response = GraphQLWSError (
444+ id = ws_message .id ,
445+ payload = [{"message" : str (e )} for e in result .errors ],
391446 )
447+ await ws .send_str (json .dumps (error_response .model_dump ()))
392448 break
393449 elif result .data :
394450 log .info ("Sending subscription data: {}" , result .data )
395- await ws .send_str (
396- json .dumps ({
397- "id" : subscription_id ,
398- "type" : "next" ,
399- "payload" : {"data" : result .data },
400- })
451+ next_response = GraphQLWSNext (
452+ id = ws_message .id , payload = {"data" : result .data }
401453 )
454+ await ws .send_str (json .dumps (next_response .model_dump ()))
402455
403- # Send completion
456+ # Send completion after async iterator is exhausted
404457 log .info ("Subscription completed, sending complete message" )
405- await ws .send_str (json .dumps ({"id" : subscription_id , "type" : "complete" }))
406- else :
407- # Fallback to regular execute for queries
408- log .info ("Not a subscription, using regular execute" )
409- result = async_result
410-
411- if result .errors :
412- await ws .send_str (
413- json .dumps ({
414- "id" : subscription_id ,
415- "type" : "error" ,
416- "payload" : [{"message" : str (e )} for e in result .errors ],
417- })
418- )
419- elif result .data :
420- await ws .send_str (
421- json .dumps ({
422- "id" : subscription_id ,
423- "type" : "next" ,
424- "payload" : {"data" : result .data },
425- })
426- )
427-
428- except Exception as e :
429- log .error ("Subscription execution error: {}" , e )
430- log .exception ("Full traceback:" )
431- await ws .send_str (
432- json .dumps ({
433- "id" : subscription_id ,
434- "type" : "error" ,
435- "payload" : [{"message" : str (e )}],
436- })
437- )
458+ complete_response = GraphQLWSCompleteResponse (id = ws_message .id )
459+ await ws .send_str (json .dumps (complete_response .model_dump ()))
460+
461+ except Exception as e :
462+ log .error ("Subscription execution error: {}" , e )
463+ log .exception ("Full traceback:" )
464+ error_response = GraphQLWSError (
465+ id = ws_message .id , payload = [{"message" : str (e )}]
466+ )
467+ await ws .send_str (json .dumps (error_response .model_dump ()))
468+
469+ elif ws_message .type == GraphQLWSMessageType .COMPLETE :
470+ if not ws_message .id :
471+ raise ValueError ("Complete message requires id" )
472+ log .info ("Received complete message for subscription: {}" , ws_message .id )
473+
474+ except Exception as e :
475+ # Handle message parsing and validation errors
476+ log .error ("WebSocket message validation error: {}" , e )
477+ error_response = GraphQLWSError (
478+ payload = [{"message" : f"Invalid message format: { str (e )} " }]
479+ )
480+ await ws .send_str (json .dumps (error_response .model_dump ()))
438481
439482 return ws
440483
@@ -454,7 +497,6 @@ def create_app(
454497 cors .add (
455498 app .router .add_route ("POST" , r"/gql/strawberry" , gql_api_handler .handle_gql_strawberry )
456499 )
457-
458- cors .add (app .router .add_get (r"/gql/strawberry/ws" , handle_gql_ws ))
500+ cors .add (app .router .add_get (r"/gql/strawberry/ws" , handle_gql_strawberry_ws ))
459501
460502 return app , []
0 commit comments