Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,28 @@ jobs:
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.6.0", "2.10.0"]
sklearn-version: ["latest"]
numpy-version: ["latest"]

include:
# windows test with standard config
- os: windows-latest
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "latest"
numpy-version: "latest"

# legacy sklearn (several API differences)
- os: ubuntu-latest
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "legacy"
numpy-version: "latest"

- os: ubuntu-latest
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "latest"
numpy-version: "legacy"

# TODO(stes): latest torch and python
# requires a PyTables release compatible with
Expand All @@ -55,6 +65,7 @@ jobs:
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "legacy"
numpy-version: "latest"

runs-on: ${{ matrix.os }}

Expand Down Expand Up @@ -88,6 +99,11 @@ jobs:
run: |
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'

- name: Check numpy legacy version
if: matrix.numpy-version == 'legacy'
run: |
pip install "numpy<2" '.[dev,datasets,integrations]'

- name: Run the formatter
run: |
make format
Expand Down
99 changes: 73 additions & 26 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import importlib.metadata
import itertools
import pickle
import warnings
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
Union)

Expand Down Expand Up @@ -1336,6 +1338,26 @@ def _get_state(self):
}
return state

def _get_state_dict(self):
backend = "sklearn"
return {
'args': self.get_params(),
'state': self._get_state(),
'state_dict': self.solver_.state_dict(),
'metadata': {
'backend':
backend,
'cebra_version':
cebra.__version__,
'torch_version':
torch.__version__,
'numpy_version':
np.__version__,
'sklearn_version':
importlib.metadata.distribution("scikit-learn").version
}
}

def save(self,
filename: str,
backend: Literal["torch", "sklearn"] = "sklearn"):
Expand Down Expand Up @@ -1384,28 +1406,16 @@ def save(self,
"""
if sklearn_utils.check_fitted(self):
if backend == "torch":
warnings.warn(
"Saving with backend='torch' is deprecated and will be removed in a future version. "
"Please use backend='sklearn' instead.",
DeprecationWarning,
stacklevel=2,
)
checkpoint = torch.save(self, filename)

elif backend == "sklearn":
checkpoint = torch.save(
{
'args': self.get_params(),
'state': self._get_state(),
'state_dict': self.solver_.state_dict(),
'metadata': {
'backend':
backend,
'cebra_version':
cebra.__version__,
'torch_version':
torch.__version__,
'numpy_version':
np.__version__,
'sklearn_version':
importlib.metadata.distribution("scikit-learn"
).version
}
}, filename)
checkpoint = torch.save(self._get_state_dict(), filename)
else:
raise NotImplementedError(f"Unsupported backend: {backend}")
else:
Expand Down Expand Up @@ -1457,15 +1467,52 @@ def load(cls,
>>> tmp_file.unlink()
"""
supported_backends = ["auto", "sklearn", "torch"]

if backend not in supported_backends:
raise NotImplementedError(
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
)

checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
if backend not in ["auto", "sklearn"]:
warnings.warn(
"From CEBRA version 0.6.1 onwards, the 'backend' parameter in cebra.CEBRA.load is deprecated and will be ignored; "
"the sklearn backend is now always used. Models saved with the torch backend can still be loaded.",
category=DeprecationWarning,
stacklevel=2,
)

if backend == "auto":
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"
backend = "sklearn"

# NOTE(stes): For maximum backwards compatibility, we allow to load legacy checkpoints. From 0.7.0 onwards,
# the user will have to explicitly pass weights_only=False to load these checkpoints, following the changes
# introduced in torch 2.6.0.
try:
checkpoint = _safe_torch_load(filename, weights_only=True, **kwargs)
except pickle.UnpicklingError as e:
if weights_only is not False:
if packaging.version.parse(
cebra.__version__) < packaging.version.parse("0.7"):
warnings.warn(
"Failed to unpickle checkpoint with weights_only=True. "
"Falling back to loading with weights_only=False. "
"This is unsafe and should only be done if you trust the source of the model file. "
"In the future, loading these checkpoints will only work if weights_only=False is explicitly passed.",
category=UserWarning,
stacklevel=2,
)
else:
raise ValueError(
"Failed to unpickle checkpoint with weights_only=True. "
"This may be due to an incompatible model file format. "
"To attempt loading this checkpoint, please pass weights_only=False to CEBRA.load. "
"Example: CEBRA.load(filename, weights_only=False)."
) from e

checkpoint = _safe_torch_load(filename,
weights_only=False,
**kwargs)
checkpoint = _check_type_checkpoint(checkpoint)
checkpoint = checkpoint._get_state_dict()

if isinstance(checkpoint, dict) and backend == "torch":
raise RuntimeError(
Expand All @@ -1476,10 +1523,10 @@ def load(cls,
"Cannot use 'sklearn' backend a non dictionary-based checkpoint. "
"Please try a different backend.")

if backend == "sklearn":
cebra_ = _load_cebra_with_sklearn_backend(checkpoint)
else:
cebra_ = _check_type_checkpoint(checkpoint)
if backend != "sklearn":
raise ValueError(f"Unsupported backend: {backend}")

cebra_ = _load_cebra_with_sklearn_backend(checkpoint)

n_features = cebra_.n_features_
cebra_.solver_.n_features = ([
Expand Down
18 changes: 17 additions & 1 deletion cebra/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from __future__ import annotations

import fnmatch
import functools
import itertools
import sys
import textwrap
Expand Down Expand Up @@ -214,14 +215,29 @@ def _zip_dict(d):
yield dict(zip(keys, combination))

def _create_class(cls, **default_kwargs):
class_name = pattern.format(**default_kwargs)

@register(pattern.format(**default_kwargs), base=pattern)
@register(class_name, base=pattern)
class _ParametrizedClass(cls):

def __init__(self, *args, **kwargs):
default_kwargs.update(kwargs)
super().__init__(*args, **default_kwargs)

# Make the class pickleable by copying metadata from the base class
# and registering it in the module namespace
functools.update_wrapper(_ParametrizedClass, cls, updated=[])

# Set a unique qualname so pickle finds this class, not the base class
unique_name = f"{cls.__qualname__}_{class_name.replace('-', '_')}"
_ParametrizedClass.__qualname__ = unique_name
_ParametrizedClass.__name__ = unique_name

# Register in module namespace so pickle can find it via getattr
parent_module = sys.modules.get(cls.__module__)
if parent_module is not None:
setattr(parent_module, unique_name, _ParametrizedClass)

def _parametrize(cls):
for _default_kwargs in kwargs:
_create_class(cls, **_default_kwargs)
Expand Down
Loading
Loading