|
13 | 13 |
|
14 | 14 | import numpy as np
|
15 | 15 | import pytest
|
| 16 | +import sklearn |
16 | 17 | from joblib import Memory
|
17 | 18 | from pytest import raises
|
18 | 19 | from sklearn.base import BaseEstimator, clone
|
|
30 | 31 | assert_array_almost_equal,
|
31 | 32 | assert_array_equal,
|
32 | 33 | )
|
| 34 | +from sklearn.utils.fixes import parse_version |
33 | 35 |
|
34 | 36 | from imblearn.datasets import make_imbalance
|
35 | 37 | from imblearn.pipeline import Pipeline, make_pipeline
|
36 | 38 | from imblearn.under_sampling import EditedNearestNeighbours as ENN
|
37 | 39 | from imblearn.under_sampling import RandomUnderSampler
|
38 | 40 | from imblearn.utils.estimator_checks import check_param_validation
|
39 | 41 |
|
| 42 | +sklearn_version = parse_version(sklearn.__version__) |
| 43 | + |
40 | 44 | JUNK_FOOD_DOCS = (
|
41 | 45 | "the pizza pizza beer copyright",
|
42 | 46 | "the pizza burger beer copyright",
|
@@ -1333,3 +1337,27 @@ def test_pipeline_param_validation():
|
1333 | 1337 | [("sampler", RandomUnderSampler()), ("classifier", LogisticRegression())]
|
1334 | 1338 | )
|
1335 | 1339 | check_param_validation("Pipeline", model)
|
| 1340 | + |
| 1341 | + |
| 1342 | +@pytest.mark.skipif( |
| 1343 | + sklearn_version < parse_version("1.2"), reason="requires scikit-learn >= 1.2" |
| 1344 | +) |
| 1345 | +def test_pipeline_with_set_output(): |
| 1346 | + pd = pytest.importorskip("pandas") |
| 1347 | + X, y = load_iris(return_X_y=True, as_frame=True) |
| 1348 | + pipeline = make_pipeline( |
| 1349 | + StandardScaler(), RandomUnderSampler(), LogisticRegression() |
| 1350 | + ).set_output(transform="default") |
| 1351 | + pipeline.fit(X, y) |
| 1352 | + |
| 1353 | + X_res, y_res = pipeline[:-1].fit_resample(X, y) |
| 1354 | + assert isinstance(X_res, np.ndarray) |
| 1355 | + # transformer will not change `y` and sampler will always preserve the type of `y` |
| 1356 | + assert isinstance(y_res, type(y)) |
| 1357 | + |
| 1358 | + pipeline.set_output(transform="pandas") |
| 1359 | + X_res, y_res = pipeline[:-1].fit_resample(X, y) |
| 1360 | + |
| 1361 | + assert isinstance(X_res, pd.DataFrame) |
| 1362 | + # transformer will not change `y` and sampler will always preserve the type of `y` |
| 1363 | + assert isinstance(y_res, type(y)) |
0 commit comments