diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py index 541a7e3ad3c..c2745b9c666 100644 --- a/src/ai/backend/common/validators.py +++ b/src/ai/backend/common/validators.py @@ -218,6 +218,23 @@ def check_and_return(self, value: Any) -> T_enum: self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) +class EnumList(t.Trafaret, Generic[T_enum]): + def __init__(self, enum_cls: Type[T_enum], *, use_name: bool = False) -> None: + self.enum_cls = enum_cls + self.use_name = use_name + + def check_and_return(self, value: Any) -> list[T_enum]: + try: + if self.use_name: + return [self.enum_cls[val] for val in value] + else: + return [self.enum_cls(val) for val in value] + except TypeError: + self._failure("cannot parse value into list", value=value) + except (KeyError, ValueError): + self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) + + class JSONString(t.Trafaret): def check_and_return(self, value: Any) -> dict: try: diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 15bce8562ef..ad041e33322 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -43,7 +43,11 @@ import trafaret as t from aiohttp import hdrs, web from dateutil.tz import tzutc -from pydantic import AliasChoices, BaseModel, Field +from pydantic import ( + AliasChoices, + BaseModel, + Field, +) from redis.asyncio import Redis from sqlalchemy.orm import noload, selectinload from sqlalchemy.sql.expression import null, true @@ -73,6 +77,7 @@ ClusterMode, ImageRegistry, KernelId, + KernelStatusCollection, MountPermission, MountTypes, SessionTypes, @@ -82,6 +87,7 @@ from ..config import DEFAULT_CHUNK_SIZE from ..defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE +from ..exceptions import MultiAgentError from ..models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, DEAD_SESSION_STATUSES, @@ -969,6 +975,43 @@ async def sync_agent_registry(request: web.Request, params: Any) -> web.StreamRe return web.json_response({}, status=200) +class SyncAgentResourceRequestModel(BaseModel): + agent_id: AgentId = Field( + validation_alias=AliasChoices("agent_id", "agent"), + description="Target agent id to sync resource.", + ) + + +@server_status_required(ALL_ALLOWED) +@auth_required +@pydantic_params_api_handler(SyncAgentResourceRequestModel) +async def sync_agent_resource( + request: web.Request, params: SyncAgentResourceRequestModel +) -> web.Response: + root_ctx: RootContext = request.app["_root.context"] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + + agent_id = params.agent_id + log.info( + "SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id + ) + + try: + result = await root_ctx.registry.sync_agent_resource(root_ctx.db, [agent_id]) + except BackendError: + log.exception("SYNC_AGENT_RESOURCE: exception") + raise + val = result.get(agent_id) + match val: + case KernelStatusCollection(): + pass + case MultiAgentError(): + return web.Response(status=500) + case _: + pass + return web.Response(status=204) + + @server_status_required(ALL_ALLOWED) @auth_required @check_api_params( @@ -2315,6 +2358,7 @@ def create_app( cors.add(app.router.add_route("POST", "/_/create-cluster", create_cluster)) cors.add(app.router.add_route("GET", "/_/match", match_sessions)) cors.add(app.router.add_route("POST", "/_/sync-agent-registry", sync_agent_registry)) + cors.add(app.router.add_route("POST", "/_/sync-agent-resource", sync_agent_resource)) session_resource = cors.add(app.router.add_resource(r"/{session_name}")) cors.add(session_resource.add_route("GET", get_info)) cors.add(session_resource.add_route("PATCH", restart)) diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index 8b93b29078d..16c991709cd 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -224,6 +224,7 @@ from .api.exceptions import ObjectNotFound, ServerMisconfiguredError from .models.session import SessionStatus from .pglock import PgAdvisoryLock +from .types import DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS, AgentResourceSyncTrigger log = BraceStyleAdapter(logging.getLogger(__spec__.name)) @@ -295,6 +296,9 @@ "agent-selection-resource-priority", default=["cuda", "rocm", "tpu", "cpu", "mem"], ): t.List(t.String), + t.Key( + "agent-resource-sync-trigger", default=DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS + ): tx.EnumList(AgentResourceSyncTrigger), t.Key("importer-image", default="lablup/importer:manylinux2010"): t.String, t.Key("max-wsmsg-size", default=16 * (2**20)): t.ToInt, # default: 16 MiB tx.AliasedKey(["aiomonitor-termui-port", "aiomonitor-port"], default=48100): t.ToInt[ diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index d62e78bd6eb..320a1d037cd 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -129,7 +129,7 @@ ) from .config import LocalConfig, SharedConfig from .defs import DEFAULT_IMAGE_ARCH, DEFAULT_ROLE, INTRINSIC_SLOTS -from .exceptions import MultiAgentError, convert_to_status_data +from .exceptions import ErrorStatusInfo, MultiAgentError, convert_to_status_data from .models import ( AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, AGENT_RESOURCE_OCCUPYING_SESSION_STATUSES, @@ -182,7 +182,7 @@ reenter_txn_session, sql_json_merge, ) -from .types import UserScope +from .types import AgentResourceSyncTrigger, UserScope if TYPE_CHECKING: from sqlalchemy.engine.row import Row @@ -1694,6 +1694,10 @@ async def _create_kernels_in_one_agent( is_local = image_info["is_local"] resource_policy: KeyPairResourcePolicyRow = image_info["resource_policy"] auto_pull = image_info["auto_pull"] + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-trigger"], + ) assert agent_alloc_ctx.agent_id is not None assert scheduled_session.id is not None @@ -1790,6 +1794,9 @@ async def _update_kernel() -> None: ex = e err_info = convert_to_status_data(ex, self.debug) + def _is_insufficient_resource_err(err_info: ErrorStatusInfo) -> bool: + return err_info["error"]["name"] == "InsufficientResource" + # The agent has already cancelled or issued the destruction lifecycle event # for this batch of kernels. for binding in items: @@ -1821,6 +1828,18 @@ async def _update_failure() -> None: await db_sess.execute(query) await execute_with_retry(_update_failure) + if ( + AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger + and _is_insufficient_resource_err(err_info) + ): + await self.sync_agent_resource( + self.db, + [ + binding.agent_alloc_ctx.agent_id + for binding in items + if binding.agent_alloc_ctx.agent_id is not None + ], + ) raise async def create_cluster_ssh_keypair(self) -> ClusterSSHKeyPair: diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index 6c3768f3897..1c3d49da8aa 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -20,6 +20,7 @@ Sequence, Tuple, Union, + cast, ) import aiotools @@ -94,6 +95,7 @@ ) from ..models.utils import ExtendedAsyncSAEngine as SAEngine from ..models.utils import execute_with_retry, sql_json_increment, sql_json_merge +from ..types import AgentResourceSyncTrigger from .predicates import ( check_concurrency, check_dependencies, @@ -265,6 +267,10 @@ async def schedule( log.debug("schedule(): triggered") manager_id = self.local_config["manager"]["id"] redis_key = f"manager.{manager_id}.schedule" + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-trigger"], + ) def _pipeline(r: Redis) -> RedisPipeline: pipe = r.pipeline() @@ -293,10 +299,6 @@ def _pipeline(r: Redis) -> RedisPipeline: # as its individual steps are composed of many short-lived transactions. async with self.lock_factory(LockID.LOCKID_SCHEDULE, 60): async with self.db.begin_readonly_session() as db_sess: - # query = ( - # sa.select(ScalingGroupRow) - # .join(ScalingGroupRow.agents.and_(AgentRow.status == AgentStatus.ALIVE)) - # ) query = ( sa.select(AgentRow.scaling_group) .where(AgentRow.status == AgentStatus.ALIVE) @@ -304,9 +306,10 @@ def _pipeline(r: Redis) -> RedisPipeline: ) result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] + for sgroup_name in schedulable_scaling_groups: try: - await self._schedule_in_sgroup( + kernel_agent_bindings = await self._schedule_in_sgroup( sched_ctx, sgroup_name, ) @@ -320,6 +323,17 @@ def _pipeline(r: Redis) -> RedisPipeline: ) except Exception as e: log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + await self.registry.sync_agent_resource(self.db, selected_agent_ids) await redis_helper.execute( self.redis_live, lambda r: r.hset( diff --git a/src/ai/backend/manager/types.py b/src/ai/backend/manager/types.py index dee2842dd62..78833ff861b 100644 --- a/src/ai/backend/manager/types.py +++ b/src/ai/backend/manager/types.py @@ -56,3 +56,14 @@ class MountOptionModel(BaseModel): MountPermission | None, Field(validation_alias=AliasChoices("permission", "perm"), default=None), ] + + +class AgentResourceSyncTrigger(enum.StrEnum): + AFTER_SCHEDULING = "after-scheduling" + BEFORE_KERNEL_CREATION = "before-kernel-creation" + ON_CREATION_FAILURE = "on-creation-failure" + + +DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS: list[AgentResourceSyncTrigger] = [ + AgentResourceSyncTrigger.ON_CREATION_FAILURE +]