Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/manager/graphql-reference/supergraph.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,34 @@ type AvailableServiceNode implements Node
service_variants: [String]!
}

type BackgroundTaskEvent implements Node
@join__implements(graph: STRAWBERRY, interface: "Node")
@join__type(graph: STRAWBERRY)
{
"""The Globally Unique ID of this object"""
id: ID!
taskId: ID!
eventName: String!
data: JSON!
timestamp: DateTime!
retryCount: Int
}

type BackgroundTaskProgress implements Node
@join__implements(graph: STRAWBERRY, interface: "Node")
@join__type(graph: STRAWBERRY)
{
"""The Globally Unique ID of this object"""
id: ID!
taskId: ID!
currentProgress: Float
totalProgress: Float
message: String
isCompleted: Boolean!
isFailed: Boolean!
isCancelled: Boolean!
}

"""
BigInt is an extension of the regular graphene.Int scalar type
to support integers outside the range of a signed 32-bit integer.
Expand Down Expand Up @@ -2429,6 +2457,13 @@ enum join__Graph {
STRAWBERRY @join__graph(name: "strawberry", url: "http://host.docker.internal:8091/admin/gql/strawberry")
}

"""
The `JSON` scalar type represents JSON values as specified by [ECMA-404](https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf).
"""
scalar JSON
@join__type(graph: STRAWBERRY)
@specifiedBy(url: "https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf")

"""
Allows use of a JSON String for input / output from the GraphQL schema.

Expand Down Expand Up @@ -4572,6 +4607,9 @@ type Query

"""Added in 25.14.0"""
defaultArtifactRegistry(artifactType: ArtifactType!): ArtifactRegistry @join__field(graph: STRAWBERRY)

"""Get current background task progress"""
backgroundTaskProgress(taskId: ID!): BackgroundTaskProgress @join__field(graph: STRAWBERRY)
}

type QuotaDetails
Expand Down Expand Up @@ -5115,6 +5153,9 @@ type Subscription
artifactImportProgressUpdated(artifactRevisionId: ID!): ArtifactImportProgressUpdatedPayload!
deploymentStatusChanged(deploymentId: ID!): DeploymentStatusChangedPayload!
replicaStatusChanged(revisionId: ID!): ReplicaStatusChangedPayload!

"""Subscribe to background task events"""
backgroundTaskEvents(taskId: ID!): BackgroundTaskEvent
}

"""Added in 25.5.0."""
Expand Down
33 changes: 33 additions & 0 deletions docs/manager/graphql-reference/v2-schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,28 @@ enum ArtifactType {
IMAGE
}

type BackgroundTaskEvent implements Node {
"""The Globally Unique ID of this object"""
id: ID!
taskId: ID!
eventName: String!
data: JSON!
timestamp: DateTime!
retryCount: Int
}

type BackgroundTaskProgress implements Node {
"""The Globally Unique ID of this object"""
id: ID!
taskId: ID!
currentProgress: Float
totalProgress: Float
message: String
isCompleted: Boolean!
isFailed: Boolean!
isCancelled: Boolean!
}

scalar ByteSize

"""Added in 25.14.0"""
Expand Down Expand Up @@ -483,6 +505,11 @@ input IntFilter {
lessThanOrEqual: Int = null
}

"""
The `JSON` scalar type represents JSON values as specified by [ECMA-404](https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf).
"""
scalar JSON @specifiedBy(url: "https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf")

"""A custom scalar for JSON strings using orjson"""
scalar JSONString

Expand Down Expand Up @@ -879,6 +906,9 @@ type Query {

"""Added in 25.14.0"""
defaultArtifactRegistry(artifactType: ArtifactType!): ArtifactRegistry

"""Get current background task progress"""
backgroundTaskProgress(taskId: ID!): BackgroundTaskProgress
}

