Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [R-package] Add support for specifying training indices in lgb.cv() #3989

Closed
wants to merge 10 commits into from
12 changes: 11 additions & 1 deletion R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ CVBooster <- R6::R6Class(
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
#' (each element must be a vector of test fold's indices). When folds are supplied,
#' the \code{nfold} and \code{stratified} parameters are ignored.
#' @param train_folds \code{list} specifying which indicies to use for training. If \code{NULL}
julioasotodv marked this conversation as resolved.
Show resolved Hide resolved
#' (the default) all indices not specified in \code{folds} will be used for training.
#' @param colnames feature names, if not null, will use this to overwrite the names in dataset
#' @param categorical_feature categorical features. This can either be a character vector of feature
#' names or an integer vector with the indices of the features (e.g.
Expand Down Expand Up @@ -83,6 +85,7 @@ lgb.cv <- function(params = list()
, showsd = TRUE
, stratified = TRUE
, folds = NULL
, train_folds = NULL
julioasotodv marked this conversation as resolved.
Show resolved Hide resolved
, init_model = NULL
, colnames = NULL
, categorical_feature = NULL
Expand Down Expand Up @@ -302,7 +305,14 @@ lgb.cv <- function(params = list()
} else {
test_indices <- folds[[k]]
}
train_indices <- seq_len(nrow(data))[-test_indices]

# Generate train_indices from either the train_folds argument
# or as the opposite of (test)folds argument:
if (!is.null(train_folds)) {
train_indices <- train_folds[[k]]
} else {
train_indices <- seq_len(nrow(data))[-test_indices]
}

# set up test set
indexDT <- data.table::data.table(
Expand Down
4 changes: 4 additions & 0 deletions R-package/man/lgb.cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.