Skip to content

Commit 017b7fd

Browse files
CopilotMMathisLab
andcommitted
Improve robustness of sklearn API handling
- Extract repeated logic into helper function - Add error handling for sklearn API changes - Avoid generator exhaustion by converting to list first - Filter out None values from check lists Co-authored-by: MMathisLab <28102185+MMathisLab@users.noreply.github.com>
1 parent 1f307d1 commit 017b7fd

1 file changed

Lines changed: 20 additions & 13 deletions

File tree

tests/_util.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,21 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=T
8080
check_estimator_sig = inspect.signature(sklearn.utils.estimator_checks.check_estimator)
8181
supports_generate_only = 'generate_only' in check_estimator_sig.parameters
8282

83+
def _get_first_check_for_estimator(estimator):
84+
"""Helper to get the first check for a given estimator in new sklearn API."""
85+
try:
86+
decorator = sklearn.utils.estimator_checks.parametrize_with_checks([estimator])
87+
# Extract the generator from the decorator
88+
gen = decorator.mark.args[1]
89+
# Convert to list and take first element to avoid generator exhaustion issues
90+
checks_list = list(gen)
91+
return checks_list[0] if checks_list else None
92+
except (AttributeError, IndexError, TypeError) as e:
93+
raise RuntimeError(
94+
f"Failed to extract checks from sklearn.utils.estimator_checks.parametrize_with_checks. "
95+
f"This may be due to sklearn API changes. Error: {e}"
96+
)
97+
8398
if supports_generate_only:
8499
# Old sklearn API (<= 1.4.x): use check_estimator with generate_only=True
85100
fast_params = [
@@ -95,19 +110,11 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=T
95110
else:
96111
# New sklearn API (>= 1.5): use parametrize_with_checks to get test params
97112
# For each estimator, get the first check
98-
fast_params = []
99-
for fast_arg in fast_arguments:
100-
decorator = sklearn.utils.estimator_checks.parametrize_with_checks([fast_arg])
101-
# Extract the generator from the decorator and get first item
102-
gen = decorator.mark.args[1]
103-
fast_params.append(next(gen))
104-
105-
slow_params = []
106-
for slow_arg in slow_arguments:
107-
decorator = sklearn.utils.estimator_checks.parametrize_with_checks([slow_arg])
108-
# Extract the generator from the decorator and get first item
109-
gen = decorator.mark.args[1]
110-
slow_params.append(next(gen))
113+
fast_params = [_get_first_check_for_estimator(fast_arg) for fast_arg in fast_arguments]
114+
slow_params = [_get_first_check_for_estimator(slow_arg) for slow_arg in slow_arguments]
115+
# Filter out any None values
116+
fast_params = [p for p in fast_params if p is not None]
117+
slow_params = [p for p in slow_params if p is not None]
111118

112119
return parametrize_slow("estimator,check", fast_params, slow_params)
113120

0 commit comments

Comments
 (0)