Skip to content

Commit 18306a4

Browse files
committed
Fix sklearn API change for check_estimator
1 parent 5a9190e commit 18306a4

1 file changed

Lines changed: 10 additions & 12 deletions

File tree

tests/_util.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
# limitations under the License.
2121
#
2222
import collections.abc as collections_abc
23-
23+
import packaging.version
24+
import functools
2425
import pytest
2526
import sklearn.utils.estimator_checks
2627
import torch
@@ -63,17 +64,14 @@ def parametrize_slow(arg_names, fast_arguments, slow_arguments):
6364

6465

6566
def parametrize_with_checks_slow(fast_arguments, slow_arguments):
66-
fast_params = [
67-
list(
68-
sklearn.utils.estimator_checks.check_estimator(
69-
fast_arg, generate_only=True))[0] for fast_arg in fast_arguments
70-
]
71-
slow_params = [
72-
list(
73-
sklearn.utils.estimator_checks.check_estimator(
74-
slow_arg, generate_only=True))[0] for slow_arg in slow_arguments
75-
]
76-
return parametrize_slow("estimator,check", fast_params, slow_params)
67+
68+
# NOTE(stes): See https://github.com/AdaptiveMotorControlLab/CEBRA/issues/280, sklearn API changed in 1.6.
69+
if packaging.version.parse(sklearn.__version__) <= packaging.version.parse("1.6"):
70+
generate_checks = functools.partial(sklearn.utils.estimator_checks.check_estimator, generate_only=True)
71+
else:
72+
generate_checks = sklearn.utils.estimator_checks.estimator_checks_generator
73+
generate_params = lambda args: [next(generate_checks(arg)) for arg in args]
74+
return parametrize_slow("estimator,check", generate_params(fast_arguments), generate_params(slow_arguments))
7775

7876

7977
def parametrize_device(func):

0 commit comments

Comments
 (0)