Skip to content
12 changes: 12 additions & 0 deletions miles/ray/rollout/rollout_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from miles.ray.rollout.rollout_data_conversion import postprocess_rollout_data
from miles.ray.rollout.rollout_server import RolloutServer, start_rollout_servers
from miles.ray.rollout.router_manager import start_session_server
from miles.ray.rollout.server_cell import get_cell_indexer_of_id_map
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Updating the import to reflect the suggested function rename in server_cell.py.

Suggested change
from miles.ray.rollout.server_cell import get_cell_indexer_of_id_map
from miles.ray.rollout.server_cell import get_cell_id_to_indexer_map

from miles.ray.rollout.train_data_conversion import convert_samples_to_train_data, split_train_data_by_dp
from miles.ray.utils import Lock
from miles.rollout.base_types import (
Expand Down Expand Up @@ -224,6 +225,17 @@ def _get_updatable_server(self) -> RolloutServer | None:
)
return updatable[0] if updatable else None

# -------------------------- external start/stop -----------------------------

# TODO
# async def start_cell(self):
# pass

async def stop_cell(self, cell_id: int):
idx = get_cell_indexer_of_id_map(self.servers)[cell_id]
group = self.servers[idx.srv_key].server_groups[idx.group_index]
group.stop_engines(engine_indices=idx.engine_indices)
Comment on lines +234 to +237
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Recomputing the cell mapping on every call to stop_cell is inefficient, especially if the number of engines is large. Additionally, the cell_id should be validated to ensure it is within the valid range and to prevent unexpected behavior with negative indices (which Python lists allow). Caching the mapping lazily on the instance is a good way to optimize this since self.servers is static after initialization.

Suggested change
async def stop_cell(self, cell_id: int):
idx = get_cell_indexer_of_id_map(self.servers)[cell_id]
group = self.servers[idx.srv_key].server_groups[idx.group_index]
group.stop_engines(engine_indices=idx.engine_indices)
async def stop_cell(self, cell_id: int):
if not hasattr(self, "_cell_id_to_indexer"):
self._cell_id_to_indexer = get_cell_id_to_indexer_map(self.servers)
if not (0 <= cell_id < len(self._cell_id_to_indexer)):
raise IndexError(f"cell_id {cell_id} is out of range (0-{len(self._cell_id_to_indexer) - 1})")
idx = self._cell_id_to_indexer[cell_id]
group = self.servers[idx.srv_key].server_groups[idx.group_index]
group.stop_engines(engine_indices=idx.engine_indices)


# -------------------------- misc APIs -----------------------------

def get_num_rollout_per_epoch(self):
Expand Down
26 changes: 26 additions & 0 deletions miles/ray/rollout/server_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import NamedTuple

from miles.ray.rollout.rollout_server import RolloutServer


class CellIndexer(NamedTuple):
srv_key: str
group_index: int
engine_indices: list[int]


def get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function name get_cell_indexer_of_id_map is confusing and grammatically awkward. Since it returns a list where the index represents the cell_id, a more descriptive name like get_cell_id_to_indexer_map would improve readability and maintainability.

Suggested change
def get_cell_indexer_of_id_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]:
def get_cell_id_to_indexer_map(servers: dict[str, RolloutServer]) -> list[CellIndexer]:

result: list[CellIndexer] = []
for srv_key, srv in servers.items():
for group_index, group in enumerate(srv.server_groups):
for local_cell in range(len(group.engines)):
result.append(
CellIndexer(
srv_key=srv_key,
group_index=group_index,
engine_indices=list(
range(local_cell * group.nodes_per_engine, (local_cell + 1) * group.nodes_per_engine)
),
)
)
return result
Loading