diff --git a/transfer_queue/interface.py b/transfer_queue/interface.py index ba74a35c..933b42c3 100644 --- a/transfer_queue/interface.py +++ b/transfer_queue/interface.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math import os import subprocess import time from importlib import resources from typing import Any, Callable -from urllib.parse import urlparse import ray import torch @@ -32,13 +30,9 @@ from transfer_queue.metadata import KVBatchMeta from transfer_queue.sampler import * # noqa: F401 from transfer_queue.sampler import BaseSampler -from transfer_queue.storage.simple_storage import SimpleStorageUnit -from transfer_queue.utils.common import get_placement_group +from transfer_queue.storage.bootstrap import StorageBootstrapProvider from transfer_queue.utils.logging_utils import get_logger -from transfer_queue.utils.yuanrong_utils import ( - cleanup_yuanrong_resources, - initialize_yuanrong_backend, -) +from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources from transfer_queue.utils.zmq_utils import process_zmq_server_info logger = get_logger(__name__) @@ -70,125 +64,23 @@ def _maybe_create_tq_client(conf: DictConfig | None = None) -> TransferQueueClie return _TQ_CLIENT -# TODO(hz): Adopt registry pattern to manage storage backends for better scalability. def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig: global _TQ_STORAGE if _TQ_STORAGE is None: _TQ_STORAGE = {} - if conf.backend.storage_backend == "SimpleStorage": - # initialize SimpleStorageUnit - simple_storage_handles = {} - num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units - total_storage_size = conf.backend.SimpleStorage.total_storage_size - storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) - - for storage_unit_rank in range(num_data_storage_units): - storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] - placement_group=storage_placement_group, - placement_group_bundle_index=storage_unit_rank, - name=f"TransferQueueStorageUnit#{storage_unit_rank}", - ).remote( - storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), - ) - simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node - logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") - - storage_zmq_info = process_zmq_server_info(simple_storage_handles) - backend_name = conf.backend.storage_backend - conf.backend[backend_name].zmq_info = storage_zmq_info - _TQ_STORAGE["SimpleStorage"] = simple_storage_handles - if conf.backend.storage_backend == "MooncakeStore": - if conf.backend.MooncakeStore.auto_init: - # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts - check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) - if check.returncode == 0: - pids = check.stdout.strip().replace("\n", ", ") - logger.info(f"Find existing mooncake_master (PID: {pids}), try to kill first...") - - result = os.system('pkill -f "[m]ooncake_master"') - if result == 0: - logger.info("Successfully killed existing mooncake_master processes.") - else: - raise RuntimeError(f"Failed to kill existing mooncake_master processes (exit code: {result}).") - - # process metadata_server - metadata_server_raw_address = conf.backend.MooncakeStore.metadata_server - if "://" not in metadata_server_raw_address: - metadata_server_raw_address = "//" + metadata_server_raw_address - - metadata_server_parsed = urlparse(metadata_server_raw_address) - - if not metadata_server_parsed.hostname or metadata_server_parsed.port is None: - raise ValueError( - f"Invalid metadata_server '{conf.backend.MooncakeStore.metadata_server}'. " - f"Host and port are required (e.g., host:port)." - ) - - metadata_server_host = metadata_server_parsed.hostname - metadata_server_port = str(metadata_server_parsed.port) - - # process master_server - master_server_raw_address = conf.backend.MooncakeStore.master_server_address - if "://" not in master_server_raw_address: - master_server_raw_address = "//" + master_server_raw_address - - master_server_parsed = urlparse(master_server_raw_address) - - if not master_server_parsed.hostname or master_server_parsed.port is None: - raise ValueError( - f"Invalid master_server_address '{conf.backend.MooncakeStore.master_server_address}'. " - f"Host and port are required (e.g., host:port)." - ) - - master_server_port = str(master_server_parsed.port) - - cmd = [ - "mooncake_master", - "-client_ttl=30", - "-default_kv_lease_ttl=999999", - "-default_kv_soft_pin_ttl=999999", - "--eviction_high_watermark_ratio=1.0", - "--eviction_ratio=0.0", - "--enable_http_metadata_server=true", - "--allow_evict_soft_pinned_objects=false", - f"--http_metadata_server_host={metadata_server_host}", - f"--http_metadata_server_port={metadata_server_port}", - f"--rpc_port={master_server_port}", - ] - - log_file_path = "/tmp/mooncake_master.log" - with open(log_file_path, "w") as log_file: - process = subprocess.Popen( - cmd, - stdout=log_file, - stderr=subprocess.STDOUT, - text=True, - bufsize=1, - universal_newlines=True, - start_new_session=True, - ) - time.sleep(3) - - if process.poll() is None: - logger.info( - f"mooncake_master started, PID: {process.pid}. Logs are at: {os.path.abspath(log_file_path)}" - ) - else: - error_msg = "" - try: - with open(log_file_path) as f: - error_msg = f.read() - except Exception as e: - error_msg = f"Failed to read log file: {e}" - - raise RuntimeError( - f"mooncake_master exited with error. Check {log_file_path} for detailed logs. " - f"Output:\n{error_msg}" - ) - _TQ_STORAGE["MooncakeStore"] = process - if conf.backend.storage_backend == "Yuanrong" and conf.backend.Yuanrong.auto_init: - _TQ_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf) + backend_name = conf.backend.storage_backend + provider_fn = StorageBootstrapProvider.get_provider(backend_name) + if provider_fn is not None: + backend_resources = provider_fn(conf) + if backend_resources is not None: + _TQ_STORAGE[backend_name] = backend_resources + else: + logger.error(f"Not found available {backend_name} storage resources, please check the config.") + else: + logger.error( + f"Storage backend {backend_name} not registered. Please add it to the StorageBootstrapProvider." + ) return conf diff --git a/transfer_queue/storage/bootstrap/__init__.py b/transfer_queue/storage/bootstrap/__init__.py new file mode 100644 index 00000000..e8ce25f1 --- /dev/null +++ b/transfer_queue/storage/bootstrap/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import mooncake_bootstrap, simple_storage_bootstrap, yuanrong_bootstrap # noqa: F401, I001 +from .provider import StorageBootstrapProvider + +__all__ = [ + "StorageBootstrapProvider", +] diff --git a/transfer_queue/storage/bootstrap/mooncake_bootstrap.py b/transfer_queue/storage/bootstrap/mooncake_bootstrap.py new file mode 100644 index 00000000..536599ca --- /dev/null +++ b/transfer_queue/storage/bootstrap/mooncake_bootstrap.py @@ -0,0 +1,127 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import time +from urllib.parse import urlparse + +from omegaconf import DictConfig + +from transfer_queue.storage.bootstrap.provider import StorageBootstrapProvider +from transfer_queue.utils.logging_utils import get_logger + +logger = get_logger(__name__) + + +@StorageBootstrapProvider.register_provider("MooncakeStore") +def initialize_mooncake_storage(conf: DictConfig) -> subprocess.Popen | None: + """ + Initialize Mooncake store backend. + Args: + conf (DictConfig): Configuration dictionary for the Mooncake store backend. + Returns: + subprocess.Popen | None: Process object for the Mooncake store backend process. + Raises: + ValueError: If the Mooncake store is not initialized successfully. + """ + if not conf.backend.MooncakeStore.auto_init: + return None + + # Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts + check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True) + if check.returncode == 0: + pids = check.stdout.strip().replace("\n", ", ") + logger.info(f"Find existing mooncake_master (PID: {pids}), try to kill first...") + + result = os.system('pkill -f "[m]ooncake_master"') + if result == 0: + logger.info("Successfully killed existing mooncake_master processes.") + else: + raise RuntimeError(f"Failed to kill existing mooncake_master processes (exit code: {result}).") + + # process metadata_server + metadata_server_raw_address = conf.backend.MooncakeStore.metadata_server + if "://" not in metadata_server_raw_address: + metadata_server_raw_address = "//" + metadata_server_raw_address + + metadata_server_parsed = urlparse(metadata_server_raw_address) + + if not metadata_server_parsed.hostname or metadata_server_parsed.port is None: + raise ValueError( + f"Invalid metadata_server '{conf.backend.MooncakeStore.metadata_server}'. " + f"Host and port are required (e.g., host:port)." + ) + + metadata_server_host = metadata_server_parsed.hostname + metadata_server_port = str(metadata_server_parsed.port) + + # process master_server + master_server_raw_address = conf.backend.MooncakeStore.master_server_address + if "://" not in master_server_raw_address: + master_server_raw_address = "//" + master_server_raw_address + + master_server_parsed = urlparse(master_server_raw_address) + + if not master_server_parsed.hostname or master_server_parsed.port is None: + raise ValueError( + f"Invalid master_server_address '{conf.backend.MooncakeStore.master_server_address}'. " + f"Host and port are required (e.g., host:port)." + ) + + master_server_port = str(master_server_parsed.port) + + cmd = [ + "mooncake_master", + "-client_ttl=30", + "-default_kv_lease_ttl=999999", + "-default_kv_soft_pin_ttl=999999", + "--eviction_high_watermark_ratio=1.0", + "--eviction_ratio=0.0", + "--enable_http_metadata_server=true", + "--allow_evict_soft_pinned_objects=false", + f"--http_metadata_server_host={metadata_server_host}", + f"--http_metadata_server_port={metadata_server_port}", + f"--rpc_port={master_server_port}", + ] + + log_file_path = "/tmp/mooncake_master.log" + with open(log_file_path, "w") as log_file: + process = subprocess.Popen( + cmd, + stdout=log_file, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + universal_newlines=True, + start_new_session=True, + ) + time.sleep(3) + + if process.poll() is None: + logger.info(f"mooncake_master started, PID: {process.pid}. Logs are at: {os.path.abspath(log_file_path)}") + else: + error_msg = "" + try: + with open(log_file_path) as f: + error_msg = f.read() + except Exception as e: + error_msg = f"Failed to read log file: {e}" + + raise RuntimeError( + f"mooncake_master exited with error. Check {log_file_path} for detailed logs. Output:\n{error_msg}" + ) + + return process diff --git a/transfer_queue/storage/bootstrap/provider.py b/transfer_queue/storage/bootstrap/provider.py new file mode 100644 index 00000000..504900f8 --- /dev/null +++ b/transfer_queue/storage/bootstrap/provider.py @@ -0,0 +1,42 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps +from typing import Callable + + +class StorageBootstrapProvider: + """Registry for storage backend bootstrap functions.""" + + _providers: dict[str, Callable] = {} + + @classmethod + def register_provider(cls, name: str): + """Decorator to register storage provider & returns function.""" + + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + cls._providers[name.lower()] = wrapper + return wrapper + + return decorator + + @classmethod + def get_provider(cls, name: str) -> Callable | None: + """Get storage provider function by name.""" + return cls._providers.get(name.lower(), None) diff --git a/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py b/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py new file mode 100644 index 00000000..1ab2f6b6 --- /dev/null +++ b/transfer_queue/storage/bootstrap/simple_storage_bootstrap.py @@ -0,0 +1,54 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any + +from omegaconf import DictConfig + +from transfer_queue.storage.bootstrap.provider import StorageBootstrapProvider +from transfer_queue.storage.simple_storage import SimpleStorageUnit +from transfer_queue.utils.common import get_placement_group +from transfer_queue.utils.logging_utils import get_logger +from transfer_queue.utils.zmq_utils import process_zmq_server_info + +logger = get_logger(__name__) + + +@StorageBootstrapProvider.register_provider("SimpleStorage") +def initialize_simple_storage(conf: DictConfig) -> dict[str, Any]: + """Initialize Simple storage with metastore mode.""" + + simple_storage_handles = {} + num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units + total_storage_size = conf.backend.SimpleStorage.total_storage_size + storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1) + + for storage_unit_rank in range(num_data_storage_units): + storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined] + placement_group=storage_placement_group, + placement_group_bundle_index=storage_unit_rank, + name=f"TransferQueueStorageUnit#{storage_unit_rank}", + ).remote( + storage_unit_size=math.ceil(total_storage_size / num_data_storage_units), + ) + simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node + logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.") + + storage_zmq_info = process_zmq_server_info(simple_storage_handles) + backend_name = conf.backend.storage_backend + conf.backend[backend_name].zmq_info = storage_zmq_info + + return simple_storage_handles diff --git a/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py b/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py new file mode 100644 index 00000000..e114f939 --- /dev/null +++ b/transfer_queue/storage/bootstrap/yuanrong_bootstrap.py @@ -0,0 +1,427 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import shutil +import subprocess +from typing import Any + +import ray +from omegaconf import DictConfig + +from transfer_queue.storage.bootstrap.provider import StorageBootstrapProvider +from transfer_queue.utils.yuanrong_utils import get_local_ip_addresses, kill_actors_and_placement_group + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) + + +def _parse_remote_h2d_device_ids(worker_args: str) -> str | None: + """Parse --remote_h2d_device_ids parameter from worker_args string. + + Args: + worker_args: Worker arguments string, e.g., "--arg1 value1 --remote_h2d_device_ids 0,1,2,3" + + Returns: + The device IDs string if found and valid, None otherwise. + + Raises: + RuntimeError: If --remote_h2d_device_ids flag is found but has invalid format. + """ + if not worker_args: + return None + + args_list = worker_args.split() + + # Find the index of --remote_h2d_device_ids + try: + idx = args_list.index("--remote_h2d_device_ids") + except ValueError: + return None + + # Check if there's a value after the flag + if idx + 1 >= len(args_list): + raise RuntimeError("--remote_h2d_device_ids flag found but no value provided") + + device_ids = args_list[idx + 1] + + # Validate the format: comma-separated digits + if not device_ids: + raise RuntimeError("Empty device IDs value after --remote_h2d_device_ids") + + # Validate each segment is a digit + parts = device_ids.split(",") + for part in parts: + if not part.isdigit(): + raise RuntimeError( + f"Invalid device ID format: '{device_ids}'. Expected comma-separated digits (e.g., '0,1,2,3')." + ) + + return device_ids + + +def start_datasystem_worker( + worker_address: str, + metastore_address: str, + is_head: bool, + worker_args: str = "", +) -> None: + """Start Yuanrong datasystem worker in metastore mode. + + Args: + worker_address: Worker address in format host:port + metastore_address: Metastore address in format host:port + is_head: Whether this node should start metastore service + worker_args: Additional arguments to append to dscli start command + + Raises: + RuntimeError: If dscli command fails + """ + if not shutil.which("dscli"): + raise RuntimeError("dscli executable not found in PATH. Please run `pip install openyuanrong-datasystem`.") + + cmd = ["dscli", "start", "-w", "--worker_address", worker_address] + cmd.extend(["--metastore_address", metastore_address]) + if is_head: + cmd.extend(["--start_metastore_service", "true"]) + + # Built-in default options + cmd.extend(["--arena_per_tenant", "1", "--enable_worker_worker_batch_get", "true"]) + + # Append worker_args if provided + if worker_args: + cmd.extend(worker_args.split()) + + node_type = "head node" if is_head else "worker node" + logger.info(f"Starting Yuanrong datasystem ({node_type}) at {worker_address}, worker_args={worker_args}") + + # Build environment with ASCEND_RT_VISIBLE_DEVICES if specified + env = None + device_ids = _parse_remote_h2d_device_ids(worker_args) + if device_ids: + env = os.environ.copy() + env["ASCEND_RT_VISIBLE_DEVICES"] = device_ids + logger.info( + f"Setting ASCEND_RT_VISIBLE_DEVICES={device_ids} for dscli subprocess ({node_type} at {worker_address})" + ) + + try: + ds_result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=90, + env=env, + ) + except subprocess.TimeoutExpired as err: + raise RuntimeError(f"dscli start timed out: {err}") from err + + if ds_result.returncode == 0 and "[ OK ]" in ds_result.stdout: + logger.info( + f"dscli started Yuanrong datasystem ({node_type}, metastore mode) at {worker_address} successfully." + ) + else: + raise RuntimeError( + f"Failed to start datasystem ({node_type}, metastore mode) at {worker_address}. " + f"Return code: {ds_result.returncode}, Output: {ds_result.stdout}" + ) + + +def stop_datasystem_worker(worker_address: str) -> None: + """Stop Yuanrong datasystem worker. + + Args: + worker_address: Worker address in format host:port + """ + if worker_address: + try: + result = subprocess.run( + ["dscli", "stop", "--worker_address", worker_address], + timeout=90, + capture_output=True, + ) + if result.returncode == 0: + logger.info(f"Stopped datasystem worker at {worker_address} via dscli stop") + else: + error_msg = (result.stderr or result.stdout or b"").decode() + logger.warning( + f"Failed to stop datasystem worker at {worker_address}. " + f"Return code: {result.returncode}, Error: {error_msg}" + ) + except subprocess.TimeoutExpired as err: + logger.warning(f"dscli stop timed out for {worker_address}: {err}") + except Exception as e: + logger.warning(f"Failed to stop datasystem worker via dscli: {e}") + + +@ray.remote(num_cpus=0.1) +class YuanrongWorkerActor: + """Ray actor to manage Yuanrong datasystem worker on a node. + + This actor runs on each node in the Ray cluster and is responsible for + starting and stopping the Yuanrong datasystem worker process on that node. + + The actor determines its own rank and role (head or worker) by finding the + intersection of local IP addresses with the provided node IPs. + """ + + def __init__(self, node_ips: list[str], worker_port: int, metastore_port: int, worker_args: str = ""): + """Initialize the Yuanrong worker actor. + + Args: + node_ips: List of all node IPs in the Ray cluster + worker_port: Port for the datasystem worker + metastore_port: Port for the metastore service (on head node) + worker_args: Additional arguments to append to dscli start command + + Raises: + RuntimeError: If cannot determine this node's IP from node_ips + """ + local_ips = get_local_ip_addresses() + self.my_ip = None + + # Find the intersection between local IPs and node_ips + for ip in node_ips: + if ip in local_ips: + self.my_ip = ip + break + + if self.my_ip is None: + raise RuntimeError(f"Cannot determine local node IP. Local IPs: {local_ips}, Cluster node IPs: {node_ips}") + + self.node_ips = node_ips + self.worker_port = worker_port + self.metastore_port = metastore_port + self.worker_address = f"{self.my_ip}:{worker_port}" + self.worker_args = worker_args + + # First node in the list is assumed to be the head node. + # This assumption is based on how interface.py constructs node_ips from ray.nodes(). + self.head_node_ip = node_ips[0] + self.metastore_address = f"{self.head_node_ip}:{metastore_port}" + self.is_head = self.my_ip == self.head_node_ip + + logger.info( + f"YuanrongWorkerActor initialized on node {self.my_ip}: " + f"worker_address={self.worker_address}, " + f"metastore_address={self.metastore_address}, is_head={self.is_head}, worker_args={self.worker_args}" + ) + + def start(self) -> str: + """Start the datasystem worker on this node. + + Returns: + The worker address. + + Raises: + RuntimeError: If dscli command fails + """ + logger.info(f"Starting datasystem worker at {self.worker_address}...") + start_datasystem_worker( + self.worker_address, + metastore_address=self.metastore_address, + is_head=self.is_head, + worker_args=self.worker_args, + ) + logger.info(f"Datasystem worker started successfully at {self.worker_address}") + return self.worker_address + + def get_metastore_address(self) -> str: + """Get the metastore address. + + Returns: + The metastore address in format host:port + """ + return self.metastore_address + + def get_node_ip(self) -> str: + """Return the IP address of the node this actor is running on.""" + assert self.my_ip is not None + return self.my_ip + + def stop(self) -> None: + """Stop the datasystem worker on this node.""" + logger.info(f"Stopping datasystem worker at {self.worker_address}...") + stop_datasystem_worker(self.worker_address) + logger.info(f"Datasystem worker stopped successfully at {self.worker_address}") + + +@StorageBootstrapProvider.register_provider("Yuanrong") +def initialize_yuanrong_storage(conf: DictConfig) -> dict[str, Any] | None: + """Initialize Yuanrong storage with metastore mode. + + This function sets up the Yuanrong storage datasystem workers across all Ray nodes + using placement groups and actors. + + Args: + conf: Configuration containing Yuanrong storage settings + + Returns: + Dict containing worker_actors, metastore_address, and placement_group + + Raises: + RuntimeError: If Ray nodes not found or initialization fails + """ + if not conf.backend.Yuanrong.auto_init: + return None + + # Get Ray cluster information + nodes = ray.nodes() + if not nodes: + raise RuntimeError("No Ray nodes found. Is Ray initialized?") + + # Filter to only alive nodes and get their IPs + alive_nodes = [node for node in nodes if node.get("Alive", False)] + if not alive_nodes: + raise RuntimeError("No alive Ray nodes found") + + # Get driver node IP to use as head node + driver_ip = ray.util.get_node_ip_address() + head_node = None + other_nodes = [] + + # Separate head node (driver) from other nodes + for node in alive_nodes: + node_ip = node["NodeManagerAddress"] + if node_ip == driver_ip: + head_node = node + else: + other_nodes.append(node) + + if head_node is None: + raise RuntimeError(f"Driver node {driver_ip} not found in alive nodes") + + # Reorder nodes: head node first, then others + ordered_nodes = [head_node] + other_nodes + + # Extract node IPs in deterministic order + node_ips = [node["NodeManagerAddress"] for node in ordered_nodes] + worker_port = conf.backend.Yuanrong.worker_port + metastore_port = conf.backend.Yuanrong.metastore_port + worker_args = conf.backend.Yuanrong.get("worker_args", "") + + logger.info(f"Found {len(ordered_nodes)} alive Ray nodes: {node_ips}") + + # Create placement group using STRICT_SPREAD to ensure each bundle is on a distinct node + bundles = [{"CPU": 0.1} for _ in ordered_nodes] + + pg = ray.util.placement_group(bundles, strategy="STRICT_SPREAD") + try: + ray.get(pg.ready(), timeout=60) + except ray.exceptions.GetTimeoutError as e: + try: + ray.util.remove_placement_group(pg) + except Exception as cleanup_error: + logger.warning(f"Failed to remove placement group after readiness timeout: {cleanup_error}") + raise RuntimeError( + "Timed out waiting for Yuanrong placement group to become ready. " + f"Requested strategy=STRICT_SPREAD, bundles={bundles}. " + "This may be due to insufficient cluster capacity." + ) from e + except Exception as e: + try: + ray.util.remove_placement_group(pg) + except Exception as cleanup_error: + logger.warning(f"Failed to remove placement group after scheduling failure: {cleanup_error}") + raise RuntimeError( + f"Failed to create Yuanrong placement group. Requested strategy=STRICT_SPREAD, bundles={bundles}." + ) from e + + logger.info(f"Created placement group with {len(bundles)} bundles using STRICT_SPREAD") + + try: + # Create all worker actors using placement group + # Without node resources, actor scheduling order is not guaranteed to match node order + # We'll identify head node actor by checking which node it runs on + worker_actors = [] + for rank in range(len(ordered_nodes)): + actor = YuanrongWorkerActor.options( # type: ignore[attr-defined] + placement_group=pg, + placement_group_bundle_index=rank, + ).remote(node_ips, worker_port, metastore_port, worker_args) + worker_actors.append(actor) + + logger.info(f"Created {len(worker_actors)} YuanrongWorkerActor instances") + + # Find which actor is running on the head node (driver IP) + # The head node actor needs to start first to initialize metastore service + head_actor_index = None + for idx, actor in enumerate(worker_actors): + try: + node_ip = ray.get(actor.get_node_ip.remote()) + if node_ip == driver_ip: + head_actor_index = idx + break + except Exception: + pass + + if head_actor_index is None: + logger.warning("Could not identify head node actor, using actor 0 as default") + head_actor_index = 0 + + logger.info(f"Head node actor identified: actor {head_actor_index}") + + # Start head worker first to initialize metastore service + logger.info("Starting head worker to initialize metastore...") + ray.get(worker_actors[head_actor_index].start.remote()) + metastore_address = ray.get(worker_actors[head_actor_index].get_metastore_address.remote()) + logger.info(f"Head worker started, metastore address: {metastore_address}") + + # Start remaining worker actors in parallel + other_actors = [worker_actors[i] for i in range(len(worker_actors)) if i != head_actor_index] + if other_actors: + logger.info(f"Starting {len(other_actors)} worker actors in parallel...") + ray.get([actor.start.remote() for actor in other_actors]) + + logger.info( + f"Yuanrong backend started successfully: metastore at {metastore_address}, workers on {len(node_ips)} nodes" + ) + + return { + "worker_actors": worker_actors, + "metastore_address": metastore_address, + "placement_group": pg, + } + except Exception as e: + # Cleanup on initialization failure: attempt graceful stop of started workers first + logger.error(f"Failed to start Yuanrong workers: {e}, cleaning up...") + + # Try to gracefully stop workers that may have already started + if worker_actors: + stop_exceptions = [] + # Stop worker nodes (all except head node 0) first + if len(worker_actors) > 1: + stop_refs = [actor.stop.remote() for actor in worker_actors[1:]] + for idx, stop_ref in enumerate(stop_refs, start=1): + try: + ray.get(stop_ref, timeout=30) + except Exception as stop_e: + stop_exceptions.append(stop_e) + logger.warning(f"Failed to stop worker node actor {idx}: {stop_e}") + # Stop head node (actor 0) + try: + ray.get(worker_actors[0].stop.remote(), timeout=30) + except Exception as stop_e: + stop_exceptions.append(stop_e) + logger.warning(f"Failed to stop head node actor: {stop_e}") + + if stop_exceptions: + logger.warning(f"Encountered {len(stop_exceptions)} errors during graceful worker stop") + + # Then kill actors and remove placement group + kill_actors_and_placement_group(worker_actors, pg) + raise diff --git a/transfer_queue/utils/yuanrong_utils.py b/transfer_queue/utils/yuanrong_utils.py index 17d5ca82..5f8ddf6e 100644 --- a/transfer_queue/utils/yuanrong_utils.py +++ b/transfer_queue/utils/yuanrong_utils.py @@ -13,16 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - import logging import os -import shutil import socket -import subprocess from typing import Any import ray -from omegaconf import DictConfig logger = logging.getLogger(__name__) logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING)) @@ -126,238 +122,7 @@ def find_reachable_host(port: int, timeout: float = 1.0) -> str | None: return None -def _parse_remote_h2d_device_ids(worker_args: str) -> str | None: - """Parse --remote_h2d_device_ids parameter from worker_args string. - - Args: - worker_args: Worker arguments string, e.g., "--arg1 value1 --remote_h2d_device_ids 0,1,2,3" - - Returns: - The device IDs string if found and valid, None otherwise. - - Raises: - RuntimeError: If --remote_h2d_device_ids flag is found but has invalid format. - """ - if not worker_args: - return None - - args_list = worker_args.split() - - # Find the index of --remote_h2d_device_ids - try: - idx = args_list.index("--remote_h2d_device_ids") - except ValueError: - return None - - # Check if there's a value after the flag - if idx + 1 >= len(args_list): - raise RuntimeError("--remote_h2d_device_ids flag found but no value provided") - - device_ids = args_list[idx + 1] - - # Validate the format: comma-separated digits - if not device_ids: - raise RuntimeError("Empty device IDs value after --remote_h2d_device_ids") - - # Validate each segment is a digit - parts = device_ids.split(",") - for part in parts: - if not part.isdigit(): - raise RuntimeError( - f"Invalid device ID format: '{device_ids}'. Expected comma-separated digits (e.g., '0,1,2,3')." - ) - - return device_ids - - -def start_datasystem_worker( - worker_address: str, - metastore_address: str, - is_head: bool, - worker_args: str = "", -) -> None: - """Start Yuanrong datasystem worker in metastore mode. - - Args: - worker_address: Worker address in format host:port - metastore_address: Metastore address in format host:port - is_head: Whether this node should start metastore service - worker_args: Additional arguments to append to dscli start command - - Raises: - RuntimeError: If dscli command fails - """ - if not shutil.which("dscli"): - raise RuntimeError("dscli executable not found in PATH. Please run `pip install openyuanrong-datasystem`.") - - cmd = ["dscli", "start", "-w", "--worker_address", worker_address] - cmd.extend(["--metastore_address", metastore_address]) - if is_head: - cmd.extend(["--start_metastore_service", "true"]) - - # Built-in default options - cmd.extend(["--arena_per_tenant", "1", "--enable_worker_worker_batch_get", "true"]) - - # Append worker_args if provided - if worker_args: - cmd.extend(worker_args.split()) - - node_type = "head node" if is_head else "worker node" - logger.info(f"Starting Yuanrong datasystem ({node_type}) at {worker_address}, worker_args={worker_args}") - - # Build environment with ASCEND_RT_VISIBLE_DEVICES if specified - env = None - device_ids = _parse_remote_h2d_device_ids(worker_args) - if device_ids: - env = os.environ.copy() - env["ASCEND_RT_VISIBLE_DEVICES"] = device_ids - logger.info( - f"Setting ASCEND_RT_VISIBLE_DEVICES={device_ids} for dscli subprocess ({node_type} at {worker_address})" - ) - - try: - ds_result = subprocess.run( - cmd, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - timeout=90, - env=env, - ) - except subprocess.TimeoutExpired as err: - raise RuntimeError(f"dscli start timed out: {err}") from err - - if ds_result.returncode == 0 and "[ OK ]" in ds_result.stdout: - logger.info( - f"dscli started Yuanrong datasystem ({node_type}, metastore mode) at {worker_address} successfully." - ) - else: - raise RuntimeError( - f"Failed to start datasystem ({node_type}, metastore mode) at {worker_address}. " - f"Return code: {ds_result.returncode}, Output: {ds_result.stdout}" - ) - - -def stop_datasystem_worker(worker_address: str) -> None: - """Stop Yuanrong datasystem worker. - - Args: - worker_address: Worker address in format host:port - """ - if worker_address: - try: - result = subprocess.run( - ["dscli", "stop", "--worker_address", worker_address], - timeout=90, - capture_output=True, - ) - if result.returncode == 0: - logger.info(f"Stopped datasystem worker at {worker_address} via dscli stop") - else: - error_msg = (result.stderr or result.stdout or b"").decode() - logger.warning( - f"Failed to stop datasystem worker at {worker_address}. " - f"Return code: {result.returncode}, Error: {error_msg}" - ) - except subprocess.TimeoutExpired as err: - logger.warning(f"dscli stop timed out for {worker_address}: {err}") - except Exception as e: - logger.warning(f"Failed to stop datasystem worker via dscli: {e}") - - -@ray.remote(num_cpus=0.1) -class YuanrongWorkerActor: - """Ray actor to manage Yuanrong datasystem worker on a node. - - This actor runs on each node in the Ray cluster and is responsible for - starting and stopping the Yuanrong datasystem worker process on that node. - - The actor determines its own rank and role (head or worker) by finding the - intersection of local IP addresses with the provided node IPs. - """ - - def __init__(self, node_ips: list[str], worker_port: int, metastore_port: int, worker_args: str = ""): - """Initialize the Yuanrong worker actor. - - Args: - node_ips: List of all node IPs in the Ray cluster - worker_port: Port for the datasystem worker - metastore_port: Port for the metastore service (on head node) - worker_args: Additional arguments to append to dscli start command - - Raises: - RuntimeError: If cannot determine this node's IP from node_ips - """ - local_ips = get_local_ip_addresses() - self.my_ip = None - - # Find the intersection between local IPs and node_ips - for ip in node_ips: - if ip in local_ips: - self.my_ip = ip - break - - if self.my_ip is None: - raise RuntimeError(f"Cannot determine local node IP. Local IPs: {local_ips}, Cluster node IPs: {node_ips}") - - self.node_ips = node_ips - self.worker_port = worker_port - self.metastore_port = metastore_port - self.worker_address = f"{self.my_ip}:{worker_port}" - self.worker_args = worker_args - - # First node in the list is assumed to be the head node. - # This assumption is based on how interface.py constructs node_ips from ray.nodes(). - self.head_node_ip = node_ips[0] - self.metastore_address = f"{self.head_node_ip}:{metastore_port}" - self.is_head = self.my_ip == self.head_node_ip - - logger.info( - f"YuanrongWorkerActor initialized on node {self.my_ip}: " - f"worker_address={self.worker_address}, " - f"metastore_address={self.metastore_address}, is_head={self.is_head}, worker_args={self.worker_args}" - ) - - def start(self) -> str: - """Start the datasystem worker on this node. - - Returns: - The worker address. - - Raises: - RuntimeError: If dscli command fails - """ - logger.info(f"Starting datasystem worker at {self.worker_address}...") - start_datasystem_worker( - self.worker_address, - metastore_address=self.metastore_address, - is_head=self.is_head, - worker_args=self.worker_args, - ) - logger.info(f"Datasystem worker started successfully at {self.worker_address}") - return self.worker_address - - def get_metastore_address(self) -> str: - """Get the metastore address. - - Returns: - The metastore address in format host:port - """ - return self.metastore_address - - def get_node_ip(self) -> str: - """Return the IP address of the node this actor is running on.""" - assert self.my_ip is not None - return self.my_ip - - def stop(self) -> None: - """Stop the datasystem worker on this node.""" - logger.info(f"Stopping datasystem worker at {self.worker_address}...") - stop_datasystem_worker(self.worker_address) - logger.info(f"Datasystem worker stopped successfully at {self.worker_address}") - - -def _kill_actors_and_placement_group(worker_actors: list, placement_group: Any) -> None: +def kill_actors_and_placement_group(worker_actors: list, placement_group: Any) -> None: """Kill actors and remove placement group without stopping workers. Args: @@ -420,169 +185,6 @@ def cleanup_yuanrong_resources(storage_value: Any) -> None: logger.warning(f"Encountered {len(stop_exceptions)} errors while stopping workers") finally: # Kill actors and remove placement group even if graceful stop fails. - _kill_actors_and_placement_group(worker_actors, placement_group) + kill_actors_and_placement_group(worker_actors, placement_group) if placement_group: logger.info("Removed Yuanrong placement group") - - -def initialize_yuanrong_backend(conf: DictConfig) -> dict[str, Any]: - """Initialize Yuanrong backend with metastore mode. - - This function sets up the Yuanrong datasystem workers across all Ray nodes - using placement groups and actors. - - Args: - conf: Configuration containing Yuanrong settings - - Returns: - Dict containing worker_actors, metastore_address, and placement_group - - Raises: - RuntimeError: If Ray nodes not found or initialization fails - """ - # Get Ray cluster information - nodes = ray.nodes() - if not nodes: - raise RuntimeError("No Ray nodes found. Is Ray initialized?") - - # Filter to only alive nodes and get their IPs - alive_nodes = [node for node in nodes if node.get("Alive", False)] - if not alive_nodes: - raise RuntimeError("No alive Ray nodes found") - - # Get driver node IP to use as head node - driver_ip = ray.util.get_node_ip_address() - head_node = None - other_nodes = [] - - # Separate head node (driver) from other nodes - for node in alive_nodes: - node_ip = node["NodeManagerAddress"] - if node_ip == driver_ip: - head_node = node - else: - other_nodes.append(node) - - if head_node is None: - raise RuntimeError(f"Driver node {driver_ip} not found in alive nodes") - - # Reorder nodes: head node first, then others - ordered_nodes = [head_node] + other_nodes - - # Extract node IPs in deterministic order - node_ips = [node["NodeManagerAddress"] for node in ordered_nodes] - worker_port = conf.backend.Yuanrong.worker_port - metastore_port = conf.backend.Yuanrong.metastore_port - worker_args = conf.backend.Yuanrong.get("worker_args", "") - - logger.info(f"Found {len(ordered_nodes)} alive Ray nodes: {node_ips}") - - # Create placement group using STRICT_SPREAD to ensure each bundle is on a distinct node - bundles = [{"CPU": 0.1} for _ in ordered_nodes] - - pg = ray.util.placement_group(bundles, strategy="STRICT_SPREAD") - try: - ray.get(pg.ready(), timeout=60) - except ray.exceptions.GetTimeoutError as e: - try: - ray.util.remove_placement_group(pg) - except Exception as cleanup_error: - logger.warning(f"Failed to remove placement group after readiness timeout: {cleanup_error}") - raise RuntimeError( - "Timed out waiting for Yuanrong placement group to become ready. " - f"Requested strategy=STRICT_SPREAD, bundles={bundles}. " - "This may be due to insufficient cluster capacity." - ) from e - except Exception as e: - try: - ray.util.remove_placement_group(pg) - except Exception as cleanup_error: - logger.warning(f"Failed to remove placement group after scheduling failure: {cleanup_error}") - raise RuntimeError( - f"Failed to create Yuanrong placement group. Requested strategy=STRICT_SPREAD, bundles={bundles}." - ) from e - - logger.info(f"Created placement group with {len(bundles)} bundles using STRICT_SPREAD") - - try: - # Create all worker actors using placement group - # Without node resources, actor scheduling order is not guaranteed to match node order - # We'll identify head node actor by checking which node it runs on - worker_actors = [] - for rank in range(len(ordered_nodes)): - actor = YuanrongWorkerActor.options( # type: ignore[attr-defined] - placement_group=pg, - placement_group_bundle_index=rank, - ).remote(node_ips, worker_port, metastore_port, worker_args) - worker_actors.append(actor) - - logger.info(f"Created {len(worker_actors)} YuanrongWorkerActor instances") - - # Find which actor is running on the head node (driver IP) - # The head node actor needs to start first to initialize metastore service - head_actor_index = None - for idx, actor in enumerate(worker_actors): - try: - node_ip = ray.get(actor.get_node_ip.remote()) - if node_ip == driver_ip: - head_actor_index = idx - break - except Exception: - pass - - if head_actor_index is None: - logger.warning("Could not identify head node actor, using actor 0 as default") - head_actor_index = 0 - - logger.info(f"Head node actor identified: actor {head_actor_index}") - - # Start head worker first to initialize metastore service - logger.info("Starting head worker to initialize metastore...") - ray.get(worker_actors[head_actor_index].start.remote()) - metastore_address = ray.get(worker_actors[head_actor_index].get_metastore_address.remote()) - logger.info(f"Head worker started, metastore address: {metastore_address}") - - # Start remaining worker actors in parallel - other_actors = [worker_actors[i] for i in range(len(worker_actors)) if i != head_actor_index] - if other_actors: - logger.info(f"Starting {len(other_actors)} worker actors in parallel...") - ray.get([actor.start.remote() for actor in other_actors]) - - logger.info( - f"Yuanrong backend started successfully: metastore at {metastore_address}, workers on {len(node_ips)} nodes" - ) - - return { - "worker_actors": worker_actors, - "metastore_address": metastore_address, - "placement_group": pg, - } - except Exception as e: - # Cleanup on initialization failure: attempt graceful stop of started workers first - logger.error(f"Failed to start Yuanrong workers: {e}, cleaning up...") - - # Try to gracefully stop workers that may have already started - if worker_actors: - stop_exceptions = [] - # Stop worker nodes (all except head node 0) first - if len(worker_actors) > 1: - stop_refs = [actor.stop.remote() for actor in worker_actors[1:]] - for idx, stop_ref in enumerate(stop_refs, start=1): - try: - ray.get(stop_ref, timeout=30) - except Exception as stop_e: - stop_exceptions.append(stop_e) - logger.warning(f"Failed to stop worker node actor {idx}: {stop_e}") - # Stop head node (actor 0) - try: - ray.get(worker_actors[0].stop.remote(), timeout=30) - except Exception as stop_e: - stop_exceptions.append(stop_e) - logger.warning(f"Failed to stop head node actor: {stop_e}") - - if stop_exceptions: - logger.warning(f"Encountered {len(stop_exceptions)} errors during graceful worker stop") - - # Then kill actors and remove placement group - _kill_actors_and_placement_group(worker_actors, pg) - raise