2323
2424import importlib .metadata
2525import itertools
26+ import pickle
27+ import warnings
2628from typing import (Callable , Dict , Iterable , List , Literal , Optional , Tuple ,
2729 Union )
2830
@@ -1336,6 +1338,26 @@ def _get_state(self):
13361338 }
13371339 return state
13381340
1341+ def _get_state_dict (self ):
1342+ backend = "sklearn"
1343+ return {
1344+ 'args' : self .get_params (),
1345+ 'state' : self ._get_state (),
1346+ 'state_dict' : self .solver_ .state_dict (),
1347+ 'metadata' : {
1348+ 'backend' :
1349+ backend ,
1350+ 'cebra_version' :
1351+ cebra .__version__ ,
1352+ 'torch_version' :
1353+ torch .__version__ ,
1354+ 'numpy_version' :
1355+ np .__version__ ,
1356+ 'sklearn_version' :
1357+ importlib .metadata .distribution ("scikit-learn" ).version
1358+ }
1359+ }
1360+
13391361 def save (self ,
13401362 filename : str ,
13411363 backend : Literal ["torch" , "sklearn" ] = "sklearn" ):
@@ -1384,28 +1406,16 @@ def save(self,
13841406 """
13851407 if sklearn_utils .check_fitted (self ):
13861408 if backend == "torch" :
1409+ warnings .warn (
1410+ "Saving with backend='torch' is deprecated and will be removed in a future version. "
1411+ "Please use backend='sklearn' instead." ,
1412+ DeprecationWarning ,
1413+ stacklevel = 2 ,
1414+ )
13871415 checkpoint = torch .save (self , filename )
13881416
13891417 elif backend == "sklearn" :
1390- checkpoint = torch .save (
1391- {
1392- 'args' : self .get_params (),
1393- 'state' : self ._get_state (),
1394- 'state_dict' : self .solver_ .state_dict (),
1395- 'metadata' : {
1396- 'backend' :
1397- backend ,
1398- 'cebra_version' :
1399- cebra .__version__ ,
1400- 'torch_version' :
1401- torch .__version__ ,
1402- 'numpy_version' :
1403- np .__version__ ,
1404- 'sklearn_version' :
1405- importlib .metadata .distribution ("scikit-learn"
1406- ).version
1407- }
1408- }, filename )
1418+ checkpoint = torch .save (self ._get_state_dict (), filename )
14091419 else :
14101420 raise NotImplementedError (f"Unsupported backend: { backend } " )
14111421 else :
@@ -1457,15 +1467,52 @@ def load(cls,
14571467 >>> tmp_file.unlink()
14581468 """
14591469 supported_backends = ["auto" , "sklearn" , "torch" ]
1470+
14601471 if backend not in supported_backends :
14611472 raise NotImplementedError (
14621473 f"Unsupported backend: '{ backend } '. Supported backends are: { ', ' .join (supported_backends )} "
14631474 )
14641475
1465- checkpoint = _safe_torch_load (filename , weights_only , ** kwargs )
1476+ if backend not in ["auto" , "sklearn" ]:
1477+ warnings .warn (
1478+ "From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; "
1479+ "the sklearn backend is now always used. Models saved with the torch backend can still be loaded." ,
1480+ category = DeprecationWarning ,
1481+ stacklevel = 2 ,
1482+ )
14661483
1467- if backend == "auto" :
1468- backend = "sklearn" if isinstance (checkpoint , dict ) else "torch"
1484+ backend = "sklearn"
1485+
1486+ # NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards,
1487+ # the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
1488+ # introduced in torch 2.6.0.
1489+ try :
1490+ checkpoint = _safe_torch_load (filename , weights_only = True , ** kwargs )
1491+ except pickle .UnpicklingError as e :
1492+ if weights_only is not False :
1493+ if packaging .version .parse (
1494+ cebra .__version__ ) < packaging .version .parse ("0.7" ):
1495+ warnings .warn (
1496+ "Failed to unpickle checkpoint with weights_only=True. "
1497+ "Falling back to loading with weights_only=False. "
1498+ "This is unsafe and should only be done if you trust the source of the model file. "
1499+ "In the future, loading these checkpoints will only work if weights_only=False is explicitly passed." ,
1500+ category = UserWarning ,
1501+ stacklevel = 2 ,
1502+ )
1503+ else :
1504+ raise ValueError (
1505+ "Failed to unpickle checkpoint with weights_only=True. "
1506+ "This may be due to an incompatible model file format. "
1507+ "To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. "
1508+ "Example: CEBRA.load(filename, weights_only=False)."
1509+ ) from e
1510+
1511+ checkpoint = _safe_torch_load (filename ,
1512+ weights_only = False ,
1513+ ** kwargs )
1514+ checkpoint = _check_type_checkpoint (checkpoint )
1515+ checkpoint = checkpoint ._get_state_dict ()
14691516
14701517 if isinstance (checkpoint , dict ) and backend == "torch" :
14711518 raise RuntimeError (
@@ -1476,10 +1523,10 @@ def load(cls,
14761523 "Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
14771524 "Please try a different backend." )
14781525
1479- if backend = = "sklearn" :
1480- cebra_ = _load_cebra_with_sklearn_backend ( checkpoint )
1481- else :
1482- cebra_ = _check_type_checkpoint (checkpoint )
1526+ if backend ! = "sklearn" :
1527+ raise ValueError ( f"Unsupported backend: { backend } " )
1528+
1529+ cebra_ = _load_cebra_with_sklearn_backend (checkpoint )
14831530
14841531 n_features = cebra_ .n_features_
14851532 cebra_ .solver_ .n_features = ([
0 commit comments