Skip to content

Conversation

@alisonshao
Copy link
Collaborator

@alisonshao alisonshao commented Dec 5, 2025

Summary

  • Always validate local cache first (removed is_in_ci() gate around validation)
  • Add post-download validation via new _validate_weights_after_download() function
  • Add SGLANG_IS_IN_CI=true to GPU CI workflows (pr-test.yml, nightly-test-nvidia.yml)

Problem

The safetensors validation only ran when SGLANG_IS_IN_CI=true was set, but this env var was missing from GPU workflows. Additionally, validation only checked cached files - freshly downloaded files were never validated.

This caused CI failures like:

safetensors_rust.SafetensorError: Error while deserializing header:
invalid JSON in header: EOF while parsing a value at line 1 column 0

Related CI failure: https://github.com/sgl-project/sglang/actions/runs/19948236909/job/57203231359

Test plan

  • Verify existing CI tests pass
  • Verify corrupted safetensors files are detected both in cache and after download

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@alisonshao
Copy link
Collaborator Author

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Dec 5, 2025
@alisonshao

This comment was marked as off-topic.

@github-actions

This comment was marked as off-topic.

## Problem

The safetensors validation for corrupted files only ran:
1. When `SGLANG_IS_IN_CI=true` was set (missing from GPU workflows)
2. Only for cached files, not for freshly downloaded files

This caused CI failures like:
```
safetensors_rust.SafetensorError: Error while deserializing header:
invalid JSON in header: EOF while parsing a value at line 1 column 0
```

## Solution

1. **Always validate local cache first** - Removed the `is_in_ci()` check
   around `find_local_hf_snapshot_dir()` so validation runs regardless
   of environment

2. **Add post-download validation** - New `_validate_weights_after_download()`
   function validates safetensors files immediately after `snapshot_download()`
   completes, catching truncated downloads or network corruption

3. **Add SGLANG_IS_IN_CI to GPU workflows** - Added the environment variable
   to pr-test.yml and nightly-test-nvidia.yml for consistency with NPU workflows

## Performance Impact

Minimal - validation only reads safetensors headers (few KB), not tensor data.
For a 19-shard model, validation takes ~1-2 seconds.
@alisonshao alisonshao force-pushed the fix/validate-weights-after-download branch from 4748d17 to 4e5b4be Compare December 5, 2025 05:18
@alisonshao

This comment was marked as off-topic.

@github-actions

This comment was marked as off-topic.

@alisonshao
Copy link
Collaborator Author

Local Test Results

Tested both BEFORE download (cache validation) and AFTER download (post-download validation) scenarios:

Local test for safetensors validation logic - both before and after download.

   import os 
  import json
  import tempfile
  import shutil
  import logging
  import re
  import glob as glob_module
  from typing import List, Optional, Tuple

  logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
  logger = logging.getLogger(__name__)

  import safetensors
  import torch
  from safetensors.torch import save_file


  def _validate_safetensors_file(file_path: str) -> bool:
      """Validate that a safetensors file is readable and not corrupted."""
      try:
          with safetensors.safe_open(file_path, framework="pt", device="cpu") as f:
              _ = list(f.keys())
          return True
      except Exception as e:
          logger.warning("Corrupted safetensors file detected: %s - %s: %s", file_path, type(e).__name__, str(e))
          return False


  def _check_index_files_exist(snapshot_dir: str):
      index_files = [f for f in os.listdir(snapshot_dir) if f.endswith(".safetensors.index.json")]
      if not index_files:
          return True, None
      for index_file in index_files:
          index_path = os.path.join(snapshot_dir, index_file)
          try:
              with open(index_path) as f:
                  index_data = json.load(f)
              weight_map = index_data.get("weight_map", {})
              if not weight_map:
                  continue
              required_files = set(weight_map.values())
              missing_files = [f for f in required_files if not os.path.exists(os.path.join(snapshot_dir, f))]
              if missing_files:
                  return (False, f"Missing {len(missing_files)} file(s) from index {index_file}: {missing_files[:3]}")
          except Exception as e:
              logger.warning("Failed to read index file %s: %s", index_file, e)
      return True, None


  def _validate_sharded_model(snapshot_dir: str, weight_files: List[str]):
      index_check_valid, index_error = _check_index_files_exist(snapshot_dir)
      if not index_check_valid:
          return False, index_error, []
      shard_pattern = re.compile(r"(.*?)-(\d+)-of-(\d+)\.(safetensors|bin)")
      shard_groups = {}
      for f in weight_files:
          match = shard_pattern.match(os.path.basename(f))
          if match:
              group_key = f"{match.group(1)}-of-{match.group(3)}.{match.group(4)}"
              if group_key not in shard_groups:
                  shard_groups[group_key] = {"prefix": match.group(1), "total": int(match.group(3)), "suffix": match.group(4),
  "found_shards": [], "files": []}
              shard_groups[group_key]["found_shards"].append(int(match.group(2)))
              shard_groups[group_key]["files"].append(f)
      corrupted_files = []
      for group_key, info in shard_groups.items():
          missing = set(range(1, info["total"] + 1)) - set(info["found_shards"])
          if missing:
              return (False, f"Missing shards in {group_key}: {sorted(missing)}", [])
          if info["suffix"] == "safetensors":
              for f in info["files"]:
                  if not _validate_safetensors_file(f):
                      corrupted_files.append(f)
              if not os.path.exists(os.path.join(snapshot_dir, f"{info['prefix']}.safetensors.index.json")):
                  return (False, f"Missing index file", [])
      if corrupted_files:
          return (False, f"Corrupted: {[os.path.basename(f) for f in corrupted_files]}", corrupted_files)
      return True, None, []


  def _validate_weights_after_download(hf_folder: str, allow_patterns: List[str], model_name_or_path: str):
      weight_files = []
      for pattern in allow_patterns:
          weight_files.extend(glob_module.glob(os.path.join(hf_folder, pattern)))
      corrupted = [os.path.basename(f) for f in weight_files if f.endswith(".safetensors") and not
  _validate_safetensors_file(f)]
      if corrupted:
          raise RuntimeError(f"Downloaded model files are corrupted for {model_name_or_path}: {corrupted}")


  def create_corrupted_safetensors(path):
      with open(path, "wb") as f:
          f.write(b"\x00\x00\x00\x00\x00\x00\x00\x00garbage")

  def create_valid_safetensors(path):
      save_file({"dummy": torch.zeros(1)}, path)

  def create_sharded_model_index(snapshot_dir, num_shards):
      weight_map = {f"layer.{i}.weight": f"model-{i:05d}-of-{num_shards:05d}.safetensors" for i in range(1, num_shards + 1)}
      with open(os.path.join(snapshot_dir, "model.safetensors.index.json"), "w") as f:
          json.dump({"weight_map": weight_map}, f)

