Skip to content

Commit 7bc5ca3

Browse files
committed
Fix legacy loading logic
1 parent d71effa commit 7bc5ca3

3 files changed

Lines changed: 118 additions & 42 deletions

File tree

cebra/integrations/sklearn/cebra.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
import importlib.metadata
2525
import itertools
26+
import pickle
27+
import warnings
2628
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
2729
Union)
2830

@@ -1336,6 +1338,26 @@ def _get_state(self):
13361338
}
13371339
return state
13381340

1341+
def _get_state_dict(self):
1342+
backend = "sklearn"
1343+
return {
1344+
'args': self.get_params(),
1345+
'state': self._get_state(),
1346+
'state_dict': self.solver_.state_dict(),
1347+
'metadata': {
1348+
'backend':
1349+
backend,
1350+
'cebra_version':
1351+
cebra.__version__,
1352+
'torch_version':
1353+
torch.__version__,
1354+
'numpy_version':
1355+
np.__version__,
1356+
'sklearn_version':
1357+
importlib.metadata.distribution("scikit-learn").version
1358+
}
1359+
}
1360+
13391361
def save(self,
13401362
filename: str,
13411363
backend: Literal["torch", "sklearn"] = "sklearn"):
@@ -1384,28 +1406,16 @@ def save(self,
13841406
"""
13851407
if sklearn_utils.check_fitted(self):
13861408
if backend == "torch":
1409+
warnings.warn(
1410+
"Saving with backend='torch' is deprecated and will be removed in a future version. "
1411+
"Please use backend='sklearn' instead.",
1412+
DeprecationWarning,
1413+
stacklevel=2,
1414+
)
13871415
checkpoint = torch.save(self, filename)
13881416

13891417
elif backend == "sklearn":
1390-
checkpoint = torch.save(
1391-
{
1392-
'args': self.get_params(),
1393-
'state': self._get_state(),
1394-
'state_dict': self.solver_.state_dict(),
1395-
'metadata': {
1396-
'backend':
1397-
backend,
1398-
'cebra_version':
1399-
cebra.__version__,
1400-
'torch_version':
1401-
torch.__version__,
1402-
'numpy_version':
1403-
np.__version__,
1404-
'sklearn_version':
1405-
importlib.metadata.distribution("scikit-learn"
1406-
).version
1407-
}
1408-
}, filename)
1418+
checkpoint = torch.save(self._get_state_dict(), filename)
14091419
else:
14101420
raise NotImplementedError(f"Unsupported backend: {backend}")
14111421
else:
@@ -1457,15 +1467,52 @@ def load(cls,
14571467
>>> tmp_file.unlink()
14581468
"""
14591469
supported_backends = ["auto", "sklearn", "torch"]
1470+
14601471
if backend not in supported_backends:
14611472
raise NotImplementedError(
14621473
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
14631474
)
14641475

1465-
checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
1476+
if backend not in ["auto", "sklearn"]:
1477+
warnings.warn(
1478+
"From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; "
1479+
"the sklearn backend is now always used. Models saved with the torch backend can still be loaded.",
1480+
category=DeprecationWarning,
1481+
stacklevel=2,
1482+
)
14661483

1467-
if backend == "auto":
1468-
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"
1484+
backend = "sklearn"
1485+
1486+
# NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards,
1487+
# the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
1488+
# introduced in torch 2.6.0.
1489+
try:
1490+
checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs)
1491+
except pickle.UnpicklingError as e:
1492+
if weights_only is not False:
1493+
if packaging.version.parse(
1494+
cebra.__version__) < packaging.version.parse("0.7"):
1495+
warnings.warn(
1496+
"Failed to unpickle checkpoint with weights_only=True. "
1497+
"Falling back to loading with weights_only=False. "
1498+
"This is unsafe and should only be done if you trust the source of the model file. "
1499+
"In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.",
1500+
category=UserWarning,
1501+
stacklevel=2,
1502+
)
1503+
else:
1504+
raise ValueError(
1505+
"Failed to unpickle checkpoint with weights_only=True. "
1506+
"This may be due to an incompatible model file format. "
1507+
"To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. "
1508+
"Example: CEBRA.load(filename, weights_only=False)."
1509+
) from e
1510+
1511+
checkpoint = _safe_torch_load(filename,
1512+
weights_only=False,
1513+
**kwargs)
1514+
checkpoint = _check_type_checkpoint(checkpoint)
1515+
checkpoint = checkpoint._get_state_dict()
14691516

