Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
^derby\.log$
^[.]?air[.]toml$
^\.vscode$
^\.claude$
20 changes: 19 additions & 1 deletion .github/workflows/spark-tests.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
on:
push:
branches: main
paths:
- '.github/workflows/spark-tests.yaml'
- 'R/**'
- 'tests/**'
- 'DESCRIPTION'
- 'NAMESPACE'
- 'man/**'
- 'inst/**'
- 'src/**'
pull_request:
branches: main
paths:
- '.github/workflows/spark-tests.yaml'
- 'R/**'
- 'tests/**'
- 'DESCRIPTION'
- 'NAMESPACE'
- 'man/**'
- 'inst/**'
- 'src/**'

name: Spark-Connect

Expand All @@ -17,7 +35,7 @@ jobs:
fail-fast: false
matrix:
config:
- {spark: '3.5.7', pyspark: '3.5.7', hadoop: '3', scala: '2.12', python: '3.10', name: 'PySpark 3'}
- {spark: '3.5.8', pyspark: '3.5.8', hadoop: '3', scala: '2.12', python: '3.10', name: 'PySpark 3'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
Expand Down
18 changes: 18 additions & 0 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
on:
push:
branches: main
paths:
- '.github/workflows/test-coverage.yaml'
- 'R/**'
- 'tests/**'
- 'DESCRIPTION'
- 'NAMESPACE'
- 'man/**'
- 'inst/**'
- 'src/**'
pull_request:
branches: main
paths:
- '.github/workflows/test-coverage.yaml'
- 'R/**'
- 'tests/**'
- 'DESCRIPTION'
- 'NAMESPACE'
- 'man/**'
- 'inst/**'
- 'src/**'

name: test-coverage

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
derby.log
spark-warehouse
requirements.txt
.claude
9 changes: 3 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: pysparklyr
Title: Provides a 'PySpark' Back-End for the 'sparklyr' Package
Version: 0.2.0.9003
Version: 0.2.1
Authors@R: c(
person("Edgar", "Ruiz", , "[email protected]", role = c("aut", "cre")),
person("Posit Software, PBC", role = c("cph", "fnd"),
Expand Down Expand Up @@ -30,7 +30,7 @@ Imports:
reticulate (>= 1.44.0),
rlang,
rstudioapi,
sparklyr (>= 1.9.3.9000),
sparklyr (>= 1.9.4),
tidyr,
tidyselect,
vctrs,
Expand All @@ -46,7 +46,7 @@ Suggests:
rsconnect,
rsample,
workflows,
tune (>= 2.0.1.9002),
tune,
parsnip,
dials,
tailor,
Expand All @@ -56,6 +56,3 @@ Config/usethis/last-upkeep: 2025-11-12
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.3
Remotes:
tidymodels/tune,
sparklyr/sparklyr
17 changes: 14 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
# pysparklyr (dev)
# pysparklyr 0.2.1

### Improvements
### New

- Adds support for `tune_grid_spark()`. It enables running a Tidymodels tune
grid inside Spark Connect clusters.

### Improvements

- Databricks Connect now auto-detects the latest library version from PyPI when
no `version` parameter is specified. When the auto-detected version differs from
the cluster's DBR version, a warning is displayed with suggestions for ensuring
version compatibility.
version compatibility.

- Adds `profile` argument support to the Databricks SDK connection call.

- When no cluster version is provided, uses the latest available main library
version from PyPI.

### Fixes

- Fixes conversion of Pandas NULL columns and date types (#178 - @tobiasdut)

# pysparklyr 0.2.0

Expand Down
6 changes: 3 additions & 3 deletions R/start-stop-service.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ spark_connect_service_start <- function(
}

# Store the process for potential cleanup
assign("spark_connect_process", prs, envir = .GlobalEnv)
pysparklyr_env$spark_connect_process <- prs

cli_end()
invisible()
Expand All @@ -124,8 +124,8 @@ spark_connect_service_stop <- function(version = "4.0", ...) {
cli_bullets(c(" " = "{.info - Shutdown command sent}"))

# Clean up stored process reference
if (exists("spark_connect_process", envir = .GlobalEnv)) {
rm("spark_connect_process", envir = .GlobalEnv)
if (!is.null(pysparklyr_env$spark_connect_process)) {
pysparklyr_env$spark_connect_process <- NULL
}

cli_end()
Expand Down
65 changes: 7 additions & 58 deletions R/tune-grid.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
# Helper to get tune functions that may have dots in dev version
tune_fn <- function(name) {
# Dev tune exports with dot prefix, CRAN tune without
dot_name <- paste0(".", name)
ns <- asNamespace("tune")
if (exists(dot_name, where = ns)) {
getFromNamespace(dot_name, "tune")
} else if (exists(name, where = ns)) {
getFromNamespace(name, "tune")
} else {
stop("Function '", name, "' not found in tune package")
}
}

#' @importFrom sparklyr tune_grid_spark
#' @export
tune_grid_spark.pyspark_connection <- function(
Expand Down Expand Up @@ -92,17 +78,6 @@ tune_grid_spark.pyspark_connection <- function(
}
spark_session_add_file(vec_resamples, sc, hash_resamples)

# ------------------- Upload tune internal functions -------------------------
# For Spark 4.1.1+ with Python 3.13, internal functions don't serialize properly
# Capture them here and upload as an RDS file
# Use tune_fn() to handle both dev (with dot) and CRAN (without dot) versions
tune_fns <- list(
get_data_subsets = tune_fn("get_data_subsets"),
loop_over_all_stages = tune_fn("loop_over_all_stages")
)
hash_tune_fns <- "tune_fns"
spark_session_add_file(tune_fns, sc, hash_tune_fns)

# -------------------------- Creates the UDF ---------------------------------
# Uses the `loop_call` function as the base of the UDF that will be sent to
# the Spark session. It works by modifying the text of the function, specifically
Expand All @@ -119,8 +94,7 @@ tune_grid_spark.pyspark_connection <- function(
str_replace("debug <- TRUE", "debug <- FALSE") |>
str_replace("xy <- 1", function_capture_code) |>
str_replace("static.rds", path(hash_static, ext = "rds")) |>
str_replace("resamples.rds", path(hash_resamples, ext = "rds")) |>
str_replace("tune_fns.rds", path(hash_tune_fns, ext = "rds"))
str_replace("resamples.rds", path(hash_resamples, ext = "rds"))

# -------------------- Creates and uploads the grid -------------------------
res_id_df <- purrr::map_df(
Expand Down Expand Up @@ -344,41 +318,18 @@ loop_call <- function(x) {
debug <- TRUE
static_fname <- "static.rds"
resample_fname <- "resamples.rds"
tune_fns_fname <- "tune_fns.rds"
if (isFALSE(debug)) {
pyspark <- reticulate::import("pyspark")
static_file <- pyspark$SparkFiles$get(static_fname)
resample_file <- pyspark$SparkFiles$get(resample_fname)
tune_fns_file <- pyspark$SparkFiles$get(tune_fns_fname)
} else {
temp_path <- Sys.getenv("TEMP_SPARK_GRID", unset = "~")
static_file <- file.path(temp_path, static_fname)
resample_file <- file.path(temp_path, resample_fname)
tune_fns_file <- file.path(temp_path, tune_fns_fname)
}
static <- readRDS(static_file)
resamples <- readRDS(resample_file)

# Load tune internal functions (or use fallback for direct test calls)
if (file.exists(tune_fns_file)) {
tune_fns <- readRDS(tune_fns_file)
get_data_subsets <- tune_fns$get_data_subsets
loop_over_all_stages <- tune_fns$loop_over_all_stages
} else {
# Fallback for direct calls in tests (debug mode without uploaded file)
# Try both naming conventions (dev uses dots, CRAN doesn't)
tryCatch({
get_data_subsets <- getFromNamespace(".get_data_subsets", "tune")
}, error = function(e) {
get_data_subsets <<- getFromNamespace("get_data_subsets", "tune")
})
tryCatch({
loop_over_all_stages <- getFromNamespace(".loop_over_all_stages", "tune")
}, error = function(e) {
loop_over_all_stages <<- getFromNamespace("loop_over_all_stages", "tune")
})
}

# ------------ Iterates through all the combinations in `x` ------------------
# Spark will more likely send more than one row (combination) in `x`. It
# will depend on how the grid data frame was partitioned inside Spark.
Expand All @@ -395,7 +346,7 @@ loop_call <- function(x) {
index <- curr_x$index
curr_resample <- resamples[[index]]

data_splits <- get_data_subsets(
data_splits <- tune::.get_data_subsets(
static$wflow,
curr_resample$splits[[1]],
static$split_args
Expand All @@ -405,9 +356,7 @@ loop_call <- function(x) {
# loop_over_all_stages() requires the grid to be a tibble
curr_grid <- tibble::as_tibble(curr_grid)
assign(".Random.seed", c(1L, 2L, 3L), envir = .GlobalEnv)
# ------ Sends current combination to `tune` for processing ----------------
# Use the captured function to ensure it's available in Spark workers
res <- loop_over_all_stages(curr_resample, curr_grid, static)
res <- tune::.loop_over_all_stages(curr_resample, curr_grid, static)
# -------------------- Extracts metrics from results -----------------------
# Mapping function accepts only tables as output, so only the metrics are
# being sent back instead of the entire results object
Expand Down Expand Up @@ -501,7 +450,7 @@ prep_static <- function(
data = resamples$splits[[1]]$data,
grid_names = names(grid)
)
grid <- tune_fn("check_grid")(
grid <- tune::.check_grid(
grid = grid,
workflow = wf,
pset = param_info
Expand All @@ -525,7 +474,7 @@ prep_static <- function(
control_err
))
}
control <- tune_fn("update_parallel_over")(control, resamples, grid)
control <- tune::.update_parallel_over(control, resamples, grid)
eval_time <- tune::check_eval_time_arg(eval_time, wf_metrics, call = call)
needed_pkgs <- c(
"rsample",
Expand All @@ -545,11 +494,11 @@ prep_static <- function(
out$static <- list(
wflow = wf,
param_info = param_info,
configs = tune_fn("get_config_key")(grid, wf),
configs = tune::.get_config_key(grid, wf),
post_estimation = workflows::.workflow_postprocessor_requires_fit(wf),
metrics = wf_metrics,
metric_info = tibble::as_tibble(wf_metrics),
pred_types = tune_fn("determine_pred_types")(wf, wf_metrics),
pred_types = tune::.determine_pred_types(wf, wf_metrics),
eval_time = eval_time,
split_args = rsample::.get_split_args(resamples),
control = control,
Expand Down
4 changes: 0 additions & 4 deletions tests/testthat/_snaps/ml-feature-transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@
dplyr::pull(ft_ngram(ft_tokenizer(use_test_table_reviews(), "x", "token_x"),
"token_x", "ngram_x"))
Output
<list<character>[1]>
[[1]]
[1] "this has" "has been" "been the" "the best"
[5] "best tv" "tv i've" "i've ever" "ever used."
Expand Down Expand Up @@ -889,7 +888,6 @@
Code
dplyr::pull(ft_regex_tokenizer(use_test_table_reviews(), "x", "new_x"))
Output
<list<character>[1]>
[[1]]
[1] "this" "has" "been" "the" "best" "tv" "i've"
[8] "ever" "used." "great" "screen," "and" "sound."
Expand Down Expand Up @@ -1065,7 +1063,6 @@
dplyr::pull(ft_stop_words_remover(ft_tokenizer(use_test_table_reviews(),
input_col = "x", output_col = "token_x"), input_col = "token_x", output_col = "stop_x"))
Output
<list<character>[1]>
[[1]]
[1] "best" "tv" "ever" "used." "great" "screen," "sound."

Expand Down Expand Up @@ -1117,7 +1114,6 @@
Code
dplyr::pull(ft_tokenizer(use_test_table_reviews(), input_col = "x", output_col = "token_x"))
Output
<list<character>[1]>
[[1]]
[1] "this" "has" "been" "the" "best" "tv" "i've"
[8] "ever" "used." "great" "screen," "and" "sound."
Expand Down
Loading