TESTS

  def test_cache_corrupted_shard():
      with tempfile.TemporaryDirectory() as d:
          create_sharded_model_index(d, 3)
          files = []
          for i in range(1, 4):
              p = os.path.join(d, f"model-{i:05d}-of-00003.safetensors")
              (create_corrupted_safetensors if i == 2 else create_valid_safetensors)(p)
              files.append(p)
          valid, msg, bad = _validate_sharded_model(d, files)
          assert not valid and len(bad) == 1
          print("✓ TEST 1: Cache - corrupted shard detected")

  def test_cache_missing_shard():
      with tempfile.TemporaryDirectory() as d:
          create_sharded_model_index(d, 3)
          files = [os.path.join(d, f"model-{i:05d}-of-00003.safetensors") for i in [1, 3]]
          for f in files: create_valid_safetensors(f)
          valid, msg, _ = _validate_sharded_model(d, files)
          assert not valid and "Missing" in msg
          print("✓ TEST 2: Cache - missing shard detected")

  def test_cache_valid():
      with tempfile.TemporaryDirectory() as d:
          create_sharded_model_index(d, 3)
          files = []
          for i in range(1, 4):
              p = os.path.join(d, f"model-{i:05d}-of-00003.safetensors")
              create_valid_safetensors(p)
              files.append(p)
          valid, _, _ = _validate_sharded_model(d, files)
          assert valid
          print("✓ TEST 3: Cache - valid model passes")

  def test_download_corrupted():
      with tempfile.TemporaryDirectory() as d:
          create_corrupted_safetensors(os.path.join(d, "model.safetensors"))
          try:
              _validate_weights_after_download(d, ["*.safetensors"], "test")
              assert False
          except RuntimeError:
              print("✓ TEST 4: Download - corruption caught")

  def test_download_sharded_corrupted():
      with tempfile.TemporaryDirectory() as d:
          for i in range(1, 4):
              p = os.path.join(d, f"model-{i:05d}-of-00003.safetensors")
              (create_corrupted_safetensors if i == 2 else create_valid_safetensors)(p)
          try:
              _validate_weights_after_download(d, ["*.safetensors"], "test")
              assert False
          except RuntimeError as e:
              assert "model-00002" in str(e)
              print("✓ TEST 5: Download - sharded corruption caught")

  def test_download_valid():
      with tempfile.TemporaryDirectory() as d:
          create_valid_safetensors(os.path.join(d, "model.safetensors"))
          _validate_weights_after_download(d, ["*.safetensors"], "test")
          print("✓ TEST 6: Download - valid passes")

  if __name__ == "__main__":
      test_cache_corrupted_shard()
      test_cache_missing_shard()
      test_cache_valid()
      test_download_corrupted()
      test_download_sharded_corrupted()
      test_download_valid()
      print("\nALL 6 TESTS PASSED!")

Output:

  ✓ TEST 1: Cache - corrupted shard detected
  ✓ TEST 2: Cache - missing shard detected
  ✓ TEST 3: Cache - valid model passes
  ✓ TEST 4: Download - corruption caught
  ✓ TEST 5: Download - sharded corruption caught
  ✓ TEST 6: Download - valid passes

  ALL 6 TESTS PASSED!

The validation correctly catches the exact error seen in CI:
SafetensorError: Error while deserializing header: invalid JSON in header: EOF while parsing a value at line 1 column 0

@Kangyan-Zhou Kangyan-Zhou merged commit b988c18 into main Dec 6, 2025
189 of 199 checks passed
@Kangyan-Zhou Kangyan-Zhou deleted the fix/validate-weights-after-download branch December 6, 2025 00:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants