Skip to content

Commit 43dde5d

Browse files
authored
TST add set_output test in pipeline (#960)
1 parent 54a7b5b commit 43dde5d

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

imblearn/tests/test_pipeline.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
import pytest
16+
import sklearn
1617
from joblib import Memory
1718
from pytest import raises
1819
from sklearn.base import BaseEstimator, clone
@@ -30,13 +31,16 @@
3031
assert_array_almost_equal,
3132
assert_array_equal,
3233
)
34+
from sklearn.utils.fixes import parse_version
3335

3436
from imblearn.datasets import make_imbalance
3537
from imblearn.pipeline import Pipeline, make_pipeline
3638
from imblearn.under_sampling import EditedNearestNeighbours as ENN
3739
from imblearn.under_sampling import RandomUnderSampler
3840
from imblearn.utils.estimator_checks import check_param_validation
3941

42+
sklearn_version = parse_version(sklearn.__version__)
43+
4044
JUNK_FOOD_DOCS = (
4145
"the pizza pizza beer copyright",
4246
"the pizza burger beer copyright",
@@ -1333,3 +1337,27 @@ def test_pipeline_param_validation():
13331337
[("sampler", RandomUnderSampler()), ("classifier", LogisticRegression())]
13341338
)
13351339
check_param_validation("Pipeline", model)
1340+
1341+
1342+
@pytest.mark.skipif(
1343+
sklearn_version < parse_version("1.2"), reason="requires scikit-learn >= 1.2"
1344+
)
1345+
def test_pipeline_with_set_output():
1346+
pd = pytest.importorskip("pandas")
1347+
X, y = load_iris(return_X_y=True, as_frame=True)
1348+
pipeline = make_pipeline(
1349+
StandardScaler(), RandomUnderSampler(), LogisticRegression()
1350+
).set_output(transform="default")
1351+
pipeline.fit(X, y)
1352+
1353+
X_res, y_res = pipeline[:-1].fit_resample(X, y)
1354+
assert isinstance(X_res, np.ndarray)
1355+
# transformer will not change `y` and sampler will always preserve the type of `y`
1356+
assert isinstance(y_res, type(y))
1357+
1358+
pipeline.set_output(transform="pandas")
1359+
X_res, y_res = pipeline[:-1].fit_resample(X, y)
1360+
1361+
assert isinstance(X_res, pd.DataFrame)
1362+
# transformer will not change `y` and sampler will always preserve the type of `y`
1363+
assert isinstance(y_res, type(y))

0 commit comments

Comments
 (0)