Skip to content

Commit 01dd9d4

Browse files
committed
Add param validation to PCA classes (dask#808)
1 parent 17cfedd commit 01dd9d4

File tree

4 files changed

+45
-49
lines changed

4 files changed

+45
-49
lines changed

dask_ml/decomposition/incremental_pca.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,6 @@ def __init__(
134134
iterated_power=0,
135135
random_state=None,
136136
):
137-
if center is False:
138-
raise NotImplementedError(
139-
"IncrementalPCA with center=False is not supported."
140-
)
141-
142137
self.n_components = n_components
143138
self.whiten = whiten
144139
self.center = center
@@ -148,6 +143,11 @@ def __init__(
148143
self.iterated_power = iterated_power
149144
self.random_state = random_state
150145

146+
def _check_params(self):
147+
super()._check_params()
148+
if self.center is False:
149+
raise ValueError("IncrementalPCA with center=False is not supported.")
150+
151151
def _fit(self, X, y=None):
152152
"""Fit the model with X, using minibatches of size batch_size.
153153
@@ -245,6 +245,7 @@ def partial_fit(self, X, y=None, check_input=True):
245245
self : object
246246
Returns the instance itself.
247247
"""
248+
self._check_params()
248249
if check_input:
249250
if sparse.issparse(X):
250251
raise TypeError(

dask_ml/decomposition/pca.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numbers
2-
import warnings
32

43
import dask
54
import dask.array as da
@@ -222,16 +221,14 @@ def __init__(
222221
self.iterated_power = iterated_power
223222
self.random_state = random_state
224223

225-
if whiten and not center:
226-
warnings.warn(
227-
"Whitening requires centering. Please, ensure that your data "
228-
"is already centered, in order to avoid unexpected behavior.",
229-
RuntimeWarning,
230-
)
224+
def _check_params(self):
225+
pass
231226

232227
def fit(self, X, y=None):
233228
if not dask.is_dask_collection(X):
234229
raise TypeError(_TYPE_MSG.format(type(X)))
230+
231+
self._check_params()
235232
self._fit(X)
236233
self.n_features_in_ = X.shape[1]
237234
return self

tests/test_incremental_pca.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,5 +478,17 @@ def test_incremental_pca_partial_fit_float_division():
478478

479479

480480
def test_incremental_pca_no_centering_not_supported():
481-
with pytest.raises(NotImplementedError, match="not supported"):
482-
IncrementalPCA(n_components=2, center=False)
481+
rng = np.random.RandomState(0)
482+
A = rng.randn(5, 3) + 2
483+
A = da.from_array(A, chunks=[3, -1])
484+
485+
pca = IncrementalPCA(n_components=2, center=False)
486+
487+
with pytest.raises(ValueError, match="not supported"):
488+
pca.partial_fit(A)
489+
490+
with pytest.raises(ValueError, match="not supported"):
491+
pca.fit(A)
492+
493+
with pytest.raises(ValueError, match="not supported"):
494+
pca.fit_transform(A)

tests/test_pca.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -183,19 +183,15 @@ def test_whitening():
183183

184184
# Test no centering.
185185
X_mean_0 = dX_mean_0.copy()
186-
with pytest.warns(RuntimeWarning) as record:
187-
pca = dd.PCA(
188-
n_components=n_components,
189-
whiten=True,
190-
center=False,
191-
copy=copy,
192-
svd_solver=solver,
193-
random_state=0,
194-
iterated_power=4,
195-
)
196-
197-
assert len(record) == 1
198-
assert "Whitening requires centering" in record[0].message.args[0]
186+
pca = dd.PCA(
187+
n_components=n_components,
188+
whiten=True,
189+
center=False,
190+
copy=copy,
191+
svd_solver=solver,
192+
random_state=0,
193+
iterated_power=4,
194+
)
199195

200196
X_whitened = pca.fit_transform(X_mean_0.copy())
201197
assert X_whitened.shape == (n_samples, n_components)
@@ -358,19 +354,14 @@ def test_inverse_transform_no_centering():
358354
Y_inverse = pca.inverse_transform(Y)
359355
assert_almost_equal(dX_mean_0.compute(), Y_inverse, decimal=3)
360356

361-
# Capture warning about employing whitening without centering.
362-
with pytest.warns(RuntimeWarning) as record:
363-
# As above, but with whitening.
364-
pca = dd.PCA(
365-
n_components=2,
366-
svd_solver="full",
367-
random_state=0,
368-
whiten=True,
369-
center=False,
370-
).fit(dX_mean_0)
371-
372-
assert len(record) == 1
373-
assert "Whitening requires centering" in record[0].message.args[0]
357+
# As above, but with whitening.
358+
pca = dd.PCA(
359+
n_components=2,
360+
svd_solver="full",
361+
random_state=0,
362+
whiten=True,
363+
center=False,
364+
).fit(dX_mean_0)
374365

375366
Y = pca.transform(dX_mean_0)
376367
Y_inverse = pca.inverse_transform(Y)
@@ -818,15 +809,10 @@ def test_pca_score2():
818809
ll2 = pca.score(dX)
819810
assert ll1 > ll2
820811

821-
# Capture warning about employing whitening without centering.
822-
with pytest.warns(RuntimeWarning) as record:
823-
apca = dd.PCA(n_components=2, whiten=True, center=False, svd_solver=solver)
824-
apca.fit(dX_mean_0)
825-
all2 = apca.score(dX_mean_0)
826-
assert all1 > all2
827-
828-
assert len(record) == 1
829-
assert "Whitening requires centering" in record[0].message.args[0]
812+
apca = dd.PCA(n_components=2, whiten=True, center=False, svd_solver=solver)
813+
apca.fit(dX_mean_0)
814+
all2 = apca.score(dX_mean_0)
815+
assert all1 > all2
830816

831817

832818
def test_pca_score3():

0 commit comments

Comments
 (0)