diff --git a/changes/2317.fix.md b/changes/2317.fix.md new file mode 100644 index 00000000000..c4e6f72a887 --- /dev/null +++ b/changes/2317.fix.md @@ -0,0 +1 @@ +Omit to clean containerless kernels which are still creating its container. diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index dbd922a9847..20fbc7f3b24 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -157,6 +157,7 @@ Container, ContainerLifecycleEvent, ContainerStatus, + KernelLifecycleStatus, LifecycleEvent, MountInfo, ) @@ -560,7 +561,6 @@ class AbstractAgent( redis: Redis restarting_kernels: MutableMapping[KernelId, RestartTracker] - terminating_kernels: Set[KernelId] timer_tasks: MutableSequence[asyncio.Task] container_lifecycle_queue: asyncio.Queue[ContainerLifecycleEvent | Sentinel] @@ -600,7 +600,6 @@ def __init__( self.computers = {} self.images = {} # repoTag -> digest self.restarting_kernels = {} - self.terminating_kernels = set() self.stat_ctx = StatContext( self, mode=StatModes(local_config["container"]["stats-type"]), @@ -969,7 +968,10 @@ async def collect_container_stat(self, interval: float): container_ids = [] async with self.registry_lock: for kernel_id, kernel_obj in [*self.kernel_registry.items()]: - if not kernel_obj.stats_enabled: + if ( + not kernel_obj.stats_enabled + or kernel_obj.state != KernelLifecycleStatus.RUNNING + ): continue container_ids.append(kernel_obj["container_id"]) await self.stat_ctx.collect_container_stat(container_ids) @@ -987,7 +989,10 @@ async def collect_process_stat(self, interval: float): container_ids = [] async with self.registry_lock: for kernel_id, kernel_obj in [*self.kernel_registry.items()]: - if not kernel_obj.stats_enabled: + if ( + not kernel_obj.stats_enabled + or kernel_obj.state != KernelLifecycleStatus.RUNNING + ): continue updated_kernel_ids.append(kernel_id) container_ids.append(kernel_obj["container_id"]) @@ -1012,6 +1017,7 @@ async def _handle_start_event(self, ev: ContainerLifecycleEvent) -> None: kernel_obj = self.kernel_registry.get(ev.kernel_id) if kernel_obj is not None: kernel_obj.stats_enabled = True + kernel_obj.state = KernelLifecycleStatus.RUNNING async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: try: @@ -1019,7 +1025,6 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: assert current_task is not None if ev.kernel_id not in self._ongoing_destruction_tasks: self._ongoing_destruction_tasks[ev.kernel_id] = current_task - self.terminating_kernels.add(ev.kernel_id) async with self.registry_lock: kernel_obj = self.kernel_registry.get(ev.kernel_id) if kernel_obj is None: @@ -1042,6 +1047,7 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: ev.done_future.set_result(None) return else: + kernel_obj.state = KernelLifecycleStatus.TERMINATING kernel_obj.stats_enabled = False kernel_obj.termination_reason = ev.reason if kernel_obj.runner is not None: @@ -1115,7 +1121,6 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: self.port_pool.update(restored_ports) await kernel_obj.close() finally: - self.terminating_kernels.discard(ev.kernel_id) if restart_tracker := self.restarting_kernels.get(ev.kernel_id, None): restart_tracker.destroy_event.set() else: @@ -1349,9 +1354,10 @@ def _get_session_id(container: Container) -> SessionId | None: kernel_session_map[kernel_id] = session_id # Check if: kernel_registry has the container but it's gone. for kernel_id in known_kernels.keys() - alive_kernels.keys(): + kernel_obj = self.kernel_registry[kernel_id] if ( kernel_id in self.restarting_kernels - or kernel_id in self.terminating_kernels + or kernel_obj.state != KernelLifecycleStatus.RUNNING ): continue log.debug(f"kernel with no container (kid: {kernel_id})") @@ -1379,7 +1385,8 @@ def _get_session_id(container: Container) -> SessionId | None: terminated_kernel_ids = ",".join([ str(kid) for kid in terminated_kernels.keys() ]) - log.debug(f"Terminating kernels(ids:[{terminated_kernel_ids}])") + if terminated_kernel_ids: + log.debug(f"Terminate kernels(ids:[{terminated_kernel_ids}])") for kernel_id, ev in terminated_kernels.items(): await self.container_lifecycle_queue.put(ev) @@ -2141,6 +2148,8 @@ async def create_kernel( }, ), ) + async with self.registry_lock: + kernel_obj.state = KernelLifecycleStatus.RUNNING # The startup command for the batch-type sessions will be executed by the manager # upon firing of the "session_started" event. diff --git a/src/ai/backend/agent/kernel.py b/src/ai/backend/agent/kernel.py index b9704e22e3b..8d0904cd6f2 100644 --- a/src/ai/backend/agent/kernel.py +++ b/src/ai/backend/agent/kernel.py @@ -55,7 +55,7 @@ from .exception import UnsupportedBaseDistroError from .resources import KernelResourceSpec -from .types import AgentEventData +from .types import AgentEventData, KernelLifecycleStatus log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] @@ -177,6 +177,7 @@ class AbstractKernel(UserDict, aobject, metaclass=ABCMeta): stats_enabled: bool # FIXME: apply TypedDict to data in Python 3.8 environ: Mapping[str, Any] + status: KernelLifecycleStatus _tasks: Set[asyncio.Task] @@ -213,6 +214,7 @@ def __init__( self.environ = environ self.runner = None self.container_id = None + self.state = KernelLifecycleStatus.PREPARING async def init(self, event_producer: EventProducer) -> None: log.debug( @@ -233,6 +235,9 @@ def __getstate__(self) -> Mapping[str, Any]: return props def __setstate__(self, props) -> None: + # Used when a `Kernel` object is loaded from pickle data. + if "state" not in props: + props["state"] = KernelLifecycleStatus.RUNNING self.__dict__.update(props) # agent_config is set by the pickle.loads() caller. self.clean_event = None diff --git a/src/ai/backend/agent/types.py b/src/ai/backend/agent/types.py index 730beb291a0..8d7d6ffe018 100644 --- a/src/ai/backend/agent/types.py +++ b/src/ai/backend/agent/types.py @@ -64,6 +64,20 @@ class Container: backend_obj: Any # used to keep the backend-specific data +class KernelLifecycleStatus(enum.StrEnum): + """ + The lifecycle status of `AbstractKernel` object. + + By default, the state of a newly created kernel is `PREPARING`. + The state of a kernel changes from `PREPARING` to `RUNNING` after the kernel starts a container successfully. + It changes from `RUNNING` to `TERMINATING` before destroy kernel. + """ + + PREPARING = enum.auto() + RUNNING = enum.auto() + TERMINATING = enum.auto() + + class LifecycleEvent(enum.IntEnum): DESTROY = 0 CLEAN = 1