Skip to content

fix: Allow loading CUDA-saved models on CPU-only machines#2

Open
robosimon wants to merge 4 commits intomainfrom
fix/cuda-load-on-cpu
Open

fix: Allow loading CUDA-saved models on CPU-only machines#2
robosimon wants to merge 4 commits intomainfrom
fix/cuda-load-on-cpu

Conversation

@robosimon
Copy link
Copy Markdown
Owner

@robosimon robosimon commented Feb 11, 2026

Fix #1

This PR adds automatic CPU fallback when loading CUDA-saved checkpoints on CPU-only machines.

Changes:

  • Add _resolve_checkpoint_device() helper to detect CUDA → CPU fallback scenario
  • Update _load_cebra_with_sklearn_backend() to use resolved device for all .to() calls
  • Handle both str and torch.device types (including "cuda:0" variants)
  • Update cebra_.device_ and cebra_.device attributes after resolution
  • Add comprehensive tests with monkeypatched CPU-only environment (10 test cases)

Test coverage:

  • 4 CUDA device variants: "cuda", "cuda:0", torch.device("cuda"), torch.device("cuda", 0)
  • 2 model architectures: "offset1-model", "parametrized-model-5"
  • All save/load tests pass: 78 passed, 1 skipped

Verification:

  • Tests without fix: 8 FAILED (confirm bug exists)
  • Tests with fix: 10 PASSED (all green)

@robosimon
Copy link
Copy Markdown
Owner Author

@robosimon I need to remove the .gitignore and the "docs" of the issue, this doesn't belong into the code base or PR

@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch from cd21bb5 to b952d02 Compare February 11, 2026 15:25
@robosimon
Copy link
Copy Markdown
Owner Author

@robosimon Done! Removed the docs commit (ISSUE_CUDA_LOAD_ON_CPU.md and .gitignore changes). The PR now only contains:

  • Commit 1: Test - Add tests for loading CUDA-saved models on CPU-only machines
  • Commit 2: Fix - Allow loading CUDA-saved models on CPU-only machines

The branch is now clean with only the necessary code changes.

Add comprehensive tests for the CUDA-to-CPU fallback fix:
- test_load_cuda_checkpoint_falls_back_to_cpu: parametrized with 4 CUDA device variants and 2 model architectures
- test_load_cuda_checkpoint_with_device_override: tests map_location behavior

These tests verify:
- Models saved with device='cuda' can load on CPU-only machines
- Device attributes are correctly updated to 'cpu'
- Model parameters are on CPU
- Model can perform inference (transform) after loading

Expected to FAIL before the fix is applied.
@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch 2 times, most recently from 6fae503 to efe8b95 Compare February 11, 2026 16:07
When a CEBRA checkpoint was saved on a CUDA device but is loaded on a
machine without CUDA available, it now gracefully falls back to CPU
instead of crashing with RuntimeError.

Changes:
- Add _resolve_checkpoint_device() helper to handle device resolution
- Update _load_cebra_with_sklearn_backend() to use resolved device
- Handle both string and torch.device types, including cuda:0 variants
- Update model device attributes after resolution

Fixes: Loading model saved with device='cuda' on CPU-only machine
@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch from efe8b95 to 97d5b90 Compare February 11, 2026 16:12
…lLab#296

- Refactored _safe_torch_load() to use recursion instead of duplicate logic
- Added meaningful error messages when CPU fallback fails
- Added UserWarning when auto-remapping CUDA/MPS to CPU
- Extended _resolve_checkpoint_device() to handle MPS fallback
- Added test for MPS checkpoint fallback
- Added test for meaningful error on retry failure
- Added test for error with explicit map_location
- Created tests/generate_cuda_checkpoint.py utility for GPU test data
- Removed binary checkpoint files from repo
- Updated .gitignore to exclude test checkpoint binaries

All 53 tests pass (14 CUDA/MPS tests + 39 regression tests)
@robosimon robosimon force-pushed the fix/cuda-load-on-cpu branch from 1e6749f to 55f7589 Compare February 11, 2026 19:30
- Add type annotations to _resolve_checkpoint_device: Union[str, torch.device] -> str
- Remove type mentions from docstring Args/Returns
- Use torch.device.type instead of string startswith checks
- Remove obvious comments from _load_cebra_with_sklearn_backend
- Use torch.device for device type checking in load backend

All 10 related tests pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Loading a model saved on CUDA fails on CPU-only machines

1 participant