4848from .http import wget
4949from .python import check_python
5050from .types import (
51+ Accelerator ,
5152 DistInfo ,
5253 HalfstackConfig ,
5354 HostPortPair ,
@@ -105,6 +106,10 @@ def __init__(
105106 def hydrate_install_info (self ) -> InstallInfo :
106107 raise NotImplementedError
107108
109+ @abstractmethod
110+ async def _configure_mock_accelerator (self , accelerator : Accelerator ) -> None :
111+ raise NotImplementedError
112+
108113 def add_post_guide (self , guide : PostGuide ) -> None :
109114 self ._post_guides .append (guide )
110115
@@ -461,6 +466,7 @@ async def configure_manager(self) -> None:
461466 async def configure_agent (self ) -> None :
462467 halfstack = self .install_info .halfstack_config
463468 service = self .install_info .service_config
469+ accelerator = self .install_info .accelerator
464470 toml_path = self .copy_config ("agent.toml" )
465471 self .sed_in_place_multi (
466472 toml_path ,
@@ -485,25 +491,27 @@ async def configure_agent(self) -> None:
485491 Path (self .install_info .service_config .agent_var_base_path ).mkdir (
486492 parents = True , exist_ok = True
487493 )
488- # enable the CUDA plugin (open-source version)
489- # The agent will show an error log if the CUDA is not available in the system and report
490- # "cuda.devices = 0" as the agent capacity, but it will still run.
494+ if accelerator is not None :
495+ if accelerator == Accelerator .CUDA :
496+ plugin_list = ['"ai.backend.accelerator.cuda_open"' ]
497+ elif accelerator in (
498+ Accelerator .CUDA_MOCK ,
499+ Accelerator .CUDA_MIG_MOCK ,
500+ Accelerator .ROCM_MOCK ,
501+ ):
502+ plugin_list = ['"ai.backend.accelerator.mock"' ]
503+ else :
504+ plugin_list = []
505+
506+ await self ._configure_mock_accelerator (accelerator )
507+ else :
508+ plugin_list = []
509+
491510 self .sed_in_place (
492511 toml_path ,
493- re .compile ("^(# )?allow-compute-plugins = .*" , flags = re .M ),
494- ' allow-compute-plugins = ["ai.backend.accelerator.cuda_open"]' ,
512+ re .compile (r "^(# )?allow-compute-plugins = .*" , flags = re .M ),
513+ f" allow-compute-plugins = [{ ', ' . join ( plugin_list ) } ]" ,
495514 )
496- # TODO: let the installer enable the CUDA plugin only when it verifies CUDA availability or
497- # via an explicit installer option/config.
498- r"""
499- if [ $ENABLE_CUDA -eq 1 ]; then
500- sed_inplace "s/# allow-compute-plugins =.*/allow-compute-plugins = [\"ai.backend.accelerator.cuda_open\"]/" ./agent.toml
501- elif [ $ENABLE_CUDA_MOCK -eq 1 ]; then
502- sed_inplace "s/# allow-compute-plugins =.*/allow-compute-plugins = [\"ai.backend.accelerator.mock\"]/" ./agent.toml
503- else
504- sed_inplace "s/# allow-compute-plugins =.*/allow-compute-plugins = []/" ./agent.toml
505- fi
506- """
507515
508516 async def configure_storage_proxy (self ) -> None :
509517 halfstack = self .install_info .halfstack_config
@@ -870,6 +878,7 @@ def hydrate_install_info(self) -> InstallInfo:
870878 last_updated = datetime .now (tzutc ()),
871879 halfstack_config = halfstack_config ,
872880 service_config = service_config ,
881+ accelerator = self .install_variable .accelerator ,
873882 )
874883
875884 def copy_config (self , template_name : str ) -> Path :
@@ -895,10 +904,23 @@ async def install(self) -> None:
895904 await install_editable_webui (self )
896905 await self .install_halfstack ()
897906
898- async def _configure_mock_accelerator (self ) -> None :
907+ async def _configure_mock_accelerator (self , accelerator : Accelerator ) -> None :
899908 """
900909 cp "configs/accelerator/mock-accelerator.toml" mock-accelerator.toml
901910 """
911+ mapping = {
912+ Accelerator .CUDA_MOCK : "configs/accelerator/mock-accelerator.toml" ,
913+ Accelerator .CUDA_MIG_MOCK : "configs/accelerator/cuda-mock-mig.toml" ,
914+ Accelerator .ROCM_MOCK : "configs/accelerator/rocm-mock.toml" ,
915+ }
916+
917+ src = mapping .get (accelerator )
918+ if not src :
919+ return
920+
921+ dst = Path ("mock-accelerator.toml" )
922+ print (f"[Installer] Copying accelerator config: { src } -> { dst } " )
923+ shutil .copy (src , dst )
902924
903925 async def configure (self ) -> None :
904926 self .log_header ("Configuring manager..." )
@@ -986,6 +1008,7 @@ def hydrate_install_info(self) -> InstallInfo:
9861008 last_updated = datetime .now (tzutc ()),
9871009 halfstack_config = halfstack_config ,
9881010 service_config = service_config ,
1011+ accelerator = self .install_variable .accelerator ,
9891012 )
9901013
9911014 def copy_config (self , template_name : str ) -> Path :
@@ -1148,6 +1171,24 @@ async def install(self) -> None:
11481171 self .log_header ("Installing databases (halfstack)..." )
11491172 await self .install_halfstack ()
11501173
1174+ async def _configure_mock_accelerator (self , accelerator : Accelerator ) -> None :
1175+ """
1176+ cp "configs/accelerator/mock-accelerator.toml" mock-accelerator.toml
1177+ """
1178+ mapping = {
1179+ Accelerator .CUDA_MOCK : "configs/accelerator/mock-accelerator.toml" ,
1180+ Accelerator .CUDA_MIG_MOCK : "configs/accelerator/cuda-mock-mig.toml" ,
1181+ Accelerator .ROCM_MOCK : "configs/accelerator/rocm-mock.toml" ,
1182+ }
1183+
1184+ src = mapping .get (accelerator )
1185+ if not src :
1186+ return
1187+
1188+ dst = Path ("mock-accelerator.toml" )
1189+ print (f"[Installer] Copying accelerator config: { src } -> { dst } " )
1190+ shutil .copy (src , dst )
1191+
11511192 async def configure (self ) -> None :
11521193 self .log_header ("Configuring manager..." )
11531194 await self .configure_manager ()
0 commit comments