3030 List ,
3131 Literal ,
3232 Optional ,
33+ Set ,
3334 Tuple ,
3435 Type ,
3536 Union ,
4849from ..core .status_guard import InstanceInfo , LaunchStatus
4950from ..model .utils import get_engine_params_by_name
5051from ..types import PeftModelConfig
52+ from .launch_strategy import create_launch_strategy
5153from .metrics import record_metrics
5254from .resource import GPUStatus , ResourceStatus
5355from .utils import (
@@ -899,6 +901,44 @@ def _get_worker_refs_by_ip(self, ip: str) -> List[xo.ActorRefType["WorkerActor"]
899901 )
900902 return refs
901903
904+ def _build_gpu_memory_info (
905+ self , worker_ref
906+ ) -> Optional [Dict [int , Dict [str , float ]]]:
907+ """Use latest heartbeat data for GPU memory snapshot."""
908+ worker_status = self ._worker_status .get (worker_ref .address )
909+ if worker_status is None :
910+ return None
911+ gpu_info : Dict [int , Dict [str , float ]] = {}
912+ for dev , status in worker_status .status .items ():
913+ if isinstance (status , GPUStatus ) and str (dev ).startswith ("gpu-" ):
914+ try :
915+ idx = int (str (dev ).split ("-" , 1 )[1 ])
916+ except Exception :
917+ continue
918+ gpu_info [idx ] = {
919+ "total" : status .mem_total // (1024 ** 2 ),
920+ "used" : status .mem_used // (1024 ** 2 ),
921+ "available" : status .mem_free // (1024 ** 2 ),
922+ }
923+ return gpu_info or None
924+
925+ async def _install_strategy_on_worker (self , model_uid : str , worker_ref ) -> None :
926+ ctx = await worker_ref .get_launch_strategy_context ()
927+ gpu_memory_info = self ._build_gpu_memory_info (worker_ref )
928+ if gpu_memory_info is None :
929+ # Heartbeat disabled or missing: assume all visible GPUs are available with "infinite" mem
930+ gpu_memory_info = {
931+ dev : {"total" : float ("inf" ), "used" : 0.0 , "available" : float ("inf" )}
932+ for dev in ctx ["total_gpu_devices" ]
933+ }
934+ strategy = create_launch_strategy (
935+ strategy_name = ctx ["launch_strategy_name" ],
936+ total_gpu_devices = ctx ["total_gpu_devices" ],
937+ allowed_devices = ctx ["allowed_devices" ],
938+ gpu_memory_info = gpu_memory_info ,
939+ )
940+ await worker_ref .install_launch_strategy (model_uid , strategy )
941+
902942 @log_async (logger = logger )
903943 async def launch_builtin_model (
904944 self ,
@@ -1096,9 +1136,6 @@ async def _launch_one_model(worker_ref, _replica_model_uid, rank: int):
10961136 model_type = model_type or "LLM"
10971137
10981138 try :
1099- # Ensure per-base-model launch strategy is ready on worker before concurrent launches
1100- await worker_ref .ensure_launch_strategy (model_uid )
1101-
11021139 subpool_address = await worker_ref .launch_builtin_model (
11031140 model_uid = _replica_model_uid ,
11041141 model_name = model_name ,
@@ -1140,6 +1177,7 @@ async def _launch_model():
11401177 try :
11411178 # Pre-fetch worker loads for balanced scheduling
11421179 worker_candidates = []
1180+ prepared_workers : Set [str ] = set ()
11431181
11441182 if target_worker_refs :
11451183 workers = target_worker_refs
@@ -1188,6 +1226,11 @@ async def _launch_model():
11881226 _idx
11891227 ].append (worker_ref )
11901228
1229+ # Prepare launch strategy per worker once before launching replicas
1230+ if worker_ref .address not in prepared_workers :
1231+ await self ._install_strategy_on_worker (model_uid , worker_ref )
1232+ prepared_workers .add (worker_ref .address )
1233+
11911234 if enable_xavier and _idx == 0 :
11921235 """
11931236 Start the rank 0 model actor on the worker that holds the rank 1 replica,
@@ -1359,6 +1402,7 @@ async def _launch_model():
13591402 "n_worker cannot be larger than the number of available workers."
13601403 )
13611404 try :
1405+ prepared_workers : Set [str ] = set ()
13621406 for _idx , rep_model_uid in enumerate (
13631407 iter_replica_model_uid (model_uid , replica )
13641408 ):
@@ -1375,6 +1419,11 @@ async def _launch_model():
13751419 ].replica_to_worker_refs [_idx ].append (worker_ref )
13761420 nonlocal model_type
13771421 model_type = model_type or "LLM"
1422+ if worker_ref .address not in prepared_workers :
1423+ await self ._install_strategy_on_worker (
1424+ model_uid , worker_ref
1425+ )
1426+ prepared_workers .add (worker_ref .address )
13781427 if i_worker > 1 :
13791428 assert (
13801429 driver_info is not None
0 commit comments