diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index a3c39ace338..79f23a04767 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -1138,7 +1138,11 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: ), ) # Notify cleanup waiters after all state updates. - if kernel_obj is not None and kernel_obj.clean_event is not None: + if ( + kernel_obj is not None + and kernel_obj.clean_event is not None + and not kernel_obj.clean_event.done() + ): kernel_obj.clean_event.set_result(None) if ev.done_future is not None and not ev.done_future.done(): ev.done_future.set_result(None) diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 2f5159feb47..bb175c4b93c 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -64,6 +64,7 @@ ImageRegistry, KernelCreationConfig, KernelId, + KernelStatusCollection, QueueSentinel, SessionId, aobject, @@ -80,7 +81,7 @@ ) from .exception import ResourceError from .monitor import AgentErrorPluginContext, AgentStatsPluginContext -from .types import AgentBackend, LifecycleEvent, VolumeInfo +from .types import AgentBackend, KernelLifecycleStatus, LifecycleEvent, VolumeInfo from .utils import get_arch_name, get_subnet_ip if TYPE_CHECKING: @@ -478,6 +479,115 @@ async def sync_kernel_registry( suppress_events=True, ) + @rpc_function + @collect_error + async def sync_and_get_kernels( + self, + preparing_kernels: Iterable[UUID], + pulling_kernels: Iterable[UUID], + running_kernels: Iterable[UUID], + terminating_kernels: Iterable[UUID], + ) -> dict[str, Any]: + """ + Sync kernel_registry and containers to truth data + and return kernel infos whose status is irreversible. + """ + + actual_terminating_kernels: list[tuple[KernelId, str]] = [] + actual_terminated_kernels: list[tuple[KernelId, str]] = [] + + async with self.agent.registry_lock: + actual_existing_kernels = [ + kid + for kid, obj in self.agent.kernel_registry.items() + if obj.state == KernelLifecycleStatus.RUNNING + ] + + for raw_kernel_id in running_kernels: + kernel_id = KernelId(raw_kernel_id) + if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: + actual_terminating_kernels.append(( + kernel_id, + str( + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED + ), + )) + else: + actual_terminated_kernels.append(( + kernel_id, + str(KernelLifecycleEventReason.ALREADY_TERMINATED), + )) + + for raw_kernel_id in terminating_kernels: + kernel_id = KernelId(raw_kernel_id) + if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED, + suppress_events=False, + ) + else: + actual_terminated_kernels.append(( + kernel_id, + str(KernelLifecycleEventReason.ALREADY_TERMINATED), + )) + + for kernel_id, kernel_obj in self.agent.kernel_registry.items(): + if kernel_id in terminating_kernels: + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER, + suppress_events=False, + ) + elif kernel_id in running_kernels: + pass + elif kernel_id in preparing_kernels: + pass + elif kernel_id in pulling_kernels: + # kernel_registry does not have `pulling` state kernels. + # Let's skip it. + pass + else: + # This kernel is not alive according to the truth data. + # The kernel should be destroyed or cleaned + if kernel_obj.state == KernelLifecycleStatus.TERMINATING: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.CLEAN, + kernel_obj.termination_reason + or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER, + suppress_events=True, + ) + elif kernel_id in self.agent.restarting_kernels: + pass + else: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER, + suppress_events=True, + ) + + result = KernelStatusCollection( + actual_existing_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ) + return result.to_json() + @rpc_function @collect_error async def create_kernels( diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index 8b52553727a..a58ad33b2d3 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -1263,3 +1263,29 @@ class ModelServiceProfile: ), RuntimeVariant.CMD: ModelServiceProfile(name="Predefined Image Command"), } + + +@dataclass +class KernelStatusCollection(JSONSerializableMixin): + KernelTerminationInfo = tuple[KernelId, str] + + actual_existing_kernels: list[KernelId] + actual_terminating_kernels: list[KernelTerminationInfo] + actual_terminated_kernels: list[KernelTerminationInfo] + + def to_json(self) -> dict[str, list[KernelId]]: + return dataclasses.asdict(self) + + @classmethod + def from_json(cls, obj: Mapping[str, Any]) -> KernelStatusCollection: + return cls(**cls.as_trafaret().check(obj)) + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + from . import validators as tx + + return t.Dict({ + t.Key("actual_existing_kernels"): tx.ToList(tx.UUID), + t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)), + t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)), + }) diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py index 82a28bc9d82..541a7e3ad3c 100644 --- a/src/ai/backend/common/validators.py +++ b/src/ai/backend/common/validators.py @@ -711,6 +711,19 @@ def check_and_return(self, value: Any) -> set: self._failure("value must be Iterable") +class ToList(t.List): + def check_common(self, value: Any) -> None: + return super().check_common(self.check_and_return(value)) # type: ignore[misc] + + def check_and_return(self, value: Any) -> list: + try: + return list(value) + except TypeError: + self._failure( + f"Cannot parse {type(value)} to list. value must be Iterable", value=value + ) + + class Delay(t.Trafaret): """ Convert a float or a tuple of 2 floats into a random generated float value diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index ddc8656b881..d62e78bd6eb 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -13,6 +13,7 @@ import uuid import zlib from collections import defaultdict +from collections.abc import Iterable, Mapping, MutableMapping, Sequence from datetime import datetime from decimal import Decimal from io import BytesIO @@ -22,10 +23,7 @@ Dict, List, Literal, - Mapping, - MutableMapping, Optional, - Sequence, Tuple, TypeAlias, Union, @@ -101,6 +99,7 @@ ImageRegistry, KernelEnqueueingConfig, KernelId, + KernelStatusCollection, ModelServiceStatus, RedisConnectionInfo, ResourceSlot, @@ -3277,6 +3276,23 @@ async def _update_session(db_session: AsyncSession) -> None: self._kernel_actual_allocated_resources[kernel_id] = actual_allocs await self.set_status_updatable_session(session_id) + async def _sync_agent_resource_and_get_kerenels( + self, + agent_id: AgentId, + preparing_kernels: Iterable[KernelId], + pulling_kernels: Iterable[KernelId], + running_kernels: Iterable[KernelId], + terminating_kernels: Iterable[KernelId], + ) -> KernelStatusCollection: + async with self.agent_cache.rpc_context(agent_id) as rpc: + resp: dict[str, Any] = await rpc.call.sync_and_get_kernels( + preparing_kernels, + pulling_kernels, + running_kernels, + terminating_kernels, + ) + return KernelStatusCollection.from_json(resp) + async def mark_kernel_terminated( self, kernel_id: KernelId, @@ -3486,6 +3502,88 @@ async def get_status_updatable_sessions(self) -> list[SessionId]: result.append(SessionId(msgpack.unpackb(raw_session_id))) return result + async def sync_agent_resource( + self, + db: ExtendedAsyncSAEngine, + agent_ids: Iterable[AgentId], + ) -> dict[AgentId, KernelStatusCollection | MultiAgentError]: + result: dict[AgentId, KernelStatusCollection | MultiAgentError] = {} + agent_kernel_by_status: dict[AgentId, dict[str, list[KernelId]]] = {} + stmt = ( + sa.select(AgentRow) + .where(AgentRow.id.in_(agent_ids)) + .options( + selectinload( + AgentRow.kernels.and_( + KernelRow.status.in_([ + KernelStatus.PREPARING, + KernelStatus.PULLING, + KernelStatus.RUNNING, + KernelStatus.TERMINATING, + ]) + ), + ).options(load_only(KernelRow.id, KernelRow.status)) + ) + ) + async with db.begin_readonly_session() as db_session: + for _agent_row in await db_session.scalars(stmt): + agent_row = cast(AgentRow, _agent_row) + preparing_kernels: list[KernelId] = [] + pulling_kernels: list[KernelId] = [] + running_kernels: list[KernelId] = [] + terminating_kernels: list[KernelId] = [] + for kernel in agent_row.kernels: + kernel_status = cast(KernelStatus, kernel.status) + match kernel_status: + case KernelStatus.PREPARING: + preparing_kernels.append(KernelId(kernel.id)) + case KernelStatus.PULLING: + pulling_kernels.append(KernelId(kernel.id)) + case KernelStatus.RUNNING: + running_kernels.append(KernelId(kernel.id)) + case KernelStatus.TERMINATING: + terminating_kernels.append(KernelId(kernel.id)) + case _: + continue + agent_kernel_by_status[AgentId(agent_row.id)] = { + "preparing_kernels": preparing_kernels, + "pulling_kernels": pulling_kernels, + "running_kernels": running_kernels, + "terminating_kernels": terminating_kernels, + } + aid_task_list: list[tuple[AgentId, asyncio.Task]] = [] + async with aiotools.PersistentTaskGroup() as tg: + for agent_id in agent_ids: + task = tg.create_task( + self._sync_agent_resource_and_get_kerenels( + agent_id, + agent_kernel_by_status[agent_id]["preparing_kernels"], + agent_kernel_by_status[agent_id]["pulling_kernels"], + agent_kernel_by_status[agent_id]["running_kernels"], + agent_kernel_by_status[agent_id]["terminating_kernels"], + ) + ) + aid_task_list.append((agent_id, task)) + for aid, task in aid_task_list: + agent_errors = [] + try: + resp = await task + except aiotools.TaskGroupError as e: + agent_errors.extend(e.__errors__) + except Exception as e: + agent_errors.append(e) + if agent_errors: + result[aid] = MultiAgentError( + "agent(s) raise errors during kernel resource sync", + agent_errors, + ) + else: + assert isinstance( + resp, KernelStatusCollection + ), f"response should be `KernelStatusCollection`, not {type(resp)}" + result[aid] = resp + return result + async def _get_user_email( self, kernel: KernelRow,