Skip to content
Merged
21 changes: 21 additions & 0 deletions .github/workflows/nightly-test-nvidia.yml
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,28 @@ jobs:
run: |
IS_BLACKWELL=1 bash scripts/ci/ci_install_dependency.sh

- name: Run Mistral-Large-3 nightly performance test
timeout-minutes: 180
env:
TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}
PERFETTO_RELAY_URL: ${{ vars.PERFETTO_RELAY_URL }}
GPU_CONFIG: "8-gpu-b200"
SGLANG_ENABLE_JIT_DEEPGEMM: "0"
run: |
rm -rf test/performance_profiles_mistral_large3/
cd test
IS_BLACKWELL=1 python3 nightly/test_mistral_large3_perf.py

- name: Publish Mistral-Large-3 traces to storage repo
env:
GITHUB_TOKEN: ${{ secrets.GH_PAT_FOR_NIGHTLY_CI_DATA }}
GITHUB_RUN_ID: ${{ github.run_id }}
GITHUB_RUN_NUMBER: ${{ github.run_number }}
run: |
python3 scripts/ci/publish_traces.py --traces-dir test/performance_profiles_mistral_large3

- name: Run DeepSeek v3.1 nightly performance test
if: always()
timeout-minutes: 180
env:
TRACE_BASE_URL: https://raw.githubusercontent.com/sglang-bot/sglang-ci-data/main/traces/${{ github.run_id }}
Expand Down
179 changes: 97 additions & 82 deletions python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
def get_lock(
model_name_or_path: str, cache_dir: Optional[str] = None, suffix: str = ""
):
lock_dir = cache_dir or temp_dir
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
lock_file_name = hash_name + model_name + suffix + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
Expand Down Expand Up @@ -309,8 +311,22 @@ def find_local_hf_snapshot_dir(
except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e)

# Check for incomplete files and clean up if found
if found_local_snapshot_dir:
# if local snapshot exists, validate it contains at least one weight file
# matching allow_patterns before skipping download.
if found_local_snapshot_dir is None:
return None

# Use file lock to prevent multiple processes (TP ranks) from
# validating and cleaning up the same model cache simultaneously.
# This prevents race conditions where multiple ranks detect corruption
# and try to delete the same files at the same time.
with get_lock(model_name_or_path, cache_dir, suffix="-validation"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this change for Mistral large 3 model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a general bug fix discovered while debugging the Mistral test. When model weights are corrupted/missing, all 8 TP ranks detect the issue simultaneously and race to delete the cache directory. One rank succeeds, others fail with "No such file or directory" errors. This was causing CI failures.
the fix adds a file lock so only one rank performs validation and cleanup at a time. other models just happened to have complete caches so the bug wasn't triggered.

# Re-check if snapshot dir still exists after acquiring lock
# (another process may have already cleaned it up)
if not os.path.isdir(found_local_snapshot_dir):
return None

# Check for incomplete files and clean up if found
repo_folder = os.path.abspath(
os.path.join(found_local_snapshot_dir, "..", "..")
)
Expand All @@ -334,91 +350,90 @@ def find_local_hf_snapshot_dir(
)
return None

# if local snapshot exists, validate it contains at least one weight file
# matching allow_patterns before skipping download.
if found_local_snapshot_dir is None:
return None

local_weight_files: List[str] = []
try:
for pattern in allow_patterns:
matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
for f in matched_files:
# os.path.exists returns False for broken symlinks.
if not os.path.exists(f):
continue
local_weight_files.append(f)
except Exception as e:
logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s",
found_local_snapshot_dir,
allow_patterns,
e,
)
local_weight_files = []

# Validate sharded models and check for corruption
if local_weight_files:
is_valid, error_msg, corrupted_files = _validate_sharded_model(
found_local_snapshot_dir, local_weight_files
)
if not is_valid:
if corrupted_files:
# Selective cleanup: only remove corrupted files
log_info_on_rank0(
logger,
f"Found {len(corrupted_files)} corrupted file(s) for "
f"{model_name_or_path}: {error_msg}. "
"Will selectively clean and re-download only these files.",
)
_cleanup_corrupted_files_selective(model_name_or_path, corrupted_files)
return None
else:
# Cannot selectively clean (e.g., missing shards) - remove entire cache
log_info_on_rank0(
logger,
f"Validation failed for {model_name_or_path}: {error_msg}. "
"Will remove entire cache and re-download.",
)
_cleanup_corrupted_model_cache(
model_name_or_path, found_local_snapshot_dir, error_msg
local_weight_files: List[str] = []
try:
for pattern in allow_patterns:
matched_files = glob.glob(
os.path.join(found_local_snapshot_dir, pattern)
)
return None
for f in matched_files:
# os.path.exists returns False for broken symlinks.
if not os.path.exists(f):
continue
local_weight_files.append(f)
except Exception as e:
logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s",
found_local_snapshot_dir,
allow_patterns,
e,
)
local_weight_files = []

