Skip to content

Commit b661a9c

Browse files
jnothmanamueller
authored andcommitted
TST Improve SelectFromModel tests (scikit-learn#9733)
Should fix one of the issues in scikit-learn#9393
1 parent 0e1d261 commit b661a9c

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

Diff for: sklearn/feature_selection/tests/test_from_model.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def test_input_estimator_unchanged():
4040
assert_true(transformer.estimator is est)
4141

4242

43-
@skip_if_32bit
4443
def test_feature_importances():
4544
X, y = datasets.make_classification(
4645
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
@@ -59,17 +58,33 @@ def test_feature_importances():
5958
feature_mask = np.abs(importances) > func(importances)
6059
assert_array_almost_equal(X_new, X[:, feature_mask])
6160

61+
62+
def test_sample_weight():
63+
# Ensure sample weights are passed to underlying estimator
64+
X, y = datasets.make_classification(
65+
n_samples=100, n_features=10, n_informative=3, n_redundant=0,
66+
n_repeated=0, shuffle=False, random_state=0)
67+
6268
# Check with sample weights
6369
sample_weight = np.ones(y.shape)
6470
sample_weight[y == 1] *= 100
6571

66-
est = RandomForestClassifier(n_estimators=50, random_state=0)
72+
est = LogisticRegression(random_state=0, fit_intercept=False)
6773
transformer = SelectFromModel(estimator=est)
74+
transformer.fit(X, y, sample_weight=None)
75+
mask = transformer._get_support_mask()
6876
transformer.fit(X, y, sample_weight=sample_weight)
69-
importances = transformer.estimator_.feature_importances_
77+
weighted_mask = transformer._get_support_mask()
78+
assert not np.all(weighted_mask == mask)
7079
transformer.fit(X, y, sample_weight=3 * sample_weight)
71-
importances_bis = transformer.estimator_.feature_importances_
72-
assert_almost_equal(importances, importances_bis)
80+
reweighted_mask = transformer._get_support_mask()
81+
assert np.all(weighted_mask == reweighted_mask)
82+
83+
84+
def test_coef_default_threshold():
85+
X, y = datasets.make_classification(
86+
n_samples=100, n_features=10, n_informative=3, n_redundant=0,
87+
n_repeated=0, shuffle=False, random_state=0)
7388

7489
# For the Lasso and related models, the threshold defaults to 1e-5
7590
transformer = SelectFromModel(estimator=Lasso(alpha=0.1))
@@ -80,7 +95,7 @@ def test_feature_importances():
8095

8196

8297
@skip_if_32bit
83-
def test_feature_importances_2d_coef():
98+
def test_2d_coef():
8499
X, y = datasets.make_classification(
85100
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
86101
n_repeated=0, shuffle=False, random_state=0, n_classes=4)

0 commit comments

Comments
 (0)