Skip to content

Commit 2ccb59e

Browse files
authored
fix(BA-2754): configure mock accelerator (#6324)
1 parent 3b2f877 commit 2ccb59e

File tree

3 files changed

+92
-20
lines changed

3 files changed

+92
-20
lines changed

src/ai/backend/install/cli.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,15 @@
4040
from . import __version__
4141
from .common import detect_os
4242
from .context import DevContext, PackageContext, current_log
43-
from .types import CliArgs, DistInfo, InstallInfo, InstallModes, InstallVariable, PrerequisiteError
43+
from .types import (
44+
Accelerator,
45+
CliArgs,
46+
DistInfo,
47+
InstallInfo,
48+
InstallModes,
49+
InstallVariable,
50+
PrerequisiteError,
51+
)
4452

4553
top_tasks: WeakSet[asyncio.Task] = WeakSet()
4654

@@ -402,7 +410,10 @@ def __init__(
402410
self._enabled_menus.add(InstallModes.CONFIGURE)
403411
assert mode is not None
404412
self._mode = mode
405-
self.install_variable = InstallVariable(public_facing_address=args.public_facing_address)
413+
self.install_variable = InstallVariable(
414+
public_facing_address=args.public_facing_address,
415+
accelerator=Accelerator(args.accelerator) if args.accelerator is not None else None,
416+
)
406417

407418
def compose(self) -> ComposeResult:
408419
yield Label(id="heading")
@@ -526,6 +537,7 @@ def __init__(self, args: CliArgs | None = None) -> None:
526537
show_guide=False,
527538
non_interactive=False,
528539
public_facing_address="127.0.0.1",
540+
accelerator=None,
529541
)
530542
self._args = args
531543

@@ -609,6 +621,13 @@ async def action_shutdown(self, message: str | None = None, exit_code: int = 0)
609621
default=False,
610622
help="Show the post-install guide using INSTALL-INFO if present.",
611623
)
624+
@click.option(
625+
"--accelerator",
626+
type=click.Choice([a.value for a in Accelerator], case_sensitive=False),
627+
default=None,
628+
show_default=True,
629+
help="Select accelerator plugin (cuda, cuda_mock, cuda_mig_mock, rocm_mock, none)",
630+
)
612631
@click.option(
613632
"--headless",
614633
is_flag=True,
@@ -631,6 +650,7 @@ def main(
631650
non_interactive: bool,
632651
headless: bool,
633652
public_facing_address: str,
653+
accelerator: str,
634654
) -> None:
635655
"""The installer"""
636656
# check sudo permission
@@ -648,6 +668,7 @@ def main(
648668
show_guide=show_guide,
649669
non_interactive=non_interactive,
650670
public_facing_address=public_facing_address,
671+
accelerator=accelerator,
651672
)
652673
app = InstallerApp(args)
653674
app.run(headless=headless)

src/ai/backend/install/context.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from .http import wget
4949
from .python import check_python
5050
from .types import (
51+
Accelerator,
5152
DistInfo,
5253
HalfstackConfig,
5354
HostPortPair,
@@ -105,6 +106,10 @@ def __init__(
105106
def hydrate_install_info(self) -> InstallInfo:
106107
raise NotImplementedError
107108

109+
@abstractmethod
110+
async def _configure_mock_accelerator(self, accelerator: Accelerator) -> None:
111+
raise NotImplementedError
112+
108113
def add_post_guide(self, guide: PostGuide) -> None:
109114
self._post_guides.append(guide)
110115

@@ -461,6 +466,7 @@ async def configure_manager(self) -> None:
461466
async def configure_agent(self) -> None:
462467
halfstack = self.install_info.halfstack_config
463468
service = self.install_info.service_config
469+
accelerator = self.install_info.accelerator
464470
toml_path = self.copy_config("agent.toml")
465471
self.sed_in_place_multi(
466472
toml_path,
@@ -485,25 +491,27 @@ async def configure_agent(self) -> None:
485491
Path(self.install_info.service_config.agent_var_base_path).mkdir(
486492
parents=True, exist_ok=True
487493
)
488-
# enable the CUDA plugin (open-source version)
489-
# The agent will show an error log if the CUDA is not available in the system and report
490-
# "cuda.devices = 0" as the agent capacity, but it will still run.
494+
if accelerator is not None:
495+
if accelerator == Accelerator.CUDA:
496+
plugin_list = ['"ai.backend.accelerator.cuda_open"']
497+
elif accelerator in (
498+
Accelerator.CUDA_MOCK,
499+
Accelerator.CUDA_MIG_MOCK,
500+
Accelerator.ROCM_MOCK,
501+
):
502+
plugin_list = ['"ai.backend.accelerator.mock"']
503+
else:
504+
plugin_list = []
505+
506+
await self._configure_mock_accelerator(accelerator)
507+
else:
508+
plugin_list = []
509+
491510
self.sed_in_place(
492511
toml_path,
493-
re.compile("^(# )?allow-compute-plugins = .*", flags=re.M),
494-
'allow-compute-plugins = ["ai.backend.accelerator.cuda_open"]',
512+
re.compile(r"^(# )?allow-compute-plugins = .*", flags=re.M),
513+
f"allow-compute-plugins = [{', '.join(plugin_list)}]",
495514
)
496-
# TODO: let the installer enable the CUDA plugin only when it verifies CUDA availability or
497-
# via an explicit installer option/config.
498-
r"""
499-
if [ $ENABLE_CUDA -eq 1 ]; then
500-
sed_inplace "s/# allow-compute-plugins =.*/allow-compute-plugins = [\"ai.backend.accelerator.cuda_open\"]/" ./agent.toml
501-
elif [ $ENABLE_CUDA_MOCK -eq 1 ]; then
502-
sed_inplace "s/# allow-compute-plugins =.*/allow-compute-plugins = [\"ai.backend.accelerator.mock\"]/" ./agent.toml
503-
else
504-
sed_inplace "s/# allow-compute-plugins =.*/allow-compute-plugins = []/" ./agent.toml
505-
fi
506-
"""
507515

508516
async def configure_storage_proxy(self) -> None:
509517
halfstack = self.install_info.halfstack_config
@@ -870,6 +878,7 @@ def hydrate_install_info(self) -> InstallInfo:
870878
last_updated=datetime.now(tzutc()),
871879
halfstack_config=halfstack_config,
872880
service_config=service_config,
881+
accelerator=self.install_variable.accelerator,
873882
)
874883

875884
def copy_config(self, template_name: str) -> Path:
@@ -895,10 +904,23 @@ async def install(self) -> None:
895904
await install_editable_webui(self)
896905
await self.install_halfstack()
897906

898-
async def _configure_mock_accelerator(self) -> None:
907+
async def _configure_mock_accelerator(self, accelerator: Accelerator) -> None:
899908
"""
900909
cp "configs/accelerator/mock-accelerator.toml" mock-accelerator.toml
901910
"""
911+
mapping = {
912+
Accelerator.CUDA_MOCK: "configs/accelerator/mock-accelerator.toml",
913+
Accelerator.CUDA_MIG_MOCK: "configs/accelerator/cuda-mock-mig.toml",
914+
Accelerator.ROCM_MOCK: "configs/accelerator/rocm-mock.toml",
915+
}
916+
917+
src = mapping.get(accelerator)
918+
if not src:
919+
return
920+
921+
dst = Path("mock-accelerator.toml")
922+
print(f"[Installer] Copying accelerator config: {src} -> {dst}")
923+
shutil.copy(src, dst)
902924

903925
async def configure(self) -> None:
904926
self.log_header("Configuring manager...")
@@ -986,6 +1008,7 @@ def hydrate_install_info(self) -> InstallInfo:
9861008
last_updated=datetime.now(tzutc()),
9871009
halfstack_config=halfstack_config,
9881010
service_config=service_config,
1011+
accelerator=self.install_variable.accelerator,
9891012
)
9901013

9911014
def copy_config(self, template_name: str) -> Path:
@@ -1148,6 +1171,24 @@ async def install(self) -> None:
11481171
self.log_header("Installing databases (halfstack)...")
11491172
await self.install_halfstack()
11501173

1174+
async def _configure_mock_accelerator(self, accelerator: Accelerator) -> None:
1175+
"""
1176+
cp "configs/accelerator/mock-accelerator.toml" mock-accelerator.toml
1177+
"""
1178+
mapping = {
1179+
Accelerator.CUDA_MOCK: "configs/accelerator/mock-accelerator.toml",
1180+
Accelerator.CUDA_MIG_MOCK: "configs/accelerator/cuda-mock-mig.toml",
1181+
Accelerator.ROCM_MOCK: "configs/accelerator/rocm-mock.toml",
1182+
}
1183+
1184+
src = mapping.get(accelerator)
1185+
if not src:
1186+
return
1187+
1188+
dst = Path("mock-accelerator.toml")
1189+
print(f"[Installer] Copying accelerator config: {src} -> {dst}")
1190+
shutil.copy(src, dst)
1191+
11511192
async def configure(self) -> None:
11521193
self.log_header("Configuring manager...")
11531194
await self.configure_manager()

src/ai/backend/install/types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import enum
55
from datetime import datetime
66
from pathlib import Path
7-
from typing import cast
7+
from typing import Optional, cast
88

99
from pydantic import BaseModel, Field
1010
from rich.console import ConsoleRenderable, RichCast
@@ -53,6 +53,7 @@ class CliArgs:
5353
show_guide: bool
5454
non_interactive: bool
5555
public_facing_address: str
56+
accelerator: Optional[str] = None
5657

5758

5859
class PrerequisiteError(RichCast, Exception):
@@ -84,13 +85,21 @@ class DistInfo(BaseModel):
8485
image_refs: list[str] = Field(default_factory=list)
8586

8687

88+
class Accelerator(enum.StrEnum):
89+
CUDA = "cuda"
90+
CUDA_MOCK = "cuda_mock"
91+
CUDA_MIG_MOCK = "cuda_mig_mock"
92+
ROCM_MOCK = "rocm_mock"
93+
94+
8795
class InstallInfo(BaseModel):
8896
version: str
8997
type: InstallType
9098
last_updated: datetime
9199
base_path: Path
92100
halfstack_config: HalfstackConfig
93101
service_config: ServiceConfig
102+
accelerator: Optional[Accelerator] = None
94103

95104

96105
@dataclasses.dataclass()
@@ -169,3 +178,4 @@ class ServiceConfig:
169178
@dataclasses.dataclass
170179
class InstallVariable:
171180
public_facing_address: str = "127.0.0.1"
181+
accelerator: Optional[Accelerator] = None

0 commit comments

Comments
 (0)