diff --git a/.github/workflows/spark-tests.yaml b/.github/workflows/spark-tests.yaml index b458e7c..5104a58 100644 --- a/.github/workflows/spark-tests.yaml +++ b/.github/workflows/spark-tests.yaml @@ -31,6 +31,11 @@ jobs: steps: - uses: actions/checkout@v3 + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '17' + - uses: r-lib/actions/setup-r@v2 with: r-version: 'release' diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml index c031bf3..b19f44c 100644 --- a/.github/workflows/test-coverage.yaml +++ b/.github/workflows/test-coverage.yaml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: config: - - {spark: '4.0.1', pyspark: '4.0.1', hadoop: '3', scala: '2.13', python: '3.10', name: 'PySpark 4'} + - {spark: '4.1.1', pyspark: '4.1.1', hadoop: '3', scala: '2.13', python: '3.13', name: 'PySpark 4'} env: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} @@ -31,6 +31,11 @@ jobs: steps: - uses: actions/checkout@v3 + - uses: actions/setup-java@v4 + with: + distribution: 'temurin' + java-version: '21' + - uses: r-lib/actions/setup-r@v2 with: r-version: 'release' @@ -45,8 +50,14 @@ jobs: any::tidymodels any::probably any::rpart + any::pak needs: check + - name: Install sparklyr dev + run: | + pak::pak("sparklyr/sparklyr") + shell: Rscript {0} + - name: Cache Spark id: cache-spark uses: actions/cache@v3 diff --git a/R/python-to-pandas-cleaned.R b/R/python-to-pandas-cleaned.R index cf6747b..ae8d948 100644 --- a/R/python-to-pandas-cleaned.R +++ b/R/python-to-pandas-cleaned.R @@ -23,10 +23,22 @@ to_pandas_cleaned <- function(x) { } else { # Pandas 3.0 conversion makes encases all columns inside lists collected <- try( - pandas_tbl$values |> - as.data.frame() |> - lapply(unlist) |> - set_names(x$columns), + { + as.data.frame(pandas_tbl$values) |> + lapply(\(col) { + map_vec( + col, + \(x) { + if (length(x) == 0 || is.nan(x)) { + NA + } else { + x[[1]] + } + } + ) + }) |> + set_names(x$columns) + }, silent = TRUE ) if (inherits(collected, "try-error")) { @@ -35,6 +47,42 @@ to_pandas_cleaned <- function(x) { } } + # Handle NULL columns by converting them to appropriate empty vectors + for (i in seq_along(collected)) { + if (is.null(collected[[i]])) { + collected[[i]] <- character(0) + } + # Convert arrow_binary to list + if (inherits(collected[[i]], "arrow_binary")) { + collected[[i]] <- as.list(collected[[i]]) + } + # Convert UTC timestamps to local timezone + if (inherits(collected[[i]], "POSIXct")) { + # If the timestamp is in UTC, convert to local timezone + if (!is.null(attr(collected[[i]], "tzone")) && attr(collected[[i]], "tzone") == "UTC") { + attr(collected[[i]], "tzone") <- "" + } + } + # Convert pandas Arrow-backed arrays to regular R vectors + if (inherits(collected[[i]], "python.builtin.object")) { + col_class <- try(class(collected[[i]]), silent = TRUE) + if (!inherits(col_class, "try-error")) { + # Check if it's an Arrow-backed array from pandas + if (any(grepl("ArrowStringArray|ArrowExtensionArray", col_class))) { + # Convert to regular character vector using numpy + collected[[i]] <- try( + as.character(collected[[i]]$to_numpy()), + silent = TRUE + ) + # Fallback to Python list conversion if to_numpy fails + if (inherits(collected[[i]], "try-error")) { + collected[[i]] <- as.character(collected[[i]]$tolist()) + } + } + } + } + } + collected <- collected |> dplyr::as_tibble() @@ -64,11 +112,16 @@ to_pandas_cleaned <- function(x) { as.Date(map_vec(col, clean_col), origin = "1970-01-01") } else if (py_type == "date" && r_type == "character") { as.Date(col, origin = "1970-01-01") + } else if (py_type == "date" && r_type %in% c("numeric", "integer", "logical")) { + as.Date(col, origin = "1970-01-01") } else if (py_type == "boolean" && r_type == "list") { map_lgl(col, clean_col) } else if (py_type == "boolean" && r_type == "character") { as.logical(col) - } else if (r_type == "numeric") { + } else if (r_type == "integer" && py_type %in% c("bigint", "long")) { + # Convert bigint/long to numeric since R integer is 32-bit + as.numeric(col) + } else if (r_type == "numeric" || r_type == "integer") { if (py_type %in% c("tinyint", "smallint", "int")) { ptype <- integer() } else { diff --git a/R/start-stop-service.R b/R/start-stop-service.R index 7e49a9c..4771e88 100644 --- a/R/start-stop-service.R +++ b/R/start-stop-service.R @@ -74,12 +74,32 @@ spark_connect_service_start <- function( } ) - output <- prs$read_all_output() - cli_bullets(c(" " = "{.info {output}}")) - error <- prs$read_all_error() - if (error != "") { - cli_abort(error) + # Wait briefly to see if there's immediate output or errors + # Don't use read_all_output() as it blocks until process closes stdout + # Spark Connect is a long-running service, so we only check initial output + prs$poll_io(timeout = 2000) # Wait max 2 seconds + + # Read only what's available, don't wait for EOF + if (prs$is_alive()) { + output <- prs$read_output_lines() + if (length(output) > 0) { + cli_bullets(c(" " = "{.info {paste(output, collapse = '\n')}}")) + } + error <- prs$read_error_lines() + if (length(error) > 0 && any(nzchar(error))) { + cli_alert_warning(paste(error, collapse = "\n")) + } + } else { + # Process exited immediately, likely an error + if (prs$get_exit_status() != 0) { + error <- prs$read_all_error() + cli_abort(c("Failed to start Spark Connect service", "x" = error)) + } } + + # Store the process for potential cleanup + assign("spark_connect_process", prs, envir = .GlobalEnv) + cli_end() invisible() } @@ -97,6 +117,17 @@ spark_connect_service_stop <- function(version = "4.0", ...) { stderr = "|", stdin = "|" ) + + # Wait for shutdown command to complete + prs$wait(timeout = 5000) # Wait max 5 seconds + cli_bullets(c(" " = "{.info - Shutdown command sent}")) + + # Clean up stored process reference + if (exists("spark_connect_process", envir = .GlobalEnv)) { + rm("spark_connect_process", envir = .GlobalEnv) + } + cli_end() + invisible() } diff --git a/R/tune-grid.R b/R/tune-grid.R index 1f47c07..838bb6c 100644 --- a/R/tune-grid.R +++ b/R/tune-grid.R @@ -1,3 +1,17 @@ +# Helper to get tune functions that may have dots in dev version +tune_fn <- function(name) { + # Dev tune exports with dot prefix, CRAN tune without + dot_name <- paste0(".", name) + ns <- asNamespace("tune") + if (exists(dot_name, where = ns)) { + getFromNamespace(dot_name, "tune") + } else if (exists(name, where = ns)) { + getFromNamespace(name, "tune") + } else { + stop("Function '", name, "' not found in tune package") + } +} + #' @importFrom sparklyr tune_grid_spark #' @export tune_grid_spark.pyspark_connection <- function( @@ -57,7 +71,7 @@ tune_grid_spark.pyspark_connection <- function( # section vec_resamples <- resamples |> vctrs::vec_split(by = 1:nrow(resamples)) |> - _$val + (\(x) x$val)() pasted_pkgs <- paste0("'", prepped$needed_pkgs, "'", collapse = ", ") # --------------- Prepares and uploads R objects to Spark -------------------- @@ -78,18 +92,35 @@ tune_grid_spark.pyspark_connection <- function( } spark_session_add_file(vec_resamples, sc, hash_resamples) + # ------------------- Upload tune internal functions ------------------------- + # For Spark 4.1.1+ with Python 3.13, internal functions don't serialize properly + # Capture them here and upload as an RDS file + # Use tune_fn() to handle both dev (with dot) and CRAN (without dot) versions + tune_fns <- list( + get_data_subsets = tune_fn("get_data_subsets"), + loop_over_all_stages = tune_fn("loop_over_all_stages") + ) + hash_tune_fns <- "tune_fns" + spark_session_add_file(tune_fns, sc, hash_tune_fns) + # -------------------------- Creates the UDF --------------------------------- # Uses the `loop_call` function as the base of the UDF that will be sent to # the Spark session. It works by modifying the text of the function, specifically # the file names it reads to load the different R object components + + # Inject function loading code for Spark 4.1.1+ compatibility + # This will be inserted at the top of loop_call and will load tune functions + function_capture_code <- "library(tidymodels)" + grid_code <- loop_call |> deparse() |> paste0(collapse = "\n") |> str_replace("\"rsample\"", pasted_pkgs) |> str_replace("debug <- TRUE", "debug <- FALSE") |> - str_replace("xy <- 1", "library(tidymodels)") |> + str_replace("xy <- 1", function_capture_code) |> str_replace("static.rds", path(hash_static, ext = "rds")) |> - str_replace("resamples.rds", path(hash_resamples, ext = "rds")) + str_replace("resamples.rds", path(hash_resamples, ext = "rds")) |> + str_replace("tune_fns.rds", path(hash_tune_fns, ext = "rds")) # -------------------- Creates and uploads the grid ------------------------- res_id_df <- purrr::map_df( @@ -307,23 +338,47 @@ loop_call <- function(x) { stop("Packages ", missing_pkgs, " are missing") } xy <- 1 + # ------------------- Reads files with needed R objects ---------------------- # Loads the needed R objects from disk debug <- TRUE static_fname <- "static.rds" resample_fname <- "resamples.rds" + tune_fns_fname <- "tune_fns.rds" if (isFALSE(debug)) { pyspark <- reticulate::import("pyspark") static_file <- pyspark$SparkFiles$get(static_fname) resample_file <- pyspark$SparkFiles$get(resample_fname) + tune_fns_file <- pyspark$SparkFiles$get(tune_fns_fname) } else { temp_path <- Sys.getenv("TEMP_SPARK_GRID", unset = "~") static_file <- file.path(temp_path, static_fname) resample_file <- file.path(temp_path, resample_fname) + tune_fns_file <- file.path(temp_path, tune_fns_fname) } static <- readRDS(static_file) resamples <- readRDS(resample_file) + # Load tune internal functions (or use fallback for direct test calls) + if (file.exists(tune_fns_file)) { + tune_fns <- readRDS(tune_fns_file) + get_data_subsets <- tune_fns$get_data_subsets + loop_over_all_stages <- tune_fns$loop_over_all_stages + } else { + # Fallback for direct calls in tests (debug mode without uploaded file) + # Try both naming conventions (dev uses dots, CRAN doesn't) + tryCatch({ + get_data_subsets <- getFromNamespace(".get_data_subsets", "tune") + }, error = function(e) { + get_data_subsets <<- getFromNamespace("get_data_subsets", "tune") + }) + tryCatch({ + loop_over_all_stages <- getFromNamespace(".loop_over_all_stages", "tune") + }, error = function(e) { + loop_over_all_stages <<- getFromNamespace("loop_over_all_stages", "tune") + }) + } + # ------------ Iterates through all the combinations in `x` ------------------ # Spark will more likely send more than one row (combination) in `x`. It # will depend on how the grid data frame was partitioned inside Spark. @@ -340,7 +395,7 @@ loop_call <- function(x) { index <- curr_x$index curr_resample <- resamples[[index]] - data_splits <- tune:::get_data_subsets( + data_splits <- get_data_subsets( static$wflow, curr_resample$splits[[1]], static$split_args @@ -351,13 +406,8 @@ loop_call <- function(x) { curr_grid <- tibble::as_tibble(curr_grid) assign(".Random.seed", c(1L, 2L, 3L), envir = .GlobalEnv) # ------ Sends current combination to `tune` for processing ---------------- - # TODO: This function check exists because the `tune` version in the Spark - # cluster may be CRAN. This needs to be removed by the time of release - if (exists(".loop_over_all_stages", where = "package:tune")) { - res <- tune::.loop_over_all_stages(curr_resample, curr_grid, static) - } else { - res <- tune:::loop_over_all_stages(curr_resample, curr_grid, static) - } + # Use the captured function to ensure it's available in Spark workers + res <- loop_over_all_stages(curr_resample, curr_grid, static) # -------------------- Extracts metrics from results ----------------------- # Mapping function accepts only tables as output, so only the metrics are # being sent back instead of the entire results object @@ -451,7 +501,7 @@ prep_static <- function( data = resamples$splits[[1]]$data, grid_names = names(grid) ) - grid <- tune::.check_grid( + grid <- tune_fn("check_grid")( grid = grid, workflow = wf, pset = param_info @@ -475,7 +525,7 @@ prep_static <- function( control_err )) } - control <- tune::.update_parallel_over(control, resamples, grid) + control <- tune_fn("update_parallel_over")(control, resamples, grid) eval_time <- tune::check_eval_time_arg(eval_time, wf_metrics, call = call) needed_pkgs <- c( "rsample", @@ -495,11 +545,11 @@ prep_static <- function( out$static <- list( wflow = wf, param_info = param_info, - configs = tune::.get_config_key(grid, wf), + configs = tune_fn("get_config_key")(grid, wf), post_estimation = workflows::.workflow_postprocessor_requires_fit(wf), metrics = wf_metrics, metric_info = tibble::as_tibble(wf_metrics), - pred_types = tune::.determine_pred_types(wf, wf_metrics), + pred_types = tune_fn("determine_pred_types")(wf, wf_metrics), eval_time = eval_time, split_args = rsample::.get_split_args(resamples), control = control, diff --git a/tests/testthat.R b/tests/testthat.R index 85e89fc..00aad07 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -7,7 +7,7 @@ # * https://testthat.r-lib.org/articles/special-files.html # Sys.setenv("CODE_COVERAGE" = "true") -# Sys.setenv("SPARK_VERSION" = "4.0.1"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.10") +# Sys.setenv("SPARK_VERSION" = "4.1.1"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.13") # Sys.setenv("SPARK_VERSION" = "3.5.7"); Sys.setenv("SCALA_VERSION" = "2.12"); Sys.setenv("PYTHON_VERSION" = "3.10") if (identical(Sys.getenv("CODE_COVERAGE"), "true")) { library(testthat) diff --git a/tests/testthat/_snaps/ml-feature-transformers.md b/tests/testthat/_snaps/ml-feature-transformers.md index dba7be8..aa161f4 100644 --- a/tests/testthat/_snaps/ml-feature-transformers.md +++ b/tests/testthat/_snaps/ml-feature-transformers.md @@ -611,6 +611,7 @@ dplyr::pull(ft_ngram(ft_tokenizer(use_test_table_reviews(), "x", "token_x"), "token_x", "ngram_x")) Output + [1]> [[1]] [1] "this has" "has been" "been the" "the best" [5] "best tv" "tv i've" "i've ever" "ever used." @@ -888,6 +889,7 @@ Code dplyr::pull(ft_regex_tokenizer(use_test_table_reviews(), "x", "new_x")) Output + [1]> [[1]] [1] "this" "has" "been" "the" "best" "tv" "i've" [8] "ever" "used." "great" "screen," "and" "sound." @@ -1063,6 +1065,7 @@ dplyr::pull(ft_stop_words_remover(ft_tokenizer(use_test_table_reviews(), input_col = "x", output_col = "token_x"), input_col = "token_x", output_col = "stop_x")) Output + [1]> [[1]] [1] "best" "tv" "ever" "used." "great" "screen," "sound." @@ -1114,6 +1117,7 @@ Code dplyr::pull(ft_tokenizer(use_test_table_reviews(), input_col = "x", output_col = "token_x")) Output + [1]> [[1]] [1] "this" "has" "been" "the" "best" "tv" "i've" [8] "ever" "used." "great" "screen," "and" "sound." diff --git a/tests/testthat/helper-init.R b/tests/testthat/helper-init.R index 9059abb..b359707 100644 --- a/tests/testthat/helper-init.R +++ b/tests/testthat/helper-init.R @@ -59,6 +59,7 @@ use_test_connect_start <- function() { cli_inform("SCALA_VERSION: {use_test_scala_spark()}") cli_inform("PYTHON_VERSION: {use_test_python_version()}") cli_h2("") + print(reticulate::py_list_packages()) withr::with_envvar( new = c( diff --git a/tests/testthat/helper-ml.R b/tests/testthat/helper-ml.R index ab3e5fc..a70fa10 100644 --- a/tests/testthat/helper-ml.R +++ b/tests/testthat/helper-ml.R @@ -3,6 +3,23 @@ use_test_pull <- function(x, table = FALSE) { if (table) { x <- table(x) } + + # Handle Spark ML vectors that come through as structured data frames + # (Spark 4.1+ with Pandas 3.0+ converts ML vectors to data frames with type, size, indices, values columns) + if (is.data.frame(x) && all(c("type", "size", "indices", "values") %in% names(x))) { + # Extract the values column which contains the actual vector data + x <- data.frame( + x = map_chr(x$values, function(vec) { + if (is.null(vec) || length(vec) == 0) { + "" + } else { + paste(as.vector(vec), collapse = ", ") + } + }) + ) + return(x) + } + if (inherits(x[[1]], "array")) { x <- as.double(map(x, as.vector)) } @@ -16,6 +33,13 @@ use_test_pull <- function(x, table = FALSE) { x = map_chr(x, function(x) paste(as.vector(x$array), collapse = ", ")) ) } + # Handle vectors that have been converted to numeric during pandas conversion + # (Spark 4.1+ with Pandas 3.0+ converts ML vectors to numeric vectors) + if (is.list(x) && length(x) > 0 && is.numeric(x[[1]]) && !is.data.frame(x)) { + x <- data.frame( + x = map_chr(x, function(vec) paste(as.vector(vec), collapse = ", ")) + ) + } x } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 020ac7c..8f5b57c 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -3,11 +3,22 @@ suppressPackageStartupMessages(library(sparklyr)) suppressPackageStartupMessages(library(cli)) ## Clean up at end -# withr::defer({ -# Disconnecting from Spark -# withr::defer(spark_disconnect_all(), teardown_env()) -# Stopping Spark Connect service -# withr::defer(pysparklyr::spark_connect_service_stop("4.0.1")) -# Deleting main Python environment -# withr::defer(fs::dir_delete(use_test_env()), teardown_env()) -# }) +withr::defer({ + # Disconnecting from Spark + try(spark_disconnect_all(), silent = TRUE) + + # Stopping Spark Connect service + try(spark_connect_service_stop(use_test_version_spark()), silent = TRUE) + + # Kill process if still running + if (exists("spark_connect_process", envir = .GlobalEnv)) { + prs <- get("spark_connect_process", envir = .GlobalEnv) + if (prs$is_alive()) { + try(prs$kill(), silent = TRUE) + } + rm("spark_connect_process", envir = .GlobalEnv) + } + + # Deleting main Python environment (commented out as it may be too aggressive) + # try(fs::dir_delete(use_test_env()), silent = TRUE) +}, testthat::teardown_env())