Skip to content

Commit d241040

Browse files
committed
Make automatic centering in PCA methods optional (fixes #734) #808
1 parent 0ea276d commit d241040

File tree

4 files changed

+342
-23
lines changed

4 files changed

+342
-23
lines changed

dask_ml/decomposition/incremental_pca.py

+7
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,21 @@ def __init__(
132132
svd_solver="auto",
133133
iterated_power=0,
134134
random_state=None,
135+
center=True,
135136
):
136137
self.n_components = n_components
137138
self.whiten = whiten
139+
self.center = center
138140
self.copy = copy
139141
self.batch_size = batch_size
140142
self.svd_solver = svd_solver
141143
self.iterated_power = iterated_power
142144
self.random_state = random_state
143145

146+
def _check_params(self):
147+
if self.center is False:
148+
raise ValueError("IncrementalPCA with center=False is not supported.")
149+
144150
def _fit(self, X, y=None):
145151
"""Fit the model with X, using minibatches of size batch_size.
146152
@@ -238,6 +244,7 @@ def partial_fit(self, X, y=None, check_input=True):
238244
self : object
239245
Returns the instance itself.
240246
"""
247+
self._check_params()
241248
if check_input:
242249
if sparse.issparse(X):
243250
raise TypeError(

dask_ml/decomposition/pca.py

+65-17
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,21 @@ class PCA(sklearn.decomposition.PCA):
8686
If None, the random number generator is the RandomState instance used
8787
by `da.random`. Used when ``svd_solver`` == 'randomized'.
8888
89+
center : bool, optional (default True)
90+
When True (the default), the underlying data gets centered at zero
91+
by subtracting the mean of the data from the data itself.
92+
93+
PCA is performed on centered data due to its being a regression model,
94+
without an intercept. As such, its principal components originate at the
95+
origin of the transformed space.
96+
97+
``center=False`` may be employed when performing PCA on already
98+
centered data.
99+
100+
Since centering is a required step as part of whitening, ``center`` set
101+
to False and ``whiten`` set to True is a combination which may result in
102+
unexpected behavior, if performed on not previously centered data.
103+
89104
Attributes
90105
----------
91106
components_ : array, shape (n_components, n_features)
@@ -152,18 +167,27 @@ class PCA(sklearn.decomposition.PCA):
152167
PCA(copy=True, iterated_power='auto', n_components=2, random_state=None,
153168
svd_solver='auto', tol=0.0, whiten=False)
154169
>>> print(pca.explained_variance_ratio_) # doctest: +ELLIPSIS
155-
[ 0.99244... 0.00755...]
170+
[0.99244289 0.00755711]
156171
>>> print(pca.singular_values_) # doctest: +ELLIPSIS
157-
[ 6.30061... 0.54980...]
172+
[6.30061232 0.54980396]
158173
159174
>>> pca = PCA(n_components=2, svd_solver='full')
160175
>>> pca.fit(dX) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
161176
PCA(copy=True, iterated_power='auto', n_components=2, random_state=None,
162177
svd_solver='full', tol=0.0, whiten=False)
163178
>>> print(pca.explained_variance_ratio_) # doctest: +ELLIPSIS
164-
[ 0.99244... 0.00755...]
179+
[0.99244289 0.00755711]
180+
>>> print(pca.singular_values_) # doctest: +ELLIPSIS
181+
[6.30061232 0.54980396]
182+
183+
>>> dX_mean_0 = dX - dX.mean(axis=0)
184+
>>> pca = PCA(n_components=2, svd_solver='full', center=False)
185+
>>> pca.fit(dX_mean_0)
186+
PCA(center=False, n_components=2, svd_solver='full')
187+
>>> print(pca.explained_variance_ratio_) # doctest: +ELLIPSIS
188+
[0.99244289 0.00755711]
165189
>>> print(pca.singular_values_) # doctest: +ELLIPSIS
166-
[ 6.30061... 0.54980...]
190+
[6.30061232 0.54980396]
167191
168192
Notes
169193
-----
@@ -175,6 +199,10 @@ class PCA(sklearn.decomposition.PCA):
175199
``dask.linalg.svd_compressed``.
176200
* n_components : ``n_components='mle'`` is not allowed.
177201
Fractional ``n_components`` between 0 and 1 is not allowed.
202+
* center : if ``True`` (the default), automatically center input data before
203+
performing PCA.
204+
Set this parameter to ``False``, if the input data have already been
205+
centered before running ``fit()``.
178206
"""
179207

180208
def __init__(
@@ -186,10 +214,12 @@ def __init__(
186214
tol=0.0,
187215
iterated_power=0,
188216
random_state=None,
217+
center=True,
189218
):
190219
self.n_components = n_components
191220
self.copy = copy
192221
self.whiten = whiten
222+
self.center = center
193223
self.svd_solver = svd_solver
194224
self.tol = tol
195225
self.iterated_power = iterated_power
@@ -198,6 +228,7 @@ def __init__(
198228
def fit(self, X, y=None):
199229
if not dask.is_dask_collection(X):
200230
raise TypeError(_TYPE_MSG.format(type(X)))
231+
201232
self._fit(X)
202233
self.n_features_in_ = X.shape[1]
203234
return self
@@ -266,8 +297,10 @@ def _fit(self, X):
266297

267298
solver = self._get_solver(X, n_components)
268299

269-
self.mean_ = X.mean(0)
270-
X -= self.mean_
300+
self.mean_ = X.mean(axis=0)
301+
302+
if self.center:
303+
X -= self.mean_
271304

272305
if solver in {"full", "tsqr"}:
273306
U, S, V = da.linalg.svd(X)
@@ -370,14 +403,20 @@ def transform(self, X):
370403
X_new : array-like, shape (n_samples, n_components)
371404
372405
"""
373-
check_is_fitted(self, ["mean_", "components_"])
406+
check_is_fitted(self, "components_")
407+
408+
if self.whiten:
409+
check_is_fitted(self, "explained_variance_")
410+
411+
if self.center:
412+
check_is_fitted(self, "mean_")
413+
if self.mean_ is not None:
414+
X -= self.mean_
374415

375-
# X = check_array(X)
376-
if self.mean_ is not None:
377-
X = X - self.mean_
378416
X_transformed = da.dot(X, self.components_.T)
379417
if self.whiten:
380418
X_transformed /= np.sqrt(self.explained_variance_)
419+
381420
return X_transformed
382421

383422
def fit_transform(self, X, y=None):
@@ -396,7 +435,6 @@ def fit_transform(self, X, y=None):
396435
X_new : array-like, shape (n_samples, n_components)
397436
398437
"""
399-
# X = check_array(X)
400438
if not dask.is_dask_collection(X):
401439
raise TypeError(_TYPE_MSG.format(type(X)))
402440
U, S, V = self._fit(X)
@@ -431,18 +469,25 @@ def inverse_transform(self, X):
431469
If whitening is enabled, inverse_transform does not compute the
432470
exact inverse operation of transform.
433471
"""
434-
check_is_fitted(self, "mean_")
472+
check_is_fitted(self, "components_")
473+
474+
if self.center:
475+
check_is_fitted(self, "mean_")
476+
offset = self.mean_
477+
else:
478+
offset = 0
435479

436480
if self.whiten:
481+
check_is_fitted(self, "explained_variance_")
437482
return (
438483
da.dot(
439484
X,
440485
np.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_,
441486
)
442-
+ self.mean_
487+
+ offset
443488
)
444-
else:
445-
return da.dot(X, self.components_) + self.mean_
489+
490+
return da.dot(X, self.components_) + offset
446491

447492
def score_samples(self, X):
448493
"""Return the log-likelihood of each sample.
@@ -463,8 +508,11 @@ def score_samples(self, X):
463508
"""
464509
check_is_fitted(self, "mean_")
465510

466-
# X = check_array(X)
467-
Xr = X - self.mean_
511+
if self.center:
512+
Xr = X - self.mean_
513+
else:
514+
Xr = X
515+
468516
n_features = X.shape[1]
469517
precision = self.get_precision() # [n_features, n_features]
470518
log_like = -0.5 * (Xr * (da.dot(Xr, precision))).sum(axis=1)

tests/test_incremental_pca.py

+17
Original file line numberDiff line numberDiff line change
@@ -475,3 +475,20 @@ def test_incremental_pca_partial_fit_float_division():
475475
np.testing.assert_allclose(
476476
singular_vals_float_samples_seen, singular_vals_int_samples_seen
477477
)
478+
479+
480+
def test_incremental_pca_no_centering_not_supported():
481+
rng = np.random.RandomState(0)
482+
A = rng.randn(5, 3) + 2
483+
A = da.from_array(A, chunks=[3, -1])
484+
485+
pca = IncrementalPCA(n_components=2, center=False)
486+
487+
with pytest.raises(ValueError, match="not supported"):
488+
pca.partial_fit(A)
489+
490+
with pytest.raises(ValueError, match="not supported"):
491+
pca.fit(A)
492+
493+
with pytest.raises(ValueError, match="not supported"):
494+
pca.fit_transform(A)

0 commit comments

Comments
 (0)