diff --git a/changes/2178.fix.md b/changes/2178.fix.md new file mode 100644 index 00000000000..9ab7ec11b86 --- /dev/null +++ b/changes/2178.fix.md @@ -0,0 +1,2 @@ +* Let container sync loop keep alive +* Handle container creation error granularly diff --git a/src/ai/backend/agent/agent.py b/src/ai/backend/agent/agent.py index 2acd8a33932..f32bf3e164d 100644 --- a/src/ai/backend/agent/agent.py +++ b/src/ai/backend/agent/agent.py @@ -153,7 +153,14 @@ known_slot_types, ) from .stats import StatContext, StatModes -from .types import Container, ContainerLifecycleEvent, ContainerStatus, LifecycleEvent, MountInfo +from .types import ( + Container, + ContainerLifecycleEvent, + ContainerStatus, + KernelStatus, + LifecycleEvent, + MountInfo, +) from .utils import generate_local_instance_id, get_arch_name if TYPE_CHECKING: @@ -554,7 +561,6 @@ class AbstractAgent( redis: Redis restarting_kernels: MutableMapping[KernelId, RestartTracker] - terminating_kernels: Set[KernelId] timer_tasks: MutableSequence[asyncio.Task] container_lifecycle_queue: asyncio.Queue[ContainerLifecycleEvent | Sentinel] @@ -594,7 +600,6 @@ def __init__( self.computers = {} self.images = {} # repoTag -> digest self.restarting_kernels = {} - self.terminating_kernels = set() self.stat_ctx = StatContext( self, mode=StatModes(local_config["container"]["stats-type"]), @@ -698,7 +703,7 @@ async def _pipeline(r: Redis): self.timer_tasks.append(aiotools.create_timer(self.heartbeat, heartbeat_interval)) # Prepare auto-cleaning of idle kernels. - self.timer_tasks.append(aiotools.create_timer(self.sync_container_lifecycles, 10.0)) + self.timer_tasks.append(aiotools.create_timer(self.sync_container_lifecycles, 13.0)) if abuse_report_path := self.local_config["agent"].get("abuse-report-path"): log.info( @@ -956,7 +961,7 @@ async def collect_container_stat(self, interval: float): container_ids = [] async with self.registry_lock: for kernel_id, kernel_obj in [*self.kernel_registry.items()]: - if not kernel_obj.stats_enabled: + if not kernel_obj.stats_enabled or kernel_obj.status != KernelStatus.RUNNING: continue container_ids.append(kernel_obj["container_id"]) await self.stat_ctx.collect_container_stat(container_ids) @@ -974,7 +979,7 @@ async def collect_process_stat(self, interval: float): container_ids = [] async with self.registry_lock: for kernel_id, kernel_obj in [*self.kernel_registry.items()]: - if not kernel_obj.stats_enabled: + if not kernel_obj.stats_enabled or kernel_obj.status != KernelStatus.RUNNING: continue updated_kernel_ids.append(kernel_id) container_ids.append(kernel_obj["container_id"]) @@ -999,6 +1004,7 @@ async def _handle_start_event(self, ev: ContainerLifecycleEvent) -> None: kernel_obj = self.kernel_registry.get(ev.kernel_id) if kernel_obj is not None: kernel_obj.stats_enabled = True + kernel_obj.status = KernelStatus.RUNNING async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: try: @@ -1006,12 +1012,13 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: assert current_task is not None if ev.kernel_id not in self._ongoing_destruction_tasks: self._ongoing_destruction_tasks[ev.kernel_id] = current_task - self.terminating_kernels.add(ev.kernel_id) async with self.registry_lock: kernel_obj = self.kernel_registry.get(ev.kernel_id) if kernel_obj is None: log.warning( - "destroy_kernel(k:{0}) kernel missing (already dead?)", ev.kernel_id + "destroy_kernel(k:{0}, c:{1}) kernel missing (already dead?)", + ev.kernel_id, + ev.container_id, ) if ev.container_id is None: await self.reconstruct_resource_usage() @@ -1028,6 +1035,7 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: return else: kernel_obj.stats_enabled = False + kernel_obj.status = KernelStatus.TERMINATING kernel_obj.termination_reason = ev.reason if kernel_obj.runner is not None: await kernel_obj.runner.close() @@ -1039,18 +1047,17 @@ async def _handle_destroy_event(self, ev: ContainerLifecycleEvent) -> None: ev.done_future.set_exception(e) raise finally: - if ev.container_id is not None: - await self.container_lifecycle_queue.put( - ContainerLifecycleEvent( - ev.kernel_id, - ev.session_id, - ev.container_id, - LifecycleEvent.CLEAN, - ev.reason, - suppress_events=ev.suppress_events, - done_future=ev.done_future, - ), - ) + await self.container_lifecycle_queue.put( + ContainerLifecycleEvent( + ev.kernel_id, + ev.session_id, + ev.container_id, + LifecycleEvent.CLEAN, + ev.reason, + suppress_events=ev.suppress_events, + done_future=ev.done_future, + ), + ) except asyncio.CancelledError: pass except Exception: @@ -1101,7 +1108,6 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: self.port_pool.update(restored_ports) await kernel_obj.close() finally: - self.terminating_kernels.discard(ev.kernel_id) if restart_tracker := self.restarting_kernels.get(ev.kernel_id, None): restart_tracker.destroy_event.set() else: @@ -1113,7 +1119,11 @@ async def _handle_clean_event(self, ev: ContainerLifecycleEvent) -> None: ), ) # Notify cleanup waiters after all state updates. - if kernel_obj is not None and kernel_obj.clean_event is not None: + if ( + kernel_obj is not None + and kernel_obj.clean_event is not None + and not kernel_obj.clean_event.done() + ): kernel_obj.clean_event.set_result(None) if ev.done_future is not None and not ev.done_future.done(): ev.done_future.set_result(None) @@ -1266,71 +1276,98 @@ async def sync_container_lifecycles(self, interval: float) -> None: alive_kernels: Dict[KernelId, ContainerId] = {} kernel_session_map: Dict[KernelId, SessionId] = {} own_kernels: dict[KernelId, ContainerId] = {} - terminated_kernels = {} + terminated_kernels: dict[KernelId, ContainerLifecycleEvent] = {} - async with self.registry_lock: + def _get_session_id(container: Container) -> SessionId | None: + _session_id = container.labels.get("ai.backend.session-id") try: - # Check if: there are dead containers - for kernel_id, container in await self.enumerate_containers(DEAD_STATUS_SET): - if ( - kernel_id in self.restarting_kernels - or kernel_id in self.terminating_kernels - ): - continue - log.info( - "detected dead container during lifeycle sync (k:{}, c:{})", - kernel_id, - container.id, - ) - session_id = SessionId(UUID(container.labels["ai.backend.session-id"])) - terminated_kernels[kernel_id] = ContainerLifecycleEvent( - kernel_id, - session_id, - known_kernels[kernel_id], - LifecycleEvent.CLEAN, - KernelLifecycleEventReason.SELF_TERMINATED, - ) - for kernel_id, container in await self.enumerate_containers(ACTIVE_STATUS_SET): - alive_kernels[kernel_id] = container.id - session_id = SessionId(UUID(container.labels["ai.backend.session-id"])) - kernel_session_map[kernel_id] = session_id - own_kernels[kernel_id] = container.id - for kernel_id, kernel_obj in self.kernel_registry.items(): - known_kernels[kernel_id] = kernel_obj["container_id"] - session_id = kernel_obj.session_id - kernel_session_map[kernel_id] = session_id - # Check if: kernel_registry has the container but it's gone. - for kernel_id in known_kernels.keys() - alive_kernels.keys(): - if ( - kernel_id in self.restarting_kernels - or kernel_id in self.terminating_kernels - ): - continue - terminated_kernels[kernel_id] = ContainerLifecycleEvent( - kernel_id, - kernel_session_map[kernel_id], - known_kernels[kernel_id], - LifecycleEvent.CLEAN, - KernelLifecycleEventReason.SELF_TERMINATED, - ) - # Check if: there are containers not spawned by me. - for kernel_id in alive_kernels.keys() - known_kernels.keys(): - if kernel_id in self.restarting_kernels: - continue - terminated_kernels[kernel_id] = ContainerLifecycleEvent( - kernel_id, - kernel_session_map[kernel_id], - alive_kernels[kernel_id], - LifecycleEvent.DESTROY, - KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER, - ) - finally: - # Enqueue the events. - for kernel_id, ev in terminated_kernels.items(): - await self.container_lifecycle_queue.put(ev) + return SessionId(UUID(_session_id)) + except ValueError: + log.warning( + f"sync_container_lifecycles() invalid session-id (cid: {container.id}, sid:{_session_id})" + ) + return None + + try: + _containers = await self.enumerate_containers(ACTIVE_STATUS_SET | DEAD_STATUS_SET) + async with self.registry_lock: + try: + # Check if: there are dead containers + dead_containers = [ + (kid, container) + for kid, container in _containers + if container.status in DEAD_STATUS_SET + ] + for kernel_id, container in dead_containers: + if kernel_id in self.restarting_kernels: + continue + log.info( + "detected dead container during lifeycle sync (k:{}, c:{})", + kernel_id, + container.id, + ) + session_id = _get_session_id(container) + if session_id is None: + continue + terminated_kernels[kernel_id] = ContainerLifecycleEvent( + kernel_id, + session_id, + container.id, + LifecycleEvent.CLEAN, + KernelLifecycleEventReason.SELF_TERMINATED, + ) + active_containers = [ + (kid, container) + for kid, container in _containers + if container.status in ACTIVE_STATUS_SET + ] + for kernel_id, container in active_containers: + alive_kernels[kernel_id] = container.id + session_id = _get_session_id(container) + if session_id is None: + continue + kernel_session_map[kernel_id] = session_id + own_kernels[kernel_id] = container.id + for kernel_id, kernel_obj in self.kernel_registry.items(): + known_kernels[kernel_id] = kernel_obj["container_id"] + session_id = kernel_obj.session_id + kernel_session_map[kernel_id] = session_id + # Check if: kernel_registry has the container but it's gone. + for kernel_id in known_kernels.keys() - alive_kernels.keys(): + if ( + kernel_id in self.restarting_kernels + or self.kernel_registry[kernel_id].status == KernelStatus.PREPARING + ): + continue + terminated_kernels[kernel_id] = ContainerLifecycleEvent( + kernel_id, + kernel_session_map[kernel_id], + known_kernels[kernel_id], + LifecycleEvent.CLEAN, + KernelLifecycleEventReason.SELF_TERMINATED, + ) + # Check if: there are containers already deleted from my registry or not spawned by me. + for kernel_id in alive_kernels.keys() - known_kernels.keys(): + if kernel_id in self.restarting_kernels: + continue + terminated_kernels[kernel_id] = ContainerLifecycleEvent( + kernel_id, + kernel_session_map[kernel_id], + alive_kernels[kernel_id], + LifecycleEvent.DESTROY, + KernelLifecycleEventReason.TERMINATED_UNKNOWN_CONTAINER, + ) + finally: + # Enqueue the events. + for kernel_id, ev in terminated_kernels.items(): + await self.container_lifecycle_queue.put(ev) - # Set container count - await self.set_container_count(len(own_kernels.keys())) + # Set container count + await self.set_container_count(len(own_kernels.keys())) + except asyncio.TimeoutError: + log.warning("sync_container_lifecycles() timeout, continuing") + except Exception as e: + log.exception(f"sync_container_lifecycles() failure, continuing (detail: {repr(e)})") async def set_container_count(self, container_count: int) -> None: await redis_helper.execute( @@ -1946,7 +1983,7 @@ async def create_kernel( service_ports, ) async with self.registry_lock: - self.kernel_registry[ctx.kernel_id] = kernel_obj + self.kernel_registry[kernel_id] = kernel_obj try: container_data = await ctx.start_container( kernel_obj, @@ -1958,7 +1995,7 @@ async def create_kernel( msg = e.message or "unknown" log.error( "Kernel failed to create container. Kernel is going to be destroyed." - f" (k:{ctx.kernel_id}, detail:{msg})", + f" (k:{kernel_id}, detail:{msg})", ) cid = e.container_id async with self.registry_lock: @@ -1973,17 +2010,22 @@ async def create_kernel( raise AgentError( f"Kernel failed to create container (k:{str(ctx.kernel_id)}, detail:{msg})" ) - except Exception: + except Exception as e: log.warning( - "Kernel failed to create container (k:{}). Kernel is going to be" - " unregistered.", + "Kernel failed to create container (k:{}). Kernel is going to be destroyed.", kernel_id, ) - async with self.registry_lock: - del self.kernel_registry[kernel_id] - raise + await self.inject_container_lifecycle_event( + kernel_id, + session_id, + LifecycleEvent.DESTROY, + KernelLifecycleEventReason.FAILED_TO_CREATE, + ) + raise AgentError( + f"Kernel failed to create container (k:{str(kernel_id)}, detail: {str(e)})" + ) async with self.registry_lock: - self.kernel_registry[ctx.kernel_id].data.update(container_data) + self.kernel_registry[kernel_id].data.update(container_data) await kernel_obj.init(self.event_producer) current_task = asyncio.current_task() @@ -2063,6 +2105,7 @@ async def create_kernel( }, ), ) + kernel_obj.status = KernelStatus.RUNNING if ( kernel_config["session_type"] == "batch" @@ -2501,9 +2544,14 @@ async def save_last_registry(self, force=False) -> None: return # don't save too frequently var_base_path = self.local_config["agent"]["var-base-path"] last_registry_file = f"last_registry.{self.local_instance_id}.dat" + running_kernel_registry = { + kid: kernel_obj + for kid, kernel_obj in self.kernel_registry.items() + if kernel_obj.status == KernelStatus.RUNNING + } try: with open(var_base_path / last_registry_file, "wb") as f: - pickle.dump(self.kernel_registry, f) + pickle.dump(running_kernel_registry, f) self.last_registry_written_time = now log.debug("saved {}", last_registry_file) except Exception as e: diff --git a/src/ai/backend/agent/docker/agent.py b/src/ai/backend/agent/docker/agent.py index aa979e60731..4786cd873c2 100644 --- a/src/ai/backend/agent/docker/agent.py +++ b/src/ai/backend/agent/docker/agent.py @@ -10,6 +10,7 @@ import signal import struct import sys +from collections.abc import Mapping from decimal import Decimal from functools import partial from io import StringIO @@ -22,13 +23,13 @@ FrozenSet, List, Literal, - Mapping, MutableMapping, Optional, Sequence, Set, Tuple, Union, + cast, ) from uuid import UUID @@ -120,6 +121,30 @@ def container_from_docker_container(src: DockerContainer) -> Container: ) +async def _clean_scratch( + loop: asyncio.AbstractEventLoop, + scratch_type: str, + scratch_root: Path, + kernel_id: KernelId, +) -> None: + scratch_dir = scratch_root / str(kernel_id) + tmp_dir = scratch_root / f"{kernel_id}_tmp" + try: + if sys.platform.startswith("linux") and scratch_type == "memory": + await destroy_scratch_filesystem(scratch_dir) + await destroy_scratch_filesystem(tmp_dir) + await loop.run_in_executor(None, shutil.rmtree, scratch_dir) + await loop.run_in_executor(None, shutil.rmtree, tmp_dir) + elif sys.platform.startswith("linux") and scratch_type == "hostfile": + await destroy_loop_filesystem(scratch_root, kernel_id) + else: + await loop.run_in_executor(None, shutil.rmtree, scratch_dir) + except CalledProcessError: + pass + except FileNotFoundError: + pass + + def _DockerError_reduce(self): return ( type(self), @@ -851,6 +876,18 @@ async def start_container( if self.local_config["debug"]["log-kernel-config"]: log.debug("full container config: {!r}", pretty(container_config)) + async def _rollback_container_creation() -> None: + await _clean_scratch( + loop, + self.local_config["container"]["scratch-type"], + self.local_config["container"]["scratch-root"], + self.kernel_id, + ) + self.port_pool.update(host_ports) + async with self.resource_lock: + for dev_name, device_alloc in resource_spec.allocations.items(): + self.computers[dev_name].alloc_map.free(device_alloc) + # We are all set! Create and start the container. async with closing_async(Docker()) as docker: container: Optional[DockerContainer] = None @@ -859,7 +896,7 @@ async def start_container( config=container_config, name=kernel_name ) assert container is not None - cid = container._id + cid = cast(str, container._id) resource_spec.container_id = cid # Write resource.txt again to update the container id. with open(self.config_dir / "resource.txt", "w") as f: @@ -873,52 +910,47 @@ async def start_container( kvpairs = await computer_ctx.instance.generate_resource_data(device_alloc) for k, v in kvpairs.items(): await writer.write(f"{k}={v}\n") - - await container.start() - - if self.internal_data.get("sudo_session_enabled", False): - exec = await container.exec( - [ - # file ownership is guaranteed to be set as root:root since command is executed on behalf of root user - "sh", - "-c", - 'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work', - ], - user="root", - ) - shell_response = await exec.start(detach=True) - if shell_response: - raise ContainerCreationError( - container_id=cid, - message=f"sudoers provision failed: {shell_response.decode()}", - ) except asyncio.CancelledError: if container is not None: raise ContainerCreationError( - container_id=cid, message="Container creation cancelled" + container_id=container._id, message="Container creation cancelled" ) raise - except Exception: - # Oops, we have to restore the allocated resources! - scratch_type = self.local_config["container"]["scratch-type"] - scratch_root = self.local_config["container"]["scratch-root"] - if sys.platform.startswith("linux") and scratch_type == "memory": - await destroy_scratch_filesystem(self.scratch_dir) - await destroy_scratch_filesystem(self.tmp_dir) - await loop.run_in_executor(None, shutil.rmtree, self.scratch_dir) - await loop.run_in_executor(None, shutil.rmtree, self.tmp_dir) - elif sys.platform.startswith("linux") and scratch_type == "hostfile": - await destroy_loop_filesystem(scratch_root, self.kernel_id) - else: - await loop.run_in_executor(None, shutil.rmtree, self.scratch_dir) - self.port_pool.update(host_ports) - async with self.resource_lock: - for dev_name, device_alloc in resource_spec.allocations.items(): - self.computers[dev_name].alloc_map.free(device_alloc) + except Exception as e: + await _rollback_container_creation() if container is not None: - raise ContainerCreationError(container_id=cid, message="unknown") + raise ContainerCreationError( + container_id=container._id, message=f"unknown. {repr(e)}" + ) raise + try: + await container.start() + except asyncio.CancelledError: + await _rollback_container_creation() + raise ContainerCreationError(container_id=cid, message="Container start cancelled") + except Exception as e: + await _rollback_container_creation() + raise ContainerCreationError(container_id=cid, message=f"unknown. {repr(e)}") + + if self.internal_data.get("sudo_session_enabled", False): + exec = await container.exec( + [ + # file ownership is guaranteed to be set as root:root since command is executed on behalf of root user + "sh", + "-c", + 'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work', + ], + user="root", + ) + shell_response = await exec.start(detach=True) + if shell_response: + await _rollback_container_creation() + raise ContainerCreationError( + container_id=cid, + message=f"sudoers provision failed: {shell_response.decode()}", + ) + additional_network_names: Set[str] = set() for dev_name, device_alloc in resource_spec.allocations.items(): n = await self.computers[dev_name].instance.get_docker_networks(device_alloc) @@ -1500,24 +1532,12 @@ async def log_iter(): log.warning("container deletion timeout (k:{}, c:{})", kernel_id, container_id) if not restarting: - scratch_type = self.local_config["container"]["scratch-type"] - scratch_root = self.local_config["container"]["scratch-root"] - scratch_dir = scratch_root / str(kernel_id) - tmp_dir = scratch_root / f"{kernel_id}_tmp" - try: - if sys.platform.startswith("linux") and scratch_type == "memory": - await destroy_scratch_filesystem(scratch_dir) - await destroy_scratch_filesystem(tmp_dir) - await loop.run_in_executor(None, shutil.rmtree, scratch_dir) - await loop.run_in_executor(None, shutil.rmtree, tmp_dir) - elif sys.platform.startswith("linux") and scratch_type == "hostfile": - await destroy_loop_filesystem(scratch_root, kernel_id) - else: - await loop.run_in_executor(None, shutil.rmtree, scratch_dir) - except CalledProcessError: - pass - except FileNotFoundError: - pass + await _clean_scratch( + loop, + self.local_config["container"]["scratch-type"], + self.local_config["container"]["scratch-root"], + kernel_id, + ) async def create_local_network(self, network_name: str) -> None: async with closing_async(Docker()) as docker: diff --git a/src/ai/backend/agent/kernel.py b/src/ai/backend/agent/kernel.py index b6fb207f830..a7bad2e5071 100644 --- a/src/ai/backend/agent/kernel.py +++ b/src/ai/backend/agent/kernel.py @@ -54,7 +54,7 @@ from .exception import UnsupportedBaseDistroError from .resources import KernelResourceSpec -from .types import AgentEventData +from .types import AgentEventData, KernelStatus log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] @@ -176,6 +176,7 @@ class AbstractKernel(UserDict, aobject, metaclass=ABCMeta): stats_enabled: bool # FIXME: apply TypedDict to data in Python 3.8 environ: Mapping[str, Any] + status: KernelStatus _tasks: Set[asyncio.Task] @@ -212,6 +213,7 @@ def __init__( self.environ = environ self.runner = None self.container_id = None + self.status = KernelStatus.PREPARING async def init(self, event_producer: EventProducer) -> None: log.debug( @@ -232,6 +234,9 @@ def __getstate__(self) -> Mapping[str, Any]: return props def __setstate__(self, props) -> None: + if "status" not in props: + # We pickle "running" kernels, not "preparing" or "terminating" ones + props["status"] = KernelStatus.RUNNING self.__dict__.update(props) # agent_config is set by the pickle.loads() caller. self.clean_event = None diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index 74f1a810377..9795d9a18ea 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -12,6 +12,7 @@ import signal import sys from collections import OrderedDict, defaultdict +from collections.abc import Collection from ipaddress import _BaseAddress as BaseIPAddress from ipaddress import ip_network from pathlib import Path @@ -58,6 +59,7 @@ ) from ai.backend.common.logging import BraceStyleAdapter, Logger from ai.backend.common.types import ( + AgentKernelRegistryByStatus, ClusterInfo, CommitStatus, HardwareMetadata, @@ -477,6 +479,116 @@ async def sync_kernel_registry( suppress_events=True, ) + @rpc_function + @collect_error + async def sync_and_get_kernels( + self, + preparing_kernels: Collection[str], + pulling_kernels: Collection[str], + running_kernels: Collection[str], + terminating_kernels: Collection[str], + ) -> dict[str, Any]: + """ + Sync kernel_registry and containers to truth data + and return kernel infos whose status is irreversible. + """ + + actual_terminating_kernels: list[tuple[KernelId, str]] = [] + actual_terminated_kernels: list[tuple[KernelId, str]] = [] + + async with self.agent.registry_lock: + actual_existing_kernels = [ + kid + for kid in self.agent.kernel_registry + if kid not in self.agent.terminating_kernels + ] + + for raw_kernel_id in running_kernels: + kernel_id = KernelId(UUID(raw_kernel_id)) + if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: + if kernel_id in self.agent.terminating_kernels: + actual_terminating_kernels.append(( + kernel_id, + str( + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED + ), + )) + else: + actual_terminated_kernels.append(( + kernel_id, + str(KernelLifecycleEventReason.ALREADY_TERMINATED), + )) + + for raw_kernel_id in terminating_kernels: + kernel_id = KernelId(UUID(raw_kernel_id)) + if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None: + if kernel_id not in self.agent.terminating_kernels: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.ALREADY_TERMINATED, + suppress_events=False, + ) + else: + actual_terminated_kernels.append(( + kernel_id, + str(KernelLifecycleEventReason.ALREADY_TERMINATED), + )) + + for kernel_id, kernel_obj in self.agent.kernel_registry.items(): + if kernel_id in terminating_kernels: + if kernel_id not in self.agent.terminating_kernels: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER, + suppress_events=False, + ) + elif kernel_id in running_kernels: + pass + elif kernel_id in preparing_kernels: + # kernel_registry may not have `preparing` state kernels. + pass + elif kernel_id in pulling_kernels: + # kernel_registry does not have `pulling` state kernels. + # Let's just skip it. + pass + else: + # This kernel is not alive according to the truth data. + # The kernel should be destroyed or cleaned + if kernel_id in self.agent.terminating_kernels: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.CLEAN, + kernel_obj.termination_reason + or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER, + suppress_events=True, + ) + elif kernel_id in self.agent.restarting_kernels: + pass + else: + await self.agent.inject_container_lifecycle_event( + kernel_id, + kernel_obj.session_id, + LifecycleEvent.DESTROY, + kernel_obj.termination_reason + or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER, + suppress_events=True, + ) + + result = AgentKernelRegistryByStatus( + actual_existing_kernels, + actual_terminating_kernels, + actual_terminated_kernels, + ) + return result.to_json() + @rpc_function @collect_error async def create_kernels( diff --git a/src/ai/backend/agent/types.py b/src/ai/backend/agent/types.py index 730beb291a0..0089dde3a51 100644 --- a/src/ai/backend/agent/types.py +++ b/src/ai/backend/agent/types.py @@ -64,6 +64,16 @@ class Container: backend_obj: Any # used to keep the backend-specific data +class KernelStatus(enum.StrEnum): + """ + A type to track the status of `AbstractKernel` objects. + """ + + PREPARING = enum.auto() + RUNNING = enum.auto() + TERMINATING = enum.auto() + + class LifecycleEvent(enum.IntEnum): DESTROY = 0 CLEAN = 1 diff --git a/src/ai/backend/common/types.py b/src/ai/backend/common/types.py index f8b53f8b346..183c16ad2fa 100644 --- a/src/ai/backend/common/types.py +++ b/src/ai/backend/common/types.py @@ -1279,3 +1279,29 @@ class ModelServiceProfile: ), RuntimeVariant.CMD: ModelServiceProfile(name="Predefined Image Command"), } + + +@dataclass +class AgentKernelRegistryByStatus(JSONSerializableMixin): + KernelTerminationInfo = tuple[KernelId, str] + + actual_existing_kernels: list[KernelId] + actual_terminating_kernels: list[KernelTerminationInfo] + actual_terminated_kernels: list[KernelTerminationInfo] + + def to_json(self) -> dict[str, list[KernelId]]: + return dataclasses.asdict(self) + + @classmethod + def from_json(cls, obj: Mapping[str, Any]) -> AgentKernelRegistryByStatus: + return cls(**cls.as_trafaret().check(obj)) + + @classmethod + def as_trafaret(cls) -> t.Trafaret: + from . import validators as tx + + return t.Dict({ + t.Key("actual_existing_kernels"): tx.ToList(t.String), + t.Key("actual_terminating_kernels"): tx.ToList(t.Tuple(t.String, t.String)), + t.Key("actual_terminated_kernels"): tx.ToList(t.Tuple(t.String, t.String)), + }) diff --git a/src/ai/backend/common/validators.py b/src/ai/backend/common/validators.py index 6f50415b3c0..844bb5917e2 100644 --- a/src/ai/backend/common/validators.py +++ b/src/ai/backend/common/validators.py @@ -218,6 +218,23 @@ def check_and_return(self, value: Any) -> T_enum: self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) +class EnumList(t.Trafaret, Generic[T_enum]): + def __init__(self, enum_cls: Type[T_enum], *, use_name: bool = False) -> None: + self.enum_cls = enum_cls + self.use_name = use_name + + def check_and_return(self, value: Any) -> list[T_enum]: + try: + if self.use_name: + return [self.enum_cls[val] for val in value] + else: + return [self.enum_cls(val) for val in value] + except TypeError: + self._failure("cannot parse value into list", value=value) + except (KeyError, ValueError): + self._failure(f"value is not a valid member of {self.enum_cls.__name__}", value=value) + + class JSONString(t.Trafaret): def check_and_return(self, value: Any) -> dict: try: @@ -673,6 +690,19 @@ def check_and_return(self, value: Any) -> set: self._failure("value must be Iterable") +class ToList(t.List): + def check_common(self, value: Any) -> None: + return super().check_common(self.check_and_return(value)) # type: ignore[misc] + + def check_and_return(self, value: Any) -> list: + try: + return list(value) + except TypeError: + self._failure( + f"Cannot parse {type(value)} to list. value must be Iterable", value=value + ) + + class Delay(t.Trafaret): """ Convert a float or a tuple of 2 floats into a random generated float value diff --git a/src/ai/backend/manager/api/session.py b/src/ai/backend/manager/api/session.py index 41c7be61128..0575f605752 100644 --- a/src/ai/backend/manager/api/session.py +++ b/src/ai/backend/manager/api/session.py @@ -43,7 +43,11 @@ import trafaret as t from aiohttp import hdrs, web from dateutil.tz import tzutc -from pydantic import BaseModel, Field +from pydantic import ( + AliasChoices, + BaseModel, + Field, +) from redis.asyncio import Redis from sqlalchemy.orm import noload, selectinload from sqlalchemy.sql.expression import null, true @@ -967,6 +971,36 @@ async def sync_agent_registry(request: web.Request, params: Any) -> web.StreamRe return web.json_response({}, status=200) +class SyncAgentResourceRequestModel(BaseModel): + agent_id: AgentId = Field( + validation_alias=AliasChoices("agent_id", "agent"), + description="Target agent id to sync resource.", + ) + + +@server_status_required(ALL_ALLOWED) +@auth_required +@pydantic_params_api_handler(SyncAgentResourceRequestModel) +async def sync_agent_resource( + request: web.Request, params: SyncAgentResourceRequestModel +) -> web.Response: + root_ctx: RootContext = request.app["_root.context"] + requester_access_key, owner_access_key = await get_access_key_scopes(request) + + agent_id = params.agent_id + log.info( + "SYNC_AGENT_RESOURCE (ak:{}/{}, a:{})", requester_access_key, owner_access_key, agent_id + ) + + async with root_ctx.db.begin() as db_conn: + try: + await root_ctx.registry.sync_agent_resource(db_conn, [agent_id]) + except BackendError: + log.exception("SYNC_AGENT_RESOURCE: exception") + raise + return web.Response(status=204) + + @server_status_required(ALL_ALLOWED) @auth_required @check_api_params( @@ -2274,6 +2308,7 @@ def create_app( cors.add(app.router.add_route("POST", "/_/create-cluster", create_cluster)) cors.add(app.router.add_route("GET", "/_/match", match_sessions)) cors.add(app.router.add_route("POST", "/_/sync-agent-registry", sync_agent_registry)) + cors.add(app.router.add_route("POST", "/_/sync-agent-resource", sync_agent_resource)) session_resource = cors.add(app.router.add_resource(r"/{session_name}")) cors.add(session_resource.add_route("GET", get_info)) cors.add(session_resource.add_route("PATCH", restart)) diff --git a/src/ai/backend/manager/config.py b/src/ai/backend/manager/config.py index 1c5c9114ff9..93a34949f41 100644 --- a/src/ai/backend/manager/config.py +++ b/src/ai/backend/manager/config.py @@ -225,6 +225,7 @@ from .api.exceptions import ObjectNotFound, ServerMisconfiguredError from .models.session import SessionStatus from .pglock import PgAdvisoryLock +from .types import DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS, AgentResourceSyncTrigger log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined] @@ -296,6 +297,9 @@ "agent-selection-resource-priority", default=["cuda", "rocm", "tpu", "cpu", "mem"], ): t.List(t.String), + t.Key( + "agent-resource-sync-trigger", default=DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS + ): tx.EnumList(AgentResourceSyncTrigger), t.Key("importer-image", default="lablup/importer:manylinux2010"): t.String, t.Key("max-wsmsg-size", default=16 * (2**20)): t.ToInt, # default: 16 MiB tx.AliasedKey(["aiomonitor-termui-port", "aiomonitor-port"], default=48100): t.ToInt[ diff --git a/src/ai/backend/manager/registry.py b/src/ai/backend/manager/registry.py index 456959a5dc8..68f9a2c6382 100644 --- a/src/ai/backend/manager/registry.py +++ b/src/ai/backend/manager/registry.py @@ -12,6 +12,7 @@ import uuid import zlib from collections import defaultdict +from collections.abc import Collection from datetime import datetime from decimal import Decimal from io import BytesIO @@ -45,6 +46,7 @@ from redis.asyncio import Redis from sqlalchemy.exc import DBAPIError from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import AsyncSession as SASession from sqlalchemy.orm import load_only, noload, selectinload from sqlalchemy.orm.exc import NoResultFound from yarl import URL @@ -87,6 +89,7 @@ AbuseReport, AccessKey, AgentId, + AgentKernelRegistryByStatus, BinarySize, ClusterInfo, ClusterMode, @@ -173,7 +176,7 @@ reenter_txn_session, sql_json_merge, ) -from .types import UserScope +from .types import AgentResourceSyncTrigger, UserScope if TYPE_CHECKING: from sqlalchemy.engine.row import Row @@ -183,6 +186,7 @@ from ai.backend.common.events import EventDispatcher, EventProducer from .agent_cache import AgentRPCCache + from .exceptions import ErrorDetail from .models.storage import StorageSessionManager from .scheduler.types import AgentAllocationContext, KernelAgentBinding, SchedulingContext @@ -1658,6 +1662,10 @@ async def _create_kernels_in_one_agent( is_local = image_info["is_local"] resource_policy: KeyPairResourcePolicyRow = image_info["resource_policy"] auto_pull = image_info["auto_pull"] + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-trigger"], + ) assert agent_alloc_ctx.agent_id is not None assert scheduled_session.id is not None @@ -1676,6 +1684,10 @@ async def _update_kernel() -> None: await execute_with_retry(_update_kernel) + if AgentResourceSyncTrigger.BEFORE_KERNEL_CREATION in agent_resource_sync_trigger: + async with self.db.begin() as db_conn: + await self.sync_agent_resource(db_conn, [agent_alloc_ctx.agent_id]) + async with self.agent_cache.rpc_context( agent_alloc_ctx.agent_id, order_key=str(scheduled_session.id), @@ -1751,9 +1763,27 @@ async def _update_kernel() -> None: except (asyncio.TimeoutError, asyncio.CancelledError): log.warning("_create_kernels_in_one_agent(s:{}) cancelled", scheduled_session.id) except Exception as e: + ex = e + err_info = convert_to_status_data(ex, self.debug) + + def _has_insufficient_resource_err(_err_info: ErrorDetail) -> bool: + if _err_info["name"] == "InsufficientResource": + return True + if (sub_errors := _err_info.get("collection")) is not None: + for suberr in sub_errors: + if _has_insufficient_resource_err(suberr): + return True + return False + + if AgentResourceSyncTrigger.ON_CREATION_FAILURE in agent_resource_sync_trigger: + if _has_insufficient_resource_err(err_info["error"]): + async with self.db.begin() as db_conn: + await self.sync_agent_resource( + db_conn, + [agent_alloc_ctx.agent_id], + ) # The agent has already cancelled or issued the destruction lifecycle event # for this batch of kernels. - ex = e for binding in items: kernel_id = binding.kernel.id @@ -1777,7 +1807,7 @@ async def _update_failure() -> None: ), # ["PULLING", "PREPARING"] }, ), - status_data=convert_to_status_data(ex, self.debug), + status_data=err_info, ) ) await db_sess.execute(query) @@ -3056,6 +3086,102 @@ async def sync_agent_kernel_registry(self, agent_id: AgentId) -> None: (str(kernel.id), str(kernel.session_id)) for kernel in grouped_kernels ]) + async def _sync_agent_resource_and_get_kerenels( + self, + agent_id: AgentId, + preparing_kernels: Collection[KernelId], + pulling_kernels: Collection[KernelId], + running_kernels: Collection[KernelId], + terminating_kernels: Collection[KernelId], + ) -> AgentKernelRegistryByStatus: + async with self.agent_cache.rpc_context(agent_id) as rpc: + resp: dict[str, Any] = await rpc.call.sync_and_get_kernels( + preparing_kernels, + pulling_kernels, + running_kernels, + terminating_kernels, + ) + return AgentKernelRegistryByStatus.from_json(resp) + + async def sync_agent_resource( + self, + db_connection: SAConnection, + agent_ids: Collection[AgentId], + ) -> dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError]: + result: dict[AgentId, AgentKernelRegistryByStatus | MultiAgentError] = {} + agent_kernel_by_status: dict[AgentId, dict[str, list[KernelId]]] = {} + stmt = ( + sa.select(AgentRow) + .where(AgentRow.id.in_(agent_ids)) + .options( + selectinload( + AgentRow.kernels.and_( + KernelRow.status.in_([ + KernelStatus.PREPARING, + KernelStatus.PULLING, + KernelStatus.RUNNING, + KernelStatus.TERMINATING, + ]) + ), + ).options(load_only(KernelRow.id, KernelRow.status)) + ) + ) + async with SASession(bind=db_connection) as db_session: + for _agent_row in await db_session.scalars(stmt): + agent_row = cast(AgentRow, _agent_row) + preparing_kernels: list[KernelId] = [] + pulling_kernels: list[KernelId] = [] + running_kernels: list[KernelId] = [] + terminating_kernels: list[KernelId] = [] + for kernel in agent_row.kernels: + kernel_status = cast(KernelStatus, kernel.status) + match kernel_status: + case KernelStatus.PREPARING: + preparing_kernels.append(KernelId(kernel.id)) + case KernelStatus.PULLING: + pulling_kernels.append(KernelId(kernel.id)) + case KernelStatus.RUNNING: + running_kernels.append(KernelId(kernel.id)) + case KernelStatus.TERMINATING: + terminating_kernels.append(KernelId(kernel.id)) + case _: + continue + agent_kernel_by_status[AgentId(agent_row.id)] = { + "preparing_kernels": preparing_kernels, + "pulling_kernels": pulling_kernels, + "running_kernels": running_kernels, + "terminating_kernels": terminating_kernels, + } + tasks = [] + for agent_id in agent_ids: + tasks.append( + self._sync_agent_resource_and_get_kerenels( + agent_id, + agent_kernel_by_status[agent_id]["preparing_kernels"], + agent_kernel_by_status[agent_id]["pulling_kernels"], + agent_kernel_by_status[agent_id]["running_kernels"], + agent_kernel_by_status[agent_id]["terminating_kernels"], + ) + ) + responses = await asyncio.gather(*tasks, return_exceptions=True) + for aid, resp in zip(agent_ids, responses): + agent_errors = [] + if isinstance(resp, aiotools.TaskGroupError): + agent_errors.extend(resp.__errors__) + elif isinstance(result, Exception): + agent_errors.append(resp) + if agent_errors: + result[aid] = MultiAgentError( + "agent(s) raise errors during kernel resource sync", + agent_errors, + ) + else: + assert isinstance( + resp, AgentKernelRegistryByStatus + ), f"response should be `AgentKernelRegistryByStatus`, not {type(resp)}" + result[aid] = resp + return result + async def mark_kernel_terminated( self, kernel_id: KernelId, diff --git a/src/ai/backend/manager/scheduler/dispatcher.py b/src/ai/backend/manager/scheduler/dispatcher.py index a138b6ad650..3117459ac6a 100644 --- a/src/ai/backend/manager/scheduler/dispatcher.py +++ b/src/ai/backend/manager/scheduler/dispatcher.py @@ -20,6 +20,7 @@ Sequence, Tuple, Union, + cast, ) import aiotools @@ -88,6 +89,7 @@ ) from ..models.utils import ExtendedAsyncSAEngine as SAEngine from ..models.utils import execute_with_retry, sql_json_increment, sql_json_merge +from ..types import AgentResourceSyncTrigger from .predicates import ( check_concurrency, check_dependencies, @@ -259,6 +261,10 @@ async def schedule( log.debug("schedule(): triggered") manager_id = self.local_config["manager"]["id"] redis_key = f"manager.{manager_id}.schedule" + agent_resource_sync_trigger = cast( + list[AgentResourceSyncTrigger], + self.local_config["manager"]["agent-resource-sync-trigger"], + ) def _pipeline(r: Redis) -> RedisPipeline: pipe = r.pipeline() @@ -287,10 +293,6 @@ def _pipeline(r: Redis) -> RedisPipeline: # as its individual steps are composed of many short-lived transactions. async with self.lock_factory(LockID.LOCKID_SCHEDULE, 60): async with self.db.begin_readonly_session() as db_sess: - # query = ( - # sa.select(ScalingGroupRow) - # .join(ScalingGroupRow.agents.and_(AgentRow.status == AgentStatus.ALIVE)) - # ) query = ( sa.select(AgentRow.scaling_group) .where(AgentRow.status == AgentStatus.ALIVE) @@ -298,22 +300,41 @@ def _pipeline(r: Redis) -> RedisPipeline: ) result = await db_sess.execute(query) schedulable_scaling_groups = [row.scaling_group for row in result.fetchall()] - for sgroup_name in schedulable_scaling_groups: - try: - await self._schedule_in_sgroup( - sched_ctx, - sgroup_name, - ) - await redis_helper.execute( - self.redis_live, - lambda r: r.hset( - redis_key, - "resource_group", + + async with self.db.begin() as db_conn: + produce_do_prepare = False + for sgroup_name in schedulable_scaling_groups: + try: + kernel_agent_bindings = await self._schedule_in_sgroup( + sched_ctx, sgroup_name, - ), - ) - except Exception as e: - log.exception("schedule({}): scheduling error!\n{}", sgroup_name, repr(e)) + ) + if kernel_agent_bindings: + produce_do_prepare = True + await redis_helper.execute( + self.redis_live, + lambda r: r.hset( + redis_key, + "resource_group", + sgroup_name, + ), + ) + except Exception as e: + log.exception( + "schedule({}): scheduling error!\n{}", sgroup_name, repr(e) + ) + else: + if ( + AgentResourceSyncTrigger.AFTER_SCHEDULING + in agent_resource_sync_trigger + and kernel_agent_bindings + ): + selected_agent_ids = [ + binding.agent_alloc_ctx.agent_id + for binding in kernel_agent_bindings + if binding.agent_alloc_ctx.agent_id is not None + ] + await self.registry.sync_agent_resource(db_conn, selected_agent_ids) await redis_helper.execute( self.redis_live, lambda r: r.hset( @@ -322,6 +343,8 @@ def _pipeline(r: Redis) -> RedisPipeline: datetime.now(tzutc()).isoformat(), ), ) + if produce_do_prepare: + await self.event_producer.produce_event(DoPrepareEvent()) except DBAPIError as e: if getattr(e.orig, "pgcode", None) == "55P03": log.info( @@ -355,7 +378,7 @@ async def _schedule_in_sgroup( self, sched_ctx: SchedulingContext, sgroup_name: str, - ) -> None: + ) -> list[KernelAgentBinding]: async def _apply_cancellation( db_sess: SASession, session_ids: list[SessionId], reason="pending-timeout" ): @@ -426,7 +449,8 @@ async def _update(): len(cancelled_sessions), ) zero = ResourceSlot() - num_scheduled = 0 + kernel_agent_bindings_in_sgroup: list[KernelAgentBinding] = [] + while len(pending_sessions) > 0: async with self.db.begin_readonly_session() as db_sess: candidate_agents = await list_schedulable_agents_by_sgroup(db_sess, sgroup_name) @@ -440,7 +464,7 @@ async def _update(): if picked_session_id is None: # no session is picked. # continue to next sgroup. - return + return [] for picked_idx, sess_ctx in enumerate(pending_sessions): if sess_ctx.id == picked_session_id: break @@ -651,7 +675,7 @@ async def _update_session_status_data() -> None: try: match schedulable_sess.cluster_mode: case ClusterMode.SINGLE_NODE: - await self._schedule_single_node_session( + kernel_agent_bindings = await self._schedule_single_node_session( sched_ctx, scheduler, sgroup_name, @@ -661,7 +685,7 @@ async def _update_session_status_data() -> None: check_results, ) case ClusterMode.MULTI_NODE: - await self._schedule_multi_node_session( + kernel_agent_bindings = await self._schedule_multi_node_session( sched_ctx, scheduler, sgroup_name, @@ -695,9 +719,9 @@ async def _update_session_status_data() -> None: # _schedule_{single,multi}_node_session() already handle general exceptions. # Proceed to the next pending session and come back later continue - num_scheduled += 1 - if num_scheduled > 0: - await self.event_producer.produce_event(DoPrepareEvent()) + else: + kernel_agent_bindings_in_sgroup.extend(kernel_agent_bindings) + return kernel_agent_bindings_in_sgroup async def _filter_agent_by_container_limit( self, candidate_agents: list[AgentRow] @@ -730,7 +754,7 @@ async def _schedule_single_node_session( sess_ctx: SessionRow, agent_selection_resource_priority: list[str], check_results: List[Tuple[str, Union[Exception, PredicateResult]]], - ) -> None: + ) -> list[KernelAgentBinding]: """ Finds and assigns an agent having resources enough to host the entire session. """ @@ -994,6 +1018,11 @@ async def _finalize_scheduled() -> None: SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id), ) + kernel_agent_bindings: list[KernelAgentBinding] = [] + for kernel_row in sess_ctx.kernels: + kernel_agent_bindings.append(KernelAgentBinding(kernel_row, agent_alloc_ctx, set())) + return kernel_agent_bindings + async def _schedule_multi_node_session( self, sched_ctx: SchedulingContext, @@ -1003,7 +1032,7 @@ async def _schedule_multi_node_session( sess_ctx: SessionRow, agent_selection_resource_priority: list[str], check_results: List[Tuple[str, Union[Exception, PredicateResult]]], - ) -> None: + ) -> list[KernelAgentBinding]: """ Finds and assigns agents having resources enough to host each kernel in the session. """ @@ -1231,6 +1260,7 @@ async def _finalize_scheduled() -> None: await self.registry.event_producer.produce_event( SessionScheduledEvent(sess_ctx.id, sess_ctx.creation_id), ) + return kernel_agent_bindings async def prepare( self, diff --git a/src/ai/backend/manager/types.py b/src/ai/backend/manager/types.py index dee2842dd62..78833ff861b 100644 --- a/src/ai/backend/manager/types.py +++ b/src/ai/backend/manager/types.py @@ -56,3 +56,14 @@ class MountOptionModel(BaseModel): MountPermission | None, Field(validation_alias=AliasChoices("permission", "perm"), default=None), ] + + +class AgentResourceSyncTrigger(enum.StrEnum): + AFTER_SCHEDULING = "after-scheduling" + BEFORE_KERNEL_CREATION = "before-kernel-creation" + ON_CREATION_FAILURE = "on-creation-failure" + + +DEFAULT_AGENT_RESOURE_SYNC_TRIGGERS: list[AgentResourceSyncTrigger] = [ + AgentResourceSyncTrigger.ON_CREATION_FAILURE +]