Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,42 @@ jobs:
fail-fast: true
matrix:
os: [ubuntu-latest]
python-version: ["3.9", "3.10", "3.12"]
# NOTE(stes): We test against the oldest still supported version of python (3.10),
# the python version of the latest Ubuntu LTS release (3.12)
# and the latest python version CEBRA supports fully (3.13).
#
# Python version chart:
# https://devguide.python.org/versions/
# Python Ubuntu LTS:
# https://documentation.ubuntu.com/ubuntu-for-developers/reference/availability/python/
python-version: ["3.10", "3.12", "3.13"]
# We aim to support the versions on pytorch.org
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.4.0", "2.6.0"]
torch-version: ["2.6.0", "2.9.1"]
sklearn-version: ["latest"]
include:
# windows test with standard config
- os: windows-latest
torch-version: 2.4.0
python-version: "3.10"
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "latest"

# legacy sklearn (several API differences)
- os: ubuntu-latest
torch-version: 2.6.0
python-version: "3.12"
sklearn-version: "legacy"

# TODO(stes): latest torch and python
# requires a PyTables release compatible with
# python 3.14. Update when new version is released.
#- os: ubuntu-latest
# torch-version: 2.9.1
# python-version: "3.14"
# sklearn-version: "latest"

# legacy support
- os: ubuntu-latest
torch-version: 2.4.0
python-version: "3.10"
Expand Down Expand Up @@ -82,7 +107,7 @@ jobs:
make check_for_binary

- name: Run pytest tests
timeout-minutes: 10
timeout-minutes: 15
run: |
make test

Expand Down
48 changes: 48 additions & 0 deletions .github/workflows/latest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: Python package (latest)

# This is a variant of the workflow in build.yml to check incompatibilities for the
# very latest versions of Python and the core dependencies.

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build:
timeout-minutes: 30
runs-on: ubuntu-latest

steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: pip-os_latest

- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python 3.14
uses: actions/setup-python@v5
with:
python-version: "3.14"

- name: Install package
run: |
python -m pip install --upgrade pip setuptools wheel --pre
python -m pip install --pre torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install --pre '.[dev]'

- name: Run pytest tests
timeout-minutes: 15
run: |
make test_reduced

- name: Build package
run: |
make build
15 changes: 15 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ test_all: clean_test
test_fast: clean_test
python -m pytest --ff --ignore cebra/grid_search.py -m "not requires_dataset" tests cebra --runfast

# Reduced variant of the test suite. These tests should pass even if only the core package without any
# addtional dataset/integration dependencies are installed.
test_reduced: clean_test
python -m pytest --ff --ignore cebra/grid_search.py \
--ignore tests/test_attribution.py \
--ignore tests/test_dlc.py \
--ignore tests/test_grid_search.py \
--ignore tests/test_integration_xcebra.py \
--ignore tests/test_load.py \
--ignore tests/test_plotly.py \
--ignore cebra/attribution \
--ignore cebra/integrations \
-m "not requires_dataset" tests cebra --runfast


