@@ -158,6 +158,10 @@ def __init__(
158158 # Share launch spread/replica counts across strategy instances
159159 self ._model_spread_used_gpus : Dict [str , Set [int ]] = {}
160160 self ._active_model_counts : Dict [str , int ] = {}
161+ # Cached launch strategies per base model
162+ self ._launch_strategies : Dict [str , Any ] = {}
163+ # Protect concurrent allocations/releases so bookings stay consistent
164+ self ._allocation_lock = threading .Lock ()
161165 from ..constants import (
162166 XINFERENCE_LAUNCH_ALLOWED_GPUS ,
163167 XINFERENCE_LAUNCH_STRATEGY ,
@@ -559,7 +563,7 @@ def _collect_user_specified_devices(self) -> Set[int]:
559563 user_specified_allocated_devices .add (dev )
560564 return user_specified_allocated_devices
561565
562- def _create_launch_strategy_instance (self ):
566+ def _gather_initial_gpu_memory_info (self ) -> Optional [ Dict [ int , Dict [ str , float ]]] :
563567 # Try to seed strategy with current GPU memory snapshot from NVML
564568 initial_gpu_memory_info : Optional [Dict [int , Dict [str , float ]]] = None
565569 try :
@@ -576,27 +580,64 @@ def _create_launch_strategy_instance(self):
576580 initial_gpu_memory_info = gpu_info or None
577581 except Exception :
578582 initial_gpu_memory_info = None
583+ return initial_gpu_memory_info
579584
585+ def _create_launch_strategy_instance (
586+ self , gpu_memory_info : Optional [Dict [int , Dict [str , float ]]] = None
587+ ):
588+ if gpu_memory_info is None :
589+ raise ValueError ("gpu_memory_info is required to create launch strategy" )
580590 return create_launch_strategy (
581591 strategy_name = self ._launch_strategy_name ,
582592 total_gpu_devices = self ._total_gpu_devices ,
583593 allowed_devices = self ._launch_allowed_gpus ,
584- gpu_memory_info = initial_gpu_memory_info ,
585- model_spread_used_gpus = self ._model_spread_used_gpus ,
586- active_model_counts = self ._active_model_counts ,
594+ gpu_memory_info = gpu_memory_info ,
595+ )
596+
597+ def _get_base_model_uid (self , model_uid : str ) -> str :
598+ try :
599+ base_model_uid , _ = parse_replica_model_uid (model_uid )
600+ return base_model_uid
601+ except Exception :
602+ return model_uid
603+
604+ def _get_or_create_launch_strategy (self , model_uid : str ):
605+ base_model_uid = self ._get_base_model_uid (model_uid )
606+ strategy = self ._launch_strategies .get (base_model_uid )
607+ if strategy is not None :
608+ return strategy
609+ strategy = self ._create_launch_strategy_instance (
610+ gpu_memory_info = self ._gather_initial_gpu_memory_info ()
587611 )
612+ self ._launch_strategies [base_model_uid ] = strategy
613+ return strategy
614+
615+ def ensure_launch_strategy (self , model_uid : str ):
616+ """
617+ Ensure a launch strategy exists for the given base model.
618+ This is intended to be triggered from supervisor before concurrent launches.
619+ """
620+ base_model_uid = self ._get_base_model_uid (model_uid )
621+ with self ._allocation_lock :
622+ if base_model_uid in self ._launch_strategies :
623+ return
624+ strategy = self ._create_launch_strategy_instance (
625+ gpu_memory_info = self ._gather_initial_gpu_memory_info ()
626+ )
627+ self ._launch_strategies [base_model_uid ] = strategy
588628
589629 def allocate_devices (self , model_uid : str , n_gpu : int ) -> List [int ]:
590630 spec = LaunchModelSpec (model_uid = model_uid , n_gpu = n_gpu )
591- strategy = self ._create_launch_strategy_instance ()
592- devices = strategy .allocate (
593- spec = spec ,
594- total_gpu_devices = self ._total_gpu_devices ,
595- user_specified_allocated_devices = self ._collect_user_specified_devices (),
596- allocated_gpus = self ._gpu_to_model_uid ,
597- )
598- for dev in devices :
599- self ._gpu_to_model_uid [int (dev )].add (model_uid )
631+ strategy = self ._get_or_create_launch_strategy (model_uid )
632+ with self ._allocation_lock :
633+ devices = strategy .allocate (
634+ spec = spec ,
635+ total_gpu_devices = self ._total_gpu_devices ,
636+ user_specified_allocated_devices = self ._collect_user_specified_devices (),
637+ allocated_gpus = self ._gpu_to_model_uid ,
638+ )
639+ for dev in devices :
640+ self ._gpu_to_model_uid [int (dev )].add (model_uid )
600641 return sorted (devices )
601642
602643 def allocate_devices_for_model (
@@ -616,15 +657,16 @@ def allocate_devices_for_model(
616657 model_format = model_format ,
617658 quantization = quantization ,
618659 )
619- strategy = self ._create_launch_strategy_instance ()
620- devices = strategy .allocate (
621- spec = spec ,
622- total_gpu_devices = self ._total_gpu_devices ,
623- user_specified_allocated_devices = self ._collect_user_specified_devices (),
624- allocated_gpus = self ._gpu_to_model_uid ,
625- )
626- for dev in devices :
627- self ._gpu_to_model_uid [int (dev )].add (model_uid )
660+ strategy = self ._get_or_create_launch_strategy (model_uid )
661+ with self ._allocation_lock :
662+ devices = strategy .allocate (
663+ spec = spec ,
664+ total_gpu_devices = self ._total_gpu_devices ,
665+ user_specified_allocated_devices = self ._collect_user_specified_devices (),
666+ allocated_gpus = self ._gpu_to_model_uid ,
667+ )
668+ for dev in devices :
669+ self ._gpu_to_model_uid [int (dev )].add (model_uid )
628670 return sorted (devices )
629671
630672 async def allocate_devices_with_gpu_idx (
@@ -666,35 +708,40 @@ async def allocate_devices_with_gpu_idx(
666708 return sorted (gpu_idx )
667709
668710 def release_devices (self , model_uid : str ):
669- devices = [
670- dev for dev , uids in self ._gpu_to_model_uid .items () if model_uid in uids
671- ]
672- for dev in devices :
673- if model_uid in self ._gpu_to_model_uid [dev ]:
674- self ._gpu_to_model_uid [dev ].remove (model_uid )
675- if not self ._gpu_to_model_uid [dev ]:
676- del self ._gpu_to_model_uid [dev ]
677-
678- # check embedding
679- for dev in self ._gpu_to_embedding_model_uids :
680- if model_uid in self ._gpu_to_embedding_model_uids [dev ]:
681- self ._gpu_to_embedding_model_uids [dev ].remove (model_uid )
682-
683- # check user-specified slots
684- for dev in list (self ._user_specified_gpu_to_model_uids ):
685- model_infos = [
686- info
687- for info in self ._user_specified_gpu_to_model_uids [dev ]
688- if info [0 ] == model_uid
711+ base_model_uid = self ._get_base_model_uid (model_uid )
712+ strategy = self ._launch_strategies .get (base_model_uid )
713+ with self ._allocation_lock :
714+ devices = [
715+ dev for dev , uids in self ._gpu_to_model_uid .items () if model_uid in uids
689716 ]
690- for model_info in model_infos :
691- self ._user_specified_gpu_to_model_uids [dev ].remove (model_info )
692- if not self ._user_specified_gpu_to_model_uids [dev ]:
693- del self ._user_specified_gpu_to_model_uids [dev ]
694-
695- # Keep strategy bookkeeping in sync for spread逻辑
696- strategy = self ._create_launch_strategy_instance ()
697- strategy .release (model_uid , devices )
717+ for dev in devices :
718+ if model_uid in self ._gpu_to_model_uid [dev ]:
719+ self ._gpu_to_model_uid [dev ].remove (model_uid )
720+ if not self ._gpu_to_model_uid [dev ]:
721+ del self ._gpu_to_model_uid [dev ]
722+
723+ # check embedding
724+ for dev in self ._gpu_to_embedding_model_uids :
725+ if model_uid in self ._gpu_to_embedding_model_uids [dev ]:
726+ self ._gpu_to_embedding_model_uids [dev ].remove (model_uid )
727+
728+ # check user-specified slots
729+ for dev in list (self ._user_specified_gpu_to_model_uids ):
730+ model_infos = [
731+ info
732+ for info in self ._user_specified_gpu_to_model_uids [dev ]
733+ if info [0 ] == model_uid
734+ ]
735+ for model_info in model_infos :
736+ self ._user_specified_gpu_to_model_uids [dev ].remove (model_info )
737+ if not self ._user_specified_gpu_to_model_uids [dev ]:
738+ del self ._user_specified_gpu_to_model_uids [dev ]
739+
740+ # Keep strategy bookkeeping in sync for spread逻辑
741+ if strategy is not None :
742+ strategy .release (model_uid , devices )
743+ if strategy .is_idle ():
744+ self ._launch_strategies .pop (base_model_uid , None )
698745
699746 async def _create_subpool (
700747 self ,
0 commit comments