Skip to content

Commit 92d1bfe

Browse files
authored
[R] Move gc data protection to R side (#11104)
1 parent 57ce062 commit 92d1bfe

File tree

2 files changed

+31
-14
lines changed

2 files changed

+31
-14
lines changed

Diff for: R-package/R/xgb.DMatrix.R

+28-5
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@ xgb.QuantileDMatrix <- function(
353353
)
354354
data_iterator <- .single.data.iterator(iterator_env)
355355

356+
env_keep_alive <- new.env()
357+
env_keep_alive$keepalive <- NULL
358+
356359
# Note: the ProxyDMatrix has its finalizer assigned in the R externalptr
357360
# object, but that finalizer will only be called once the object is
358361
# garbage-collected, which doesn't happen immediately after it goes out
@@ -363,9 +366,10 @@ xgb.QuantileDMatrix <- function(
363366
.Call(XGDMatrixFree_R, proxy_handle)
364367
})
365368
iterator_next <- function() {
366-
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
369+
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
367370
}
368371
iterator_reset <- function() {
372+
env_keep_alive$keepalive <- NULL
369373
return(data_iterator$f_reset(iterator_env))
370374
}
371375
calling_env <- environment()
@@ -553,7 +557,8 @@ xgb.DataBatch <- function(
553557
}
554558

555559
# This is only for internal usage, class is not exposed to the user.
556-
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
560+
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator, env_keep_alive) {
561+
env_keep_alive$keepalive <- NULL
557562
lst <- data_iterator$f_next(data_iterator$env)
558563
if (is.null(lst)) {
559564
return(0L)
@@ -566,13 +571,19 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
566571
stop("Either one of 'group' or 'qid' should be NULL")
567572
}
568573
if (is.data.frame(lst$data)) {
569-
tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types)
574+
data <- lst$data
575+
lst$data <- NULL
576+
tmp <- .process.df.for.dmatrix(data, lst$feature_types)
570577
lst$feature_types <- tmp$feature_types
578+
data <- NULL
579+
env_keep_alive$keepalive <- tmp
571580
.Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst)
572581
} else if (is.matrix(lst$data)) {
582+
env_keep_alive$keepalive <- lst
573583
.Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data)
574584
} else if (inherits(lst$data, "dgRMatrix")) {
575585
tmp <- list(p = lst$data@p, j = lst$data@j, x = lst$data@x, ncol = ncol(lst$data))
586+
env_keep_alive$keepalive <- tmp
576587
.Call(XGProxyDMatrixSetDataCSR_R, proxy_handle, tmp)
577588
} else {
578589
stop("'data' has unsupported type.")
@@ -712,14 +723,23 @@ xgb.ExtMemDMatrix <- function(
712723
cache_prefix <- path.expand(cache_prefix)
713724
nthread <- as.integer(NVL(nthread, -1L))
714725

726+
# The purpose of this environment is to keep data alive (protected from the
727+
# garbage collector) after setting the data in the proxy dmatrix. The data
728+
# held here (under name 'keepalive') should be unset (leaving it unprotected
729+
# for garbage collection) before the start of each data iteration batch and
730+
# during each iterator reset.
731+
env_keep_alive <- new.env()
732+
env_keep_alive$keepalive <- NULL
733+
715734
proxy_handle <- .make.proxy.handle()
716735
on.exit({
717736
.Call(XGDMatrixFree_R, proxy_handle)
718737
})
719738
iterator_next <- function() {
720-
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
739+
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
721740
}
722741
iterator_reset <- function() {
742+
env_keep_alive$keepalive <- NULL
723743
return(data_iterator$f_reset(data_iterator$env))
724744
}
725745
calling_env <- environment()
@@ -779,14 +799,17 @@ xgb.QuantileDMatrix.from_iterator <- function( # nolint
779799

780800
nthread <- as.integer(NVL(nthread, -1L))
781801

802+
env_keep_alive <- new.env()
803+
env_keep_alive$keepalive <- NULL
782804
proxy_handle <- .make.proxy.handle()
783805
on.exit({
784806
.Call(XGDMatrixFree_R, proxy_handle)
785807
})
786808
iterator_next <- function() {
787-
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
809+
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
788810
}
789811
iterator_reset <- function() {
812+
env_keep_alive$keepalive <- NULL
790813
return(data_iterator$f_reset(data_iterator$env))
791814
}
792815
calling_env <- environment()

Diff for: R-package/src/xgboost_R.cc

+3-9
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) {
687687
{
688688
std::string array_str = MakeArrayInterfaceFromRMat(R_mat);
689689
res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str());
690-
R_SetExternalPtrProtected(handle, R_mat);
691690
}
692691
CHECK_CALL(res_code);
693692
R_API_END();
@@ -708,7 +707,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) {
708707
array_str_indices.c_str(),
709708
array_str_data.c_str(),
710709
ncol);
711-
R_SetExternalPtrProtected(handle, lst);
712710
}
713711
CHECK_CALL(res_code);
714712
R_API_END();
@@ -722,7 +720,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) {
722720
{
723721
std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst);
724722
res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str());
725-
R_SetExternalPtrProtected(handle, lst);
726723
}
727724
CHECK_CALL(res_code);
728725
R_API_END();
@@ -736,20 +733,17 @@ struct _RDataIterator {
736733
SEXP f_reset;
737734
SEXP calling_env;
738735
SEXP continuation_token;
739-
SEXP proxy_dmat;
740736

741737
_RDataIterator(
742-
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) :
738+
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) :
743739
f_next(f_next), f_reset(f_reset), calling_env(calling_env),
744-
continuation_token(continuation_token), proxy_dmat(proxy_dmat) {}
740+
continuation_token(continuation_token) {}
745741

746742
void reset() {
747-
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
748743
SafeExecFun(this->f_reset, this->calling_env, this->continuation_token);
749744
}
750745

751746
int next() {
752-
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
753747
SEXP R_res = Rf_protect(
754748
SafeExecFun(this->f_next, this->calling_env, this->continuation_token));
755749
int res = Rf_asInteger(R_res);
@@ -777,7 +771,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R(
777771

778772
int res_code;
779773
try {
780-
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat);
774+
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token);
781775

782776
std::string str_cache_prefix;
783777
xgboost::Json jconfig{xgboost::Object{}};

0 commit comments

Comments
 (0)