@@ -80,6 +80,21 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=T
8080 check_estimator_sig = inspect .signature (sklearn .utils .estimator_checks .check_estimator )
8181 supports_generate_only = 'generate_only' in check_estimator_sig .parameters
8282
83+ def _get_first_check_for_estimator (estimator ):
84+ """Helper to get the first check for a given estimator in new sklearn API."""
85+ try :
86+ decorator = sklearn .utils .estimator_checks .parametrize_with_checks ([estimator ])
87+ # Extract the generator from the decorator
88+ gen = decorator .mark .args [1 ]
89+ # Convert to list and take first element to avoid generator exhaustion issues
90+ checks_list = list (gen )
91+ return checks_list [0 ] if checks_list else None
92+ except (AttributeError , IndexError , TypeError ) as e :
93+ raise RuntimeError (
94+ f"Failed to extract checks from sklearn.utils.estimator_checks.parametrize_with_checks. "
95+ f"This may be due to sklearn API changes. Error: { e } "
96+ )
97+
8398 if supports_generate_only :
8499 # Old sklearn API (<= 1.4.x): use check_estimator with generate_only=True
85100 fast_params = [
@@ -95,19 +110,11 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=T
95110 else :
96111 # New sklearn API (>= 1.5): use parametrize_with_checks to get test params
97112 # For each estimator, get the first check
98- fast_params = []
99- for fast_arg in fast_arguments :
100- decorator = sklearn .utils .estimator_checks .parametrize_with_checks ([fast_arg ])
101- # Extract the generator from the decorator and get first item
102- gen = decorator .mark .args [1 ]
103- fast_params .append (next (gen ))
104-
105- slow_params = []
106- for slow_arg in slow_arguments :
107- decorator = sklearn .utils .estimator_checks .parametrize_with_checks ([slow_arg ])
108- # Extract the generator from the decorator and get first item
109- gen = decorator .mark .args [1 ]
110- slow_params .append (next (gen ))
113+ fast_params = [_get_first_check_for_estimator (fast_arg ) for fast_arg in fast_arguments ]
114+ slow_params = [_get_first_check_for_estimator (slow_arg ) for slow_arg in slow_arguments ]
115+ # Filter out any None values
116+ fast_params = [p for p in fast_params if p is not None ]
117+ slow_params = [p for p in slow_params if p is not None ]
111118
112119 return parametrize_slow ("estimator,check" , fast_params , slow_params )
113120
0 commit comments