diff --git a/miles/ray/rollout/rollout_manager.py b/miles/ray/rollout/rollout_manager.py index 1d8b27e07f..d7c1072b3d 100644 --- a/miles/ray/rollout/rollout_manager.py +++ b/miles/ray/rollout/rollout_manager.py @@ -10,6 +10,7 @@ from miles.ray.rollout.rollout_data_conversion import postprocess_rollout_data from miles.ray.rollout.rollout_server import RolloutServer, start_rollout_servers from miles.ray.rollout.router_manager import start_session_server +from miles.ray.rollout.server_cell import get_cell_indexer_of_id_map from miles.ray.rollout.train_data_conversion import convert_samples_to_train_data, split_train_data_by_dp from miles.ray.utils import Lock from miles.rollout.base_types import ( @@ -224,6 +225,17 @@ def _get_updatable_server(self) -> RolloutServer | None: ) return updatable[0] if updatable else None + # -------------------------- external start/stop ----------------------------- + + # TODO + # async def start_cell(self): + # pass + + async def stop_cell(self, cell_id: int): + idx = get_cell_indexer_of_id_map(self.servers)[cell_id] + group = self.servers[idx.srv_key].server_groups[idx.group_index] + group.stop_engines(engine_indices=idx.engine_indices) + # -------------------------- misc APIs ----------------------------- def get_num_rollout_per_epoch(self): diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py new file mode 100644 index 0000000000..140bf07880 --- /dev/null +++ b/miles/ray/rollout/server_cell.py @@ -0,0 +1,28 @@ +from typing import NamedTuple + +from miles.ray.rollout.rollout_server import RolloutServer + + +class CellIndexer(NamedTuple): + srv_key: str + group_index: int + engine_indices: list[int] + + +def get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]: + result: list[CellIndexer] = [] + for srv_key in sorted(servers): + srv = servers[srv_key] + for group_index, group in enumerate(srv.server_groups): + assert len(group.all_engines) == len(group.engines) * group.nodes_per_engine + for local_index in range(len(group.engines)): + result.append( + CellIndexer( + srv_key=srv_key, + group_index=group_index, + engine_indices=list( + range(local_index * group.nodes_per_engine, (local_index + 1) * group.nodes_per_engine) + ), + ) + ) + return result