Skip to content

Commit 3b2f877

Browse files
HyeockJinKimclaude
andauthored
feat(BA-2791): Implement scaling group filtering with injectable rules (#6424)
Co-authored-by: Claude <[email protected]>
1 parent 4027343 commit 3b2f877

File tree

16 files changed

+839
-415
lines changed

16 files changed

+839
-415
lines changed

changes/6424.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement scaling group filtering with injectable rules for session scheduling. The new ScalingGroupFilter applies configurable filter rules (public/private access, session type support) to determine eligible scaling groups before session creation, replacing the previous validation-only approach. This enables more flexible and extensible scaling group selection logic.

src/ai/backend/manager/api/auth.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,6 @@ async def _authenticate_via_jwt(
531531

532532
try:
533533
# 1. Decode token without verification to extract access_key
534-
log.info("jwt token: {}", jwt_token)
535534
unverified_payload = pyjwt.decode(
536535
jwt_token,
537536
options={"verify_signature": False},
@@ -727,8 +726,6 @@ async def auth_middleware(request: web.Request, handler) -> web.StreamResponse:
727726
# Detect authentication method and route to appropriate flow
728727
jwt_token = request.headers.get("X-BackendAI-Token")
729728
auth_header = request.headers.get("Authorization")
730-
log.info("AUTH.MIDDLEWARE(path:{}, headers:{})", request.path, dict(request.headers))
731-
732729
if jwt_token:
733730
# JWT authentication flow (GraphQL Federation)
734731
await _authenticate_via_jwt(request, root_ctx, jwt_token)

src/ai/backend/manager/errors/resource.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,16 @@ def error_code(cls) -> ErrorCode:
354354
operation=ErrorOperation.READ,
355355
error_detail=ErrorDetail.INVALID_PARAMETERS,
356356
)
357+
358+
359+
class NoAvailableScalingGroup(BackendAIError, web.HTTPBadRequest):
360+
error_type = "https://api.backend.ai/probs/no-available-scaling-group"
361+
error_title = "No scaling groups available for this session."
362+
363+
@classmethod
364+
def error_code(cls) -> ErrorCode:
365+
return ErrorCode(
366+
domain=ErrorDomain.SCALING_GROUP,
367+
operation=ErrorOperation.ACCESS,
368+
error_detail=ErrorDetail.NOT_FOUND,
369+
)

src/ai/backend/manager/sokovan/scheduling_controller/scheduling_controller.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@
4343
ClusterValidationRule,
4444
ContainerLimitRule,
4545
MountNameValidationRule,
46-
ScalingGroupAccessRule,
46+
PublicPrivateFilterRule,
47+
ScalingGroupFilter,
4748
ServicePortRule,
48-
SessionTypeRule,
49+
SessionTypeFilterRule,
4950
SessionValidator,
5051
)
5152

@@ -76,6 +77,7 @@ class SchedulingController:
7677

7778
# Services
7879
_scaling_group_resolver: ScalingGroupResolver
80+
_scaling_group_filter: ScalingGroupFilter
7981
_validator: SessionValidator
8082
_preparer: SessionPreparer
8183
_resource_calculator: ResourceCalculator
@@ -98,11 +100,16 @@ def __init__(self, args: SchedulingControllerArgs) -> None:
98100
# Initialize services
99101
self._scaling_group_resolver = ScalingGroupResolver()
100102

103+
# Initialize scaling group filter with rules
104+
filter_rules = [
105+
PublicPrivateFilterRule(),
106+
SessionTypeFilterRule(),
107+
]
108+
self._scaling_group_filter = ScalingGroupFilter(filter_rules)
109+
101110
# Initialize validator with rules
102111
validator_rules = [
103112
ContainerLimitRule(),
104-
ScalingGroupAccessRule(),
105-
SessionTypeRule(),
106113
ServicePortRule(),
107114
ClusterValidationRule(),
108115
MountNameValidationRule(),
@@ -196,10 +203,22 @@ async def enqueue_session(
196203
allowed_vfolder_types,
197204
)
198205

199-
# Phase 3: Validate
206+
# Phase 3: Filter and validate
200207
with self._metric_observer.measure_phase(
201208
"scheduling_controller", validated_scaling_group.name, "validation"
202209
):
210+
# Filter scaling groups based on session requirements
211+
# This will raise NoAvailableScalingGroup if filtering fails
212+
filter_result = self._scaling_group_filter.filter(
213+
session_spec,
214+
creation_context.allowed_scaling_groups,
215+
)
216+
217+
# Update context with filtered scaling groups for remaining validation
218+
creation_context.allowed_scaling_groups = filter_result.allowed_groups
219+
session_spec.scaling_group = filter_result.selected_scaling_group
220+
221+
# Run remaining validation rules
203222
self._validator.validate(
204223
session_spec,
205224
creation_context,

src/ai/backend/manager/sokovan/scheduling_controller/validators/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
11
"""Validators for session creation."""
22

3-
from .base import SessionValidatorRule
3+
from .base import (
4+
ScalingGroupFilterResult,
5+
ScalingGroupFilterRule,
6+
ScalingGroupFilterRuleResult,
7+
SessionValidatorRule,
8+
)
49
from .cluster import ClusterValidationRule
510
from .mount import MountNameValidationRule
611
from .rules import (
712
ContainerLimitRule,
813
ResourceLimitRule,
9-
ScalingGroupAccessRule,
1014
ServicePortRule,
11-
SessionTypeRule,
15+
)
16+
from .scaling_group_filter import (
17+
PublicPrivateFilterRule,
18+
ScalingGroupFilter,
19+
SessionTypeFilterRule,
1220
)
1321
from .validator import SessionValidator
1422

1523
__all__ = [
1624
"SessionValidator",
1725
"SessionValidatorRule",
1826
"ContainerLimitRule",
19-
"ScalingGroupAccessRule",
20-
"SessionTypeRule",
27+
"ScalingGroupFilter",
28+
"ScalingGroupFilterRule",
29+
"ScalingGroupFilterResult",
30+
"ScalingGroupFilterRuleResult",
31+
"PublicPrivateFilterRule",
32+
"SessionTypeFilterRule",
2133
"ServicePortRule",
2234
"ResourceLimitRule",
2335
"ClusterValidationRule",

src/ai/backend/manager/sokovan/scheduling_controller/validators/base.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
"""Base classes for session validation rules."""
1+
"""Base classes for session validation and filtering rules."""
2+
3+
from __future__ import annotations
24

35
from abc import ABC, abstractmethod
6+
from dataclasses import dataclass
47

58
from ai.backend.manager.repositories.scheduler.types.session_creation import (
69
AllowedScalingGroup,
@@ -9,6 +12,58 @@
912
)
1013

1114

15+
@dataclass
16+
class ScalingGroupFilterRuleResult:
17+
"""Result of a single scaling group filter rule."""
18+
19+
allowed_groups: list[AllowedScalingGroup]
20+
"""Scaling groups that passed this rule."""
21+
22+
rejected_groups: dict[str, str]
23+
"""Scaling groups that were rejected by this rule, mapped to rejection reason."""
24+
25+
26+
@dataclass
27+
class ScalingGroupFilterResult:
28+
"""Final result of scaling group filtering with selected group."""
29+
30+
allowed_groups: list[AllowedScalingGroup]
31+
"""Scaling groups that passed all filters."""
32+
33+
selected_scaling_group: str
34+
"""The selected scaling group name (either specified or auto-selected)."""
35+
36+
37+
class ScalingGroupFilterRule(ABC):
38+
"""
39+
Abstract base class for scaling group filter rules.
40+
Each rule filters scaling groups based on specific criteria.
41+
"""
42+
43+
@abstractmethod
44+
def name(self) -> str:
45+
"""Return the filter rule name."""
46+
raise NotImplementedError
47+
48+
@abstractmethod
49+
def filter(
50+
self,
51+
spec: SessionCreationSpec,
52+
allowed_groups: list[AllowedScalingGroup],
53+
) -> ScalingGroupFilterRuleResult:
54+
"""
55+
Filter scaling groups based on session creation specification.
56+
57+
Args:
58+
spec: Session creation specification
59+
allowed_groups: List of scaling groups to filter
60+
61+
Returns:
62+
ScalingGroupFilterRuleResult containing allowed and rejected groups with reasons
63+
"""
64+
raise NotImplementedError
65+
66+
1267
class SessionValidatorRule(ABC):
1368
"""
1469
Abstract base class for session validator rules.
@@ -25,15 +80,13 @@ def validate(
2580
self,
2681
spec: SessionCreationSpec,
2782
context: SessionCreationContext,
28-
allowed_groups: list[AllowedScalingGroup],
2983
) -> None:
3084
"""
3185
Validate a session creation specification.
3286
3387
Args:
3488
spec: Session creation specification
3589
context: Pre-fetched context with all required data
36-
allowed_groups: List of allowed scaling groups for the user
3790
3891
Raises:
3992
InvalidAPIParameters or QuotaExceeded: If validation fails

src/ai/backend/manager/sokovan/scheduling_controller/validators/cluster.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from ai.backend.manager.errors.api import InvalidAPIParameters
44
from ai.backend.manager.repositories.scheduler.types.session_creation import (
5-
AllowedScalingGroup,
65
SessionCreationContext,
76
SessionCreationSpec,
87
)
@@ -20,7 +19,6 @@ def validate(
2019
self,
2120
spec: SessionCreationSpec,
2221
context: SessionCreationContext,
23-
allowed_groups: list[AllowedScalingGroup],
2422
) -> None:
2523
"""Validate cluster configuration and kernel specifications."""
2624
# Check if kernel_specs exists

src/ai/backend/manager/sokovan/scheduling_controller/validators/mount.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from ai.backend.manager.errors.api import InvalidAPIParameters
66
from ai.backend.manager.models import verify_vfolder_name
77
from ai.backend.manager.repositories.scheduler.types.session_creation import (
8-
AllowedScalingGroup,
98
SessionCreationContext,
109
SessionCreationSpec,
1110
)
@@ -23,7 +22,6 @@ def validate(
2322
self,
2423
spec: SessionCreationSpec,
2524
context: SessionCreationContext,
26-
allowed_groups: list[AllowedScalingGroup],
2725
) -> None:
2826
"""Validate mount names if mount map is provided."""
2927
mount_map = spec.creation_spec.get("mount_map") or {}

src/ai/backend/manager/sokovan/scheduling_controller/validators/rules.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
from ai.backend.common.types import SlotName, SlotTypes
88
from ai.backend.manager.errors.api import InvalidAPIParameters
99
from ai.backend.manager.errors.kernel import QuotaExceeded
10-
from ai.backend.manager.models import PRIVATE_SESSION_TYPES
1110
from ai.backend.manager.repositories.scheduler.types.session_creation import (
12-
AllowedScalingGroup,
1311
SessionCreationContext,
1412
SessionCreationSpec,
1513
)
@@ -29,7 +27,6 @@ def validate(
2927
self,
3028
spec: SessionCreationSpec,
3129
context: SessionCreationContext,
32-
allowed_groups: list[AllowedScalingGroup],
3330
) -> None:
3431
max_containers = spec.resource_policy.get("max_containers_per_session", 1)
3532
if spec.cluster_size > int(max_containers):
@@ -38,68 +35,6 @@ def validate(
3835
)
3936

4037

41-
class ScalingGroupAccessRule(SessionValidatorRule):
42-
"""Validates that the scaling group is accessible."""
43-
44-
@override
45-
def name(self) -> str:
46-
return "scaling_group_access"
47-
48-
@override
49-
def validate(
50-
self,
51-
spec: SessionCreationSpec,
52-
context: SessionCreationContext,
53-
allowed_groups: list[AllowedScalingGroup],
54-
) -> None:
55-
if not spec.scaling_group:
56-
# Should have been resolved already
57-
return
58-
59-
public_sgroup_only = spec.session_type not in PRIVATE_SESSION_TYPES
60-
61-
# Find the scaling group in allowed list
62-
for sg in allowed_groups:
63-
if sg.name == spec.scaling_group:
64-
if public_sgroup_only and sg.is_private:
65-
raise InvalidAPIParameters(
66-
f"Scaling group {spec.scaling_group} is not allowed for {spec.session_type} sessions"
67-
)
68-
return
69-
70-
raise InvalidAPIParameters(f"Scaling group {spec.scaling_group} is not accessible")
71-
72-
73-
class SessionTypeRule(SessionValidatorRule):
74-
"""Validates session type compatibility with scaling group."""
75-
76-
@override
77-
def name(self) -> str:
78-
return "session_type"
79-
80-
@override
81-
def validate(
82-
self,
83-
spec: SessionCreationSpec,
84-
context: SessionCreationContext,
85-
allowed_groups: list[AllowedScalingGroup],
86-
) -> None:
87-
if spec.scaling_group is None:
88-
# Should have been resolved already
89-
return
90-
91-
for sg in allowed_groups:
92-
if sg.name == spec.scaling_group:
93-
allowed_session_types = sg.scheduler_opts.allowed_session_types
94-
if spec.session_type not in allowed_session_types:
95-
raise InvalidAPIParameters(
96-
f"Session type {spec.session_type} is not allowed in scaling group {sg.name}"
97-
)
98-
return
99-
100-
raise InvalidAPIParameters(f"Scaling group {spec.scaling_group} is not accessible")
101-
102-
10338
class ServicePortRule(SessionValidatorRule):
10439
"""Validates preopen ports against service ports."""
10540

@@ -112,7 +47,6 @@ def validate(
11247
self,
11348
spec: SessionCreationSpec,
11449
context: SessionCreationContext,
115-
allowed_groups: list[AllowedScalingGroup],
11650
) -> None:
11751
# Check preopen_ports from creation_config (applies to all kernels)
11852
creation_preopen_ports = spec.creation_spec.get("preopen_ports")
@@ -186,7 +120,6 @@ def validate(
186120
self,
187121
spec: SessionCreationSpec,
188122
context: SessionCreationContext,
189-
allowed_groups: list[AllowedScalingGroup],
190123
) -> None:
191124
# Note: This validation should ideally be done after resource calculation
192125
# For now, we'll validate what we can from the spec

0 commit comments

Comments
 (0)