Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2178.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Keep `sync_container_lifecycles()` bgtask alive in a loop.
230 changes: 145 additions & 85 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@
known_slot_types,
)
from .stats import StatContext, StatModes
from .types import Container, ContainerLifecycleEvent, ContainerStatus, LifecycleEvent, MountInfo
from .types import (
Container,
ContainerLifecycleEvent,
ContainerStatus,
LifecycleEvent,
MountInfo,
)
from .utils import generate_local_instance_id, get_arch_name

if TYPE_CHECKING:
Expand Down Expand Up @@ -737,6 +743,7 @@ async def shutdown(self, stop_signal: signal.Signals) -> None:
if kernel_obj.runner is not None:
await kernel_obj.runner.close()
await kernel_obj.close()
await self.save_last_registry(force=True)
if stop_signal == signal.SIGTERM:
await self.clean_all_kernels(blocking=True)

Expand Down Expand Up @@ -1011,7 +1018,9 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None:
kernel_obj = self.kernel_registry.get(ev.kernel_id)
if kernel_obj is None:
log.warning(
"destroy_kernel(k:{0}) kernel missing (already dead?)", ev.kernel_id
"destroy_kernel(k:{0}, c:{1}) kernel missing (already dead?)",
ev.kernel_id,
ev.container_id,
)
if ev.container_id is None:
await self.reconstruct_resource_usage()
Expand Down Expand Up @@ -1039,18 +1048,17 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None:
ev.done_future.set_exception(e)
raise
finally:
if ev.container_id is not None:
await self.container_lifecycle_queue.put(
ContainerLifecycleEvent(
ev.kernel_id,
ev.session_id,
ev.container_id,
LifecycleEvent.CLEAN,
ev.reason,
suppress_events=ev.suppress_events,
done_future=ev.done_future,
),
)
await self.container_lifecycle_queue.put(
ContainerLifecycleEvent(
ev.kernel_id,
ev.session_id,
ev.container_id,
LifecycleEvent.CLEAN,
ev.reason,
suppress_events=ev.suppress_events,
done_future=ev.done_future,
),
)
except asyncio.CancelledError:
pass
except Exception:
Expand Down Expand Up @@ -1262,75 +1270,122 @@ async def sync_container_lifecycles(self, interval: float) -> None:
for cases when we miss the container lifecycle events from the underlying implementation APIs
due to the agent restarts or crashes.
"""
known_kernels: Dict[KernelId, ContainerId] = {}
known_kernels: Dict[KernelId, ContainerId | None] = {}
alive_kernels: Dict[KernelId, ContainerId] = {}
kernel_session_map: Dict[KernelId, SessionId] = {}
own_kernels: dict[KernelId, ContainerId] = {}
terminated_kernels = {}
terminated_kernels: dict[KernelId, ContainerLifecycleEvent] = {}

async with self.registry_lock:
def _get_session_id(container: Container) -> SessionId | None:
_session_id = container.labels.get("ai.backend.session-id")
try:
# Check if: there are dead containers
for kernel_id, container in await self.enumerate_containers(DEAD_STATUS_SET):
if (
kernel_id in self.restarting_kernels
or kernel_id in self.terminating_kernels
):
continue
log.info(
"detected dead container during lifeycle sync (k:{}, c:{})",
kernel_id,
container.id,
)
session_id = SessionId(UUID(container.labels["ai.backend.session-id"]))
terminated_kernels[kernel_id] = ContainerLifecycleEvent(
kernel_id,
session_id,
known_kernels[kernel_id],
LifecycleEvent.CLEAN,
KernelLifecycleEventReason.SELF_TERMINATED,
)
for kernel_id, container in await self.enumerate_containers(ACTIVE_STATUS_SET):
alive_kernels[kernel_id] = container.id
session_id = SessionId(UUID(container.labels["ai.backend.session-id"]))
kernel_session_map[kernel_id] = session_id
own_kernels[kernel_id] = container.id
for kernel_id, kernel_obj in self.kernel_registry.items():
known_kernels[kernel_id] = kernel_obj["container_id"]
session_id = kernel_obj.session_id
kernel_session_map[kernel_id] = session_id
# Check if: kernel_registry has the container but it's gone.
for kernel_id in known_kernels.keys() - alive_kernels.keys():
if (
kernel_id in self.restarting_kernels
or kernel_id in self.terminating_kernels
):
continue
terminated_kernels[kernel_id] = ContainerLifecycleEvent(
kernel_id,
kernel_session_map[kernel_id],
known_kernels[kernel_id],
LifecycleEvent.CLEAN,
KernelLifecycleEventReason.SELF_TERMINATED,
return SessionId(UUID(_session_id))
except ValueError:
log.warning(
f"sync_container_lifecycles() invalid session-id (cid: {container.id}, sid:{_session_id})"
)
return None

log.debug("sync_container_lifecycles(): triggered")
try:
_containers = await self.enumerate_containers(ACTIVE_STATUS_SET | DEAD_STATUS_SET)
async with self.registry_lock:
try:
# Check if: there are dead containers
dead_containers = [
(kid, container)
for kid, container in _containers
if container.status in DEAD_STATUS_SET
]
log.debug(
f"detected dead containers: {[container.id[:12] for _, container in dead_containers]}"
)
# Check if: there are containers not spawned by me.
for kernel_id in alive_kernels.keys() - known_kernels.keys():
if kernel_id in self.restarting_kernels:
continue
terminated_kernels[kernel_id] = ContainerLifecycleEvent(
kernel_id,
kernel_session_map[kernel_id],
alive_kernels[kernel_id],
LifecycleEvent.DESTROY,
KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER,
for kernel_id, container in dead_containers:
if kernel_id in self.restarting_kernels:
continue
log.info(
"detected dead container during lifeycle sync (k:{}, c:{})",
kernel_id,
container.id,
)
session_id = _get_session_id(container)
if session_id is None:
continue
terminated_kernels[kernel_id] = ContainerLifecycleEvent(
kernel_id,
session_id,
container.id,
LifecycleEvent.CLEAN,
KernelLifecycleEventReason.SELF_TERMINATED,
)
active_containers = [
(kid, container)
for kid, container in _containers
if container.status in ACTIVE_STATUS_SET
]
log.debug(
f"detected active containers: {[container.id[:12] for _, container in active_containers]}"
)
finally:
# Enqueue the events.
for kernel_id, ev in terminated_kernels.items():
await self.container_lifecycle_queue.put(ev)

# Set container count
await self.set_container_count(len(own_kernels.keys()))
for kernel_id, container in active_containers:
alive_kernels[kernel_id] = container.id
session_id = _get_session_id(container)
if session_id is None:
continue
kernel_session_map[kernel_id] = session_id
own_kernels[kernel_id] = container.id
for kernel_id, kernel_obj in self.kernel_registry.items():
known_kernels[kernel_id] = (
ContainerId(kernel_obj.container_id)
if kernel_obj.container_id is not None
else None
)
session_id = kernel_obj.session_id
kernel_session_map[kernel_id] = session_id
# Check if: kernel_registry has the container but it's gone.
for kernel_id in known_kernels.keys() - alive_kernels.keys():
if (
kernel_id in self.restarting_kernels
or kernel_id in self.terminating_kernels
):
continue
log.debug(f"kernel with no container (kid: {kernel_id})")
terminated_kernels[kernel_id] = ContainerLifecycleEvent(
kernel_id,
kernel_session_map[kernel_id],
known_kernels[kernel_id],
LifecycleEvent.CLEAN,
KernelLifecycleEventReason.CONTAINER_NOT_FOUND,
)
# Check if: there are containers already deleted from my registry.
for kernel_id in alive_kernels.keys() - known_kernels.keys():
if kernel_id in self.restarting_kernels:
continue
log.debug(f"kernel not found in registry (kid:{kernel_id})")
terminated_kernels[kernel_id] = ContainerLifecycleEvent(
kernel_id,
kernel_session_map[kernel_id],
alive_kernels[kernel_id],
LifecycleEvent.DESTROY,
KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER,
)
finally:
# Enqueue the events.
terminated_kernel_ids = ",".join([
str(kid) for kid in terminated_kernels.keys()
])
log.debug(f"Terminating kernels(ids:[{terminated_kernel_ids}])")
for kernel_id, ev in terminated_kernels.items():
await self.container_lifecycle_queue.put(ev)

# Set container count
await self.set_container_count(len(own_kernels.keys()))
except asyncio.CancelledError:
pass
except asyncio.TimeoutError:
log.warning("sync_container_lifecycles() timeout, continuing")
except Exception as e:
log.exception(f"sync_container_lifecycles() failure, continuing (detail: {repr(e)})")
await self.produce_error_event()

async def set_container_count(self, container_count: int) -> None:
await redis_helper.execute(
Expand Down Expand Up @@ -1946,7 +2001,7 @@ async def create_kernel(
service_ports,
)
async with self.registry_lock:
self.kernel_registry[ctx.kernel_id] = kernel_obj
self.kernel_registry[kernel_id] = kernel_obj
try:
container_data = await ctx.start_container(
kernel_obj,
Expand All @@ -1958,7 +2013,7 @@ async def create_kernel(
msg = e.message or "unknown"
log.error(
"Kernel failed to create container. Kernel is going to be destroyed."
f" (k:{ctx.kernel_id}, detail:{msg})",
f" (k:{kernel_id}, detail:{msg})",
)
cid = e.container_id
async with self.registry_lock:
Expand All @@ -1973,17 +2028,22 @@ async def create_kernel(
raise AgentError(
f"Kernel failed to create container (k:{str(ctx.kernel_id)}, detail:{msg})"
)
except Exception:
except Exception as e:
log.warning(
"Kernel failed to create container (k:{}). Kernel is going to be"
" unregistered.",
"Kernel failed to create container (k:{}). Kernel is going to be destroyed.",
kernel_id,
)
async with self.registry_lock:
del self.kernel_registry[kernel_id]
raise
await self.inject_container_lifecycle_event(
kernel_id,
session_id,
LifecycleEvent.DESTROY,
KernelLifecycleEventReason.FAILED_TO_CREATE,
)
raise AgentError(
f"Kernel failed to create container (k:{str(kernel_id)}, detail: {str(e)})"
)
async with self.registry_lock:
self.kernel_registry[ctx.kernel_id].data.update(container_data)
self.kernel_registry[kernel_id].data.update(container_data)
await kernel_obj.init(self.event_producer)

current_task = asyncio.current_task()
Expand Down
13 changes: 8 additions & 5 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import signal
import struct
import sys
from collections.abc import Mapping
from decimal import Decimal
from functools import partial
from io import StringIO
Expand All @@ -22,13 +23,13 @@
FrozenSet,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from uuid import UUID

Expand Down Expand Up @@ -860,7 +861,7 @@ async def start_container(
config=container_config, name=kernel_name
)
assert container is not None
cid = container._id
cid = cast(str, container._id)
resource_spec.container_id = cid
# Write resource.txt again to update the container id.
with open(self.config_dir / "resource.txt", "w") as f:
Expand Down Expand Up @@ -896,10 +897,10 @@ async def start_container(
except asyncio.CancelledError:
if container is not None:
raise ContainerCreationError(
container_id=cid, message="Container creation cancelled"
container_id=container._id, message="Container creation cancelled"
)
raise
except Exception:
except Exception as e:
# Oops, we have to restore the allocated resources!
scratch_type = self.local_config["container"]["scratch-type"]
scratch_root = self.local_config["container"]["scratch-root"]
Expand All @@ -917,7 +918,9 @@ async def start_container(
for dev_name, device_alloc in resource_spec.allocations.items():
self.computers[dev_name].alloc_map.free(device_alloc)
if container is not None:
raise ContainerCreationError(container_id=cid, message="unknown")
raise ContainerCreationError(
container_id=container._id, message=f"unknown. {repr(e)}"
)
raise

additional_network_names: Set[str] = set()
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/common/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class KernelLifecycleEventReason(enum.StrEnum):
UNKNOWN = "unknown"
USER_REQUESTED = "user-requested"
NOT_FOUND_IN_MANAGER = "not-found-in-manager"
CONTAINER_NOT_FOUND = "container-not-found"

@classmethod
def from_value(cls, value: Optional[str]) -> Optional[KernelLifecycleEventReason]:
Expand Down
Loading