From 63c83c3c2fc03763f54c36aa8ec2f0e02eadbeda Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:51:13 +0800 Subject: [PATCH 01/12] more --- miles/ray/rollout/server_cell.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 miles/ray/rollout/server_cell.py diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py new file mode 100644 index 000000000..68b1c09d2 --- /dev/null +++ b/miles/ray/rollout/server_cell.py @@ -0,0 +1,18 @@ +from typing import NamedTuple + +from miles.ray.rollout.rollout_server import RolloutServer + + +class CellIndexer(NamedTuple): + srv_key: str + group_index: int + engine_indices: list[int] + + +def get_cell_indexer_from_id(servers: dict[str, RolloutServer], cell_id: int) -> CellIndexer: + assert 0 <= cell_id < get_num_cells(servers) + return TODO + + +def get_num_cells(servers: dict[str, RolloutServer]) -> int: + return TODO From 11c167a4ef5aac054818d7e5a65280f3613980b0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:54:07 +0800 Subject: [PATCH 02/12] more --- miles/ray/rollout/rollout_manager.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/miles/ray/rollout/rollout_manager.py b/miles/ray/rollout/rollout_manager.py index 1d8b27e07..abfe231e4 100644 --- a/miles/ray/rollout/rollout_manager.py +++ b/miles/ray/rollout/rollout_manager.py @@ -10,6 +10,7 @@ from miles.ray.rollout.rollout_data_conversion import postprocess_rollout_data from miles.ray.rollout.rollout_server import RolloutServer, start_rollout_servers from miles.ray.rollout.router_manager import start_session_server +from miles.ray.rollout.server_cell import get_cell_indexer_from_id from miles.ray.rollout.train_data_conversion import convert_samples_to_train_data, split_train_data_by_dp from miles.ray.utils import Lock from miles.rollout.base_types import ( @@ -224,6 +225,17 @@ def _get_updatable_server(self) -> RolloutServer | None: ) return updatable[0] if updatable else None + # -------------------------- external start/stop ----------------------------- + + # TODO + # async def start_cell(self): + # pass + + async def stop_cell(self, cell_id: int): + indexer = get_cell_indexer_from_id(self.servers, cell_id) + group = self.servers[indexer.srv_key].server_groups[indexer.group_index] + group.stop_engines(TODO_translate) + # -------------------------- misc APIs ----------------------------- def get_num_rollout_per_epoch(self): From 504844abb50f19930761c3b23288098cdd34dbba Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:54:29 +0800 Subject: [PATCH 03/12] implement get_cell_indexer_from_id and get_num_cells A cell is nodes_per_engine consecutive engines from ServerGroup.all_engines. cell_id is a flat index across all servers and groups. --- miles/ray/rollout/server_cell.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index 68b1c09d2..ddab84d09 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -11,8 +11,26 @@ class CellIndexer(NamedTuple): def get_cell_indexer_from_id(servers: dict[str, RolloutServer], cell_id: int) -> CellIndexer: assert 0 <= cell_id < get_num_cells(servers) - return TODO + offset = 0 + for srv_key, srv in servers.items(): + for group_index, group in enumerate(srv.server_groups): + num_cells_in_group = len(group.engines) + if cell_id < offset + num_cells_in_group: + local_cell = cell_id - offset + npe = group.nodes_per_engine + engine_indices = list(range(local_cell * npe, (local_cell + 1) * npe)) + return CellIndexer( + srv_key=srv_key, + group_index=group_index, + engine_indices=engine_indices, + ) + offset += num_cells_in_group + raise AssertionError("unreachable") def get_num_cells(servers: dict[str, RolloutServer]) -> int: - return TODO + return sum( + len(group.engines) + for srv in servers.values() + for group in srv.server_groups + ) From a116d7b9684926ee144e878234d942a8b390b609 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:55:20 +0800 Subject: [PATCH 04/12] refactor: extract _get_cell_id_to_indexer_map as single source of truth --- miles/ray/rollout/server_cell.py | 33 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index ddab84d09..d4f3de3d7 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -9,28 +9,25 @@ class CellIndexer(NamedTuple): engine_indices: list[int] -def get_cell_indexer_from_id(servers: dict[str, RolloutServer], cell_id: int) -> CellIndexer: - assert 0 <= cell_id < get_num_cells(servers) - offset = 0 +def _get_cell_id_to_indexer_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]: + result: list[CellIndexer] = [] for srv_key, srv in servers.items(): for group_index, group in enumerate(srv.server_groups): - num_cells_in_group = len(group.engines) - if cell_id < offset + num_cells_in_group: - local_cell = cell_id - offset - npe = group.nodes_per_engine - engine_indices = list(range(local_cell * npe, (local_cell + 1) * npe)) - return CellIndexer( + npe = group.nodes_per_engine + for local_cell in range(len(group.engines)): + result.append(CellIndexer( srv_key=srv_key, group_index=group_index, - engine_indices=engine_indices, - ) - offset += num_cells_in_group - raise AssertionError("unreachable") + engine_indices=list(range(local_cell * npe, (local_cell + 1) * npe)), + )) + return result + + +def get_cell_indexer_from_id(servers: dict[str, RolloutServer], cell_id: int) -> CellIndexer: + mapping = _get_cell_id_to_indexer_map(servers) + assert 0 <= cell_id < len(mapping) + return mapping[cell_id] def get_num_cells(servers: dict[str, RolloutServer]) -> int: - return sum( - len(group.engines) - for srv in servers.values() - for group in srv.server_groups - ) + return len(_get_cell_id_to_indexer_map(servers)) From e1219fbd61ef5c4b028d7481d7b2ba400361e8e8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:55:47 +0800 Subject: [PATCH 05/12] more --- miles/ray/rollout/server_cell.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index d4f3de3d7..4d080af12 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -13,12 +13,12 @@ def _get_cell_id_to_indexer_map(servers: dict[str, RolloutServer]) -> list[CellI result: list[CellIndexer] = [] for srv_key, srv in servers.items(): for group_index, group in enumerate(srv.server_groups): - npe = group.nodes_per_engine for local_cell in range(len(group.engines)): result.append(CellIndexer( srv_key=srv_key, group_index=group_index, - engine_indices=list(range(local_cell * npe, (local_cell + 1) * npe)), + engine_indices=list( + range(local_cell * group.nodes_per_engine, (local_cell + 1) * group.nodes_per_engine)), )) return result From 4a665bf5c48d82fe8b41e1e9c655cf317edffdb7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:56:02 +0800 Subject: [PATCH 06/12] fmt --- miles/ray/rollout/server_cell.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index 4d080af12..dac0514bf 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -14,12 +14,15 @@ def _get_cell_id_to_indexer_map(servers: dict[str, RolloutServer]) -> list[CellI for srv_key, srv in servers.items(): for group_index, group in enumerate(srv.server_groups): for local_cell in range(len(group.engines)): - result.append(CellIndexer( - srv_key=srv_key, - group_index=group_index, - engine_indices=list( - range(local_cell * group.nodes_per_engine, (local_cell + 1) * group.nodes_per_engine)), - )) + result.append( + CellIndexer( + srv_key=srv_key, + group_index=group_index, + engine_indices=list( + range(local_cell * group.nodes_per_engine, (local_cell + 1) * group.nodes_per_engine) + ), + ) + ) return result From 9f8913993a619c957483f988bc465601ce5b153e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:56:29 +0800 Subject: [PATCH 07/12] more --- miles/ray/rollout/server_cell.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index dac0514bf..106f5847d 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -9,7 +9,7 @@ class CellIndexer(NamedTuple): engine_indices: list[int] -def _get_cell_id_to_indexer_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]: +def _get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]: result: list[CellIndexer] = [] for srv_key, srv in servers.items(): for group_index, group in enumerate(srv.server_groups): @@ -27,10 +27,10 @@ def _get_cell_id_to_indexer_map(servers: dict[str, RolloutServer]) -> list[CellI def get_cell_indexer_from_id(servers: dict[str, RolloutServer], cell_id: int) -> CellIndexer: - mapping = _get_cell_id_to_indexer_map(servers) + mapping = _get_cell_indexer_of_id_map(servers) assert 0 <= cell_id < len(mapping) return mapping[cell_id] def get_num_cells(servers: dict[str, RolloutServer]) -> int: - return len(_get_cell_id_to_indexer_map(servers)) + return len(_get_cell_indexer_of_id_map(servers)) From 5c7f68668deed7a1a30a70aa0b8b5b877b8d5faf Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:57:00 +0800 Subject: [PATCH 08/12] more --- miles/ray/rollout/rollout_manager.py | 6 +++--- miles/ray/rollout/server_cell.py | 12 +----------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/miles/ray/rollout/rollout_manager.py b/miles/ray/rollout/rollout_manager.py index abfe231e4..4f7a3d313 100644 --- a/miles/ray/rollout/rollout_manager.py +++ b/miles/ray/rollout/rollout_manager.py @@ -10,7 +10,7 @@ from miles.ray.rollout.rollout_data_conversion import postprocess_rollout_data from miles.ray.rollout.rollout_server import RolloutServer, start_rollout_servers from miles.ray.rollout.router_manager import start_session_server -from miles.ray.rollout.server_cell import get_cell_indexer_from_id +from miles.ray.rollout.server_cell import get_cell_indexer_of_id_map from miles.ray.rollout.train_data_conversion import convert_samples_to_train_data, split_train_data_by_dp from miles.ray.utils import Lock from miles.rollout.base_types import ( @@ -232,8 +232,8 @@ def _get_updatable_server(self) -> RolloutServer | None: # pass async def stop_cell(self, cell_id: int): - indexer = get_cell_indexer_from_id(self.servers, cell_id) - group = self.servers[indexer.srv_key].server_groups[indexer.group_index] + idx = get_cell_indexer_of_id_map(self.servers)[cell_id] + group = self.servers[idx.srv_key].server_groups[idx.group_index] group.stop_engines(TODO_translate) # -------------------------- misc APIs ----------------------------- diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index 106f5847d..21671c37d 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -9,7 +9,7 @@ class CellIndexer(NamedTuple): engine_indices: list[int] -def _get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]: +def get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]: result: list[CellIndexer] = [] for srv_key, srv in servers.items(): for group_index, group in enumerate(srv.server_groups): @@ -24,13 +24,3 @@ def _get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellI ) ) return result - - -def get_cell_indexer_from_id(servers: dict[str, RolloutServer], cell_id: int) -> CellIndexer: - mapping = _get_cell_indexer_of_id_map(servers) - assert 0 <= cell_id < len(mapping) - return mapping[cell_id] - - -def get_num_cells(servers: dict[str, RolloutServer]) -> int: - return len(_get_cell_indexer_of_id_map(servers)) From 3840993493d5aef01e30fe1dfa96b4cc501117e9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 18:59:38 +0800 Subject: [PATCH 09/12] more --- miles/ray/rollout/rollout_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/ray/rollout/rollout_manager.py b/miles/ray/rollout/rollout_manager.py index 4f7a3d313..d7c1072b3 100644 --- a/miles/ray/rollout/rollout_manager.py +++ b/miles/ray/rollout/rollout_manager.py @@ -234,7 +234,7 @@ def _get_updatable_server(self) -> RolloutServer | None: async def stop_cell(self, cell_id: int): idx = get_cell_indexer_of_id_map(self.servers)[cell_id] group = self.servers[idx.srv_key].server_groups[idx.group_index] - group.stop_engines(TODO_translate) + group.stop_engines(engine_indices=idx.engine_indices) # -------------------------- misc APIs ----------------------------- From 8131ae1c8b422910771a312b7413d72f113fb722 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 19:03:32 +0800 Subject: [PATCH 10/12] more --- miles/ray/rollout/server_cell.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index 21671c37d..4a7792d5a 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -11,7 +11,8 @@ class CellIndexer(NamedTuple): def get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]: result: list[CellIndexer] = [] - for srv_key, srv in servers.items(): + for srv_key in sorted(servers): + srv = servers[srv_key] for group_index, group in enumerate(srv.server_groups): for local_cell in range(len(group.engines)): result.append( From e6804eb75f7a95595bcc5fc73b27479dc182e5fe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 19:04:17 +0800 Subject: [PATCH 11/12] add assert for all_engines / engines invariant in cell indexer --- miles/ray/rollout/server_cell.py | 1 + 1 file changed, 1 insertion(+) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index 4a7792d5a..5b95f645a 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -14,6 +14,7 @@ def get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIn for srv_key in sorted(servers): srv = servers[srv_key] for group_index, group in enumerate(srv.server_groups): + assert len(group.all_engines) == len(group.engines) * group.nodes_per_engine for local_cell in range(len(group.engines)): result.append( CellIndexer( From f6c9b804b6ca2b123d5a893eb5b6b80f7aea93a4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Tue, 7 Apr 2026 19:04:44 +0800 Subject: [PATCH 12/12] more --- miles/ray/rollout/server_cell.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/miles/ray/rollout/server_cell.py b/miles/ray/rollout/server_cell.py index 5b95f645a..140bf0788 100644 --- a/miles/ray/rollout/server_cell.py +++ b/miles/ray/rollout/server_cell.py @@ -15,13 +15,13 @@ def get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIn srv = servers[srv_key] for group_index, group in enumerate(srv.server_groups): assert len(group.all_engines) == len(group.engines) * group.nodes_per_engine - for local_cell in range(len(group.engines)): + for local_index in range(len(group.engines)): result.append( CellIndexer( srv_key=srv_key, group_index=group_index, engine_indices=list( - range(local_cell * group.nodes_per_engine, (local_cell + 1) * group.nodes_per_engine) + range(local_index * group.nodes_per_engine, (local_index + 1) * group.nodes_per_engine) ), ) )