11# SPDX-License-Identifier: Apache-2.0
2- """AwexSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler."""
2+ """AwexSchedulerBridge/RDTSchedulerBridge + PPSchedulerBridge: compose weight-update methods onto SGLang Scheduler."""
33
44from __future__ import annotations
55
1212import zmq
1313from sglang .srt .server_args import PortArgs , ServerArgs
1414
15+ from areal .experimental .weight_update import (
16+ BACKEND_AWEX ,
17+ BACKEND_RDT ,
18+ WEIGHT_UPDATE_BACKEND_ENV ,
19+ get_weight_update_backend ,
20+ )
1521from areal .infra .rpc .serialization import serialize_value
1622
1723RESULT_IPC_ENV = "AREAL_AWEX_RESULT_IPC"
24+ RDT_RESULT_IPC_ENV = "AREAL_RDT_RESULT_IPC"
1825
1926
2027class AwexSchedulerBridge :
@@ -118,6 +125,96 @@ def awex_randomize_parameters(self) -> None:
118125 self ._require_adapter ().randomize_parameters ()
119126
120127
128+ class RDTSchedulerBridge :
129+ """Compose RDT weight-update capabilities onto a plain Scheduler instance.
130+
131+ Lifecycle:
132+ 1. Created after ``Scheduler.__init__()`` in :func:`areal_run_scheduler_process`
133+ 2. :meth:`bind` attaches ``rdt_*`` methods to the scheduler via ``setattr``
134+ 3. ``handle_rpc_request`` dispatches via ``getattr(self, method)`` and finds them
135+ 4. Methods delegate to :class:`RDTSGLangAdapter` for actual work
136+ 5. Data-returning methods push results via ZMQ PUSH (tp_rank 0, dp_rank 0 only)
137+
138+ No inheritance. No monkey-patch. The scheduler instance remains a plain
139+ ``sglang.srt.managers.scheduler.Scheduler``.
140+ """
141+
142+ def __init__ (self , scheduler : Any ) -> None :
143+ self ._scheduler = scheduler
144+ self ._adapter : Any | None = None
145+ self ._result_push : zmq .Socket | None = None
146+
147+ result_ipc = os .environ .get (RDT_RESULT_IPC_ENV )
148+ if (
149+ result_ipc
150+ and scheduler .tp_rank == 0
151+ and (getattr (scheduler , "dp_rank" , None ) is None or scheduler .dp_rank == 0 )
152+ ):
153+ ctx = zmq .Context (1 )
154+ self ._result_push = ctx .socket (zmq .PUSH )
155+ self ._result_push .connect (result_ipc )
156+
157+ def bind (self ) -> None :
158+ """Attach ``rdt_*`` methods to the scheduler instance."""
159+ methods = [
160+ "rdt_report_weight_meta" ,
161+ "rdt_report_parallelism" ,
162+ "rdt_init_weight_update_group" ,
163+ "rdt_execute_weight_update" ,
164+ "rdt_randomize_parameters" ,
165+ "rdt_get_parameters" ,
166+ ]
167+ for name in methods :
168+ setattr (self ._scheduler , name , getattr (self , name ))
169+
170+ def _require_adapter (self ) -> Any :
171+ if self ._adapter is None :
172+ from areal .experimental .weight_update .rdt .sglang_adapter import (
173+ RDTSGLangAdapter ,
174+ )
175+
176+ self ._adapter = RDTSGLangAdapter (self ._scheduler )
177+ return self ._adapter
178+
179+ def _push_result (self , result : Any ) -> None :
180+ if self ._result_push is not None :
181+ self ._result_push .send_pyobj (result )
182+
183+ def rdt_report_weight_meta (self ) -> None :
184+ adapter = self ._require_adapter ()
185+ local_meta = adapter .get_weight_metadata ()
186+ s = self ._scheduler
187+
188+ if s .tp_size > 1 :
189+ gathered : list [list ] = [[] for _ in range (s .tp_size )]
190+ dist .all_gather_object (gathered , local_meta , group = s .tp_cpu_group )
191+ all_meta : list = []
192+ for rank_meta in gathered :
193+ all_meta .extend (rank_meta )
194+ self ._push_result (serialize_value (all_meta ))
195+ else :
196+ self ._push_result (serialize_value (local_meta ))
197+
198+ def rdt_report_parallelism (self ) -> None :
199+ self ._push_result (self ._require_adapter ().parallelism_strategy )
200+
201+ def rdt_init_weight_update_group (self , ** kwargs : Any ) -> None :
202+ self ._require_adapter ().rdt_init_weight_update_group (** kwargs )
203+
204+ def rdt_execute_weight_update (self , version : int = 0 ) -> None :
205+ self ._require_adapter ().rdt_execute_weight_update (version )
206+
207+ def rdt_randomize_parameters (self ) -> None :
208+ """Randomize model parameters for testing."""
209+ self ._require_adapter ().randomize_parameters ()
210+
211+ def rdt_get_parameters (
212+ self , save_path : str , names : list [str ] | None = None
213+ ) -> None :
214+ """Save parameters to disk for validation."""
215+ self ._require_adapter ().save_parameters (save_path , names )
216+
217+
121218# ---------------------------------------------------------------------------
122219# Duplicated from sglang.srt.managers.scheduler.run_scheduler_process
123220# (SGLang commit pinned in this repo).
@@ -216,7 +313,11 @@ def areal_run_scheduler_process(
216313 )
217314
218315 # ---- BEGIN AREAL ----
219- AwexSchedulerBridge (scheduler ).bind ()
316+ backend = get_weight_update_backend ()
317+ if backend == BACKEND_AWEX :
318+ AwexSchedulerBridge (scheduler ).bind ()
319+ elif backend == BACKEND_RDT :
320+ RDTSchedulerBridge (scheduler ).bind ()
220321 PPSchedulerBridge (scheduler , server_args ).bind ()
221322 # ---- END AREAL ----
222323
@@ -229,7 +330,23 @@ def areal_run_scheduler_process(
229330 parent_process .send_signal (signal .SIGQUIT )
230331
231332
232- def create_result_ipc () -> str :
233- path = f"ipc://{ tempfile .mktemp (prefix = 'areal_result_' )} "
234- os .environ [RESULT_IPC_ENV ] = path
333+ def create_result_ipc (backend : str ) -> str :
334+ """Create result IPC path for given backend.
335+
336+ Sets environment variable for scheduler subprocess to read.
337+
338+ Args:
339+ backend: "awex" or "rdt"
340+
341+ Returns:
342+ IPC path string
343+ """
344+ path = f"ipc://{ tempfile .mktemp (prefix = f'areal_{ backend } _result_' )} "
345+
346+ if backend == BACKEND_AWEX :
347+ os .environ [RESULT_IPC_ENV ] = path
348+ elif backend == BACKEND_RDT :
349+ os .environ [RDT_RESULT_IPC_ENV ] = path
350+
351+ os .environ [WEIGHT_UPDATE_BACKEND_ENV ] = backend
235352 return path
0 commit comments