-
Notifications
You must be signed in to change notification settings - Fork 151
Add aliveness to rollout engine state #941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rollout_ft/23
Are you sure you want to change the base?
Changes from all commits
70ae54b
fd3063b
f7b43fd
a6d5989
e35fe34
a3f4975
e01c982
4791ea7
6d47973
88364c7
508eba9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -55,15 +55,15 @@ def engines(self) -> list[ServerEngine]: | |||||||||||||||||||||||||
| """Node-0 engines only (for multi-node serving).""" | ||||||||||||||||||||||||||
| return self.all_engines[:: self.nodes_per_engine] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: | ||||||||||||||||||||||||||
| def start_engines(self, port_cursors: PortCursors) -> tuple[list, list[int]]: | ||||||||||||||||||||||||||
| """Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list | ||||||||||||||||||||||||||
| Returns ``(init_handles, new_engine_indices)`` where *init_handles* is a list | ||||||||||||||||||||||||||
| of Ray ObjectRefs and *port_cursors* maps node index -> next free port. | ||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||
|
Comment on lines
+58
to
63
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for start_engines is outdated. It still refers to curr_num_new_engines (an integer) as the second return value, but the method now returns new_engine_indices (a list of integers). Additionally, it incorrectly implies that port_cursors is part of the return value, whereas it is modified in-place.
Suggested change
|
||||||||||||||||||||||||||
| if self.args.debug_train_only or self.worker_type == "placeholder": | ||||||||||||||||||||||||||
| self.has_new_engines = False | ||||||||||||||||||||||||||
| return [], 0 | ||||||||||||||||||||||||||
| return [], [] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -72,6 +72,7 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: | |||||||||||||||||||||||||
| RolloutRayActor = ray.remote(SGLangEngine) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| new_engines = [] | ||||||||||||||||||||||||||
| new_engine_indices = [] | ||||||||||||||||||||||||||
| for i in range(len(self.all_engines)): | ||||||||||||||||||||||||||
| if self.all_engines[i].is_allocated: | ||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||
|
|
@@ -120,13 +121,14 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: | |||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| new_engines.append((global_rank, rollout_engine)) | ||||||||||||||||||||||||||
| self.all_engines[i].mark_allocated(rollout_engine) | ||||||||||||||||||||||||||
| new_engine_indices.append(i) | ||||||||||||||||||||||||||
| self.all_engines[i].mark_allocated_uninitialized(rollout_engine) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| curr_num_new_engines = len(new_engines) | ||||||||||||||||||||||||||
| self.has_new_engines |= curr_num_new_engines > 0 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if curr_num_new_engines == 0: | ||||||||||||||||||||||||||
| return [], 0 | ||||||||||||||||||||||||||
| return [], [] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if self.args.rollout_external: | ||||||||||||||||||||||||||
| addr_and_ports = allocate_rollout_engine_addr_and_ports_external( | ||||||||||||||||||||||||||
|
|
@@ -152,7 +154,7 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]: | |||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| for index, engine in new_engines | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| return init_handles, curr_num_new_engines | ||||||||||||||||||||||||||
| return init_handles, new_engine_indices | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def stop_engines(self, rollout_engine_id: int): | ||||||||||||||||||||||||||
| logger.info(f"Killing server group {rollout_engine_id}...") | ||||||||||||||||||||||||||
|
|
@@ -176,13 +178,13 @@ def stop_engines(self, rollout_engine_id: int): | |||||||||||||||||||||||||
| async def recover(self, port_cursors: PortCursors): | ||||||||||||||||||||||||||
| dead_indices = [i for i, engine in enumerate(self.all_engines) if not engine.is_allocated] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| handles, curr_num_new_engines = self.start_engines(port_cursors) | ||||||||||||||||||||||||||
| handles, new_engine_indices = self.start_engines(port_cursors) | ||||||||||||||||||||||||||
| await asyncio.gather(*handles) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| release_handles = [] | ||||||||||||||||||||||||||
| all_resume_engines = [] | ||||||||||||||||||||||||||
| logger.info(f"Recovered {curr_num_new_engines} dead rollout engines (worker_type={self.worker_type})") | ||||||||||||||||||||||||||
| assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length" | ||||||||||||||||||||||||||
| logger.info(f"Recovered {len(new_engine_indices)} dead rollout engines (worker_type={self.worker_type})") | ||||||||||||||||||||||||||
| assert len(new_engine_indices) == len(dead_indices), "curr_num_new_engines does not match dead_indices length" | ||||||||||||||||||||||||||
| if self.needs_offload and dead_indices: | ||||||||||||||||||||||||||
| new_engines = [self.all_engines[i] for i in dead_indices] | ||||||||||||||||||||||||||
| release_handles.extend(engine.actor_handle.release_memory_occupation.remote() for engine in new_engines) | ||||||||||||||||||||||||||
|
|
@@ -199,6 +201,12 @@ async def recover(self, port_cursors: PortCursors): | |||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.mark_alive(engine_indices=new_engine_indices) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def mark_alive(self, engine_indices: list[int]): | ||||||||||||||||||||||||||
| for engine_index in engine_indices: | ||||||||||||||||||||||||||
| self.all_engines[engine_index].mark_alive() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def offload(self): | ||||||||||||||||||||||||||
| if not self.needs_offload: | ||||||||||||||||||||||||||
| return [] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mark_alive call here is premature. group.start_engines returns Ray ObjectRefs for the asynchronous init calls, which are only resolved later at line 86 via ray.get(all_init_handles). Marking the engines as alive before they have finished initializing is inconsistent with the state's intended meaning. Additionally, per repository rules, when waiting for a server process to start, checking process liveness is not sufficient; the check must also verify that the server is actively listening for connections on its designated port.
References