diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 00645523..d57159d3 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -1491,6 +1491,10 @@ def load(cls, backend = "sklearn" + if (not torch.cuda.is_available() and + "map_location" not in kwargs): + kwargs["map_location"] = "cpu" + # NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards, # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes # introduced in torch 2.6.0. diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 831ad49d..740db5f5 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -1097,6 +1097,103 @@ def get_ordered_cuda_devices(): ) else [] +@pytest.mark.parametrize("saved_device", [ + "cuda", + "cuda:0", + torch.device("cuda"), + torch.device("cuda", 0), +]) +@pytest.mark.parametrize("model_architecture", + ["offset1-model", "parametrized-model-5"]) +def test_load_cuda_checkpoint_falls_back_to_cpu(saved_device, + model_architecture, + monkeypatch): + X = np.random.uniform(0, 1, (100, 5)) + + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture=model_architecture, max_iterations=5, + device="cpu").fit(X) + + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + + checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) + checkpoint["state"]["device_"] = saved_device + torch.save(checkpoint, tempname) + + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) + + assert loaded_model.device_ == "cpu" + assert loaded_model.device == "cpu" + assert next(loaded_model.solver_.model.parameters()).device == torch.device( + "cpu") + + X_test = np.random.uniform(0, 1, (10, 5)) + embedding = loaded_model.transform(X_test) + assert embedding.shape[0] == 10 + assert embedding.shape[1] > 0 + assert isinstance(embedding, np.ndarray) + + +def test_safe_torch_load_cuda_fallback(monkeypatch): + import os + import tempfile + + checkpoint = {"test": torch.tensor([1.0, 2.0, 3.0])} + + with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: + tempname = f.name + torch.save(checkpoint, tempname) + + try: + original_torch_load = torch.load + call_count = [0] + + def mock_torch_load(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1 and "map_location" not in kwargs: + raise RuntimeError( + "Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False" + ) + return original_torch_load(*args, **kwargs) + + monkeypatch.setattr(torch, "load", mock_torch_load) + + result = cebra_sklearn_cebra._safe_torch_load(tempname) + assert "test" in result + assert torch.allclose(result["test"], checkpoint["test"]) + assert call_count[0] == 2 + + finally: + os.unlink(tempname) + + +@pytest.mark.parametrize("saved_device", ["cuda", "cuda:0"]) +def test_load_cuda_checkpoint_with_device_override(saved_device, monkeypatch): + X = np.random.uniform(0, 1, (100, 5)) + + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", max_iterations=5, + device="cpu").fit(X) + + with _windows_compatible_tempfile(mode="w+b") as tempname: + cebra_model.save(tempname) + checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) + checkpoint["state"]["device_"] = saved_device + torch.save(checkpoint, tempname) + + monkeypatch.setattr(torch.cuda, "is_available", lambda: False) + + loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) + + X_test = np.random.uniform(0, 1, (10, 5)) + embedding = loaded_model.transform(X_test) + assert embedding.shape[0] == 10 + assert embedding.shape[1] > 0 + + def test_fit_after_moving_to_device(): expected_device = 'cpu' expected_type = type(expected_device)