Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/2178.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* Let container sync loop keep alive
* Handle container creation error granularly
236 changes: 142 additions & 94 deletions src/ai/backend/agent/agent.py

Large diffs are not rendered by default.

136 changes: 78 additions & 58 deletions src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import signal
import struct
import sys
from collections.abc import Mapping
from decimal import Decimal
from functools import partial
from io import StringIO
Expand All @@ -22,13 +23,13 @@
FrozenSet,
List,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
)
from uuid import UUID

Expand Down Expand Up @@ -120,6 +121,30 @@ def container_from_docker_container(src: DockerContainer) -> Container:
)


async def _clean_scratch(
loop: asyncio.AbstractEventLoop,
scratch_type: str,
scratch_root: Path,
kernel_id: KernelId,
) -> None:
scratch_dir = scratch_root / str(kernel_id)
tmp_dir = scratch_root / f"{kernel_id}_tmp"
try:
if sys.platform.startswith("linux") and scratch_type == "memory":
await destroy_scratch_filesystem(scratch_dir)
await destroy_scratch_filesystem(tmp_dir)
await loop.run_in_executor(None, shutil.rmtree, scratch_dir)
await loop.run_in_executor(None, shutil.rmtree, tmp_dir)
elif sys.platform.startswith("linux") and scratch_type == "hostfile":
await destroy_loop_filesystem(scratch_root, kernel_id)
else:
await loop.run_in_executor(None, shutil.rmtree, scratch_dir)
except CalledProcessError:
pass
except FileNotFoundError:
pass


