Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions miles/ray/rollout/server_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,20 @@ def start_engines(self, port_cursors: PortCursors) -> tuple[list, list[int]]:
]
return init_handles, new_engine_indices

def stop_engines(self, rollout_engine_id: int):
logger.info(f"Killing server group {rollout_engine_id}...")
for i in range(
rollout_engine_id * self.nodes_per_engine,
(rollout_engine_id + 1) * self.nodes_per_engine,
):
# There are two callers, only one of them will exist in a running system
# 1. For new callers (RolloutManager.stop_cell, main thread, async),
# deliberately make this function non-async here to avoid introducing two states
# like "stopping (but not stopped)" vs "stopped", since single-thread async code will not yield
# without an await point
# it has the drawback of freezing the whole async thread, which may be avoided later by
# moving `shutdown` mainly to local code
# 2. For legacy callers (RolloutHealthMonitor, another thread, sync)
# it is still unsafe to be called in another thread
# because engine may be observed as non-stopped while being shutdown,
# but that is same as the original code
def stop_engines(self, engine_indices: list[int]):
logger.info(f"Killing server {engine_indices=}...")
for i in engine_indices:
engine = self.all_engines[i]
if engine.is_allocated:
logger.info(f"Shutting down and killing engine at index {i}")
Expand Down
10 changes: 9 additions & 1 deletion miles/utils/health_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,14 @@ def _check_engine_health(self, rollout_engine_id, engine) -> None:
logger.error(
f"Health check failed for rollout engine {rollout_engine_id} (ray timeout or error). Killing actor. Exception: {e}"
)
self._server_group.stop_engines(rollout_engine_id=rollout_engine_id)
nodes_per_engine = self._server_group.nodes_per_engine
self._server_group.stop_engines(
engine_indices=list(
range(
rollout_engine_id * nodes_per_engine,
(rollout_engine_id + 1) * nodes_per_engine,
)
)
)
else:
logger.debug(f"Health check passed for rollout engine {rollout_engine_id}")
Loading