-
Notifications
You must be signed in to change notification settings - Fork 496
feat(infra): Support for proxy server through RayScheduler #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
04b28c2
50bfcac
12b4962
398f0bc
468c027
324203c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,13 +1,21 @@ | ||||||||||||||
| import abc | ||||||||||||||
| import os | ||||||||||||||
| import shlex | ||||||||||||||
| import subprocess | ||||||||||||||
| import sys | ||||||||||||||
| import time | ||||||||||||||
| import traceback | ||||||||||||||
| from concurrent.futures import Future | ||||||||||||||
| from typing import Any | ||||||||||||||
|
|
||||||||||||||
| import ray | ||||||||||||||
| import requests | ||||||||||||||
|
|
||||||||||||||
| from areal.api import InferenceEngine, TrainEngine | ||||||||||||||
| from areal.api.cli_args import BaseExperimentConfig | ||||||||||||||
| from areal.infra.rpc.rtensor import RTensor | ||||||||||||||
| from areal.infra.rpc.serialization import deserialize_value, serialize_value | ||||||||||||||
| from areal.infra.utils.proc import kill_process_tree | ||||||||||||||
| from areal.utils import logging, name_resolve, seeding | ||||||||||||||
| from areal.utils.data import ( | ||||||||||||||
| broadcast_tensor_container, | ||||||||||||||
|
|
@@ -17,25 +25,19 @@ | |||||||||||||
| from areal.utils.network import find_free_ports | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @ray.remote | ||||||||||||||
| class RayRPCServer: | ||||||||||||||
| class RayServer(abc.ABC): | ||||||||||||||
| """ | ||||||||||||||
| Ray engine container. Represents either: | ||||||||||||||
| - one training world rank, or | ||||||||||||||
| - one rollout instance | ||||||||||||||
|
|
||||||||||||||
| Supports multiple named engines per worker for colocation scenarios. | ||||||||||||||
|
|
||||||||||||||
| Placement group scheduling is controlled by the scheduler. | ||||||||||||||
| The actor is only responsible for the engine lifecycle and method calls | ||||||||||||||
| within this process. | ||||||||||||||
| Ray actor base class that all Ray actors under RayScheduler should inherit from | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| def __init__(self): | ||||||||||||||
| def __init__(self, config: BaseExperimentConfig, **kwargs): | ||||||||||||||
| self._engines: dict[str, TrainEngine | InferenceEngine] = {} | ||||||||||||||
| self._default_engine_name: str | None = None # For backward compatibility | ||||||||||||||
| self._allocated_port = set() | ||||||||||||||
| self.logger = logging.getLogger("RayRPCServer") | ||||||||||||||
| self.config: BaseExperimentConfig = config | ||||||||||||||
| ctx = ray.get_runtime_context() | ||||||||||||||
| self.actor_name = ctx.get_actor_name() | ||||||||||||||
| self.logger = logging.getLogger(self.__class__.__name__) | ||||||||||||||
|
|
||||||||||||||
| def _get_device(self): | ||||||||||||||
| # lazy resolve the device inside worker process | ||||||||||||||
|
|
@@ -83,6 +85,53 @@ def set_env(self, env: dict[str, str]) -> None: | |||||||||||||
| for k, v in env.items(): | ||||||||||||||
| os.environ[str(k)] = str(v) | ||||||||||||||
|
|
||||||||||||||
| def post_init(self, **kwargs) -> Any: | ||||||||||||||
| # the HTTPLauncher needs this, but keeping this here for interface compatibility | ||||||||||||||
| # launched after the actor has been deployed | ||||||||||||||
| pass | ||||||||||||||
|
|
||||||||||||||
| @abc.abstractmethod | ||||||||||||||
| def create_engine( | ||||||||||||||
| self, | ||||||||||||||
| engine: str, | ||||||||||||||
| *init_args, | ||||||||||||||
| engine_name: str | None = None, | ||||||||||||||
| **init_kwargs, | ||||||||||||||
| ) -> None: | ||||||||||||||
| raise NotImplementedError() | ||||||||||||||
|
|
||||||||||||||
| @abc.abstractmethod | ||||||||||||||
| def call(self, method: str, *args, engine_name: str | None = None, **kwargs) -> Any: | ||||||||||||||
| raise NotImplementedError() | ||||||||||||||
|
|
||||||||||||||
| @abc.abstractmethod | ||||||||||||||
| def destroy(self) -> None: | ||||||||||||||
| raise NotImplementedError() | ||||||||||||||
|
|
||||||||||||||
| def __ray_shutdown__(self): | ||||||||||||||
| self.destroy() | ||||||||||||||
|
|
||||||||||||||
| def __repr__(self): | ||||||||||||||
| return f"{self.__class__.__name__} [{self.actor_name}]" | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @ray.remote | ||||||||||||||
| class RayRPCServer(RayServer): | ||||||||||||||
| """ | ||||||||||||||
| Ray engine container. Represents either: | ||||||||||||||
| - one training world rank, or | ||||||||||||||
| - one rollout instance | ||||||||||||||
|
|
||||||||||||||
| Supports multiple named engines per worker for colocation scenarios. | ||||||||||||||
|
|
||||||||||||||
| Placement group scheduling is controlled by the scheduler. | ||||||||||||||
| The actor is only responsible for the engine lifecycle and method calls | ||||||||||||||
| within this process. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, config: BaseExperimentConfig, **kwargs): | ||||||||||||||
| super().__init__(config, **kwargs) | ||||||||||||||
|
|
||||||||||||||
| def create_engine( | ||||||||||||||
| self, | ||||||||||||||
| engine: str, | ||||||||||||||
|
|
@@ -213,3 +262,191 @@ def destroy(self) -> None: | |||||||||||||
| self._engines.clear() | ||||||||||||||
| self._default_engine_name = None | ||||||||||||||
| ray.actor.exit_actor() | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| @ray.remote | ||||||||||||||
| class RayHTTPLauncher(RayServer): | ||||||||||||||
| """ | ||||||||||||||
| Ray implementation of a launcher to launch proxy servers and any HTTP servers | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| REQUIRED_ARGS = ("command", "worker_index", "role") | ||||||||||||||
|
|
||||||||||||||
| def __init__(self, config: BaseExperimentConfig, **kwargs): | ||||||||||||||
| super().__init__(config, **kwargs) | ||||||||||||||
|
|
||||||||||||||
| missing = [k for k in self.REQUIRED_ARGS if k not in kwargs] | ||||||||||||||
| if missing: | ||||||||||||||
| raise TypeError(f"Missing required kwargs: {missing}") | ||||||||||||||
|
|
||||||||||||||
| self.command = kwargs["command"] | ||||||||||||||
| self.worker_index = kwargs["worker_index"] | ||||||||||||||
| self.role = kwargs["role"] | ||||||||||||||
| self.worker_ip = ray.util.get_node_ip_address() | ||||||||||||||
| self.worker_port = None | ||||||||||||||
| self.worker_process: subprocess.Popen | None = None | ||||||||||||||
|
|
||||||||||||||
| def post_init(self, **kwargs): | ||||||||||||||
| self.worker_port = kwargs.get("port", self.alloc_ports(1)[0]) | ||||||||||||||
| self.worker_process = self.launch_server(port=self.worker_port) | ||||||||||||||
|
|
||||||||||||||
| def create_engine( | ||||||||||||||
| self, | ||||||||||||||
| engine: str, | ||||||||||||||
| *init_args, | ||||||||||||||
| engine_name: str | None = None, | ||||||||||||||
| **init_kwargs, | ||||||||||||||
| ) -> None: | ||||||||||||||
| self.logger.debug(f"Initializing engine {engine}") | ||||||||||||||
| payload = { | ||||||||||||||
| "engine": engine, | ||||||||||||||
| "engine_name": engine_name, | ||||||||||||||
| "init_args": serialize_value(list(init_args)), | ||||||||||||||
| "init_kwargs": serialize_value(init_kwargs), | ||||||||||||||
| } | ||||||||||||||
| try: | ||||||||||||||
| self._post_request("create_engine", payload) | ||||||||||||||
| except Exception as e: | ||||||||||||||
| self.logger.error( | ||||||||||||||
| f"RayHTTPLauncher failed to create engine '{engine}' : {e}\n" | ||||||||||||||
| f"{traceback.format_exc()}" | ||||||||||||||
| ) | ||||||||||||||
| raise | ||||||||||||||
|
|
||||||||||||||
| def call( | ||||||||||||||
| self, | ||||||||||||||
| method: str, | ||||||||||||||
| *args, | ||||||||||||||
| engine_name: str | None = None, | ||||||||||||||
| rpc_meta: dict[str, Any] | None = None, | ||||||||||||||
| **kwargs, | ||||||||||||||
| ) -> Any: | ||||||||||||||
| self.logger.debug( | ||||||||||||||
| f"Calling {method} on engine '{engine_name}' with arguments {args=} {kwargs=}" | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| payload = { | ||||||||||||||
| "method": method, | ||||||||||||||
| "engine_name": engine_name, | ||||||||||||||
| "rpc_meta": rpc_meta, | ||||||||||||||
| "args": serialize_value(list(args)), | ||||||||||||||
| "kwargs": serialize_value(kwargs), | ||||||||||||||
| } | ||||||||||||||
| try: | ||||||||||||||
| return self._post_request("call", payload) | ||||||||||||||
| except Exception as e: | ||||||||||||||
| self.logger.error( | ||||||||||||||
| f"RayHTTPLauncher failed for '{method}': {e}\n{traceback.format_exc()}" | ||||||||||||||
| ) | ||||||||||||||
| raise | ||||||||||||||
|
|
||||||||||||||
| def destroy(self) -> None: | ||||||||||||||
| if self.worker_process and self.worker_process.poll() is None: | ||||||||||||||
| kill_process_tree(self.worker_process.pid, timeout=3, graceful=True) | ||||||||||||||
| self._default_engine_name = None | ||||||||||||||
| ray.actor.exit_actor() | ||||||||||||||
|
|
||||||||||||||
| def launch_server(self, port): | ||||||||||||||
| # keeping this as a separate function to support Awex server launches later | ||||||||||||||
| if not self.command: | ||||||||||||||
| raise RuntimeError( | ||||||||||||||
| f"Command was not given to {self.__class__.__name__}.launch_server. Cannot launch without command." | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| cmd = [sys.executable, "-m"] | ||||||||||||||
| cmd.extend(shlex.split(self.command)) | ||||||||||||||
| cmd.extend(["--port", str(port)]) | ||||||||||||||
|
|
||||||||||||||
| cmd.extend(["--experiment-name", self.config.experiment_name]) | ||||||||||||||
| cmd.extend(["--trial-name", self.config.trial_name]) | ||||||||||||||
| cmd.extend(["--role", self.role]) | ||||||||||||||
| cmd.extend(["--worker-index", str(self.worker_index)]) | ||||||||||||||
|
|
||||||||||||||
| cluster_config = self.config.cluster | ||||||||||||||
| name_resolve = self.config.cluster.name_resolve | ||||||||||||||
|
|
||||||||||||||
| cmd.extend(["--name-resolve-type", name_resolve.type]) | ||||||||||||||
| cmd.extend(["--nfs-record-root", name_resolve.nfs_record_root]) | ||||||||||||||
| cmd.extend(["--etcd3-addr", name_resolve.etcd3_addr]) | ||||||||||||||
| cmd.extend(["--fileroot", str(cluster_config.fileroot)]) | ||||||||||||||
|
|
||||||||||||||
| _env = os.environ.copy() | ||||||||||||||
| self.worker_process = subprocess.Popen( | ||||||||||||||
| cmd, env=_env, stdout=sys.stdout, stderr=subprocess.STDOUT | ||||||||||||||
| ) | ||||||||||||||
|
Comment on lines
+374
to
+376
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't face such an issue when running with Ray. I can change to |
||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| self._check_health() | ||||||||||||||
| except Exception as e: | ||||||||||||||
| self.logger.error(e) | ||||||||||||||
| kill_process_tree(self.worker_process.pid, timeout=3, graceful=True) | ||||||||||||||
| raise RuntimeError(f"Could not launch server with command {cmd}") | ||||||||||||||
|
|
||||||||||||||
| return self.worker_process | ||||||||||||||
|
|
||||||||||||||
| def _post_request( | ||||||||||||||
| self, | ||||||||||||||
| endpoint, | ||||||||||||||
| payload, | ||||||||||||||
| http_timeout: float = 7200.0, | ||||||||||||||
| max_retries: int = 3, | ||||||||||||||
| retry_delay: float = 1.0, | ||||||||||||||
| ): | ||||||||||||||
| url = f"{self.url}/{endpoint}" | ||||||||||||||
| last_error = "" | ||||||||||||||
| # adapted from local scheduler | ||||||||||||||
| for attempt in range(1, max_retries + 1): | ||||||||||||||
| if self.worker_process and self.worker_process.poll() is not None: | ||||||||||||||
| raise RuntimeError("Worker has terminated") | ||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| response = requests.post(url, json=payload, timeout=http_timeout) | ||||||||||||||
| if response.status_code == 200: | ||||||||||||||
| result = response.json().get("result") | ||||||||||||||
| deserialized_result = deserialize_value(result) | ||||||||||||||
| return deserialized_result | ||||||||||||||
| elif response.status_code in [400, 404, 500]: | ||||||||||||||
| error_detail = response.json().get("detail", "unknown error") | ||||||||||||||
| raise RuntimeError(error_detail) | ||||||||||||||
|
hlyli marked this conversation as resolved.
Outdated
|
||||||||||||||
|
|
||||||||||||||
| last_error = f"HTTP {response.status_code}: {response.text}" | ||||||||||||||
| except Exception as e: | ||||||||||||||
| last_error = str(e) | ||||||||||||||
| self.logger.warning( | ||||||||||||||
| f"Post failed when calling url {url} on actor '{self.actor_name}': {e}" | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # otherwise retry | ||||||||||||||
| if attempt < max_retries: | ||||||||||||||
| delay = retry_delay * (2 ** (attempt - 1)) | ||||||||||||||
| self.logger.warning( | ||||||||||||||
| f"Calling url {url} failed on actor '{self.actor_name}' " | ||||||||||||||
| f"(attempt {attempt}/{max_retries}): {last_error}. " | ||||||||||||||
| f"Retrying in {delay:.1f}s..." | ||||||||||||||
| ) | ||||||||||||||
| time.sleep(delay) | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||||||||||||||
| raise RuntimeError( | ||||||||||||||
| f"Max retries exceeded trying to call url {url}: {last_error or 'unknown error'}" | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| @property | ||||||||||||||
| def url(self): | ||||||||||||||
| return f"http://{self.worker_ip}:{self.worker_port}" | ||||||||||||||
|
|
||||||||||||||
| def _check_health(self, timeout: float = 60.0): | ||||||||||||||
| url = f"{self.url}/health" | ||||||||||||||
| deadline = time.time() + timeout | ||||||||||||||
| while time.time() < deadline: | ||||||||||||||
| if self.worker_process and self.worker_process.poll() is not None: | ||||||||||||||
| raise RuntimeError("Server process exited before becoming healthy") | ||||||||||||||
|
|
||||||||||||||
| try: | ||||||||||||||
| r = requests.get(url, timeout=2.0) | ||||||||||||||
| if r.status_code == 200: | ||||||||||||||
| return | ||||||||||||||
| except requests.RequestException: | ||||||||||||||
| # expected during startup | ||||||||||||||
| pass | ||||||||||||||
| time.sleep(1) | ||||||||||||||
|
|
||||||||||||||
| raise RuntimeError(f"Health check timed out for {url}") | ||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.