def _DockerError_reduce(self):
return (
type(self),
Expand Down Expand Up @@ -851,6 +876,18 @@ async def start_container(
if self.local_config["debug"]["log-kernel-config"]:
log.debug("full container config: {!r}", pretty(container_config))

async def _rollback_container_creation() -> None:
await _clean_scratch(
loop,
self.local_config["container"]["scratch-type"],
self.local_config["container"]["scratch-root"],
self.kernel_id,
)
self.port_pool.update(host_ports)
async with self.resource_lock:
for dev_name, device_alloc in resource_spec.allocations.items():
self.computers[dev_name].alloc_map.free(device_alloc)

# We are all set! Create and start the container.
async with closing_async(Docker()) as docker:
container: Optional[DockerContainer] = None
Expand All @@ -859,7 +896,7 @@ async def start_container(
config=container_config, name=kernel_name
)
assert container is not None
cid = container._id
cid = cast(str, container._id)
resource_spec.container_id = cid
# Write resource.txt again to update the container id.
with open(self.config_dir / "resource.txt", "w") as f:
Expand All @@ -873,52 +910,47 @@ async def start_container(
kvpairs = await computer_ctx.instance.generate_resource_data(device_alloc)
for k, v in kvpairs.items():
await writer.write(f"{k}={v}\n")

await container.start()

if self.internal_data.get("sudo_session_enabled", False):
exec = await container.exec(
[
# file ownership is guaranteed to be set as root:root since command is executed on behalf of root user
"sh",
"-c",
'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work',
],
user="root",
)
shell_response = await exec.start(detach=True)
if shell_response:
raise ContainerCreationError(
container_id=cid,
message=f"sudoers provision failed: {shell_response.decode()}",
)
except asyncio.CancelledError:
if container is not None:
raise ContainerCreationError(
container_id=cid, message="Container creation cancelled"
container_id=container._id, message="Container creation cancelled"
)
raise
except Exception:
# Oops, we have to restore the allocated resources!
scratch_type = self.local_config["container"]["scratch-type"]
scratch_root = self.local_config["container"]["scratch-root"]
if sys.platform.startswith("linux") and scratch_type == "memory":
await destroy_scratch_filesystem(self.scratch_dir)
await destroy_scratch_filesystem(self.tmp_dir)
await loop.run_in_executor(None, shutil.rmtree, self.scratch_dir)
await loop.run_in_executor(None, shutil.rmtree, self.tmp_dir)
elif sys.platform.startswith("linux") and scratch_type == "hostfile":
await destroy_loop_filesystem(scratch_root, self.kernel_id)
else:
await loop.run_in_executor(None, shutil.rmtree, self.scratch_dir)
self.port_pool.update(host_ports)
async with self.resource_lock:
for dev_name, device_alloc in resource_spec.allocations.items():
self.computers[dev_name].alloc_map.free(device_alloc)
except Exception as e:
await _rollback_container_creation()
if container is not None:
raise ContainerCreationError(container_id=cid, message="unknown")
raise ContainerCreationError(
container_id=container._id, message=f"unknown. {repr(e)}"
)
raise

try:
await container.start()
except asyncio.CancelledError:
await _rollback_container_creation()
raise ContainerCreationError(container_id=cid, message="Container start cancelled")
except Exception as e:
await _rollback_container_creation()
raise ContainerCreationError(container_id=cid, message=f"unknown. {repr(e)}")

if self.internal_data.get("sudo_session_enabled", False):
exec = await container.exec(
[
# file ownership is guaranteed to be set as root:root since command is executed on behalf of root user
"sh",
"-c",
'mkdir -p /etc/sudoers.d && echo "work ALL=(ALL:ALL) NOPASSWD:ALL" > /etc/sudoers.d/01-bai-work',
],
user="root",
)
shell_response = await exec.start(detach=True)
if shell_response:
await _rollback_container_creation()
raise ContainerCreationError(
container_id=cid,
message=f"sudoers provision failed: {shell_response.decode()}",
)

additional_network_names: Set[str] = set()
for dev_name, device_alloc in resource_spec.allocations.items():
n = await self.computers[dev_name].instance.get_docker_networks(device_alloc)
Expand Down Expand Up @@ -1500,24 +1532,12 @@ async def log_iter():
log.warning("container deletion timeout (k:{}, c:{})", kernel_id, container_id)

if not restarting:
scratch_type = self.local_config["container"]["scratch-type"]
scratch_root = self.local_config["container"]["scratch-root"]
scratch_dir = scratch_root / str(kernel_id)
tmp_dir = scratch_root / f"{kernel_id}_tmp"
try:
if sys.platform.startswith("linux") and scratch_type == "memory":
await destroy_scratch_filesystem(scratch_dir)
await destroy_scratch_filesystem(tmp_dir)
await loop.run_in_executor(None, shutil.rmtree, scratch_dir)
await loop.run_in_executor(None, shutil.rmtree, tmp_dir)
elif sys.platform.startswith("linux") and scratch_type == "hostfile":
await destroy_loop_filesystem(scratch_root, kernel_id)
else:
await loop.run_in_executor(None, shutil.rmtree, scratch_dir)
except CalledProcessError:
pass
except FileNotFoundError:
pass
await _clean_scratch(
loop,
self.local_config["container"]["scratch-type"],
self.local_config["container"]["scratch-root"],
kernel_id,
)

async def create_local_network(self, network_name: str) -> None:
async with closing_async(Docker()) as docker:
Expand Down
7 changes: 6 additions & 1 deletion src/ai/backend/agent/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

from .exception import UnsupportedBaseDistroError
from .resources import KernelResourceSpec
from .types import AgentEventData
from .types import AgentEventData, KernelStatus

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]

Expand Down Expand Up @@ -176,6 +176,7 @@ class AbstractKernel(UserDict, aobject, metaclass=ABCMeta):
stats_enabled: bool
# FIXME: apply TypedDict to data in Python 3.8
environ: Mapping[str, Any]
status: KernelStatus

_tasks: Set[asyncio.Task]

Expand Down Expand Up @@ -212,6 +213,7 @@ def __init__(
self.environ = environ
self.runner = None
self.container_id = None
self.status = KernelStatus.PREPARING

async def init(self, event_producer: EventProducer) -> None:
log.debug(
Expand All @@ -232,6 +234,9 @@ def __getstate__(self) -> Mapping[str, Any]:
return props

def __setstate__(self, props) -> None:
if "status" not in props:
# We pickle "running" kernels, not "preparing" or "terminating" ones
props["status"] = KernelStatus.RUNNING
self.__dict__.update(props)
# agent_config is set by the pickle.loads() caller.
self.clean_event = None
Expand Down
112 changes: 112 additions & 0 deletions src/ai/backend/agent/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import signal
import sys
from collections import OrderedDict, defaultdict
from collections.abc import Collection
from ipaddress import _BaseAddress as BaseIPAddress
from ipaddress import ip_network
from pathlib import Path
Expand Down Expand Up @@ -58,6 +59,7 @@
)
from ai.backend.common.logging import BraceStyleAdapter, Logger
from ai.backend.common.types import (
AgentKernelRegistryByStatus,
ClusterInfo,
CommitStatus,
HardwareMetadata,
Expand Down Expand Up @@ -477,6 +479,116 @@ async def sync_kernel_registry(
suppress_events=True,
)

@rpc_function
@collect_error
async def sync_and_get_kernels(
self,
preparing_kernels: Collection[str],
pulling_kernels: Collection[str],
running_kernels: Collection[str],
terminating_kernels: Collection[str],
) -> dict[str, Any]:
"""
Sync kernel_registry and containers to truth data
and return kernel infos whose status is irreversible.
"""

actual_terminating_kernels: list[tuple[KernelId, str]] = []
actual_terminated_kernels: list[tuple[KernelId, str]] = []

async with self.agent.registry_lock:
actual_existing_kernels = [
kid
for kid in self.agent.kernel_registry
if kid not in self.agent.terminating_kernels
]

for raw_kernel_id in running_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id in self.agent.terminating_kernels:
actual_terminating_kernels.append((
kernel_id,
str(
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED
),
))
else:
actual_terminated_kernels.append((
kernel_id,
str(KernelLifecycleEventReason.ALREADY_TERMINATED),
))

for raw_kernel_id in terminating_kernels:
kernel_id = KernelId(UUID(raw_kernel_id))
if (kernel_obj := self.agent.kernel_registry.get(kernel_id)) is not None:
if kernel_id not in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.ALREADY_TERMINATED,
suppress_events=False,
)
else:
actual_terminated_kernels.append((
kernel_id,
str(KernelLifecycleEventReason.ALREADY_TERMINATED),
))

for kernel_id, kernel_obj in self.agent.kernel_registry.items():
if kernel_id in terminating_kernels:
if kernel_id not in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=False,
)
elif kernel_id in running_kernels:
pass
elif kernel_id in preparing_kernels:
# kernel_registry may not have `preparing` state kernels.
pass
elif kernel_id in pulling_kernels:
# kernel_registry does not have `pulling` state kernels.
# Let's just skip it.
pass
else:
# This kernel is not alive according to the truth data.
# The kernel should be destroyed or cleaned
if kernel_id in self.agent.terminating_kernels:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.CLEAN,
kernel_obj.termination_reason
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=True,
)
elif kernel_id in self.agent.restarting_kernels:
pass
else:
await self.agent.inject_container_lifecycle_event(
kernel_id,
kernel_obj.session_id,
LifecycleEvent.DESTROY,
kernel_obj.termination_reason
or KernelLifecycleEventReason.NOT_FOUND_IN_MANAGER,
suppress_events=True,
)

result = AgentKernelRegistryByStatus(
actual_existing_kernels,
actual_terminating_kernels,
actual_terminated_kernels,
)
return result.to_json()

@rpc_function
@collect_error
async def create_kernels(
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/agent/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ class Container:
backend_obj: Any # used to keep the backend-specific data


class KernelStatus(enum.StrEnum):
"""
A type to track the status of `AbstractKernel` objects.
"""

PREPARING = enum.auto()
RUNNING = enum.auto()
TERMINATING = enum.auto()


class LifecycleEvent(enum.IntEnum):
DESTROY = 0
CLEAN = 1
Expand Down
Loading
Loading