diff --git a/changes/5672.feature.md b/changes/5672.feature.md new file mode 100644 index 00000000000..8d38f74997a --- /dev/null +++ b/changes/5672.feature.md @@ -0,0 +1 @@ +Implement API Layer of Model Deployment diff --git a/docs/manager/graphql-reference/schema.graphql b/docs/manager/graphql-reference/schema.graphql index 2fbe1f376c0..b581080da89 100644 --- a/docs/manager/graphql-reference/schema.graphql +++ b/docs/manager/graphql-reference/schema.graphql @@ -128,7 +128,7 @@ type Query { """Added in 24.03.1""" id: String reference: String - architecture: String = "aarch64" + architecture: String = "x86_64" ): Image images( """ @@ -2341,7 +2341,7 @@ type Mutation { ): RescanImages preload_image(references: [String]!, target_agents: [String]!): PreloadImage unload_image(references: [String]!, target_agents: [String]!): UnloadImage - modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage + modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage """Added in 25.6.0""" clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload @@ -2350,7 +2350,7 @@ type Mutation { forget_image_by_id(image_id: String!): ForgetImageById """Deprecated since 25.4.0. Use `forget_image_by_id` instead.""" - forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") + forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") """Added in 25.4.0""" purge_image_by_id( @@ -2362,7 +2362,7 @@ type Mutation { """Added in 24.03.1""" untag_image_from_registry(image_id: String!): UntagImageFromRegistry - alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage + alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage dealias_image(alias: String!): DealiasImage clear_images(registry: String): ClearImages @@ -2937,7 +2937,7 @@ type ClearImageCustomResourceLimitPayload { """Added in 25.6.0.""" input ClearImageCustomResourceLimitKey { image_canonical: String! - architecture: String! = "aarch64" + architecture: String! = "x86_64" } """Added in 24.03.0.""" diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index c3e743af019..bce6104960c 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -30,17 +30,17 @@ type AccessToken implements Node """The Globally Unique ID of this object""" id: ID! - """Added in 25.13.0: The access token.""" + """Added in 25.16.0: The access token.""" token: String! - """Added in 25.13.0: The creation timestamp of the access token.""" + """Added in 25.16.0: The creation timestamp of the access token.""" createdAt: DateTime! - """Added in 25.13.0: The expiration timestamp of the access token.""" + """Added in 25.16.0: The expiration timestamp of the access token.""" validUntil: DateTime! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type AccessTokenConnection @join__type(graph: STRAWBERRY) { @@ -63,8 +63,22 @@ type AccessTokenEdge node: AccessToken! } +"""Added in 25.16.0""" +input AccessTokenOrderBy + @join__type(graph: STRAWBERRY) +{ + field: AccessTokenOrderField! + direction: OrderDirection! = DESC +} + +enum AccessTokenOrderField + @join__type(graph: STRAWBERRY) +{ + CREATED_AT @join__enumValue(graph: STRAWBERRY) +} + """ -Added in 25.13.0. This enum represents the activeness status of a replica, indicating whether the deployment is currently active and able to serve requests. +Added in 25.16.0. This enum represents the activeness status of a replica, indicating whether the deployment is currently active and able to serve requests. """ enum ActivenessStatus @join__type(graph: STRAWBERRY) @@ -73,7 +87,7 @@ enum ActivenessStatus INACTIVE @join__enumValue(graph: STRAWBERRY) } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ActivenessStatusFilter @join__type(graph: STRAWBERRY) { @@ -81,7 +95,7 @@ input ActivenessStatusFilter equals: ActivenessStatus = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input AddModelRevisionInput @join__type(graph: STRAWBERRY) { @@ -95,7 +109,7 @@ input AddModelRevisionInput extraMounts: [ExtraVFolderMountInput!] } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type AddModelRevisionPayload @join__type(graph: STRAWBERRY) { @@ -847,26 +861,26 @@ type AutoScalingRule implements Node """The Globally Unique ID of this object""" id: ID! - """Added in 25.13.0 (e.g. KERNEL, INFERENCE_FRAMEWORK)""" + """Added in 25.16.0 (e.g. KERNEL, INFERENCE_FRAMEWORK)""" metricSource: AutoScalingMetricSource! metricName: String! - """Added in 25.13.0: The minimum threshold for scaling (e.g. 0.5)""" + """Added in 25.16.0: The minimum threshold for scaling (e.g. 0.5)""" minThreshold: Decimal - """Added in 25.13.0: The maximum threshold for scaling (e.g. 21.0)""" + """Added in 25.16.0: The maximum threshold for scaling (e.g. 21.0)""" maxThreshold: Decimal - """Added in 25.13.0: The step size for scaling (e.g. 1).""" + """Added in 25.16.0: The step size for scaling (e.g. 1).""" stepSize: Int! - """Added in 25.13.0: The time window (seconds) for scaling (e.g. 60).""" + """Added in 25.16.0: The time window (seconds) for scaling (e.g. 60).""" timeWindow: Int! - """Added in 25.13.0: The minimum number of replicas (e.g. 1).""" + """Added in 25.16.0: The minimum number of replicas (e.g. 1).""" minReplicas: Int - """Added in 25.13.0: The maximum number of replicas (e.g. 10).""" + """Added in 25.16.0: The maximum number of replicas (e.g. 10).""" maxReplicas: Int createdAt: DateTime! lastTriggeredAt: DateTime! @@ -1022,7 +1036,7 @@ input ClearImageCustomResourceLimitKey @join__type(graph: GRAPHENE) { image_canonical: String! - architecture: String! = "aarch64" + architecture: String! = "x86_64" } """Added in 25.6.0.""" @@ -1039,7 +1053,7 @@ type ClearImages msg: String } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ClusterConfig @join__type(graph: STRAWBERRY) { @@ -1047,7 +1061,7 @@ type ClusterConfig size: Int! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ClusterConfigInput @join__type(graph: STRAWBERRY) { @@ -1055,7 +1069,7 @@ input ClusterConfigInput size: Int! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" enum ClusterMode @join__type(graph: STRAWBERRY) { @@ -1417,11 +1431,11 @@ input CreateAccessTokenInput @join__type(graph: STRAWBERRY) { """ - Added in 25.13.0: The ID of the model deployment for which the access token is created. + Added in 25.16.0: The ID of the model deployment for which the access token is created. """ modelDeploymentId: ID! - """Added in 25.13.0: The expiration timestamp of the access token.""" + """Added in 25.16.0: The expiration timestamp of the access token.""" validUntil: DateTime! } @@ -1633,7 +1647,7 @@ input CreateKeyPairResourcePolicyInput max_pending_session_resource_slots: JSONString } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input CreateModelDeploymentInput @join__type(graph: STRAWBERRY) { @@ -1644,14 +1658,14 @@ input CreateModelDeploymentInput initialRevision: CreateModelRevisionInput! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type CreateModelDeploymentPayload @join__type(graph: STRAWBERRY) { deployment: ModelDeployment! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input CreateModelRevisionInput @join__type(graph: STRAWBERRY) { @@ -1664,7 +1678,7 @@ input CreateModelRevisionInput extraMounts: [ExtraVFolderMountInput!] } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type CreateModelRevisionPayload @join__type(graph: STRAWBERRY) { @@ -2074,14 +2088,14 @@ type DeleteKeyPairResourcePolicy msg: String } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeleteModelDeploymentInput @join__type(graph: STRAWBERRY) { id: ID! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type DeleteModelDeploymentPayload @join__type(graph: STRAWBERRY) { @@ -2178,7 +2192,7 @@ type DeleteVFSStoragePayload id: ID! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentFilter @join__type(graph: STRAWBERRY) { @@ -2193,7 +2207,7 @@ input DeploymentFilter NOT: [DeploymentFilter!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentOrderBy @join__type(graph: STRAWBERRY) { @@ -2201,7 +2215,6 @@ input DeploymentOrderBy direction: OrderDirection! = DESC } -"""Added in 25.13.0""" enum DeploymentOrderField @join__type(graph: STRAWBERRY) { @@ -2211,7 +2224,7 @@ enum DeploymentOrderField } """ -Added in 25.13.0. This enum represents the deployment status of a model deployment, indicating its current state. +Added in 25.16.0. This enum represents the deployment status of a model deployment, indicating its current state. """ enum DeploymentStatus @join__type(graph: STRAWBERRY) @@ -2224,14 +2237,14 @@ enum DeploymentStatus STOPPED @join__enumValue(graph: STRAWBERRY) } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type DeploymentStatusChangedPayload @join__type(graph: STRAWBERRY) { deployment: ModelDeployment! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentStatusFilter @join__type(graph: STRAWBERRY) { @@ -2239,14 +2252,14 @@ input DeploymentStatusFilter equals: DeploymentStatus = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type DeploymentStrategy @join__type(graph: STRAWBERRY) { type: DeploymentStrategyType! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentStrategyInput @join__type(graph: STRAWBERRY) { @@ -2254,7 +2267,7 @@ input DeploymentStrategyInput } """ -Added in 25.13.0. This enum represents the deployment strategy type of a model deployment, indicating the strategy used for deployment. +Added in 25.16.0. This enum represents the deployment strategy type of a model deployment, indicating the strategy used for deployment. """ enum DeploymentStrategyType @join__type(graph: STRAWBERRY) @@ -2589,7 +2602,7 @@ type ExtraVFolderMount implements Node vfolder: VirtualFolderNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ExtraVFolderMountConnection @join__type(graph: STRAWBERRY) { @@ -2612,7 +2625,7 @@ type ExtraVFolderMountEdge node: ExtraVFolderMount! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ExtraVFolderMountInput @join__type(graph: STRAWBERRY) { @@ -2889,7 +2902,7 @@ type ImageEdge cursor: String! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ImageInput @join__type(graph: STRAWBERRY) { @@ -3303,7 +3316,7 @@ enum link__Purpose { } """ -Added in 25.13.0. This enum represents the liveness status of a replica, indicating whether the deployment is currently running and able to serve requests. +Added in 25.16.0. This enum represents the liveness status of a replica, indicating whether the deployment is currently running and able to serve requests. """ enum LivenessStatus @join__type(graph: STRAWBERRY) @@ -3314,7 +3327,7 @@ enum LivenessStatus DEGRADED @join__enumValue(graph: STRAWBERRY) } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input LivenessStatusFilter @join__type(graph: STRAWBERRY) { @@ -3399,7 +3412,7 @@ type ModelCardEdge cursor: String! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeployment implements Node @join__implements(graph: STRAWBERRY, interface: "Node") @join__type(graph: STRAWBERRY) @@ -3409,14 +3422,14 @@ type ModelDeployment implements Node metadata: ModelDeploymentMetadata! networkAccess: ModelDeploymentNetworkAccess! revision: ModelRevision - scalingRule: ScalingRule! - replicaState: ReplicaState! defaultDeploymentStrategy: DeploymentStrategy! createdUser: UserNode! + scalingRule: ScalingRule! + replicaState: ReplicaState! revisionHistory(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeploymentConnection @join__type(graph: STRAWBERRY) { @@ -3439,20 +3452,20 @@ type ModelDeploymentEdge node: ModelDeployment! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeploymentMetadata @join__type(graph: STRAWBERRY) { name: String! status: DeploymentStatus! tags: [String!]! - project: GroupNode! - domain: DomainNode! createdAt: DateTime! updatedAt: DateTime! + project: GroupNode! + domain: DomainNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelDeploymentMetadataInput @join__type(graph: STRAWBERRY) { @@ -3462,17 +3475,17 @@ input ModelDeploymentMetadataInput tags: [String!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeploymentNetworkAccess @join__type(graph: STRAWBERRY) { endpointUrl: String preferredDomainName: String openToPublic: Boolean! - accessTokens: AccessTokenConnection! + accessTokens(orderBy: [AccessTokenOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): AccessTokenConnection! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelDeploymentNetworkAccessInput @join__type(graph: STRAWBERRY) { @@ -3480,16 +3493,16 @@ input ModelDeploymentNetworkAccessInput openToPublic: Boolean! = false } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelMountConfig @join__type(graph: STRAWBERRY) { - vfolder: VirtualFolderNode! mountDestination: String! definitionPath: String! + vfolder: VirtualFolderNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelMountConfigInput @join__type(graph: STRAWBERRY) { @@ -3498,14 +3511,13 @@ input ModelMountConfigInput definitionPath: String! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelReplica implements Node @join__implements(graph: STRAWBERRY, interface: "Node") @join__type(graph: STRAWBERRY) { """The Globally Unique ID of this object""" id: ID! - revision: ModelRevision! """ This represents whether the replica has been checked and its health state. @@ -3538,9 +3550,10 @@ type ModelReplica implements Node The session ID associated with the replica. This can be null right after replica creation. """ session: ComputeSessionNode! + revision: ModelRevision! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelReplicaConnection @join__type(graph: STRAWBERRY) { @@ -3563,7 +3576,7 @@ type ModelReplicaEdge node: ModelReplica! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelRevision implements Node @join__implements(graph: STRAWBERRY, interface: "Node") @join__type(graph: STRAWBERRY) @@ -3576,11 +3589,11 @@ type ModelRevision implements Node modelRuntimeConfig: ModelRuntimeConfig! modelMountConfig: ModelMountConfig! extraMounts: ExtraVFolderMountConnection! - image: ImageNode! createdAt: DateTime! + image: ImageNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelRevisionConnection @join__type(graph: STRAWBERRY) { @@ -3603,7 +3616,7 @@ type ModelRevisionEdge node: ModelRevision! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelRevisionFilter @join__type(graph: STRAWBERRY) { @@ -3615,7 +3628,7 @@ input ModelRevisionFilter NOT: [ModelRevisionFilter!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelRevisionOrderBy @join__type(graph: STRAWBERRY) { @@ -3623,16 +3636,14 @@ input ModelRevisionOrderBy direction: OrderDirection! = DESC } -"""Added in 25.13.0""" enum ModelRevisionOrderField @join__type(graph: STRAWBERRY) { CREATED_AT @join__enumValue(graph: STRAWBERRY) NAME @join__enumValue(graph: STRAWBERRY) - ID @join__enumValue(graph: STRAWBERRY) } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelRuntimeConfig @join__type(graph: STRAWBERRY) { @@ -3645,7 +3656,7 @@ type ModelRuntimeConfig environ: JSONString } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelRuntimeConfigInput @join__type(graph: STRAWBERRY) { @@ -4252,7 +4263,7 @@ type Mutation ): RescanImages @join__field(graph: GRAPHENE) preload_image(references: [String]!, target_agents: [String]!): PreloadImage @join__field(graph: GRAPHENE) unload_image(references: [String]!, target_agents: [String]!): UnloadImage @join__field(graph: GRAPHENE) - modify_image(architecture: String = "aarch64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE) + modify_image(architecture: String = "x86_64", props: ModifyImageInput!, target: String!): ModifyImage @join__field(graph: GRAPHENE) """Added in 25.6.0""" clear_image_custom_resource_limit(key: ClearImageCustomResourceLimitKey!): ClearImageCustomResourceLimitPayload @join__field(graph: GRAPHENE) @@ -4261,7 +4272,7 @@ type Mutation forget_image_by_id(image_id: String!): ForgetImageById @join__field(graph: GRAPHENE) """Deprecated since 25.4.0. Use `forget_image_by_id` instead.""" - forget_image(architecture: String = "aarch64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") + forget_image(architecture: String = "x86_64", reference: String!): ForgetImage @join__field(graph: GRAPHENE) @deprecated(reason: "Deprecated since 25.4.0. Use `forget_image_by_id` instead.") """Added in 25.4.0""" purge_image_by_id( @@ -4273,7 +4284,7 @@ type Mutation """Added in 24.03.1""" untag_image_from_registry(image_id: String!): UntagImageFromRegistry @join__field(graph: GRAPHENE) - alias_image(alias: String!, architecture: String = "aarch64", target: String!): AliasImage @join__field(graph: GRAPHENE) + alias_image(alias: String!, architecture: String = "x86_64", target: String!): AliasImage @join__field(graph: GRAPHENE) dealias_image(alias: String!): DealiasImage @join__field(graph: GRAPHENE) clear_images(registry: String): ClearImages @join__field(graph: GRAPHENE) @@ -4596,24 +4607,26 @@ type Mutation """ cancelImportArtifact(input: CancelArtifactInput!): CancelImportArtifactPayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" createModelDeployment(input: CreateModelDeploymentInput!): CreateModelDeploymentPayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" updateModelDeployment(input: UpdateModelDeploymentInput!): UpdateModelDeploymentPayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" deleteModelDeployment(input: DeleteModelDeploymentInput!): DeleteModelDeploymentPayload! @join__field(graph: STRAWBERRY) """ - Added in 25.13.0. Force syncs up-to-date replica information. In normal situations this will be automatically handled by Backend.AI schedulers + Added in 25.16.0. Force syncs up-to-date replica information. In normal situations this will be automatically handled by Backend.AI schedulers """ syncReplicas(input: SyncReplicaInput!): SyncReplicaPayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """ + Added in 25.16.0. Create model revision which is not attached to any deployment. + """ createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload! @join__field(graph: STRAWBERRY) """Added in 25.14.0""" @@ -4622,13 +4635,13 @@ type Mutation """Added in 25.14.0""" updateObjectStorage(input: UpdateObjectStorageInput!): UpdateObjectStoragePayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" createAutoScalingRule(input: CreateAutoScalingRuleInput!): CreateAutoScalingRulePayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" updateAutoScalingRule(input: UpdateAutoScalingRuleInput!): UpdateAutoScalingRulePayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" deleteAutoScalingRule(input: DeleteAutoScalingRuleInput!): DeleteAutoScalingRulePayload! @join__field(graph: STRAWBERRY) """Added in 25.14.0""" @@ -4701,7 +4714,7 @@ type Mutation """ rejectArtifactRevision(input: RejectArtifactInput!): RejectArtifactPayload! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" createAccessToken(input: CreateAccessTokenInput!): CreateAccessTokenPayload! @join__field(graph: STRAWBERRY) } @@ -5088,7 +5101,7 @@ type Query """Added in 24.03.1""" id: String reference: String - architecture: String = "aarch64" + architecture: String = "x86_64" ): Image @join__field(graph: GRAPHENE) images( """ @@ -5392,22 +5405,22 @@ type Query """ artifactRevisions(filter: ArtifactRevisionFilter = null, orderBy: [ArtifactRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ArtifactRevisionConnection! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" deployments(filter: DeploymentFilter = null, orderBy: [DeploymentOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelDeploymentConnection! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" deployment(id: ID!): ModelDeployment @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" revisions(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" revision(id: ID!): ModelRevision! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" replicas(filter: ReplicaFilter = null, orderBy: [ReplicaOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelReplicaConnection! @join__field(graph: STRAWBERRY) - """Added in 25.13.0""" + """Added in 25.16.0""" replica(id: ID!): ModelReplica @join__field(graph: STRAWBERRY) """Added in 25.14.0""" @@ -5441,11 +5454,11 @@ type Query agentStats: AgentStats! @join__field(graph: STRAWBERRY) """ - Added in 25.13.0 Get configuration JSON Schemas for all inference runtimes + Added in 25.16.0 Get configuration JSON Schemas for all inference runtimes """ inferenceRuntimeConfigs: JSON! @join__field(graph: STRAWBERRY) - """Added in 25.13.0. Get JSON Schema for inference runtime configuration""" + """Added in 25.16.0. Get JSON Schema for inference runtime configuration""" inferenceRuntimeConfig(name: String!): JSON! @join__field(graph: STRAWBERRY) } @@ -5474,7 +5487,7 @@ input QuotaScopeInput } """ -Added in 25.13.0. This enum represents the readiness status of a replica, indicating whether the deployment has been checked and its health state. +Added in 25.16.0. This enum represents the readiness status of a replica, indicating whether the deployment has been checked and its health state. """ enum ReadinessStatus @join__type(graph: STRAWBERRY) @@ -5484,7 +5497,7 @@ enum ReadinessStatus UNHEALTHY @join__enumValue(graph: STRAWBERRY) } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ReadinessStatusFilter @join__type(graph: STRAWBERRY) { @@ -5541,7 +5554,7 @@ type RejectArtifactPayload artifactRevision: ArtifactRevision! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ReplicaFilter @join__type(graph: STRAWBERRY) { @@ -5554,7 +5567,7 @@ input ReplicaFilter NOT: [ReplicaFilter!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ReplicaOrderBy @join__type(graph: STRAWBERRY) { @@ -5562,14 +5575,14 @@ input ReplicaOrderBy direction: OrderDirection! = DESC } -"""Added in 25.13.0""" enum ReplicaOrderField @join__type(graph: STRAWBERRY) { CREATED_AT @join__enumValue(graph: STRAWBERRY) + ID @join__enumValue(graph: STRAWBERRY) } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ReplicaState @join__type(graph: STRAWBERRY) { @@ -5577,7 +5590,7 @@ type ReplicaState replicas(filter: ReplicaFilter = null, orderBy: [ReplicaOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelReplicaConnection! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ReplicaStatusChangedPayload @join__type(graph: STRAWBERRY) { @@ -5637,12 +5650,10 @@ type ReservoirRegistryEdge node: ReservoirRegistry! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ResourceConfig @join__type(graph: STRAWBERRY) { - resourceGroup: ScalingGroupNode! - """ Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}" """ @@ -5652,9 +5663,10 @@ type ResourceConfig Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}" """ resourceOpts: JSONString + resourceGroup: ScalingGroupNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ResourceConfigInput @join__type(graph: STRAWBERRY) { @@ -5671,7 +5683,7 @@ input ResourceConfigInput resourceOpts: JSONString = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ResourceGroupInput @join__type(graph: STRAWBERRY) { @@ -5863,7 +5875,7 @@ type ScalinGroupEdge cursor: String! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ScalingRule @join__type(graph: STRAWBERRY) { @@ -6181,10 +6193,10 @@ type Subscription """ artifactImportProgressUpdated(artifactRevisionId: ID!): ArtifactImportProgressUpdatedPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" deploymentStatusChanged(deploymentId: ID!): DeploymentStatusChangedPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" replicaStatusChanged(revisionId: ID!): ReplicaStatusChangedPayload! """ @@ -6198,14 +6210,14 @@ type Subscription backgroundTaskEvents(taskId: ID!): BackgroundTaskEventPayload! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input SyncReplicaInput @join__type(graph: STRAWBERRY) { modelDeploymentId: ID! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type SyncReplicaPayload @join__type(graph: STRAWBERRY) { @@ -6346,7 +6358,7 @@ type UpdateHuggingFaceRegistryPayload huggingfaceRegistry: HuggingFaceRegistry! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input UpdateModelDeploymentInput @join__type(graph: STRAWBERRY) { @@ -6360,7 +6372,7 @@ input UpdateModelDeploymentInput preferredDomainName: String = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type UpdateModelDeploymentPayload @join__type(graph: STRAWBERRY) { diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index e761c3206ff..ad853229088 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -8,17 +8,17 @@ type AccessToken implements Node { """The Globally Unique ID of this object""" id: ID! - """Added in 25.13.0: The access token.""" + """Added in 25.16.0: The access token.""" token: String! - """Added in 25.13.0: The creation timestamp of the access token.""" + """Added in 25.16.0: The creation timestamp of the access token.""" createdAt: DateTime! - """Added in 25.13.0: The expiration timestamp of the access token.""" + """Added in 25.16.0: The expiration timestamp of the access token.""" validUntil: DateTime! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type AccessTokenConnection { """Pagination data for this connection""" pageInfo: PageInfo! @@ -37,21 +37,31 @@ type AccessTokenEdge { node: AccessToken! } +"""Added in 25.16.0""" +input AccessTokenOrderBy { + field: AccessTokenOrderField! + direction: OrderDirection! = DESC +} + +enum AccessTokenOrderField { + CREATED_AT +} + """ -Added in 25.13.0. This enum represents the activeness status of a replica, indicating whether the deployment is currently active and able to serve requests. +Added in 25.16.0. This enum represents the activeness status of a replica, indicating whether the deployment is currently active and able to serve requests. """ enum ActivenessStatus { ACTIVE INACTIVE } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ActivenessStatusFilter { in: [ActivenessStatus!] = null equals: ActivenessStatus = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input AddModelRevisionInput { name: String = null deploymentId: ID! @@ -63,7 +73,7 @@ input AddModelRevisionInput { extraMounts: [ExtraVFolderMountInput!] } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type AddModelRevisionPayload { revision: ModelRevision! } @@ -441,26 +451,26 @@ type AutoScalingRule implements Node { """The Globally Unique ID of this object""" id: ID! - """Added in 25.13.0 (e.g. KERNEL, INFERENCE_FRAMEWORK)""" + """Added in 25.16.0 (e.g. KERNEL, INFERENCE_FRAMEWORK)""" metricSource: AutoScalingMetricSource! metricName: String! - """Added in 25.13.0: The minimum threshold for scaling (e.g. 0.5)""" + """Added in 25.16.0: The minimum threshold for scaling (e.g. 0.5)""" minThreshold: Decimal - """Added in 25.13.0: The maximum threshold for scaling (e.g. 21.0)""" + """Added in 25.16.0: The maximum threshold for scaling (e.g. 21.0)""" maxThreshold: Decimal - """Added in 25.13.0: The step size for scaling (e.g. 1).""" + """Added in 25.16.0: The step size for scaling (e.g. 1).""" stepSize: Int! - """Added in 25.13.0: The time window (seconds) for scaling (e.g. 60).""" + """Added in 25.16.0: The time window (seconds) for scaling (e.g. 60).""" timeWindow: Int! - """Added in 25.13.0: The minimum number of replicas (e.g. 1).""" + """Added in 25.16.0: The minimum number of replicas (e.g. 1).""" minReplicas: Int - """Added in 25.13.0: The maximum number of replicas (e.g. 10).""" + """Added in 25.16.0: The maximum number of replicas (e.g. 10).""" maxReplicas: Int createdAt: DateTime! lastTriggeredAt: DateTime! @@ -532,19 +542,19 @@ type CleanupArtifactRevisionsPayload { artifactRevisions: ArtifactRevisionConnection! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ClusterConfig { mode: ClusterMode! size: Int! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ClusterConfigInput { mode: ClusterMode! size: Int! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" enum ClusterMode { SINGLE_NODE MULTI_NODE @@ -556,11 +566,11 @@ extend type ComputeSessionNode @key(fields: "id") { input CreateAccessTokenInput { """ - Added in 25.13.0: The ID of the model deployment for which the access token is created. + Added in 25.16.0: The ID of the model deployment for which the access token is created. """ modelDeploymentId: ID! - """Added in 25.13.0: The expiration timestamp of the access token.""" + """Added in 25.16.0: The expiration timestamp of the access token.""" validUntil: DateTime! } @@ -596,7 +606,7 @@ type CreateHuggingFaceRegistryPayload { huggingfaceRegistry: HuggingFaceRegistry! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input CreateModelDeploymentInput { metadata: ModelDeploymentMetadataInput! networkAccess: ModelDeploymentNetworkAccessInput! @@ -605,12 +615,12 @@ input CreateModelDeploymentInput { initialRevision: CreateModelRevisionInput! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type CreateModelDeploymentPayload { deployment: ModelDeployment! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input CreateModelRevisionInput { name: String = null clusterConfig: ClusterConfigInput! @@ -621,7 +631,7 @@ input CreateModelRevisionInput { extraMounts: [ExtraVFolderMountInput!] } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type CreateModelRevisionPayload { revision: ModelRevision! } @@ -790,12 +800,12 @@ type DeleteHuggingFaceRegistryPayload { id: ID! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeleteModelDeploymentInput { id: ID! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type DeleteModelDeploymentPayload { id: ID! } @@ -830,7 +840,7 @@ type DeleteVFSStoragePayload { id: ID! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentFilter { name: StringFilter = null status: DeploymentStatusFilter = null @@ -843,13 +853,12 @@ input DeploymentFilter { NOT: [DeploymentFilter!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentOrderBy { field: DeploymentOrderField! direction: OrderDirection! = DESC } -"""Added in 25.13.0""" enum DeploymentOrderField { CREATED_AT UPDATED_AT @@ -857,7 +866,7 @@ enum DeploymentOrderField { } """ -Added in 25.13.0. This enum represents the deployment status of a model deployment, indicating its current state. +Added in 25.16.0. This enum represents the deployment status of a model deployment, indicating its current state. """ enum DeploymentStatus { PENDING @@ -868,29 +877,29 @@ enum DeploymentStatus { STOPPED } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type DeploymentStatusChangedPayload { deployment: ModelDeployment! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentStatusFilter { in: [DeploymentStatus!] = null equals: DeploymentStatus = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type DeploymentStrategy { type: DeploymentStrategyType! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input DeploymentStrategyInput { type: DeploymentStrategyType! } """ -Added in 25.13.0. This enum represents the deployment strategy type of a model deployment, indicating the strategy used for deployment. +Added in 25.16.0. This enum represents the deployment strategy type of a model deployment, indicating the strategy used for deployment. """ enum DeploymentStrategyType { ROLLING @@ -908,7 +917,7 @@ type ExtraVFolderMount implements Node { vfolder: VirtualFolderNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ExtraVFolderMountConnection { """Pagination data for this connection""" pageInfo: PageInfo! @@ -927,7 +936,7 @@ type ExtraVFolderMountEdge { node: ExtraVFolderMount! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ExtraVFolderMountInput { vfolderId: ID! mountDestination: String @@ -989,7 +998,7 @@ type HuggingFaceRegistryEdge { node: HuggingFaceRegistry! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ImageInput { name: String! architecture: String! @@ -1038,11 +1047,11 @@ The `JSON` scalar type represents JSON values as specified by [ECMA-404](https:/ """ scalar JSON @specifiedBy(url: "https://ecma-international.org/wp-content/uploads/ECMA-404_2nd_edition_december_2017.pdf") -"""Added in 25.13.0""" +"""Added in 25.15.0""" scalar JSONString """ -Added in 25.13.0. This enum represents the liveness status of a replica, indicating whether the deployment is currently running and able to serve requests. +Added in 25.16.0. This enum represents the liveness status of a replica, indicating whether the deployment is currently running and able to serve requests. """ enum LivenessStatus { NOT_CHECKED @@ -1051,27 +1060,27 @@ enum LivenessStatus { DEGRADED } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input LivenessStatusFilter { in: [LivenessStatus!] = null equals: LivenessStatus = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeployment implements Node { """The Globally Unique ID of this object""" id: ID! metadata: ModelDeploymentMetadata! networkAccess: ModelDeploymentNetworkAccess! revision: ModelRevision - scalingRule: ScalingRule! - replicaState: ReplicaState! defaultDeploymentStrategy: DeploymentStrategy! createdUser: UserNode! + scalingRule: ScalingRule! + replicaState: ReplicaState! revisionHistory(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeploymentConnection { """Pagination data for this connection""" pageInfo: PageInfo! @@ -1090,18 +1099,18 @@ type ModelDeploymentEdge { node: ModelDeployment! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeploymentMetadata { name: String! status: DeploymentStatus! tags: [String!]! - project: GroupNode! - domain: DomainNode! createdAt: DateTime! updatedAt: DateTime! + project: GroupNode! + domain: DomainNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelDeploymentMetadataInput { projectId: ID! domainName: String! @@ -1109,39 +1118,38 @@ input ModelDeploymentMetadataInput { tags: [String!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelDeploymentNetworkAccess { endpointUrl: String preferredDomainName: String openToPublic: Boolean! - accessTokens: AccessTokenConnection! + accessTokens(orderBy: [AccessTokenOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): AccessTokenConnection! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelDeploymentNetworkAccessInput { preferredDomainName: String = null openToPublic: Boolean! = false } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelMountConfig { - vfolder: VirtualFolderNode! mountDestination: String! definitionPath: String! + vfolder: VirtualFolderNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelMountConfigInput { vfolderId: ID! mountDestination: String! definitionPath: String! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelReplica implements Node { """The Globally Unique ID of this object""" id: ID! - revision: ModelRevision! """ This represents whether the replica has been checked and its health state. @@ -1174,9 +1182,10 @@ type ModelReplica implements Node { The session ID associated with the replica. This can be null right after replica creation. """ session: ComputeSessionNode! + revision: ModelRevision! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelReplicaConnection { """Pagination data for this connection""" pageInfo: PageInfo! @@ -1195,7 +1204,7 @@ type ModelReplicaEdge { node: ModelReplica! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelRevision implements Node { """The Globally Unique ID of this object""" id: ID! @@ -1205,11 +1214,11 @@ type ModelRevision implements Node { modelRuntimeConfig: ModelRuntimeConfig! modelMountConfig: ModelMountConfig! extraMounts: ExtraVFolderMountConnection! - image: ImageNode! createdAt: DateTime! + image: ImageNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelRevisionConnection { """Pagination data for this connection""" pageInfo: PageInfo! @@ -1228,7 +1237,7 @@ type ModelRevisionEdge { node: ModelRevision! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelRevisionFilter { name: StringFilter = null deploymentId: ID = null @@ -1238,20 +1247,18 @@ input ModelRevisionFilter { NOT: [ModelRevisionFilter!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelRevisionOrderBy { field: ModelRevisionOrderField! direction: OrderDirection! = DESC } -"""Added in 25.13.0""" enum ModelRevisionOrderField { CREATED_AT NAME - ID } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ModelRuntimeConfig { runtimeVariant: String! inferenceRuntimeConfig: JSON @@ -1262,7 +1269,7 @@ type ModelRuntimeConfig { environ: JSONString } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ModelRuntimeConfigInput { runtimeVariant: String! inferenceRuntimeConfig: JSON = null @@ -1408,24 +1415,26 @@ type Mutation { """ cancelImportArtifact(input: CancelArtifactInput!): CancelImportArtifactPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" createModelDeployment(input: CreateModelDeploymentInput!): CreateModelDeploymentPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" updateModelDeployment(input: UpdateModelDeploymentInput!): UpdateModelDeploymentPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" deleteModelDeployment(input: DeleteModelDeploymentInput!): DeleteModelDeploymentPayload! """ - Added in 25.13.0. Force syncs up-to-date replica information. In normal situations this will be automatically handled by Backend.AI schedulers + Added in 25.16.0. Force syncs up-to-date replica information. In normal situations this will be automatically handled by Backend.AI schedulers """ syncReplicas(input: SyncReplicaInput!): SyncReplicaPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" addModelRevision(input: AddModelRevisionInput!): AddModelRevisionPayload! - """Added in 25.13.0""" + """ + Added in 25.16.0. Create model revision which is not attached to any deployment. + """ createModelRevision(input: CreateModelRevisionInput!): CreateModelRevisionPayload! """Added in 25.14.0""" @@ -1434,13 +1443,13 @@ type Mutation { """Added in 25.14.0""" updateObjectStorage(input: UpdateObjectStorageInput!): UpdateObjectStoragePayload! - """Added in 25.13.0""" + """Added in 25.16.0""" createAutoScalingRule(input: CreateAutoScalingRuleInput!): CreateAutoScalingRulePayload! - """Added in 25.13.0""" + """Added in 25.16.0""" updateAutoScalingRule(input: UpdateAutoScalingRuleInput!): UpdateAutoScalingRulePayload! - """Added in 25.13.0""" + """Added in 25.16.0""" deleteAutoScalingRule(input: DeleteAutoScalingRuleInput!): DeleteAutoScalingRulePayload! """Added in 25.14.0""" @@ -1513,7 +1522,7 @@ type Mutation { """ rejectArtifactRevision(input: RejectArtifactInput!): RejectArtifactPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" createAccessToken(input: CreateAccessTokenInput!): CreateAccessTokenPayload! } @@ -1626,22 +1635,22 @@ type Query { """ artifactRevisions(filter: ArtifactRevisionFilter = null, orderBy: [ArtifactRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ArtifactRevisionConnection! - """Added in 25.13.0""" + """Added in 25.16.0""" deployments(filter: DeploymentFilter = null, orderBy: [DeploymentOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelDeploymentConnection! - """Added in 25.13.0""" + """Added in 25.16.0""" deployment(id: ID!): ModelDeployment - """Added in 25.13.0""" + """Added in 25.16.0""" revisions(filter: ModelRevisionFilter = null, orderBy: [ModelRevisionOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelRevisionConnection! - """Added in 25.13.0""" + """Added in 25.16.0""" revision(id: ID!): ModelRevision! - """Added in 25.13.0""" + """Added in 25.16.0""" replicas(filter: ReplicaFilter = null, orderBy: [ReplicaOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelReplicaConnection! - """Added in 25.13.0""" + """Added in 25.16.0""" replica(id: ID!): ModelReplica """Added in 25.14.0""" @@ -1675,16 +1684,16 @@ type Query { agentStats: AgentStats! """ - Added in 25.13.0 Get configuration JSON Schemas for all inference runtimes + Added in 25.16.0 Get configuration JSON Schemas for all inference runtimes """ inferenceRuntimeConfigs: JSON! - """Added in 25.13.0. Get JSON Schema for inference runtime configuration""" + """Added in 25.16.0. Get JSON Schema for inference runtime configuration""" inferenceRuntimeConfig(name: String!): JSON! } """ -Added in 25.13.0. This enum represents the readiness status of a replica, indicating whether the deployment has been checked and its health state. +Added in 25.16.0. This enum represents the readiness status of a replica, indicating whether the deployment has been checked and its health state. """ enum ReadinessStatus { NOT_CHECKED @@ -1692,7 +1701,7 @@ enum ReadinessStatus { UNHEALTHY } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ReadinessStatusFilter { in: [ReadinessStatus!] = null equals: ReadinessStatus = null @@ -1739,7 +1748,7 @@ type RejectArtifactPayload { artifactRevision: ArtifactRevision! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ReplicaFilter { readinessStatus: ReadinessStatusFilter = null livenessStatus: LivenessStatusFilter = null @@ -1750,24 +1759,24 @@ input ReplicaFilter { NOT: [ReplicaFilter!] = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ReplicaOrderBy { field: ReplicaOrderField! direction: OrderDirection! = DESC } -"""Added in 25.13.0""" enum ReplicaOrderField { CREATED_AT + ID } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ReplicaState { desiredReplicaCount: Int! replicas(filter: ReplicaFilter = null, orderBy: [ReplicaOrderBy!] = null, before: String = null, after: String = null, first: Int = null, last: Int = null, limit: Int = null, offset: Int = null): ModelReplicaConnection! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ReplicaStatusChangedPayload { replica: ModelReplica! } @@ -1803,10 +1812,8 @@ type ReservoirRegistryEdge { node: ReservoirRegistry! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ResourceConfig { - resourceGroup: ScalingGroupNode! - """ Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\"cpu\": \"1\", \"mem\": \"1073741824\", \"cuda.device\": \"0\"}" """ @@ -1816,9 +1823,10 @@ type ResourceConfig { Resource Options are a JSON string that describes additional options for the resources. This is especially used for shared memory configurations. Example: "resourceOpts": "{\"shmem\": \"64m\"}" """ resourceOpts: JSONString + resourceGroup: ScalingGroupNode! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ResourceConfigInput { resourceGroup: ResourceGroupInput! @@ -1833,7 +1841,7 @@ input ResourceConfigInput { resourceOpts: JSONString = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input ResourceGroupInput { name: String! } @@ -1864,7 +1872,7 @@ extend type ScalingGroupNode @key(fields: "id") { id: ID! @external } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type ScalingRule { autoScalingRules: [AutoScalingRule!]! } @@ -2035,10 +2043,10 @@ type Subscription { """ artifactImportProgressUpdated(artifactRevisionId: ID!): ArtifactImportProgressUpdatedPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" deploymentStatusChanged(deploymentId: ID!): DeploymentStatusChangedPayload! - """Added in 25.13.0""" + """Added in 25.16.0""" replicaStatusChanged(revisionId: ID!): ReplicaStatusChangedPayload! """ @@ -2052,12 +2060,12 @@ type Subscription { backgroundTaskEvents(taskId: ID!): BackgroundTaskEventPayload! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input SyncReplicaInput { modelDeploymentId: ID! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type SyncReplicaPayload { success: Boolean! } @@ -2137,7 +2145,7 @@ type UpdateHuggingFaceRegistryPayload { huggingfaceRegistry: HuggingFaceRegistry! } -"""Added in 25.13.0""" +"""Added in 25.16.0""" input UpdateModelDeploymentInput { id: ID! openToPublic: Boolean = null @@ -2149,7 +2157,7 @@ input UpdateModelDeploymentInput { preferredDomainName: String = null } -"""Added in 25.13.0""" +"""Added in 25.16.0""" type UpdateModelDeploymentPayload { deployment: ModelDeployment! } diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index 9053946f057..6129918ba59 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -476,7 +476,7 @@ def error_code(cls) -> ErrorCode: ) -class BgtaskNotFoundError(BackendAIError, web.HTTPNotFound): +class BgtaskNotFound(BackendAIError, web.HTTPNotFound): error_type = "https://api.backend.ai/probs/bgtask-not-found" error_title = "Background Task Not Found" @@ -700,7 +700,7 @@ def error_code(cls) -> ErrorCode: ) -class VFolderNotFoundError(BackendAIError, web.HTTPNotFound): +class VFolderNotFound(BackendAIError, web.HTTPNotFound): error_type = "https://api.backend.ai/probs/vfolder-not-found" error_title = "Virtual Folder Not Found" @@ -713,7 +713,7 @@ def error_code(cls) -> ErrorCode: ) -class UserNotFoundError(BackendAIError, web.HTTPNotFound): +class UserNotFound(BackendAIError, web.HTTPNotFound): error_type = "https://api.backend.ai/probs/user-not-found" error_title = "User Not Found" @@ -726,7 +726,7 @@ def error_code(cls) -> ErrorCode: ) -class GroupNotFoundError(BackendAIError, web.HTTPNotFound): +class GroupNotFound(BackendAIError, web.HTTPNotFound): error_type = "https://api.backend.ai/probs/group-not-found" error_title = "Project Not Found" @@ -739,7 +739,7 @@ def error_code(cls) -> ErrorCode: ) -class DomainNotFoundError(BackendAIError, web.HTTPNotFound): +class DomainNotFound(BackendAIError, web.HTTPNotFound): error_type = "https://api.backend.ai/probs/domain-not-found" error_title = "Domain Not Found" @@ -750,3 +750,42 @@ def error_code(cls) -> ErrorCode: operation=ErrorOperation.READ, error_detail=ErrorDetail.NOT_FOUND, ) + + +class ModelDeploymentNotFound(BackendAIError, web.HTTPNotFound): + error_type = "https://api.backend.ai/probs/model-deployment-not-found" + error_title = "Model Deployment Not Found" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.MODEL_DEPLOYMENT, + operation=ErrorOperation.READ, + error_detail=ErrorDetail.NOT_FOUND, + ) + + +class ModelDeploymentUnavailable(BackendAIError, web.HTTPServiceUnavailable): + error_type = "https://api.backend.ai/probs/model-deployment-unavailable" + error_title = "Model Deployment Unavailable" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.MODEL_DEPLOYMENT, + operation=ErrorOperation.EXECUTE, + error_detail=ErrorDetail.UNAVAILABLE, + ) + + +class ModelRevisionNotFound(BackendAIError, web.HTTPNotFound): + error_type = "https://api.backend.ai/probs/model-revision-not-found" + error_title = "Model Revision Not Found" + + @classmethod + def error_code(cls) -> ErrorCode: + return ErrorCode( + domain=ErrorDomain.MODEL_DEPLOYMENT, + operation=ErrorOperation.READ, + error_detail=ErrorDetail.NOT_FOUND, + ) diff --git a/src/ai/backend/manager/api/admin.py b/src/ai/backend/manager/api/admin.py index 16a4f06f0b1..6331d10c079 100644 --- a/src/ai/backend/manager/api/admin.py +++ b/src/ai/backend/manager/api/admin.py @@ -19,6 +19,7 @@ from ai.backend.common import validators as tx from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.api.gql.data_loader.registry import DataLoaderRegistry from ai.backend.manager.api.gql.types import StrawberryGQLContext from ..api.gql.schema import schema as strawberry_schema @@ -65,6 +66,7 @@ async def get_context( # type: ignore[override] config_provider=root_context.config_provider, event_hub=root_context.event_hub, event_fetcher=root_context.event_fetcher, + dataloader_registry=DataLoaderRegistry(), ) diff --git a/src/ai/backend/manager/api/gql/base.py b/src/ai/backend/manager/api/gql/base.py index 6e58e063d36..5255b062a2d 100644 --- a/src/ai/backend/manager/api/gql/base.py +++ b/src/ai/backend/manager/api/gql/base.py @@ -2,15 +2,17 @@ import uuid from collections.abc import Mapping +from dataclasses import dataclass from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Optional, Protocol, Type, TypeVar, cast -import orjson +import graphene import strawberry from graphql import StringValueNode from graphql_relay.utils import base64, unbase64 from strawberry.types import get_object_definition, has_object_definition +from ai.backend.common.json import dump_json_str, load_json from ai.backend.common.types import ResourceSlot if TYPE_CHECKING: @@ -136,31 +138,41 @@ class Ordering(StrEnum): DESC_NULLS_LAST = "DESC_NULLS_LAST" -@strawberry.scalar(description="Added in 25.13.0") +@strawberry.scalar(description="Added in 25.15.0") class JSONString: @staticmethod def parse_value(value: str | bytes) -> Mapping[str, Any]: if isinstance(value, str): - return orjson.loads(value) + return load_json(value) if isinstance(value, bytes): - return orjson.loads(value) + return load_json(value) return value @staticmethod def serialize(value: Any) -> JSONString: if isinstance(value, (dict, list)): - return cast(JSONString, orjson.dumps(value).decode("utf-8")) + return cast(JSONString, dump_json_str(value)) elif isinstance(value, str): return cast(JSONString, value) else: - return cast(JSONString, orjson.dumps(value).decode("utf-8")) + return cast(JSONString, dump_json_str(value)) @staticmethod def from_resource_slot(resource_slot: ResourceSlot) -> JSONString: return JSONString.serialize(resource_slot.to_json()) -def to_global_id(type_: Type[Any], local_id: uuid.UUID | str) -> str: +def to_global_id( + type_: Type[Any], local_id: uuid.UUID | str, is_target_graphene_object: bool = False +) -> str: + if is_target_graphene_object: + # For compatibility with existing Graphene-based global IDs + if not issubclass(type_, graphene.ObjectType): + raise TypeError( + "type_ must be a graphene ObjectType when is_target_graphene_object is True." + ) + typename = type_.__name__ + return base64(f"{typename}:{local_id}") if not has_object_definition(type_): raise TypeError("type_ must be a Strawberry object type (Node or Edge).") typename = get_object_definition(type_, strict=True).name @@ -203,3 +215,75 @@ def build_pagination_options( pagination.backward = BackwardPaginationOptions(before=before, last=last) return pagination + + +@dataclass +class PageInfo: + has_next_page: bool + has_previous_page: bool + start_cursor: Optional[str] = None + end_cursor: Optional[str] = None + + def to_strawberry_page_info(self) -> "strawberry.relay.PageInfo": + return strawberry.relay.PageInfo( + has_next_page=self.has_next_page, + has_previous_page=self.has_previous_page, + start_cursor=self.start_cursor, + end_cursor=self.end_cursor, + ) + + +class HasCursor(Protocol): + cursor: str + + +TEdge = TypeVar("TEdge", bound=HasCursor) + + +def build_page_info( + edges: list[TEdge], total_count: int, pagination_options: PaginationOptions +) -> PageInfo: + """Build PageInfo from edges and pagination options""" + has_next_page = False + has_previous_page = False + + if pagination_options.offset: + # Offset-based pagination + offset = pagination_options.offset.offset or 0 + + has_previous_page = offset > 0 + has_next_page = (offset + len(edges)) < total_count + + elif pagination_options.forward: + # Forward pagination (after/first) + first = pagination_options.forward.first + if first is not None: + # If we got exactly the requested number and there might be more + has_next_page = len(edges) == first + else: + # If no first specified, check if we have all items + has_next_page = len(edges) < total_count + has_previous_page = pagination_options.forward.after is not None + + elif pagination_options.backward: + # Backward pagination (before/last) + last = pagination_options.backward.last + if last is not None: + # If we got exactly the requested number, there might be more before + has_previous_page = len(edges) == last + else: + # If no last specified, assume there could be previous items + has_previous_page = True + has_next_page = pagination_options.backward.before is not None + + else: + # Default case - assume we have all items if no pagination specified + has_next_page = len(edges) < total_count + has_previous_page = False + + return PageInfo( + has_next_page=has_next_page, + has_previous_page=has_previous_page, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ) diff --git a/src/ai/backend/manager/data/model_deployment/__init__.py b/src/ai/backend/manager/api/gql/data_loader/__init__.py similarity index 100% rename from src/ai/backend/manager/data/model_deployment/__init__.py rename to src/ai/backend/manager/api/gql/data_loader/__init__.py diff --git a/src/ai/backend/manager/api/gql/data_loader/registry.py b/src/ai/backend/manager/api/gql/data_loader/registry.py new file mode 100644 index 00000000000..eed1b9ddd64 --- /dev/null +++ b/src/ai/backend/manager/api/gql/data_loader/registry.py @@ -0,0 +1,26 @@ +from typing import TYPE_CHECKING, Any, Awaitable, Callable + +from aiotools import apartial +from strawberry.dataloader import DataLoader + +if TYPE_CHECKING: + from ai.backend.manager.api.gql.types import StrawberryGQLContext + + +class DataLoaderRegistry: + _loader: dict[Callable, DataLoader] + + def __init__(self) -> None: + self._loader = {} + + def get_loader( + self, + func: Callable[["StrawberryGQLContext", Any], Awaitable[Any]], + context: "StrawberryGQLContext", + ) -> DataLoader: + loader = self._loader.get(func, None) + if loader is None: + new_loader = DataLoader(apartial(func, context)) + self._loader[func] = new_loader + return new_loader + return loader diff --git a/src/ai/backend/manager/api/gql/model_deployment/access_token.py b/src/ai/backend/manager/api/gql/model_deployment/access_token.py index 92429f17b99..8336730daf9 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/access_token.py +++ b/src/ai/backend/manager/api/gql/model_deployment/access_token.py @@ -1,29 +1,54 @@ -from datetime import datetime, timedelta +from datetime import datetime +from typing import Self from uuid import UUID import strawberry from strawberry import ID, Info from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.manager.api.gql.base import OrderDirection from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.access_token import ModelDeploymentAccessTokenCreator +from ai.backend.manager.data.deployment.types import ( + AccessTokenOrderField, + ModelDeploymentAccessTokenData, +) +from ai.backend.manager.services.deployment.actions.access_token.create_access_token import ( + CreateAccessTokenAction, +) + + +@strawberry.input(description="Added in 25.16.0") +class AccessTokenOrderBy: + field: AccessTokenOrderField + direction: OrderDirection = OrderDirection.DESC @strawberry.type class AccessToken(Node): - id: NodeID - token: str = strawberry.field(description="Added in 25.13.0: The access token.") + id: NodeID[str] + token: str = strawberry.field(description="Added in 25.16.0: The access token.") created_at: datetime = strawberry.field( - description="Added in 25.13.0: The creation timestamp of the access token." + description="Added in 25.16.0: The creation timestamp of the access token." ) valid_until: datetime = strawberry.field( - description="Added in 25.13.0: The expiration timestamp of the access token." + description="Added in 25.16.0: The expiration timestamp of the access token." ) + @classmethod + def from_dataclass(cls, data: ModelDeploymentAccessTokenData) -> Self: + return cls( + id=ID(str(data.id)), + token=data.token, + created_at=data.created_at, + valid_until=data.valid_until, + ) + AccessTokenEdge = Edge[AccessToken] -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class AccessTokenConnection(Connection[AccessToken]): count: int @@ -32,59 +57,34 @@ def __init__(self, *args, count: int, **kwargs): self.count = count -mock_access_token_1 = AccessToken( - id=UUID("13cd8325-9307-49e4-94eb-ded2581363f8"), - token="mock-token-1", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=12), -) - -mock_access_token_2 = AccessToken( - id=UUID("dc1a223a-7437-4e6f-aedf-23417d0486dd"), - token="mock-token-2", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=1), -) - -mock_access_token_3 = AccessToken( - id=UUID("39f8b49e-0ddf-4dfb-92d6-003c771684b7"), - token="mock-token-3", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=100), -) - -mock_access_token_4 = AccessToken( - id=UUID("85a6ed1e-133b-4f58-9c06-f667337c6111"), - token="mock-token-4", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=10), -) - -mock_access_token_5 = AccessToken( - id=UUID("c42f8578-b31d-4203-b858-93f93b4b9549"), - token="mock-token-5", - created_at=datetime.now(), - valid_until=datetime.now() + timedelta(hours=3), -) - - @strawberry.input class CreateAccessTokenInput: model_deployment_id: ID = strawberry.field( - description="Added in 25.13.0: The ID of the model deployment for which the access token is created." + description="Added in 25.16.0: The ID of the model deployment for which the access token is created." ) valid_until: datetime = strawberry.field( - description="Added in 25.13.0: The expiration timestamp of the access token." + description="Added in 25.16.0: The expiration timestamp of the access token." ) + def to_creator(self) -> "ModelDeploymentAccessTokenCreator": + return ModelDeploymentAccessTokenCreator( + model_deployment_id=UUID(self.model_deployment_id), + valid_until=self.valid_until, + ) + @strawberry.type class CreateAccessTokenPayload: access_token: AccessToken -@strawberry.mutation(description="Added in 25.13.0") +@strawberry.mutation(description="Added in 25.16.0") async def create_access_token( input: CreateAccessTokenInput, info: Info[StrawberryGQLContext] ) -> CreateAccessTokenPayload: - return CreateAccessTokenPayload(access_token=mock_access_token_1) + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + result = await deployment_processor.create_access_token.wait_for_complete( + action=CreateAccessTokenAction(input.to_creator()) + ) + return CreateAccessTokenPayload(access_token=AccessToken.from_dataclass(result.data)) diff --git a/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py b/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py index c133ff32d91..b86dc083262 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py +++ b/src/ai/backend/manager/api/gql/model_deployment/auto_scaling_rule.py @@ -1,14 +1,28 @@ -from datetime import datetime, timedelta +from datetime import datetime from decimal import Decimal from enum import StrEnum -from typing import Optional +from typing import Optional, Self from uuid import UUID import strawberry from strawberry import ID, Info from strawberry.relay import Node, NodeID +from ai.backend.common.types import AutoScalingMetricSource as CommonAutoScalingMetricSource from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.scale import ModelDeploymentAutoScalingRuleCreator +from ai.backend.manager.data.deployment.scale_modifier import ModelDeploymentAutoScalingRuleModifier +from ai.backend.manager.data.deployment.types import ModelDeploymentAutoScalingRuleData +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.create_auto_scaling_rule import ( + CreateAutoScalingRuleAction, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.delete_auto_scaling_rule import ( + DeleteAutoScalingRuleAction, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.update_auto_scaling_rule import ( + UpdateAutoScalingRuleAction, +) +from ai.backend.manager.types import OptionalState @strawberry.enum(description="Added in 25.1.0") @@ -19,37 +33,53 @@ class AutoScalingMetricSource(StrEnum): @strawberry.type class AutoScalingRule(Node): - id: NodeID + id: NodeID[str] metric_source: AutoScalingMetricSource = strawberry.field( - description="Added in 25.13.0 (e.g. KERNEL, INFERENCE_FRAMEWORK)" + description="Added in 25.16.0 (e.g. KERNEL, INFERENCE_FRAMEWORK)" ) metric_name: str = strawberry.field() min_threshold: Optional[Decimal] = strawberry.field( - description="Added in 25.13.0: The minimum threshold for scaling (e.g. 0.5)" + description="Added in 25.16.0: The minimum threshold for scaling (e.g. 0.5)" ) max_threshold: Optional[Decimal] = strawberry.field( - description="Added in 25.13.0: The maximum threshold for scaling (e.g. 21.0)" + description="Added in 25.16.0: The maximum threshold for scaling (e.g. 21.0)" ) step_size: int = strawberry.field( - description="Added in 25.13.0: The step size for scaling (e.g. 1)." + description="Added in 25.16.0: The step size for scaling (e.g. 1)." ) time_window: int = strawberry.field( - description="Added in 25.13.0: The time window (seconds) for scaling (e.g. 60)." + description="Added in 25.16.0: The time window (seconds) for scaling (e.g. 60)." ) min_replicas: Optional[int] = strawberry.field( - description="Added in 25.13.0: The minimum number of replicas (e.g. 1)." + description="Added in 25.16.0: The minimum number of replicas (e.g. 1)." ) max_replicas: Optional[int] = strawberry.field( - description="Added in 25.13.0: The maximum number of replicas (e.g. 10)." + description="Added in 25.16.0: The maximum number of replicas (e.g. 10)." ) created_at: datetime last_triggered_at: datetime + @classmethod + def from_dataclass(cls, data: ModelDeploymentAutoScalingRuleData) -> Self: + return cls( + id=ID(str(data.id)), + metric_source=AutoScalingMetricSource(data.metric_source.name), + metric_name=data.metric_name, + min_threshold=data.min_threshold, + max_threshold=data.max_threshold, + step_size=data.step_size, + time_window=data.time_window, + min_replicas=data.min_replicas, + max_replicas=data.max_replicas, + created_at=data.created_at, + last_triggered_at=data.last_triggered_at, + ) + # Input Types @strawberry.input @@ -64,6 +94,19 @@ class CreateAutoScalingRuleInput: min_replicas: Optional[int] max_replicas: Optional[int] + def to_creator(self) -> ModelDeploymentAutoScalingRuleCreator: + return ModelDeploymentAutoScalingRuleCreator( + model_deployment_id=UUID(self.model_deployment_id), + metric_source=CommonAutoScalingMetricSource(self.metric_source.lower()), + metric_name=self.metric_name, + min_threshold=self.min_threshold, + max_threshold=self.max_threshold, + step_size=self.step_size, + time_window=self.time_window, + min_replicas=self.min_replicas, + max_replicas=self.max_replicas, + ) + @strawberry.input class UpdateAutoScalingRuleInput: @@ -77,6 +120,26 @@ class UpdateAutoScalingRuleInput: min_replicas: Optional[int] max_replicas: Optional[int] + def to_action(self) -> UpdateAutoScalingRuleAction: + optional_state_metric_source = OptionalState[CommonAutoScalingMetricSource].nop() + if isinstance(self.metric_source, AutoScalingMetricSource): + optional_state_metric_source = OptionalState[CommonAutoScalingMetricSource].update( + CommonAutoScalingMetricSource(self.metric_source) + ) + return UpdateAutoScalingRuleAction( + auto_scaling_rule_id=UUID(self.id), + modifier=ModelDeploymentAutoScalingRuleModifier( + metric_source=optional_state_metric_source, + metric_name=OptionalState[str].from_graphql(self.metric_name), + min_threshold=OptionalState[Decimal].from_graphql(self.min_threshold), + max_threshold=OptionalState[Decimal].from_graphql(self.max_threshold), + step_size=OptionalState[int].from_graphql(self.step_size), + time_window=OptionalState[int].from_graphql(self.time_window), + min_replicas=OptionalState[int].from_graphql(self.min_replicas), + max_replicas=OptionalState[int].from_graphql(self.max_replicas), + ), + ) + @strawberry.input class DeleteAutoScalingRuleInput: @@ -99,89 +162,41 @@ class DeleteAutoScalingRulePayload: id: ID -mock_scaling_rule_0 = AutoScalingRule( - id=UUID("77117a41-87f3-43b7-ba24-40dd5e978720"), - metric_source=AutoScalingMetricSource.KERNEL, - metric_name="memory_usage", - min_threshold=None, - max_threshold=Decimal("90"), - step_size=1, - time_window=120, - min_replicas=1, - max_replicas=3, - created_at=datetime.now() - timedelta(days=15), - last_triggered_at=datetime.now() - timedelta(hours=6), -) - -mock_scaling_rule_1 = AutoScalingRule( - id=UUID("7ff8c1f5-cf8c-4ea2-911c-24ca0f4c2efb"), - metric_source=AutoScalingMetricSource.KERNEL, - metric_name="cpu_usage", - min_threshold=None, - max_threshold=Decimal("80"), - step_size=1, - time_window=300, - min_replicas=1, - max_replicas=5, - created_at=datetime.now() - timedelta(days=10), - last_triggered_at=datetime.now() - timedelta(hours=2), -) - -mock_scaling_rule_2 = AutoScalingRule( - id=UUID("483e2158-e089-482b-8cef-260805649cf1"), - metric_source=AutoScalingMetricSource.INFERENCE_FRAMEWORK, - metric_name="requests_per_second", - min_threshold=None, - max_threshold=Decimal("1000"), - step_size=2, - time_window=600, - min_replicas=2, - max_replicas=10, - created_at=datetime.now() - timedelta(days=5), - last_triggered_at=datetime.now() - timedelta(hours=12), -) - - -@strawberry.mutation(description="Added in 25.13.0") +@strawberry.mutation(description="Added in 25.16.0") async def create_auto_scaling_rule( input: CreateAutoScalingRuleInput, info: Info[StrawberryGQLContext] ) -> CreateAutoScalingRulePayload: - return CreateAutoScalingRulePayload(auto_scaling_rule=mock_scaling_rule_0) + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + result = await deployment_processor.create_auto_scaling_rule.wait_for_complete( + action=CreateAutoScalingRuleAction(input.to_creator()) + ) + return CreateAutoScalingRulePayload( + auto_scaling_rule=AutoScalingRule.from_dataclass(result.data) + ) -@strawberry.mutation(description="Added in 25.13.0") +@strawberry.mutation(description="Added in 25.16.0") async def update_auto_scaling_rule( input: UpdateAutoScalingRuleInput, info: Info[StrawberryGQLContext] ) -> UpdateAutoScalingRulePayload: + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + action_result = await deployment_processor.update_auto_scaling_rule.wait_for_complete( + input.to_action() + ) return UpdateAutoScalingRulePayload( - auto_scaling_rule=AutoScalingRule( - id=input.id, - metric_source=input.metric_source - if input.metric_source - else mock_scaling_rule_1.metric_source, - metric_name=input.metric_name if input.metric_name else mock_scaling_rule_1.metric_name, - min_threshold=input.min_threshold - if input.min_threshold - else mock_scaling_rule_1.min_threshold, - max_threshold=input.max_threshold - if input.max_threshold - else mock_scaling_rule_1.max_threshold, - step_size=input.step_size if input.step_size else mock_scaling_rule_1.step_size, - time_window=input.time_window if input.time_window else mock_scaling_rule_1.time_window, - min_replicas=input.min_replicas - if input.min_replicas - else mock_scaling_rule_1.min_replicas, - max_replicas=input.max_replicas - if input.max_replicas - else mock_scaling_rule_1.max_replicas, - created_at=datetime.now(), - last_triggered_at=datetime.now(), - ) + auto_scaling_rule=AutoScalingRule.from_dataclass(action_result.data) ) -@strawberry.mutation(description="Added in 25.13.0") +@strawberry.mutation(description="Added in 25.16.0") async def delete_auto_scaling_rule( input: DeleteAutoScalingRuleInput, info: Info[StrawberryGQLContext] ) -> DeleteAutoScalingRulePayload: - return DeleteAutoScalingRulePayload(id=input.id) + deployment_processor = info.context.processors.deployment + assert deployment_processor is not None + _ = await deployment_processor.delete_auto_scaling_rule.wait_for_complete( + DeleteAutoScalingRuleAction(auto_scaling_rule_id=UUID(input.id)) + ) + return DeleteAutoScalingRulePayload(id=ID(input.id)) diff --git a/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py b/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py index 34f42693a5b..8c751a48cb7 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py +++ b/src/ai/backend/manager/api/gql/model_deployment/model_deployment.py @@ -1,88 +1,122 @@ -from datetime import datetime, timedelta -from enum import StrEnum +from collections.abc import Sequence +from datetime import datetime from typing import AsyncGenerator, Optional from uuid import UUID, uuid4 import strawberry from strawberry import ID, Info -from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo +from strawberry.relay import Connection, Edge, Node, NodeID +from ai.backend.common.contexts.user import current_user from ai.backend.common.data.model_deployment.types import ( DeploymentStrategy as CommonDeploymentStrategy, ) from ai.backend.common.data.model_deployment.types import ( ModelDeploymentStatus as CommonDeploymentStatus, ) -from ai.backend.manager.api.gql.base import OrderDirection, StringFilter -from ai.backend.manager.api.gql.domain import Domain, mock_domain +from ai.backend.common.exception import ModelDeploymentNotFound, ModelDeploymentUnavailable +from ai.backend.manager.api.gql.base import ( + OrderDirection, + StringFilter, + build_page_info, + build_pagination_options, + resolve_global_id, + to_global_id, +) +from ai.backend.manager.api.gql.domain import Domain from ai.backend.manager.api.gql.model_deployment.access_token import ( + AccessToken, AccessTokenConnection, AccessTokenEdge, - mock_access_token_1, - mock_access_token_2, - mock_access_token_3, - mock_access_token_4, - mock_access_token_5, + AccessTokenOrderBy, ) from ai.backend.manager.api.gql.model_deployment.auto_scaling_rule import ( AutoScalingRule, - mock_scaling_rule_1, - mock_scaling_rule_2, ) from ai.backend.manager.api.gql.model_deployment.model_replica import ( ModelReplicaConnection, - ModelReplicaEdge, ReplicaFilter, ReplicaOrderBy, - mock_model_replica_1, - mock_model_replica_2, - mock_model_replica_3, + resolve_replicas, ) -from ai.backend.manager.api.gql.project import Project, mock_project +from ai.backend.manager.api.gql.project import Project from ai.backend.manager.api.gql.types import StrawberryGQLContext -from ai.backend.manager.api.gql.user import User, mock_user_id +from ai.backend.manager.api.gql.user import User +from ai.backend.manager.data.deployment.creator import NewDeploymentCreator +from ai.backend.manager.data.deployment.modifier import NewDeploymentModifier +from ai.backend.manager.data.deployment.types import ( + DeploymentMetadata, + DeploymentNetworkSpec, + DeploymentOrderField, + ModelDeploymentData, + ModelDeploymentMetadataInfo, + ReplicaSpec, + ReplicaStateData, +) +from ai.backend.manager.errors.user import UserNotFound +from ai.backend.manager.models.gql_models.domain import DomainNode +from ai.backend.manager.models.gql_models.group import GroupNode +from ai.backend.manager.models.gql_models.user import UserNode +from ai.backend.manager.repositories.deployment.types.types import ( + AccessTokenOrderingOptions, + DeploymentFilterOptions, + DeploymentOrderingOptions, + DeploymentStatusFilterType, +) +from ai.backend.manager.repositories.deployment.types.types import ( + DeploymentStatusFilter as RepoDeploymentStatusFilter, +) +from ai.backend.manager.services.deployment.actions.access_token.list_access_tokens import ( + ListAccessTokensAction, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.batch_load_auto_scaling_rules import ( + BatchLoadAutoScalingRulesAction, +) +from ai.backend.manager.services.deployment.actions.batch_load_deployments import ( + BatchLoadDeploymentsAction, +) +from ai.backend.manager.services.deployment.actions.create_deployment import ( + CreateDeploymentAction, +) +from ai.backend.manager.services.deployment.actions.destroy_deployment import ( + DestroyDeploymentAction, +) +from ai.backend.manager.services.deployment.actions.list_deployments import ListDeploymentsAction +from ai.backend.manager.services.deployment.actions.sync_replicas import SyncReplicaAction +from ai.backend.manager.services.deployment.actions.update_deployment import UpdateDeploymentAction +from ai.backend.manager.types import OptionalState, TriState from .model_revision import ( CreateModelRevisionInput, ModelRevision, ModelRevisionConnection, - ModelRevisionEdge, ModelRevisionFilter, ModelRevisionOrderBy, - mock_model_revision_1, - mock_model_revision_2, - mock_model_revision_3, + resolve_revisions, ) DeploymentStatus = strawberry.enum( CommonDeploymentStatus, name="DeploymentStatus", - description="Added in 25.13.0. This enum represents the deployment status of a model deployment, indicating its current state.", + description="Added in 25.16.0. This enum represents the deployment status of a model deployment, indicating its current state.", ) DeploymentStrategyType = strawberry.enum( CommonDeploymentStrategy, name="DeploymentStrategyType", - description="Added in 25.13.0. This enum represents the deployment strategy type of a model deployment, indicating the strategy used for deployment.", + description="Added in 25.16.0. This enum represents the deployment strategy type of a model deployment, indicating the strategy used for deployment.", ) -@strawberry.enum(description="Added in 25.13.0") -class DeploymentOrderField(StrEnum): - CREATED_AT = "CREATED_AT" - UPDATED_AT = "UPDATED_AT" - NAME = "NAME" - - -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class DeploymentStrategy: type: DeploymentStrategyType -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ReplicaState: - desired_replica_count: int _replica_ids: strawberry.Private[list[UUID]] + desired_replica_count: int @strawberry.field async def replicas( @@ -97,50 +131,195 @@ async def replicas( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelReplicaConnection: - return ModelReplicaConnection( - count=2, - edges=[ - ModelReplicaEdge(node=mock_model_replica_1, cursor="replica-cursor-1"), - ModelReplicaEdge(node=mock_model_replica_2, cursor="replica-cursor-2"), - ], + final_filter = ReplicaFilter(ids_in=self._replica_ids) + if filter: + final_filter = ReplicaFilter(AND=[final_filter, filter]) + + return await resolve_replicas( + info=info, + filter=final_filter, + order_by=order_by, + before=before, + after=after, + first=first, + last=last, + limit=limit, + offset=offset, ) -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ScalingRule: - auto_scaling_rules: list[AutoScalingRule] + _scaling_rule_ids: strawberry.Private[list[UUID]] + @strawberry.field + async def auto_scaling_rules(self, info: Info[StrawberryGQLContext]) -> list[AutoScalingRule]: + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.batch_load_auto_scaling_rules.wait_for_complete( + BatchLoadAutoScalingRulesAction(auto_scaling_rule_ids=self._scaling_rule_ids) + ) -@strawberry.type(description="Added in 25.13.0") + return [AutoScalingRule.from_dataclass(rule) for rule in result.data] + + +@strawberry.type(description="Added in 25.16.0") class ModelDeploymentMetadata: + _project_id: strawberry.Private[UUID] + _domain_name: strawberry.Private[str] name: str status: DeploymentStatus tags: list[str] - project: Project - domain: Domain created_at: datetime updated_at: datetime + @strawberry.field + async def project(self, info: Info[StrawberryGQLContext]) -> Project: + project_global_id = to_global_id( + GroupNode, self._project_id, is_target_graphene_object=True + ) + return Project(id=ID(project_global_id)) + + @strawberry.field + async def domain(self, info: Info[StrawberryGQLContext]) -> Domain: + domain_global_id = to_global_id( + DomainNode, self._domain_name, is_target_graphene_object=True + ) + return Domain(id=ID(domain_global_id)) + + @classmethod + def from_dataclass(cls, data: ModelDeploymentMetadataInfo) -> "ModelDeploymentMetadata": + return cls( + name=data.name, + status=DeploymentStatus(data.status), + tags=data.tags, + _project_id=data.project_id, + _domain_name=data.domain_name, + created_at=data.created_at, + updated_at=data.updated_at, + ) + + +def _convert_gql_revision_ordering_to_repo_ordering( + order_by: Optional[list[AccessTokenOrderBy]], +) -> AccessTokenOrderingOptions: + if order_by is None or len(order_by) == 0: + return AccessTokenOrderingOptions() + + repo_ordering = [] + for order in order_by: + desc = order.direction == OrderDirection.DESC + repo_ordering.append((order.field, desc)) -@strawberry.type(description="Added in 25.13.0") + return AccessTokenOrderingOptions(order_by=repo_ordering) + + +@strawberry.type(description="Added in 25.16.0") class ModelDeploymentNetworkAccess: + _access_token_ids: strawberry.Private[Optional[list[UUID]]] endpoint_url: Optional[str] = None preferred_domain_name: Optional[str] = None open_to_public: bool = False - access_tokens: AccessTokenConnection + + @strawberry.field + async def access_tokens( + self, + info: Info[StrawberryGQLContext], + order_by: Optional[list[AccessTokenOrderBy]] = None, + before: Optional[str] = None, + after: Optional[str] = None, + first: Optional[int] = None, + last: Optional[int] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + ) -> AccessTokenConnection: + """Resolve access tokens using dataloader.""" + repo_ordering = _convert_gql_revision_ordering_to_repo_ordering(order_by) + + pagination_options = build_pagination_options( + before=before, + after=after, + first=first, + last=last, + limit=limit, + offset=offset, + ) + + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await processor.list_access_tokens.wait_for_complete( + ListAccessTokensAction( + pagination=pagination_options, + ordering=repo_ordering, + ) + ) + edges = [] + tokens = action_result.data + total_count = action_result.total_count + + for token in tokens: + edges.append( + AccessTokenEdge( + node=AccessToken.from_dataclass(token), + cursor=to_global_id(AccessToken, token.id), + ) + ) + + page_info = build_page_info(edges, total_count, pagination_options) + + return AccessTokenConnection( + count=total_count, edges=edges, page_info=page_info.to_strawberry_page_info() + ) + + @classmethod + def from_dataclass(cls, data: DeploymentNetworkSpec) -> "ModelDeploymentNetworkAccess": + return cls( + _access_token_ids=data.access_token_ids, + endpoint_url=data.url, + preferred_domain_name=data.preferred_domain_name, + open_to_public=data.open_to_public, + ) # Main ModelDeployment Type -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ModelDeployment(Node): id: NodeID metadata: ModelDeploymentMetadata network_access: ModelDeploymentNetworkAccess revision: Optional[ModelRevision] = None - scaling_rule: ScalingRule - replica_state: ReplicaState default_deployment_strategy: DeploymentStrategy - created_user: User + _revision_history_ids: strawberry.Private[list[UUID]] + _replica_state_data: strawberry.Private[ReplicaStateData] + _created_user_id: strawberry.Private[UUID] + _scaling_rule_ids: strawberry.Private[list[UUID]] + + @strawberry.field + async def created_user(self, info: Info[StrawberryGQLContext]) -> User: + user_global_id = to_global_id( + UserNode, self._created_user_id, is_target_graphene_object=True + ) + return User(id=strawberry.ID(user_global_id)) + + @strawberry.field + async def scaling_rule(self, info: Info[StrawberryGQLContext]) -> ScalingRule: + return ScalingRule( + _scaling_rule_ids=self._scaling_rule_ids, + ) + + @strawberry.field + async def replica_state(self, info: Info[StrawberryGQLContext]) -> ReplicaState: + return ReplicaState( + desired_replica_count=self._replica_state_data.desired_replica_count, + _replica_ids=self._replica_state_data.replica_ids, + ) @strawberry.field async def revision_history( @@ -155,23 +334,85 @@ async def revision_history( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelRevisionConnection: - return ModelRevisionConnection( - count=2, - edges=[ - ModelRevisionEdge(node=mock_model_revision_1, cursor="rev-cursor-1"), - ModelRevisionEdge(node=mock_model_revision_2, cursor="rev-cursor-2"), - ], + final_filter = ModelRevisionFilter(ids_in=self._revision_history_ids) + if filter: + final_filter = ModelRevisionFilter(AND=[final_filter, filter]) + + return await resolve_revisions( + info=info, + filter=final_filter, + order_by=order_by, + before=before, + after=after, + first=first, + last=last, + limit=limit, + offset=offset, + ) + + @classmethod + async def batch_load_by_ids( + cls, ctx: StrawberryGQLContext, deployment_ids: Sequence[UUID] + ) -> list["ModelDeployment"]: + """Batch load deployments by their IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.batch_load_deployments.wait_for_complete( + BatchLoadDeploymentsAction(deployment_ids=list(deployment_ids)) + ) + + deployment_map = {deployment.id: deployment for deployment in result.data} + model_deployments = [] + + for deployment_id in deployment_ids: + if deployment_id not in deployment_map: + raise ModelDeploymentNotFound(f"Deployment with ID {deployment_id} not found") + model_deployments.append(cls.from_dataclass(deployment_map[deployment_id])) + + return model_deployments + + @classmethod + def from_dataclass( + cls, + data: ModelDeploymentData, + ) -> "ModelDeployment": + metadata = ModelDeploymentMetadata( + name=data.metadata.name, + status=DeploymentStatus(data.metadata.status), + tags=data.metadata.tags, + _project_id=data.metadata.project_id, + _domain_name=data.metadata.domain_name, + created_at=data.metadata.created_at, + updated_at=data.metadata.updated_at, + ) + + return cls( + id=ID(str(data.id)), + metadata=metadata, + network_access=ModelDeploymentNetworkAccess.from_dataclass(data.network_access), + revision=ModelRevision.from_dataclass(data.revision) if data.revision else None, + default_deployment_strategy=DeploymentStrategy( + type=DeploymentStrategyType(data.default_deployment_strategy) + ), + _created_user_id=data.created_user_id, + _revision_history_ids=data.revision_history_ids, + _scaling_rule_ids=data.scaling_rule_ids, + _replica_state_data=data.replica_state, ) # Filter Types -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class DeploymentStatusFilter: in_: Optional[list[DeploymentStatus]] = strawberry.field(name="in", default=None) equals: Optional[DeploymentStatus] = None -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class DeploymentFilter: name: Optional[StringFilter] = None status: Optional[DeploymentStatusFilter] = None @@ -184,36 +425,66 @@ class DeploymentFilter: OR: Optional[list["DeploymentFilter"]] = None NOT: Optional[list["DeploymentFilter"]] = None - -@strawberry.input(description="Added in 25.13.0") + def to_repo_filter(self) -> DeploymentFilterOptions: + repo_filter = DeploymentFilterOptions() + + repo_filter.name = self.name + repo_filter.open_to_public = self.open_to_public + repo_filter.tags = self.tags + repo_filter.endpoint_url = self.endpoint_url + repo_filter.id = UUID(self.id) if self.id else None + if self.status: + if self.status.in_ is not None: + repo_filter.status = RepoDeploymentStatusFilter( + type=DeploymentStatusFilterType.IN, + values=[CommonDeploymentStatus(status) for status in self.status.in_], + ) + elif self.status.equals is not None: + repo_filter.status = RepoDeploymentStatusFilter( + type=DeploymentStatusFilterType.EQUALS, + values=[CommonDeploymentStatus(self.status.equals)], + ) + + # Handle logical operations + if self.AND: + repo_filter.AND = [f.to_repo_filter() for f in self.AND] + if self.OR: + repo_filter.OR = [f.to_repo_filter() for f in self.OR] + if self.NOT: + repo_filter.NOT = [f.to_repo_filter() for f in self.NOT] + + return repo_filter + + +@strawberry.input(description="Added in 25.16.0") class DeploymentOrderBy: field: DeploymentOrderField direction: OrderDirection = OrderDirection.DESC # Payload Types -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class CreateModelDeploymentPayload: deployment: ModelDeployment -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class UpdateModelDeploymentPayload: deployment: ModelDeployment -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class DeleteModelDeploymentPayload: id: ID -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class DeploymentStatusChangedPayload: deployment: ModelDeployment # Input Types -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ModelDeploymentMetadataInput: project_id: ID domain_name: str @@ -221,18 +492,24 @@ class ModelDeploymentMetadataInput: tags: Optional[list[str]] = None -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ModelDeploymentNetworkAccessInput: preferred_domain_name: Optional[str] = None open_to_public: bool = False + def to_network_spec(self) -> DeploymentNetworkSpec: + return DeploymentNetworkSpec( + open_to_public=self.open_to_public, + preferred_domain_name=self.preferred_domain_name, + ) + -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class DeploymentStrategyInput: type: DeploymentStrategyType -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class CreateModelDeploymentInput: metadata: ModelDeploymentMetadataInput network_access: ModelDeploymentNetworkAccessInput @@ -240,8 +517,31 @@ class CreateModelDeploymentInput: desired_replica_count: int initial_revision: CreateModelRevisionInput + def to_creator(self) -> NewDeploymentCreator: + name = self.metadata.name or f"deployment-{uuid4().hex[:8]}" + tag = ",".join(self.metadata.tags) if self.metadata.tags else None + user_data = current_user() + if user_data is None: + raise UserNotFound("User not found in context") + metadata_for_creator = DeploymentMetadata( + name=name, + domain=self.metadata.domain_name, + project=UUID(str(self.metadata.project_id)), + resource_group=self.initial_revision.resource_config.resource_group.name, + created_user=user_data.user_id, + session_owner=user_data.user_id, + created_at=None, + tag=tag, + ) + return NewDeploymentCreator( + metadata=metadata_for_creator, + replica_spec=ReplicaSpec(replica_count=self.desired_replica_count), + network=self.network_access.to_network_spec(), + model_revision=self.initial_revision.to_model_revision_creator(), + ) + -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class UpdateModelDeploymentInput: id: ID open_to_public: Optional[bool] = None @@ -252,148 +552,33 @@ class UpdateModelDeploymentInput: name: Optional[str] = None preferred_domain_name: Optional[str] = None - -@strawberry.input(description="Added in 25.13.0") -class DeleteModelDeploymentInput: - id: ID - - -# TODO: After implementing the actual logic, remove these mock objects -# Mock Model Deployments -mock_model_deployment_id_1 = "8c3105c3-3a02-42e3-aa00-6923cdcd114c" -mock_created_user_id_1 = "9a41b189-72fa-4265-afe8-04172ec5d26b" -mock_model_deployment_1 = ModelDeployment( - id=UUID(mock_model_deployment_id_1), - metadata=ModelDeploymentMetadata( - name="Llama 3.8B Instruct", - status=DeploymentStatus.READY, - tags=["production", "llm", "chat", "instruct"], - created_at=datetime.now() - timedelta(days=30), - updated_at=datetime.now() - timedelta(hours=2), - project=mock_project, - domain=mock_domain, - ), - network_access=ModelDeploymentNetworkAccess( - endpoint_url="https://api.backend.ai/models/dep-001", - preferred_domain_name="llama-3-8b.models.backend.ai", - open_to_public=True, - access_tokens=AccessTokenConnection( - count=5, - edges=[ - AccessTokenEdge(node=mock_access_token_1, cursor="token-cursor-1"), - AccessTokenEdge(node=mock_access_token_2, cursor="token-cursor-2"), - AccessTokenEdge(node=mock_access_token_3, cursor="token-cursor-3"), - AccessTokenEdge(node=mock_access_token_4, cursor="token-cursor-4"), - AccessTokenEdge(node=mock_access_token_5, cursor="token-cursor-5"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="token-cursor-1", - end_cursor="token-cursor-5", + def to_modifier(self) -> NewDeploymentModifier: + strategy_type = None + if self.default_deployment_strategy is not None: + strategy_type = CommonDeploymentStrategy(self.default_deployment_strategy.type) + return NewDeploymentModifier( + open_to_public=OptionalState[bool].from_graphql(self.open_to_public), + tags=OptionalState[list[str]].from_graphql(self.tags), + default_deployment_strategy=OptionalState[CommonDeploymentStrategy].from_graphql( + strategy_type ), - ), - ), - revision=mock_model_revision_1, - scaling_rule=ScalingRule(auto_scaling_rules=[mock_scaling_rule_1, mock_scaling_rule_2]), - replica_state=ReplicaState( - desired_replica_count=3, - _replica_ids=[mock_model_replica_1.id, mock_model_replica_2.id, mock_model_replica_3.id], - ), - default_deployment_strategy=DeploymentStrategy(type=DeploymentStrategyType.ROLLING), - created_user=User(id=mock_user_id), -) + active_revision_id=OptionalState[UUID].from_graphql(UUID(self.active_revision_id)), + desired_replica_count=OptionalState[int].from_graphql(self.desired_replica_count), + name=OptionalState[str].from_graphql(self.name), + preferred_domain_name=TriState[str].from_graphql(self.preferred_domain_name), + ) -mock_model_deployment_id_2 = "5f839a95-17bd-43b0-a029-a132aa60ae71" -mock_created_user_id_2 = "75994553-fa63-4464-9398-67b6b96c8d11" -mock_model_deployment_2 = ModelDeployment( - id=UUID(mock_model_deployment_id_2), - metadata=ModelDeploymentMetadata( - name="Mistral 7B v0.3", - status=DeploymentStatus.READY, - tags=["staging", "llm", "experimental"], - created_at=datetime.now() - timedelta(days=20), - updated_at=datetime.now() - timedelta(days=1), - project=mock_project, - domain=mock_domain, - ), - network_access=ModelDeploymentNetworkAccess( - endpoint_url="https://api.backend.ai/models/dep-002", - preferred_domain_name="mistral-7b.models.backend.ai", - open_to_public=False, - access_tokens=AccessTokenConnection( - count=2, - edges=[ - AccessTokenEdge(node=mock_access_token_1, cursor="token-cursor-1"), - AccessTokenEdge(node=mock_access_token_2, cursor="token-cursor-2"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="token-cursor-1", - end_cursor="token-cursor-5", - ), - ), - ), - revision=mock_model_revision_3, - scaling_rule=ScalingRule(auto_scaling_rules=[]), - replica_state=ReplicaState( - desired_replica_count=1, - _replica_ids=[mock_model_replica_3.id], - ), - default_deployment_strategy=DeploymentStrategy(type=DeploymentStrategyType.BLUE_GREEN), - created_user=User(id=mock_user_id), -) -mock_model_deployment_id_3 = "d040c413-a5df-4292-a5f4-0e0d85f7a1d4" -mock_created_user_id_3 = "640b0af8-8140-4e58-8ca4-96daba325be8" -mock_model_deployment_3 = ModelDeployment( - id=UUID(mock_model_deployment_id_3), - metadata=ModelDeploymentMetadata( - name="Gemma 2.9B", - status=DeploymentStatus.STOPPED, - project=mock_project, - domain=mock_domain, - tags=["development", "llm", "testing"], - created_at=datetime.now() - timedelta(days=15), - updated_at=datetime.now() - timedelta(days=7), - ), - network_access=ModelDeploymentNetworkAccess( - endpoint_url=None, - preferred_domain_name=None, - open_to_public=False, - access_tokens=AccessTokenConnection( - count=4, - edges=[ - AccessTokenEdge(node=mock_access_token_1, cursor="token-cursor-1"), - AccessTokenEdge(node=mock_access_token_2, cursor="token-cursor-2"), - AccessTokenEdge(node=mock_access_token_3, cursor="token-cursor-3"), - AccessTokenEdge(node=mock_access_token_4, cursor="token-cursor-4"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="token-cursor-1", - end_cursor="token-cursor-4", - ), - ), - ), - revision=None, - scaling_rule=ScalingRule(auto_scaling_rules=[]), - replica_state=ReplicaState( - desired_replica_count=0, - _replica_ids=[], - ), - default_deployment_strategy=DeploymentStrategy(type=DeploymentStrategyType.BLUE_GREEN), - created_user=User(id=mock_user_id), -) +@strawberry.input(description="Added in 25.16.0") +class DeleteModelDeploymentInput: + id: ID ModelDeploymentEdge = Edge[ModelDeployment] # Connection types for Relay support -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ModelDeploymentConnection(Connection[ModelDeployment]): count: int @@ -402,6 +587,19 @@ def __init__(self, *args, count: int, **kwargs): self.count = count +def _convert_gql_deployment_ordering_to_repo( + order_by: Optional[list[DeploymentOrderBy]], +) -> DeploymentOrderingOptions: + if order_by is None or len(order_by) == 0: + return DeploymentOrderingOptions() + + repo_ordering = [] + for order in order_by: + desc = order.direction == OrderDirection.DESC + repo_ordering.append((order.field, desc)) + return DeploymentOrderingOptions(order_by=repo_ordering) + + async def resolve_deployments( info: Info[StrawberryGQLContext], filter: Optional[DeploymentFilter] = None, @@ -413,24 +611,54 @@ async def resolve_deployments( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelDeploymentConnection: - return ModelDeploymentConnection( - count=3, - edges=[ - ModelDeploymentEdge(node=mock_model_deployment_1, cursor="deployment-cursor-1"), - ModelDeploymentEdge(node=mock_model_deployment_2, cursor="deployment-cursor-2"), - ModelDeploymentEdge(node=mock_model_deployment_3, cursor="deployment-cursor-3"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="deployment-cursor-1", - end_cursor="deployment-cursor-3", - ), + repo_filter = None + if filter: + repo_filter = filter.to_repo_filter() + + repo_ordering = _convert_gql_deployment_ordering_to_repo(order_by) + + pagination_options = build_pagination_options( + before=before, + after=after, + first=first, + last=last, + limit=limit, + offset=offset, + ) + + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await processor.list_deployments.wait_for_complete( + ListDeploymentsAction( + pagination=pagination_options, ordering=repo_ordering, filters=repo_filter + ) + ) + edges = [] + for deployment in action_result.data: + edges.append( + ModelDeploymentEdge( + node=ModelDeployment.from_dataclass(deployment), cursor=str(deployment.id) + ) + ) + page_info = build_page_info( + edges=edges, + total_count=action_result.total_count, + pagination_options=pagination_options, ) + connection = ModelDeploymentConnection( + count=action_result.total_count, + edges=edges, + page_info=page_info.to_strawberry_page_info(), + ) + return connection + # Resolvers -@strawberry.field(description="Added in 25.13.0") +@strawberry.field(description="Added in 25.16.0") async def deployments( info: Info[StrawberryGQLContext], filter: Optional[DeploymentFilter] = None, @@ -457,63 +685,107 @@ async def deployments( ) -@strawberry.field(description="Added in 25.13.0") -async def deployment(id: ID) -> Optional[ModelDeployment]: +@strawberry.field(description="Added in 25.16.0") +async def deployment(id: ID, info: Info[StrawberryGQLContext]) -> Optional[ModelDeployment]: """Get a specific deployment by ID.""" - return mock_model_deployment_1 + _, deployment_id = resolve_global_id(id) + deployment_dataloader = info.context.dataloader_registry.get_loader( + ModelDeployment.batch_load_by_ids, info.context + ) + deployment: list[ModelDeployment] = await deployment_dataloader.load(deployment_id) + return deployment[0] -@strawberry.mutation(description="Added in 25.13.0") + +@strawberry.mutation(description="Added in 25.16.0") async def create_model_deployment( input: CreateModelDeploymentInput, info: Info[StrawberryGQLContext] -) -> CreateModelDeploymentPayload: +) -> "CreateModelDeploymentPayload": """Create a new model deployment.""" - # Create a dummy deployment for placeholder - return CreateModelDeploymentPayload(deployment=mock_model_deployment_1) + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.create_deployment.wait_for_complete( + CreateDeploymentAction(creator=input.to_creator()) + ) -@strawberry.mutation(description="Added in 25.13.0") + return CreateModelDeploymentPayload(deployment=ModelDeployment.from_dataclass(result.data)) + + +@strawberry.mutation(description="Added in 25.16.0") async def update_model_deployment( input: UpdateModelDeploymentInput, info: Info[StrawberryGQLContext] ) -> UpdateModelDeploymentPayload: """Update an existing model deployment.""" - # Create a dummy deployment for placeholder - return UpdateModelDeploymentPayload(deployment=mock_model_deployment_1) + _, deployment_id = resolve_global_id(input.id) + deployment_processor = info.context.processors.deployment + if deployment_processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await deployment_processor.update_deployment.wait_for_complete( + UpdateDeploymentAction(deployment_id=UUID(deployment_id), modifier=input.to_modifier()) + ) + return UpdateModelDeploymentPayload( + deployment=ModelDeployment.from_dataclass(action_result.data) + ) -@strawberry.mutation(description="Added in 25.13.0") +@strawberry.mutation(description="Added in 25.16.0") async def delete_model_deployment( input: DeleteModelDeploymentInput, info: Info[StrawberryGQLContext] ) -> DeleteModelDeploymentPayload: """Delete a model deployment.""" - return DeleteModelDeploymentPayload(id=ID(str(uuid4()))) + _, deployment_id = resolve_global_id(input.id) + deployment_processor = info.context.processors.deployment + if deployment_processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + _ = await deployment_processor.destroy_deployment.wait_for_complete( + DestroyDeploymentAction(endpoint_id=UUID(deployment_id)) + ) + return DeleteModelDeploymentPayload(id=input.id) -@strawberry.subscription(description="Added in 25.13.0") +@strawberry.subscription(description="Added in 25.16.0") async def deployment_status_changed( deployment_id: ID, info: Info[StrawberryGQLContext] ) -> AsyncGenerator[DeploymentStatusChangedPayload, None]: """Subscribe to deployment status changes.""" - deployment = [mock_model_deployment_1, mock_model_deployment_2, mock_model_deployment_3] + # Mock implementation + # In real implementation, this would yield artifacts when status changes + if False: # Placeholder to make this a generator + yield DeploymentStatusChangedPayload(deployment_id=deployment_id) - for dep in deployment: - yield DeploymentStatusChangedPayload(deployment=dep) - -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class SyncReplicaInput: model_deployment_id: ID -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class SyncReplicaPayload: success: bool @strawberry.mutation( - description="Added in 25.13.0. Force syncs up-to-date replica information. In normal situations this will be automatically handled by Backend.AI schedulers" + description="Added in 25.16.0. Force syncs up-to-date replica information. In normal situations this will be automatically handled by Backend.AI schedulers" ) async def sync_replicas( input: SyncReplicaInput, info: Info[StrawberryGQLContext] ) -> SyncReplicaPayload: + _, deployment_id = resolve_global_id(input.model_deployment_id) + deployment_processor = info.context.processors.deployment + if deployment_processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + await deployment_processor.sync_replicas.wait_for_complete( + SyncReplicaAction(deployment_id=UUID(deployment_id)) + ) return SyncReplicaPayload(success=True) diff --git a/src/ai/backend/manager/api/gql/model_deployment/model_replica.py b/src/ai/backend/manager/api/gql/model_deployment/model_replica.py index f68a6263cee..71362e57f5a 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/model_replica.py +++ b/src/ai/backend/manager/api/gql/model_deployment/model_replica.py @@ -1,7 +1,7 @@ -from datetime import datetime, timedelta -from enum import StrEnum -from typing import AsyncGenerator, Optional, cast -from uuid import UUID, uuid4 +from collections.abc import Sequence +from datetime import datetime +from typing import AsyncGenerator, Optional +from uuid import UUID import strawberry from strawberry import ID, Info @@ -10,89 +10,158 @@ from ai.backend.common.data.model_deployment.types import ActivenessStatus as CommonActivenessStatus from ai.backend.common.data.model_deployment.types import LivenessStatus as CommonLivenessStatus from ai.backend.common.data.model_deployment.types import ReadinessStatus as CommonReadinessStatus -from ai.backend.manager.api.gql.base import JSONString, OrderDirection +from ai.backend.common.exception import ModelDeploymentUnavailable +from ai.backend.manager.api.gql.base import ( + JSONString, + OrderDirection, + build_page_info, + build_pagination_options, + resolve_global_id, + to_global_id, +) from ai.backend.manager.api.gql.session import Session from ai.backend.manager.api.gql.types import StrawberryGQLContext -from ai.backend.manager.models.gql_relay import AsyncNode +from ai.backend.manager.data.deployment.types import ModelReplicaData, ReplicaOrderField +from ai.backend.manager.models.gql_models.session import ComputeSessionNode +from ai.backend.manager.repositories.deployment.types.types import ( + ActivenessStatusFilter as RepoActivenessStatus, +) +from ai.backend.manager.repositories.deployment.types.types import ( + ActivenessStatusFilterType, + LivenessStatusFilterType, + ModelReplicaFilterOptions, + ModelReplicaOrderingOptions, + ReadinessStatusFilterType, +) +from ai.backend.manager.repositories.deployment.types.types import ( + LivenessStatusFilter as RepoLivenessStatusFilter, +) +from ai.backend.manager.repositories.deployment.types.types import ( + ReadinessStatusFilter as RepoReadinessStatusFilter, +) +from ai.backend.manager.services.deployment.actions.batch_load_replicas_by_revision_ids import ( + BatchLoadReplicasByRevisionIdsAction, +) +from ai.backend.manager.services.deployment.actions.list_replicas import ListReplicasAction +from ai.backend.manager.types import PaginationOptions from .model_revision import ( ModelRevision, - mock_model_revision_1, ) ReadinessStatus = strawberry.enum( CommonReadinessStatus, name="ReadinessStatus", - description="Added in 25.13.0. This enum represents the readiness status of a replica, indicating whether the deployment has been checked and its health state.", + description="Added in 25.16.0. This enum represents the readiness status of a replica, indicating whether the deployment has been checked and its health state.", ) LivenessStatus = strawberry.enum( CommonLivenessStatus, name="LivenessStatus", - description="Added in 25.13.0. This enum represents the liveness status of a replica, indicating whether the deployment is currently running and able to serve requests.", + description="Added in 25.16.0. This enum represents the liveness status of a replica, indicating whether the deployment is currently running and able to serve requests.", ) ActivenessStatus = strawberry.enum( CommonActivenessStatus, name="ActivenessStatus", - description="Added in 25.13.0. This enum represents the activeness status of a replica, indicating whether the deployment is currently active and able to serve requests.", + description="Added in 25.16.0. This enum represents the activeness status of a replica, indicating whether the deployment is currently active and able to serve requests.", ) -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ReadinessStatusFilter: in_: Optional[list[ReadinessStatus]] = strawberry.field(name="in", default=None) equals: Optional[ReadinessStatus] = None -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class LivenessStatusFilter: in_: Optional[list[LivenessStatus]] = strawberry.field(name="in", default=None) equals: Optional[LivenessStatus] = None -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ActivenessStatusFilter: in_: Optional[list[ActivenessStatus]] = strawberry.field(name="in", default=None) equals: Optional[ActivenessStatus] = None -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ReplicaFilter: readiness_status: Optional[ReadinessStatusFilter] = None liveness_status: Optional[LivenessStatusFilter] = None activeness_status: Optional[ActivenessStatusFilter] = None id: Optional[ID] = None + ids_in: strawberry.Private[Optional[Sequence[UUID]]] = None AND: Optional[list["ReplicaFilter"]] = None OR: Optional[list["ReplicaFilter"]] = None NOT: Optional[list["ReplicaFilter"]] = None - -@strawberry.enum(description="Added in 25.13.0") -class ReplicaOrderField(StrEnum): - CREATED_AT = "CREATED_AT" - - -@strawberry.input(description="Added in 25.13.0") + def to_repo_filter(self) -> ModelReplicaFilterOptions: + repo_filter = ModelReplicaFilterOptions() + + if self.readiness_status: + if self.readiness_status.in_: + repo_filter.readiness_status_filter = RepoReadinessStatusFilter( + type=ReadinessStatusFilterType.IN, + values=[ReadinessStatus(status) for status in self.readiness_status.in_], + ) + elif self.readiness_status.equals: + repo_filter.readiness_status_filter = RepoReadinessStatusFilter( + type=ReadinessStatusFilterType.EQUALS, + values=[ReadinessStatus(self.readiness_status.equals)], + ) + if self.liveness_status: + if self.liveness_status.in_: + repo_filter.liveness_status_filter = RepoLivenessStatusFilter( + type=LivenessStatusFilterType.IN, + values=[LivenessStatus(status) for status in self.liveness_status.in_], + ) + elif self.liveness_status.equals: + repo_filter.liveness_status_filter = RepoLivenessStatusFilter( + type=LivenessStatusFilterType.EQUALS, + values=[LivenessStatus(self.liveness_status.equals)], + ) + if self.activeness_status: + if self.activeness_status.in_: + repo_filter.activeness_status_filter = RepoActivenessStatus( + type=ActivenessStatusFilterType.IN, + values=[ActivenessStatus(status) for status in self.activeness_status.in_], + ) + elif self.activeness_status.equals: + repo_filter.activeness_status_filter = RepoActivenessStatus( + type=ActivenessStatusFilterType.EQUALS, + values=[ActivenessStatus(self.activeness_status.equals)], + ) + + if self.id: + repo_filter.id = UUID(self.id) + if self.ids_in: + repo_filter.ids_in = list(self.ids_in) + + # Handle logical operations + if self.AND: + repo_filter.AND = [f.to_repo_filter() for f in self.AND] + if self.OR: + repo_filter.OR = [f.to_repo_filter() for f in self.OR] + if self.NOT: + repo_filter.NOT = [f.to_repo_filter() for f in self.NOT] + + return repo_filter + + +@strawberry.input(description="Added in 25.16.0") class ReplicaOrderBy: field: ReplicaOrderField direction: OrderDirection = OrderDirection.DESC -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ModelReplica(Node): id: NodeID - revision: ModelRevision _session_id: strawberry.Private[UUID] - - @strawberry.field( - description="The session ID associated with the replica. This can be null right after replica creation." - ) - async def session(self, info: Info[StrawberryGQLContext]) -> "Session": - session_global_id = AsyncNode.to_global_id("ComputeSessionNode", self._session_id) - return Session(id=ID(session_global_id)) - + _revision_id: strawberry.Private[UUID] readiness_status: ReadinessStatus = strawberry.field( description="This represents whether the replica has been checked and its health state.", ) @@ -111,11 +180,69 @@ async def session(self, info: Info[StrawberryGQLContext]) -> "Session": description='live statistics of the routing node. e.g. "live_stat": "{\\"cpu_util\\": {\\"current\\": \\"7.472\\", \\"capacity\\": \\"1000\\", \\"pct\\": \\"0.75\\", \\"unit_hint\\": \\"percent\\"}}"' ) + @strawberry.field( + description="The session ID associated with the replica. This can be null right after replica creation." + ) + async def session(self, info: Info[StrawberryGQLContext]) -> "Session": + session_global_id = to_global_id( + ComputeSessionNode, self._session_id, is_target_graphene_object=True + ) + return Session(id=ID(session_global_id)) + + @strawberry.field + async def revision(self, info: Info[StrawberryGQLContext]) -> ModelRevision: + """Resolve revision using dataloader.""" + revision_loader = info.context.dataloader_registry.get_loader( + ModelRevision.batch_load_by_ids, info.context + ) + revision: list[ModelRevision] = await revision_loader.load(self._revision_id) + return revision[0] + + @classmethod + def from_dataclass(cls, data: ModelReplicaData) -> "ModelReplica": + return cls( + id=ID(str(data.id)), + _revision_id=data.revision_id, + _session_id=data.session_id, + readiness_status=ReadinessStatus(data.readiness_status), + liveness_status=LivenessStatus(data.liveness_status), + activeness_status=ActivenessStatus(data.activeness_status), + weight=data.weight, + detail=JSONString.serialize(data.detail), + created_at=data.created_at, + live_stat=JSONString.serialize(data.live_stat), + ) + + @classmethod + async def batch_load_by_revision_ids( + cls, ctx: StrawberryGQLContext, revision_ids: Sequence[UUID] + ) -> list[list["ModelReplica"]]: + """Batch load replicas by their revision IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + action_result = await processor.batch_load_replicas_by_revision_ids.wait_for_complete( + BatchLoadReplicasByRevisionIdsAction(revision_ids=list(revision_ids)) + ) + replicas_map = action_result.data + + result = [] + for revision_id in revision_ids: + replica_data_list = replicas_map.get(revision_id, []) + replica_objects = [ + cls.from_dataclass(replica_data) for replica_data in replica_data_list + ] + result.append(replica_objects) + return result + ModelReplicaEdge = Edge[ModelReplica] -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ModelReplicaConnection(Connection[ModelReplica]): count: int @@ -123,85 +250,49 @@ def __init__(self, *args, count: int, **kwargs): super().__init__(*args, **kwargs) self.count = count + @classmethod + def from_dataclass(cls, replicas_data: list[ModelReplicaData]) -> "ModelReplicaConnection": + nodes = [ModelReplica.from_dataclass(data) for data in replicas_data] + edges = [ModelReplicaEdge(node=node, cursor=str(node.id)) for node in nodes] -# Mock Model Replicas -mock_model_replica_1 = ModelReplica( - id=UUID("b62f9890-228a-40c9-a614-63387805b9a7"), - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.HEALTHY, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.ACTIVE, - weight=1, - detail=cast( - JSONString, - '{"type": "creation_success", "message": "Model replica created successfully", "status": "operational"}', - ), - created_at=datetime.now() - timedelta(days=5), - live_stat=cast( - JSONString, - '{"requests": 1523, "latency_ms": 187, "tokens_per_second": 42.5}', - ), -) - - -mock_model_replica_2 = ModelReplica( - id=UUID("7562e9d4-a368-4e28-9092-65eb91534bac"), - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.HEALTHY, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.ACTIVE, - weight=2, - detail=cast( - JSONString, - '{"type": "creation_success", "message": "Model replica created successfully", "status": "operational"}', - ), - created_at=datetime.now() - timedelta(days=5), - live_stat=cast( - JSONString, - '{"requests": 1456, "latency_ms": 195, "tokens_per_second": 41.2}', - ), -) + page_info = PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ) -mock_model_replica_3 = ModelReplica( - id=UUID("2a2388ea-a312-422a-b77e-0e0b61c48145"), - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.UNHEALTHY, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.INACTIVE, - weight=0, - detail=cast( - JSONString, - '{"type": "creation_failed", "errors": [{"src": "", "name": "InvalidAPIParameters", "repr": ""}]}', - ), - created_at=datetime.now() - timedelta(days=2), - live_stat=cast(JSONString, '{"requests": 0, "latency_ms": 0, "tokens_per_second": 0}'), -) + return cls(count=len(nodes), edges=edges, page_info=page_info) -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ReplicaStatusChangedPayload: replica: ModelReplica -@strawberry.field(description="Added in 25.13.0") +@strawberry.field(description="Added in 25.16.0") async def replica(id: ID, info: Info[StrawberryGQLContext]) -> Optional[ModelReplica]: """Get a specific replica by ID.""" - - return ModelReplica( - id=id, - revision=mock_model_revision_1, - _session_id=uuid4(), - readiness_status=CommonReadinessStatus.NOT_CHECKED, - liveness_status=CommonLivenessStatus.HEALTHY, - activeness_status=CommonActivenessStatus.ACTIVE, - weight=1, - detail=cast(JSONString, "{}"), - created_at=datetime.now() - timedelta(days=2), - live_stat=cast(JSONString, '{"requests": 0, "latency_ms": 0, "tokens_per_second": 0}'), + _, replica_id = resolve_global_id(id) + replica_loader = info.context.dataloader_registry.get_loader( + ModelReplica.batch_load_by_revision_ids, info.context ) + replicas: list[ModelReplica] = await replica_loader.load(UUID(replica_id)) + return replicas[0] + + +def _convert_gql_replica_ordering_to_repo_ordering( + order_by: Optional[list[ReplicaOrderBy]], +) -> ModelReplicaOrderingOptions: + if not order_by: + return ModelReplicaOrderingOptions() + + repo_order_by = [] + for order in order_by: + desc = order.direction == OrderDirection.DESC + repo_order_by.append((order.field, desc)) + + return ModelReplicaOrderingOptions(order_by=repo_order_by) async def resolve_replicas( @@ -215,23 +306,52 @@ async def resolve_replicas( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelReplicaConnection: + repo_filter = None + if filter: + repo_filter = filter.to_repo_filter() + + repo_ordering = _convert_gql_replica_ordering_to_repo_ordering(order_by) + + pagination_options = build_pagination_options( + before=before, + after=after, + first=first, + last=last, + limit=limit, + offset=offset, + ) + + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + action_result = await processor.list_replicas.wait_for_complete( + ListReplicasAction( + pagination=PaginationOptions(), + ordering=repo_ordering, + filters=repo_filter, + ) + ) + edges = [] + for replica_data in action_result.data: + node = ModelReplica.from_dataclass(replica_data) + edge = ModelReplicaEdge(node=node, cursor=str(node.id)) + edges.append(edge) + + page_info = build_page_info( + edges=edges, total_count=action_result.total_count, pagination_options=pagination_options + ) + return ModelReplicaConnection( - count=3, - edges=[ - ModelReplicaEdge(node=mock_model_replica_1, cursor="replica-cursor-1"), - ModelReplicaEdge(node=mock_model_replica_2, cursor="replica-cursor-2"), - ModelReplicaEdge(node=mock_model_replica_3, cursor="replica-cursor-3"), - ], - page_info=PageInfo( - has_next_page=False, - has_previous_page=False, - start_cursor="replica-cursor-1", - end_cursor="replica-cursor-3", - ), + count=action_result.total_count, + edges=edges, + page_info=page_info.to_strawberry_page_info(), ) -@strawberry.field(description="Added in 25.13.0") +@strawberry.field(description="Added in 25.16.0") async def replicas( info: Info[StrawberryGQLContext], filter: Optional[ReplicaFilter] = None, @@ -256,12 +376,10 @@ async def replicas( ) -@strawberry.subscription(description="Added in 25.13.0") +@strawberry.subscription(description="Added in 25.16.0") async def replica_status_changed( revision_id: ID, ) -> AsyncGenerator[ReplicaStatusChangedPayload, None]: """Subscribe to replica status changes.""" - replicas = [mock_model_replica_1, mock_model_replica_2, mock_model_replica_3] - - for replica in replicas: + if False: # Replace with actual subscription logic yield ReplicaStatusChangedPayload(replica=replica) diff --git a/src/ai/backend/manager/api/gql/model_deployment/model_revision.py b/src/ai/backend/manager/api/gql/model_deployment/model_revision.py index 5de9d72feee..9def3798fc8 100644 --- a/src/ai/backend/manager/api/gql/model_deployment/model_revision.py +++ b/src/ai/backend/manager/api/gql/model_deployment/model_revision.py @@ -1,15 +1,28 @@ -from datetime import datetime, timedelta -from decimal import Decimal -from enum import Enum, StrEnum +from collections.abc import Mapping, Sequence +from datetime import datetime +from enum import StrEnum +from pathlib import PurePosixPath from typing import Any, Optional, cast -from uuid import UUID, uuid4 +from uuid import UUID import strawberry from strawberry import ID, Info from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo from strawberry.scalars import JSON -from ai.backend.manager.api.gql.base import JSONString, OrderDirection, StringFilter +from ai.backend.common.exception import ModelDeploymentUnavailable, ModelRevisionNotFound +from ai.backend.common.types import ClusterMode as CommonClusterMode +from ai.backend.common.types import MountPermission as CommonMountPermission +from ai.backend.common.types import RuntimeVariant +from ai.backend.manager.api.gql.base import ( + JSONString, + OrderDirection, + StringFilter, + build_page_info, + build_pagination_options, + resolve_global_id, + to_global_id, +) from ai.backend.manager.api.gql.image import ( Image, ) @@ -19,34 +32,83 @@ from ai.backend.manager.api.gql.types import StrawberryGQLContext from ai.backend.manager.api.gql.vfolder import ( ExtraVFolderMountConnection, - ExtraVFolderMountEdge, VFolder, - mock_extra_mount_1, - mock_extra_mount_2, - mock_vfolder_id, ) -from ai.backend.manager.data.model_deployment.inference_runtime_config import ( +from ai.backend.manager.data.deployment.creator import ModelRevisionCreator, VFolderMountsCreator +from ai.backend.manager.data.deployment.inference_runtime_config import ( MOJORuntimeConfig, NVDIANIMRuntimeConfig, SGLangRuntimeConfig, VLLMRuntimeConfig, ) +from ai.backend.manager.data.deployment.types import ( + ClusterConfigData, + ExecutionSpec, + ModelMountConfigData, + ModelRevisionData, + ModelRevisionOrderField, + ModelRuntimeConfigData, + MountInfo, + ResourceConfigData, + ResourceSpec, +) +from ai.backend.manager.data.image.types import ImageIdentifier +from ai.backend.manager.models.gql_models.image import ImageNode +from ai.backend.manager.models.gql_models.scaling_group import ScalingGroupNode +from ai.backend.manager.models.gql_models.vfolder import VirtualFolderNode +from ai.backend.manager.repositories.deployment.types.types import ( + ModelRevisionFilterOptions, + ModelRevisionOrderingOptions, +) +from ai.backend.manager.services.deployment.actions.model_revision.add_model_revision import ( + AddModelRevisionAction, +) +from ai.backend.manager.services.deployment.actions.model_revision.batch_load_revisions import ( + BatchLoadRevisionsAction, +) +from ai.backend.manager.services.deployment.actions.model_revision.create_model_revision import ( + CreateModelRevisionAction, +) +from ai.backend.manager.services.deployment.actions.model_revision.list_revisions import ( + ListRevisionsAction, +) +MountPermission = strawberry.enum( + CommonMountPermission, + name="MountPermission", + description="Added in 25.16.0. This enum represents the permission level for a mounted volume. It can be ro, rw, wd", +) -@strawberry.enum(description="Added in 25.13.0") + +@strawberry.enum(description="Added in 25.16.0") class ClusterMode(StrEnum): SINGLE_NODE = "SINGLE_NODE" MULTI_NODE = "MULTI_NODE" -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ModelMountConfig: - vfolder: VFolder + _vfolder_id: strawberry.Private[UUID] mount_destination: str definition_path: str + @strawberry.field + async def vfolder(self, info: Info[StrawberryGQLContext]) -> VFolder: + vfolder_global_id = to_global_id( + VirtualFolderNode, self._vfolder_id, is_target_graphene_object=True + ) + return VFolder(id=ID(vfolder_global_id)) + + @classmethod + def from_dataclass(cls, data: ModelMountConfigData) -> "ModelMountConfig": + return cls( + _vfolder_id=data.vfolder_id, + mount_destination=data.mount_destination, + definition_path=data.definition_path, + ) -@strawberry.type(description="Added in 25.13.0") + +@strawberry.type(description="Added in 25.16.0") class ModelRuntimeConfig: runtime_variant: str inference_runtime_config: Optional[JSON] = None @@ -55,10 +117,18 @@ class ModelRuntimeConfig: default=None, ) + @classmethod + def from_dataclass(cls, data: ModelRuntimeConfigData) -> "ModelRuntimeConfig": + return cls( + runtime_variant=data.runtime_variant, + inference_runtime_config=data.inference_runtime_config, + environ=JSONString.serialize(data.environ) if data.environ else None, + ) + -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ResourceConfig: - resource_group: ResourceGroup + _resource_group_name: strawberry.Private[str] resource_slots: JSONString = strawberry.field( description='Resource Slots are a JSON string that describes the resources allocated for the deployment. Example: "resourceSlots": "{\\"cpu\\": \\"1\\", \\"mem\\": \\"1073741824\\", \\"cuda.device\\": \\"0\\"}"' ) @@ -67,220 +137,154 @@ class ResourceConfig: default=None, ) + @strawberry.field + def resource_group(self) -> "ResourceGroup": + """Resolves the federated ResourceGroup.""" + global_id = to_global_id( + ScalingGroupNode, self._resource_group_name, is_target_graphene_object=True + ) + return ResourceGroup(id=ID(global_id)) + + @classmethod + def from_dataclass(cls, data: ResourceConfigData) -> "ResourceConfig": + return cls( + _resource_group_name=data.resource_group_name, + resource_slots=JSONString.from_resource_slot(data.resource_slot), + resource_opts=JSONString.serialize(data.resource_opts), + ) -@strawberry.type(description="Added in 25.13.0") + +@strawberry.type(description="Added in 25.16.0") class ClusterConfig: mode: ClusterMode size: int + @classmethod + def from_dataclass(cls, data: ClusterConfigData) -> "ClusterConfig": + return cls( + mode=ClusterMode(data.mode.name), + size=data.size, + ) + -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ModelRevision(Node): + _image_id: strawberry.Private[UUID] id: NodeID name: str - cluster_config: ClusterConfig resource_config: ResourceConfig - model_runtime_config: ModelRuntimeConfig model_mount_config: ModelMountConfig extra_mounts: ExtraVFolderMountConnection - - image: Image - created_at: datetime + @strawberry.field + async def image(self, info: Info[StrawberryGQLContext]) -> Image: + image_global_id = to_global_id(ImageNode, self._image_id, is_target_graphene_object=True) + return Image(id=ID(image_global_id)) + + @classmethod + def from_dataclass(cls, data: ModelRevisionData) -> "ModelRevision": + return cls( + id=ID(str(data.id)), + name=data.name, + cluster_config=ClusterConfig.from_dataclass(data.cluster_config), + resource_config=ResourceConfig.from_dataclass(data.resource_config), + model_runtime_config=ModelRuntimeConfig.from_dataclass(data.model_runtime_config), + model_mount_config=ModelMountConfig.from_dataclass(data.model_mount_config), + extra_mounts=ExtraVFolderMountConnection.from_dataclass(data.extra_vfolder_mounts), + _image_id=data.image_id, + created_at=data.created_at, + ) + + @classmethod + async def batch_load_by_ids( + cls, ctx: StrawberryGQLContext, revision_ids: Sequence[UUID] + ) -> list["ModelRevision"]: + """Batch load revisions by their IDs.""" + processor = ctx.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.batch_load_revisions.wait_for_complete( + BatchLoadRevisionsAction(revision_ids=list(revision_ids)) + ) + + revision_map = {revision.id: revision for revision in result.data} + revisions = [] + for revision_id in revision_ids: + if revision_id not in revision_map: + raise ModelRevisionNotFound(f"Revision {revision_id} not found") + revisions.append(cls.from_dataclass(revision_map[revision_id])) + + return revisions + # Filter and Order Types -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ModelRevisionFilter: name: Optional[StringFilter] = None deployment_id: Optional[ID] = None id: Optional[ID] = None + ids_in: strawberry.Private[Optional[Sequence[UUID]]] = None AND: Optional[list["ModelRevisionFilter"]] = None OR: Optional[list["ModelRevisionFilter"]] = None NOT: Optional[list["ModelRevisionFilter"]] = None + def to_repo_filter(self) -> ModelRevisionFilterOptions: + repo_filter = ModelRevisionFilterOptions() -@strawberry.enum(description="Added in 25.13.0") -class ModelRevisionOrderField(Enum): - CREATED_AT = "CREATED_AT" - NAME = "NAME" - ID = "ID" + # Handle basic filters + repo_filter.name = self.name + repo_filter.deployment_id = UUID(self.deployment_id) if self.deployment_id else None + repo_filter.id = UUID(self.id) if self.id else None + repo_filter.ids_in = list(self.ids_in) if self.ids_in is not None else None + # Handle logical operations + if self.AND: + repo_filter.AND = [f.to_repo_filter() for f in self.AND] + if self.OR: + repo_filter.OR = [f.to_repo_filter() for f in self.OR] + if self.NOT: + repo_filter.NOT = [f.to_repo_filter() for f in self.NOT] -@strawberry.input(description="Added in 25.13.0") + return repo_filter + + +@strawberry.input(description="Added in 25.16.0") class ModelRevisionOrderBy: field: ModelRevisionOrderField direction: OrderDirection = OrderDirection.DESC -# TODO: After implementing the actual logic, remove these mock objects -# Mock Model Revisions - - -def _generate_random_name() -> str: - return f"revision-{uuid4()}" - - -mock_inference_runtime_config = { - "tp_size": 2, - "pp_size": 4, - "ep_enable": True, - "sp_size": 8, - "max_model_length": 4096, - "batch_size": 32, - "memory_util_percentage": Decimal("0.90"), - "kv_storage_dtype": "float16", - "trust_remote_code": True, - "tool_call_parser": "granite", - "reasoning_parser": "deepseek_r1", -} -mock_image_global_id = ID("SW1hZ2VOb2RlOjQwMWZjYjM4LTkwMWYtNDdjYS05YmJjLWQyMjUzYjk4YTZhMA==") -mock_revision_id_1 = "d19f8f78-f308-45a9-ab7b-1c63346024fd" -mock_model_revision_1 = ModelRevision( - id=UUID(mock_revision_id_1), - name="llama-3-8b-instruct-v1.0", - cluster_config=ClusterConfig(mode=ClusterMode.SINGLE_NODE, size=1), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant="custom", - inference_runtime_config=mock_inference_runtime_config, - environ=cast(JSONString, '{"CUDA_VISIBLE_DEVICES": "0"}'), - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="models/llama-3-8b/config.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=2, - edges=[ - ExtraVFolderMountEdge(node=mock_extra_mount_1, cursor="extra-mount-cursor-1"), - ExtraVFolderMountEdge(node=mock_extra_mount_2, cursor="extra-mount-cursor-2"), - ], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now() - timedelta(days=10), -) - -mock_revision_id_2 = "3c81bc63-24c1-4a8f-9ad2-8a19899690c3" -mock_model_revision_2 = ModelRevision( - id=UUID(mock_revision_id_2), - name="llama-3-8b-instruct-v1.1", - cluster_config=ClusterConfig(mode=ClusterMode.SINGLE_NODE, size=1), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant="vllm", - inference_runtime_config=mock_inference_runtime_config, - environ=cast(JSONString, '{"CUDA_VISIBLE_DEVICES": "0,1"}'), - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="models/llama-3-8b/config.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=2, - edges=[ - ExtraVFolderMountEdge(node=mock_extra_mount_1, cursor="extra-mount-cursor-1"), - ExtraVFolderMountEdge(node=mock_extra_mount_2, cursor="extra-mount-cursor-2"), - ], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now() - timedelta(days=5), -) - - -mock_revision_id_3 = "86d1a714-b177-4851-897f-da36f306fe30" -mock_model_revision_3 = ModelRevision( - id=UUID(mock_revision_id_3), - name="mistral-7b-v0.3-initial", - cluster_config=ClusterConfig(mode=ClusterMode.SINGLE_NODE, size=1), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant="vllm", - inference_runtime_config=mock_inference_runtime_config, - environ=cast(JSONString, '{"CUDA_VISIBLE_DEVICES": "2"}'), - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="models/mistral-7b/config.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=0, - edges=[], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now() - timedelta(days=20), -) - - # Payload Types -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class CreateModelRevisionPayload: revision: ModelRevision -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class AddModelRevisionPayload: revision: ModelRevision # Input Types -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ClusterConfigInput: mode: ClusterMode size: int -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ResourceGroupInput: name: str -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ResourceConfigInput: resource_group: ResourceGroupInput resource_slots: JSONString = strawberry.field( @@ -292,13 +296,13 @@ class ResourceConfigInput: ) -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ImageInput: name: str architecture: str -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ModelRuntimeConfigInput: runtime_variant: str inference_runtime_config: Optional[JSON] = None @@ -308,20 +312,20 @@ class ModelRuntimeConfigInput: ) -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ModelMountConfigInput: vfolder_id: ID mount_destination: str definition_path: str -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class ExtraVFolderMountInput: vfolder_id: ID mount_destination: Optional[str] -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class CreateModelRevisionInput: name: Optional[str] = None cluster_config: ClusterConfigInput @@ -331,8 +335,57 @@ class CreateModelRevisionInput: model_mount_config: ModelMountConfigInput extra_mounts: Optional[list[ExtraVFolderMountInput]] + def to_model_revision_creator(self) -> ModelRevisionCreator: + image_identifier = ImageIdentifier( + canonical=self.image.name, + architecture=self.image.architecture, + ) + + resource_spec = ResourceSpec( + cluster_mode=CommonClusterMode(self.cluster_config.mode), + cluster_size=self.cluster_config.size, + resource_slots=cast(Mapping[str, Any], self.resource_config.resource_slots), + resource_opts=cast(Mapping[str, Any] | None, self.resource_config.resource_opts), + ) + + extra_mounts = [] + if self.extra_mounts is not None: + extra_mounts = [ + MountInfo( + vfolder_id=UUID(str(extra_mount.vfolder_id)), + kernel_path=PurePosixPath( + extra_mount.mount_destination + if extra_mount.mount_destination is not None + else "" + ), + ) + for extra_mount in self.extra_mounts + ] + + mounts = VFolderMountsCreator( + model_vfolder_id=UUID(str(self.model_mount_config.vfolder_id)), + model_definition_path=self.model_mount_config.definition_path, + model_mount_destination=self.model_mount_config.mount_destination, + extra_mounts=extra_mounts, + ) + + execution_spec = ExecutionSpec( + environ=cast(Optional[dict[str, str]], self.model_runtime_config.environ), + runtime_variant=RuntimeVariant(self.model_runtime_config.runtime_variant), + inference_runtime_config=cast( + Optional[dict[str, Any]], self.model_runtime_config.inference_runtime_config + ), + ) + + return ModelRevisionCreator( + image_identifier=image_identifier, + resource_spec=resource_spec, + mounts=mounts, + execution=execution_spec, + ) + -@strawberry.input(description="Added in 25.13.0") +@strawberry.input(description="Added in 25.16.0") class AddModelRevisionInput: name: Optional[str] = None deployment_id: ID @@ -343,11 +396,60 @@ class AddModelRevisionInput: model_mount_config: ModelMountConfigInput extra_mounts: Optional[list[ExtraVFolderMountInput]] + def to_model_revision_creator(self) -> ModelRevisionCreator: + image_identifier = ImageIdentifier( + canonical=self.image.name, + architecture=self.image.architecture, + ) + + resource_spec = ResourceSpec( + cluster_mode=CommonClusterMode(self.cluster_config.mode), + cluster_size=self.cluster_config.size, + resource_slots=cast(Mapping[str, Any], self.resource_config.resource_slots), + resource_opts=cast(Mapping[str, Any] | None, self.resource_config.resource_opts), + ) + + extra_mounts = [] + if self.extra_mounts is not None: + extra_mounts = [ + MountInfo( + vfolder_id=UUID(str(extra_mount.vfolder_id)), + kernel_path=PurePosixPath( + extra_mount.mount_destination + if extra_mount.mount_destination is not None + else "" + ), + ) + for extra_mount in self.extra_mounts + ] + + mounts = VFolderMountsCreator( + model_vfolder_id=UUID(str(self.model_mount_config.vfolder_id)), + model_definition_path=self.model_mount_config.definition_path, + model_mount_destination=self.model_mount_config.mount_destination, + extra_mounts=extra_mounts, + ) + + execution_spec = ExecutionSpec( + environ=cast(Optional[dict[str, str]], self.model_runtime_config.environ), + runtime_variant=RuntimeVariant(self.model_runtime_config.runtime_variant), + inference_runtime_config=cast( + Optional[dict[str, Any]], self.model_runtime_config.inference_runtime_config + ), + ) + + return ModelRevisionCreator( + image_identifier=image_identifier, + resource_spec=resource_spec, + mounts=mounts, + execution=execution_spec, + ) + ModelRevisionEdge = Edge[ModelRevision] -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ModelRevisionConnection(Connection[ModelRevision]): count: int @@ -355,9 +457,23 @@ def __init__(self, *args, count: int, **kwargs: Any): super().__init__(*args, **kwargs) self.count = count + @classmethod + def from_dataclass(cls, revisions_data: list[ModelRevisionData]) -> "ModelRevisionConnection": + nodes = [ModelRevision.from_dataclass(data) for data in revisions_data] + edges = [ModelRevisionEdge(node=node, cursor=str(node.id)) for node in nodes] + + page_info = PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ) + + return cls(count=len(nodes), edges=edges, page_info=page_info) + @strawberry.field( - description="Added in 25.13.0. Get JSON Schema for inference runtime configuration" + description="Added in 25.16.0. Get JSON Schema for inference runtime configuration" ) async def inference_runtime_config(name: str) -> JSON: match name.lower(): @@ -376,7 +492,7 @@ async def inference_runtime_config(name: str) -> JSON: @strawberry.field( - description="Added in 25.13.0 Get configuration JSON Schemas for all inference runtimes" + description="Added in 25.16.0 Get configuration JSON Schemas for all inference runtimes" ) async def inference_runtime_configs(info: Info[StrawberryGQLContext]) -> JSON: all_configs = { @@ -389,6 +505,20 @@ async def inference_runtime_configs(info: Info[StrawberryGQLContext]) -> JSON: return all_configs +def _convert_gql_revision_ordering_to_repo_ordering( + order_by: Optional[list[ModelRevisionOrderBy]], +) -> ModelRevisionOrderingOptions: + if order_by is None or len(order_by) == 0: + return ModelRevisionOrderingOptions() + + repo_ordering = [] + for order in order_by: + desc = order.direction == OrderDirection.DESC + repo_ordering.append((order.field, desc)) + + return ModelRevisionOrderingOptions(order_by=repo_ordering) + + async def resolve_revisions( info: Info[StrawberryGQLContext], filter: Optional[ModelRevisionFilter] = None, @@ -400,21 +530,56 @@ async def resolve_revisions( limit: Optional[int] = None, offset: Optional[int] = None, ) -> ModelRevisionConnection: - # Implement the logic to resolve the revisions based on the provided filters and pagination - return ModelRevisionConnection( - count=3, - edges=[ - ModelRevisionEdge(node=mock_model_revision_1, cursor="revision-cursor-1"), - ModelRevisionEdge(node=mock_model_revision_2, cursor="revision-cursor-2"), - ModelRevisionEdge(node=mock_model_revision_3, cursor="revision-cursor-3"), - ], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), + repo_filter = None + if filter: + repo_filter = filter.to_repo_filter() + + repo_ordering = _convert_gql_revision_ordering_to_repo_ordering(order_by) + + pagination_options = build_pagination_options( + before=before, + after=after, + first=first, + last=last, + limit=limit, + offset=offset, ) + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + action_result = await processor.list_revisions.wait_for_complete( + ListRevisionsAction( + pagination=pagination_options, + ordering=repo_ordering, + filters=repo_filter, + ) + ) + + edges = [] + revisions = action_result.data + total_count = action_result.total_count + for revision in revisions: + edges.append( + ModelRevisionEdge( + node=ModelRevision.from_dataclass(revision), + cursor=to_global_id(ModelRevision, revision.id), + ) + ) + + page_info = build_page_info(edges, total_count, pagination_options) + + connection = ModelRevisionConnection( + count=total_count, + edges=edges, + page_info=page_info.to_strawberry_page_info(), + ) + return connection -@strawberry.field(description="Added in 25.13.0") + +@strawberry.field(description="Added in 25.16.0") async def revisions( info: Info[StrawberryGQLContext], filter: Optional[ModelRevisionFilter] = None, @@ -440,99 +605,53 @@ async def revisions( ) -@strawberry.field(description="Added in 25.13.0") +@strawberry.field(description="Added in 25.16.0") async def revision(id: ID, info: Info[StrawberryGQLContext]) -> ModelRevision: """Get a specific revision by ID.""" - return mock_model_revision_1 - - -@strawberry.mutation(description="Added in 25.13.0") -async def create_model_revision( - input: CreateModelRevisionInput, info: Info[StrawberryGQLContext] -) -> CreateModelRevisionPayload: - """Create a new model revision.""" - revision = ModelRevision( - id=UUID("4cc91efb-7297-47ec-80c4-6e9c4378ae8b"), - name=_generate_random_name(), - cluster_config=ClusterConfig( - mode=ClusterMode.SINGLE_NODE, - size=1, - ), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant=input.model_runtime_config.runtime_variant, - inference_runtime_config=input.model_runtime_config.inference_runtime_config, - environ=None, - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="model.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=0, - edges=[], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now(), + _, revision_id = resolve_global_id(id) + revision_loader = info.context.dataloader_registry.get_loader( + ModelRevision.batch_load_by_ids, info.context ) - return CreateModelRevisionPayload(revision=revision) + revision: list[ModelRevision] = await revision_loader.load(revision_id) + return revision[0] -@strawberry.mutation(description="Added in 25.13.0") +@strawberry.mutation(description="Added in 25.16.0") async def add_model_revision( input: AddModelRevisionInput, info: Info[StrawberryGQLContext] ) -> AddModelRevisionPayload: """Add a model revision to a deployment.""" - revision = ModelRevision( - id=UUID("dda405f0-6463-45c4-a5ca-3721cc8d730c"), - name=_generate_random_name(), - cluster_config=ClusterConfig( - mode=ClusterMode.SINGLE_NODE, - size=1, - ), - resource_config=ResourceConfig( - resource_group=ResourceGroup(id=ID("U2NhbGluZ0dyb3VwTm9kZTpkZWZhdWx0")), - resource_slots=cast( - JSONString, - '{"cpu": 8, "mem": "32G", "cuda.shares": 1, "cuda.device": 1}', - ), - resource_opts=cast( - JSONString, - '{"shmem": "2G", "reserved_time": "24h", "scaling_group": "us-east-1"}', - ), - ), - model_runtime_config=ModelRuntimeConfig( - runtime_variant=input.model_runtime_config.runtime_variant, - inference_runtime_config=input.model_runtime_config.inference_runtime_config, - environ=None, - ), - model_mount_config=ModelMountConfig( - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/models", - definition_path="model.yaml", - ), - extra_mounts=ExtraVFolderMountConnection( - count=0, - edges=[], - page_info=PageInfo( - has_next_page=False, has_previous_page=False, start_cursor=None, end_cursor=None - ), - ), - image=Image(id=mock_image_global_id), - created_at=datetime.now(), + + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.add_model_revision.wait_for_complete( + AddModelRevisionAction( + model_deployment_id=UUID(input.deployment_id), adder=input.to_model_revision_creator() + ) ) - return AddModelRevisionPayload(revision=revision) + + return AddModelRevisionPayload(revision=ModelRevision.from_dataclass(result.revision)) + + +@strawberry.mutation( + description="Added in 25.16.0. Create model revision which is not attached to any deployment." +) +async def create_model_revision( + input: CreateModelRevisionInput, info: Info[StrawberryGQLContext] +) -> CreateModelRevisionPayload: + """Create a new model revision without attaching it to any deployment.""" + processor = info.context.processors.deployment + if processor is None: + raise ModelDeploymentUnavailable( + "Model Deployment feature is unavailable. Please contact support." + ) + + result = await processor.create_model_revision.wait_for_complete( + CreateModelRevisionAction(creator=input.to_model_revision_creator()) + ) + + return CreateModelRevisionPayload(revision=ModelRevision.from_dataclass(result.revision)) diff --git a/src/ai/backend/manager/api/gql/project.py b/src/ai/backend/manager/api/gql/project.py index f239f3da5cd..2de6e380919 100644 --- a/src/ai/backend/manager/api/gql/project.py +++ b/src/ai/backend/manager/api/gql/project.py @@ -9,7 +9,3 @@ class Project: @classmethod def resolve_reference(cls, id: ID, info: Info) -> "Project": return cls(id=id) - - -mock_project_id = ID("UHJvamVjdE5vZGU6ZjM4ZGVhMjMtNTBmYS00MmEwLWI1YWUtMzM4ZjVmNDY5M2Y0") -mock_project = Project(id=mock_project_id) diff --git a/src/ai/backend/manager/api/gql/types.py b/src/ai/backend/manager/api/gql/types.py index d70a1cc662e..92a65ee3897 100644 --- a/src/ai/backend/manager/api/gql/types.py +++ b/src/ai/backend/manager/api/gql/types.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from ai.backend.manager.api.gql.data_loader.registry import DataLoaderRegistry from ai.backend.manager.config.provider import ManagerConfigProvider from ai.backend.manager.services.processors import Processors @@ -18,3 +19,4 @@ class StrawberryGQLContext: config_provider: ManagerConfigProvider event_hub: EventHub event_fetcher: EventFetcher + dataloader_registry: DataLoaderRegistry diff --git a/src/ai/backend/manager/api/gql/user.py b/src/ai/backend/manager/api/gql/user.py index b6c75561d55..15be7655e16 100644 --- a/src/ai/backend/manager/api/gql/user.py +++ b/src/ai/backend/manager/api/gql/user.py @@ -9,6 +9,3 @@ class User: @classmethod def resolve_reference(cls, id: ID, info: Info) -> "User": return cls(id=id) - - -mock_user_id = ID("VXNlck5vZGU6ZjM4ZGVhMjMtNTBmYS00MmEwLWI1YWUtMzM4ZjVmNDY5M2Y0") diff --git a/src/ai/backend/manager/api/gql/vfolder.py b/src/ai/backend/manager/api/gql/vfolder.py index 56b2411a23b..a0a27775143 100644 --- a/src/ai/backend/manager/api/gql/vfolder.py +++ b/src/ai/backend/manager/api/gql/vfolder.py @@ -1,9 +1,13 @@ from typing import Any -from uuid import uuid4 +from uuid import UUID import strawberry from strawberry import ID, Info -from strawberry.relay import Connection, Edge, Node, NodeID +from strawberry.relay import Connection, Edge, Node, NodeID, PageInfo + +from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.data.deployment.types import ExtraVFolderMountData +from ai.backend.manager.models.gql_relay import AsyncNode @strawberry.federation.type(keys=["id"], name="VirtualFolderNode", extend=True) @@ -15,20 +19,31 @@ def resolve_reference(cls, id: ID, info: Info) -> "VFolder": return cls(id=id) -mock_vfolder_id = ID("VmlydHVhbEZvbGRlck5vZGU6YmEzMzE5ZGQtMTFmZC00Yjk4LTkzNGMtNjUxYTQ4YTVmMzM0") - - @strawberry.type class ExtraVFolderMount(Node): - id: NodeID + id: NodeID[str] mount_destination: str - vfolder: VFolder + _vfolder_id: strawberry.Private[UUID] + + @strawberry.field + async def vfolder(self, info: Info[StrawberryGQLContext]) -> VFolder: + vfolder_global_id = AsyncNode.to_global_id("VirtualFolderNode", self._vfolder_id) + return VFolder(id=ID(vfolder_global_id)) + + @classmethod + def from_dataclass(cls, data: ExtraVFolderMountData) -> "ExtraVFolderMount": + return cls( + # TODO: fix id generation logic + id=ID(f"{data.vfolder_id}:{data.mount_destination}"), + mount_destination=data.mount_destination, + _vfolder_id=data.vfolder_id, + ) ExtraVFolderMountEdge = Edge[ExtraVFolderMount] -@strawberry.type(description="Added in 25.13.0") +@strawberry.type(description="Added in 25.16.0") class ExtraVFolderMountConnection(Connection[ExtraVFolderMount]): count: int @@ -36,15 +51,17 @@ def __init__(self, *args, count: int, **kwargs: Any): super().__init__(*args, **kwargs) self.count = count - -mock_extra_mount_1 = ExtraVFolderMount( - id=uuid4(), - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/extra_models/model1", -) - -mock_extra_mount_2 = ExtraVFolderMount( - id=uuid4(), - vfolder=VFolder(id=mock_vfolder_id), - mount_destination="/extra_models/model2", -) + @classmethod + def from_dataclass( + cls, mounts_data: list[ExtraVFolderMountData] + ) -> "ExtraVFolderMountConnection": + nodes = [ExtraVFolderMount.from_dataclass(data) for data in mounts_data] + edges = [Edge(node=node, cursor=str(node.id)) for node in nodes] + page_info = PageInfo( + has_next_page=False, + has_previous_page=False, + start_cursor=edges[0].cursor if edges else None, + end_cursor=edges[-1].cursor if edges else None, + ) + + return cls(count=len(nodes), edges=edges, page_info=page_info) diff --git a/src/ai/backend/manager/api/service.py b/src/ai/backend/manager/api/service.py index ea233dee1c7..ba3f39b712b 100644 --- a/src/ai/backend/manager/api/service.py +++ b/src/ai/backend/manager/api/service.py @@ -56,9 +56,9 @@ ServiceConfig, ServiceInfo, ) -from ai.backend.manager.services.deployment.actions.create_deployment import ( - CreateDeploymentAction, - CreateDeploymentActionResult, +from ai.backend.manager.services.deployment.actions.create_legacy_deployment import ( + CreateLegacyDeploymentAction, + CreateLegacyDeploymentActionResult, ) from ai.backend.manager.services.deployment.actions.destroy_deployment import ( DestroyDeploymentAction, @@ -735,7 +735,7 @@ async def create(request: web.Request, params: NewServiceRequestModel) -> ServeI and root_ctx.processors.deployment is not None ): # Create deployment using the new deployment controller - deployment_action = CreateDeploymentAction( + deployment_action = CreateLegacyDeploymentAction( creator=DeploymentCreator( metadata=DeploymentMetadata( name=params.service_name, @@ -756,8 +756,8 @@ async def create(request: web.Request, params: NewServiceRequestModel) -> ServeI ), ) ) - deployment_result: CreateDeploymentActionResult = ( - await root_ctx.processors.deployment.create_deployment.wait_for_complete( + deployment_result: CreateLegacyDeploymentActionResult = ( + await root_ctx.processors.deployment.create_legacy_deployment.wait_for_complete( deployment_action ) ) diff --git a/src/ai/backend/manager/data/deployment/access_token.py b/src/ai/backend/manager/data/deployment/access_token.py new file mode 100644 index 00000000000..d0545e3ebba --- /dev/null +++ b/src/ai/backend/manager/data/deployment/access_token.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass +from datetime import datetime +from uuid import UUID + + +@dataclass +class ModelDeploymentAccessTokenCreator: + model_deployment_id: UUID + valid_until: datetime diff --git a/src/ai/backend/manager/data/deployment/creator.py b/src/ai/backend/manager/data/deployment/creator.py index ee50a0c8eaa..72d97ec3cf7 100644 --- a/src/ai/backend/manager/data/deployment/creator.py +++ b/src/ai/backend/manager/data/deployment/creator.py @@ -1,15 +1,44 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional from uuid import UUID from ai.backend.manager.data.deployment.types import ( DeploymentMetadata, DeploymentNetworkSpec, + ExecutionSpec, ModelRevisionSpec, + MountInfo, + MountMetadata, ReplicaSpec, + ResourceSpec, ) from ai.backend.manager.data.image.types import ImageIdentifier +@dataclass +class VFolderMountsCreator: + model_vfolder_id: UUID + model_definition_path: Optional[str] = None + model_mount_destination: str = "/models" + extra_mounts: list[MountInfo] = field(default_factory=list) + + +@dataclass +class ModelRevisionCreator: + image_identifier: ImageIdentifier + resource_spec: ResourceSpec + mounts: VFolderMountsCreator + execution: ExecutionSpec + + def to_revision_spec(self, mount_metadata: MountMetadata) -> ModelRevisionSpec: + return ModelRevisionSpec( + image_identifier=self.image_identifier, + resource_spec=self.resource_spec, + mounts=mount_metadata, + execution=self.execution, + ) + + @dataclass class DeploymentCreator: metadata: DeploymentMetadata @@ -37,3 +66,11 @@ def project(self) -> UUID: def name(self) -> str: """Get the deployment name from metadata.""" return self.metadata.name + + +@dataclass +class NewDeploymentCreator: + metadata: DeploymentMetadata + replica_spec: ReplicaSpec + network: DeploymentNetworkSpec + model_revision: ModelRevisionCreator diff --git a/src/ai/backend/manager/data/model_deployment/inference_runtime_config.py b/src/ai/backend/manager/data/deployment/inference_runtime_config.py similarity index 100% rename from src/ai/backend/manager/data/model_deployment/inference_runtime_config.py rename to src/ai/backend/manager/data/deployment/inference_runtime_config.py diff --git a/src/ai/backend/manager/data/deployment/modifier.py b/src/ai/backend/manager/data/deployment/modifier.py index 48de852dd6b..a6b4d8ca7ea 100644 --- a/src/ai/backend/manager/data/deployment/modifier.py +++ b/src/ai/backend/manager/data/deployment/modifier.py @@ -2,6 +2,7 @@ from typing import Any, Optional, override from uuid import UUID +from ai.backend.common.data.model_deployment.types import DeploymentStrategy from ai.backend.manager.types import OptionalState, PartialModifier, TriState @@ -89,3 +90,32 @@ def fields_to_update(self) -> dict[str, Any]: if self.model_revision: to_update.update(self.model_revision.fields_to_update()) return to_update + + +@dataclass +class NewDeploymentModifier(PartialModifier): + name: OptionalState[str] = field(default_factory=OptionalState[str].nop) + tags: OptionalState[list[str]] = field(default_factory=OptionalState[list[str]].nop) + desired_replica_count: OptionalState[int] = field(default_factory=OptionalState[int].nop) + open_to_public: OptionalState[bool] = field(default_factory=OptionalState[bool].nop) + preferred_domain_name: TriState[str] = field(default_factory=TriState[str].nop) + default_deployment_strategy: OptionalState[DeploymentStrategy] = field( + default_factory=OptionalState[DeploymentStrategy].nop + ) + active_revision_id: OptionalState[UUID] = field( + default_factory=OptionalState[UUID].nop + ) # TODO: Check if TriState is more appropriate + + @override + def fields_to_update(self) -> dict[str, Any]: + to_update: dict[str, Any] = {} + self.name.update_dict(to_update, "name") + tag = self.tags.optional_value() + if tag is not None: + to_update["tags"] = ",".join(tag) + self.desired_replica_count.update_dict(to_update, "desired_replica_count") + self.open_to_public.update_dict(to_update, "open_to_public") + self.preferred_domain_name.update_dict(to_update, "preferred_domain_name") + self.default_deployment_strategy.update_dict(to_update, "default_deployment_strategy") + self.active_revision_id.update_dict(to_update, "current_revision_id") + return to_update diff --git a/src/ai/backend/manager/data/deployment/scale.py b/src/ai/backend/manager/data/deployment/scale.py index b255163f6da..40b082835dc 100644 --- a/src/ai/backend/manager/data/deployment/scale.py +++ b/src/ai/backend/manager/data/deployment/scale.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from datetime import datetime +from decimal import Decimal from typing import Optional from uuid import UUID from ai.backend.common.types import AutoScalingMetricComparator, AutoScalingMetricSource +# Dataclasses for auto scaling rules used in Model Service (legacy) @dataclass class AutoScalingCondition: metric_source: AutoScalingMetricSource @@ -35,3 +37,31 @@ class AutoScalingRule: action: AutoScalingAction created_at: datetime last_triggered_at: Optional[datetime] + + +# Dataclasses for auto scaling rules used in Model Deployment +@dataclass +class ModelDeploymentAutoScalingRuleCreator: + model_deployment_id: UUID + metric_source: AutoScalingMetricSource + metric_name: str + min_threshold: Optional[Decimal] + max_threshold: Optional[Decimal] + step_size: int + time_window: int + min_replicas: Optional[int] + max_replicas: Optional[int] + + +@dataclass +class ModelDeploymentAutoScalingRule: + id: UUID + model_deployment_id: UUID + metric_source: AutoScalingMetricSource + metric_name: str + min_threshold: Optional[Decimal] + max_threshold: Optional[Decimal] + step_size: int + time_window: int + min_replicas: Optional[int] + max_replicas: Optional[int] diff --git a/src/ai/backend/manager/data/deployment/scale_modifier.py b/src/ai/backend/manager/data/deployment/scale_modifier.py index b8cf20db48a..feefc3da662 100644 --- a/src/ai/backend/manager/data/deployment/scale_modifier.py +++ b/src/ai/backend/manager/data/deployment/scale_modifier.py @@ -1,43 +1,12 @@ from dataclasses import dataclass, field -from datetime import datetime +from decimal import Decimal from typing import Any, Optional, override -from uuid import UUID from ai.backend.common.types import AutoScalingMetricComparator, AutoScalingMetricSource from ai.backend.manager.types import OptionalState, PartialModifier -@dataclass -class AutoScalingCondition: - metric_source: AutoScalingMetricSource - metric_name: str - threshold: str - comparator: AutoScalingMetricComparator - - -@dataclass -class AutoScalingAction: - step_size: int - cooldown_seconds: int - min_replicas: Optional[int] = None - max_replicas: Optional[int] = None - - -@dataclass -class AutoScalingRuleCreator: - condition: AutoScalingCondition - action: AutoScalingAction - - -@dataclass -class AutoScalingRule: - id: UUID - condition: AutoScalingCondition - action: AutoScalingAction - created_at: datetime - last_triggered_at: Optional[datetime] - - +# Dataclasses for auto scaling rules used in Model Service (legacy) @dataclass class AutoScalingConditionModifier(PartialModifier): metric_source: OptionalState[AutoScalingMetricSource] = field( @@ -91,3 +60,31 @@ def fields_to_update(self) -> dict[str, Any]: to_update.update(self.condition_modifier.fields_to_update()) to_update.update(self.action_modifier.fields_to_update()) return to_update + + +# Dataclasses for auto scaling rules used in Model Deployment +@dataclass +class ModelDeploymentAutoScalingRuleModifier(PartialModifier): + metric_source: OptionalState[AutoScalingMetricSource] = field( + default_factory=OptionalState[AutoScalingMetricSource].nop + ) + metric_name: OptionalState[str] = field(default_factory=OptionalState[str].nop) + min_threshold: OptionalState[Decimal] = field(default_factory=OptionalState[Decimal].nop) + max_threshold: OptionalState[Decimal] = field(default_factory=OptionalState[Decimal].nop) + step_size: OptionalState[int] = field(default_factory=OptionalState[int].nop) + time_window: OptionalState[int] = field(default_factory=OptionalState[int].nop) + min_replicas: OptionalState[int] = field(default_factory=OptionalState[int].nop) + max_replicas: OptionalState[int] = field(default_factory=OptionalState[int].nop) + + @override + def fields_to_update(self) -> dict[str, Any]: + to_update: dict[str, Any] = {} + self.metric_source.update_dict(to_update, "metric_source") + self.metric_name.update_dict(to_update, "metric_name") + self.min_threshold.update_dict(to_update, "min_threshold") + self.max_threshold.update_dict(to_update, "max_threshold") + self.step_size.update_dict(to_update, "step_size") + self.time_window.update_dict(to_update, "time_window") + self.min_replicas.update_dict(to_update, "min_replicas") + self.max_replicas.update_dict(to_update, "max_replicas") + return to_update diff --git a/src/ai/backend/manager/data/deployment/types.py b/src/ai/backend/manager/data/deployment/types.py index b1fcac603d6..1b39259f83e 100644 --- a/src/ai/backend/manager/data/deployment/types.py +++ b/src/ai/backend/manager/data/deployment/types.py @@ -2,13 +2,29 @@ from collections.abc import Mapping from dataclasses import dataclass, field from datetime import datetime +from decimal import Decimal from functools import lru_cache +from pathlib import PurePosixPath from typing import Any, Optional from uuid import UUID import yarl -from ai.backend.common.types import ClusterMode, RuntimeVariant, SessionId, VFolderMount +from ai.backend.common.data.model_deployment.types import ( + ActivenessStatus, + DeploymentStrategy, + LivenessStatus, + ModelDeploymentStatus, + ReadinessStatus, +) +from ai.backend.common.types import ( + AutoScalingMetricSource, + ClusterMode, + ResourceSlot, + RuntimeVariant, + SessionId, + VFolderMount, +) from ai.backend.manager.data.deployment.scale import AutoScalingRule from ai.backend.manager.data.image.types import ImageIdentifier @@ -95,6 +111,12 @@ class MountSpec: mount_options: Mapping[UUID, dict[str, Any]] +@dataclass +class MountInfo: + vfolder_id: UUID + kernel_path: PurePosixPath + + @dataclass class MountMetadata: model_vfolder_id: UUID @@ -142,6 +164,7 @@ class ExecutionSpec: environ: Optional[dict[str, str]] = None runtime_variant: RuntimeVariant = RuntimeVariant.CUSTOM callback_url: Optional[yarl.URL] = None + inference_runtime_config: Optional[Mapping[str, Any]] = None @dataclass @@ -155,7 +178,9 @@ class ModelRevisionSpec: @dataclass class DeploymentNetworkSpec: open_to_public: bool + access_token_ids: Optional[list[UUID]] = None url: Optional[str] = None + preferred_domain_name: Optional[str] = None @dataclass @@ -218,3 +243,138 @@ class DeploymentInfoWithAutoScalingRules: deployment_info: DeploymentInfo rules: list[AutoScalingRule] = field(default_factory=list) + + +@dataclass +class ModelDeploymentAutoScalingRuleData: + id: UUID + model_deployment_id: UUID + metric_source: AutoScalingMetricSource + metric_name: str + min_threshold: Optional[Decimal] + max_threshold: Optional[Decimal] + step_size: int + time_window: int + min_replicas: Optional[int] + max_replicas: Optional[int] + created_at: datetime + last_triggered_at: datetime + + +@dataclass +class ModelDeploymentAccessTokenData: + id: UUID + token: str + valid_until: datetime + created_at: datetime + + +@dataclass +class ModelReplicaData: + id: UUID + revision_id: UUID + session_id: UUID + readiness_status: ReadinessStatus + liveness_status: LivenessStatus + activeness_status: ActivenessStatus + weight: int + detail: dict[str, Any] + created_at: datetime + live_stat: dict[str, Any] + + +@dataclass +class ClusterConfigData: + mode: ClusterMode + size: int + + +@dataclass +class ResourceConfigData: + resource_group_name: str + resource_slot: ResourceSlot + resource_opts: Mapping[str, Any] = field(default_factory=dict) + + +@dataclass +class ModelRuntimeConfigData: + runtime_variant: RuntimeVariant + inference_runtime_config: Optional[Mapping[str, Any]] = None + environ: Optional[dict[str, Any]] = None + + +@dataclass +class ModelMountConfigData: + vfolder_id: UUID + mount_destination: str + definition_path: str + + +@dataclass +class ExtraVFolderMountData: + vfolder_id: UUID + mount_destination: str + + +@dataclass +class ModelRevisionData: + id: UUID + name: str + cluster_config: ClusterConfigData + resource_config: ResourceConfigData + model_runtime_config: ModelRuntimeConfigData + model_mount_config: ModelMountConfigData + created_at: datetime + image_id: UUID + extra_vfolder_mounts: list[ExtraVFolderMountData] = field(default_factory=list) + + +@dataclass +class ModelDeploymentMetadataInfo: + name: str + status: ModelDeploymentStatus + tags: list[str] + project_id: UUID + domain_name: str + created_at: datetime + updated_at: datetime + + +@dataclass +class ReplicaStateData: + desired_replica_count: int + replica_ids: list[UUID] + + +@dataclass +class ModelDeploymentData: + id: UUID + metadata: ModelDeploymentMetadataInfo + network_access: DeploymentNetworkSpec + revision: Optional[ModelRevisionData] + revision_history_ids: list[UUID] + scaling_rule_ids: list[UUID] + replica_state: ReplicaStateData + default_deployment_strategy: DeploymentStrategy + created_user_id: UUID + access_token_ids: Optional[UUID] = None + + +class DeploymentOrderField(enum.StrEnum): + CREATED_AT = "CREATED_AT" + UPDATED_AT = "UPDATED_AT" + NAME = "NAME" + + +class ModelRevisionOrderField(enum.StrEnum): + CREATED_AT = "CREATED_AT" + NAME = "NAME" + + +class ReplicaOrderField(enum.StrEnum): + CREATED_AT = "CREATED_AT" + ID = "ID" + + +class AccessTokenOrderField(enum.StrEnum): + CREATED_AT = "CREATED_AT" diff --git a/src/ai/backend/manager/data/model_deployment/BUILD b/src/ai/backend/manager/data/model_deployment/BUILD deleted file mode 100644 index 73574424040..00000000000 --- a/src/ai/backend/manager/data/model_deployment/BUILD +++ /dev/null @@ -1 +0,0 @@ -python_sources(name="src") diff --git a/src/ai/backend/manager/data/scaling_group/types.py b/src/ai/backend/manager/data/scaling_group/types.py new file mode 100644 index 00000000000..a99fb3781d1 --- /dev/null +++ b/src/ai/backend/manager/data/scaling_group/types.py @@ -0,0 +1,20 @@ +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime +from typing import Any + + +@dataclass +class ScalingGroupData: + name: str + description: str + is_active: bool + is_public: bool + created_at: datetime + wsproxy_addr: str + wsproxy_api_token: str + driver: str + driver_opts: Mapping[str, Any] + scheduler: str + scheduler_opts: Mapping[str, Any] + use_host_network: bool diff --git a/src/ai/backend/manager/dto/context.py b/src/ai/backend/manager/dto/context.py index d8f3e1e520d..0e8d078a90c 100644 --- a/src/ai/backend/manager/dto/context.py +++ b/src/ai/backend/manager/dto/context.py @@ -31,3 +31,14 @@ class ProcessorsCtx(MiddlewareParam): async def from_request(cls, request: web.Request) -> Self: root_ctx: RootContext = request.app["_root.context"] return cls(processors=root_ctx.processors) + + +class RequestCtx(MiddlewareParam): + request: web.Request + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @override + @classmethod + async def from_request(cls, request: web.Request) -> Self: + return cls(request=request) diff --git a/src/ai/backend/manager/models/gql.py b/src/ai/backend/manager/models/gql.py index 627db64ad9e..31c5386227f 100644 --- a/src/ai/backend/manager/models/gql.py +++ b/src/ai/backend/manager/models/gql.py @@ -514,7 +514,7 @@ class Mutation(graphene.ObjectType): class Query(graphene.ObjectType): """ All available GraphQL queries. - Type name changed from 'Queries' to 'Query' in 25.13.0 + Type name changed from 'Queries' to 'Query' in 25.15.0 """ node = AsyncNode.Field() diff --git a/src/ai/backend/manager/models/gql_models/domain.py b/src/ai/backend/manager/models/gql_models/domain.py index 8cc32a90a07..03fa9bc1e91 100644 --- a/src/ai/backend/manager/models/gql_models/domain.py +++ b/src/ai/backend/manager/models/gql_models/domain.py @@ -21,7 +21,7 @@ from sqlalchemy.engine.row import Row from sqlalchemy.ext.asyncio import AsyncSession -from ai.backend.common.exception import DomainNotFoundError +from ai.backend.common.exception import DomainNotFound from ai.backend.common.types import ResourceSlot, Sentinel from ai.backend.manager.data.domain.types import ( DomainCreator, @@ -320,7 +320,7 @@ async def get_connection( async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs) -> DomainNode: domain_node = await DomainNode.get_node(info, self.id) if domain_node is None: - raise DomainNotFoundError() + raise DomainNotFound(f"Domain not found: {self.id}") return domain_node diff --git a/src/ai/backend/manager/models/gql_models/group.py b/src/ai/backend/manager/models/gql_models/group.py index b8799fb44aa..c690662b546 100644 --- a/src/ai/backend/manager/models/gql_models/group.py +++ b/src/ai/backend/manager/models/gql_models/group.py @@ -19,7 +19,7 @@ from graphql import Undefined from sqlalchemy.engine.row import Row -from ai.backend.common.exception import GroupNotFoundError +from ai.backend.common.exception import GroupNotFound from ai.backend.common.types import ResourceSlot from ai.backend.manager.data.group.types import GroupCreator, GroupData, GroupModifier from ai.backend.manager.models.rbac import ProjectScope @@ -235,7 +235,7 @@ async def get_node(cls, info: graphene.ResolveInfo, id) -> Self: async with graph_ctx.db.begin_readonly_session() as db_session: group_row = (await db_session.scalars(query)).first() if group_row is None: - raise GroupNotFoundError() + raise GroupNotFound(f"Group not found: {group_id}") return cls.from_row(graph_ctx, group_row) @classmethod diff --git a/src/ai/backend/manager/models/gql_models/scaling_group.py b/src/ai/backend/manager/models/gql_models/scaling_group.py index 1506bb1ecdb..e15960f770b 100644 --- a/src/ai/backend/manager/models/gql_models/scaling_group.py +++ b/src/ai/backend/manager/models/gql_models/scaling_group.py @@ -22,10 +22,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import load_only -from ai.backend.common.exception import ScalingGroupNotFoundError from ai.backend.common.types import AccessKey, ResourceSlot from ai.backend.logging.utils import BraceStyleAdapter -from ai.backend.manager.errors.resource import ScalingGroupDeletionFailure +from ai.backend.manager.errors.resource import ScalingGroupDeletionFailure, ScalingGroupNotFound from ai.backend.manager.models.agent import AgentStatus from ai.backend.manager.models.user import UserRole from ai.backend.manager.models.utils import execute_with_txn_retry @@ -133,7 +132,7 @@ async def __resolve_reference(self, info: graphene.ResolveInfo, **kwargs) -> "Sc ) result = await db_session.scalar(query_stmt) if result is None: - raise ScalingGroupNotFoundError() + raise ScalingGroupNotFound(f"Scaling group not found: {scaling_group_name}") return ScalingGroupNode.from_row(graph_ctx, result) @classmethod diff --git a/src/ai/backend/manager/models/gql_models/user.py b/src/ai/backend/manager/models/gql_models/user.py index f538361075f..6b01e7340c2 100644 --- a/src/ai/backend/manager/models/gql_models/user.py +++ b/src/ai/backend/manager/models/gql_models/user.py @@ -20,7 +20,7 @@ from graphql import Undefined from sqlalchemy.engine.row import Row -from ai.backend.common.exception import UserNotFoundError +from ai.backend.common.exception import UserNotFound from ai.backend.manager.data.user.types import ( UserCreator, UserData, @@ -193,7 +193,7 @@ async def get_node(cls, info: graphene.ResolveInfo, id) -> Self: async with graph_ctx.db.begin_readonly_session() as db_session: user_row = (await db_session.scalars(query)).first() if user_row is None: - raise UserNotFoundError() + raise UserNotFound(f"User not found: {user_id}") return cls.from_row(graph_ctx, user_row) _queryfilter_fieldspec: Mapping[str, FieldSpecItem] = { diff --git a/src/ai/backend/manager/models/gql_models/vfolder.py b/src/ai/backend/manager/models/gql_models/vfolder.py index a0805ab1674..6ffad9f37d7 100644 --- a/src/ai/backend/manager/models/gql_models/vfolder.py +++ b/src/ai/backend/manager/models/gql_models/vfolder.py @@ -24,7 +24,7 @@ from sqlalchemy.orm import joinedload from ai.backend.common.config import model_definition_iv -from ai.backend.common.exception import VFolderNotFoundError +from ai.backend.common.exception import VFolderNotFound from ai.backend.common.types import ( VFolderID, VFolderUsageMode, @@ -402,7 +402,7 @@ async def __resolve_reference( ) -> "VirtualFolderNode": vfolder_node = await VirtualFolderNode.get_node(info, self.id) if vfolder_node is None: - raise VFolderNotFoundError(self.id) + raise VFolderNotFound(f"Virtual folder not found: {self.id}") return vfolder_node diff --git a/src/ai/backend/manager/repositories/deployment/types/types.py b/src/ai/backend/manager/repositories/deployment/types/types.py new file mode 100644 index 00000000000..36e40279ea1 --- /dev/null +++ b/src/ai/backend/manager/repositories/deployment/types/types.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional +from uuid import UUID + +from ai.backend.common.data.model_deployment.types import ( + ActivenessStatus, + LivenessStatus, + ModelDeploymentStatus, + ReadinessStatus, +) +from ai.backend.manager.api.gql.base import StringFilter +from ai.backend.manager.data.deployment.types import ( + AccessTokenOrderField, + DeploymentOrderField, + ModelRevisionOrderField, + ReplicaOrderField, +) + + +@dataclass +class DeploymentOrderingOptions: + """Ordering options for deployment queries.""" + + order_by: list[tuple[DeploymentOrderField, bool]] = field( + default_factory=lambda: [(DeploymentOrderField.CREATED_AT, True)] + ) # (field, desc) + + +@dataclass +class ModelRevisionOrderingOptions: + """Ordering options for model revision queries.""" + + order_by: list[tuple[ModelRevisionOrderField, bool]] = field( + default_factory=lambda: [(ModelRevisionOrderField.CREATED_AT, True)] + ) # (field, desc) + + +@dataclass +class ModelReplicaOrderingOptions: + """Ordering options for model replica queries.""" + + order_by: list[tuple[ReplicaOrderField, bool]] = field( + default_factory=lambda: [(ReplicaOrderField.CREATED_AT, True)] + ) # (field, desc) + + +@dataclass +class AccessTokenOrderingOptions: + """Ordering options for access token queries.""" + + order_by: list[tuple[AccessTokenOrderField, bool]] = field( + default_factory=lambda: [(AccessTokenOrderField.CREATED_AT, True)] + ) # (field, desc) + + +class DeploymentStatusFilterType(Enum): + IN = "in" + EQUALS = "equals" + + +@dataclass +class DeploymentStatusFilter: + """Status filter with operation type and values.""" + + type: DeploymentStatusFilterType + values: list[ModelDeploymentStatus] + + +@dataclass +class DeploymentFilterOptions: + """Filtering options for deployments.""" + + name: Optional[StringFilter] = None + status: Optional[DeploymentStatusFilter] = None + open_to_public: Optional[bool] = None + tags: Optional[StringFilter] = None + endpoint_url: Optional[StringFilter] = None + id: Optional[UUID] = None + + # Logical operations + AND: Optional[list["DeploymentFilterOptions"]] = None + OR: Optional[list["DeploymentFilterOptions"]] = None + NOT: Optional[list["DeploymentFilterOptions"]] = None + + +@dataclass +class ModelRevisionFilterOptions: + """Filtering options for model revisions.""" + + name: Optional[StringFilter] = None + deployment_id: Optional[UUID] = None + id: Optional[UUID] = None + ids_in: Optional[list[UUID]] = None + + # Logical operations + AND: Optional[list["ModelRevisionFilterOptions"]] = None + OR: Optional[list["ModelRevisionFilterOptions"]] = None + NOT: Optional[list["ModelRevisionFilterOptions"]] = None + + +class ReadinessStatusFilterType(Enum): + IN = "in" + EQUALS = "equals" + + +@dataclass +class ReadinessStatusFilter: + """Readiness status filter with operation type and values.""" + + type: ReadinessStatusFilterType + values: list[ReadinessStatus] + + +class LivenessStatusFilterType(Enum): + IN = "in" + EQUALS = "equals" + + +@dataclass +class LivenessStatusFilter: + """Liveness status filter with operation type and values.""" + + type: LivenessStatusFilterType + values: list[LivenessStatus] + + +class ActivenessStatusFilterType(Enum): + IN = "in" + EQUALS = "equals" + + +@dataclass +class ActivenessStatusFilter: + """Activeness status filter with operation type and values.""" + + type: ActivenessStatusFilterType + values: list[ActivenessStatus] + + +@dataclass +class ModelReplicaFilterOptions: + """Filtering options for model replicas.""" + + readiness_status_filter: Optional[ReadinessStatusFilter] = None + liveness_status_filter: Optional[LivenessStatusFilter] = None + activeness_status_filter: Optional[ActivenessStatusFilter] = None + id: Optional[UUID] = None + ids_in: Optional[list[UUID]] = None + + # Logical operations + AND: Optional[list["ModelReplicaFilterOptions"]] = None + OR: Optional[list["ModelReplicaFilterOptions"]] = None + NOT: Optional[list["ModelReplicaFilterOptions"]] = None diff --git a/src/ai/backend/manager/services/deployment/actions/access_token/__init__.py b/src/ai/backend/manager/services/deployment/actions/access_token/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/services/deployment/actions/access_token/base.py b/src/ai/backend/manager/services/deployment/actions/access_token/base.py new file mode 100644 index 00000000000..ab9c7607c69 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/access_token/base.py @@ -0,0 +1,10 @@ +from typing import override + +from ai.backend.manager.actions.action import BaseAction + + +class DeploymentAccessTokenBaseAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> str: + return "deployment_access_token" diff --git a/src/ai/backend/manager/services/deployment/actions/access_token/create_access_token.py b/src/ai/backend/manager/services/deployment/actions/access_token/create_access_token.py new file mode 100644 index 00000000000..9b596a04039 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/access_token/create_access_token.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.access_token import ModelDeploymentAccessTokenCreator +from ai.backend.manager.data.deployment.types import ModelDeploymentAccessTokenData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class CreateAccessTokenAction(DeploymentBaseAction): + creator: ModelDeploymentAccessTokenCreator + + @override + def entity_id(self) -> Optional[str]: + return str(self.creator.model_deployment_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class CreateAccessTokenActionResult(BaseActionResult): + data: ModelDeploymentAccessTokenData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/access_token/list_access_tokens.py b/src/ai/backend/manager/services/deployment/actions/access_token/list_access_tokens.py new file mode 100644 index 00000000000..f6c33d1645f --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/access_token/list_access_tokens.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelDeploymentAccessTokenData +from ai.backend.manager.repositories.deployment.types.types import AccessTokenOrderingOptions +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction +from ai.backend.manager.types import PaginationOptions + + +@dataclass +class ListAccessTokensAction(DeploymentBaseAction): + pagination: PaginationOptions + ordering: Optional[AccessTokenOrderingOptions] = None + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "list_access_tokens" + + +@dataclass +class ListAccessTokensActionResult(BaseActionResult): + data: list[ModelDeploymentAccessTokenData] + total_count: int + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/__init__.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/base.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/base.py new file mode 100644 index 00000000000..33c4781a994 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/base.py @@ -0,0 +1,10 @@ +from typing import override + +from ai.backend.manager.actions.action import BaseAction + + +class AutoScalingRuleBaseAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> str: + return "auto_scaling_rule" diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/batch_load_auto_scaling_rules.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/batch_load_auto_scaling_rules.py new file mode 100644 index 00000000000..497181b2dae --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/batch_load_auto_scaling_rules.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelDeploymentAutoScalingRuleData, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class BatchLoadAutoScalingRulesAction(AutoScalingRuleBaseAction): + auto_scaling_rule_ids: list[UUID] + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "batch_load_auto_scaling_rules" + + +@dataclass +class BatchLoadAutoScalingRulesActionResult(BaseActionResult): + data: list[ModelDeploymentAutoScalingRuleData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/create_auto_scaling_rule.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/create_auto_scaling_rule.py new file mode 100644 index 00000000000..0e68936e500 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/create_auto_scaling_rule.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.scale import ModelDeploymentAutoScalingRuleCreator +from ai.backend.manager.data.deployment.types import ModelDeploymentAutoScalingRuleData +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class CreateAutoScalingRuleAction(AutoScalingRuleBaseAction): + creator: ModelDeploymentAutoScalingRuleCreator + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class CreateAutoScalingRuleActionResult(BaseActionResult): + data: ModelDeploymentAutoScalingRuleData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/delete_auto_scaling_rule.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/delete_auto_scaling_rule.py new file mode 100644 index 00000000000..4c0c5d3f51c --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/delete_auto_scaling_rule.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class DeleteAutoScalingRuleAction(AutoScalingRuleBaseAction): + auto_scaling_rule_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.auto_scaling_rule_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "delete" + + +@dataclass +class DeleteAutoScalingRuleActionResult(BaseActionResult): + success: bool + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/update_auto_scaling_rule.py b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/update_auto_scaling_rule.py new file mode 100644 index 00000000000..534a147d64a --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/auto_scaling_rule/update_auto_scaling_rule.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.scale_modifier import ModelDeploymentAutoScalingRuleModifier +from ai.backend.manager.data.deployment.types import ( + ModelDeploymentAutoScalingRuleData, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.base import ( + AutoScalingRuleBaseAction, +) + + +@dataclass +class UpdateAutoScalingRuleAction(AutoScalingRuleBaseAction): + auto_scaling_rule_id: UUID + modifier: ModelDeploymentAutoScalingRuleModifier + + @override + def entity_id(self) -> Optional[str]: + return str(self.auto_scaling_rule_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "update" + + +@dataclass +class UpdateAutoScalingRuleActionResult(BaseActionResult): + data: ModelDeploymentAutoScalingRuleData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/base.py b/src/ai/backend/manager/services/deployment/actions/base.py index 4cbfb6ff613..4b73f7009ee 100644 --- a/src/ai/backend/manager/services/deployment/actions/base.py +++ b/src/ai/backend/manager/services/deployment/actions/base.py @@ -1,13 +1,9 @@ -"""Base action for deployment service.""" - from typing import override from ai.backend.manager.actions.action import BaseAction class DeploymentBaseAction(BaseAction): - """Base action for deployment operations.""" - @override @classmethod def entity_type(cls) -> str: diff --git a/src/ai/backend/manager/services/deployment/actions/batch_load_deployments.py b/src/ai/backend/manager/services/deployment/actions/batch_load_deployments.py new file mode 100644 index 00000000000..57aeb89fa77 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/batch_load_deployments.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelDeploymentData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class BatchLoadDeploymentsAction(DeploymentBaseAction): + deployment_ids: list[UUID] + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "batch_load_deployments" + + +@dataclass +class BatchLoadDeploymentsActionResult(BaseActionResult): + data: list[ModelDeploymentData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/batch_load_replicas_by_revision_ids.py b/src/ai/backend/manager/services/deployment/actions/batch_load_replicas_by_revision_ids.py new file mode 100644 index 00000000000..137a66c5105 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/batch_load_replicas_by_revision_ids.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelReplicaData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class BatchLoadReplicasByRevisionIdsAction(DeploymentBaseAction): + revision_ids: list[UUID] + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "batch_load_replicas_by_revision_ids" + + +@dataclass +class BatchLoadReplicasByRevisionIdsActionResult(BaseActionResult): + data: dict[UUID, list[ModelReplicaData]] + + @override + def entity_id(self) -> Optional[str]: + return None # This is a list operation for replicas diff --git a/src/ai/backend/manager/services/deployment/actions/create_deployment.py b/src/ai/backend/manager/services/deployment/actions/create_deployment.py index bdfbbf2537f..3c491b471ba 100644 --- a/src/ai/backend/manager/services/deployment/actions/create_deployment.py +++ b/src/ai/backend/manager/services/deployment/actions/create_deployment.py @@ -4,16 +4,16 @@ from typing import Optional, override from ai.backend.manager.actions.action import BaseActionResult -from ai.backend.manager.data.deployment.creator import DeploymentCreator -from ai.backend.manager.data.deployment.types import DeploymentInfo +from ai.backend.manager.data.deployment.creator import NewDeploymentCreator +from ai.backend.manager.data.deployment.types import ModelDeploymentData from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction @dataclass class CreateDeploymentAction(DeploymentBaseAction): - """Action to create a new deployment.""" + """Action to create a new deployment(Model Service).""" - creator: DeploymentCreator + creator: NewDeploymentCreator @override def entity_id(self) -> Optional[str]: @@ -27,7 +27,7 @@ def operation_type(cls) -> str: @dataclass class CreateDeploymentActionResult(BaseActionResult): - data: DeploymentInfo + data: ModelDeploymentData @override def entity_id(self) -> Optional[str]: diff --git a/src/ai/backend/manager/services/deployment/actions/create_legacy_deployment.py b/src/ai/backend/manager/services/deployment/actions/create_legacy_deployment.py new file mode 100644 index 00000000000..3632ce8730c --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/create_legacy_deployment.py @@ -0,0 +1,34 @@ +"""Action for creating legacy deployments(Model Service).""" + +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.creator import DeploymentCreator +from ai.backend.manager.data.deployment.types import DeploymentInfo +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class CreateLegacyDeploymentAction(DeploymentBaseAction): + """Action to create a new legacy deployment(Model Service).""" + + creator: DeploymentCreator + + @override + def entity_id(self) -> Optional[str]: + return None # New deployment doesn't have an ID yet + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class CreateLegacyDeploymentActionResult(BaseActionResult): + data: DeploymentInfo + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/list_deployments.py b/src/ai/backend/manager/services/deployment/actions/list_deployments.py new file mode 100644 index 00000000000..37978692f09 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/list_deployments.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelDeploymentData +from ai.backend.manager.repositories.deployment.types.types import ( + DeploymentFilterOptions, + DeploymentOrderingOptions, +) +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction +from ai.backend.manager.types import PaginationOptions + + +@dataclass +class ListDeploymentsAction(DeploymentBaseAction): + pagination: PaginationOptions + ordering: Optional[DeploymentOrderingOptions] = None + filters: Optional[DeploymentFilterOptions] = None + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "list_deployments" + + +@dataclass +class ListDeploymentsActionResult(BaseActionResult): + data: list[ModelDeploymentData] + # Note: Total number of deployments, this is not equals to len(data) + total_count: int + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/list_replicas.py b/src/ai/backend/manager/services/deployment/actions/list_replicas.py new file mode 100644 index 00000000000..a8116433f7e --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/list_replicas.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelReplicaData +from ai.backend.manager.repositories.deployment.types.types import ( + ModelReplicaFilterOptions, + ModelReplicaOrderingOptions, +) +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction +from ai.backend.manager.types import PaginationOptions + + +@dataclass +class ListReplicasAction(DeploymentBaseAction): + pagination: PaginationOptions + ordering: Optional[ModelReplicaOrderingOptions] = None + filters: Optional[ModelReplicaFilterOptions] = None + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "list_deployments" + + +@dataclass +class ListReplicasActionResult(BaseActionResult): + data: list[ModelReplicaData] + # Note: Total number of replicas, this is not equals to len(data) + total_count: int + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/__init__.py b/src/ai/backend/manager/services/deployment/actions/model_revision/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/add_model_revision.py b/src/ai/backend/manager/services/deployment/actions/model_revision/add_model_revision.py new file mode 100644 index 00000000000..a728bef1a57 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/add_model_revision.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.creator import ModelRevisionCreator +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class AddModelRevisionAction(ModelRevisionBaseAction): + model_deployment_id: UUID + adder: ModelRevisionCreator + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class AddModelRevisionActionResult(BaseActionResult): + revision: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.revision.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/base.py b/src/ai/backend/manager/services/deployment/actions/model_revision/base.py new file mode 100644 index 00000000000..b60ebea80f5 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/base.py @@ -0,0 +1,10 @@ +from typing import override + +from ai.backend.manager.actions.action import BaseAction + + +class ModelRevisionBaseAction(BaseAction): + @override + @classmethod + def entity_type(cls) -> str: + return "model_revision" diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/batch_load_revisions.py b/src/ai/backend/manager/services/deployment/actions/model_revision/batch_load_revisions.py new file mode 100644 index 00000000000..f75011c75e1 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/batch_load_revisions.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class BatchLoadRevisionsAction(ModelRevisionBaseAction): + revision_ids: list[UUID] + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "batch_load_revisions" + + +@dataclass +class BatchLoadRevisionsActionResult(BaseActionResult): + data: list[ModelRevisionData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/create_model_revision.py b/src/ai/backend/manager/services/deployment/actions/model_revision/create_model_revision.py new file mode 100644 index 00000000000..b74ad36a83f --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/create_model_revision.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.creator import ModelRevisionCreator +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class CreateModelRevisionAction(ModelRevisionBaseAction): + creator: ModelRevisionCreator + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "create" + + +@dataclass +class CreateModelRevisionActionResult(BaseActionResult): + revision: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.revision.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_deployment_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_deployment_id.py new file mode 100644 index 00000000000..254de284e94 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_deployment_id.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelRevisionData, +) +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionByDeploymentIdAction(ModelRevisionBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionByDeploymentIdActionResult(BaseActionResult): + data: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_id.py new file mode 100644 index 00000000000..9b6b7db51ed --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_id.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelRevisionData, +) +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionByIdAction(ModelRevisionBaseAction): + revision_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.revision_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionByIdActionResult(BaseActionResult): + data: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_replica_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_replica_id.py new file mode 100644 index 00000000000..2045e012b24 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revision_by_replica_id.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ( + ModelRevisionData, +) +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionByReplicaIdAction(ModelRevisionBaseAction): + replica_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionByReplicaIdActionResult(BaseActionResult): + data: ModelRevisionData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/get_revisions_by_deployment_id.py b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revisions_by_deployment_id.py new file mode 100644 index 00000000000..00ea0244b69 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/get_revisions_by_deployment_id.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) + + +@dataclass +class GetRevisionsByDeploymentIdAction(ModelRevisionBaseAction): + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "get" + + +@dataclass +class GetRevisionsByDeploymentIdActionResult(BaseActionResult): + data: list[ModelRevisionData] + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/model_revision/list_revisions.py b/src/ai/backend/manager/services/deployment/actions/model_revision/list_revisions.py new file mode 100644 index 00000000000..b143066b4f4 --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/model_revision/list_revisions.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, override + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.types import ModelRevisionData +from ai.backend.manager.repositories.deployment.types.types import ( + ModelRevisionFilterOptions, + ModelRevisionOrderingOptions, +) +from ai.backend.manager.services.deployment.actions.model_revision.base import ( + ModelRevisionBaseAction, +) +from ai.backend.manager.types import PaginationOptions + + +@dataclass +class ListRevisionsAction(ModelRevisionBaseAction): + pagination: PaginationOptions + ordering: Optional[ModelRevisionOrderingOptions] = None + filters: Optional[ModelRevisionFilterOptions] = None + + @override + def entity_id(self) -> Optional[str]: + return None + + @override + @classmethod + def operation_type(cls) -> str: + return "list_revisions" + + +@dataclass +class ListRevisionsActionResult(BaseActionResult): + data: list[ModelRevisionData] + total_count: int + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/sync_replicas.py b/src/ai/backend/manager/services/deployment/actions/sync_replicas.py new file mode 100644 index 00000000000..bfd5e20207b --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/sync_replicas.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class SyncReplicaAction(DeploymentBaseAction): + """Action to sync replicas for an existing deployment.""" + + deployment_id: UUID + + @override + def entity_id(self) -> Optional[str]: + return str(self.deployment_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "sync_replicas" + + +@dataclass +class SyncReplicaActionResult(BaseActionResult): + success: bool + + @override + def entity_id(self) -> Optional[str]: + return None diff --git a/src/ai/backend/manager/services/deployment/actions/update_deployment.py b/src/ai/backend/manager/services/deployment/actions/update_deployment.py new file mode 100644 index 00000000000..58c0b0dbd8d --- /dev/null +++ b/src/ai/backend/manager/services/deployment/actions/update_deployment.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import Optional, override +from uuid import UUID + +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.data.deployment.modifier import NewDeploymentModifier +from ai.backend.manager.data.deployment.types import ModelDeploymentData +from ai.backend.manager.services.deployment.actions.base import DeploymentBaseAction + + +@dataclass +class UpdateDeploymentAction(DeploymentBaseAction): + """Action to update an existing deployment.""" + + deployment_id: UUID + modifier: NewDeploymentModifier + + @override + def entity_id(self) -> Optional[str]: + return str(self.deployment_id) + + @override + @classmethod + def operation_type(cls) -> str: + return "update" + + +@dataclass +class UpdateDeploymentActionResult(BaseActionResult): + data: ModelDeploymentData + + @override + def entity_id(self) -> Optional[str]: + return str(self.data.id) diff --git a/src/ai/backend/manager/services/deployment/processors.py b/src/ai/backend/manager/services/deployment/processors.py index 6d82cfd1163..ea8fea45545 100644 --- a/src/ai/backend/manager/services/deployment/processors.py +++ b/src/ai/backend/manager/services/deployment/processors.py @@ -1,36 +1,302 @@ """Deployment service processors for GraphQL API.""" -from typing import TYPE_CHECKING, override +from typing import Protocol, override from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.processor import ActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec +from ai.backend.manager.services.deployment.actions.access_token.create_access_token import ( + CreateAccessTokenAction, + CreateAccessTokenActionResult, +) +from ai.backend.manager.services.deployment.actions.access_token.list_access_tokens import ( + ListAccessTokensAction, + ListAccessTokensActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.batch_load_auto_scaling_rules import ( + BatchLoadAutoScalingRulesAction, + BatchLoadAutoScalingRulesActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.create_auto_scaling_rule import ( + CreateAutoScalingRuleAction, + CreateAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.delete_auto_scaling_rule import ( + DeleteAutoScalingRuleAction, + DeleteAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.update_auto_scaling_rule import ( + UpdateAutoScalingRuleAction, + UpdateAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.batch_load_deployments import ( + BatchLoadDeploymentsAction, + BatchLoadDeploymentsActionResult, +) +from ai.backend.manager.services.deployment.actions.batch_load_replicas_by_revision_ids import ( + BatchLoadReplicasByRevisionIdsAction, + BatchLoadReplicasByRevisionIdsActionResult, +) from ai.backend.manager.services.deployment.actions.create_deployment import ( CreateDeploymentAction, CreateDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.create_legacy_deployment import ( + CreateLegacyDeploymentAction, + CreateLegacyDeploymentActionResult, +) from ai.backend.manager.services.deployment.actions.destroy_deployment import ( DestroyDeploymentAction, DestroyDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.list_deployments import ( + ListDeploymentsAction, + ListDeploymentsActionResult, +) +from ai.backend.manager.services.deployment.actions.list_replicas import ( + ListReplicasAction, + ListReplicasActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.add_model_revision import ( + AddModelRevisionAction, + AddModelRevisionActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.batch_load_revisions import ( + BatchLoadRevisionsAction, + BatchLoadRevisionsActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.create_model_revision import ( + CreateModelRevisionAction, + CreateModelRevisionActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_deployment_id import ( + GetRevisionByDeploymentIdAction, + GetRevisionByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_id import ( + GetRevisionByIdAction, + GetRevisionByIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_replica_id import ( + GetRevisionByReplicaIdAction, + GetRevisionByReplicaIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revisions_by_deployment_id import ( + GetRevisionsByDeploymentIdAction, + GetRevisionsByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.list_revisions import ( + ListRevisionsAction, + ListRevisionsActionResult, +) +from ai.backend.manager.services.deployment.actions.sync_replicas import ( + SyncReplicaAction, + SyncReplicaActionResult, +) +from ai.backend.manager.services.deployment.actions.update_deployment import ( + UpdateDeploymentAction, + UpdateDeploymentActionResult, +) + + +class DeploymentServiceProtocol(Protocol): + async def create_deployment( + self, action: CreateDeploymentAction + ) -> CreateDeploymentActionResult: ... + + async def create_legacy_deployment( + self, action: CreateLegacyDeploymentAction + ) -> CreateLegacyDeploymentActionResult: ... + + async def update_deployment( + self, action: UpdateDeploymentAction + ) -> UpdateDeploymentActionResult: ... + + async def destroy_deployment( + self, action: DestroyDeploymentAction + ) -> DestroyDeploymentActionResult: ... + + async def batch_load_deployments( + self, action: BatchLoadDeploymentsAction + ) -> BatchLoadDeploymentsActionResult: ... + + async def create_auto_scaling_rule( + self, action: CreateAutoScalingRuleAction + ) -> CreateAutoScalingRuleActionResult: ... + + async def update_auto_scaling_rule( + self, action: UpdateAutoScalingRuleAction + ) -> UpdateAutoScalingRuleActionResult: ... + + async def delete_auto_scaling_rule( + self, action: DeleteAutoScalingRuleAction + ) -> DeleteAutoScalingRuleActionResult: ... + + async def create_access_token( + self, action: CreateAccessTokenAction + ) -> CreateAccessTokenActionResult: ... + + async def sync_replicas(self, action: SyncReplicaAction) -> SyncReplicaActionResult: ... + + async def add_model_revision( + self, action: AddModelRevisionAction + ) -> AddModelRevisionActionResult: ... + + async def batch_load_auto_scaling_rules( + self, action: BatchLoadAutoScalingRulesAction + ) -> BatchLoadAutoScalingRulesActionResult: ... + + async def get_revision_by_deployment_id( + self, action: GetRevisionByDeploymentIdAction + ) -> GetRevisionByDeploymentIdActionResult: ... + + async def get_revision_by_replica_id( + self, action: GetRevisionByReplicaIdAction + ) -> GetRevisionByReplicaIdActionResult: ... + + async def get_revision_by_id( + self, action: GetRevisionByIdAction + ) -> GetRevisionByIdActionResult: ... + + async def get_revisions_by_deployment_id( + self, action: GetRevisionsByDeploymentIdAction + ) -> GetRevisionsByDeploymentIdActionResult: ... + + async def batch_load_replicas_by_revision_ids( + self, action: BatchLoadReplicasByRevisionIdsAction + ) -> BatchLoadReplicasByRevisionIdsActionResult: ... + + async def batch_load_revisions( + self, action: BatchLoadRevisionsAction + ) -> BatchLoadRevisionsActionResult: ... + + async def list_replicas(self, action: ListReplicasAction) -> ListReplicasActionResult: ... + async def list_revisions(self, action: ListRevisionsAction) -> ListRevisionsActionResult: ... + + async def create_model_revision( + self, action: CreateModelRevisionAction + ) -> CreateModelRevisionActionResult: ... -if TYPE_CHECKING: - from ai.backend.manager.services.deployment.service import DeploymentService + async def list_access_tokens( + self, action: ListAccessTokensAction + ) -> ListAccessTokensActionResult: ... class DeploymentProcessors(AbstractProcessorPackage): """Processors for deployment operations.""" create_deployment: ActionProcessor[CreateDeploymentAction, CreateDeploymentActionResult] + update_deployment: ActionProcessor[UpdateDeploymentAction, UpdateDeploymentActionResult] destroy_deployment: ActionProcessor[DestroyDeploymentAction, DestroyDeploymentActionResult] + create_legacy_deployment: ActionProcessor[ + CreateLegacyDeploymentAction, CreateLegacyDeploymentActionResult + ] + create_auto_scaling_rule: ActionProcessor[ + CreateAutoScalingRuleAction, CreateAutoScalingRuleActionResult + ] + update_auto_scaling_rule: ActionProcessor[ + UpdateAutoScalingRuleAction, UpdateAutoScalingRuleActionResult + ] + delete_auto_scaling_rule: ActionProcessor[ + DeleteAutoScalingRuleAction, DeleteAutoScalingRuleActionResult + ] + create_access_token: ActionProcessor[CreateAccessTokenAction, CreateAccessTokenActionResult] + list_access_tokens: ActionProcessor[ListAccessTokensAction, ListAccessTokensActionResult] + sync_replicas: ActionProcessor[SyncReplicaAction, SyncReplicaActionResult] + add_model_revision: ActionProcessor[AddModelRevisionAction, AddModelRevisionActionResult] + batch_load_auto_scaling_rules: ActionProcessor[ + BatchLoadAutoScalingRulesAction, BatchLoadAutoScalingRulesActionResult + ] + get_revision_by_id: ActionProcessor[GetRevisionByIdAction, GetRevisionByIdActionResult] + batch_load_revisions: ActionProcessor[BatchLoadRevisionsAction, BatchLoadRevisionsActionResult] + get_revision_by_deployment_id: ActionProcessor[ + GetRevisionByDeploymentIdAction, GetRevisionByDeploymentIdActionResult + ] + get_revision_by_replica_id: ActionProcessor[ + GetRevisionByReplicaIdAction, GetRevisionByReplicaIdActionResult + ] + list_deployments: ActionProcessor[ListDeploymentsAction, ListDeploymentsActionResult] + batch_load_deployments: ActionProcessor[ + BatchLoadDeploymentsAction, BatchLoadDeploymentsActionResult + ] + get_revisions_by_deployment_id: ActionProcessor[ + GetRevisionsByDeploymentIdAction, GetRevisionsByDeploymentIdActionResult + ] + batch_load_replicas_by_revision_ids: ActionProcessor[ + BatchLoadReplicasByRevisionIdsAction, BatchLoadReplicasByRevisionIdsActionResult + ] + list_replicas: ActionProcessor[ListReplicasAction, ListReplicasActionResult] + list_revisions: ActionProcessor[ListRevisionsAction, ListRevisionsActionResult] + create_model_revision: ActionProcessor[ + CreateModelRevisionAction, CreateModelRevisionActionResult + ] - def __init__(self, service: "DeploymentService", action_monitors: list[ActionMonitor]) -> None: - self.create_deployment = ActionProcessor(service.create, action_monitors) - self.destroy_deployment = ActionProcessor(service.destroy, action_monitors) + def __init__( + self, service: DeploymentServiceProtocol, action_monitors: list[ActionMonitor] + ) -> None: + self.create_auto_scaling_rule = ActionProcessor( + service.create_auto_scaling_rule, action_monitors + ) + self.update_auto_scaling_rule = ActionProcessor( + service.update_auto_scaling_rule, action_monitors + ) + self.delete_auto_scaling_rule = ActionProcessor( + service.delete_auto_scaling_rule, action_monitors + ) + self.batch_load_deployments = ActionProcessor( + service.batch_load_deployments, action_monitors + ) + self.create_deployment = ActionProcessor(service.create_deployment, action_monitors) + self.destroy_deployment = ActionProcessor(service.destroy_deployment, action_monitors) + self.update_deployment = ActionProcessor(service.update_deployment, action_monitors) + self.create_legacy_deployment = ActionProcessor( + service.create_legacy_deployment, action_monitors + ) + self.create_access_token = ActionProcessor(service.create_access_token, action_monitors) + self.list_access_tokens = ActionProcessor(service.list_access_tokens, action_monitors) + self.sync_replicas = ActionProcessor(service.sync_replicas, action_monitors) + self.add_model_revision = ActionProcessor(service.add_model_revision, action_monitors) + self.batch_load_auto_scaling_rules = ActionProcessor( + service.batch_load_auto_scaling_rules, action_monitors + ) + self.get_revision_by_replica_id = ActionProcessor( + service.get_revision_by_replica_id, action_monitors + ) + self.get_revision_by_id = ActionProcessor(service.get_revision_by_id, action_monitors) + self.get_revisions_by_deployment_id = ActionProcessor( + service.get_revisions_by_deployment_id, action_monitors + ) + self.batch_load_replicas_by_revision_ids = ActionProcessor( + service.batch_load_replicas_by_revision_ids, action_monitors + ) + self.list_replicas = ActionProcessor(service.list_replicas, action_monitors) + self.list_revisions = ActionProcessor(service.list_revisions, action_monitors) + self.create_model_revision = ActionProcessor(service.create_model_revision, action_monitors) + self.batch_load_revisions = ActionProcessor(service.batch_load_revisions, action_monitors) @override def supported_actions(self) -> list[ActionSpec]: return [ CreateDeploymentAction.spec(), DestroyDeploymentAction.spec(), + CreateAutoScalingRuleAction.spec(), + UpdateAutoScalingRuleAction.spec(), + UpdateDeploymentAction.spec(), + DeleteAutoScalingRuleAction.spec(), + CreateAccessTokenAction.spec(), + SyncReplicaAction.spec(), + AddModelRevisionAction.spec(), + BatchLoadAutoScalingRulesAction.spec(), + GetRevisionByDeploymentIdAction.spec(), + GetRevisionByReplicaIdAction.spec(), + GetRevisionByIdAction.spec(), + GetRevisionsByDeploymentIdAction.spec(), + ListRevisionsAction.spec(), + ListReplicasAction.spec(), + CreateLegacyDeploymentAction.spec(), + CreateModelRevisionAction.spec(), + BatchLoadRevisionsAction.spec(), + BatchLoadDeploymentsAction.spec(), + ListAccessTokensAction.spec(), + BatchLoadReplicasByRevisionIdsAction.spec(), ] diff --git a/src/ai/backend/manager/services/deployment/service.py b/src/ai/backend/manager/services/deployment/service.py index ee9f2745b35..8356a47a707 100644 --- a/src/ai/backend/manager/services/deployment/service.py +++ b/src/ai/backend/manager/services/deployment/service.py @@ -1,16 +1,123 @@ """Deployment service for managing model deployments.""" import logging +from datetime import datetime, timedelta +from decimal import Decimal +from uuid import uuid4 +from ai.backend.common.data.model_deployment.types import ( + DeploymentStrategy, + ModelDeploymentStatus, +) +from ai.backend.common.types import ( + AutoScalingMetricSource, + ClusterMode, + ResourceSlot, + RuntimeVariant, +) from ai.backend.logging.utils import BraceStyleAdapter +from ai.backend.manager.data.deployment.types import ( + ClusterConfigData, + DeploymentNetworkSpec, + ExtraVFolderMountData, + ModelDeploymentAccessTokenData, + ModelDeploymentAutoScalingRuleData, + ModelDeploymentData, + ModelDeploymentMetadataInfo, + ModelMountConfigData, + ModelRevisionData, + ModelRuntimeConfigData, + ReplicaStateData, + ResourceConfigData, +) +from ai.backend.manager.services.deployment.actions.access_token.create_access_token import ( + CreateAccessTokenAction, + CreateAccessTokenActionResult, +) +from ai.backend.manager.services.deployment.actions.access_token.list_access_tokens import ( + ListAccessTokensAction, + ListAccessTokensActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.batch_load_auto_scaling_rules import ( + BatchLoadAutoScalingRulesAction, + BatchLoadAutoScalingRulesActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.create_auto_scaling_rule import ( + CreateAutoScalingRuleAction, + CreateAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.delete_auto_scaling_rule import ( + DeleteAutoScalingRuleAction, + DeleteAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.auto_scaling_rule.update_auto_scaling_rule import ( + UpdateAutoScalingRuleAction, + UpdateAutoScalingRuleActionResult, +) +from ai.backend.manager.services.deployment.actions.batch_load_deployments import ( + BatchLoadDeploymentsAction, + BatchLoadDeploymentsActionResult, +) +from ai.backend.manager.services.deployment.actions.batch_load_replicas_by_revision_ids import ( + BatchLoadReplicasByRevisionIdsAction, + BatchLoadReplicasByRevisionIdsActionResult, +) from ai.backend.manager.services.deployment.actions.create_deployment import ( CreateDeploymentAction, CreateDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.create_legacy_deployment import ( + CreateLegacyDeploymentAction, + CreateLegacyDeploymentActionResult, +) from ai.backend.manager.services.deployment.actions.destroy_deployment import ( DestroyDeploymentAction, DestroyDeploymentActionResult, ) +from ai.backend.manager.services.deployment.actions.list_replicas import ( + ListReplicasAction, + ListReplicasActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.add_model_revision import ( + AddModelRevisionAction, + AddModelRevisionActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.batch_load_revisions import ( + BatchLoadRevisionsAction, + BatchLoadRevisionsActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.create_model_revision import ( + CreateModelRevisionAction, + CreateModelRevisionActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_deployment_id import ( + GetRevisionByDeploymentIdAction, + GetRevisionByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_id import ( + GetRevisionByIdAction, + GetRevisionByIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revision_by_replica_id import ( + GetRevisionByReplicaIdAction, + GetRevisionByReplicaIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.get_revisions_by_deployment_id import ( + GetRevisionsByDeploymentIdAction, + GetRevisionsByDeploymentIdActionResult, +) +from ai.backend.manager.services.deployment.actions.model_revision.list_revisions import ( + ListRevisionsAction, + ListRevisionsActionResult, +) +from ai.backend.manager.services.deployment.actions.sync_replicas import ( + SyncReplicaAction, + SyncReplicaActionResult, +) +from ai.backend.manager.services.deployment.actions.update_deployment import ( + UpdateDeploymentAction, + UpdateDeploymentActionResult, +) from ai.backend.manager.sokovan.deployment import DeploymentController from ai.backend.manager.sokovan.deployment.types import DeploymentLifecycleType @@ -26,23 +133,96 @@ def __init__(self, deployment_controller: DeploymentController) -> None: """Initialize deployment service with controller.""" self._deployment_controller = deployment_controller - async def create(self, action: CreateDeploymentAction) -> CreateDeploymentActionResult: - """Create a new deployment. + async def create_deployment( + self, action: CreateDeploymentAction + ) -> CreateDeploymentActionResult: + return CreateDeploymentActionResult( + data=ModelDeploymentData( + id=uuid4(), + metadata=ModelDeploymentMetadataInfo( + name="test-deployment", + status=ModelDeploymentStatus.READY, + tags=["tag1", "tag2"], + project_id=uuid4(), + domain_name="default", + created_at=datetime.now(), + updated_at=datetime.now(), + ), + network_access=DeploymentNetworkSpec( + open_to_public=True, + url="http://example.com", + preferred_domain_name="example.com", + access_token_ids=[uuid4()], + ), + revision_history_ids=[uuid4(), uuid4()], + revision=mock_revision_data_1, + scaling_rule_ids=[uuid4(), uuid4()], + replica_state=ReplicaStateData( + desired_replica_count=3, + replica_ids=[uuid4(), uuid4(), uuid4()], + ), + default_deployment_strategy=DeploymentStrategy.ROLLING, + created_user_id=uuid4(), + ) + ) + + async def create_legacy_deployment( + self, action: CreateLegacyDeploymentAction + ) -> CreateLegacyDeploymentActionResult: + """Create a new legacy deployment(Model Serving). Args: - action: Create deployment action containing the creator specification + action: Create legacy deployment action containing the creator specification Returns: - CreateDeploymentActionResult: Result containing the created deployment info + CreateLegacyDeploymentActionResult: Result containing the created deployment info """ log.info("Creating deployment with name: {}", action.creator.name) deployment_info = await self._deployment_controller.create_deployment(action.creator) await self._deployment_controller.mark_lifecycle_needed( DeploymentLifecycleType.CHECK_PENDING ) - return CreateDeploymentActionResult(data=deployment_info) + return CreateLegacyDeploymentActionResult(data=deployment_info) - async def destroy(self, action: DestroyDeploymentAction) -> DestroyDeploymentActionResult: + async def update_deployment( + self, action: UpdateDeploymentAction + ) -> UpdateDeploymentActionResult: + await self._deployment_controller.mark_lifecycle_needed( + DeploymentLifecycleType.CHECK_REPLICA + ) + return UpdateDeploymentActionResult( + data=ModelDeploymentData( + id=action.deployment_id, + metadata=ModelDeploymentMetadataInfo( + name="test-deployment", + status=ModelDeploymentStatus.READY, + tags=["tag1", "tag2"], + project_id=uuid4(), + domain_name="default", + created_at=datetime.now(), + updated_at=datetime.now(), + ), + network_access=DeploymentNetworkSpec( + open_to_public=True, + url="http://example.com", + preferred_domain_name="example.com", + access_token_ids=[uuid4()], + ), + revision_history_ids=[uuid4(), uuid4()], + revision=mock_revision_data_1, + scaling_rule_ids=[uuid4(), uuid4()], + replica_state=ReplicaStateData( + desired_replica_count=3, + replica_ids=[uuid4(), uuid4(), uuid4()], + ), + default_deployment_strategy=DeploymentStrategy.ROLLING, + created_user_id=uuid4(), + ) + ) + + async def destroy_deployment( + self, action: DestroyDeploymentAction + ) -> DestroyDeploymentActionResult: """Destroy an existing deployment. Args: @@ -55,3 +235,265 @@ async def destroy(self, action: DestroyDeploymentAction) -> DestroyDeploymentAct success = await self._deployment_controller.destroy_deployment(action.endpoint_id) await self._deployment_controller.mark_lifecycle_needed(DeploymentLifecycleType.DESTROYING) return DestroyDeploymentActionResult(success=success) + + async def batch_load_deployments( + self, action: BatchLoadDeploymentsAction + ) -> BatchLoadDeploymentsActionResult: + return BatchLoadDeploymentsActionResult( + data=[ + ModelDeploymentData( + id=deployment_id, + metadata=ModelDeploymentMetadataInfo( + name=f"test-deployment-{i}", + status=ModelDeploymentStatus.READY, + tags=["tag1", "tag2"], + project_id=uuid4(), + domain_name="default", + created_at=datetime.now(), + updated_at=datetime.now(), + ), + network_access=DeploymentNetworkSpec( + open_to_public=True, + url="http://example.com", + preferred_domain_name="example.com", + access_token_ids=[uuid4()], + ), + revision_history_ids=[uuid4(), uuid4()], + revision=mock_revision_data_1, + scaling_rule_ids=[uuid4(), uuid4()], + replica_state=ReplicaStateData( + desired_replica_count=3, + replica_ids=[uuid4(), uuid4(), uuid4()], + ), + default_deployment_strategy=DeploymentStrategy.ROLLING, + created_user_id=uuid4(), + ) + for i, deployment_id in enumerate(action.deployment_ids) + ] + ) + + async def create_auto_scaling_rule( + self, action: CreateAutoScalingRuleAction + ) -> CreateAutoScalingRuleActionResult: + return CreateAutoScalingRuleActionResult( + data=ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=action.creator.model_deployment_id, + metric_source=action.creator.metric_source, + metric_name=action.creator.metric_name, + min_threshold=action.creator.min_threshold, + max_threshold=action.creator.max_threshold, + step_size=action.creator.step_size, + time_window=action.creator.time_window, + min_replicas=action.creator.min_replicas, + max_replicas=action.creator.max_replicas, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ) + ) + + async def update_auto_scaling_rule( + self, action: UpdateAutoScalingRuleAction + ) -> UpdateAutoScalingRuleActionResult: + return UpdateAutoScalingRuleActionResult( + data=ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=uuid4(), + metric_source=AutoScalingMetricSource.KERNEL, + metric_name="test-metric", + min_threshold=Decimal("0.5"), + max_threshold=Decimal("21.0"), + step_size=1, + time_window=60, + min_replicas=1, + max_replicas=10, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ) + ) + + async def delete_auto_scaling_rule( + self, action: DeleteAutoScalingRuleAction + ) -> DeleteAutoScalingRuleActionResult: + return DeleteAutoScalingRuleActionResult(success=True) + + async def create_access_token( + self, action: CreateAccessTokenAction + ) -> CreateAccessTokenActionResult: + return CreateAccessTokenActionResult( + data=ModelDeploymentAccessTokenData( + id=uuid4(), + token="test_token", + valid_until=datetime.now() + timedelta(hours=1), + created_at=datetime.now(), + ) + ) + + async def list_access_tokens( + self, action: ListAccessTokensAction + ) -> ListAccessTokensActionResult: + tokens = [] + for i in range(5): + tokens.append( + ModelDeploymentAccessTokenData( + id=uuid4(), + token=f"test_token_{i}", + valid_until=datetime.now() + timedelta(hours=24 * (i + 1)), + created_at=datetime.now() - timedelta(hours=i), + ) + ) + return ListAccessTokensActionResult( + data=tokens, + total_count=len(tokens), + ) + + async def sync_replicas(self, action: SyncReplicaAction) -> SyncReplicaActionResult: + return SyncReplicaActionResult(success=True) + + async def add_model_revision( + self, action: AddModelRevisionAction + ) -> AddModelRevisionActionResult: + return AddModelRevisionActionResult(revision=mock_revision_data_2) + + async def batch_load_auto_scaling_rules( + self, action: BatchLoadAutoScalingRulesAction + ) -> BatchLoadAutoScalingRulesActionResult: + return BatchLoadAutoScalingRulesActionResult( + data=[ + ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=uuid4(), + metric_source=AutoScalingMetricSource.KERNEL, + metric_name="test-metric", + min_threshold=Decimal("0.5"), + max_threshold=Decimal("21.0"), + step_size=1, + time_window=60, + min_replicas=1, + max_replicas=10, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ), + ModelDeploymentAutoScalingRuleData( + id=uuid4(), + model_deployment_id=uuid4(), + metric_source=AutoScalingMetricSource.KERNEL, + metric_name="test-metric", + min_threshold=Decimal("0.0"), + max_threshold=Decimal("10.0"), + step_size=2, + time_window=200, + min_replicas=1, + max_replicas=5, + created_at=datetime.now(), + last_triggered_at=datetime.now(), + ), + ] + ) + + async def list_replicas(self, action: ListReplicasAction) -> ListReplicasActionResult: + return ListReplicasActionResult( + data=[], + total_count=0, + ) + + async def batch_load_revisions( + self, action: BatchLoadRevisionsAction + ) -> BatchLoadRevisionsActionResult: + return BatchLoadRevisionsActionResult(data=[mock_revision_data_1, mock_revision_data_2]) + + async def list_revisions(self, action: ListRevisionsAction) -> ListRevisionsActionResult: + return ListRevisionsActionResult(data=[], total_count=0) + + async def get_revision_by_deployment_id( + self, action: GetRevisionByDeploymentIdAction + ) -> GetRevisionByDeploymentIdActionResult: + return GetRevisionByDeploymentIdActionResult(data=mock_revision_data_1) + + async def get_revision_by_id( + self, action: GetRevisionByIdAction + ) -> GetRevisionByIdActionResult: + return GetRevisionByIdActionResult(data=mock_revision_data_1) + + async def get_revision_by_replica_id( + self, action: GetRevisionByReplicaIdAction + ) -> GetRevisionByReplicaIdActionResult: + return GetRevisionByReplicaIdActionResult(data=mock_revision_data_1) + + async def get_revisions_by_deployment_id( + self, action: GetRevisionsByDeploymentIdAction + ) -> GetRevisionsByDeploymentIdActionResult: + # For now, return mock revision data list + return GetRevisionsByDeploymentIdActionResult( + data=[mock_revision_data_1, mock_revision_data_2] + ) + + async def batch_load_replicas_by_revision_ids( + self, action: BatchLoadReplicasByRevisionIdsAction + ) -> BatchLoadReplicasByRevisionIdsActionResult: + # For now, return empty replica list + return BatchLoadReplicasByRevisionIdsActionResult(data={}) + + async def create_model_revision( + self, action: CreateModelRevisionAction + ) -> CreateModelRevisionActionResult: + return CreateModelRevisionActionResult(revision=mock_revision_data_2) + + +mock_revision_data_1 = ModelRevisionData( + id=uuid4(), + name="test-revision", + cluster_config=ClusterConfigData( + mode=ClusterMode.SINGLE_NODE, + size=1, + ), + resource_config=ResourceConfigData( + resource_group_name="default", + resource_slot=ResourceSlot.from_json({"cpu": 1, "memory": 1024}), + ), + model_mount_config=ModelMountConfigData( + vfolder_id=uuid4(), + mount_destination="/model", + definition_path="model-definition.yaml", + ), + model_runtime_config=ModelRuntimeConfigData( + runtime_variant=RuntimeVariant.VLLM, + inference_runtime_config={"tp_size": 2, "max_length": 1024}, + ), + extra_vfolder_mounts=[ + ExtraVFolderMountData( + vfolder_id=uuid4(), + mount_destination="/var", + ), + ExtraVFolderMountData( + vfolder_id=uuid4(), + mount_destination="/example", + ), + ], + image_id=uuid4(), + created_at=datetime.now(), +) + +mock_revision_data_2 = ModelRevisionData( + id=uuid4(), + name="test-revision-2", + cluster_config=ClusterConfigData( + mode=ClusterMode.MULTI_NODE, + size=1, + ), + resource_config=ResourceConfigData( + resource_group_name="default", + resource_slot=ResourceSlot.from_json({"cpu": 1, "memory": 1024}), + ), + model_mount_config=ModelMountConfigData( + vfolder_id=uuid4(), + mount_destination="/model", + definition_path="model-definition.yaml", + ), + model_runtime_config=ModelRuntimeConfigData( + runtime_variant=RuntimeVariant.NIM, + inference_runtime_config={"tp_size": 2, "max_length": 1024}, + ), + image_id=uuid4(), + created_at=datetime.now(), +)