diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 70da16426..a66456a95 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -989,6 +989,139 @@ class SlurmLauncherConfig: ) +@dataclass +class SkyPilotLauncherConfig: + """Configuration for launching the training jobs with SkyPilot.""" + + # Basic task metadata + name: str | None = field( + default=None, + metadata={"help": "Optional task name displayed in SkyPilot."}, + ) + workdir: str | None = field( + default=None, + metadata={ + "help": "Local path or git repo spec to sync as working directory (mirrors SkyPilot YAML workdir)." + }, + ) + + # Core resource specification (subset of SkyPilot resources.* fields) + infra: str | None = field( + default=None, + metadata={ + "help": "Infrastructure spec // or k8s[/context] (resources.infra).", + }, + ) + accelerator_type: str | None = field( + default=None, + metadata={ + "help": "Accelerator request, e.g. 'H100', 'A100'. Number of GPUs on the node is determined by `cluster.n_gpus_per_node`.", + }, + ) + accelerator_args: str | None = field( + default=None, + metadata={ + "help": "Additional accelerator args (YAML/JSON string) (resources.accelerator_args).", + }, + ) + use_spot: bool = field( + default=False, + metadata={ + "help": "Whether to use spot/preemptible instances (resources.use_spot)." + }, + ) + disk_size: str | None = field( + default=None, + metadata={ + "help": "Boot disk size with optional unit, e.g. '256', '256GB' (resources.disk_size)." + }, + ) + disk_tier: str = field( + default="medium", + metadata={ + "help": "Disk performance tier (resources.disk_tier).", + "choices": ["low", "medium", "high", "ultra", "best"], + }, + ) + network_tier: str = field( + default="standard", + metadata={ + "help": "Network tier (resources.network_tier).", + "choices": ["standard", "best"], + }, + ) + ports: str | None = field( + default=None, + metadata={ + "help": "Ports to expose, supports single '8080', range '10052-10100', or comma list (resources.ports)." + }, + ) + image_id: str | None = field( + default=None, + metadata={"help": "Custom base or docker image id (resources.image_id)."}, + ) + labels: str | None = field( + default=None, + metadata={ + "help": "Instance/pod labels as key=value pairs joined by commas (resources.labels)." + }, + ) + + any_of: str | None = field( + default=None, + metadata={ + "help": "YAML/JSON string list of candidate resource dicts (resources.any_of)." + }, + ) + ordered: str | None = field( + default=None, + metadata={ + "help": "YAML/JSON string list of ordered resource dicts (resources.ordered)." + }, + ) + job_recovery: str | None = field( + default=None, + metadata={ + "help": "Job recovery strategy spec (resources.job_recovery). Provide JSON/YAML." + }, + ) + + # Autostop: store flexible spec as raw string + autostop: str | None = field( + default=None, + metadata={ + "help": "Autostop configuration: true/false/10/10h or JSON object (resources.autostop)." + }, + ) + + # Environment & secret injection + envs: str | None = field( + default=None, + metadata={ + "help": "Environment variables (envs) as KEY=VAL pairs joined by commas." + }, + ) + secrets: str | None = field( + default=None, + metadata={ + "help": "Secrets usable in setup/run as KEY=VAL pairs joined by commas (will be redacted)." + }, + ) + + volumes: str | None = field( + default=None, + metadata={ + "help": "Kubernetes volume mappings (volumes) as YAML/JSON string or shorthand mount spec." + }, + ) + file_mounts: str | None = field( + default=None, + metadata={ + "help": "File mounts mapping remote_path:local_path lines or JSON/YAML string (file_mounts)." + }, + ) + + @dataclass class LauncherConfig: """Configuration for launching the LLM server and trainer processes.""" @@ -1027,6 +1160,10 @@ class LauncherConfig: default_factory=SlurmLauncherConfig, metadata={"help": "Slurm launcher configuration."}, ) + skypilot: SkyPilotLauncherConfig = field( + default_factory=SkyPilotLauncherConfig, + metadata={"help": "SkyPilot launcher configuration."}, + ) @dataclass diff --git a/areal/launcher/skypilot.py b/areal/launcher/skypilot.py new file mode 100644 index 000000000..68e75ad65 --- /dev/null +++ b/areal/launcher/skypilot.py @@ -0,0 +1,632 @@ +"""Launch AReaL experiments on SkyPilot-managed clusters. + +This launcher mirrors the semantics of the Ray and Slurm launchers while +delegating provisioning and task execution to the SkyPilot Python SDK. +""" + +from __future__ import annotations + +import re +import shlex +import sys +import time +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import yaml + +from areal.api.alloc_mode import AllocationMode, AllocationType +from areal.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + RecoverConfig, + SGLangConfig, + SkyPilotLauncherConfig, + parse_cli_args, + to_structured_cfg, + vLLMConfig, +) +from areal.utils import logging, name_resolve, names +from areal.utils.launcher import ( + JobException, + JobInfo, + JobState, + get_env_vars, + validate_config_for_distributed_launcher, + wait_llm_server_addrs, +) +from areal.utils.recover import check_if_recover + +logger = logging.getLogger("SkyPilotLauncher") + + +try: + import sky + from sky import JobStatus as SkyJobStatus +except ImportError as exc: # pragma: no cover - handled at runtime + raise ImportError( + "SkyPilot launcher requires the `skypilot` package. " + "Install it via `pip install -U skypilot`." + ) from exc + + +SKY_TO_JOB_STATE: Dict[SkyJobStatus, JobState] = { + SkyJobStatus.INIT: JobState.PENDING, + SkyJobStatus.PENDING: JobState.PENDING, + SkyJobStatus.SETTING_UP: JobState.PENDING, + SkyJobStatus.RUNNING: JobState.RUNNING, + SkyJobStatus.SUCCEEDED: JobState.COMPLETED, + SkyJobStatus.FAILED: JobState.FAILED, + SkyJobStatus.FAILED_SETUP: JobState.FAILED, + SkyJobStatus.FAILED_DRIVER: JobState.FAILED, + SkyJobStatus.CANCELLED: JobState.CANCELLED, +} + + +SKY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds + + +def _readable_cluster_name(experiment_name: str, trial_name: str) -> str: + slug = f"areal-{experiment_name}-{trial_name}" + return re.sub(r"[^a-zA-Z0-9-]", "-", slug).lower() + + +def _parse_key_value_pairs(value: Optional[str]) -> Dict[str, str]: + if not value: + return {} + result = {} + for chunk in value.split(","): + if not chunk: + continue + if "=" not in chunk: + raise ValueError( + f"Environment/secret entry '{chunk}' must be in KEY=VALUE format." + ) + key, val = chunk.split("=", 1) + result[key.strip()] = val.strip() + return result + + +def _parse_yaml_like(value: Optional[str]) -> Any: + if not value: + return None + return yaml.safe_load(value) + + +def _default_workdir(skypilot_cfg: SkyPilotLauncherConfig) -> str: + if skypilot_cfg.workdir: + return skypilot_cfg.workdir + return str(Path.cwd()) + + +RunSpec = Union[str, Callable[[int, List[str]], str]] + + +class SkyPilotLauncher: + def __init__( + self, + experiment_name: str, + trial_name: str, + config: LauncherConfig, + cluster_spec_config: ClusterSpecConfig, + ): + self.experiment_name = experiment_name + self.trial_name = trial_name + self.config = config + self.skypilot_config = config.skypilot + self.cluster_spec_config = cluster_spec_config + + self.cluster_name = _readable_cluster_name(experiment_name, trial_name) + self._ensure_cluster() + + def _build_resources(self) -> sky.Resources: + assert self.skypilot_config.accelerator_type is not None + self.cluster_spec_config.n_nodes + accelerator_type = self.skypilot_config.accelerator_type + n_gpus_per_node = self.cluster_spec_config.n_gpus_per_node + accelerators = f"{accelerator_type}:{n_gpus_per_node}" + # TODO: build sky.Resources according to SkyPilotLauncherConfig cli_args.py + + def _base_task( + self, + name: str, + num_nodes: int, + run: RunSpec, + ) -> sky.Task: + base_envs = _parse_key_value_pairs(self.skypilot_config.envs) + secrets = _parse_key_value_pairs(self.skypilot_config.secrets) + if secrets: + base_envs.update(secrets) + workdir = _default_workdir(self.skypilot_config) + file_mounts = None + if self.skypilot_config.file_mounts: + file_mounts = _parse_yaml_like(self.skypilot_config.file_mounts) + resources = self._build_resources(self.skypilot_config) + task_kwargs: Dict[str, Any] = { + "name": name, + "num_nodes": num_nodes, + "run": run, + "workdir": workdir, + } + if base_envs: + task_kwargs["envs"] = base_envs + if file_mounts: + task_kwargs["file_mounts"] = file_mounts + task = sky.Task(**task_kwargs) + task.set_resources(resources) + return task + + def _ensure_cluster(self) -> None: + if self._cluster_ready: + return + provision_task = self._base_task( + name=f"{self.cluster_name}-provision", + num_nodes=self.total_nodes, + run="echo '[SkyPilot] Cluster ready for AReaL launching.'", # noqa: E501 + ) + logger.info( + "Launching/repairing SkyPilot cluster '%s' with %d node(s).", + self.cluster_name, + self.total_nodes, + ) + req_id = sky.launch(provision_task, cluster_name=self.cluster_name) + sky.stream_and_get(req_id) + self._cluster_ready = True + + @property + def run_name(self) -> str: + return f"{self.experiment_name}_{self.trial_name}" + + def submit(self, job_name: str, task: sky.Task) -> int: + job_ids = self.submit_array(job_name, [task]) + return job_ids[0] + + def submit_array(self, job_name: str, tasks: List[sky.Task]) -> List[int]: + assert tasks, "Tasks list cannot be empty." + job_ids: List[int] = [] + for idx, task in enumerate(tasks): + derived_name = job_name if len(tasks) == 1 else f"{job_name}:{idx}" + task.name = derived_name + logger.info("Submitting SkyPilot task '%s'", derived_name) + request_id = sky.exec(task, cluster_name=self.cluster_name) + job_id, _ = sky.get(request_id) + self._register_job(job_id, derived_name) + job_ids.append(job_id) + return job_ids + + def stop(self, job_name: str, force: bool = False) -> None: + job_ids = list(self._job_groups.get(job_name, set())) + if not job_ids: + return + logger.info("Stopping jobs %s (ids=%s)", job_name, job_ids) + sky.cancel(self.cluster_name, job_ids=job_ids) + for job_id in job_ids: + self._remove_job(job_id) + + def stop_all(self, force: bool = False) -> None: + job_ids = list(self.jobs.keys()) + if not job_ids: + return + logger.info("Stopping all SkyPilot jobs: %s", job_ids) + sky.cancel(self.cluster_name, job_ids=job_ids) + for job_id in job_ids: + self._remove_job(job_id) + + def find(self, job_name: str) -> Optional[JobInfo]: + self._update_all() + job_ids = list(self._job_groups.get(job_name, set())) + if not job_ids: + return None + return self.jobs[job_ids[0]] + + def find_all(self, job_name_regex: str = ".*") -> List[JobInfo]: + self._update_all() + pattern = re.compile(job_name_regex) + results: List[JobInfo] = [] + for job_id, info in self.jobs.items(): + base = self._job_meta[job_id]["base"] + if pattern.fullmatch(base): + results.append(info) + return results + + def wait( + self, + timeout: Optional[int] = None, + check_status: Tuple[JobState, ...] = ( + JobState.CANCELLED, + JobState.FAILED, + JobState.NOT_FOUND, + ), + remove_status: Tuple[JobState, ...] = (JobState.COMPLETED,), + update: bool = False, + job_names: Optional[Iterable[str]] = None, + ) -> None: + deadline = None if timeout is None else time.time() + timeout + target_ids = self._select_job_ids(job_names) + pending = set(target_ids) + if not pending: + return + logger.info("Waiting for jobs %s", sorted(pending)) + while pending: + if deadline is not None and time.time() > deadline: + raise TimeoutError( + f"Timeout waiting for jobs {sorted(pending)} to finish." + ) + self._update_all() + for job_id in list(pending): + info = self.jobs.get(job_id) + if info is None: + pending.discard(job_id) + continue + state = info.state + base = self._job_meta[job_id]["base"] + if state in check_status: + raise JobException( + run_name=self.run_name, + worker_type=base, + host=self.cluster_name, + reason=state, + ) + if state in remove_status: + logger.info( + "Job %s (id=%s) reached %s", info.name, job_id, state.name + ) + pending.discard(job_id) + if update: + self._remove_job(job_id) + if pending: + time.sleep(SKY_WAIT_CHECK_TIME_INTERVAL) + + def _register_job(self, job_id: int, job_name: str) -> None: + base = job_name.split(":", maxsplit=1)[0] + self.jobs[job_id] = JobInfo( + name=job_name, + state=JobState.PENDING, + host=self.cluster_name, + ) + self._job_meta[job_id] = {"name": job_name, "base": base} + self._job_groups.setdefault(base, set()).add(job_id) + + def _remove_job(self, job_id: int) -> None: + info = self._job_meta.pop(job_id, None) + self.jobs.pop(job_id, None) + if info is None: + return + base = info["base"] + group = self._job_groups.get(base) + if group and job_id in group: + group.remove(job_id) + if not group: + self._job_groups.pop(base, None) + + def _select_job_ids(self, job_names: Optional[Iterable[str]]) -> List[int]: + if job_names is None: + return list(self.jobs.keys()) + selected: List[int] = [] + for base in job_names: + selected.extend(list(self._job_groups.get(base, set()))) + return selected + + def _update_all(self) -> None: + if not self.jobs: + return + job_ids = list(self.jobs.keys()) + try: + status_request = sky.job_status(self.cluster_name, job_ids=job_ids) + statuses = sky.get(status_request) + except Exception as exc: # pragma: no cover - best effort logging + logger.warning("Failed to query SkyPilot job status: %s", exc) + return + for job_id in job_ids: + info = self.jobs.get(job_id) + if info is None: + continue + status = statuses.get(job_id) + if status is None: + info.state = JobState.NOT_FOUND + else: + info.state = SKY_TO_JOB_STATE.get(status, JobState.NOT_FOUND) + + +def _quoted_cmd(parts: List[str]) -> str: + return " ".join(shlex.quote(p) for p in parts) + + +def _build_sglang_task( + launcher: SkyPilotLauncher, + job_name: str, + config_args: List[str], + allocation: AllocationMode, + sglang_cfg: SGLangConfig, + n_nodes: int, + gpus_per_node: int, + env_vars: Dict[str, str], +) -> sky.Task: + assert allocation.gen_backend == "sglang" + base_seed = sglang_cfg.random_seed + n_sglang_servers = allocation.gen.dp_size + n_servers_per_node = max(n_sglang_servers // max(n_nodes, 1), 1) + cross_nodes = allocation.gen_instance_size > gpus_per_node + + def run_generator(node_rank: int, host_ips: List[str]) -> str: + args = list(config_args) + args.append(f"sglang.random_seed={base_seed + node_rank * n_servers_per_node}") + cmd = _quoted_cmd(["python", "-m", "areal.launcher.sglang_server", *args]) + exports: List[str] = [] + if cross_nodes: + exports.append(f"export AREAL_SGLANG_MULTI_NODE_RANK={node_rank}") + exports.append(f"export AREAL_SGLANG_MULTI_NODE_MASTER_ADDR={host_ips[0]}") + # TODO: find free port + exports.append("export AREAL_SGLANG_MULTI_NODE_MASTER_PORT=17901") + if exports: + return " && ".join(exports + [cmd]) + return cmd + + return launcher._base_task( # pylint: disable=protected-access + name=job_name, + num_nodes=n_nodes, + run=run_generator, + extra_envs=env_vars, + ) + + +def _build_vllm_task( + launcher: SkyPilotLauncher, + job_name: str, + config_args: List[str], + allocation: AllocationMode, + vllm_cfg: vLLMConfig, + n_nodes: int, + env_vars: Dict[str, str], +) -> sky.Task: + assert allocation.gen_backend == "vllm" + base_seed = vllm_cfg.seed + + def run_generator(node_rank: int, _host_ips: List[str]) -> str: + args = list(config_args) + args.append(f"vllm.seed={base_seed + node_rank}") + cmd = _quoted_cmd(["python", "-m", "areal.launcher.vllm_server", *args]) + return cmd + + return launcher._base_task( # pylint: disable=protected-access + name=job_name, + num_nodes=n_nodes, + run=run_generator, + extra_envs=env_vars, + ) + + +def _build_trainer_task( + launcher: SkyPilotLauncher, + job_name: str, + trainer_entry: str, + trainer_args: List[str], + allocation: AllocationMode, + n_nodes: int, + gpus_per_node: int, + env_vars: Dict[str, str], + is_eval_only: bool, +) -> sky.Task: + + if is_eval_only: + + cmd = _quoted_cmd(["python", trainer_entry, *trainer_args]) + return launcher._base_task( + job_name, + num_nodes=1, + run=cmd, + extra_envs=env_vars, + ) + + # TODO: find free port + rendezvous_port = 29501 + + def run_generator(node_rank: int, host_ips: List[str]) -> str: + master_addr = host_ips[0] + torchrun_cmd = [ + "torchrun", + "--nnodes", + str(n_nodes), + "--nproc-per-node", + str(gpus_per_node), + "--rdzv_backend", + "c10d", + "--rdzv_endpoint", + f"{master_addr}:{rendezvous_port}", + "--node_rank", + str(node_rank), + trainer_entry, + *trainer_args, + ] + cmd = _quoted_cmd(torchrun_cmd) + return cmd + + return launcher._base_task( + job_name, + num_nodes=n_nodes, + run=run_generator, + extra_envs=env_vars, + ) + + +def skypilot_main(config, run_id: int = 0): + config.launcher = to_structured_cfg(config.launcher, LauncherConfig) + config.recover = to_structured_cfg(config.recover, RecoverConfig) + config.cluster = to_structured_cfg(config.cluster, ClusterSpecConfig) + config.launcher.skypilot = to_structured_cfg( + config.launcher.skypilot, SkyPilotLauncherConfig + ) + + is_recover_run = check_if_recover(config.recover, run_id) + validate_config_for_distributed_launcher(config) + + name_resolve.reconfigure(config.cluster.name_resolve) + name_resolve.clear_subtree( + names.trial_root( + experiment_name=config.experiment_name, trial_name=config.trial_name + ) + ) + + allocation_mode = AllocationMode.from_str(config.allocation_mode) + launcher = SkyPilotLauncher( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + total_nodes=config.cluster.n_nodes, + skypilot_cfg=config.launcher.skypilot, + ) + launcher.ensure_cluster() + + trainer_entry = sys.argv[1] + trainer_args = sys.argv[2:] + + llm_job_name: Optional[str] = None + llm_addrs: List[str] = [] + + try: + gpus_per_node = config.cluster.n_gpus_per_node + llm_backend = allocation_mode.gen_backend + if llm_backend == "sglang": + llm_job_name = f"{launcher.cluster_name}-sglang" + config.sglang = to_structured_cfg(config.sglang, SGLangConfig) + n_llm_nodes = max( + (allocation_mode.gen.world_size + gpus_per_node - 1) + // max(gpus_per_node, 1), + 1, + ) + llm_env = get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ) + task = _build_sglang_task( + launcher=launcher, + job_name=llm_job_name, + config_args=list(trainer_args), + allocation=allocation_mode, + sglang_cfg=config.sglang, + n_nodes=n_llm_nodes, + gpus_per_node=gpus_per_node, + env_vars=llm_env, + ) + launcher.submit(llm_job_name, task) + elif llm_backend == "vllm": + llm_job_name = f"{launcher.cluster_name}-vllm" + config.vllm = to_structured_cfg(config.vllm, vLLMConfig) + n_llm_nodes = max( + (allocation_mode.gen.world_size + gpus_per_node - 1) + // max(gpus_per_node, 1), + 1, + ) + llm_env = get_env_vars( + config.cluster.cluster_name, + config.launcher.inference_server_env_vars, + ) + task = _build_vllm_task( + launcher=launcher, + job_name=llm_job_name, + config_args=list(trainer_args), + allocation=allocation_mode, + vllm_cfg=config.vllm, + n_nodes=n_llm_nodes, + env_vars=llm_env, + ) + launcher.submit(llm_job_name, task) + + if llm_job_name is not None: + llm_addrs = wait_llm_server_addrs( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + n_rollout_servers=allocation_mode.gen.dp_size, + ) + + if allocation_mode.type_ == AllocationType.LLM_SERVER_ONLY: + if llm_job_name is None: + logger.warning( + "Allocation mode is LLM_SERVER_ONLY but no LLM job launched." + ) + else: + launcher.wait( + job_names=[llm_job_name], + check_status=( + JobState.FAILED, + JobState.CANCELLED, + JobState.NOT_FOUND, + ), + remove_status=(JobState.COMPLETED,), + update=False, + ) + return + + trainer_env = dict( + **get_env_vars( + config.cluster.cluster_name, + config.launcher.trainer_env_vars, + ), + AREAL_RECOVER_RUN=str(int(is_recover_run)), + ) + if llm_addrs: + trainer_env["AREAL_LLM_SERVER_ADDRS"] = ",".join(llm_addrs) + + if allocation_mode.type_ == AllocationType.DECOUPLED_EVAL: + trainer_nodes = 1 + gpus_per_node = 0 + else: + trainer_nodes = max( + config.cluster.n_nodes + - (allocation_mode.gen.world_size // config.cluster.n_gpus_per_node), + 1, + ) + gpus_per_node = config.cluster.n_gpus_per_node + + trainer_job_name = f"{launcher.cluster_name}-trainer" + trainer_task = _build_trainer_task( + launcher=launcher, + job_name=trainer_job_name, + trainer_entry=trainer_entry, + trainer_args=trainer_args, + allocation=allocation_mode, + n_nodes=trainer_nodes, + gpus_per_node=gpus_per_node, + env_vars=trainer_env, + is_eval_only=allocation_mode.type_ == AllocationType.DECOUPLED_EVAL, + ) + launcher.submit(trainer_job_name, trainer_task) + + launcher.wait( + job_names=[trainer_job_name], + check_status=( + JobState.FAILED, + JobState.CANCELLED, + JobState.NOT_FOUND, + ), + remove_status=(JobState.COMPLETED,), + update=True, + ) + except (KeyboardInterrupt, TimeoutError, JobException) as exc: + logger.error("SkyPilot launcher encountered an error: %s", exc) + launcher.stop_all() + recoverable_states = {JobState.FAILED} + if ( + isinstance(exc, JobException) + and exc.reason in recoverable_states + and run_id < config.recover.retries + and config.recover.mode in ("auto", "fault") + ): + time.sleep(10) + skypilot_main(config, run_id=run_id + 1) + else: + raise + finally: + if llm_job_name is not None: + try: + launcher.stop(llm_job_name) + except Exception as cancel_exc: # pragma: no cover + logger.warning("Failed to cancel LLM server job: %s", cancel_exc) + + +def main(): + config, _ = parse_cli_args(sys.argv[1:]) + skypilot_main(config, run_id=0) + + +if __name__ == "__main__": + main() diff --git a/docs/cli_reference.md b/docs/cli_reference.md index a0c60f956..bdaf64d34 100644 --- a/docs/cli_reference.md +++ b/docs/cli_reference.md @@ -72,6 +72,7 @@ For detailed examples, see the experiment configurations in the `examples/` dire ### Others - [Scheduler Configuration](section-scheduler) +- [SkyPilotLauncher Configuration](section-sky-pilot-launcher) ______________________________________________________________________ @@ -597,15 +598,16 @@ Configuration for cluster specification and distributed computing setup. Configuration for launching the LLM server and trainer processes. -| Parameter | Type | Default | Description | -| ------------------------------- | ----------------------------------------------- | ------------ | ------------------------------------------------------------------------------------------------ | -| `inference_server_cpus_per_gpu` | integer | `4` | Number of CPUs allocated per GPU for inference server. | -| `inference_server_mem_per_gpu` | integer | `32768` | Memory allocated per GPU for inference server in MB. | -| `trainer_cpus_per_gpu` | integer | `4` | Number of CPUs allocated per GPU for training. | -| `trainer_mem_per_gpu` | integer | `32768` | Memory allocated per GPU for training in MB. | -| `inference_server_env_vars` | string | `""` | Environment variables for inference server, separated by commas. Example: 'ENV1=val1,ENV2=val2'. | -| `trainer_env_vars` | string | `""` | Environment variables for training, separated by commas. Example: 'ENV1=val1,ENV2=val2'. | -| `slurm` | [`SlurmLauncherConfig`](section-slurm-launcher) | **Required** | Slurm launcher configuration. | +| Parameter | Type | Default | Description | +| ------------------------------- | ------------------------------------------------------ | ------------ | ------------------------------------------------------------------------------------------------ | +| `inference_server_cpus_per_gpu` | integer | `4` | Number of CPUs allocated per GPU for inference server. | +| `inference_server_mem_per_gpu` | integer | `32768` | Memory allocated per GPU for inference server in MB. | +| `trainer_cpus_per_gpu` | integer | `4` | Number of CPUs allocated per GPU for training. | +| `trainer_mem_per_gpu` | integer | `32768` | Memory allocated per GPU for training in MB. | +| `inference_server_env_vars` | string | `""` | Environment variables for inference server, separated by commas. Example: 'ENV1=val1,ENV2=val2'. | +| `trainer_env_vars` | string | `""` | Environment variables for training, separated by commas. Example: 'ENV1=val1,ENV2=val2'. | +| `slurm` | [`SlurmLauncherConfig`](section-slurm-launcher) | **Required** | Slurm launcher configuration. | +| `skypilot` | [`SkyPilotLauncherConfig`](section-sky-pilot-launcher) | **Required** | SkyPilot launcher configuration. | (section-name-resolve)= @@ -756,3 +758,32 @@ Configuration for worker scheduling. Used in the single-controller mode. Experim | `reward_functioncall_config` | `Dict` | **Required** | - | | `reward_model_path` | string | `""` | - | | `reward_model_service_url` | string | `"http://localhost:30000/classify"` | - | + +(section-sky-pilot-launcher)= + +## SkyPilotLauncher Configuration + +Configuration for launching the training jobs with SkyPilot. + +| Parameter | Type | Default | Description | +| ------------------ | -------------- | ------------ | ---------------------------------------------------------------------------------------------------------------- | +| `name` | string \| None | `None` | Optional task name displayed in SkyPilot. | +| `workdir` | string \| None | `None` | Local path or git repo spec to sync as working directory (mirrors SkyPilot YAML workdir). | +| `infra` | string \| None | `None` | Infrastructure spec // or k8s\[/context\] (resources.infra). | +| `accelerator_type` | string \| None | `None` | Accelerator request, e.g. 'H100', 'A100'. Number of GPUs on the node is determined by `cluster.n_gpus_per_node`. | +| `accelerator_args` | string \| None | `None` | Additional accelerator args (YAML/JSON string) (resources.accelerator_args). | +| `use_spot` | boolean | `False` | Whether to use spot/preemptible instances (resources.use_spot). | +| `disk_size` | string \| None | `None` | Boot disk size with optional unit, e.g. '256', '256GB' (resources.disk_size). | +| `disk_tier` | string | `"medium"` | Disk performance tier (resources.disk_tier). **Choices:** `low`, `medium`, `high`, `ultra`, `best` | +| `network_tier` | string | `"standard"` | Network tier (resources.network_tier). **Choices:** `standard`, `best` | +| `ports` | string \| None | `None` | Ports to expose, supports single '8080', range '10052-10100', or comma list (resources.ports). | +| `image_id` | string \| None | `None` | Custom base or docker image id (resources.image_id). | +| `labels` | string \| None | `None` | Instance/pod labels as key=value pairs joined by commas (resources.labels). | +| `any_of` | string \| None | `None` | YAML/JSON string list of candidate resource dicts (resources.any_of). | +| `ordered` | string \| None | `None` | YAML/JSON string list of ordered resource dicts (resources.ordered). | +| `job_recovery` | string \| None | `None` | Job recovery strategy spec (resources.job_recovery). Provide JSON/YAML. | +| `autostop` | string \| None | `None` | Autostop configuration: true/false/10/10h or JSON object (resources.autostop). | +| `envs` | string \| None | `None` | Environment variables (envs) as KEY=VAL pairs joined by commas. | +| `secrets` | string \| None | `None` | Secrets usable in setup/run as KEY=VAL pairs joined by commas (will be redacted). | +| `volumes` | string \| None | `None` | Kubernetes volume mappings (volumes) as YAML/JSON string or shorthand mount spec. | +| `file_mounts` | string \| None | `None` | File mounts mapping remote_path:local_path lines or JSON/YAML string (file_mounts). | diff --git a/docs/tutorial/installation.md b/docs/tutorial/installation.md index 01ece031f..f93d69d96 100644 --- a/docs/tutorial/installation.md +++ b/docs/tutorial/installation.md @@ -77,6 +77,54 @@ bash examples/env/setup-pip-deps.sh > The SGLang patch is applied via `examples/env/setup-container-deps.sh` or `examples/env/setup-pip-deps.sh`. To confirm whether it has been applied, run `git status` in the `/sglang` directory (for Docker) or `AReaL/sglang` (for custom setups). --> +## (Optional) Install SkyPilot + +SkyPilot helps you run AReaL easily on cloud or Kubernetes infrastructures. Below shows +the minimal steps to setup skypilot on GCP or Kubernetes. + +### Install SkyPilot + +```bash +# In your conda environment +# NOTE: SkyPilot requires 3.7 <= python <= 3.13 +pip install -U "skypilot[gcp,kubernetes]" +``` + +### GCP setup + +```bash +# Install Google Cloud SDK +conda install -y -c conda-forge google-cloud-sdk + +# Initialize gcloud and select your account/project +gcloud init + +# (Optional) choose a project explicitly +gcloud config set project + +# Create Application Default Credentials +gcloud auth application-default login +``` + +### Kubernetes setup + +Check +[here](https://docs.skypilot.co/en/latest/reference/kubernetes/kubernetes-setup.html) +for a comprehensive guide on how to set up a kubernetes cluster for SkyPilot. + +### Verify + +```bash +sky check +``` + +If `GCP: enabled` or `Kubernetes: enabled` are shown, you're ready to use SkyPilot with +AReaL. Check +[here](https://github.com/inclusionAI/AReaL/blob/main/examples/skypilot/README.md) for a +detailed example to run AReaL with SkyPilot. For more options and details for SkyPilot, +see the official +[SkyPilot installation guide](https://docs.skypilot.co/en/latest/getting-started/installation.html). + ## (Optional) Launch Ray Cluster for Distributed Training On the first node, start the Ray Head: diff --git a/examples/skypilot/README.md b/examples/skypilot/README.md new file mode 100644 index 000000000..4e9bab38a --- /dev/null +++ b/examples/skypilot/README.md @@ -0,0 +1,149 @@ +# Running AReaL with SkyPilot + +This README includes examples and guidelines to running AReaL experiments with SkyPilot. +Make sure you have SkyPilot properly installed following +[our installation guide](../../docs/tutorial/installation.md#optional-install-skypilot) +before running this example. Note that all command lines shown in this file are assumed +to be execute under the root of AReaL repository. + +## Running a Single Node Experiment + +To run a single node experiment, you only need to setup the node with SkyPilot and +launch the experiment with AReaL local launcher. [The following file](local.yaml) shows +a SkyPilot yaml that could launch a simple GSM8K GRPO experiment in a single command +line. This example runs on GCP, but could be easily migrated to other cloud or K8S +cluster by changing `resource.infra` field in SkyPilot YAML file. + +```yaml +name: areal-test-skypilot + +resources: + infra: gcp + accelerators: A100:2 + autostop: + idle_minutes: 10 + down: true + cpus: 8+ + memory: 32GB+ + disk_size: 256GB + image_id: docker:ghcr.io/inclusionai/areal-runtime:v0.3.4 + +num_nodes: 1 + +workdir: . + +run: | + python3 -m areal.launcher.local examples/math/gsm8k_grpo.py \ + --config examples/math/gsm8k_grpo.yaml \ + experiment_name=gsm8k-grpo \ + trial_name=trial0 \ + cluster.n_gpus_per_node=2 \ + allocation_mode=sglang.d1+d1 \ + train_dataset.batch_size=4 \ + actor.mb_spec.max_tokens_per_mb=4096 +``` + +To run the experiment, execute: + +```bash +sky launch -c areal-test examples/skypilot/local.yaml +``` + +## Running a Multi-Node Experiment + +### Running AReaL with Ray Launcher + +The following example shows how to setup a ray cluster with SkyPilot and then use AReaL +to run GRPO with GSM8K dataset on 2 nodes, each with 1 A100 GPU. This example runs on +GCP, but could be easily migrated to other cloud or K8S cluster by changing +`resource.infra` field in SkyPilot YAML file. + +Specify the resources and image used to run the experiment. + +```yaml +resources: + infra: gcp + accelerators: A100:1 + image_id: docker:ghcr.io/inclusionai/areal-runtime:v0.3.4 + memory: 256+ + cpus: 32+ + +num_nodes: 2 + +workdir: . +``` + +Designate shared storage. You could either use an existing cloud bucket or volume: + +```yaml +file_mounts: + /storage: gs://areal-default +``` + +or create a new bucket or volume with SkyPilot: + +```yaml +file_mounts: + /storage: + name: areal-test + store: gcs +``` + +For more information about shared storage with SkyPilot, check +[SkyPilot Cloud Buckets](https://docs.skypilot.co/en/latest/reference/storage.html) and +[SkyPilot Volume](https://docs.skypilot.co/en/latest/reference/volumes.html). + +Next, prepare commands used to setup ray cluster and run the experiment. + +```yaml +run: | + # Get the Head node's IP and total number of nodes (environment variables injected by SkyPilot). + head_ip=$(echo "$SKYPILOT_NODE_IPS" | head -n1) + num_nodes=$(echo "$SKYPILOT_NODE_IPS" | wc -l) + + if [ "$SKYPILOT_NODE_RANK" = "0" ]; then + echo "Starting Ray head node..." + ray start --head --port=6379 + + while [ $(ray status | grep node_ | wc -l) -lt $num_nodes ]; do + echo "Waiting for all nodes to join... Current nodes: $(ray status | grep node_ | wc -l) / $num_nodes" + sleep 5 + done + + echo "Executing training script on head node..." + python3 -m areal.launcher.ray examples/math/gsm8k_grpo.py \ + --config examples/skypilot/gsm8k_grpo_ray.yaml \ + experiment_name=gsm8k-grpo \ + trial_name=trial0 + else + sleep 10 + echo "Starting Ray worker node..." + ray start --address $head_ip:6379 + sleep 5 + fi + + echo "Node setup complete for rank $SKYPILOT_NODE_RANK." +``` + +### Launch the Ray Cluster and AReaL + +Then you are ready to run AReaL with command line: + +```bash +sky launch -c areal-test examples/skypilot/ray_cluster.yaml +``` + +You should be able to see your AReaL running and producing training logs in your +terminal. + +Successfully launched 2 nodes on GCP and deployed a ray cluster: +Launching Ray Cluster + +Successfully ran a training step: +Running a train step + +### Running AReaL with SkyPilot Launcher + +AReaL plans to support a SkyPilot native launcher with +[SkyPilot Python SDK](https://docs.skypilot.co/en/latest/reference/api.html), which is +currently under development. diff --git a/examples/skypilot/gsm8k_grpo_ray.yaml b/examples/skypilot/gsm8k_grpo_ray.yaml new file mode 100644 index 000000000..ed973c14f --- /dev/null +++ b/examples/skypilot/gsm8k_grpo_ray.yaml @@ -0,0 +1,153 @@ +experiment_name: gsm8k-grpo-on-ray +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 2 + n_gpus_per_node: 1 + fileroot: /storage/experiments + name_resolve: + type: ray + ray_actor_name: ray_kv_store + +allocation_mode: sglang.d1+d1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 4096 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +# datasets +train_dataset: + batch_size: 4 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 4 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 diff --git a/examples/skypilot/local.yaml b/examples/skypilot/local.yaml new file mode 100644 index 000000000..94b6ad3b7 --- /dev/null +++ b/examples/skypilot/local.yaml @@ -0,0 +1,29 @@ +name: areal-test-skypilot + +resources: + infra: gcp + accelerators: A100:2 + autostop: + idle_minutes: 10 + down: true + cpus: 8+ + memory: 32GB+ + disk_size: 256GB + image_id: docker:ghcr.io/inclusionai/areal-runtime:v0.3.4 + +num_nodes: 1 + +file_mounts: + /storage: gs://areal-default + +workdir: . + +run: | + python3 -m areal.launcher.local examples/math/gsm8k_grpo.py \ + --config examples/math/gsm8k_grpo.yaml \ + experiment_name=gsm8k-grpo \ + trial_name=trial0 \ + cluster.n_gpus_per_node=2 \ + allocation_mode=sglang.d1+d1 \ + train_dataset.batch_size=4 \ + actor.mb_spec.max_tokens_per_mb=4096 diff --git a/examples/skypilot/ray_cluster.yaml b/examples/skypilot/ray_cluster.yaml new file mode 100644 index 000000000..d1c72cb11 --- /dev/null +++ b/examples/skypilot/ray_cluster.yaml @@ -0,0 +1,42 @@ + +resources: + infra: gcp + accelerators: A100:1 + image_id: docker:ghcr.io/inclusionai/areal-runtime:v0.3.4 + memory: 32+ + cpus: 8+ + +num_nodes: 2 + +workdir: . + +file_mounts: + /storage: gs://areal-default + +run: | + # Get the Head node's IP and total number of nodes (environment variables injected by SkyPilot). + head_ip=$(echo "$SKYPILOT_NODE_IPS" | head -n1) + num_nodes=$(echo "$SKYPILOT_NODE_IPS" | wc -l) + + if [ "$SKYPILOT_NODE_RANK" = "0" ]; then + echo "Starting Ray head node..." + ray start --head --port=6379 + + while [ $(ray status | grep node_ | wc -l) -lt $num_nodes ]; do + echo "Waiting for all nodes to join... Current nodes: $(ray status | grep node_ | wc -l) / $num_nodes" + sleep 5 + done + + echo "Executing training script on head node..." + python3 -m areal.launcher.ray examples/math/gsm8k_grpo.py \ + --config examples/skypilot/gsm8k_grpo_ray.yaml \ + experiment_name=gsm8k-grpo \ + trial_name=trial0 + else + sleep 10 + echo "Starting Ray worker node..." + ray start --address $head_ip:6379 + sleep 5 + fi + + echo "Node setup complete for rank $SKYPILOT_NODE_RANK." diff --git a/examples/skypilot/ray_launch.png b/examples/skypilot/ray_launch.png new file mode 100644 index 000000000..207251b28 Binary files /dev/null and b/examples/skypilot/ray_launch.png differ diff --git a/examples/skypilot/train_step_success.png b/examples/skypilot/train_step_success.png new file mode 100644 index 000000000..4acb5f6ba Binary files /dev/null and b/examples/skypilot/train_step_success.png differ