Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2317.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Omit to clean containerless kernels which are still creating its container.
25 changes: 17 additions & 8 deletions src/ai/backend/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
Container,
ContainerLifecycleEvent,
ContainerStatus,
KernelLifecycleStatus,
LifecycleEvent,
MountInfo,
)
Expand Down Expand Up @@ -560,7 +561,6 @@ class AbstractAgent(
redis: Redis

restarting_kernels: MutableMapping[KernelId, RestartTracker]
terminating_kernels: Set[KernelId]
timer_tasks: MutableSequence[asyncio.Task]
container_lifecycle_queue: asyncio.Queue[ContainerLifecycleEvent | Sentinel]

Expand Down Expand Up @@ -600,7 +600,6 @@ def __init__(
self.computers = {}
self.images = {} # repoTag -> digest
self.restarting_kernels = {}
self.terminating_kernels = set()
self.stat_ctx = StatContext(
self,
mode=StatModes(local_config["container"]["stats-type"]),
Expand Down Expand Up @@ -969,7 +968,10 @@ async def collect_container_stat(self, interval: float):
container_ids = []
async with self.registry_lock:
for kernel_id, kernel_obj in [*self.kernel_registry.items()]:
if not kernel_obj.stats_enabled:
if (
not kernel_obj.stats_enabled
or kernel_obj.state != KernelLifecycleStatus.RUNNING
):
continue
container_ids.append(kernel_obj["container_id"])
await self.stat_ctx.collect_container_stat(container_ids)
Expand All @@ -987,7 +989,10 @@ async def collect_process_stat(self, interval: float):
container_ids = []
async with self.registry_lock:
for kernel_id, kernel_obj in [*self.kernel_registry.items()]:
if not kernel_obj.stats_enabled:
if (
not kernel_obj.stats_enabled
or kernel_obj.state != KernelLifecycleStatus.RUNNING
):
continue
updated_kernel_ids.append(kernel_id)
container_ids.append(kernel_obj["container_id"])
Expand All @@ -1012,14 +1017,14 @@ async def _handle_start_event(self, ev: ContainerLifecycleEvent) -> None:
kernel_obj = self.kernel_registry.get(ev.kernel_id)
if kernel_obj is not None:
kernel_obj.stats_enabled = True
kernel_obj.state = KernelLifecycleStatus.RUNNING

async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None:
try:
current_task = asyncio.current_task()
assert current_task is not None
if ev.kernel_id not in self._ongoing_destruction_tasks:
self._ongoing_destruction_tasks[ev.kernel_id] = current_task
self.terminating_kernels.add(ev.kernel_id)
async with self.registry_lock:
kernel_obj = self.kernel_registry.get(ev.kernel_id)
if kernel_obj is None:
Expand All @@ -1042,6 +1047,7 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None:
ev.done_future.set_result(None)
return
else:
kernel_obj.state = KernelLifecycleStatus.TERMINATING
kernel_obj.stats_enabled = False
kernel_obj.termination_reason = ev.reason
if kernel_obj.runner is not None:
Expand Down Expand Up @@ -1115,7 +1121,6 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None:
self.port_pool.update(restored_ports)
await kernel_obj.close()
finally:
self.terminating_kernels.discard(ev.kernel_id)
if restart_tracker := self.restarting_kernels.get(ev.kernel_id, None):
restart_tracker.destroy_event.set()
else:
Expand Down Expand Up @@ -1349,9 +1354,10 @@ def _get_session_id(container: Container) -> SessionId | None:
kernel_session_map[kernel_id] = session_id
# Check if: kernel_registry has the container but it's gone.
for kernel_id in known_kernels.keys() - alive_kernels.keys():
kernel_obj = self.kernel_registry[kernel_id]
if (
kernel_id in self.restarting_kernels
or kernel_id in self.terminating_kernels
or kernel_obj.state != KernelLifecycleStatus.RUNNING
):
continue
log.debug(f"kernel with no container (kid: {kernel_id})")
Expand Down Expand Up @@ -1379,7 +1385,8 @@ def _get_session_id(container: Container) -> SessionId | None:
terminated_kernel_ids = ",".join([
str(kid) for kid in terminated_kernels.keys()
])
log.debug(f"Terminating kernels(ids:[{terminated_kernel_ids}])")
if terminated_kernel_ids:
log.debug(f"Terminate kernels(ids:[{terminated_kernel_ids}])")
for kernel_id, ev in terminated_kernels.items():
await self.container_lifecycle_queue.put(ev)

Expand Down Expand Up @@ -2141,6 +2148,8 @@ async def create_kernel(
},
),
)
async with self.registry_lock:
kernel_obj.state = KernelLifecycleStatus.RUNNING

# The startup command for the batch-type sessions will be executed by the manager
# upon firing of the "session_started" event.
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/agent/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

from .exception import UnsupportedBaseDistroError
from .resources import KernelResourceSpec
from .types import AgentEventData
from .types import AgentEventData, KernelLifecycleStatus

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]

Expand Down Expand Up @@ -177,6 +177,7 @@ class AbstractKernel(UserDict, aobject, metaclass=ABCMeta):
stats_enabled: bool
# FIXME: apply TypedDict to data in Python 3.8
environ: Mapping[str, Any]
status: KernelLifecycleStatus

_tasks: Set[asyncio.Task]

Expand Down Expand Up @@ -213,6 +214,7 @@ def __init__(
self.environ = environ
self.runner = None
self.container_id = None
self.state = KernelLifecycleStatus.PREPARING

async def init(self, event_producer: EventProducer) -> None:
log.debug(
Expand All @@ -233,6 +235,9 @@ def __getstate__(self) -> Mapping[str, Any]:
return props

def __setstate__(self, props) -> None:
# Used when a `Kernel` object is loaded from pickle data.
if "state" not in props:
props["state"] = KernelLifecycleStatus.RUNNING
Copy link
Member Author

@fregataa fregataa Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we shutdown and restart an agent to update its version, kernel_registry is dumped as a pickle file and the agent loads the pickle file when it restarts. Old Kernel objects that are dumped before the version update do not have state field when we restart agent.
We need to insert the state value to the old Kernel objects.

self.__dict__.update(props)
# agent_config is set by the pickle.loads() caller.
self.clean_event = None
Expand Down
14 changes: 14 additions & 0 deletions src/ai/backend/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ class Container:
backend_obj: Any # used to keep the backend-specific data


class KernelLifecycleStatus(enum.StrEnum):
"""
The lifecycle status of `AbstractKernel` object.

By default, the state of a newly created kernel is `PREPARING`.
The state of a kernel changes from `PREPARING` to `RUNNING` after the kernel starts a container successfully.
It changes from `RUNNING` to `TERMINATING` before destroy kernel.
"""

PREPARING = enum.auto()
RUNNING = enum.auto()
TERMINATING = enum.auto()


class LifecycleEvent(enum.IntEnum):
DESTROY = 0
CLEAN = 1
Expand Down