Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
)
DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8 = "Qwen/Qwen3-1.7B-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE = "gaunernst/DeepSeek-V2-Lite-Chat-FP8"
DEFAULT_MODEL_NAME_FOR_TEST_MXFP8 = "zianglih/Qwen3-30B-A3B-MXFP8"

# MXFP4 models
# Standard MXFP4 MoE test model
Expand Down
96 changes: 96 additions & 0 deletions test/registered/rl/test_update_weights_from_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
from sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MXFP8,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_blackwell_system,
is_in_ci,
popen_launch_server,
)
Expand Down Expand Up @@ -305,6 +307,100 @@ def test_update_weights_abort_all_requests(self):
self.assertNotEqual(updated_model_path, origin_model_path)


@unittest.skipIf(not is_blackwell_system(), "MXFP8 requires Blackwell (CUDA)")
class TestServerUpdateWeightsFromDiskMXFP8(CustomTestCase):
model = DEFAULT_MODEL_NAME_FOR_TEST_MXFP8
base_url = DEFAULT_URL_FOR_TEST
backend_test_suites = [
{"fp8_gemm_backend": "triton", "moe_runner_backend": "cutlass"},
{"fp8_gemm_backend": "auto", "moe_runner_backend": "auto"},
]

def _launch_server(self, fp8_gemm_backend, moe_runner_backend):
return popen_launch_server(
self.model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--fp8-gemm-backend",
fp8_gemm_backend,
"--moe-runner-backend",
moe_runner_backend,
],
)

def run_decode(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 16},
},
)
payload = response.json()
return payload["text"]

def get_model_info(self):
response = requests.get(self.base_url + "/get_model_info")
return response.json()["model_path"]

def run_update_weights(
self,
model_path,
flush_cache=True,
abort_all_requests=False,
):
response = requests.post(
self.base_url + "/update_weights_from_disk",
json={
"model_path": model_path,
"flush_cache": flush_cache,
"abort_all_requests": abort_all_requests,
},
)
ret = response.json()
print(json.dumps(ret))
return ret

def test_parameterized_update_weights_mxfp8(self):
update_test_suites = [
{"flush_cache": True, "abort_all_requests": False},
{"flush_cache": False, "abort_all_requests": False},
]
for backend_test_suite in self.backend_test_suites:
with self.subTest(**backend_test_suite):
process = self._launch_server(
backend_test_suite["fp8_gemm_backend"],
backend_test_suite["moe_runner_backend"],
)
try:
origin_model_path = self.get_model_info()
self.assertEqual(origin_model_path, self.model)
origin_response = self.run_decode()
self.assertTrue(len(origin_response) > 0)

for update_test_suite in update_test_suites:
with self.subTest(
fp8_gemm_backend=backend_test_suite["fp8_gemm_backend"],
moe_runner_backend=backend_test_suite["moe_runner_backend"],
flush_cache=update_test_suite["flush_cache"],
abort_all_requests=update_test_suite["abort_all_requests"],
):
ret = self.run_update_weights(
self.model,
flush_cache=update_test_suite["flush_cache"],
abort_all_requests=update_test_suite[
"abort_all_requests"
],
)
self.assertTrue(ret["success"])
self.assertEqual(self.get_model_info(), self.model)
updated_response = self.run_decode()
self.assertTrue(len(updated_response) > 0)
finally:
kill_process_tree(process.pid)


###############################################################################
# Parameterized Tests for update_weights_from_disk
# Test coverage is determined based on the value of is_in_ci:
Expand Down
177 changes: 176 additions & 1 deletion test/registered/rl/test_update_weights_from_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,20 @@
import requests
import torch
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM

import sglang as sgl
from sglang.srt.utils import init_custom_process_group
from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MXFP8,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
find_available_port,
is_blackwell_system,
is_in_amd_ci,
is_in_ci,
popen_launch_server,
Expand Down Expand Up @@ -848,5 +851,177 @@ def test_update_weights_from_distributed(self):
)


@unittest.skipIf(not is_blackwell_system(), "MXFP8 requires Blackwell (CUDA)")
class TestServerUpdateWeightsFromDistributedMXFP8(CustomTestCase):
model = DEFAULT_MODEL_NAME_FOR_TEST_MXFP8
base_url = DEFAULT_URL_FOR_TEST
backend_test_suites = [
{"fp8_gemm_backend": "triton", "moe_runner_backend": "cutlass"},
{"fp8_gemm_backend": "auto", "moe_runner_backend": "auto"},
]

