Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 250 additions & 13 deletions areal/infra/rpc/ray_rpc_server.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import abc
import os
import shlex
import subprocess
import sys
import time
import traceback
from concurrent.futures import Future
from typing import Any

import ray
import requests

from areal.api import InferenceEngine, TrainEngine
from areal.api.cli_args import BaseExperimentConfig
from areal.infra.rpc.rtensor import RTensor
from areal.infra.rpc.serialization import deserialize_value, serialize_value
from areal.infra.utils.proc import kill_process_tree
from areal.utils import logging, name_resolve, seeding
from areal.utils.data import (
broadcast_tensor_container,
Expand All @@ -17,25 +25,19 @@
from areal.utils.network import find_free_ports


@ray.remote
class RayRPCServer:
class RayServer(abc.ABC):
"""
Ray engine container. Represents either:
- one training world rank, or
- one rollout instance

Supports multiple named engines per worker for colocation scenarios.

Placement group scheduling is controlled by the scheduler.
The actor is only responsible for the engine lifecycle and method calls
within this process.
Ray actor base class that all Ray actors under RayScheduler should inherit from
"""

def __init__(self):
def __init__(self, config: BaseExperimentConfig, **kwargs):
self._engines: dict[str, TrainEngine | InferenceEngine] = {}
self._default_engine_name: str | None = None # For backward compatibility
self._allocated_port = set()
self.logger = logging.getLogger("RayRPCServer")
self.config: BaseExperimentConfig = config
ctx = ray.get_runtime_context()
self.actor_name = ctx.get_actor_name()
self.logger = logging.getLogger(self.__class__.__name__)

def _get_device(self):
# lazy resolve the device inside worker process
Expand Down Expand Up @@ -83,6 +85,53 @@ def set_env(self, env: dict[str, str]) -> None:
for k, v in env.items():
os.environ[str(k)] = str(v)

def post_init(self, **kwargs) -> Any:
# the HTTPLauncher needs this, but keeping this here for interface compatibility
# launched after the actor has been deployed
pass

@abc.abstractmethod
def create_engine(
self,
engine: str,
*init_args,
engine_name: str | None = None,
**init_kwargs,
) -> None:
raise NotImplementedError()

@abc.abstractmethod
def call(self, method: str, *args, engine_name: str | None = None, **kwargs) -> Any:
raise NotImplementedError()

@abc.abstractmethod
def destroy(self) -> None:
raise NotImplementedError()

def __ray_shutdown__(self):
self.destroy()

def __repr__(self):
return f"{self.__class__.__name__} [{self.actor_name}]"


@ray.remote
class RayRPCServer(RayServer):
"""
Ray engine container. Represents either:
- one training world rank, or
- one rollout instance

Supports multiple named engines per worker for colocation scenarios.

Placement group scheduling is controlled by the scheduler.
The actor is only responsible for the engine lifecycle and method calls
within this process.
"""

def __init__(self, config: BaseExperimentConfig, **kwargs):
super().__init__(config, **kwargs)

def create_engine(
self,
engine: str,
Expand Down Expand Up @@ -213,3 +262,191 @@ def destroy(self) -> None:
self._engines.clear()
self._default_engine_name = None
ray.actor.exit_actor()


@ray.remote
class RayHTTPLauncher(RayServer):
"""
Ray implementation of a launcher to launch proxy servers and any HTTP servers
"""

REQUIRED_ARGS = ("command", "worker_index", "role")

def __init__(self, config: BaseExperimentConfig, **kwargs):
super().__init__(config, **kwargs)

missing = [k for k in self.REQUIRED_ARGS if k not in kwargs]
if missing:
raise TypeError(f"Missing required kwargs: {missing}")

self.command = kwargs["command"]
self.worker_index = kwargs["worker_index"]
self.role = kwargs["role"]
self.worker_ip = ray.util.get_node_ip_address()
self.worker_port = None
self.worker_process: subprocess.Popen | None = None

def post_init(self, **kwargs):
self.worker_port = kwargs.get("port", self.alloc_ports(1)[0])
self.worker_process = self.launch_server(port=self.worker_port)

def create_engine(
self,
engine: str,
*init_args,
engine_name: str | None = None,
**init_kwargs,
) -> None:
self.logger.debug(f"Initializing engine {engine}")
payload = {
"engine": engine,
"engine_name": engine_name,
"init_args": serialize_value(list(init_args)),
"init_kwargs": serialize_value(init_kwargs),
}
try:
self._post_request("create_engine", payload)
except Exception as e:
self.logger.error(
f"RayHTTPLauncher failed to create engine '{engine}' : {e}\n"
f"{traceback.format_exc()}"
)
raise

def call(
self,
method: str,
*args,
engine_name: str | None = None,
rpc_meta: dict[str, Any] | None = None,
**kwargs,
) -> Any:
self.logger.debug(
f"Calling {method} on engine '{engine_name}' with arguments {args=} {kwargs=}"
)

payload = {
"method": method,
"engine_name": engine_name,
"rpc_meta": rpc_meta,
"args": serialize_value(list(args)),
"kwargs": serialize_value(kwargs),
}
Comment thread
hlyli marked this conversation as resolved.
try:
return self._post_request("call", payload)
except Exception as e:
self.logger.error(
f"RayHTTPLauncher failed for '{method}': {e}\n{traceback.format_exc()}"
)
raise

def destroy(self) -> None:
if self.worker_process and self.worker_process.poll() is None:
kill_process_tree(self.worker_process.pid, timeout=3, graceful=True)
self._default_engine_name = None
ray.actor.exit_actor()

def launch_server(self, port):
# keeping this as a separate function to support Awex server launches later
if not self.command:
raise RuntimeError(
f"Command was not given to {self.__class__.__name__}.launch_server. Cannot launch without command."
)

cmd = [sys.executable, "-m"]
cmd.extend(shlex.split(self.command))
cmd.extend(["--port", str(port)])

cmd.extend(["--experiment-name", self.config.experiment_name])
cmd.extend(["--trial-name", self.config.trial_name])
cmd.extend(["--role", self.role])
cmd.extend(["--worker-index", str(self.worker_index)])

cluster_config = self.config.cluster
name_resolve = self.config.cluster.name_resolve

cmd.extend(["--name-resolve-type", name_resolve.type])
cmd.extend(["--nfs-record-root", name_resolve.nfs_record_root])
cmd.extend(["--etcd3-addr", name_resolve.etcd3_addr])
cmd.extend(["--fileroot", str(cluster_config.fileroot)])

_env = os.environ.copy()
self.worker_process = subprocess.Popen(
cmd, env=_env, stdout=sys.stdout, stderr=subprocess.STDOUT
)
Comment on lines +374 to +376
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using sys.stdout as the stdout argument for subprocess.Popen within a Ray actor can be problematic. In many Ray environments, sys.stdout is redirected to a custom stream object that does not have a valid file descriptor (fileno()), which will cause Popen to raise an UnsupportedOperation or ValueError. It is safer to use None to inherit the standard output or redirect to a specific log file.

Suggested change
self.worker_process = subprocess.Popen(
cmd, env=_env, stdout=sys.stdout, stderr=subprocess.STDOUT
)
self.worker_process = subprocess.Popen(
cmd, env=_env, stdout=None, stderr=subprocess.STDOUT
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't face such an issue when running with Ray. I can change to None if needed.


try:
self._check_health()
except Exception as e:
self.logger.error(e)
kill_process_tree(self.worker_process.pid, timeout=3, graceful=True)
raise RuntimeError(f"Could not launch server with command {cmd}")

return self.worker_process

def _post_request(
self,
endpoint,
payload,
http_timeout: float = 7200.0,
max_retries: int = 3,
retry_delay: float = 1.0,
):
url = f"{self.url}/{endpoint}"
last_error = ""
# adapted from local scheduler
for attempt in range(1, max_retries + 1):
if self.worker_process and self.worker_process.poll() is not None:
raise RuntimeError("Worker has terminated")

try:
response = requests.post(url, json=payload, timeout=http_timeout)
if response.status_code == 200:
result = response.json().get("result")
deserialized_result = deserialize_value(result)
return deserialized_result
elif response.status_code in [400, 404, 500]:
error_detail = response.json().get("detail", "unknown error")
raise RuntimeError(error_detail)
Comment thread
hlyli marked this conversation as resolved.
Outdated

last_error = f"HTTP {response.status_code}: {response.text}"
except Exception as e:
last_error = str(e)
self.logger.warning(
f"Post failed when calling url {url} on actor '{self.actor_name}': {e}"
)

# otherwise retry
if attempt < max_retries:
delay = retry_delay * (2 ** (attempt - 1))
self.logger.warning(
f"Calling url {url} failed on actor '{self.actor_name}' "
f"(attempt {attempt}/{max_retries}): {last_error}. "
f"Retrying in {delay:.1f}s..."
)
time.sleep(delay)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using time.sleep inside a Ray actor method blocks the actor's execution thread. This prevents the actor from responding to other incoming messages, such as ping() or destroy(), which could lead to unnecessary timeouts in the scheduler. If the actor is intended to be responsive during retries, consider using asyncio.sleep (if the actor is async) or reducing the retry delay.

raise RuntimeError(
f"Max retries exceeded trying to call url {url}: {last_error or 'unknown error'}"
)

@property
def url(self):
return f"http://{self.worker_ip}:{self.worker_port}"

def _check_health(self, timeout: float = 60.0):
url = f"{self.url}/health"
deadline = time.time() + timeout
while time.time() < deadline:
if self.worker_process and self.worker_process.poll() is not None:
raise RuntimeError("Server process exited before becoming healthy")

try:
r = requests.get(url, timeout=2.0)
if r.status_code == 200:
return
except requests.RequestException:
# expected during startup
pass
time.sleep(1)

raise RuntimeError(f"Health check timed out for {url}")
Loading
Loading