Skip to content

Commit 5766b16

Browse files
committed
feat(experimental): integrate Ray RDT for weight syncing
Signed-off-by: Haichuan Hu <kaisennhu@gmail.com>
1 parent 19fec3b commit 5766b16

17 files changed

Lines changed: 2865 additions & 32 deletions

File tree

areal/experimental/inference_service/sglang/launch_server.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,21 @@ def areal_launch_server(server_args) -> None:
2424
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
2525

2626
# ---- BEGIN AREAL ----
27-
from areal.experimental.inference_service.sglang.awex import (
28-
register_awex_endpoints,
29-
)
27+
from areal.experimental.inference_service.sglang.awex import register_awex_endpoints
28+
from areal.experimental.inference_service.sglang.rdt import register_rdt_endpoints
3029
from areal.experimental.inference_service.sglang.rpc_proxy import RpcProxy
3130
from areal.experimental.inference_service.sglang.scheduler import (
3231
areal_run_scheduler_process,
3332
create_result_ipc,
33+
get_weight_update_backend,
3434
)
3535
# ---- END AREAL ----
3636

3737
# ---- BEGIN AREAL ----
38-
result_ipc = create_result_ipc()
38+
backend = getattr(server_args, "weight_update_backend", None)
39+
if backend is None:
40+
backend = get_weight_update_backend()
41+
result_ipc = create_result_ipc(backend)
3942
# ---- END AREAL ----
4043

