diff --git a/changes/2179.fix.md b/changes/2179.fix.md new file mode 100644 index 00000000000..ef6951949dd --- /dev/null +++ b/changes/2179.fix.md @@ -0,0 +1 @@ +Sync agent's kernel registry with the actual container through periodic loop. diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 19c72585189..477c17b1f8e 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -16,6 +16,7 @@ import zlib from abc import ABCMeta, abstractmethod from collections import defaultdict +from collections.abc import Container as ContainerT from decimal import Decimal from io import SEEK_END, BytesIO from pathlib import Path @@ -1076,30 +1077,17 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: ev.done_future.set_exception(e) await self.produce_error_event() finally: - if ev.kernel_id in self.restarting_kernels: - # Don't forget as we are restarting it. - kernel_obj = self.kernel_registry.get(ev.kernel_id, None) - else: - # Forget as we are done with this kernel. - kernel_obj = self.kernel_registry.pop(ev.kernel_id, None) + kernel_obj = self.kernel_registry.get(ev.kernel_id, None) try: if kernel_obj is not None: - # Restore used ports to the port pool. - port_range = self.local_config["container"]["port-range"] - # Exclude out-of-range ports, because when the agent restarts - # with a different port range, existing containers' host ports - # may not belong to the new port range. - if host_ports := kernel_obj.get("host_ports"): - restored_ports = [ - *filter( - lambda p: port_range[0] <= p <= port_range[1], - host_ports, - ) - ] - self.port_pool.update(restored_ports) + await self._restore_port_pool(kernel_obj) await kernel_obj.close() finally: self.terminating_kernels.discard(ev.kernel_id) + try: + del self.kernel_registry[ev.kernel_id] + except KeyError: + pass if restart_tracker := self.restarting_kernels.get(ev.kernel_id, None): restart_tracker.destroy_event.set() else: @@ -1116,6 +1104,20 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: if ev.done_future is not None and not ev.done_future.done(): ev.done_future.set_result(None) + async def _restore_port_pool(self, kernel_obj: AbstractKernel) -> None: + port_range = self.local_config["container"]["port-range"] + # Exclude out-of-range ports, because when the agent restarts + # with a different port range, existing containers' host ports + # may not belong to the new port range. + if host_ports := kernel_obj.get("host_ports"): + restored_ports = [ + *filter( + lambda p: port_range[0] <= p <= port_range[1], + host_ports, + ) + ] + self.port_pool.update(restored_ports) + async def process_lifecycle_events(self) -> None: async def lifecycle_task_exception_handler( exc_type: Type[Exception], @@ -1260,6 +1262,8 @@ async def sync_container_lifecycles(self, interval: float) -> None: for cases when we miss the container lifecycle events from the underlying implementation APIs due to the agent restarts or crashes. """ + all_detected_kernels: set[KernelId] = set() + known_kernels: Dict[KernelId, ContainerId] = {} alive_kernels: Dict[KernelId, ContainerId] = {} kernel_session_map: Dict[KernelId, SessionId] = {} @@ -1270,6 +1274,7 @@ async def sync_container_lifecycles(self, interval: float) -> None: try: # Check if: there are dead containers for kernel_id, container in await self.enumerate_containers(DEAD_STATUS_SET): + all_detected_kernels.add(kernel_id) if ( kernel_id in self.restarting_kernels or kernel_id in self.terminating_kernels @@ -1289,6 +1294,7 @@ async def sync_container_lifecycles(self, interval: float) -> None: KernelLifecycleEventReason.SELF_TERMINATED, ) for kernel_id, container in await self.enumerate_containers(ACTIVE_STATUS_SET): + all_detected_kernels.add(kernel_id) alive_kernels[kernel_id] = container.id session_id = SessionId(UUID(container.labels["ai.backend.session-id"])) kernel_session_map[kernel_id] = session_id @@ -1323,6 +1329,7 @@ async def sync_container_lifecycles(self, interval: float) -> None: KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER, ) finally: + await self.prune_kernel_registry(all_detected_kernels) # Enqueue the events. for kernel_id, ev in terminated_kernels.items(): await self.container_lifecycle_queue.put(ev) @@ -1330,6 +1337,33 @@ async def sync_container_lifecycles(self, interval: float) -> None: # Set container count await self.set_container_count(len(own_kernels.keys())) + async def prune_kernel_registry( + self, detected_kernels: ContainerT[KernelId], *, ensure_cleaned: bool = True + ) -> None: + """ + Deregister containerless kernels from `kernel_registry` + since `_handle_clean_event()` does not deregister them. + """ + any_container_pruned = False + for kernel_id in [*self.kernel_registry.keys()]: + if kernel_id not in detected_kernels: + if ensure_cleaned: + # Don't need to process this through event task + # since there is no communication with any container here. + kernel_obj = self.kernel_registry[kernel_id] + kernel_obj.stats_enabled = False + if kernel_obj.runner is not None: + await kernel_obj.runner.close() + if kernel_obj.clean_event is not None and not kernel_obj.clean_event.done(): + kernel_obj.clean_event.set_result(None) + await self._restore_port_pool(kernel_obj) + await kernel_obj.close() + del self.kernel_registry[kernel_id] + self.terminating_kernels.discard(kernel_id) + any_container_pruned = True + if any_container_pruned: + await self.reconstruct_resource_usage() + async def set_container_count(self, container_count: int) -> None: await redis_helper.execute( self.redis_stat_pool, lambda r: r.set(f"container_count.{self.id}", container_count) @@ -2035,8 +2069,6 @@ async def create_kernel( " unregistered.", kernel_id, ) - async with self.registry_lock: - del self.kernel_registry[kernel_id] raise async with self.registry_lock: self.kernel_registry[ctx.kernel_id].data.update(container_data)