From 459b4a464142e23e7c0858f70c2ebe59b6e4426b Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 11 Jun 2013 11:08:34 +0200 Subject: [PATCH 1/2] Enforce n_folds >= 2 for k-fold cross-validation --- doc/whats_new.rst | 4 ++++ sklearn/cross_validation.py | 19 ++++++++++++++++--- sklearn/tests/test_cross_validation.py | 8 ++++++-- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 70b1c23b2a6b8..0ebdcdd72e2a8 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -121,6 +121,10 @@ API changes summary - ``gcv_mode="auto"`` no longer tries to perform SVD on a densified sparse matrix in :class:`sklearn.linear_model.RidgeCV`. + - :class:`cross_valiation.KFold` and + :class:`cross_valiation.StratifiedKFold` now enforce `n_folds >= 2` + otherwise a ``ValueError`` is raised. By `Olivier Grisel`_. + .. _changes_0_13_1: diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index b858257c4b2aa..42c58e91f838d 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -218,7 +218,7 @@ class KFold(object): Total number of elements. n_folds : int, default=3 - Number of folds. + Number of folds. Must be at least 2. indices : boolean, optional (default True) Return train/test split as arrays of indices, rather than a boolean @@ -274,7 +274,13 @@ def __init__(self, n, n_folds=3, indices=True, shuffle=False, self.n = int(n) if abs(n_folds - int(n_folds)) >= np.finfo('f').eps: raise ValueError("n_folds must be an integer") - self.n_folds = int(n_folds) + self.n_folds = n_folds = int(n_folds) + if n_folds < 2: + raise ValueError( + "KFold cross validation requires at least one" + " train / test split by setting n_folds=2 or more," + " got n_folds=%d." + % n_folds) self.indices = indices self.idxs = np.arange(n) if shuffle: @@ -326,7 +332,7 @@ class StratifiedKFold(object): Samples to split in K folds. n_folds : int, default=3 - Number of folds. + Number of folds. Must be at least 2. indices : boolean, optional (default True) Return train/test split as arrays of indices, rather than a boolean @@ -366,6 +372,13 @@ def __init__(self, y, n_folds=3, indices=True, k=None): _validate_kfold(n_folds, n) _, y_sorted = unique(y, return_inverse=True) min_labels = np.min(np.bincount(y_sorted)) + n_folds = int(n_folds) + if n_folds < 2: + raise ValueError( + "StratifiedKFold cross validation requires at least one" + " train / test split by setting n_folds=2 or more," + " got n_folds=%d." + % n_folds) if n_folds > min_labels: warnings.warn(("The least populated class in y has only %d" " members, which is too few. The minimum" diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index 64320c544b9aa..3f78861242ca6 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -107,14 +107,18 @@ def test_kfold_valueerrors(): # a characteristic of the code and not a behavior assert_true("The least populated class" in str(w[0])) - # Error when number of folds is <= 0 + # Error when number of folds is <= 1 assert_raises(ValueError, cval.KFold, 2, 0) + assert_raises(ValueError, cval.KFold, 2, 1) + assert_raises(ValueError, cval.StratifiedKFold, y, 0) + assert_raises(ValueError, cval.StratifiedKFold, y, 1) # When n is not integer: - assert_raises(ValueError, cval.KFold, 2.5, 1) + assert_raises(ValueError, cval.KFold, 2.5, 2) # When n_folds is not integer: assert_raises(ValueError, cval.KFold, 5, 1.5) + assert_raises(ValueError, cval.StratifiedKFold, y, 1.5) def test_kfold_indices(): From a4c1733b7e4884c9efbc8801edc5cf660f5a6e34 Mon Sep 17 00:00:00 2001 From: Joel Nothman Date: Wed, 12 Jun 2013 23:30:54 +1000 Subject: [PATCH 2/2] COSMIT move n_folds validation to _validate_kfold --- sklearn/cross_validation.py | 47 +++++++++++++------------------------ 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index 42c58e91f838d..61837dd3639f5 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -195,12 +195,17 @@ def __len__(self): / factorial(self.p)) -def _validate_kfold(k, n_samples): - if k <= 0: - raise ValueError("Cannot have number of folds k below 1.") - if k > n_samples: - raise ValueError("Cannot have number of folds k=%d greater than" - " the number of samples: %d." % (k, n_samples)) +def _validate_kfold(n_folds, n): + if abs(n - int(n)) >= np.finfo('f').eps: + raise ValueError("n must be an integer") + if abs(n_folds - int(n_folds)) >= np.finfo('f').eps: + raise ValueError("n_folds must be an integer") + if n_folds < 2: + raise ValueError("Cannot have number of folds n_folds below 2.") + if n_folds > n: + raise ValueError("Cannot have number of folds n_folds=%d greater than" + " the number of samples: %d." % (n_folds, n)) + return int(n_folds), int(n) class KFold(object): @@ -266,23 +271,10 @@ def __init__(self, n, n_folds=3, indices=True, shuffle=False, warnings.warn("The parameter k was renamed to n_folds and will be" " removed in 0.15.", DeprecationWarning) n_folds = k - _validate_kfold(n_folds, n) + self.n_folds, self.n = _validate_kfold(n_folds, n) random_state = check_random_state(random_state) - - if abs(n - int(n)) >= np.finfo('f').eps: - raise ValueError("n must be an integer") - self.n = int(n) - if abs(n_folds - int(n_folds)) >= np.finfo('f').eps: - raise ValueError("n_folds must be an integer") - self.n_folds = n_folds = int(n_folds) - if n_folds < 2: - raise ValueError( - "KFold cross validation requires at least one" - " train / test split by setting n_folds=2 or more," - " got n_folds=%d." - % n_folds) self.indices = indices - self.idxs = np.arange(n) + self.idxs = np.arange(self.n) if shuffle: random_state.shuffle(self.idxs) @@ -369,22 +361,15 @@ def __init__(self, y, n_folds=3, indices=True, k=None): n_folds = k y = np.asarray(y) n = y.shape[0] - _validate_kfold(n_folds, n) + self.n_folds, self.n = _validate_kfold(n_folds, n) _, y_sorted = unique(y, return_inverse=True) min_labels = np.min(np.bincount(y_sorted)) - n_folds = int(n_folds) - if n_folds < 2: - raise ValueError( - "StratifiedKFold cross validation requires at least one" - " train / test split by setting n_folds=2 or more," - " got n_folds=%d." - % n_folds) - if n_folds > min_labels: + if self.n_folds > min_labels: warnings.warn(("The least populated class in y has only %d" " members, which is too few. The minimum" " number of labels for any class cannot" " be less than n_folds=%d." - % (min_labels, n_folds)), Warning) + % (min_labels, self.n_folds)), Warning) self.y = y self.n_folds = n_folds self.indices = indices