Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ API changes summary
- ``gcv_mode="auto"`` no longer tries to perform SVD on a densified
sparse matrix in :class:`sklearn.linear_model.RidgeCV`.

- :class:`cross_valiation.KFold` and
:class:`cross_valiation.StratifiedKFold` now enforce `n_folds >= 2`
otherwise a ``ValueError`` is raised. By `Olivier Grisel`_.


.. _changes_0_13_1:

Expand Down
38 changes: 18 additions & 20 deletions sklearn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,17 @@ def __len__(self):
/ factorial(self.p))


def _validate_kfold(k, n_samples):
if k <= 0:
raise ValueError("Cannot have number of folds k below 1.")
if k > n_samples:
raise ValueError("Cannot have number of folds k=%d greater than"
" the number of samples: %d." % (k, n_samples))
def _validate_kfold(n_folds, n):
if abs(n - int(n)) >= np.finfo('f').eps:
raise ValueError("n must be an integer")
if abs(n_folds - int(n_folds)) >= np.finfo('f').eps:
raise ValueError("n_folds must be an integer")
if n_folds < 2:
raise ValueError("Cannot have number of folds n_folds below 2.")
if n_folds > n:
raise ValueError("Cannot have number of folds n_folds=%d greater than"
" the number of samples: %d." % (n_folds, n))
return int(n_folds), int(n)


class KFold(object):
Expand All @@ -218,7 +223,7 @@ class KFold(object):
Total number of elements.

n_folds : int, default=3
Number of folds.
Number of folds. Must be at least 2.

indices : boolean, optional (default True)
Return train/test split as arrays of indices, rather than a boolean
Expand Down Expand Up @@ -266,17 +271,10 @@ def __init__(self, n, n_folds=3, indices=True, shuffle=False,
warnings.warn("The parameter k was renamed to n_folds and will be"
" removed in 0.15.", DeprecationWarning)
n_folds = k
_validate_kfold(n_folds, n)
self.n_folds, self.n = _validate_kfold(n_folds, n)
random_state = check_random_state(random_state)

if abs(n - int(n)) >= np.finfo('f').eps:
raise ValueError("n must be an integer")
self.n = int(n)
if abs(n_folds - int(n_folds)) >= np.finfo('f').eps:
raise ValueError("n_folds must be an integer")
self.n_folds = int(n_folds)
self.indices = indices
self.idxs = np.arange(n)
self.idxs = np.arange(self.n)
if shuffle:
random_state.shuffle(self.idxs)

Expand Down Expand Up @@ -326,7 +324,7 @@ class StratifiedKFold(object):
Samples to split in K folds.

n_folds : int, default=3
Number of folds.
Number of folds. Must be at least 2.

indices : boolean, optional (default True)
Return train/test split as arrays of indices, rather than a boolean
Expand Down Expand Up @@ -363,15 +361,15 @@ def __init__(self, y, n_folds=3, indices=True, k=None):
n_folds = k
y = np.asarray(y)
n = y.shape[0]
_validate_kfold(n_folds, n)
self.n_folds, self.n = _validate_kfold(n_folds, n)
_, y_sorted = unique(y, return_inverse=True)
min_labels = np.min(np.bincount(y_sorted))
if n_folds > min_labels:
if self.n_folds > min_labels:
warnings.warn(("The least populated class in y has only %d"
" members, which is too few. The minimum"
" number of labels for any class cannot"
" be less than n_folds=%d."
% (min_labels, n_folds)), Warning)
% (min_labels, self.n_folds)), Warning)
self.y = y
self.n_folds = n_folds
self.indices = indices
Expand Down
8 changes: 6 additions & 2 deletions sklearn/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,18 @@ def test_kfold_valueerrors():
# a characteristic of the code and not a behavior
assert_true("The least populated class" in str(w[0]))

# Error when number of folds is <= 0
# Error when number of folds is <= 1
assert_raises(ValueError, cval.KFold, 2, 0)
assert_raises(ValueError, cval.KFold, 2, 1)
assert_raises(ValueError, cval.StratifiedKFold, y, 0)
assert_raises(ValueError, cval.StratifiedKFold, y, 1)

# When n is not integer:
assert_raises(ValueError, cval.KFold, 2.5, 1)
assert_raises(ValueError, cval.KFold, 2.5, 2)

# When n_folds is not integer:
assert_raises(ValueError, cval.KFold, 5, 1.5)
assert_raises(ValueError, cval.StratifiedKFold, y, 1.5)


def test_kfold_indices():
Expand Down