Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/put_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from transfer_queue import TransferQueueClient
from transfer_queue.controller import TransferQueueController
from transfer_queue.storage.simple_storage import SimpleStorageUnit
from transfer_queue.storage.backends.simple_storage import SimpleStorageUnit
Comment thread
fy2462 marked this conversation as resolved.
Outdated
from transfer_queue.utils.common import get_placement_group
from transfer_queue.utils.zmq_utils import process_zmq_server_info

Expand Down
2 changes: 1 addition & 1 deletion tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ class TestStorageUnitDataStrict:

def test_put_data_length_mismatch_raises(self):
"""put_data must raise when global_indexes and field values have different lengths."""
from transfer_queue.storage.simple_storage import StorageUnitData
from transfer_queue.storage.backends.simple_storage import StorageUnitData

sud = StorageUnitData(storage_size=10)
# 3 indexes but only 2 values — must raise, not silently drop
Expand Down
4 changes: 2 additions & 2 deletions tests/test_simple_storage_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import zmq

from transfer_queue.storage.simple_storage import SimpleStorageUnit
from transfer_queue.storage.backends.simple_storage import SimpleStorageUnit
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType


Expand Down Expand Up @@ -420,7 +420,7 @@ def test_storage_unit_data_direct():

def test_storage_unit_data_capacity_uses_active_keys():
"""Capacity check must use _active_keys, not scan field_data."""
from transfer_queue.storage.simple_storage import StorageUnitData
from transfer_queue.storage.backends.simple_storage import StorageUnitData

storage = StorageUnitData(storage_size=3)

Expand Down
134 changes: 12 additions & 122 deletions transfer_queue/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
import subprocess
import time
from importlib import resources
from typing import Any, Callable
from urllib.parse import urlparse

import ray
import torch
Expand All @@ -32,13 +30,9 @@
from transfer_queue.metadata import KVBatchMeta
from transfer_queue.sampler import * # noqa: F401
from transfer_queue.sampler import BaseSampler
from transfer_queue.storage.simple_storage import SimpleStorageUnit
from transfer_queue.utils.common import get_placement_group
from transfer_queue.storage.backends.base import StorageBackendFactory
from transfer_queue.utils.logging_utils import get_logger
from transfer_queue.utils.yuanrong_utils import (
cleanup_yuanrong_resources,
initialize_yuanrong_backend,
)
from transfer_queue.utils.yuanrong_utils import cleanup_yuanrong_resources
from transfer_queue.utils.zmq_utils import process_zmq_server_info

logger = get_logger(__name__)
Expand Down Expand Up @@ -70,125 +64,21 @@ def _maybe_create_tq_client(conf: DictConfig | None = None) -> TransferQueueClie
return _TQ_CLIENT


# TODO(hz): Adopt registry pattern to manage storage backends for better scalability.
def _maybe_create_tq_storage(conf: DictConfig) -> DictConfig:
global _TQ_STORAGE

if _TQ_STORAGE is None:
_TQ_STORAGE = {}
if conf.backend.storage_backend == "SimpleStorage":
# initialize SimpleStorageUnit
simple_storage_handles = {}
num_data_storage_units = conf.backend.SimpleStorage.num_data_storage_units
total_storage_size = conf.backend.SimpleStorage.total_storage_size
storage_placement_group = get_placement_group(num_data_storage_units, num_cpus_per_actor=1)

for storage_unit_rank in range(num_data_storage_units):
storage_node = SimpleStorageUnit.options( # type: ignore[attr-defined]
placement_group=storage_placement_group,
placement_group_bundle_index=storage_unit_rank,
name=f"TransferQueueStorageUnit#{storage_unit_rank}",
).remote(
storage_unit_size=math.ceil(total_storage_size / num_data_storage_units),
)
simple_storage_handles[f"TransferQueueStorageUnit#{storage_unit_rank}"] = storage_node
logger.info(f"TransferQueueStorageUnit#{storage_unit_rank} has been created.")

storage_zmq_info = process_zmq_server_info(simple_storage_handles)
backend_name = conf.backend.storage_backend
conf.backend[backend_name].zmq_info = storage_zmq_info
_TQ_STORAGE["SimpleStorage"] = simple_storage_handles
if conf.backend.storage_backend == "MooncakeStore":
if conf.backend.MooncakeStore.auto_init:
# Try to kill existing mooncake_master processes before starting a new one to avoid potential conflicts
check = subprocess.run(["pgrep", "-f", "mooncake_master"], stdout=subprocess.PIPE, text=True)
if check.returncode == 0:
pids = check.stdout.strip().replace("\n", ", ")
logger.info(f"Find existing mooncake_master (PID: {pids}), try to kill first...")

