Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion miles/ray/rollout/rollout_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def start_rollout_servers(args, pg) -> dict[str, "RolloutServer"]:
router_port=router_port,
update_weights=model_cfg.update_weights,
)
handles, _ = group.start_engines(port_cursors)
handles, new_engine_indices = group.start_engines(port_cursors)
group.mark_alive(engine_indices=new_engine_indices)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The mark_alive call here is premature. group.start_engines returns Ray ObjectRefs for the asynchronous init calls, which are only resolved later at line 86 via ray.get(all_init_handles). Marking the engines as alive before they have finished initializing is inconsistent with the state's intended meaning. Additionally, per repository rules, when waiting for a server process to start, checking process liveness is not sufficient; the check must also verify that the server is actively listening for connections on its designated port.

References
  1. When waiting for a server process to start, checking process liveness (e.g., is_alive()) is not sufficient. The check must also verify that the server is actively listening for connections on its designated port, for instance by attempting a socket connection or making an HTTP request.

all_init_handles.extend(handles)
server_groups.append(group)

Expand Down
40 changes: 30 additions & 10 deletions miles/ray/rollout/server_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,39 @@
logger = logging.getLogger(__name__)


# NOTE: currently it is almost a dataclass without encapsulation;
# ideally, it may encapsulate all logic and ensure state transition only happens after internal actions,
# and no external code can touch its internals
# NOTE: currently it is almost a dataclass without encapsulation to minimize code diff
# (logic is batched currently while may be non-batched in the future)
# ideally, it may encapsulate all actions and states, and ensure state transition
# only happens after internal actions, while no external code can touch its internals
# for example:
# def __init__(...configs...)
# def init(): _allocate_engine(); _mark_allocated(); _init_engine(); _mark_alive()
# def stop(): _kill_engine(); _mark_stopped()
# and external code cannot directly mutate the engines
# this makes it more encapsulated, easier to reason about, and prevents state-resource inconsistency
class ServerEngine:
def __init__(self):
self._state = _StateStopped()

def mark_allocated(self, actor_handle: ray.actor.ActorHandle):
self._change_state("mark_allocated", _StateStopped, _StateAllocated(actor_handle=actor_handle))
def mark_allocated_uninitialized(self, actor_handle: ray.actor.ActorHandle):
self._change_state("mark_allocated", _StateStopped, _StateAllocatedUninitialized(actor_handle=actor_handle))

def mark_alive(self):
self._change_state(
"mark_alive", _StateAllocatedUninitialized, _StateAllocatedAlive(actor_handle=self.actor_handle)
)

def mark_stopped(self):
self._change_state("mark_stopped", (_StateStopped, _StateAllocated), _StateStopped())
self._change_state("mark_stopped", (_StateStopped, _StateAllocatedBase), _StateStopped())

@property
def actor_handle(self) -> ray.actor.ActorHandle:
assert isinstance(self._state, _StateAllocated)
assert isinstance(self._state, _StateAllocatedBase)
return self._state.actor_handle

@property
def is_allocated(self) -> bool:
return isinstance(self._state, _StateAllocated)
return isinstance(self._state, _StateAllocatedBase)

# TODO: unify w/ trainer `change_state`
def _change_state(
Expand All @@ -54,8 +66,16 @@ class _StateStopped(_StateBase):
pass


class _StateAllocated(_StateBase):
class _StateAllocatedBase(_StateBase):
actor_handle: ray.actor.ActorHandle


_State = _StateStopped | _StateAllocated
class _StateAllocatedUninitialized(_StateAllocatedBase):
pass


class _StateAllocatedAlive(_StateAllocatedBase):
pass


_State = _StateStopped | _StateAllocatedUninitialized | _StateAllocatedAlive
26 changes: 17 additions & 9 deletions miles/ray/rollout/server_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ def engines(self) -> list[ServerEngine]:
"""Node-0 engines only (for multi-node serving)."""
return self.all_engines[:: self.nodes_per_engine]

def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]:
def start_engines(self, port_cursors: PortCursors) -> tuple[list, list[int]]:
"""Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting.

Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list
Returns ``(init_handles, new_engine_indices)`` where *init_handles* is a list
of Ray ObjectRefs and *port_cursors* maps node index -> next free port.
"""
Comment on lines +58 to 63
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for start_engines is outdated. It still refers to curr_num_new_engines (an integer) as the second return value, but the method now returns new_engine_indices (a list of integers). Additionally, it incorrectly implies that port_cursors is part of the return value, whereas it is modified in-place.