type RawServiceConfig {
Expand Down Expand Up @@ -1053,6 +1083,9 @@ type Subscription {
artifactImportProgressUpdated(artifactRevisionId: ID!): ArtifactImportProgressUpdatedPayload!
deploymentStatusChanged(deploymentId: ID!): DeploymentStatusChangedPayload!
replicaStatusChanged(revisionId: ID!): ReplicaStatusChangedPayload!

"""Subscribe to background task events"""
backgroundTaskEvents(taskId: ID!): BackgroundTaskEvent
}

scalar UUID
Expand Down
38 changes: 38 additions & 0 deletions docs/manager/rest-reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -5005,6 +5005,44 @@
"description": "\n**Preconditions:**\n* User privilege required.\n"
}
},
"//gql/strawberry/ws": {
"head": {
"operationId": "root.handle_gql_ws",
"tags": [
"root"
],
"responses": {
"200": {
"description": "Successful response"
}
},
"security": [
{
"TokenAuth": []
}
],
"parameters": [],
"description": "\n**Preconditions:**\n* User privilege required.\n"
},
"get": {
"operationId": "root.handle_gql_ws.2",
"tags": [
"root"
],
"responses": {
"200": {
"description": "Successful response"
}
},
"security": [
{
"TokenAuth": []
}
],
"parameters": [],
"description": "\n**Preconditions:**\n* User privilege required.\n"
}
},
"/spec/graphiql": {
"get": {
"operationId": "spec.render_graphiql_graphene_html",
Expand Down
128 changes: 128 additions & 0 deletions src/ai/backend/manager/api/admin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import logging
import traceback
from http import HTTPStatus
Expand All @@ -17,6 +18,7 @@
from graphql.execution import ExecutionResult # pants: no-infer-dep
from pydantic import ConfigDict, Field

# Import Strawberry aiohttp views
from ai.backend.common import validators as tx
from ai.backend.common.api_handlers import APIResponse, BodyParam, MiddlewareParam, api_handler
from ai.backend.common.dto.manager.request import GraphQLReq
Expand Down Expand Up @@ -232,6 +234,9 @@ async def handle_gql_strawberry(
strawberry_ctx = StrawberryGQLContext(
processors=processors_ctx.processors,
config_provider=config_provider_ctx.config_provider,
event_hub=processors_ctx.event_hub,
event_fetcher=processors_ctx.event_fetcher,
valkey_bgtask=processors_ctx.valkey_bgtask,
)

query, variables, operation_name = (
Expand Down Expand Up @@ -302,6 +307,9 @@ async def init(app: web.Application) -> None:
auto_camelcase=False,
)
app_ctx.gql_v2_schema = strawberry_schema

log.info("Simple Strawberry WebSocket handler ready")

root_ctx: RootContext = app["_root.context"]
if root_ctx.config_provider.config.api.allow_graphql_schema_introspection:
log.warning(
Expand All @@ -314,6 +322,123 @@ async def shutdown(app: web.Application) -> None:
pass


@auth_required
async def handle_gql_ws(request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse(protocols=["graphql-transport-ws", "graphql-ws"])
await ws.prepare(request)

# Create context once
root_ctx: RootContext = request.app["_root.context"]
processors_ctx = await ProcessorsCtx.from_request(request)
context = StrawberryGQLContext(
processors=processors_ctx.processors,
config_provider=root_ctx.config_provider,
event_hub=processors_ctx.event_hub,
event_fetcher=processors_ctx.event_fetcher,
valkey_bgtask=processors_ctx.valkey_bgtask,
)

schema = request.app["admin.context"].gql_v2_schema

async for msg in ws:
if msg.type == web.WSMsgType.TEXT:
data = msg.json()

if data.get("type") == "connection_init":
await ws.send_str('{"type":"connection_ack"}')

elif data.get("type") == "subscribe":
subscription_id = data.get("id")
payload = data.get("payload", {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the format of how values are passed from GraphQL and define it with Pydantic.

query = payload.get("query", "")
variables = payload.get("variables", {})

log.info(
"Processing subscription: {}, query: {}, variables: {}",
subscription_id,
query[:30],
variables,
)

try:
# Execute subscription using Strawberry's subscribe method for proper AsyncGenerator handling
async_result = await schema.subscribe(
query,
variable_values=variables,
context_value=context,
)

log.info("Subscription subscribe result: {}", type(async_result))

if hasattr(async_result, "__aiter__"):
log.info("Processing subscription async generator")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If __aiter__ is not implemented, just leave it to raise.


async for result in async_result:
log.info(
"Subscription result: errors={}, data={}",
result.errors,
result.data,
)

if result.errors:
log.error("Subscription errors: {}", result.errors)
await ws.send_str(
json.dumps({
"id": subscription_id,
"type": "error",
"payload": [{"message": str(e)} for e in result.errors],
})
)
break
elif result.data:
log.info("Sending subscription data: {}", result.data)
await ws.send_str(
json.dumps({
"id": subscription_id,
"type": "next",
"payload": {"data": result.data},
})
)

# Send completion
log.info("Subscription completed, sending complete message")
await ws.send_str(json.dumps({"id": subscription_id, "type": "complete"}))
else:
# Fallback to regular execute for queries
log.info("Not a subscription, using regular execute")
result = async_result

if result.errors:
await ws.send_str(
json.dumps({
"id": subscription_id,
"type": "error",
"payload": [{"message": str(e)} for e in result.errors],
})
)
elif result.data:
await ws.send_str(
json.dumps({
"id": subscription_id,
"type": "next",
"payload": {"data": result.data},
})
)

except Exception as e:
log.error("Subscription execution error: {}", e)
log.exception("Full traceback:")
await ws.send_str(
json.dumps({
"id": subscription_id,
"type": "error",
"payload": [{"message": str(e)}],
})
)

return ws


def create_app(
default_cors_options: CORSOptions,
) -> Tuple[web.Application, Iterable[WebMiddleware]]:
Expand All @@ -329,4 +454,7 @@ def create_app(
cors.add(
app.router.add_route("POST", r"/gql/strawberry", gql_api_handler.handle_gql_strawberry)
)

cors.add(app.router.add_get(r"/gql/strawberry/ws", handle_gql_ws))

return app, []
Loading
Loading