result = os.system('pkill -f "[m]ooncake_master"')
if result == 0:
logger.info("Successfully killed existing mooncake_master processes.")
else:
raise RuntimeError(f"Failed to kill existing mooncake_master processes (exit code: {result}).")

# process metadata_server
metadata_server_raw_address = conf.backend.MooncakeStore.metadata_server
if "://" not in metadata_server_raw_address:
metadata_server_raw_address = "//" + metadata_server_raw_address

metadata_server_parsed = urlparse(metadata_server_raw_address)

if not metadata_server_parsed.hostname or metadata_server_parsed.port is None:
raise ValueError(
f"Invalid metadata_server '{conf.backend.MooncakeStore.metadata_server}'. "
f"Host and port are required (e.g., host:port)."
)

metadata_server_host = metadata_server_parsed.hostname
metadata_server_port = str(metadata_server_parsed.port)

# process master_server
master_server_raw_address = conf.backend.MooncakeStore.master_server_address
if "://" not in master_server_raw_address:
master_server_raw_address = "//" + master_server_raw_address

master_server_parsed = urlparse(master_server_raw_address)

if not master_server_parsed.hostname or master_server_parsed.port is None:
raise ValueError(
f"Invalid master_server_address '{conf.backend.MooncakeStore.master_server_address}'. "
f"Host and port are required (e.g., host:port)."
)

master_server_port = str(master_server_parsed.port)

cmd = [
"mooncake_master",
"-client_ttl=30",
"-default_kv_lease_ttl=999999",
"-default_kv_soft_pin_ttl=999999",
"--eviction_high_watermark_ratio=1.0",
"--eviction_ratio=0.0",
"--enable_http_metadata_server=true",
"--allow_evict_soft_pinned_objects=false",
f"--http_metadata_server_host={metadata_server_host}",
f"--http_metadata_server_port={metadata_server_port}",
f"--rpc_port={master_server_port}",
]

log_file_path = "/tmp/mooncake_master.log"
with open(log_file_path, "w") as log_file:
process = subprocess.Popen(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
start_new_session=True,
)
time.sleep(3)

if process.poll() is None:
logger.info(
f"mooncake_master started, PID: {process.pid}. Logs are at: {os.path.abspath(log_file_path)}"
)
else:
error_msg = ""
try:
with open(log_file_path) as f:
error_msg = f.read()
except Exception as e:
error_msg = f"Failed to read log file: {e}"

raise RuntimeError(
f"mooncake_master exited with error. Check {log_file_path} for detailed logs. "
f"Output:\n{error_msg}"
)
_TQ_STORAGE["MooncakeStore"] = process
if conf.backend.storage_backend == "Yuanrong" and conf.backend.Yuanrong.auto_init:
_TQ_STORAGE["Yuanrong"] = initialize_yuanrong_backend(conf)
backend_name = conf.backend.storage_backend
registered_backend_fn = StorageBackendFactory.get_backend(backend_name)
if registered_backend_fn:
Comment thread
fy2462 marked this conversation as resolved.
Outdated
backend_instance = registered_backend_fn(conf)
if backend_instance:
_TQ_STORAGE[backend_name] = backend_instance
Comment thread
fy2462 marked this conversation as resolved.
Outdated
else:
logger.error(f"Not found available {backend_name} storage backend instance, please check the config.")
else:
logger.error(f"Storage backend {backend_name} not registered. Please add it to the StorageBackendFactory.")
return conf


Expand Down
2 changes: 1 addition & 1 deletion transfer_queue/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .backends import SimpleStorageUnit, StorageUnitData
from .managers import (
AsyncSimpleStorageManager,
MooncakeStorageManager,
Expand All @@ -21,7 +22,6 @@
StorageManagerFactory,
YuanrongStorageManager,
)
from .simple_storage import SimpleStorageUnit, StorageUnitData

__all__ = [
"SimpleStorageUnit",
Expand Down
24 changes: 24 additions & 0 deletions transfer_queue/storage/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2026 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import mooncake_storage, simple_storage, yuanrong_storage # noqa: F401, I001
from .base import StorageBackendFactory
from .simple_storage import SimpleStorageUnit, StorageUnitData

__all__ = [
"StorageBackendFactory",
"SimpleStorageUnit",
"StorageUnitData",
]
40 changes: 40 additions & 0 deletions transfer_queue/storage/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2026 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps
from typing import Callable


class StorageBackendFactory:
Comment thread
fy2462 marked this conversation as resolved.
Outdated
_backends: dict[str, Callable] = {}

@classmethod
def register_backend(cls, name: str):
"""Decorator to register storage backend & returns function."""

def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)

cls._backends[name.lower()] = wrapper
return wrapper

return decorator

@classmethod
def get_backend(cls, name: str) -> Callable | None:
"""Get storage backend function by name."""
return cls._backends.get(name.lower(), None)
Loading
Loading