Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2179.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Sync agent's kernel registry with the actual container through periodic loop.
74 changes: 53 additions & 21 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import zlib
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from collections.abc import Container as ContainerT
from decimal import Decimal
from io import SEEK_END, BytesIO
from pathlib import Path
Expand Down Expand Up @@ -1076,30 +1077,17 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
ev.done_future.set_exception(e)
await self.produce_error_event()
finally:
if ev.kernel_id in self.restarting_kernels:
# Don't forget as we are restarting it.
kernel_obj = self.kernel_registry.get(ev.kernel_id, None)
else:
# Forget as we are done with this kernel.
kernel_obj = self.kernel_registry.pop(ev.kernel_id, None)
kernel_obj = self.kernel_registry.get(ev.kernel_id, None)
try:
if kernel_obj is not None:
# Restore used ports to the port pool.
port_range = self.local_config["container"]["port-range"]
# Exclude out-of-range ports, because when the agent restarts
# with a different port range, existing containers' host ports
# may not belong to the new port range.
if host_ports := kernel_obj.get("host_ports"):
restored_ports = [
*filter(
lambda p: port_range[0] <= p <= port_range[1],
host_ports,
)
]
self.port_pool.update(restored_ports)
await self._restore_port_pool(kernel_obj)
await kernel_obj.close()
finally:
self.terminating_kernels.discard(ev.kernel_id)
try:
del self.kernel_registry[ev.kernel_id]
except KeyError:
pass
if restart_tracker := self.restarting_kernels.get(ev.kernel_id, None):
restart_tracker.destroy_event.set()
else:
Expand All @@ -1116,6 +1104,20 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
if ev.done_future is not None and not ev.done_future.done():
ev.done_future.set_result(None)

async def _restore_port_pool(self, kernel_obj: AbstractKernel) -> None:
port_range = self.local_config["container"]["port-range"]
# Exclude out-of-range ports, because when the agent restarts
# with a different port range, existing containers' host ports
# may not belong to the new port range.
if host_ports := kernel_obj.get("host_ports"):
restored_ports = [
*filter(
lambda p: port_range[0] <= p <= port_range[1],
host_ports,
)
]
self.port_pool.update(restored_ports)

async def process_lifecycle_events(self) -> None:
async def lifecycle_task_exception_handler(
exc_type: Type[Exception],
Expand Down Expand Up @@ -1260,6 +1262,8 @@ async def sync_container_lifecycles(self, interval: float) -> None:
for cases when we miss the container lifecycle events from the underlying implementation APIs
due to the agent restarts or crashes.
"""
all_detected_kernels: set[KernelId] = set()

known_kernels: Dict[KernelId, ContainerId] = {}
alive_kernels: Dict[KernelId, ContainerId] = {}
kernel_session_map: Dict[KernelId, SessionId] = {}
Expand All @@ -1270,6 +1274,7 @@ async def sync_container_lifecycles(self, interval: float) -> None:
try:
# Check if: there are dead containers
for kernel_id, container in await self.enumerate_containers(DEAD_STATUS_SET):
all_detected_kernels.add(kernel_id)
if (
kernel_id in self.restarting_kernels
or kernel_id in self.terminating_kernels
Expand All @@ -1289,6 +1294,7 @@ async def sync_container_lifecycles(self, interval: float) -> None:
KernelLifecycleEventReason.SELF_TERMINATED,
)
for kernel_id, container in await self.enumerate_containers(ACTIVE_STATUS_SET):
all_detected_kernels.add(kernel_id)
alive_kernels[kernel_id] = container.id
session_id = SessionId(UUID(container.labels["ai.backend.session-id"]))
kernel_session_map[kernel_id] = session_id
Expand Down Expand Up @@ -1323,13 +1329,41 @@ async def sync_container_lifecycles(self, interval: float) -> None:
KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER,
)
finally:
await self.prune_kernel_registry(all_detected_kernels)
# Enqueue the events.
for kernel_id, ev in terminated_kernels.items():
await self.container_lifecycle_queue.put(ev)

# Set container count
await self.set_container_count(len(own_kernels.keys()))

async def prune_kernel_registry(
self, detected_kernels: ContainerT[KernelId], *, ensure_cleaned: bool = True
) -> None:
"""
Deregister containerless kernels from `kernel_registry`
since `_handle_clean_event()` does not deregister them.
"""
any_container_pruned = False
for kernel_id in [*self.kernel_registry.keys()]:
if kernel_id not in detected_kernels:
if ensure_cleaned:
# Don't need to process this through event task
# since there is no communication with any container here.
kernel_obj = self.kernel_registry[kernel_id]
kernel_obj.stats_enabled = False
if kernel_obj.runner is not None:
await kernel_obj.runner.close()
if kernel_obj.clean_event is not None and not kernel_obj.clean_event.done():
kernel_obj.clean_event.set_result(None)
await self._restore_port_pool(kernel_obj)
await kernel_obj.close()
del self.kernel_registry[kernel_id]
self.terminating_kernels.discard(kernel_id)
any_container_pruned = True
if any_container_pruned:
await self.reconstruct_resource_usage()

async def set_container_count(self, container_count: int) -> None:
await redis_helper.execute(
self.redis_stat_pool, lambda r: r.set(f"container_count.{self.id}", container_count)
Expand Down Expand Up @@ -2035,8 +2069,6 @@ async def create_kernel(
" unregistered.",
kernel_id,
)
async with self.registry_lock:
del self.kernel_registry[kernel_id]
raise
async with self.registry_lock:
self.kernel_registry[ctx.kernel_id].data.update(container_data)
Expand Down