diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 7bdd1c17a..892c576e2 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -553,10 +553,10 @@ def update_weights(self) -> None: # type: ignore[override] if self.args.debug_train_only or self.args.debug_rollout_only: return - rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get( + rollout_engines, rollout_engine_lock, has_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get( self.rollout_manager.get_updatable_engines_and_lock.remote() ) - if num_new_engines > 0: + if has_new_engines: self.weight_updater.connect_rollout_engines( rollout_engines, rollout_engine_lock, @@ -565,7 +565,7 @@ def update_weights(self) -> None: # type: ignore[override] ) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: - ray.get(self.rollout_manager.clear_updatable_num_new_engines.remote()) + ray.get(self.rollout_manager.clear_updatable_has_new_engines.remote()) self.weight_updater.update_weights() diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 6ec17be86..581664faa 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -494,14 +494,14 @@ def update_weights(self) -> None: ray.get(self.rollout_manager.recover_updatable_engines.remote()) dist.barrier(group=get_gloo_group()) - rollout_engines, rollout_engine_lock, num_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get( + rollout_engines, rollout_engine_lock, has_new_engines, engine_gpu_counts, engine_gpu_offsets = ray.get( self.rollout_manager.get_updatable_engines_and_lock.remote() ) if self.args.offload_train: reload_process_groups() - if num_new_engines > 0: + if has_new_engines: self.weight_updater.connect_rollout_engines( rollout_engines, rollout_engine_lock, @@ -510,7 +510,7 @@ def update_weights(self) -> None: ) dist.barrier(group=get_gloo_group()) if dist.get_rank() == 0: - ray.get(self.rollout_manager.clear_updatable_num_new_engines.remote()) + ray.get(self.rollout_manager.clear_updatable_has_new_engines.remote()) if self.args.offload_train and is_lora_enabled(self.args): # For LoRA, we must resume() to restore GPU memory backing for adapter diff --git a/miles/ray/rollout/rollout_manager.py b/miles/ray/rollout/rollout_manager.py index 4c380e449..1c0002a03 100644 --- a/miles/ray/rollout/rollout_manager.py +++ b/miles/ray/rollout/rollout_manager.py @@ -192,17 +192,17 @@ def get_updatable_engines_and_lock(self): engines = srv.engines if srv else [] gpu_counts = srv.engine_gpu_counts if srv else [] gpu_offsets = srv.engine_gpu_offsets if srv else [] - num_new = srv.num_new_engines if srv else 0 - return engines, self.rollout_engine_lock, num_new, gpu_counts, gpu_offsets + has_new = srv.has_new_engines if srv else False + return engines, self.rollout_engine_lock, has_new, gpu_counts, gpu_offsets - def clear_updatable_num_new_engines(self): - # when fault tolerance is not enabled, we need to manually clear num_new_engines after update_weights + def clear_updatable_has_new_engines(self): + # when fault tolerance is not enabled, we need to manually clear has_new_engines after update_weights srv = self._get_updatable_server() if srv: - srv.clear_num_new_engines() + srv.clear_has_new_engines() async def recover_updatable_engines(self) -> None: - """Restart any dead rollout engines and update num_new_engines for update_weights detection. + """Restart any dead rollout engines and update has_new_engines for update_weights detection. Recovers the updatable model (the one that receives weight updates from training). diff --git a/miles/ray/rollout/rollout_server.py b/miles/ray/rollout/rollout_server.py index 53a92bce9..9cc1ad118 100644 --- a/miles/ray/rollout/rollout_server.py +++ b/miles/ray/rollout/rollout_server.py @@ -60,7 +60,7 @@ def start_rollout_servers(args, pg) -> dict[str, "RolloutServer"]: pg=pg, all_engines=[None] * num_engines if group_cfg.worker_type != "placeholder" else [], num_gpus_per_engine=gpus_per_engine, - num_new_engines=0, + has_new_engines=False, worker_type=group_cfg.worker_type, rank_offset=engine_offset, gpu_offset=gpu_offset, @@ -71,7 +71,7 @@ def start_rollout_servers(args, pg) -> dict[str, "RolloutServer"]: router_port=router_port, update_weights=model_cfg.update_weights, ) - handles = group.start_engines(port_cursors) + handles, _ = group.start_engines(port_cursors) all_init_handles.extend(handles) server_groups.append(group) @@ -163,12 +163,12 @@ def all_engines(self): return [e for g in self.server_groups for e in g.all_engines] @property - def num_new_engines(self): - return sum(g.num_new_engines for g in self.server_groups) + def has_new_engines(self) -> bool: + return any(g.has_new_engines for g in self.server_groups) - def clear_num_new_engines(self): + def clear_has_new_engines(self): for g in self.server_groups: - g.num_new_engines = 0 + g.has_new_engines = False @property def engine_gpu_counts(self) -> list[int]: diff --git a/miles/ray/rollout/server_group.py b/miles/ray/rollout/server_group.py index 4dbfe9fbe..9b37cbf18 100644 --- a/miles/ray/rollout/server_group.py +++ b/miles/ray/rollout/server_group.py @@ -33,7 +33,8 @@ class ServerGroup: pg: Any # (placement_group, reordered_bundle_indices, reordered_gpu_ids) all_engines: list num_gpus_per_engine: int - num_new_engines: int + # NOTE: this may have risk when recovering engines parallelly; may use source of truth (all_engines) later + has_new_engines: bool worker_type: str = "regular" # "regular", "prefill", or "decode" rank_offset: int = 0 gpu_offset: int = 0 @@ -53,15 +54,15 @@ def engines(self): """Node-0 engines only (for multi-node serving).""" return self.all_engines[:: self.nodes_per_engine] - def start_engines(self, port_cursors: PortCursors) -> list: + def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: """Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting. - Returns ``(init_handles, port_cursors)`` where *init_handles* is a list + Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list of Ray ObjectRefs and *port_cursors* maps node index -> next free port. """ if self.args.debug_train_only or self.worker_type == "placeholder": - self.num_new_engines = 0 - return [] + self.has_new_engines = False + return [], 0 num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node) @@ -120,10 +121,11 @@ def start_engines(self, port_cursors: PortCursors) -> list: new_engines.append((global_rank, rollout_engine)) self.all_engines[i] = rollout_engine - self.num_new_engines = len(new_engines) + curr_num_new_engines = len(new_engines) + self.has_new_engines |= curr_num_new_engines > 0 - if self.num_new_engines == 0: - return [] + if curr_num_new_engines == 0: + return [], 0 if self.args.rollout_external: addr_and_ports = allocate_rollout_engine_addr_and_ports_external( @@ -149,7 +151,7 @@ def start_engines(self, port_cursors: PortCursors) -> list: ) for index, engine in new_engines ] - return init_handles + return init_handles, curr_num_new_engines def stop_engines(self, rollout_engine_id: int): logger.info(f"Killing server group {rollout_engine_id}...") @@ -173,12 +175,13 @@ def stop_engines(self, rollout_engine_id: int): async def recover(self, port_cursors: PortCursors): dead_indices = [i for i, engine in enumerate(self.all_engines) if engine is None] - await asyncio.gather(*self.start_engines(port_cursors)) + handles, curr_num_new_engines = self.start_engines(port_cursors) + await asyncio.gather(*handles) release_handles = [] all_resume_engines = [] - logger.info(f"Recovered {self.num_new_engines} dead rollout engines (worker_type={self.worker_type})") - assert self.num_new_engines == len(dead_indices), "num_new_engines does not match dead_indices length" + logger.info(f"Recovered {curr_num_new_engines} dead rollout engines (worker_type={self.worker_type})") + assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length" if self.needs_offload and dead_indices: new_engines = [self.all_engines[i] for i in dead_indices] release_handles.extend(engine.release_memory_occupation.remote() for engine in new_engines)