# Also validate single (non-sharded) safetensors files
for f in local_weight_files:
base_name = os.path.basename(f)
# Check if this is a single model file (not sharded)
# Include adapter_model.safetensors for LoRA adapters
if base_name in [
"model.safetensors",
"pytorch_model.safetensors",
"adapter_model.safetensors",
]:
if not _validate_safetensors_file(f):
# Validate sharded models and check for corruption
if local_weight_files:
is_valid, error_msg, corrupted_files = _validate_sharded_model(
found_local_snapshot_dir, local_weight_files
)
if not is_valid:
if corrupted_files:
# Selective cleanup: only remove corrupted files
log_info_on_rank0(
logger,
f"Found {len(corrupted_files)} corrupted file(s) for "
f"{model_name_or_path}: {error_msg}. "
"Will selectively clean and re-download only these files.",
)
_cleanup_corrupted_files_selective(
model_name_or_path, corrupted_files
)
return None
else:
# Cannot selectively clean (e.g., missing shards) - remove entire cache
log_info_on_rank0(
logger,
f"Corrupted model file {base_name} for {model_name_or_path}. "
"Will selectively clean and re-download this file.",
f"Validation failed for {model_name_or_path}: {error_msg}. "
"Will remove entire cache and re-download.",
)
_cleanup_corrupted_model_cache(
model_name_or_path, found_local_snapshot_dir, error_msg
)
# Selective cleanup for single file
_cleanup_corrupted_files_selective(model_name_or_path, [f])
return None

if len(local_weight_files) > 0:
log_info_on_rank0(
logger,
f"Found local HF snapshot for {model_name_or_path} at "
f"{found_local_snapshot_dir}; skipping download.",
)
return found_local_snapshot_dir
else:
log_info_on_rank0(
logger,
f"Local HF snapshot at {found_local_snapshot_dir} has no files matching "
f"{allow_patterns}; will attempt download.",
)
return None
# Also validate single (non-sharded) safetensors files
for f in local_weight_files:
base_name = os.path.basename(f)
# Check if this is a single model file (not sharded)
# Include adapter_model.safetensors for LoRA adapters
if base_name in [
"model.safetensors",
"pytorch_model.safetensors",
"adapter_model.safetensors",
]:
if not _validate_safetensors_file(f):
log_info_on_rank0(
logger,
f"Corrupted model file {base_name} for {model_name_or_path}. "
"Will selectively clean and re-download this file.",
)
# Selective cleanup for single file
_cleanup_corrupted_files_selective(model_name_or_path, [f])
return None

if len(local_weight_files) > 0:
log_info_on_rank0(
logger,
f"Found local HF snapshot for {model_name_or_path} at "
f"{found_local_snapshot_dir}; skipping download.",
)
return found_local_snapshot_dir
else:
log_info_on_rank0(
logger,
f"Local HF snapshot at {found_local_snapshot_dir} has no files matching "
f"{allow_patterns}; will attempt download.",
)
return None


def download_weights_from_hf(
Expand Down
105 changes: 105 additions & 0 deletions test/nightly/test_mistral_large3_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import unittest
from types import SimpleNamespace

from nightly_utils import NightlyBenchmarkRunner

from sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
_parse_int_list_env,
popen_launch_server,
)

register_cuda_ci(est_time=600, suite="nightly-8-gpu-b200", nightly=True)

MISTRAL_LARGE3_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512"
PROFILE_DIR = "performance_profiles_mistral_large3"


class TestNightlyMistralLarge3Performance(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Set environment variable to disable JIT DeepGemm
os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"] = "0"

cls.model = MISTRAL_LARGE3_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.batch_sizes = [1, 1, 8, 16, 64]
cls.input_lens = tuple(_parse_int_list_env("NIGHTLY_INPUT_LENS", "4096"))
cls.output_lens = tuple(_parse_int_list_env("NIGHTLY_OUTPUT_LENS", "512"))

# Mistral-Large-3-675B requires TP=8 and trtllm_mla attention backend
cls.other_args = [
"--tp",
"8",
"--attention-backend",
"trtllm_mla",
"--model-loader-extra-config",
'{"enable_multithread_load": true}',
"--chat-template",
"mistral",
]

cls.runner = NightlyBenchmarkRunner(PROFILE_DIR, cls.__name__, cls.base_url)
cls.runner.setup_profile_directory()

@classmethod
def tearDownClass(cls):
# Clean up environment variable
if "SGLANG_ENABLE_JIT_DEEPGEMM" in os.environ:
del os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"]

def test_bench_one_batch(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add accuracy test for it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, i can do that. What score threshold should we expect? How do we determine this value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just added test_accuracy_mgsm method using the same eval framework as other text models. Set initial threshold to 0.90 (placeholder). Will calibrate the threshold after observing actual model performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MGSM accuracy for mistralai/Mistral-Large-3-675B-Instruct-2512: 0.972
https://github.com/sgl-project/sglang/actions/runs/19961215892/job/57241902940#step:5:2471

results, success = self.runner.run_benchmark_for_model(
model_path=self.model,
batch_sizes=self.batch_sizes,
input_lens=self.input_lens,
output_lens=self.output_lens,
other_args=self.other_args,
)

self.runner.add_report(results)
self.runner.write_final_report()

if not success:
raise AssertionError(
f"Benchmark failed for {self.model}. Check the logs for details."
)

def test_accuracy_mgsm(self):
"""Run MGSM accuracy evaluation for Mistral Large 3."""
process = popen_launch_server(
model=self.model,
base_url=self.base_url,
other_args=self.other_args,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)

try:
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(f"MGSM accuracy for {self.model}: {metrics['score']}")

# Placeholder threshold - adjust after first successful run
expected_threshold = 0.90
self.assertGreaterEqual(
metrics["score"],
expected_threshold,
f"MGSM accuracy {metrics['score']} below threshold {expected_threshold}",
)
finally:
kill_process_tree(process.pid)


if __name__ == "__main__":
unittest.main()
Loading