Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/spark-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ jobs:
steps:
- uses: actions/checkout@v3

- uses: actions/setup-java@v4
with:
distribution: 'temurin'
java-version: '17'

- uses: r-lib/actions/setup-r@v2
with:
r-version: 'release'
Expand Down
13 changes: 12 additions & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
config:
- {spark: '4.0.1', pyspark: '4.0.1', hadoop: '3', scala: '2.13', python: '3.10', name: 'PySpark 4'}
- {spark: '4.1.1', pyspark: '4.1.1', hadoop: '3', scala: '2.13', python: '3.13', name: 'PySpark 4'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand All @@ -31,6 +31,11 @@ jobs:
steps:
- uses: actions/checkout@v3

- uses: actions/setup-java@v4
with:
distribution: 'temurin'
java-version: '21'

- uses: r-lib/actions/setup-r@v2
with:
r-version: 'release'
Expand All @@ -45,8 +50,14 @@ jobs:
any::tidymodels
any::probably
any::rpart
any::pak
needs: check

- name: Install sparklyr dev
run: |
pak::pak("sparklyr/sparklyr")
shell: Rscript {0}

- name: Cache Spark
id: cache-spark
uses: actions/cache@v3
Expand Down
63 changes: 58 additions & 5 deletions R/python-to-pandas-cleaned.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,22 @@ to_pandas_cleaned <- function(x) {
} else {
# Pandas 3.0 conversion makes encases all columns inside lists
collected <- try(
pandas_tbl$values |>
as.data.frame() |>
lapply(unlist) |>
set_names(x$columns),
{
as.data.frame(pandas_tbl$values) |>
lapply(\(col) {
map_vec(
col,
\(x) {
if (length(x) == 0 || is.nan(x)) {
NA
} else {
x[[1]]
}
}
)
}) |>
set_names(x$columns)
},
silent = TRUE
)
if (inherits(collected, "try-error")) {
Expand All @@ -35,6 +47,42 @@ to_pandas_cleaned <- function(x) {
}
}

# Handle NULL columns by converting them to appropriate empty vectors
for (i in seq_along(collected)) {
if (is.null(collected[[i]])) {
collected[[i]] <- character(0)
}
# Convert arrow_binary to list
if (inherits(collected[[i]], "arrow_binary")) {
collected[[i]] <- as.list(collected[[i]])
}
# Convert UTC timestamps to local timezone
if (inherits(collected[[i]], "POSIXct")) {
# If the timestamp is in UTC, convert to local timezone
if (!is.null(attr(collected[[i]], "tzone")) && attr(collected[[i]], "tzone") == "UTC") {
attr(collected[[i]], "tzone") <- ""
}
}
# Convert pandas Arrow-backed arrays to regular R vectors
if (inherits(collected[[i]], "python.builtin.object")) {
col_class <- try(class(collected[[i]]), silent = TRUE)
if (!inherits(col_class, "try-error")) {
# Check if it's an Arrow-backed array from pandas
if (any(grepl("ArrowStringArray|ArrowExtensionArray", col_class))) {
# Convert to regular character vector using numpy
collected[[i]] <- try(
as.character(collected[[i]]$to_numpy()),
silent = TRUE
)
# Fallback to Python list conversion if to_numpy fails
if (inherits(collected[[i]], "try-error")) {
collected[[i]] <- as.character(collected[[i]]$tolist())
}
}
}
}
}

collected <- collected |>
dplyr::as_tibble()

Expand Down Expand Up @@ -64,11 +112,16 @@ to_pandas_cleaned <- function(x) {
as.Date(map_vec(col, clean_col), origin = "1970-01-01")
} else if (py_type == "date" && r_type == "character") {
as.Date(col, origin = "1970-01-01")
} else if (py_type == "date" && r_type %in% c("numeric", "integer", "logical")) {
as.Date(col, origin = "1970-01-01")
} else if (py_type == "boolean" && r_type == "list") {
map_lgl(col, clean_col)
} else if (py_type == "boolean" && r_type == "character") {
as.logical(col)
} else if (r_type == "numeric") {
} else if (r_type == "integer" && py_type %in% c("bigint", "long")) {
# Convert bigint/long to numeric since R integer is 32-bit
as.numeric(col)
} else if (r_type == "numeric" || r_type == "integer") {
if (py_type %in% c("tinyint", "smallint", "int")) {
ptype <- integer()
} else {
Expand Down
41 changes: 36 additions & 5 deletions R/start-stop-service.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,32 @@ spark_connect_service_start <- function(
}
)

output <- prs$read_all_output()
cli_bullets(c(" " = "{.info {output}}"))
error <- prs$read_all_error()
if (error != "") {
cli_abort(error)
# Wait briefly to see if there's immediate output or errors
# Don't use read_all_output() as it blocks until process closes stdout
# Spark Connect is a long-running service, so we only check initial output
prs$poll_io(timeout = 2000) # Wait max 2 seconds

# Read only what's available, don't wait for EOF
if (prs$is_alive()) {
output <- prs$read_output_lines()
if (length(output) > 0) {
cli_bullets(c(" " = "{.info {paste(output, collapse = '\n')}}"))
}
error <- prs$read_error_lines()
if (length(error) > 0 && any(nzchar(error))) {
cli_alert_warning(paste(error, collapse = "\n"))
}
} else {
# Process exited immediately, likely an error
if (prs$get_exit_status() != 0) {
error <- prs$read_all_error()
cli_abort(c("Failed to start Spark Connect service", "x" = error))
}
}

# Store the process for potential cleanup
assign("spark_connect_process", prs, envir = .GlobalEnv)

cli_end()
invisible()
}
Expand All @@ -97,6 +117,17 @@ spark_connect_service_stop <- function(version = "4.0", ...) {
stderr = "|",
stdin = "|"
)

# Wait for shutdown command to complete
prs$wait(timeout = 5000) # Wait max 5 seconds

cli_bullets(c(" " = "{.info - Shutdown command sent}"))

# Clean up stored process reference
if (exists("spark_connect_process", envir = .GlobalEnv)) {
rm("spark_connect_process", envir = .GlobalEnv)
}

cli_end()
invisible()
}
80 changes: 65 additions & 15 deletions R/tune-grid.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Helper to get tune functions that may have dots in dev version
tune_fn <- function(name) {
# Dev tune exports with dot prefix, CRAN tune without
dot_name <- paste0(".", name)
ns <- asNamespace("tune")
if (exists(dot_name, where = ns)) {
getFromNamespace(dot_name, "tune")
} else if (exists(name, where = ns)) {
getFromNamespace(name, "tune")
} else {
stop("Function '", name, "' not found in tune package")
}
}

#' @importFrom sparklyr tune_grid_spark
#' @export
tune_grid_spark.pyspark_connection <- function(
Expand Down Expand Up @@ -57,7 +71,7 @@ tune_grid_spark.pyspark_connection <- function(
# section
vec_resamples <- resamples |>
vctrs::vec_split(by = 1:nrow(resamples)) |>
_$val
(\(x) x$val)()
pasted_pkgs <- paste0("'", prepped$needed_pkgs, "'", collapse = ", ")

# --------------- Prepares and uploads R objects to Spark --------------------
Expand All @@ -78,18 +92,35 @@ tune_grid_spark.pyspark_connection <- function(
}
spark_session_add_file(vec_resamples, sc, hash_resamples)

# ------------------- Upload tune internal functions -------------------------
# For Spark 4.1.1+ with Python 3.13, internal functions don't serialize properly
# Capture them here and upload as an RDS file
# Use tune_fn() to handle both dev (with dot) and CRAN (without dot) versions
tune_fns <- list(
get_data_subsets = tune_fn("get_data_subsets"),
loop_over_all_stages = tune_fn("loop_over_all_stages")
)
hash_tune_fns <- "tune_fns"
spark_session_add_file(tune_fns, sc, hash_tune_fns)

# -------------------------- Creates the UDF ---------------------------------
# Uses the `loop_call` function as the base of the UDF that will be sent to
# the Spark session. It works by modifying the text of the function, specifically
# the file names it reads to load the different R object components

# Inject function loading code for Spark 4.1.1+ compatibility
# This will be inserted at the top of loop_call and will load tune functions
function_capture_code <- "library(tidymodels)"

grid_code <- loop_call |>
deparse() |>
paste0(collapse = "\n") |>
str_replace("\"rsample\"", pasted_pkgs) |>
str_replace("debug <- TRUE", "debug <- FALSE") |>
str_replace("xy <- 1", "library(tidymodels)") |>
str_replace("xy <- 1", function_capture_code) |>
str_replace("static.rds", path(hash_static, ext = "rds")) |>
str_replace("resamples.rds", path(hash_resamples, ext = "rds"))
str_replace("resamples.rds", path(hash_resamples, ext = "rds")) |>
str_replace("tune_fns.rds", path(hash_tune_fns, ext = "rds"))

# -------------------- Creates and uploads the grid -------------------------
res_id_df <- purrr::map_df(
Expand Down Expand Up @@ -307,23 +338,47 @@ loop_call <- function(x) {
stop("Packages ", missing_pkgs, " are missing")
}
xy <- 1

# ------------------- Reads files with needed R objects ----------------------
# Loads the needed R objects from disk
debug <- TRUE
static_fname <- "static.rds"
resample_fname <- "resamples.rds"
tune_fns_fname <- "tune_fns.rds"
if (isFALSE(debug)) {
pyspark <- reticulate::import("pyspark")
static_file <- pyspark$SparkFiles$get(static_fname)
resample_file <- pyspark$SparkFiles$get(resample_fname)
tune_fns_file <- pyspark$SparkFiles$get(tune_fns_fname)
} else {
temp_path <- Sys.getenv("TEMP_SPARK_GRID", unset = "~")
static_file <- file.path(temp_path, static_fname)
resample_file <- file.path(temp_path, resample_fname)
tune_fns_file <- file.path(temp_path, tune_fns_fname)
}
static <- readRDS(static_file)
resamples <- readRDS(resample_file)

# Load tune internal functions (or use fallback for direct test calls)
if (file.exists(tune_fns_file)) {
tune_fns <- readRDS(tune_fns_file)
get_data_subsets <- tune_fns$get_data_subsets
loop_over_all_stages <- tune_fns$loop_over_all_stages
} else {
# Fallback for direct calls in tests (debug mode without uploaded file)
# Try both naming conventions (dev uses dots, CRAN doesn't)
tryCatch({
get_data_subsets <- getFromNamespace(".get_data_subsets", "tune")
}, error = function(e) {
get_data_subsets <<- getFromNamespace("get_data_subsets", "tune")
})
tryCatch({
loop_over_all_stages <- getFromNamespace(".loop_over_all_stages", "tune")
}, error = function(e) {
loop_over_all_stages <<- getFromNamespace("loop_over_all_stages", "tune")
})
}

# ------------ Iterates through all the combinations in `x` ------------------
# Spark will more likely send more than one row (combination) in `x`. It
# will depend on how the grid data frame was partitioned inside Spark.
Expand All @@ -340,7 +395,7 @@ loop_call <- function(x) {
index <- curr_x$index
curr_resample <- resamples[[index]]

data_splits <- tune:::get_data_subsets(
data_splits <- get_data_subsets(
static$wflow,
curr_resample$splits[[1]],
static$split_args
Expand All @@ -351,13 +406,8 @@ loop_call <- function(x) {
curr_grid <- tibble::as_tibble(curr_grid)
assign(".Random.seed", c(1L, 2L, 3L), envir = .GlobalEnv)
# ------ Sends current combination to `tune` for processing ----------------
# TODO: This function check exists because the `tune` version in the Spark
# cluster may be CRAN. This needs to be removed by the time of release
if (exists(".loop_over_all_stages", where = "package:tune")) {
res <- tune::.loop_over_all_stages(curr_resample, curr_grid, static)
} else {
res <- tune:::loop_over_all_stages(curr_resample, curr_grid, static)
}
# Use the captured function to ensure it's available in Spark workers
res <- loop_over_all_stages(curr_resample, curr_grid, static)
# -------------------- Extracts metrics from results -----------------------
# Mapping function accepts only tables as output, so only the metrics are
# being sent back instead of the entire results object
Expand Down Expand Up @@ -451,7 +501,7 @@ prep_static <- function(
data = resamples$splits[[1]]$data,
grid_names = names(grid)
)
grid <- tune::.check_grid(
grid <- tune_fn("check_grid")(
grid = grid,
workflow = wf,
pset = param_info
Expand All @@ -475,7 +525,7 @@ prep_static <- function(
control_err
))
}
control <- tune::.update_parallel_over(control, resamples, grid)
control <- tune_fn("update_parallel_over")(control, resamples, grid)
eval_time <- tune::check_eval_time_arg(eval_time, wf_metrics, call = call)
needed_pkgs <- c(
"rsample",
Expand All @@ -495,11 +545,11 @@ prep_static <- function(
out$static <- list(
wflow = wf,
param_info = param_info,
configs = tune::.get_config_key(grid, wf),
configs = tune_fn("get_config_key")(grid, wf),
post_estimation = workflows::.workflow_postprocessor_requires_fit(wf),
metrics = wf_metrics,
metric_info = tibble::as_tibble(wf_metrics),
pred_types = tune::.determine_pred_types(wf, wf_metrics),
pred_types = tune_fn("determine_pred_types")(wf, wf_metrics),
eval_time = eval_time,
split_args = rsample::.get_split_args(resamples),
control = control,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# * https://testthat.r-lib.org/articles/special-files.html

# Sys.setenv("CODE_COVERAGE" = "true")
# Sys.setenv("SPARK_VERSION" = "4.0.1"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.10")
# Sys.setenv("SPARK_VERSION" = "4.1.1"); Sys.setenv("SCALA_VERSION" = "2.13"); Sys.setenv("PYTHON_VERSION" = "3.13")
# Sys.setenv("SPARK_VERSION" = "3.5.7"); Sys.setenv("SCALA_VERSION" = "2.12"); Sys.setenv("PYTHON_VERSION" = "3.10")
if (identical(Sys.getenv("CODE_COVERAGE"), "true")) {
library(testthat)
Expand Down
Loading