diff --git a/miles/ray/rollout/rollout_manager.py b/miles/ray/rollout/rollout_manager.py index 4dae70f16..1d8b27e07 100644 --- a/miles/ray/rollout/rollout_manager.py +++ b/miles/ray/rollout/rollout_manager.py @@ -186,9 +186,11 @@ async def onload_kv(self): # -------------------------- engine management ----------------------------- - def get_updatable_engines_and_lock(self): + async def get_updatable_engines_and_lock(self): """Return engines eligible for weight updates.""" srv = self._get_updatable_server() + if srv: + await srv.wait_all_engines_alive() engines = [e.actor_handle for e in srv.engines] if srv else [] gpu_counts = srv.engine_gpu_counts if srv else [] gpu_offsets = srv.engine_gpu_offsets if srv else [] diff --git a/miles/ray/rollout/rollout_server.py b/miles/ray/rollout/rollout_server.py index 08b927f06..608236f45 100644 --- a/miles/ray/rollout/rollout_server.py +++ b/miles/ray/rollout/rollout_server.py @@ -209,3 +209,12 @@ async def onload(self, tags: list[str] | None = None): async def check_weights(self, action: str): return await asyncio.gather(*[g.check_weights(action=action) for g in self.server_groups]) + + async def wait_all_engines_alive(self, timeout: float = 600): + sleep_time = 2 + for _ in range(int(timeout // sleep_time)): + if all(e.is_alive for g in self.server_groups for e in g.all_engines): + return + await asyncio.sleep(sleep_time) + logger.info("wait_all_engines_alive looping...") + raise TimeoutError(f"Timed out after {timeout}s waiting for engines to become ready") diff --git a/miles/ray/rollout/server_engine.py b/miles/ray/rollout/server_engine.py index 0d6aeab22..6ca0c4b43 100644 --- a/miles/ray/rollout/server_engine.py +++ b/miles/ray/rollout/server_engine.py @@ -42,6 +42,10 @@ def actor_handle(self) -> ray.actor.ActorHandle: def is_allocated(self) -> bool: return isinstance(self._state, _StateAllocatedBase) + @property + def is_alive(self) -> bool: + return isinstance(self._state, _StateAllocatedAlive) + # TODO: unify w/ trainer `change_state` def _change_state( self,