From ed67624c002721f4570d521e8d9c58110ea13750 Mon Sep 17 00:00:00 2001 From: tobiasdut <167294345+tobiasdut@users.noreply.github.com> Date: Wed, 4 Mar 2026 00:30:19 +0100 Subject: [PATCH 01/13] Fix pandas NULL column / date type --- R/python-to-pandas-cleaned.R | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/R/python-to-pandas-cleaned.R b/R/python-to-pandas-cleaned.R index cf6747b..b9a3b66 100644 --- a/R/python-to-pandas-cleaned.R +++ b/R/python-to-pandas-cleaned.R @@ -23,10 +23,22 @@ to_pandas_cleaned <- function(x) { } else { # Pandas 3.0 conversion makes encases all columns inside lists collected <- try( - pandas_tbl$values |> - as.data.frame() |> - lapply(unlist) |> - set_names(x$columns), + { + as.data.frame(pandas_tbl$values) |> + lapply(\(col) { + map_vec( + col, + \(x) { + if (length(x) == 0 || is.nan(x)) { + NA + } else { + x[[1]] + } + } + ) + }) |> + set_names(x$columns) + }, silent = TRUE ) if (inherits(collected, "try-error")) { @@ -64,6 +76,8 @@ to_pandas_cleaned <- function(x) { as.Date(map_vec(col, clean_col), origin = "1970-01-01") } else if (py_type == "date" && r_type == "character") { as.Date(col, origin = "1970-01-01") + } else if (py_type == "date" && r_type %in% c("numeric", "integer", "logical")) { + as.Date(col, origin = "1970-01-01") } else if (py_type == "boolean" && r_type == "list") { map_lgl(col, clean_col) } else if (py_type == "boolean" && r_type == "character") { From 0bdfd23b57257df08a4a017c8e64be3e846126c0 Mon Sep 17 00:00:00 2001 From: tobiasdut <167294345+tobiasdut@users.noreply.github.com> Date: Wed, 4 Mar 2026 00:31:40 +0100 Subject: [PATCH 02/13] Replace pipe placeholder with compatible lambda in tune-grid --- R/tune-grid.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/tune-grid.R b/R/tune-grid.R index 1f47c07..cc114d3 100644 --- a/R/tune-grid.R +++ b/R/tune-grid.R @@ -57,7 +57,7 @@ tune_grid_spark.pyspark_connection <- function( # section vec_resamples <- resamples |> vctrs::vec_split(by = 1:nrow(resamples)) |> - _$val + (\(x) x$val)() pasted_pkgs <- paste0("'", prepped$needed_pkgs, "'", collapse = ", ") # --------------- Prepares and uploads R objects to Spark -------------------- From fe29e36e00bc4b0633354a8744c01682883b80b1 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sat, 7 Mar 2026 17:19:14 -0600 Subject: [PATCH 03/13] Temporarily downgrades Spark version for GHA --- .github/workflows/test-coverage.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index c031bf3..79c03a1 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: config: - - {spark: '4.0.1', pyspark: '4.0.1', hadoop: '3', scala: '2.13', python: '3.10', name: 'PySpark 4'} + - {spark: '4.0.0', pyspark: '4.0.0', hadoop: '3', scala: '2.13', python: '3.10', name: 'PySpark 4'} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} From 5a2004b92bcf81a40e1f5569240c5934502c9770 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sat, 7 Mar 2026 18:03:40 -0600 Subject: [PATCH 04/13] Temporariliy uses dev sparklyr --- .github/workflows/test-coverage.yaml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 79c03a1..21d6a2e 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: config: - - {spark: '4.0.0', pyspark: '4.0.0', hadoop: '3', scala: '2.13', python: '3.10', name: 'PySpark 4'} + - {spark: '4.1.1', pyspark: '4.1.1', hadoop: '3', scala: '2.13', python: '3.12', name: 'PySpark 4'} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} @@ -45,8 +45,14 @@ jobs: any::tidymodels any::probably any::rpart + any::pak needs: check + - name: Install sparklyr dev + run: | + pak::pak("sparklyr/sparklyr") + shell: Rscript {0} + - name: Cache Spark id: cache-spark uses: actions/cache@v3 From 16b65e531ad3b5ee102b2795321c428675044f1a Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sat, 7 Mar 2026 19:49:47 -0600 Subject: [PATCH 05/13] Fixes to pandas cleaned --- R/python-to-pandas-cleaned.R | 23 ++++++++++++++++++++++- tests/testthat.R | 2 +- tests/testthat/helper-init.R | 1 + tests/testthat/helper-ml.R | 24 ++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 2 deletions(-) diff --git a/R/python-to-pandas-cleaned.R b/R/python-to-pandas-cleaned.R index b9a3b66..5b94648 100644 --- a/R/python-to-pandas-cleaned.R +++ b/R/python-to-pandas-cleaned.R @@ -47,6 +47,24 @@ to_pandas_cleaned <- function(x) { } } + # Handle NULL columns by converting them to appropriate empty vectors + for (i in seq_along(collected)) { + if (is.null(collected[[i]])) { + collected[[i]] <- character(0) + } + # Convert arrow_binary to list + if (inherits(collected[[i]], "arrow_binary")) { + collected[[i]] <- as.list(collected[[i]]) + } + # Convert UTC timestamps to local timezone + if (inherits(collected[[i]], "POSIXct")) { + # If the timestamp is in UTC, convert to local timezone + if (!is.null(attr(collected[[i]], "tzone")) && attr(collected[[i]], "tzone") == "UTC") { + attr(collected[[i]], "tzone") <- "" + } + } + } + collected <- collected |> dplyr::as_tibble() @@ -82,7 +100,10 @@ to_pandas_cleaned <- function(x) { map_lgl(col, clean_col) } else if (py_type == "boolean" && r_type == "character") { as.logical(col) - } else if (r_type == "numeric") { + } else if (r_type == "integer" && py_type %in% c("bigint", "long")) { + # Convert bigint/long to numeric since R integer is 32-bit + as.numeric(col) + } else if (r_type == "numeric" || r_type == "integer") { if (py_type %in% c("tinyint", "smallint", "int")) { ptype <- integer() } else { diff --git a/tests/testthat.R b/tests/testthat.R index 85e89fc..912812c 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -7,7 +7,7 @@ # * https://testthat.r-lib.org/articles/special-files.html # Sys.setenv("CODE_COVERAGE" = "true") -# Sys.setenv("SPARK_VERSION" = "4.0.1"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.10") +# Sys.setenv("SPARK_VERSION" = "4.1.0"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.12") # Sys.setenv("SPARK_VERSION" = "3.5.7"); Sys.setenv("SCALA_VERSION" = "2.12"); Sys.setenv("PYTHON_VERSION" = "3.10") if (identical(Sys.getenv("CODE_COVERAGE"), "true")) { library(testthat) diff --git a/tests/testthat/helper-init.R b/tests/testthat/helper-init.R index 9059abb..b359707 100644 --- a/tests/testthat/helper-init.R +++ b/tests/testthat/helper-init.R @@ -59,6 +59,7 @@ use_test_connect_start <- function() { cli_inform("SCALA_VERSION: {use_test_scala_spark()}") cli_inform("PYTHON_VERSION: {use_test_python_version()}") cli_h2("") + print(reticulate::py_list_packages()) withr::with_envvar( new = c( diff --git a/tests/testthat/helper-ml.R b/tests/testthat/helper-ml.R index ab3e5fc..a70fa10 100644 --- a/tests/testthat/helper-ml.R +++ b/tests/testthat/helper-ml.R @@ -3,6 +3,23 @@ use_test_pull <- function(x, table = FALSE) { if (table) { x <- table(x) } + + # Handle Spark ML vectors that come through as structured data frames + # (Spark 4.1+ with Pandas 3.0+ converts ML vectors to data frames with type, size, indices, values columns) + if (is.data.frame(x) && all(c("type", "size", "indices", "values") %in% names(x))) { + # Extract the values column which contains the actual vector data + x <- data.frame( + x = map_chr(x$values, function(vec) { + if (is.null(vec) || length(vec) == 0) { + "" + } else { + paste(as.vector(vec), collapse = ", ") + } + }) + ) + return(x) + } + if (inherits(x[[1]], "array")) { x <- as.double(map(x, as.vector)) } @@ -16,6 +33,13 @@ use_test_pull <- function(x, table = FALSE) { x = map_chr(x, function(x) paste(as.vector(x$array), collapse = ", ")) ) } + # Handle vectors that have been converted to numeric during pandas conversion + # (Spark 4.1+ with Pandas 3.0+ converts ML vectors to numeric vectors) + if (is.list(x) && length(x) > 0 && is.numeric(x[[1]]) && !is.data.frame(x)) { + x <- data.frame( + x = map_chr(x, function(vec) paste(as.vector(vec), collapse = ", ")) + ) + } x } From b102e02b54d25de36cd8532ca9ed3a7440587239 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sat, 7 Mar 2026 20:10:57 -0600 Subject: [PATCH 06/13] Updates Python version on CI --- .github/workflows/test-coverage.yaml | 2 +- tests/testthat.R | 2 +- tests/testthat/_snaps/ml-feature-transformers.md | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 21d6a2e..0f47194 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: config: - - {spark: '4.1.1', pyspark: '4.1.1', hadoop: '3', scala: '2.13', python: '3.12', name: 'PySpark 4'} + - {spark: '4.1.1', pyspark: '4.1.1', hadoop: '3', scala: '2.13', python: '3.13', name: 'PySpark 4'} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} diff --git a/tests/testthat.R b/tests/testthat.R index 912812c..00aad07 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -7,7 +7,7 @@ # * https://testthat.r-lib.org/articles/special-files.html # Sys.setenv("CODE_COVERAGE" = "true") -# Sys.setenv("SPARK_VERSION" = "4.1.0"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.12") +# Sys.setenv("SPARK_VERSION" = "4.1.1"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.13") # Sys.setenv("SPARK_VERSION" = "3.5.7"); Sys.setenv("SCALA_VERSION" = "2.12"); Sys.setenv("PYTHON_VERSION" = "3.10") if (identical(Sys.getenv("CODE_COVERAGE"), "true")) { library(testthat) diff --git a/tests/testthat/_snaps/ml-feature-transformers.md b/tests/testthat/_snaps/ml-feature-transformers.md index dba7be8..aa161f4 100644 --- a/tests/testthat/_snaps/ml-feature-transformers.md +++ b/tests/testthat/_snaps/ml-feature-transformers.md @@ -611,6 +611,7 @@ dplyr::pull(ft_ngram(ft_tokenizer(use_test_table_reviews(), "x", "token_x"), "token_x", "ngram_x")) Output + [1]> [[1]] [1] "this has" "has been" "been the" "the best" [5] "best tv" "tv i've" "i've ever" "ever used." @@ -888,6 +889,7 @@ Code dplyr::pull(ft_regex_tokenizer(use_test_table_reviews(), "x", "new_x")) Output + [1]> [[1]] [1] "this" "has" "been" "the" "best" "tv" "i've" [8] "ever" "used." "great" "screen," "and" "sound." @@ -1063,6 +1065,7 @@ dplyr::pull(ft_stop_words_remover(ft_tokenizer(use_test_table_reviews(), input_col = "x", output_col = "token_x"), input_col = "token_x", output_col = "stop_x")) Output + [1]> [[1]] [1] "best" "tv" "ever" "used." "great" "screen," "sound." @@ -1114,6 +1117,7 @@ Code dplyr::pull(ft_tokenizer(use_test_table_reviews(), input_col = "x", output_col = "token_x")) Output + [1]> [[1]] [1] "this" "has" "been" "the" "best" "tv" "i've" [8] "ever" "used." "great" "screen," "and" "sound." From d0a946c0947e0a282974d38fbd84ba130f5179f7 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sat, 7 Mar 2026 20:47:48 -0600 Subject: [PATCH 07/13] Explicitly sets Java version for CI --- .github/workflows/spark-tests.yaml | 6 ++++++ .github/workflows/test-coverage.yaml | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/.github/workflows/spark-tests.yaml b/.github/workflows/spark-tests.yaml index b458e7c..fc74e56 100644 --- a/.github/workflows/spark-tests.yaml +++ b/.github/workflows/spark-tests.yaml @@ -31,6 +31,12 @@ jobs: steps: - uses: actions/checkout@v3 + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + cache: 'maven' + - uses: r-lib/actions/setup-r@v2 with: r-version: 'release' diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 0f47194..7f99f41 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -31,6 +31,12 @@ jobs: steps: - uses: actions/checkout@v3 + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '21' + cache: 'maven' + - uses: r-lib/actions/setup-r@v2 with: r-version: 'release' From 6a97b688ff7100f6afd909de3c39c54f00dbc25f Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sat, 7 Mar 2026 20:50:08 -0600 Subject: [PATCH 08/13] Removes cache --- .github/workflows/spark-tests.yaml | 1 - .github/workflows/test-coverage.yaml | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/spark-tests.yaml b/.github/workflows/spark-tests.yaml index fc74e56..5104a58 100644 --- a/.github/workflows/spark-tests.yaml +++ b/.github/workflows/spark-tests.yaml @@ -35,7 +35,6 @@ jobs: with: distribution: 'temurin' java-version: '17' - cache: 'maven' - uses: r-lib/actions/setup-r@v2 with: diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index 7f99f41..b19f44c 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -35,7 +35,6 @@ jobs: with: distribution: 'temurin' java-version: '21' - cache: 'maven' - uses: r-lib/actions/setup-r@v2 with: From 1f0908643bdfd35c93a61bec7709f01e88bda10c Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sat, 7 Mar 2026 21:07:03 -0600 Subject: [PATCH 09/13] Avoids compatability issues with Linux workstations --- R/python-to-pandas-cleaned.R | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/R/python-to-pandas-cleaned.R b/R/python-to-pandas-cleaned.R index 5b94648..ae8d948 100644 --- a/R/python-to-pandas-cleaned.R +++ b/R/python-to-pandas-cleaned.R @@ -63,6 +63,24 @@ to_pandas_cleaned <- function(x) { attr(collected[[i]], "tzone") <- "" } } + # Convert pandas Arrow-backed arrays to regular R vectors + if (inherits(collected[[i]], "python.builtin.object")) { + col_class <- try(class(collected[[i]]), silent = TRUE) + if (!inherits(col_class, "try-error")) { + # Check if it's an Arrow-backed array from pandas + if (any(grepl("ArrowStringArray|ArrowExtensionArray", col_class))) { + # Convert to regular character vector using numpy + collected[[i]] <- try( + as.character(collected[[i]]$to_numpy()), + silent = TRUE + ) + # Fallback to Python list conversion if to_numpy fails + if (inherits(collected[[i]], "try-error")) { + collected[[i]] <- as.character(collected[[i]]$tolist()) + } + } + } + } } collected <- collected |> From 81324e72475b991dc20ca87f8f772ff7bdae8f20 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sun, 8 Mar 2026 14:36:19 -0500 Subject: [PATCH 10/13] Properly implements test defer --- R/start-stop-service.R | 41 ++++++++++++++++++++++++++++++++++++----- tests/testthat/setup.R | 27 +++++++++++++++++++-------- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/R/start-stop-service.R b/R/start-stop-service.R index 7e49a9c..4771e88 100644 --- a/R/start-stop-service.R +++ b/R/start-stop-service.R @@ -74,12 +74,32 @@ spark_connect_service_start <- function( } ) - output <- prs$read_all_output() - cli_bullets(c(" " = "{.info {output}}")) - error <- prs$read_all_error() - if (error != "") { - cli_abort(error) + # Wait briefly to see if there's immediate output or errors + # Don't use read_all_output() as it blocks until process closes stdout + # Spark Connect is a long-running service, so we only check initial output + prs$poll_io(timeout = 2000) # Wait max 2 seconds + + # Read only what's available, don't wait for EOF + if (prs$is_alive()) { + output <- prs$read_output_lines() + if (length(output) > 0) { + cli_bullets(c(" " = "{.info {paste(output, collapse = '\n')}}")) + } + error <- prs$read_error_lines() + if (length(error) > 0 && any(nzchar(error))) { + cli_alert_warning(paste(error, collapse = "\n")) + } + } else { + # Process exited immediately, likely an error + if (prs$get_exit_status() != 0) { + error <- prs$read_all_error() + cli_abort(c("Failed to start Spark Connect service", "x" = error)) + } } + + # Store the process for potential cleanup + assign("spark_connect_process", prs, envir = .GlobalEnv) + cli_end() invisible() } @@ -97,6 +117,17 @@ spark_connect_service_stop <- function(version = "4.0", ...) { stderr = "|", stdin = "|" ) + + # Wait for shutdown command to complete + prs$wait(timeout = 5000) # Wait max 5 seconds + cli_bullets(c(" " = "{.info - Shutdown command sent}")) + + # Clean up stored process reference + if (exists("spark_connect_process", envir = .GlobalEnv)) { + rm("spark_connect_process", envir = .GlobalEnv) + } + cli_end() + invisible() } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 020ac7c..8f5b57c 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -3,11 +3,22 @@ suppressPackageStartupMessages(library(sparklyr)) suppressPackageStartupMessages(library(cli)) ## Clean up at end -# withr::defer({ -# Disconnecting from Spark -# withr::defer(spark_disconnect_all(), teardown_env()) -# Stopping Spark Connect service -# withr::defer(pysparklyr::spark_connect_service_stop("4.0.1")) -# Deleting main Python environment -# withr::defer(fs::dir_delete(use_test_env()), teardown_env()) -# }) +withr::defer({ + # Disconnecting from Spark + try(spark_disconnect_all(), silent = TRUE) + + # Stopping Spark Connect service + try(spark_connect_service_stop(use_test_version_spark()), silent = TRUE) + + # Kill process if still running + if (exists("spark_connect_process", envir = .GlobalEnv)) { + prs <- get("spark_connect_process", envir = .GlobalEnv) + if (prs$is_alive()) { + try(prs$kill(), silent = TRUE) + } + rm("spark_connect_process", envir = .GlobalEnv) + } + + # Deleting main Python environment (commented out as it may be too aggressive) + # try(fs::dir_delete(use_test_env()), silent = TRUE) +}, testthat::teardown_env()) From 3908747d9382dc378658910776d03793dcb7a220 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sun, 8 Mar 2026 15:30:08 -0500 Subject: [PATCH 11/13] Improvements to loop call --- R/tune-grid.R | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/R/tune-grid.R b/R/tune-grid.R index cc114d3..0048e13 100644 --- a/R/tune-grid.R +++ b/R/tune-grid.R @@ -82,12 +82,25 @@ tune_grid_spark.pyspark_connection <- function( # Uses the `loop_call` function as the base of the UDF that will be sent to # the Spark session. It works by modifying the text of the function, specifically # the file names it reads to load the different R object components + + # Inject function capture code for Spark 4.1.1+ compatibility + function_capture_code <- " + library(tidymodels) + # Capture internal tune functions for Spark worker serialization + get_data_subsets <- getFromNamespace('get_data_subsets', 'tune') + if (exists('.loop_over_all_stages', where = asNamespace('tune'))) { + loop_over_all_stages <- getFromNamespace('.loop_over_all_stages', 'tune') + } else { + loop_over_all_stages <- getFromNamespace('loop_over_all_stages', 'tune') + } + " + grid_code <- loop_call |> deparse() |> paste0(collapse = "\n") |> str_replace("\"rsample\"", pasted_pkgs) |> str_replace("debug <- TRUE", "debug <- FALSE") |> - str_replace("xy <- 1", "library(tidymodels)") |> + str_replace("xy <- 1", function_capture_code) |> str_replace("static.rds", path(hash_static, ext = "rds")) |> str_replace("resamples.rds", path(hash_resamples, ext = "rds")) @@ -307,6 +320,19 @@ loop_call <- function(x) { stop("Packages ", missing_pkgs, " are missing") } xy <- 1 + + # Fallback: Capture functions if not injected (for direct calls in tests) + if (!exists("get_data_subsets")) { + get_data_subsets <- getFromNamespace("get_data_subsets", "tune") + } + if (!exists("loop_over_all_stages")) { + if (exists(".loop_over_all_stages", where = asNamespace("tune"))) { + loop_over_all_stages <- getFromNamespace(".loop_over_all_stages", "tune") + } else { + loop_over_all_stages <- getFromNamespace("loop_over_all_stages", "tune") + } + } + # ------------------- Reads files with needed R objects ---------------------- # Loads the needed R objects from disk debug <- TRUE @@ -340,7 +366,7 @@ loop_call <- function(x) { index <- curr_x$index curr_resample <- resamples[[index]] - data_splits <- tune:::get_data_subsets( + data_splits <- get_data_subsets( static$wflow, curr_resample$splits[[1]], static$split_args @@ -351,13 +377,8 @@ loop_call <- function(x) { curr_grid <- tibble::as_tibble(curr_grid) assign(".Random.seed", c(1L, 2L, 3L), envir = .GlobalEnv) # ------ Sends current combination to `tune` for processing ---------------- - # TODO: This function check exists because the `tune` version in the Spark - # cluster may be CRAN. This needs to be removed by the time of release - if (exists(".loop_over_all_stages", where = "package:tune")) { - res <- tune::.loop_over_all_stages(curr_resample, curr_grid, static) - } else { - res <- tune:::loop_over_all_stages(curr_resample, curr_grid, static) - } + # Use the captured function to ensure it's available in Spark workers + res <- loop_over_all_stages(curr_resample, curr_grid, static) # -------------------- Extracts metrics from results ----------------------- # Mapping function accepts only tables as output, so only the metrics are # being sent back instead of the entire results object From 579febbee86f98b3790e0c47a7efc6ec3e5d913f Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sun, 8 Mar 2026 15:44:09 -0500 Subject: [PATCH 12/13] Further refinements --- R/tune-grid.R | 61 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/R/tune-grid.R b/R/tune-grid.R index 0048e13..9d43e90 100644 --- a/R/tune-grid.R +++ b/R/tune-grid.R @@ -78,22 +78,28 @@ tune_grid_spark.pyspark_connection <- function( } spark_session_add_file(vec_resamples, sc, hash_resamples) + # ------------------- Upload tune internal functions ------------------------- + # For Spark 4.1.1+ with Python 3.13, internal functions don't serialize properly + # Capture them here and upload as an RDS file + tune_fns <- list( + get_data_subsets = getFromNamespace("get_data_subsets", "tune"), + loop_over_all_stages = if (exists(".loop_over_all_stages", where = asNamespace("tune"))) { + getFromNamespace(".loop_over_all_stages", "tune") + } else { + getFromNamespace("loop_over_all_stages", "tune") + } + ) + hash_tune_fns <- "tune_fns" + spark_session_add_file(tune_fns, sc, hash_tune_fns) + # -------------------------- Creates the UDF --------------------------------- # Uses the `loop_call` function as the base of the UDF that will be sent to # the Spark session. It works by modifying the text of the function, specifically # the file names it reads to load the different R object components - # Inject function capture code for Spark 4.1.1+ compatibility - function_capture_code <- " - library(tidymodels) - # Capture internal tune functions for Spark worker serialization - get_data_subsets <- getFromNamespace('get_data_subsets', 'tune') - if (exists('.loop_over_all_stages', where = asNamespace('tune'))) { - loop_over_all_stages <- getFromNamespace('.loop_over_all_stages', 'tune') - } else { - loop_over_all_stages <- getFromNamespace('loop_over_all_stages', 'tune') - } - " + # Inject function loading code for Spark 4.1.1+ compatibility + # This will be inserted at the top of loop_call and will load tune functions + function_capture_code <- "library(tidymodels)" grid_code <- loop_call |> deparse() |> @@ -102,7 +108,8 @@ tune_grid_spark.pyspark_connection <- function( str_replace("debug <- TRUE", "debug <- FALSE") |> str_replace("xy <- 1", function_capture_code) |> str_replace("static.rds", path(hash_static, ext = "rds")) |> - str_replace("resamples.rds", path(hash_resamples, ext = "rds")) + str_replace("resamples.rds", path(hash_resamples, ext = "rds")) |> + str_replace("tune_fns.rds", path(hash_tune_fns, ext = "rds")) # -------------------- Creates and uploads the grid ------------------------- res_id_df <- purrr::map_df( @@ -321,35 +328,41 @@ loop_call <- function(x) { } xy <- 1 - # Fallback: Capture functions if not injected (for direct calls in tests) - if (!exists("get_data_subsets")) { - get_data_subsets <- getFromNamespace("get_data_subsets", "tune") - } - if (!exists("loop_over_all_stages")) { - if (exists(".loop_over_all_stages", where = asNamespace("tune"))) { - loop_over_all_stages <- getFromNamespace(".loop_over_all_stages", "tune") - } else { - loop_over_all_stages <- getFromNamespace("loop_over_all_stages", "tune") - } - } - # ------------------- Reads files with needed R objects ---------------------- # Loads the needed R objects from disk debug <- TRUE static_fname <- "static.rds" resample_fname <- "resamples.rds" + tune_fns_fname <- "tune_fns.rds" if (isFALSE(debug)) { pyspark <- reticulate::import("pyspark") static_file <- pyspark$SparkFiles$get(static_fname) resample_file <- pyspark$SparkFiles$get(resample_fname) + tune_fns_file <- pyspark$SparkFiles$get(tune_fns_fname) } else { temp_path <- Sys.getenv("TEMP_SPARK_GRID", unset = "~") static_file <- file.path(temp_path, static_fname) resample_file <- file.path(temp_path, resample_fname) + tune_fns_file <- file.path(temp_path, tune_fns_fname) } static <- readRDS(static_file) resamples <- readRDS(resample_file) + # Load tune internal functions (or use fallback for direct test calls) + if (file.exists(tune_fns_file)) { + tune_fns <- readRDS(tune_fns_file) + get_data_subsets <- tune_fns$get_data_subsets + loop_over_all_stages <- tune_fns$loop_over_all_stages + } else { + # Fallback for direct calls in tests (debug mode without uploaded file) + get_data_subsets <- getFromNamespace("get_data_subsets", "tune") + if (exists(".loop_over_all_stages", where = asNamespace("tune"))) { + loop_over_all_stages <- getFromNamespace(".loop_over_all_stages", "tune") + } else { + loop_over_all_stages <- getFromNamespace("loop_over_all_stages", "tune") + } + } + # ------------ Iterates through all the combinations in `x` ------------------ # Spark will more likely send more than one row (combination) in `x`. It # will depend on how the grid data frame was partitioned inside Spark. From 818053ef31a3e3f0423ac2f0b4e4ea7386586584 Mon Sep 17 00:00:00 2001 From: Edgar Ruiz <77294576+edgararuiz@users.noreply.github.com> Date: Sun, 8 Mar 2026 16:14:46 -0500 Subject: [PATCH 13/13] Confirms it work on dev and CRAN tune --- R/tune-grid.R | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/R/tune-grid.R b/R/tune-grid.R index 9d43e90..838bb6c 100644 --- a/R/tune-grid.R +++ b/R/tune-grid.R @@ -1,3 +1,17 @@ +# Helper to get tune functions that may have dots in dev version +tune_fn <- function(name) { + # Dev tune exports with dot prefix, CRAN tune without + dot_name <- paste0(".", name) + ns <- asNamespace("tune") + if (exists(dot_name, where = ns)) { + getFromNamespace(dot_name, "tune") + } else if (exists(name, where = ns)) { + getFromNamespace(name, "tune") + } else { + stop("Function '", name, "' not found in tune package") + } +} + #' @importFrom sparklyr tune_grid_spark #' @export tune_grid_spark.pyspark_connection <- function( @@ -81,13 +95,10 @@ tune_grid_spark.pyspark_connection <- function( # ------------------- Upload tune internal functions ------------------------- # For Spark 4.1.1+ with Python 3.13, internal functions don't serialize properly # Capture them here and upload as an RDS file + # Use tune_fn() to handle both dev (with dot) and CRAN (without dot) versions tune_fns <- list( - get_data_subsets = getFromNamespace("get_data_subsets", "tune"), - loop_over_all_stages = if (exists(".loop_over_all_stages", where = asNamespace("tune"))) { - getFromNamespace(".loop_over_all_stages", "tune") - } else { - getFromNamespace("loop_over_all_stages", "tune") - } + get_data_subsets = tune_fn("get_data_subsets"), + loop_over_all_stages = tune_fn("loop_over_all_stages") ) hash_tune_fns <- "tune_fns" spark_session_add_file(tune_fns, sc, hash_tune_fns) @@ -355,12 +366,17 @@ loop_call <- function(x) { loop_over_all_stages <- tune_fns$loop_over_all_stages } else { # Fallback for direct calls in tests (debug mode without uploaded file) - get_data_subsets <- getFromNamespace("get_data_subsets", "tune") - if (exists(".loop_over_all_stages", where = asNamespace("tune"))) { + # Try both naming conventions (dev uses dots, CRAN doesn't) + tryCatch({ + get_data_subsets <- getFromNamespace(".get_data_subsets", "tune") + }, error = function(e) { + get_data_subsets <<- getFromNamespace("get_data_subsets", "tune") + }) + tryCatch({ loop_over_all_stages <- getFromNamespace(".loop_over_all_stages", "tune") - } else { - loop_over_all_stages <- getFromNamespace("loop_over_all_stages", "tune") - } + }, error = function(e) { + loop_over_all_stages <<- getFromNamespace("loop_over_all_stages", "tune") + }) } # ------------ Iterates through all the combinations in `x` ------------------ @@ -485,7 +501,7 @@ prep_static <- function( data = resamples$splits[[1]]$data, grid_names = names(grid) ) - grid <- tune::.check_grid( + grid <- tune_fn("check_grid")( grid = grid, workflow = wf, pset = param_info @@ -509,7 +525,7 @@ prep_static <- function( control_err )) } - control <- tune::.update_parallel_over(control, resamples, grid) + control <- tune_fn("update_parallel_over")(control, resamples, grid) eval_time <- tune::check_eval_time_arg(eval_time, wf_metrics, call = call) needed_pkgs <- c( "rsample", @@ -529,11 +545,11 @@ prep_static <- function( out$static <- list( wflow = wf, param_info = param_info, - configs = tune::.get_config_key(grid, wf), + configs = tune_fn("get_config_key")(grid, wf), post_estimation = workflows::.workflow_postprocessor_requires_fit(wf), metrics = wf_metrics, metric_info = tibble::as_tibble(wf_metrics), - pred_types = tune::.determine_pred_types(wf, wf_metrics), + pred_types = tune_fn("determine_pred_types")(wf, wf_metrics), eval_time = eval_time, split_args = rsample::.get_split_args(resamples), control = control,