Skip to content

Commit 516b26e

Browse files
committed
refactor: Change global id generator to support type safety
1 parent fcfe7ad commit 516b26e

File tree

4 files changed

+51
-13
lines changed

4 files changed

+51
-13
lines changed

src/ai/backend/manager/api/gql/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from enum import StrEnum
66
from typing import TYPE_CHECKING, Any, Optional, Type, cast
77

8+
import graphene
89
import orjson
910
import strawberry
1011
from graphql import StringValueNode
@@ -160,7 +161,17 @@ def from_resource_slot(resource_slot: ResourceSlot) -> JSONString:
160161
return JSONString.serialize(resource_slot.to_json())
161162

162163

163-
def to_global_id(type_: Type[Any], local_id: uuid.UUID | str) -> str:
164+
def to_global_id(
165+
type_: Type[Any], local_id: uuid.UUID | str, is_target_graphene_object: bool = False
166+
) -> str:
167+
if is_target_graphene_object:
168+
# For compatibility with existing Graphene-based global IDs
169+
if not issubclass(type_, graphene.ObjectType):
170+
raise TypeError(
171+
"type_ must be a graphene ObjectType when is_target_graphene_object is True."
172+
)
173+
typename = type_.__name__
174+
return base64(f"{typename}:{local_id}")
164175
if not has_object_definition(type_):
165176
raise TypeError("type_ must be a Strawberry object type (Node or Edge).")
166177
typename = get_object_definition(type_, strict=True).name

src/ai/backend/manager/api/gql/model_deployment/model_deployment.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
ModelDeploymentStatus as CommonDeploymentStatus,
1818
)
1919
from ai.backend.common.exception import ModelDeploymentUnavailableError
20-
from ai.backend.manager.api.gql.base import OrderDirection, StringFilter, resolve_global_id
20+
from ai.backend.manager.api.gql.base import (
21+
OrderDirection,
22+
StringFilter,
23+
resolve_global_id,
24+
to_global_id,
25+
)
2126
from ai.backend.manager.api.gql.domain import Domain
2227
from ai.backend.manager.api.gql.model_deployment.access_token import (
2328
AccessToken,
@@ -46,7 +51,9 @@
4651
ReplicaSpec,
4752
ReplicaStateData,
4853
)
49-
from ai.backend.manager.models.gql_relay import AsyncNode
54+
from ai.backend.manager.models.gql_models.domain import DomainNode
55+
from ai.backend.manager.models.gql_models.group import GroupNode
56+
from ai.backend.manager.models.gql_models.user import UserNode
5057
from ai.backend.manager.services.deployment.actions.auto_scaling_rule.get_auto_scaling_rule_by_deployment_id import (
5158
GetAutoScalingRulesByDeploymentIdAction,
5259
)
@@ -185,12 +192,16 @@ class ModelDeploymentMetadata:
185192

186193
@strawberry.field
187194
async def project(self, info: Info[StrawberryGQLContext]) -> Project:
188-
project_global_id = AsyncNode.to_global_id("GroupNode", self._project_id)
195+
project_global_id = to_global_id(
196+
GroupNode, self._project_id, is_target_graphene_object=True
197+
)
189198
return Project(id=ID(project_global_id))
190199

191200
@strawberry.field
192201
async def domain(self, info: Info[StrawberryGQLContext]) -> Domain:
193-
domain_global_id = AsyncNode.to_global_id("DomainNode", self._domain_name)
202+
domain_global_id = to_global_id(
203+
DomainNode, self._domain_name, is_target_graphene_object=True
204+
)
194205
return Domain(id=ID(domain_global_id))
195206

196207
@classmethod
@@ -266,7 +277,9 @@ class ModelDeployment(Node):
266277

267278
@strawberry.field
268279
async def created_user(self, info: Info[StrawberryGQLContext]) -> User:
269-
user_global_id = AsyncNode.to_global_id("UserNode", self._created_user_id)
280+
user_global_id = to_global_id(
281+
UserNode, self._created_user_id, is_target_graphene_object=True
282+
)
270283
return User(id=strawberry.ID(user_global_id))
271284

272285
@strawberry.field

