From 2cbf689052f78f35f3b752d568c7d3470f5f1590 Mon Sep 17 00:00:00 2001 From: jopemachine Date: Wed, 17 Sep 2025 05:44:07 +0000 Subject: [PATCH 1/4] WIP --- .../graphql-reference/supergraph.graphql | 41 ++++ .../graphql-reference/v2-schema.graphql | 33 +++ src/ai/backend/manager/api/admin.py | 128 ++++++++++ .../manager/api/gql/background_task.py | 226 ++++++++++++++++++ src/ai/backend/manager/api/gql/schema.py | 3 + src/ai/backend/manager/api/gql/types.py | 6 + src/ai/backend/manager/api/spec.py | 10 +- src/ai/backend/manager/dto/context.py | 13 +- src/ai/backend/web/proxy.py | 14 +- src/ai/backend/web/server.py | 10 + 10 files changed, 481 insertions(+), 3 deletions(-) create mode 100644 src/ai/backend/manager/api/gql/background_task.py diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 7fa786903b2..a3d97ba9533 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -614,6 +614,34 @@ type AvailableServiceNode implements Node service_variants: [String]! } +type BackgroundTaskEvent implements Node + @join__implements(graph: STRAWBERRY, interface: "Node") + @join__type(graph: STRAWBERRY) +{ + """The Globally Unique ID of this object""" + id: ID! + taskId: ID! + eventName: String! + data: JSON! + timestamp: DateTime! + retryCount: Int +} + +type BackgroundTaskProgress implements Node + @join__implements(graph: STRAWBERRY, interface: "Node") + @join__type(graph: STRAWBERRY) +{ + """The Globally Unique ID of this object""" + id: ID! + taskId: ID! + currentProgress: Float + totalProgress: Float + message: String + isCompleted: Boolean! + isFailed: Boolean! + isCancelled: Boolean! +} + """ BigInt is an extension of the regular graphene.Int scalar type to support integers outside the range of a signed 32-bit integer. @@ -2429,6 +2457,13 @@ enum join__Graph { STRAWBERRY @join__graph(name: "strawberry", url: "http://host.docker.internal:8091/admin/gql/strawberry") } +""" +The `JSON` scalar type represents JSON values as specified by [ECMA-404](https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf). +""" +scalar JSON + @join__type(graph: STRAWBERRY) + @specifiedBy(url: "https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf") + """ Allows use of a JSON String for input / output from the GraphQL schema. @@ -4572,6 +4607,9 @@ type Query """Added in 25.14.0""" defaultArtifactRegistry(artifactType: ArtifactType!): ArtifactRegistry @join__field(graph: STRAWBERRY) + + """Get current background task progress""" + backgroundTaskProgress(taskId: ID!): BackgroundTaskProgress @join__field(graph: STRAWBERRY) } type QuotaDetails @@ -5115,6 +5153,9 @@ type Subscription artifactImportProgressUpdated(artifactRevisionId: ID!): ArtifactImportProgressUpdatedPayload! deploymentStatusChanged(deploymentId: ID!): DeploymentStatusChangedPayload! replicaStatusChanged(revisionId: ID!): ReplicaStatusChangedPayload! + + """Subscribe to background task events""" + backgroundTaskEvents(taskId: ID!): BackgroundTaskEvent } """Added in 25.5.0.""" diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index c58b52a6b46..9f80495fa01 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -197,6 +197,28 @@ enum ArtifactType { IMAGE } +type BackgroundTaskEvent implements Node { + """The Globally Unique ID of this object""" + id: ID! + taskId: ID! + eventName: String! + data: JSON! + timestamp: DateTime! + retryCount: Int +} + +type BackgroundTaskProgress implements Node { + """The Globally Unique ID of this object""" + id: ID! + taskId: ID! + currentProgress: Float + totalProgress: Float + message: String + isCompleted: Boolean! + isFailed: Boolean! + isCancelled: Boolean! +} + scalar ByteSize """Added in 25.14.0""" @@ -483,6 +505,11 @@ input IntFilter { lessThanOrEqual: Int = null } +""" +The `JSON` scalar type represents JSON values as specified by [ECMA-404](https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf). +""" +scalar JSON @specifiedBy(url: "https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf") + """A custom scalar for JSON strings using orjson""" scalar JSONString @@ -879,6 +906,9 @@ type Query { """Added in 25.14.0""" defaultArtifactRegistry(artifactType: ArtifactType!): ArtifactRegistry + + """Get current background task progress""" + backgroundTaskProgress(taskId: ID!): BackgroundTaskProgress } type RawServiceConfig { @@ -1053,6 +1083,9 @@ type Subscription { artifactImportProgressUpdated(artifactRevisionId: ID!): ArtifactImportProgressUpdatedPayload! deploymentStatusChanged(deploymentId: ID!): DeploymentStatusChangedPayload! replicaStatusChanged(revisionId: ID!): ReplicaStatusChangedPayload! + + """Subscribe to background task events""" + backgroundTaskEvents(taskId: ID!): BackgroundTaskEvent } scalar UUID diff --git a/src/ai/backend/manager/api/admin.py b/src/ai/backend/manager/api/admin.py index 40b563e95c4..251e1768361 100644 --- a/src/ai/backend/manager/api/admin.py +++ b/src/ai/backend/manager/api/admin.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import logging import traceback from http import HTTPStatus @@ -17,6 +18,7 @@ from graphql.execution import ExecutionResult # pants: no-infer-dep from pydantic import ConfigDict, Field +# Import Strawberry aiohttp views from ai.backend.common import validators as tx from ai.backend.common.api_handlers import APIResponse, BodyParam, MiddlewareParam, api_handler from ai.backend.common.dto.manager.request import GraphQLReq @@ -232,6 +234,9 @@ async def handle_gql_strawberry( strawberry_ctx = StrawberryGQLContext( processors=processors_ctx.processors, config_provider=config_provider_ctx.config_provider, + event_hub=processors_ctx.event_hub, + event_fetcher=processors_ctx.event_fetcher, + valkey_bgtask=processors_ctx.valkey_bgtask, ) query, variables, operation_name = ( @@ -302,6 +307,9 @@ async def init(app: web.Application) -> None: auto_camelcase=False, ) app_ctx.gql_v2_schema = strawberry_schema + + log.info("Simple Strawberry WebSocket handler ready") + root_ctx: RootContext = app["_root.context"] if root_ctx.config_provider.config.api.allow_graphql_schema_introspection: log.warning( @@ -314,6 +322,123 @@ async def shutdown(app: web.Application) -> None: pass +@auth_required +async def handle_gql_ws(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse(protocols=["graphql-transport-ws", "graphql-ws"]) + await ws.prepare(request) + + # Create context once + root_ctx: RootContext = request.app["_root.context"] + processors_ctx = await ProcessorsCtx.from_request(request) + context = StrawberryGQLContext( + processors=processors_ctx.processors, + config_provider=root_ctx.config_provider, + event_hub=processors_ctx.event_hub, + event_fetcher=processors_ctx.event_fetcher, + valkey_bgtask=processors_ctx.valkey_bgtask, + ) + + schema = request.app["admin.context"].gql_v2_schema + + async for msg in ws: + if msg.type == web.WSMsgType.TEXT: + data = msg.json() + + if data.get("type") == "connection_init": + await ws.send_str('{"type":"connection_ack"}') + + elif data.get("type") == "subscribe": + subscription_id = data.get("id") + payload = data.get("payload", {}) + query = payload.get("query", "") + variables = payload.get("variables", {}) + + log.info( + "Processing subscription: {}, query: {}, variables: {}", + subscription_id, + query[:30], + variables, + ) + + try: + # Execute subscription using Strawberry's subscribe method for proper AsyncGenerator handling + async_result = await schema.subscribe( + query, + variable_values=variables, + context_value=context, + ) + + log.info("Subscription subscribe result: {}", type(async_result)) + + if hasattr(async_result, "__aiter__"): + log.info("Processing subscription async generator") + + async for result in async_result: + log.info( + "Subscription result: errors={}, data={}", + result.errors, + result.data, + ) + + if result.errors: + log.error("Subscription errors: {}", result.errors) + await ws.send_str( + json.dumps({ + "id": subscription_id, + "type": "error", + "payload": [{"message": str(e)} for e in result.errors], + }) + ) + break + elif result.data: + log.info("Sending subscription data: {}", result.data) + await ws.send_str( + json.dumps({ + "id": subscription_id, + "type": "next", + "payload": {"data": result.data}, + }) + ) + + # Send completion + log.info("Subscription completed, sending complete message") + await ws.send_str(json.dumps({"id": subscription_id, "type": "complete"})) + else: + # Fallback to regular execute for queries + log.info("Not a subscription, using regular execute") + result = async_result + + if result.errors: + await ws.send_str( + json.dumps({ + "id": subscription_id, + "type": "error", + "payload": [{"message": str(e)} for e in result.errors], + }) + ) + elif result.data: + await ws.send_str( + json.dumps({ + "id": subscription_id, + "type": "next", + "payload": {"data": result.data}, + }) + ) + + except Exception as e: + log.error("Subscription execution error: {}", e) + log.exception("Full traceback:") + await ws.send_str( + json.dumps({ + "id": subscription_id, + "type": "error", + "payload": [{"message": str(e)}], + }) + ) + + return ws + + def create_app( default_cors_options: CORSOptions, ) -> Tuple[web.Application, Iterable[WebMiddleware]]: @@ -329,4 +454,7 @@ def create_app( cors.add( app.router.add_route("POST", r"/gql/strawberry", gql_api_handler.handle_gql_strawberry) ) + + cors.add(app.router.add_get(r"/gql/strawberry/ws", handle_gql_ws)) + return app, [] diff --git a/src/ai/backend/manager/api/gql/background_task.py b/src/ai/backend/manager/api/gql/background_task.py new file mode 100644 index 00000000000..8a67a36ee55 --- /dev/null +++ b/src/ai/backend/manager/api/gql/background_task.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import logging +import uuid +from datetime import datetime +from typing import AsyncGenerator, Optional + +import strawberry +from strawberry import ID +from strawberry.relay import Node, NodeID +from strawberry.types import Info + +from ai.backend.common.events.hub.propagators.cache import WithCachePropagator +from ai.backend.common.events.types import EventCacheDomain, EventDomain +from ai.backend.logging import BraceStyleAdapter + +from .types import StrawberryGQLContext + +log = BraceStyleAdapter(logging.getLogger(__spec__.name)) + + +@strawberry.type +class BackgroundTaskEvent(Node): + """Background task event information""" + + id: NodeID[str] + task_id: ID + event_name: str + data: strawberry.scalars.JSON + timestamp: datetime + retry_count: Optional[int] = None + + +@strawberry.type +class BackgroundTaskProgress(Node): + """Background task progress information""" + + id: NodeID[str] + task_id: ID + current_progress: Optional[float] = None + total_progress: Optional[float] = None + message: Optional[str] = None + is_completed: bool = False + is_failed: bool = False + is_cancelled: bool = False + + +@strawberry.input +class BackgroundTaskEventFilter: + """Filter for background task events""" + + task_id: Optional[ID] = None + event_name: Optional[str] = None + + +async def _background_task_events_impl(task_id, context): + """ + Subscribe to real-time background task events for a specific task. + GraphQL Subscription equivalent of REST API /events/background-task + """ + print(f"[DEBUG] SUBSCRIPTION IMPL CALLED WITH task_id={task_id}") + log.info("=== BACKGROUND_TASK_EVENTS IMPL CALLED ===") + log.info("BACKGROUND_TASK_EVENTS: task_id={}", task_id) + + try: + # Get event hub and fetcher from context + event_hub = context.event_hub + event_fetcher = context.event_fetcher + task_uuid = uuid.UUID(task_id) + log.info("BACKGROUND_TASK_EVENTS subscription started (t:{})", task_uuid) + + # EventHub integration - similar to push_background_task_events() + propagator = WithCachePropagator(event_fetcher) + event_domain_key = (EventDomain.BGTASK, str(task_uuid)) + log.info("BACKGROUND_TASK_EVENTS: Registering propagator with key: {}", event_domain_key) + event_hub.register_event_propagator(propagator, [event_domain_key]) + + try: + cache_id = EventCacheDomain.BGTASK.cache_id(str(task_uuid)) + log.info("BACKGROUND_TASK_EVENTS: Cache ID: {}", cache_id) + + # 먼저 현재 캐시된 상태가 있는지 확인해보자 + cached_event = await event_fetcher.fetch_cached_event(cache_id) + if cached_event is not None: + user_event = cached_event.user_event() + if user_event is not None: + event_name = user_event.event_name() + if event_name is not None: + log.info("BACKGROUND_TASK_EVENTS: Yielding cached event: {}", event_name) + yield BackgroundTaskEvent( + id=ID(str(uuid.uuid4())), + task_id=task_id, + event_name=event_name, + data=user_event.user_event_mapping(), + timestamp=datetime.now(), + retry_count=user_event.retry_count(), + ) + else: + log.info("BACKGROUND_TASK_EVENTS: No cached event found for task {}", task_uuid) + + # 이제 새로운 이벤트를 기다린다 + log.info("BACKGROUND_TASK_EVENTS: Starting to listen for new events...") + async for event in propagator.receive(cache_id): + log.info("BACKGROUND_TASK_EVENTS: Received raw event: {}", event) + user_event = event.user_event() + if user_event is None: + log.warning( + "Received unsupported user event: {}", + event.event_name(), + ) + continue + + event_name = user_event.event_name() + if event_name is None: + log.warning("Event has no event_name") + continue + + log.info("BACKGROUND_TASK_EVENTS: Yielding new event: {}", event_name) + yield BackgroundTaskEvent( + id=ID(str(uuid.uuid4())), + task_id=task_id, + event_name=event_name, + data=user_event.user_event_mapping(), + timestamp=datetime.now(), + retry_count=user_event.retry_count(), + ) + + if user_event.is_close_event(): + log.debug( + "Received close event: {}", + user_event.event_name(), + ) + break + finally: + event_hub.unregister_event_propagator(propagator.id()) + log.info("BACKGROUND_TASK_EVENTS subscription ended (t:{})", task_uuid) + + except Exception as e: + log.error("BACKGROUND_TASK_EVENTS: Error in subscription: {}", e) + log.exception("Full traceback:") + # 예외가 발생해도 최소 하나의 이벤트는 yield하자 + yield BackgroundTaskEvent( + id=ID(str(uuid.uuid4())), + task_id=task_id, + event_name="error_fallback", + data={"error": str(e)}, + timestamp=datetime.now(), + retry_count=0, + ) + + +@strawberry.subscription(description="Subscribe to background task events") +async def background_task_events( + task_id: ID, + info: Info[StrawberryGQLContext], +) -> AsyncGenerator[Optional[BackgroundTaskEvent], None]: + """ + Subscribe to real-time background task events for a specific task. + GraphQL Subscription equivalent of REST API /events/background-task + """ + async for event in _background_task_events_impl(task_id, info.context): + yield event + + +@strawberry.field(description="Get current background task progress") +async def background_task_progress( + task_id: ID, + info: Info[StrawberryGQLContext], +) -> Optional[BackgroundTaskProgress]: + """ + Get current progress status of a specific background task. + Returns the latest progress information from Redis cache. + """ + task_uuid = uuid.UUID(task_id) + log.info("BACKGROUND_TASK_PROGRESS query started (t:{})", task_uuid) + + try: + # Get the cached progress event directly from Redis via EventFetcher + cache_id = EventCacheDomain.BGTASK.cache_id(str(task_uuid)) + cached_event = await info.context.event_fetcher.fetch_cached_event(cache_id) + + if cached_event is None: + log.debug("No cached progress data found for task (t:{})", task_uuid) + return None + + # Get user event from the cached event + user_event = cached_event.user_event() + if user_event is None: + log.debug("No user event found for task (t:{})", task_uuid) + return None + + event_name = user_event.event_name() + event_data = user_event.user_event_mapping() + + # Determine task status + is_completed = event_name == "bgtask_done" + is_failed = event_name == "bgtask_failed" + is_cancelled = event_name == "bgtask_cancelled" + + # Extract progress information + current_progress = None + total_progress = None + message = None + + if event_name == "bgtask_updated": + current_progress = event_data.get("current_progress") + total_progress = event_data.get("total_progress") + message = event_data.get("message") + elif is_completed or is_failed or is_cancelled: + message = event_data.get("message") + return BackgroundTaskProgress( + id=ID(str(task_uuid)), + task_id=task_id, + current_progress=current_progress, + total_progress=total_progress, + message=message, + is_completed=is_completed, + is_failed=is_failed, + is_cancelled=is_cancelled, + ) + + except Exception as e: + log.warning("Failed to get background task progress (t:{}): {}", task_uuid, e) + return None + finally: + log.info("BACKGROUND_TASK_PROGRESS query ended (t:{})", task_uuid) diff --git a/src/ai/backend/manager/api/gql/schema.py b/src/ai/backend/manager/api/gql/schema.py index ac95b353f4a..5a71ef44171 100644 --- a/src/ai/backend/manager/api/gql/schema.py +++ b/src/ai/backend/manager/api/gql/schema.py @@ -21,6 +21,7 @@ update_artifact, ) from .artifact_registry import default_artifact_registry +from .background_task import background_task_events, background_task_progress from .huggingface_registry import ( create_huggingface_registry, delete_huggingface_registry, @@ -81,6 +82,7 @@ class Query: reservoir_registry = reservoir_registry reservoir_registries = reservoir_registries default_artifact_registry = default_artifact_registry + background_task_progress = background_task_progress @strawberry.type @@ -120,6 +122,7 @@ class Subscription: artifact_import_progress_updated = artifact_import_progress_updated deployment_status_changed = deployment_status_changed replica_status_changed = replica_status_changed + background_task_events = background_task_events class CustomizedSchema(Schema): diff --git a/src/ai/backend/manager/api/gql/types.py b/src/ai/backend/manager/api/gql/types.py index b01e3cfd024..2c4598ab149 100644 --- a/src/ai/backend/manager/api/gql/types.py +++ b/src/ai/backend/manager/api/gql/types.py @@ -3,6 +3,9 @@ import attrs +from ai.backend.common.clients.valkey_client.valkey_bgtask.client import ValkeyBgtaskClient +from ai.backend.common.events.fetcher import EventFetcher +from ai.backend.common.events.hub.hub import EventHub from ai.backend.manager.config.provider import ManagerConfigProvider from ai.backend.manager.services.processors import Processors @@ -11,3 +14,6 @@ class StrawberryGQLContext: processors: Processors config_provider: ManagerConfigProvider + event_hub: EventHub + event_fetcher: EventFetcher + valkey_bgtask: ValkeyBgtaskClient diff --git a/src/ai/backend/manager/api/spec.py b/src/ai/backend/manager/api/spec.py index da7a682b9c1..6ebcc10af66 100644 --- a/src/ai/backend/manager/api/spec.py +++ b/src/ai/backend/manager/api/spec.py @@ -89,7 +89,15 @@ >