fix: Allow loading CUDA-saved models on CPU-only machines#2
Open
fix: Allow loading CUDA-saved models on CPU-only machines#2
Conversation
Owner
Author
|
@robosimon I need to remove the .gitignore and the "docs" of the issue, this doesn't belong into the code base or PR |
cd21bb5 to
b952d02
Compare
Owner
Author
|
@robosimon Done! Removed the docs commit (ISSUE_CUDA_LOAD_ON_CPU.md and .gitignore changes). The PR now only contains:
The branch is now clean with only the necessary code changes. |
Add comprehensive tests for the CUDA-to-CPU fallback fix: - test_load_cuda_checkpoint_falls_back_to_cpu: parametrized with 4 CUDA device variants and 2 model architectures - test_load_cuda_checkpoint_with_device_override: tests map_location behavior These tests verify: - Models saved with device='cuda' can load on CPU-only machines - Device attributes are correctly updated to 'cpu' - Model parameters are on CPU - Model can perform inference (transform) after loading Expected to FAIL before the fix is applied.
6fae503 to
efe8b95
Compare
When a CEBRA checkpoint was saved on a CUDA device but is loaded on a machine without CUDA available, it now gracefully falls back to CPU instead of crashing with RuntimeError. Changes: - Add _resolve_checkpoint_device() helper to handle device resolution - Update _load_cebra_with_sklearn_backend() to use resolved device - Handle both string and torch.device types, including cuda:0 variants - Update model device attributes after resolution Fixes: Loading model saved with device='cuda' on CPU-only machine
efe8b95 to
97d5b90
Compare
…lLab#296 - Refactored _safe_torch_load() to use recursion instead of duplicate logic - Added meaningful error messages when CPU fallback fails - Added UserWarning when auto-remapping CUDA/MPS to CPU - Extended _resolve_checkpoint_device() to handle MPS fallback - Added test for MPS checkpoint fallback - Added test for meaningful error on retry failure - Added test for error with explicit map_location - Created tests/generate_cuda_checkpoint.py utility for GPU test data - Removed binary checkpoint files from repo - Updated .gitignore to exclude test checkpoint binaries All 53 tests pass (14 CUDA/MPS tests + 39 regression tests)
1e6749f to
55f7589
Compare
- Add type annotations to _resolve_checkpoint_device: Union[str, torch.device] -> str - Remove type mentions from docstring Args/Returns - Use torch.device.type instead of string startswith checks - Remove obvious comments from _load_cebra_with_sklearn_backend - Use torch.device for device type checking in load backend All 10 related tests pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix #1
This PR adds automatic CPU fallback when loading CUDA-saved checkpoints on CPU-only machines.
Changes:
_resolve_checkpoint_device()helper to detect CUDA → CPU fallback scenario_load_cebra_with_sklearn_backend()to use resolved device for all.to()callsstrandtorch.devicetypes (including"cuda:0"variants)cebra_.device_andcebra_.deviceattributes after resolutionTest coverage:
"cuda","cuda:0",torch.device("cuda"),torch.device("cuda", 0)"offset1-model","parametrized-model-5"Verification: