Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/2404.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Update `gpu_allocated` legacy metric fields to consider all accelerator devices, including both `cuda.devices` and `cuda.shares`, but also MIG variants and other NPUs as well (Known issue: all resources visible to each user and group MUST use a consistent fraction mode)
2 changes: 1 addition & 1 deletion src/ai/backend/accelerator/cuda_open/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ async def generate_mounts(

def get_metadata(self) -> AcceleratorMetadata:
return {
"slot_name": self.slot_types[0][0],
"slot_name": str(self.slot_types[0][0]),
"human_readable_name": "GPU",
"description": "CUDA-capable GPU",
"display_unit": "GPU",
Expand Down
6 changes: 3 additions & 3 deletions src/ai/backend/agent/alloc_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def check_exclusive(self, a: SlotName, b: SlotName) -> bool:
return True
for t in self.exclusive_slot_types:
if "*" in t:
a_in_exclusive_set = a_in_exclusive_set or fnmatch.fnmatchcase(a, t)
b_in_exclusive_set = b_in_exclusive_set or fnmatch.fnmatchcase(b, t)
a_in_exclusive_set = a_in_exclusive_set or fnmatch.fnmatchcase(str(a), str(t))
b_in_exclusive_set = b_in_exclusive_set or fnmatch.fnmatchcase(str(b), str(t))
return a_in_exclusive_set and b_in_exclusive_set

def format_current_allocations(self) -> str:
Expand Down Expand Up @@ -713,7 +713,7 @@ def distribute_evenly(
def allocate_across_devices(
dev_allocs: list[tuple[DeviceId, Decimal]],
remaining_alloc: Decimal,
slot_name: str,
slot_name: SlotName,
) -> dict[DeviceId, Decimal]:
slot_allocation: dict[DeviceId, Decimal] = {}
n_devices = len(dev_allocs)
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/docker/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def prepare_resource_spec(self) -> Tuple[KernelResourceSpec, Optional[Mapp
slots = slots.normalize_slots(ignore_unknown=True)
resource_spec = KernelResourceSpec(
allocations={},
slots={**slots}, # copy
slots=slots.copy(),
mounts=[],
scratch_disk_size=0, # TODO: implement (#70)
)
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/dummy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def prepare_resource_spec(
slots = slots.normalize_slots(ignore_unknown=True)
resource_spec = KernelResourceSpec(
allocations={},
slots={**slots}, # copy
slots=slots.copy(),
mounts=[],
scratch_disk_size=0, # TODO: implement (#70)
)
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/agent/kubernetes/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _kernel_resource_spec_read():
slots = slots.normalize_slots(ignore_unknown=True)
resource_spec = KernelResourceSpec(
allocations={},
slots={**slots}, # copy
slots=slots.copy(),
mounts=[],
scratch_disk_size=0, # TODO: implement (#70)
)
Expand Down
17 changes: 9 additions & 8 deletions src/ai/backend/agent/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class KernelResourceSpec:
while kernel containers are running.
"""

slots: Mapping[SlotName, str]
slots: ResourceSlot
"""Stores the original user-requested resource slots."""

allocations: MutableMapping[DeviceName, Mapping[SlotName, Mapping[DeviceId, Decimal]]]
Expand Down Expand Up @@ -186,7 +186,7 @@ def read_from_string(cls, text: str) -> "KernelResourceSpec":
return cls(
scratch_disk_size=BinarySize.finite_from_str(kvpairs["SCRATCH_SIZE"]),
allocations=dict(allocations),
slots=ResourceSlot(load_json(kvpairs["SLOTS"])),
slots=ResourceSlot.from_json(load_json(kvpairs["SLOTS"])),
mounts=mounts,
)

Expand Down Expand Up @@ -520,7 +520,8 @@ def _read_kernel_resource_spec(path: Path) -> None:
return
if resource_spec is None:
return
for slot_name in resource_spec.slots.keys():
for raw_slot_name in resource_spec.slots.keys():
slot_name = SlotName(raw_slot_name)
slot_allocs[slot_name] += Decimal(resource_spec.slots[slot_name])

async def _wrap_future(fut: asyncio.Future) -> None:
Expand Down Expand Up @@ -615,8 +616,8 @@ def allocate(

# Sort out the device names in the resource spec based on the configured allocation order
dev_names: set[DeviceName] = set()
for slot_name in slots.keys():
dev_name = slot_name.split(".", maxsplit=1)[0]
for raw_slot_name in slots.keys():
dev_name = raw_slot_name.split(".", maxsplit=1)[0]
dev_names.add(DeviceName(dev_name))
ordered_dev_names = sorted(dev_names, key=lambda item: alloc_order.index(item))

Expand All @@ -635,9 +636,9 @@ def allocate(
computer_ctx = computers[dev_name]
device_id_map = {device.device_id: device for device in computer_ctx.devices}
device_specific_slots = {
SlotName(slot_name): Decimal(alloc)
for slot_name, alloc in slots.items()
if slot_name == dev_name or slot_name.startswith(f"{dev_name}.")
SlotName(raw_slot_name): Decimal(alloc)
for raw_slot_name, alloc in slots.items()
if raw_slot_name == dev_name or raw_slot_name.startswith(f"{dev_name}.")
}
try:
if isinstance(computer_ctx.alloc_map, FractionAllocMap):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _prepare_resource_spec(
slots = slots.normalize_slots(ignore_unknown=True)
resource_spec = KernelResourceSpec(
allocations={},
slots={**slots}, # copy
slots=slots.copy(),
mounts=[],
scratch_disk_size=0, # TODO: implement (#70)
)
Expand Down
19 changes: 11 additions & 8 deletions src/ai/backend/common/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import temporenc

from .typed_validators import AutoDirectoryPath
from .types import BinarySize, ResourceSlot
from .types import BinarySize, ResourceSlot, SlotName

__all__ = ("packb", "unpackb")

Expand All @@ -31,6 +31,7 @@ class ExtTypes(enum.IntEnum):
ENUM = 6
IMAGE_REF = 7
RESOURCE_SLOT = 8
SLOT_NAME = 9
BACKENDAI_BINARY_SIZE = 16
AUTO_DIRECTORY_PATH = 17

Expand All @@ -57,6 +58,8 @@ def _default(obj: object) -> _msgpack.ExtType:
return _msgpack.ExtType(ExtTypes.AUTO_DIRECTORY_PATH, os.fsencode(obj))
case ResourceSlot():
return _msgpack.ExtType(ExtTypes.RESOURCE_SLOT, pickle.dumps(obj, protocol=5))
case SlotName():
return _msgpack.ExtType(ExtTypes.SLOT_NAME, pickle.dumps(obj, protocol=5))
case enum.Enum():
return _msgpack.ExtType(ExtTypes.ENUM, pickle.dumps(obj, protocol=5))
case ImageRef():
Expand All @@ -65,21 +68,21 @@ def _default(obj: object) -> _msgpack.ExtType:


class ExtFunc(Protocol):
def __call__(self, data: bytes) -> Any:
pass
def __call__(self, data: bytes, /) -> Any: ...


_DEFAULT_EXT_HOOK: Mapping[ExtTypes, ExtFunc] = {
ExtTypes.UUID: lambda data: uuid.UUID(bytes=data),
ExtTypes.DATETIME: lambda data: temporenc.unpackb(data).datetime(),
ExtTypes.DECIMAL: lambda data: pickle.loads(data),
ExtTypes.DECIMAL: pickle.loads,
ExtTypes.POSIX_PATH: lambda data: PosixPath(os.fsdecode(data)),
ExtTypes.PURE_POSIX_PATH: lambda data: PurePosixPath(os.fsdecode(data)),
ExtTypes.AUTO_DIRECTORY_PATH: lambda data: AutoDirectoryPath(os.fsdecode(data)),
ExtTypes.ENUM: lambda data: pickle.loads(data),
ExtTypes.RESOURCE_SLOT: lambda data: pickle.loads(data),
ExtTypes.BACKENDAI_BINARY_SIZE: lambda data: pickle.loads(data),
ExtTypes.IMAGE_REF: lambda data: pickle.loads(data),
ExtTypes.ENUM: pickle.loads,
ExtTypes.RESOURCE_SLOT: pickle.loads,
ExtTypes.SLOT_NAME: pickle.loads,
ExtTypes.BACKENDAI_BINARY_SIZE: pickle.loads,
ExtTypes.IMAGE_REF: pickle.loads,
}


Expand Down
14 changes: 12 additions & 2 deletions src/ai/backend/common/resilience/policies/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Optional, ParamSpec, TypeVar
from typing import Final, Optional, ParamSpec, TypeVar

from ai.backend.common.exception import (
BackendAIError,
Expand Down Expand Up @@ -47,6 +47,13 @@ class BackoffStrategy(enum.StrEnum):
EXPONENTIAL = "exponential"


DEFAULT_NON_RETRYABLE_ERRORS: Final = (
TypeError,
NameError,
AttributeError,
)


@dataclass
class RetryArgs:
"""Arguments for RetryPolicy."""
Expand Down Expand Up @@ -88,7 +95,10 @@ def __init__(self, args: RetryArgs) -> None:
self._retry_delay = args.retry_delay
self._backoff_strategy = args.backoff_strategy
self._max_delay = args.max_delay
self._non_retryable_exceptions = args.non_retryable_exceptions
self._non_retryable_exceptions = (
*DEFAULT_NON_RETRYABLE_ERRORS,
*args.non_retryable_exceptions,
)

async def execute(
self,
Expand Down
111 changes: 99 additions & 12 deletions src/ai/backend/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import textwrap
import uuid
from abc import ABC, ABCMeta, abstractmethod
from collections import UserDict, defaultdict, namedtuple
from collections import UserDict, UserString, defaultdict, namedtuple
from collections.abc import AsyncIterator, Iterable
from contextvars import ContextVar
from dataclasses import dataclass, field
Expand Down Expand Up @@ -313,7 +313,52 @@ def check_typed_tuple(value: Tuple[Any, ...], types: Tuple[Type, ...]) -> Tuple:
AGENTID_STORAGE = AgentId("storage")
DeviceName = NewType("DeviceName", str)
DeviceId = NewType("DeviceId", str)
SlotName = NewType("SlotName", str)


class SlotName(UserString):
__slots__ = ("_parsed", "_device_name", "_major_type", "_minor_type")

def __init__(self, value: str | SlotName) -> None:
self._parsed = False
self._device_name = ""
self._major_type = ""
self._minor_type = ""
super().__init__(value)

def _parse(self) -> None:
# Do lazy-parsing for when required only because SlotName is used
# very frequently in certain code paths to represent subtypes,
# without actually needing to access parsed attributes.
if self._parsed:
return
name, _, type_ = self.data.partition(".")
major_type, _, minor_type = type_.partition(":")
self._device_name = name
self._major_type = major_type
self._minor_type = minor_type
self._parsed = True

@property
def device_name(self) -> str:
self._parse()
return self._device_name

@property
def major_type(self) -> str:
self._parse()
return self._major_type

@property
def minor_type(self) -> str:
self._parse()
return self._minor_type

def is_accelerator(self) -> bool:
if self.major_type in ("device", "devices", "share", "shares"):
return True
return False


MetricKey = NewType("MetricKey", str)

AccessKey = NewType("AccessKey", str)
Expand Down Expand Up @@ -808,22 +853,64 @@ def _validate_binary_size(v: Any) -> BinarySize:
# Create a custom type annotation for BinarySize fields
BinarySizeField = Annotated[BinarySize, PlainValidator(_validate_binary_size)]

type RawResourceValue = int | float | str | Decimal | BinarySize

class ResourceSlot(UserDict):

class ResourceSlot(UserDict[str, Decimal]):
"""
key: `str` type slot name.
value: `str` or `Decimal` type value. Do not convert this to `float` or `int`.
value: `Decimal` type value. Do not convert this to `float` or `int` for calculation accuracy.
"""

__slots__ = ("data",)

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __init__(
self,
data: Mapping[SlotName, RawResourceValue | None]
| Mapping[str, RawResourceValue | None]
| None = None,
**kwargs: RawResourceValue | None, # for legacy codes (TODO: update all kwarg-based init)
) -> None:
if data is None:
data = {}
normalized: dict[str, Decimal] = {}
for kwargs_key, v in kwargs.items():
if v is None:
continue
normalized[kwargs_key] = self._process_raw_value(kwargs_key, v)
for raw_data_key, v in data.items():
if v is None:
continue
data_key = str(raw_data_key)
normalized[data_key] = self._process_raw_value(data_key, v)
super().__init__(normalized)

@classmethod
def from_known_slots(cls, known_slots: Mapping[SlotName, SlotTypes]) -> ResourceSlot:
return cls({k: Decimal(0) for k in known_slots.keys()})

@classmethod
def _process_raw_value(cls, key: str, value: RawResourceValue) -> Decimal:
if cls._guess_slot_type(str(key)) == SlotTypes.BYTES and isinstance(value, str):
v = Decimal(BinarySize.from_str(value))
else:
v = Decimal(value)
return v

def __setitem__(self, key: str | SlotName, value: RawResourceValue | None) -> None:
normalized_key = str(key)
if value is None:
self.data.pop(normalized_key, None)
return
self.data[normalized_key] = self._process_raw_value(normalized_key, value)

def __getitem__(self, key: str | SlotName) -> Decimal:
normalized_key = str(key)
return self.data[normalized_key]

def copy(self) -> Self:
return type(self)(self.data.copy())

def sync_keys(self, other: ResourceSlot) -> None:
self_only_keys = self.data.keys() - other.data.keys()
other_only_keys = other.data.keys() - self.data.keys()
Expand Down Expand Up @@ -916,16 +1003,16 @@ def normalize_slots(self, *, ignore_unknown: bool) -> ResourceSlot:
raise ValueError(f"Unknown slots: {', '.join(map(repr, unknown_slots))}")
data = {k: v for k, v in self.data.items() if k in known_slots}
for k in unset_slots:
data[k] = Decimal(0)
data[str(k)] = Decimal(0)
return type(self)(data)

@classmethod
def _normalize_value(cls, key: str, value: Any, unit: SlotTypes) -> Decimal:
def _normalize_value(cls, key: str, value: RawResourceValue, unit: SlotTypes) -> Decimal:
try:
if unit == SlotTypes.BYTES:
if isinstance(value, Decimal):
return Decimal(value) if value.is_finite() else value
if isinstance(value, int):
return value
if isinstance(value, (int, float)):
return Decimal(value)
value = Decimal(BinarySize.from_str(value))
else:
Expand Down Expand Up @@ -999,7 +1086,7 @@ def from_user_input(
# fill missing
for k in slot_types.keys():
if k not in data:
data[k] = Decimal(0)
data[str(k)] = Decimal(0)
except KeyError as e:
extra_guide = ""
if e.args[0] == "shmem":
Expand All @@ -1010,7 +1097,7 @@ def from_user_input(
def to_humanized(self, slot_types: Mapping) -> Mapping[str, str]:
try:
return {
k: type(self)._humanize_value(v, slot_types[k])
k: type(self)._humanize_value(Decimal(v), slot_types[k])
for k, v in self.data.items()
if v is not None
}
Expand Down
Loading
Loading