4144
(
@@ -60,7 +63,10 @@ def areal_launch_server(server_args) -> None:
6063

6164
# ---- BEGIN AREAL ----
6265
rpc_proxy = RpcProxy(port_args, result_ipc)
63-
register_awex_endpoints(app, rpc_proxy)
66+
if backend == "awex":
67+
register_awex_endpoints(app, rpc_proxy)
68+
elif backend == "rdt":
69+
register_rdt_endpoints(app, rpc_proxy)
6470
# ---- END AREAL ----
6571

6672
try:
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""RDT HTTP endpoints for IW weight update.
3+
4+
Reference: areal.experimental.inference_service.sglang.awex
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from typing import TYPE_CHECKING
10+
11+
from fastapi import FastAPI, Request
12+
from fastapi.responses import JSONResponse
13+
14+
from areal.utils import logging
15+
16+
if TYPE_CHECKING:
17+
from areal.experimental.inference_service.sglang.rpc_proxy import RpcProxy
18+
19+
logger = logging.getLogger("RDTIWEndpoints")
20+
21+
22+
def register_rdt_endpoints(app: FastAPI, rpc_proxy: RpcProxy) -> None:
23+
"""Register ``/rdt/*`` weight-update endpoints on IW's FastAPI app.
24+
25+
Each endpoint dispatches to all scheduler processes via RpcProxy,
26+
using collective_rpc_with_result or collective_rpc.
27+
28+
Args:
29+
app: FastAPI application
30+
rpc_proxy: RpcProxy for scheduler subprocess communication
31+
"""
32+
33+
@app.get("/rdt/report_parallelism")
34+
async def report_parallelism() -> JSONResponse:
35+
"""Report IW parallelism strategy for TransferPlan building."""
36+
try:
37+
result = rpc_proxy.collective_rpc_with_result("rdt_report_parallelism")
38+
if not isinstance(result, dict):
39+
err_msg = f"Expected dict from rdt_report_parallelism, got {type(result).__name__}"
40+
logger.error(err_msg)
41+
return JSONResponse(status_code=500, content={"error": err_msg})
42+
return JSONResponse(content=result)
43+
except Exception as e:
44+
logger.error("Failed to report parallelism: %s", e)
45+
return JSONResponse(status_code=500, content={"error": str(e)})
46+
47+
@app.post("/rdt/report_weight_meta")
48+
async def report_weight_meta() -> JSONResponse:
49+
"""Report IW weight metadata for TransferPlan building."""
50+
try:
51+
result = rpc_proxy.collective_rpc_with_result("rdt_report_weight_meta")
52+
return JSONResponse(content={"status": "ok", "meta": result})
53+
except Exception as e:
54+
logger.error("Failed to report weight meta: %s", e)
55+
return JSONResponse(status_code=500, content={"error": str(e)})
56+
57+
@app.post("/rdt/init_weight_update_group")
58+
async def init_weight_update_group(request: Request) -> JSONResponse:
59+
"""Initialize RDT weight update group.
60+
61+
Args passed via JSON body:
62+
pair_name: TW-IW pair identifier
63+
kv_store_url: Gateway KV store URL
64+
tw_actor_bytes_b64_list: Base64-encoded TW actor handle bytes
65+
infer_world_size: Total IW world size
66+
train_world_size: Total TW world size
67+
num_engines: Number of IW engines
68+
transfer_rank: IW's transfer rank
69+
"""
70+
try:
71+
data = await request.json()
72+
rpc_proxy.collective_rpc("rdt_init_weight_update_group", **data)
73+
return JSONResponse(content={"status": "ok"})
74+
except Exception as e:
75+
logger.error("Failed to init RDT weight update group: %s", e)
76+
return JSONResponse(status_code=500, content={"error": str(e)})
77+
78+
@app.post("/rdt/execute_weight_update")
79+
async def execute_weight_update(request: Request) -> JSONResponse:
80+
"""Execute RDT weight update - pull from TW via Ray RPC.
81+
82+
Args passed via JSON body:
83+
version: Weight version number (optional, default 0)
84+
"""
85+
try:
86+
data = await request.json()
87+
version = data.get("version", 0)
88+
rpc_proxy.collective_rpc("rdt_execute_weight_update", version=version)
89+
return JSONResponse(content={"status": "ok", "version": version})
90+
except Exception as e:
91+
logger.error("Failed to execute RDT weight update: %s", e)
92+
return JSONResponse(status_code=500, content={"error": str(e)})
93+
94+
# ---------------------------------------------------------------------------
95+
# Debug endpoints for E2E testing
96+
# ---------------------------------------------------------------------------
97+
98+
@app.post("/rdt/debug/randomize_parameters")
99+
async def randomize_parameters() -> JSONResponse:
100+
"""Randomize model parameters for testing."""
101+
try:
102+
rpc_proxy.collective_rpc("rdt_randomize_parameters")
103+
return JSONResponse(content={"status": "ok"})
104+
except Exception as e:
105+
logger.error("Failed to randomize parameters: %s", e)
106+
return JSONResponse(status_code=500, content={"error": str(e)})
107+
108+
@app.post("/rdt/debug/get_parameters")
109+
async def get_parameters(request: Request) -> JSONResponse:
110+
"""Save parameters to disk for validation."""
111+
try:
112+
data = await request.json()
113+
rpc_proxy.collective_rpc("rdt_get_parameters", **data)
114+
return JSONResponse(content={"status": "ok"})
115+
except Exception as e:
116+
logger.error("Failed to get parameters: %s", e)
117+
return JSONResponse(status_code=500, content={"error": str(e)})

areal/experimental/inference_service/sglang/scheduler.py

Lines changed: 122 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
"""AwexSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler."""
2+
"""AwexSchedulerBridge/RDTSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler."""
33

44
from __future__ import annotations
55

@@ -12,9 +12,16 @@
1212
import zmq
1313
from sglang.srt.server_args import PortArgs, ServerArgs
1414

15+
from areal.experimental.weight_update import (
16+
BACKEND_AWEX,
17+
BACKEND_RDT,
18+
WEIGHT_UPDATE_BACKEND_ENV,
19+
get_weight_update_backend,
20+
)
1521
from areal.infra.rpc.serialization import serialize_value
1622

1723
RESULT_IPC_ENV = "AREAL_AWEX_RESULT_IPC"
24+
RDT_RESULT_IPC_ENV = "AREAL_RDT_RESULT_IPC"
1825

1926

2027
class AwexSchedulerBridge:
@@ -118,6 +125,96 @@ def awex_randomize_parameters(self) -> None:
118125
self._require_adapter().randomize_parameters()
119126

120127

128+
class RDTSchedulerBridge:
129+
"""Compose RDT weight-update capabilities onto a plain Scheduler instance.
130+
131+
Lifecycle:
132+
1. Created after ``Scheduler.__init__()`` in :func:`areal_run_scheduler_process`
133+
2. :meth:`bind` attaches ``rdt_*`` methods to the scheduler via ``setattr``
134+
3. ``handle_rpc_request`` dispatches via ``getattr(self, method)`` and finds them
135+
4. Methods delegate to :class:`RDTSGLangAdapter` for actual work
136+
5. Data-returning methods push results via ZMQ PUSH (tp_rank 0, dp_rank 0 only)
137+
138+
No inheritance. No monkey-patch. The scheduler instance remains a plain
139+
``sglang.srt.managers.scheduler.Scheduler``.
140+
"""
141+
142+
def __init__(self, scheduler: Any) -> None:
143+
self._scheduler = scheduler
144+
self._adapter: Any | None = None
145+
self._result_push: zmq.Socket | None = None
146+
147+
result_ipc = os.environ.get(RDT_RESULT_IPC_ENV)
148+
if (
149+
result_ipc
150+
and scheduler.tp_rank == 0
151+
and (getattr(scheduler, "dp_rank", None) is None or scheduler.dp_rank == 0)
152+
):
153+
ctx = zmq.Context(1)
154+
self._result_push = ctx.socket(zmq.PUSH)
155+
self._result_push.connect(result_ipc)
156+
157+
def bind(self) -> None:
158+
"""Attach ``rdt_*`` methods to the scheduler instance."""
159+
methods = [
160+
"rdt_report_weight_meta",
161+
"rdt_report_parallelism",
162+
"rdt_init_weight_update_group",
163+
"rdt_execute_weight_update",
164+
"rdt_randomize_parameters",
165+
"rdt_get_parameters",
166+
]
167+
for name in methods:
168+
setattr(self._scheduler, name, getattr(self, name))
169+
170+
def _require_adapter(self) -> Any:
171+
if self._adapter is None:
172+
from areal.experimental.weight_update.rdt.sglang_adapter import (
173+
RDTSGLangAdapter,
174+
)
175+
176+
self._adapter = RDTSGLangAdapter(self._scheduler)
177+
return self._adapter
178+
179+
def _push_result(self, result: Any) -> None:
180+
if self._result_push is not None:
181+
self._result_push.send_pyobj(result)
182+
183+
def rdt_report_weight_meta(self) -> None:
184+
adapter = self._require_adapter()
185+
local_meta = adapter.get_weight_metadata()
186+
s = self._scheduler
187+
188+
if s.tp_size > 1:
189+
gathered: list[list] = [[] for _ in range(s.tp_size)]
190+
dist.all_gather_object(gathered, local_meta, group=s.tp_cpu_group)
191+
all_meta: list = []
192+
for rank_meta in gathered:
193+
all_meta.extend(rank_meta)
194+
self._push_result(serialize_value(all_meta))
195+
else:
196+
self._push_result(serialize_value(local_meta))
197+
198+
def rdt_report_parallelism(self) -> None:
199+
self._push_result(self._require_adapter().parallelism_strategy)
200+
201+
def rdt_init_weight_update_group(self, **kwargs: Any) -> None:
202+
self._require_adapter().rdt_init_weight_update_group(**kwargs)
203+
204+
def rdt_execute_weight_update(self, version: int = 0) -> None:
205+
self._require_adapter().rdt_execute_weight_update(version)
206+
207+
def rdt_randomize_parameters(self) -> None:
208+
"""Randomize model parameters for testing."""
209+
self._require_adapter().randomize_parameters()
210+
211+
def rdt_get_parameters(
212+
self, save_path: str, names: list[str] | None = None
213+
) -> None:
214+
"""Save parameters to disk for validation."""
215+
self._require_adapter().save_parameters(save_path, names)
216+
217+
121218
# ---------------------------------------------------------------------------
122219
# Duplicated from sglang.srt.managers.scheduler.run_scheduler_process
123220
# (SGLang commit pinned in this repo).
@@ -216,7 +313,11 @@ def areal_run_scheduler_process(
216313
)
217314

218315
# ---- BEGIN AREAL ----
219-
AwexSchedulerBridge(scheduler).bind()
316+
backend = get_weight_update_backend()
317+
if backend == BACKEND_AWEX:
318+
AwexSchedulerBridge(scheduler).bind()
319+
elif backend == BACKEND_RDT:
320+
RDTSchedulerBridge(scheduler).bind()
220321
PPSchedulerBridge(scheduler, server_args).bind()
221322
# ---- END AREAL ----
222323

@@ -229,7 +330,23 @@ def areal_run_scheduler_process(
229330
parent_process.send_signal(signal.SIGQUIT)
230331

231332

232-
def create_result_ipc() -> str:
233-
path = f"ipc://{tempfile.mktemp(prefix='areal_result_')}"
234-
os.environ[RESULT_IPC_ENV] = path
333+
def create_result_ipc(backend: str) -> str:
334+
"""Create result IPC path for given backend.
335+
336+
Sets environment variable for scheduler subprocess to read.
337+
338+
Args:
339+
backend: "awex" or "rdt"
340+
341+
Returns:
342+
IPC path string
343+
"""
344+
path = f"ipc://{tempfile.mktemp(prefix=f'areal_{backend}_result_')}"
345+
346+
if backend == BACKEND_AWEX:
347+
os.environ[RESULT_IPC_ENV] = path
348+
elif backend == BACKEND_RDT:
349+
os.environ[RDT_RESULT_IPC_ENV] = path
350+
351+
os.environ[WEIGHT_UPDATE_BACKEND_ENV] = backend
235352
return path

areal/experimental/training_service/worker/app.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from areal.experimental.training_service.worker.awex import create_awex_blueprint
1515
from areal.experimental.training_service.worker.config import TrainWorkerConfig
1616
from areal.experimental.training_service.worker.engine import create_engine_module
17+
from areal.experimental.training_service.worker.rdt import create_rdt_blueprint
18+
from areal.experimental.weight_update import get_weight_update_backend
1719
from areal.infra.platforms import current_platform
1820
from areal.infra.rpc.serialization import deserialize_value, serialize_value
1921
from areal.utils import logging
@@ -198,14 +200,25 @@ def _get_node_addr() -> str:
198200
)
199201
)
200202

201-
app.register_blueprint(
202-
create_awex_blueprint(
203-
flask_module=flask,
204-
get_engine=_get_engine,
205-
submit_to_engine_thread=_submit_to_engine_thread,
206-
run_endpoint=_run_endpoint,
203+
backend = get_weight_update_backend()
204+
if backend == "awex":
205+
app.register_blueprint(
206+
create_awex_blueprint(
207+
flask_module=flask,
208+
get_engine=_get_engine,
209+
submit_to_engine_thread=_submit_to_engine_thread,
210+
run_endpoint=_run_endpoint,
211+
)
212+
)
213+
elif backend == "rdt":
214+
app.register_blueprint(
215+
create_rdt_blueprint(
216+
flask_module=flask,
217+
get_engine=_get_engine,
218+
submit_to_engine_thread=_submit_to_engine_thread,
219+
run_endpoint=_run_endpoint,
220+
)
207221
)
208-
)
209222

210223
from areal.infra.rpc.guard.data_blueprint import data_bp
211224

0 commit comments

Comments
 (0)