diff --git a/miles/ray/rollout/rollout_server.py b/miles/ray/rollout/rollout_server.py index aa0370b13..08b927f06 100644 --- a/miles/ray/rollout/rollout_server.py +++ b/miles/ray/rollout/rollout_server.py @@ -74,7 +74,8 @@ def start_rollout_servers(args, pg) -> dict[str, "RolloutServer"]: router_port=router_port, update_weights=model_cfg.update_weights, ) - handles, _ = group.start_engines(port_cursors) + handles, new_engine_indices = group.start_engines(port_cursors) + group.mark_alive(engine_indices=new_engine_indices) all_init_handles.extend(handles) server_groups.append(group) diff --git a/miles/ray/rollout/server_engine.py b/miles/ray/rollout/server_engine.py index 532eded76..0d6aeab22 100644 --- a/miles/ray/rollout/server_engine.py +++ b/miles/ray/rollout/server_engine.py @@ -8,27 +8,39 @@ logger = logging.getLogger(__name__) -# NOTE: currently it is almost a dataclass without encapsulation; -# ideally, it may encapsulate all logic and ensure state transition only happens after internal actions, -# and no external code can touch its internals +# NOTE: currently it is almost a dataclass without encapsulation to minimize code diff +# (logic is batched currently while may be non-batched in the future) +# ideally, it may encapsulate all actions and states, and ensure state transition +# only happens after internal actions, while no external code can touch its internals +# for example: +# def __init__(...configs...) +# def init(): _allocate_engine(); _mark_allocated(); _init_engine(); _mark_alive() +# def stop(): _kill_engine(); _mark_stopped() +# and external code cannot directly mutate the engines +# this makes it more encapsulated, easier to reason about, and prevents state-resource inconsistency class ServerEngine: def __init__(self): self._state = _StateStopped() - def mark_allocated(self, actor_handle: ray.actor.ActorHandle): - self._change_state("mark_allocated", _StateStopped, _StateAllocated(actor_handle=actor_handle)) + def mark_allocated_uninitialized(self, actor_handle: ray.actor.ActorHandle): + self._change_state("mark_allocated", _StateStopped, _StateAllocatedUninitialized(actor_handle=actor_handle)) + + def mark_alive(self): + self._change_state( + "mark_alive", _StateAllocatedUninitialized, _StateAllocatedAlive(actor_handle=self.actor_handle) + ) def mark_stopped(self): - self._change_state("mark_stopped", (_StateStopped, _StateAllocated), _StateStopped()) + self._change_state("mark_stopped", (_StateStopped, _StateAllocatedBase), _StateStopped()) @property def actor_handle(self) -> ray.actor.ActorHandle: - assert isinstance(self._state, _StateAllocated) + assert isinstance(self._state, _StateAllocatedBase) return self._state.actor_handle @property def is_allocated(self) -> bool: - return isinstance(self._state, _StateAllocated) + return isinstance(self._state, _StateAllocatedBase) # TODO: unify w/ trainer `change_state` def _change_state( @@ -54,8 +66,16 @@ class _StateStopped(_StateBase): pass -class _StateAllocated(_StateBase): +class _StateAllocatedBase(_StateBase): actor_handle: ray.actor.ActorHandle -_State = _StateStopped | _StateAllocated +class _StateAllocatedUninitialized(_StateAllocatedBase): + pass + + +class _StateAllocatedAlive(_StateAllocatedBase): + pass + + +_State = _StateStopped | _StateAllocatedUninitialized | _StateAllocatedAlive diff --git a/miles/ray/rollout/server_group.py b/miles/ray/rollout/server_group.py index 8fc0fedcb..b6af1edb7 100644 --- a/miles/ray/rollout/server_group.py +++ b/miles/ray/rollout/server_group.py @@ -55,15 +55,15 @@ def engines(self) -> list[ServerEngine]: """Node-0 engines only (for multi-node serving).""" return self.all_engines[:: self.nodes_per_engine] - def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: + def start_engines(self, port_cursors: PortCursors) -> tuple[list, list[int]]: """Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting. - Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list + Returns ``(init_handles, new_engine_indices)`` where *init_handles* is a list of Ray ObjectRefs and *port_cursors* maps node index -> next free port. """ if self.args.debug_train_only or self.worker_type == "placeholder": self.has_new_engines = False - return [], 0 + return [], [] num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node) @@ -72,6 +72,7 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: RolloutRayActor = ray.remote(SGLangEngine) new_engines = [] + new_engine_indices = [] for i in range(len(self.all_engines)): if self.all_engines[i].is_allocated: continue @@ -120,13 +121,14 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: ) new_engines.append((global_rank, rollout_engine)) - self.all_engines[i].mark_allocated(rollout_engine) + new_engine_indices.append(i) + self.all_engines[i].mark_allocated_uninitialized(rollout_engine) curr_num_new_engines = len(new_engines) self.has_new_engines |= curr_num_new_engines > 0 if curr_num_new_engines == 0: - return [], 0 + return [], [] if self.args.rollout_external: addr_and_ports = allocate_rollout_engine_addr_and_ports_external( @@ -152,7 +154,7 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: ) for index, engine in new_engines ] - return init_handles, curr_num_new_engines + return init_handles, new_engine_indices def stop_engines(self, rollout_engine_id: int): logger.info(f"Killing server group {rollout_engine_id}...") @@ -176,13 +178,13 @@ def stop_engines(self, rollout_engine_id: int): async def recover(self, port_cursors: PortCursors): dead_indices = [i for i, engine in enumerate(self.all_engines) if not engine.is_allocated] - handles, curr_num_new_engines = self.start_engines(port_cursors) + handles, new_engine_indices = self.start_engines(port_cursors) await asyncio.gather(*handles) release_handles = [] all_resume_engines = [] - logger.info(f"Recovered {curr_num_new_engines} dead rollout engines (worker_type={self.worker_type})") - assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length" + logger.info(f"Recovered {len(new_engine_indices)} dead rollout engines (worker_type={self.worker_type})") + assert len(new_engine_indices) == len(dead_indices), "curr_num_new_engines does not match dead_indices length" if self.needs_offload and dead_indices: new_engines = [self.all_engines[i] for i in dead_indices] release_handles.extend(engine.actor_handle.release_memory_occupation.remote() for engine in new_engines) @@ -199,6 +201,12 @@ async def recover(self, port_cursors: PortCursors): ] ) + self.mark_alive(engine_indices=new_engine_indices) + + def mark_alive(self, engine_indices: list[int]): + for engine_index in engine_indices: + self.all_engines[engine_index].mark_alive() + def offload(self): if not self.needs_offload: return []