Suggested change
def start_engines(self, port_cursors: PortCursors) -> tuple[list, list[int]]:
"""Create Ray actors, allocate ports, and fire ``engine.init()`` without waiting.
Returns ``(init_handles, curr_num_new_engines)`` where *init_handles* is a list
of Ray ObjectRefs and *port_cursors* maps node index -> next free port.
"""
def start_engines(self, port_cursors: PortCursors) -> tuple[list, list[int]]:
"""Create Ray actors, allocate ports, and fire engine.init() without waiting.
Returns (init_handles, new_engine_indices) where *init_handles* is a list
of Ray ObjectRefs and *new_engine_indices* is a list of indices of the new engines.
"""

if self.args.debug_train_only or self.worker_type == "placeholder":
self.has_new_engines = False
return [], 0
return [], []

num_gpu_per_engine = min(self.num_gpus_per_engine, self.args.num_gpus_per_node)

Expand All @@ -72,6 +72,7 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]:
RolloutRayActor = ray.remote(SGLangEngine)

new_engines = []
new_engine_indices = []
for i in range(len(self.all_engines)):
if self.all_engines[i].is_allocated:
continue
Expand Down Expand Up @@ -120,13 +121,14 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]:
)

new_engines.append((global_rank, rollout_engine))
self.all_engines[i].mark_allocated(rollout_engine)
new_engine_indices.append(i)
self.all_engines[i].mark_allocated_uninitialized(rollout_engine)

curr_num_new_engines = len(new_engines)
self.has_new_engines |= curr_num_new_engines > 0

if curr_num_new_engines == 0:
return [], 0
return [], []

if self.args.rollout_external:
addr_and_ports = allocate_rollout_engine_addr_and_ports_external(
Expand All @@ -152,7 +154,7 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, int]:
)
for index, engine in new_engines
]
return init_handles, curr_num_new_engines
return init_handles, new_engine_indices

def stop_engines(self, rollout_engine_id: int):
logger.info(f"Killing server group {rollout_engine_id}...")
Expand All @@ -176,13 +178,13 @@ def stop_engines(self, rollout_engine_id: int):
async def recover(self, port_cursors: PortCursors):
dead_indices = [i for i, engine in enumerate(self.all_engines) if not engine.is_allocated]

handles, curr_num_new_engines = self.start_engines(port_cursors)
handles, new_engine_indices = self.start_engines(port_cursors)
await asyncio.gather(*handles)

release_handles = []
all_resume_engines = []
logger.info(f"Recovered {curr_num_new_engines} dead rollout engines (worker_type={self.worker_type})")
assert curr_num_new_engines == len(dead_indices), "curr_num_new_engines does not match dead_indices length"
logger.info(f"Recovered {len(new_engine_indices)} dead rollout engines (worker_type={self.worker_type})")
assert len(new_engine_indices) == len(dead_indices), "curr_num_new_engines does not match dead_indices length"
if self.needs_offload and dead_indices:
new_engines = [self.all_engines[i] for i in dead_indices]
release_handles.extend(engine.actor_handle.release_memory_occupation.remote() for engine in new_engines)
Expand All @@ -199,6 +201,12 @@ async def recover(self, port_cursors: PortCursors):
]
)

self.mark_alive(engine_indices=new_engine_indices)

def mark_alive(self, engine_indices: list[int]):
for engine_index in engine_indices:
self.all_engines[engine_index].mark_alive()

def offload(self):
if not self.needs_offload:
return []
Expand Down
Loading