14701517
if isinstance(checkpoint, dict) and backend == "torch":
14711518
raise RuntimeError(
@@ -1476,10 +1523,10 @@ def load(cls,
14761523
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
14771524
"Please try a different backend.")
14781525

1479-
if backend == "sklearn":
1480-
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
1481-
else:
1482-
cebra_ = _check_type_checkpoint(checkpoint)
1526+
if backend != "sklearn":
1527+
raise ValueError(f"Unsupported backend: {backend}")
1528+
1529+
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
14831530

14841531
n_features = cebra_.n_features_
14851532
cebra_.solver_.n_features = ([

cebra/registry.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from __future__ import annotations
4747

4848
import fnmatch
49+
import functools
4950
import itertools
5051
import sys
5152
import textwrap
@@ -214,14 +215,29 @@ def _zip_dict(d):
214215
yield dict(zip(keys, combination))
215216

216217
def _create_class(cls, **default_kwargs):
218+
class_name = pattern.format(**default_kwargs)
217219

218-
@register(pattern.format(**default_kwargs), base=pattern)
220+
@register(class_name, base=pattern)
219221
class _ParametrizedClass(cls):
220222

221223
def __init__(self, *args, **kwargs):
222224
default_kwargs.update(kwargs)
223225
super().__init__(*args, **default_kwargs)
224226

227+
# Make the class pickleable by copying metadata from the base class
228+
# and registering it in the module namespace
229+
functools.update_wrapper(_ParametrizedClass, cls, updated=[])
230+
231+
# Set a unique qualname so pickle finds this class, not the base class
232+
unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}"
233+
_ParametrizedClass.__qualname__ = unique_name
234+
_ParametrizedClass.__name__ = unique_name
235+
236+
# Register in module namespace so pickle can find it via getattr
237+
parent_module = sys.modules.get(cls.__module__)
238+
if parent_module is not None:
239+
setattr(parent_module, unique_name, _ParametrizedClass)
240+
225241
def _parametrize(cls):
226242
for _default_kwargs in kwargs:
227243
_create_class(cls, **_default_kwargs)

tests/test_sklearn.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,7 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
10531053

10541054
@pytest.mark.parametrize("action", _iterate_actions())
10551055
@pytest.mark.parametrize("backend_save", ["torch", "sklearn"])
1056-
@pytest.mark.parametrize("backend_load", ["auto", "torch", "sklearn"])
1056+
@pytest.mark.parametrize("backend_load", ["sklearn", "auto", "torch"])
10571057
@pytest.mark.parametrize("model_architecture",
10581058
["offset1-model", "parametrized-model-5"])
10591059
@pytest.mark.parametrize("device", ["cpu"] +
@@ -1072,20 +1072,14 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture,
10721072
with pytest.raises(ValueError):
10731073
original_model.save(tempname, backend=backend_save)
10741074
else:
1075-
if "parametrized" in original_model.model_architecture and backend_save == "torch":
1076-
with pytest.raises(AttributeError):
1077-
original_model.save(tempname, backend=backend_save)
1078-
else:
1079-
original_model.save(tempname, backend=backend_save)
1075+
original_model.save(tempname, backend=backend_save)
1076+
1077+
weights_only = None
10801078

1081-
if (backend_load != "auto") and (backend_save != backend_load):
1082-
with pytest.raises(RuntimeError):
1083-
cebra_sklearn_cebra.CEBRA.load(tempname, backend_load)
1084-
else:
1085-
loaded_model = cebra_sklearn_cebra.CEBRA.load(
1086-
tempname, backend_load)
1087-
_assert_equal(original_model, loaded_model)
1088-
action(loaded_model)
1079+
loaded_model = cebra_sklearn_cebra.CEBRA.load(
1080+
tempname, backend_load, weights_only=weights_only)
1081+
_assert_equal(original_model, loaded_model)
1082+
action(loaded_model)
10891083

10901084

10911085
def get_ordered_cuda_devices():
@@ -1489,7 +1483,7 @@ def test_new_transform(model_architecture, device):
14891483
X,
14901484
session_id=0)
14911485
assert np.allclose(embedding1, embedding2, rtol=1e-5,
1492-
atol=1e-8), "Arrays are not close enough"
1486+
atol=1e-8), " are not close enough"
14931487

14941488
embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0)
14951489
embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
@@ -1603,3 +1597,22 @@ def test_read_write():
16031597
cebra_model.save(tempname)
16041598
loaded_model = cebra.CEBRA.load(tempname)
16051599
_assert_equal(cebra_model, loaded_model)
1600+
1601+
1602+
def test_repro_pickle_error():
1603+
"""The torch backend for save/loading fails with python 3.14.
1604+
1605+
See https://github.com/AdaptiveMotorControlLab/CEBRA/pull/292.
1606+
1607+
This test is a minimal repro of the error.
1608+
"""
1609+
1610+
model = cebra_sklearn_cebra.CEBRA(model_architecture='parametrized-model-5',
1611+
max_iterations=5,
1612+
batch_size=100,
1613+
device='cpu')
1614+
1615+
model.fit(np.random.randn(1000, 10))
1616+
1617+
with _windows_compatible_tempfile(mode="w+b", delete=True) as tempname:
1618+
model.save(tempname, backend="torch")

0 commit comments

Comments
 (0)