diff --git a/areal/infra/rpc/ray_rpc_server.py b/areal/infra/rpc/ray_rpc_server.py index 55b5f9d07d..0e75f6df3c 100644 --- a/areal/infra/rpc/ray_rpc_server.py +++ b/areal/infra/rpc/ray_rpc_server.py @@ -1,13 +1,21 @@ +import abc import os +import shlex +import subprocess +import sys +import time import traceback from concurrent.futures import Future from typing import Any import ray +import requests from areal.api import InferenceEngine, TrainEngine from areal.api.cli_args import BaseExperimentConfig from areal.infra.rpc.rtensor import RTensor +from areal.infra.rpc.serialization import deserialize_value, serialize_value +from areal.infra.utils.proc import kill_process_tree from areal.utils import logging, name_resolve, seeding from areal.utils.data import ( broadcast_tensor_container, @@ -17,25 +25,19 @@ from areal.utils.network import find_free_ports -@ray.remote -class RayRPCServer: +class RayServer(abc.ABC): """ - Ray engine container. Represents either: - - one training world rank, or - - one rollout instance - - Supports multiple named engines per worker for colocation scenarios. - - Placement group scheduling is controlled by the scheduler. - The actor is only responsible for the engine lifecycle and method calls - within this process. + Ray actor base class that all Ray actors under RayScheduler should inherit from """ - def __init__(self): + def __init__(self, config: BaseExperimentConfig, **kwargs): self._engines: dict[str, TrainEngine | InferenceEngine] = {} self._default_engine_name: str | None = None # For backward compatibility self._allocated_port = set() - self.logger = logging.getLogger("RayRPCServer") + self.config: BaseExperimentConfig = config + ctx = ray.get_runtime_context() + self.actor_name = ctx.get_actor_name() + self.logger = logging.getLogger(self.__class__.__name__) def _get_device(self): # lazy resolve the device inside worker process @@ -83,6 +85,53 @@ def set_env(self, env: dict[str, str]) -> None: for k, v in env.items(): os.environ[str(k)] = str(v) + def post_init(self, **kwargs) -> Any: + # the HTTPLauncher needs this, but keeping this here for interface compatibility + # launched after the actor has been deployed + pass + + @abc.abstractmethod + def create_engine( + self, + engine: str, + *init_args, + engine_name: str | None = None, + **init_kwargs, + ) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def call(self, method: str, *args, engine_name: str | None = None, **kwargs) -> Any: + raise NotImplementedError() + + @abc.abstractmethod + def destroy(self) -> None: + raise NotImplementedError() + + def __ray_shutdown__(self): + self.destroy() + + def __repr__(self): + return f"{self.__class__.__name__} [{self.actor_name}]" + + +@ray.remote +class RayRPCServer(RayServer): + """ + Ray engine container. Represents either: + - one training world rank, or + - one rollout instance + + Supports multiple named engines per worker for colocation scenarios. + + Placement group scheduling is controlled by the scheduler. + The actor is only responsible for the engine lifecycle and method calls + within this process. + """ + + def __init__(self, config: BaseExperimentConfig, **kwargs): + super().__init__(config, **kwargs) + def create_engine( self, engine: str, @@ -213,3 +262,202 @@ def destroy(self) -> None: self._engines.clear() self._default_engine_name = None ray.actor.exit_actor() + + +@ray.remote +class RayHTTPLauncher(RayServer): + """ + Ray implementation of a launcher to launch proxy servers and any HTTP servers + """ + + REQUIRED_ARGS = ("command", "worker_index", "role") + + def __init__(self, config: BaseExperimentConfig, **kwargs): + super().__init__(config, **kwargs) + + missing = [k for k in self.REQUIRED_ARGS if k not in kwargs] + if missing: + raise TypeError(f"Missing required kwargs: {missing}") + + self.command = kwargs["command"] + self.worker_index = kwargs["worker_index"] + self.role = kwargs["role"] + self.worker_ip = ray.util.get_node_ip_address() + self.worker_port = None + self.worker_process: subprocess.Popen | None = None + + def post_init(self, **kwargs): + self.worker_port = kwargs.get("port", self.alloc_ports(1)[0]) + self.worker_process = self.launch_server(port=self.worker_port) + + def create_engine( + self, + engine: str, + *init_args, + engine_name: str | None = None, + **init_kwargs, + ) -> None: + self.logger.debug(f"Initializing engine {engine}") + payload = { + "engine": engine, + "engine_name": engine_name, + "init_args": serialize_value(list(init_args)), + "init_kwargs": serialize_value(init_kwargs), + } + try: + self._post_request("create_engine", payload) + except Exception as e: + self.logger.error( + f"RayHTTPLauncher failed to create engine '{engine}' : {e}\n" + f"{traceback.format_exc()}" + ) + raise + + def call( + self, + method: str, + *args, + engine_name: str | None = None, + rpc_meta: dict[str, Any] | None = None, + **kwargs, + ) -> Any: + self.logger.debug( + f"Calling {method} on engine '{engine_name}' with arguments {args=} {kwargs=}" + ) + + payload = { + "method": method, + "engine_name": engine_name, + "rpc_meta": rpc_meta, + "args": serialize_value(list(args)), + "kwargs": serialize_value(kwargs), + } + try: + return self._post_request("call", payload) + except Exception as e: + self.logger.error( + f"RayHTTPLauncher failed for '{method}': {e}\n{traceback.format_exc()}" + ) + raise + + def destroy(self) -> None: + if self.worker_process and self.worker_process.poll() is None: + kill_process_tree(self.worker_process.pid, timeout=3, graceful=True) + self._default_engine_name = None + ray.actor.exit_actor() + + def launch_server(self, port): + # keeping this as a separate function to support Awex server launches later + if not self.command: + raise RuntimeError( + f"Command was not given to {self.__class__.__name__}.launch_server. Cannot launch without command." + ) + + cmd = [sys.executable, "-m"] + cmd.extend(shlex.split(self.command)) + cmd.extend(["--port", str(port)]) + + cmd.extend(["--experiment-name", self.config.experiment_name]) + cmd.extend(["--trial-name", self.config.trial_name]) + cmd.extend(["--role", self.role]) + cmd.extend(["--worker-index", str(self.worker_index)]) + + cluster_config = self.config.cluster + name_resolve = self.config.cluster.name_resolve + + cmd.extend(["--name-resolve-type", name_resolve.type]) + cmd.extend(["--nfs-record-root", name_resolve.nfs_record_root]) + cmd.extend(["--etcd3-addr", name_resolve.etcd3_addr]) + cmd.extend(["--fileroot", str(cluster_config.fileroot)]) + + _env = os.environ.copy() + self.worker_process = subprocess.Popen( + cmd, env=_env, stdout=sys.stdout, stderr=subprocess.STDOUT + ) + + try: + self._check_health() + except Exception as e: + self.logger.error(e) + kill_process_tree(self.worker_process.pid, timeout=3, graceful=True) + raise RuntimeError(f"Could not launch server with command {cmd}") + + return self.worker_process + + def _post_request( + self, + endpoint, + payload, + http_timeout: float = 7200.0, + max_retries: int = 3, + retry_delay: float = 1.0, + ): + url = f"{self.url}/{endpoint}" + last_error = "" + # adapted from local scheduler + for attempt in range(1, max_retries + 1): + if self.worker_process and self.worker_process.poll() is not None: + raise RuntimeError("Worker has terminated") + + try: + response = requests.post(url, json=payload, timeout=http_timeout) + response.raise_for_status() + result = response.json().get("result") + deserialized_result = deserialize_value(result) + return deserialized_result + + except requests.exceptions.HTTPError as e: + resp = e.response + + if resp is not None and resp.status_code in [400, 404, 500]: + try: + error_detail = resp.json().get("detail", "unknown error") + except Exception: + error_detail = resp.text or "unknown error" + raise RuntimeError(error_detail) + + last_error = ( + f"HTTP {resp.status_code}: {resp.text}" + if resp is not None + else str(e) + ) + except Exception as e: + last_error = str(e) + self.logger.warning( + f"Post failed when calling url {url} on actor '{self.actor_name}': {e}" + ) + + # otherwise retry + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + self.logger.warning( + f"Calling url {url} failed on actor '{self.actor_name}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + time.sleep(delay) + raise RuntimeError( + f"Max retries exceeded trying to call url {url}: {last_error or 'unknown error'}" + ) + + @property + def url(self): + return f"http://{self.worker_ip}:{self.worker_port}" + + def _check_health(self, timeout: float = 60.0): + url = f"{self.url}/health" + deadline = time.time() + timeout + while time.time() < deadline: + if self.worker_process and self.worker_process.poll() is not None: + raise RuntimeError("Server process exited before becoming healthy") + + try: + r = requests.get(url, timeout=2.0) + if r.status_code == 200: + return + except requests.RequestException: + # expected during startup + pass + time.sleep(1) + + raise RuntimeError(f"Health check timed out for {url}") diff --git a/areal/infra/scheduler/ray.py b/areal/infra/scheduler/ray.py index 6ea45ed85c..091aef2943 100644 --- a/areal/infra/scheduler/ray.py +++ b/areal/infra/scheduler/ray.py @@ -19,7 +19,7 @@ SchedulingSpec, SchedulingStrategyType, ) -from areal.infra.rpc.ray_rpc_server import RayRPCServer +from areal.infra.rpc.ray_rpc_server import RayHTTPLauncher, RayRPCServer from areal.infra.scheduler.exceptions import ( EngineCallError, WorkerCreationError, @@ -56,7 +56,7 @@ class RayWorkerInfo: class RayScheduler(Scheduler): def __init__( self, - startup_timeout: float = 30.0, + startup_timeout: float = 300.0, *, exp_config: BaseExperimentConfig | None = None, ): @@ -168,19 +168,18 @@ def _create_ray_workers( worker_ids: list[str] = [] placement_strategy = self._get_placement_strategy(schedulings) - placement_groups = placement_strategy.create_placement_group( + placement_strategy.create_placement_group( role, schedulings, self.exp_config.cluster.n_gpus_per_node, timeout=self.startup_timeout, ) - master_ip, master_port = get_placement_group_master_ip_and_port( - placement_groups[0], placement_group_bundle_index=0 - ) - for idx, spec in enumerate(schedulings): options, pg_scheduling_strategy = placement_strategy.actor_resources(spec) + master_ip, master_port = get_placement_group_master_ip_and_port( + pg_scheduling_strategy.placement_group, placement_group_bundle_index=0 + ) worker_id = f"{role}/{idx}" env = self._build_env_vars(spec) actor = RayRPCServer.options( @@ -188,7 +187,7 @@ def _create_ray_workers( name=worker_id, runtime_env=RuntimeEnv(env_vars=env), scheduling_strategy=pg_scheduling_strategy, - ).remote() + ).remote(config=self.exp_config) # 0 needed to pad the list as the trainer takes index 1 for ports worker_ports = ["0", str(master_port)] @@ -215,7 +214,8 @@ def _create_forked_workers_internal( role: str, target_role: str, target_workers: list[RayWorkerInfo], - schedulings, + schedulings: list[SchedulingSpec], + command: str | None = None, ) -> list[str]: """Create forked workers on same placement groups as target workers. @@ -244,9 +244,11 @@ def _create_forked_workers_internal( worker_info_list: list[RayWorkerInfo] = [] worker_ids: list[str] = [] + post_init_tasks: list[ray.ObjectRef] = [] for idx, (target_wi, spec) in enumerate(zip(target_workers, schedulings)): - worker_id = f"{role}/{idx}" + # include parent in ray name since role and iteration idx alone can cause name collision if forking multiple times + worker_id = f"{target_role}/{role}/{idx}" # Reuse placement group from target worker pg = target_wi.placement_group @@ -272,13 +274,23 @@ def _create_forked_workers_internal( additional_options = dict(num_gpus=0.01) else: additional_options = {"resources": {device: 0.01}} - actor = RayRPCServer.options( + if command and "rpc.rpc_server" not in command: + actor_cls = RayHTTPLauncher + else: + actor_cls = RayRPCServer + actor = actor_cls.options( **additional_options, num_cpus=0, # Minimal CPU allocation for forked actor name=worker_id, runtime_env=RuntimeEnv(env_vars=target_wi.env_vars), scheduling_strategy=PlacementGroupSchedulingStrategy(**strategy_kwargs), - ).remote() + ).remote( + config=self.exp_config, + # needed for RayHTTPLauncher + command=command, + worker_index=idx, + role=role, + ) # Build Worker object with same IP/ports as target worker_ports = ray.get( @@ -286,6 +298,8 @@ def _create_forked_workers_internal( count=len(target_wi.worker.worker_ports) ) ) + # run any post inits needed + post_init_tasks.append(actor.post_init.remote(port=worker_ports[0])) worker = Worker( id=worker_id, @@ -306,8 +320,14 @@ def _create_forked_workers_internal( worker_info_list.append(wi) worker_ids.append(worker_id) + try: + ray.get(post_init_tasks) + except Exception: + self._cleanup_forked_workers(worker_info_list) + raise + # Register forked workers - self._workers[role] = worker_info_list + self._workers.setdefault(role, []).extend(worker_info_list) for wi in worker_info_list: self._worker_info_by_id[wi.worker.id] = wi @@ -551,11 +571,6 @@ def fork_workers( Creates new Ray actors colocated with existing workers of the target role. The ``command`` parameter is ignored — Ray actors always run RayRPCServer. """ - if command is not None: - logger.warning( - f"RayScheduler.fork_workers: 'command' parameter is ignored. " - f"Ray actors always use RayRPCServer. Got command='{command}'" - ) if target_role not in self._workers: raise WorkerNotFoundError(f"Target role '{target_role}' not found for fork") @@ -564,10 +579,12 @@ def fork_workers( schedulings = [] for target_wi in target_workers: # Use minimal resources for forked workers - schedulings.append(SchedulingSpec(cpu=0, mem=0, gpu=1, port_count=1)) + # use 0 gpu to prevent any scheduling issues since forks so far only use cpu + # future forks that require gpu should change fork implementation to accept a scheduling spec + schedulings.append(SchedulingSpec(cpu=0, mem=0, gpu=0, port_count=1)) - worker_ids = self._create_forked_workers( - role, target_role, target_workers, schedulings + worker_ids = self._create_forked_workers_internal( + role, target_role, target_workers, schedulings, command ) self._colocated_roles[role] = target_role return worker_ids diff --git a/areal/infra/utils/ray.py b/areal/infra/utils/ray.py index 2bc630ef48..f6ab5e573d 100644 --- a/areal/infra/utils/ray.py +++ b/areal/infra/utils/ray.py @@ -13,10 +13,11 @@ def _master_ip_and_port(): port = find_free_ports(1, (10000, 60000))[0] return host_ip, port + # 0 resources as task cannot be scheduled in certain scenarios future = ray.remote( - num_cpus=1, + num_cpus=0, num_gpus=0, - memory=10 * 1024 * 1024, # Convert MB to bytes + memory=0, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_index, diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index d6f8dc4720..e7a0774cdb 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -1244,9 +1244,6 @@ def _ensure_proxy_started(self) -> None: if not is_single_controller(): raise NotImplementedError("Proxy workers not supported in SPMD mode") - if self.config.scheduler.type == "ray": - raise NotImplementedError("Proxy workers not supported with RayScheduler") - assert isinstance(self.rollout, RolloutController) logger.info("Initializing proxy workers for AgentWorkflow support") diff --git a/tests/test_ray_scheduler.py b/tests/test_ray_scheduler.py index 90376ed0ae..43afacc164 100644 --- a/tests/test_ray_scheduler.py +++ b/tests/test_ray_scheduler.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import pytest +import ray from ray.util.state import summarize_actors from areal.api import Job, Worker @@ -376,3 +377,33 @@ def test_non_fork_colocation_reuses_workers(self): # Clean up scheduler.delete_workers() + + def test_fork_proxy(self): + # create the proxy server from a rollout + config = BaseExperimentConfig() + + scheduler = RayScheduler(startup_timeout=60.0, exp_config=config) + + rollout_job = Job( + replicas=1, + role="rollout", + tasks=[ + SchedulingSpec(cpu=1, mem=1, gpu=0), + ], + ) + scheduler.create_workers(rollout_job) + + command = "areal.experimental.openai.proxy.proxy_rollout_server" + worker_ids = scheduler.fork_workers( + role="proxy", + target_role="rollout", + command=command, + ) + + # health check each proxy to make sure they are active + for wid in worker_ids: + worker_info = scheduler._worker_info_by_id[wid] + actor_ref = worker_info.actor + ray.get(actor_ref._check_health.remote()) + + scheduler.delete_workers()