Skip to content

Commit 0c81b6f

Browse files
committed
Fix safetensors validation to catch corruption after download
## Problem The safetensors validation for corrupted files only ran: 1. When `SGLANG_IS_IN_CI=true` was set (missing from GPU workflows) 2. Only for cached files, not for freshly downloaded files This caused CI failures like: ``` safetensors_rust.SafetensorError: Error while deserializing header: invalid JSON in header: EOF while parsing a value at line 1 column 0 ``` ## Solution 1. **Always validate local cache first** - Removed the `is_in_ci()` check around `find_local_hf_snapshot_dir()` so validation runs regardless of environment 2. **Add post-download validation** - New `_validate_weights_after_download()` function validates safetensors files immediately after `snapshot_download()` completes, catching truncated downloads or network corruption 3. **Add SGLANG_IS_IN_CI to GPU workflows** - Added the environment variable to pr-test.yml and nightly-test-nvidia.yml for consistency with NPU workflows ## Performance Impact Minimal - validation only reads safetensors headers (few KB), not tensor data. For a 19-shard model, validation takes ~1-2 seconds.
1 parent 4c5074e commit 0c81b6f

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

.github/workflows/nightly-test-nvidia.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ concurrency:
4747
group: nightly-test-nvidia-${{ github.ref }}
4848
cancel-in-progress: true
4949

50+
env:
51+
SGLANG_IS_IN_CI: true
52+
5053
jobs:
5154
# General tests - 1 GPU
5255
nightly-test-general-1-gpu-runner:

.github/workflows/pr-test.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ concurrency:
2525
group: pr-test-${{ github.ref }}
2626
cancel-in-progress: true
2727

28+
env:
29+
SGLANG_IS_IN_CI: true
30+
2831
jobs:
2932
# =============================================== check changes ====================================================
3033
check-changes:

python/sglang/srt/model_loader/weight_utils.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,55 @@ def find_local_hf_snapshot_dir(
421421
return None
422422

423423

424+
def _validate_weights_after_download(
425+
hf_folder: str,
426+
allow_patterns: List[str],
427+
model_name_or_path: str,
428+
) -> None:
429+
"""Validate downloaded weight files to catch corruption early.
430+
431+
This function validates safetensors files after download to catch
432+
corruption issues (truncated downloads, network errors, etc.) before
433+
model loading fails with cryptic errors.
434+
435+
Args:
436+
hf_folder: Path to the downloaded model folder
437+
allow_patterns: Patterns used to match weight files
438+
model_name_or_path: Model identifier for error messages
439+
440+
Raises:
441+
RuntimeError: If any weight files are corrupted
442+
"""
443+
import glob as glob_module
444+
445+
# Find all weight files that were downloaded
446+
weight_files: List[str] = []
447+
for pattern in allow_patterns:
448+
weight_files.extend(glob_module.glob(os.path.join(hf_folder, pattern)))
449+
450+
if not weight_files:
451+
return # No weight files to validate
452+
453+
# Validate safetensors files
454+
corrupted_files = []
455+
for f in weight_files:
456+
if f.endswith(".safetensors") and os.path.exists(f):
457+
if not _validate_safetensors_file(f):
458+
corrupted_files.append(os.path.basename(f))
459+
460+
if corrupted_files:
461+
# Clean up corrupted files so next attempt re-downloads them
462+
_cleanup_corrupted_files_selective(
463+
model_name_or_path,
464+
[os.path.join(hf_folder, f) for f in corrupted_files],
465+
)
466+
raise RuntimeError(
467+
f"Downloaded model files are corrupted for {model_name_or_path}: "
468+
f"{corrupted_files}. The corrupted files have been removed. "
469+
"Please retry to re-download the model."
470+
)
471+
472+
424473
def download_weights_from_hf(
425474
model_name_or_path: str,
426475
cache_dir: Optional[str],
@@ -446,17 +495,19 @@ def download_weights_from_hf(
446495
str: The path to the downloaded model weights.
447496
"""
448497

449-
if is_in_ci():
450-
# If the weights are already local, skip downloading and returns the path.
451-
# This is used to skip too-many Huggingface API calls in CI.
452-
path = find_local_hf_snapshot_dir(
453-
model_name_or_path, cache_dir, allow_patterns, revision
454-
)
455-
if path is not None:
456-
return path
498+
# Always check for valid local cache first.
499+
# This validates cached files and cleans up corrupted ones.
500+
path = find_local_hf_snapshot_dir(
501+
model_name_or_path, cache_dir, allow_patterns, revision
502+
)
503+
if path is not None:
504+
# Valid local cache found, skip download
505+
return path
457506

507+
# In CI, skip HF API calls if we're in offline mode or want to avoid rate limits
508+
# But we already checked for local cache above, so if we're here we need to download
458509
if not huggingface_hub.constants.HF_HUB_OFFLINE:
459-
# Before we download we look at that is available:
510+
# Before we download we look at what is available:
460511
fs = HfFileSystem()
461512
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
462513

@@ -480,6 +531,10 @@ def download_weights_from_hf(
480531
revision=revision,
481532
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
482533
)
534+
535+
# Validate downloaded files to catch corruption early
536+
_validate_weights_after_download(hf_folder, allow_patterns, model_name_or_path)
537+
483538
return hf_folder
484539

485540

0 commit comments

Comments
 (0)