# Run failed test firsts, using a single worker (for debugging)
test_debug: clean_test
python -m pytest -vvv -x --ff -m "not requires_dataset" tests
Expand Down
5 changes: 5 additions & 0 deletions cebra/integrations/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def __init__(self, axis: Optional[matplotlib.axes.Axes], figsize: tuple,
if axis is None:
self.fig = plt.figure(figsize=figsize, dpi=dpi)

def __del__(self):
if hasattr(self, "fig") and self.fig is not None:
plt.close(self.fig)
self.fig = None

@abc.abstractmethod
def _define_ax(
self, axis: Optional[matplotlib.axes.Axes]) -> matplotlib.axes.Axes:
Expand Down
9 changes: 6 additions & 3 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
#
"""Define the CEBRA model."""

import importlib.metadata
import itertools
from typing import (Callable, Dict, Iterable, List, Literal, Optional, Tuple,
Union)

import numpy as np
import numpy.typing as npt
import packaging.version
import importlib.metadata
import sklearn
import sklearn.utils.validation as sklearn_utils_validation
import torch
Expand All @@ -46,9 +46,12 @@

# NOTE(stes): From torch 2.6 onwards, we need to specify the following list
# when loading CEBRA models to allow weights_only = True.
# NOTE(stes): "numpy.dtypes.Int32DType" was added due to this issue with
# windows (https://github.com/AdaptiveMotorControlLab/CEBRA/pull/281#issuecomment-3764185072)
# on build (windows-latest, torch 2.6.0, python 3.12, latest sklearn)
CEBRA_LOAD_SAFE_GLOBALS = [
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
np.dtypes.Float64DType, np.dtypes.Int64DType
np.dtypes.Int32DType, np.dtypes.Float64DType, np.dtypes.Int64DType
]


Expand Down Expand Up @@ -1398,7 +1401,7 @@ def save(self,
np.__version__,
'sklearn_version':
importlib.metadata.distribution("scikit-learn"
).version
).version
}
}, filename)
else:
Expand Down
8 changes: 4 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ packages = find:
where =
- .
- tests
python_requires = >=3.9
python_requires = >=3.10
install_requires =
joblib
numpy<2.0;platform_system=="Windows"
numpy<2.0;platform_system!="Windows" and python_version<"3.10"
numpy;platform_system!="Windows" and python_version>="3.10"
literate-dataclasses
scikit-learn
Expand All @@ -57,6 +56,7 @@ datasets =
# additional data loading dependencies
hdf5storage # for creating .mat files in new format
openpyxl # for excel file format loading
tables # for hdf5 file format loading
integrations =
pandas
plotly
Expand Down Expand Up @@ -104,8 +104,8 @@ dev =
pytest-benchmark
pytest-xdist
pytest-timeout
pytest-sphinx
tables
# NOTE(stes): https://github.com/twmr/pytest-sphinx/issues/69
pytest-sphinx @ https://github.com/twmr/pytest-sphinx/archive/refs/tags/v0.7.0.tar.gz
licenseheaders
interrogate
# TODO(stes) Add back once upstream issue
Expand Down
27 changes: 16 additions & 11 deletions tests/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
# limitations under the License.
#
import collections.abc as collections_abc
import functools

import packaging.version
import pytest
import sklearn.utils.estimator_checks
import torch
Expand Down Expand Up @@ -63,17 +65,20 @@ def parametrize_slow(arg_names, fast_arguments, slow_arguments):


def parametrize_with_checks_slow(fast_arguments, slow_arguments):
fast_params = [
list(
sklearn.utils.estimator_checks.check_estimator(
fast_arg, generate_only=True))[0] for fast_arg in fast_arguments
]
slow_params = [
list(
sklearn.utils.estimator_checks.check_estimator(
slow_arg, generate_only=True))[0] for slow_arg in slow_arguments
]
return parametrize_slow("estimator,check", fast_params, slow_params)

# NOTE(stes): See https://github.com/AdaptiveMotorControlLab/CEBRA/issues/280, sklearn API changed in 1.6.
if packaging.version.parse(
sklearn.__version__) <= packaging.version.parse("1.6"):
generate_checks = functools.partial(
sklearn.utils.estimator_checks.check_estimator, generate_only=True)
else:
generate_checks = sklearn.utils.estimator_checks.estimator_checks_generator

def _generate_params(args):
return [next(generate_checks(arg)) for arg in args]

return parametrize_slow("estimator,check", _generate_params(fast_arguments),
_generate_params(slow_arguments))


def parametrize_device(func):
Expand Down
6 changes: 0 additions & 6 deletions tests/_utils_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ def cebra_transform_deprecated(cebra_model,
>>> embedding = cebra_model.transform(dataset)

"""
warnings.warn(
"The method is deprecated "
"but kept for testing puroposes."
"We recommend using `transform` instead.",
DeprecationWarning,
stacklevel=2)

sklearn_utils_validation.check_is_fitted(cebra_model, "n_features_")

Expand Down
Loading