src/ai/backend/manager/api/gql/model_deployment/model_replica.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414
from ai.backend.common.data.model_deployment.types import LivenessStatus as CommonLivenessStatus
1515
from ai.backend.common.data.model_deployment.types import ReadinessStatus as CommonReadinessStatus
1616
from ai.backend.common.exception import ModelDeploymentUnavailableError
17-
from ai.backend.manager.api.gql.base import JSONString, OrderDirection, resolve_global_id
17+
from ai.backend.manager.api.gql.base import (
18+
JSONString,
19+
OrderDirection,
20+
resolve_global_id,
21+
to_global_id,
22+
)
1823
from ai.backend.manager.api.gql.session import Session
1924
from ai.backend.manager.api.gql.types import StrawberryGQLContext
2025
from ai.backend.manager.data.deployment.types import ModelReplicaData
21-
from ai.backend.manager.models.gql_relay import AsyncNode
26+
from ai.backend.manager.models.gql_models.session import ComputeSessionNode
2227
from ai.backend.manager.services.deployment.actions.get_replicas_by_deployment_id import (
2328
GetReplicasByDeploymentIdAction,
2429
)
@@ -119,7 +124,9 @@ class ModelReplica(Node):
119124
description="The session ID associated with the replica. This can be null right after replica creation."
120125
)
121126
async def session(self, info: Info[StrawberryGQLContext]) -> "Session":
122-
session_global_id = AsyncNode.to_global_id("ComputeSessionNode", self._session_id)
127+
session_global_id = to_global_id(
128+
ComputeSessionNode, self._session_id, is_target_graphene_object=True
129+
)
123130
return Session(id=ID(session_global_id))
124131

125132
@strawberry.field

src/ai/backend/manager/api/gql/model_deployment/model_revision.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
OrderDirection,
2323
StringFilter,
2424
resolve_global_id,
25+
to_global_id,
2526
)
2627
from ai.backend.manager.api.gql.image import (
2728
Image,
@@ -55,7 +56,9 @@
5556
ResourceSpec,
5657
)
5758
from ai.backend.manager.data.image.types import ImageIdentifier
58-
from ai.backend.manager.models.gql_relay import AsyncNode
59+
from ai.backend.manager.models.gql_models.image import ImageNode
60+
from ai.backend.manager.models.gql_models.scaling_group import ScalingGroupNode
61+
from ai.backend.manager.models.gql_models.vfolder import VirtualFolderNode
5962
from ai.backend.manager.services.deployment.actions.model_revision.add_model_revision import (
6063
AddModelRevisionAction,
6164
)
@@ -91,7 +94,9 @@ class ModelMountConfig:
9194

9295
@strawberry.field
9396
async def vfolder(self, info: Info[StrawberryGQLContext]) -> VFolder:
94-
vfolder_global_id = AsyncNode.to_global_id("VirtualFolderNode", self._vfolder_id)
97+
vfolder_global_id = to_global_id(
98+
VirtualFolderNode, self._vfolder_id, is_target_graphene_object=True
99+
)
95100
return VFolder(id=ID(vfolder_global_id))
96101

97102
@classmethod
@@ -135,7 +140,9 @@ class ResourceConfig:
135140
@strawberry.field
136141
def resource_group(self) -> "ResourceGroup":
137142
"""Resolves the federated ResourceGroup."""
138-
global_id = AsyncNode.to_global_id("ScalingGroupNode", self._resource_group_name)
143+
global_id = to_global_id(
144+
ScalingGroupNode, self._resource_group_name, is_target_graphene_object=True
145+
)
139146
return ResourceGroup(id=ID(global_id))
140147

141148
@classmethod
@@ -174,7 +181,7 @@ class ModelRevision(Node):
174181

175182
@strawberry.field
176183
async def image(self, info: Info[StrawberryGQLContext]) -> Image:
177-
image_global_id = AsyncNode.to_global_id("ImageNode", self._image_id)
184+
image_global_id = to_global_id(ImageNode, self._image_id, is_target_graphene_object=True)
178185
return Image(id=ID(image_global_id))
179186

180187
@classmethod

0 commit comments

Comments
 (0)