@classmethod
def setUpClass(cls):
config = AutoConfig.from_pretrained(cls.model, trust_remote_code=True)
cls.hidden_size = getattr(
config,
"hidden_size",
getattr(getattr(config, "text_config", None), "hidden_size", None),
)
if cls.hidden_size is None:
raise ValueError("Cannot resolve hidden_size for MXFP8 model config.")

def _launch_server(self, fp8_gemm_backend, moe_runner_backend):
return popen_launch_server(
self.model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--base-gpu-id",
"1",
"--fp8-gemm-backend",
fp8_gemm_backend,
"--moe-runner-backend",
moe_runner_backend,
],
env={
"NCCL_CUMEM_ENABLE": "0",
"NCCL_NVLS_ENABLE": "0",
},
)

def _post_json(self, endpoint, payload, timeout=180):
response = requests.post(
f"{self.base_url}{endpoint}",
json=payload,
timeout=timeout,
)
response.raise_for_status()
return response.json()

def _run_decode(self):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {"temperature": 0, "max_new_tokens": 16},
},
timeout=120,
)
response.raise_for_status()
return response.json()["text"]

def test_parameterized_update_weights_from_distributed_mxfp8(self):
self.assertGreaterEqual(
torch.cuda.device_count(),
2,
"MXFP8 distributed update test requires at least 2 GPUs.",
)
param_name = "model.norm.weight"
update_test_suites = [
{"load_format": None, "target_value": 1.375},
{"load_format": "flattened_bucket", "target_value": 1.5},
]

for backend_test_suite in self.backend_test_suites:
with self.subTest(**backend_test_suite):
process = self._launch_server(
backend_test_suite["fp8_gemm_backend"],
backend_test_suite["moe_runner_backend"],
)
try:
for update_test_suite in update_test_suites:
with self.subTest(
fp8_gemm_backend=backend_test_suite["fp8_gemm_backend"],
moe_runner_backend=backend_test_suite["moe_runner_backend"],
load_format=update_test_suite["load_format"],
target_value=update_test_suite["target_value"],
):
group_name = (
"test_parameter_update_group_mxfp8_"
f"{random.randint(0, 10**8)}"
)
master_port = find_available_port(50000)
group = None
try:
# Match existing distributed test mapping:
# sender rank 0 -> GPU 0, server rank 1 -> GPU 1.
torch.cuda.set_device(0)
origin_response = self._run_decode()
self.assertTrue(len(origin_response) > 0)
param_shape = [self.hidden_size]

init_payload = {
"master_address": "127.0.0.1",
"master_port": str(master_port),
"rank_offset": 1,
"world_size": 2,
"group_name": group_name,
"backend": "nccl",
}

with ThreadPoolExecutor(1) as executor:
init_future = executor.submit(
self._post_json,
"/init_weights_update_group",
init_payload,
)
group = init_custom_process_group(
backend="nccl",
init_method=f"tcp://127.0.0.1:{master_port}",
world_size=2,
rank=0,
group_name=group_name,
)
init_ret = init_future.result(timeout=240)

self.assertTrue(init_ret["success"], f"{init_ret=}")

update_payload = {
"names": [param_name],
"dtypes": ["bfloat16"],
"shapes": [param_shape],
"group_name": group_name,
"flush_cache": True,
"load_format": update_test_suite["load_format"],
}

source_tensor = torch.full(
tuple(param_shape),
update_test_suite["target_value"],
device="cuda:0",
dtype=torch.bfloat16,
)

with ThreadPoolExecutor(1) as executor:
update_future = executor.submit(
self._post_json,
"/update_weights_from_distributed",
update_payload,
240,
)
torch.distributed.broadcast(
source_tensor, src=0, group=group
)
update_ret = update_future.result(timeout=300)

self.assertTrue(update_ret["success"], f"{update_ret=}")
updated_response = self._run_decode()
self.assertTrue(len(updated_response) > 0)
finally:
if group is not None:
torch.distributed.destroy_process_group(group)
try:
self._post_json(
"/destroy_weights_update_group",
{"group_name": group_name},
timeout=120,
)
except Exception:
pass
finally:
terminate_process(process)


if __name__ == "__main__":
unittest.main()
Loading
Loading