diff --git a/R/lazy-join-query.R b/R/lazy-join-query.R index 5a1bd354f..d57a35190 100644 --- a/R/lazy-join-query.R +++ b/R/lazy-join-query.R @@ -286,7 +286,7 @@ sql_build.lazy_semi_join_query <- function(op, con, ..., sql_options = NULL) { y_vars <- op_vars(op$y) y_as <- op$by$y_as replacements <- lapply(y_vars, \(var) sql_glue2(con, "{y_as}.{.id var}")) - where <- lapply(op$where, \(expr) replace_sym(expr, y_vars, replacements)) + where <- replace_sym(op$where, y_vars, replacements) where_sql <- translate_sql_( where, con = con, diff --git a/R/tidyeval-across.R b/R/tidyeval-across.R index 3ca06bd0f..2850365ae 100644 --- a/R/tidyeval-across.R +++ b/R/tidyeval-across.R @@ -241,7 +241,7 @@ partial_eval_fun <- function(fun, env, fn) { partial_eval_prepare_fun <- function(call, sym, env) { # First resolve any .data/.env pronouns before symbol replacement call <- resolve_mask_pronouns(call, env) - call <- replace_sym(call, sym, replace = quote(!!.x)) + call <- replace_sym1(call, sym, replace = quote(!!.x)) call <- replace_call(call, replace = quote(!!.cur_col)) function(x, .cur_col) { inject( diff --git a/R/tidyeval.R b/R/tidyeval.R index 7ba410f0e..ab72b1767 100644 --- a/R/tidyeval.R +++ b/R/tidyeval.R @@ -210,7 +210,7 @@ partial_eval_call <- function(call, data, env) { if (inherits(fun, "inline_colwise_function")) { vars <- colnames(tidyselect_data_proxy(data)) dot_var <- vars[[attr(call, "position")]] - call <- replace_sym(attr(fun, "formula")[[2]], c(".", ".x"), sym(dot_var)) + call <- replace_sym1(attr(fun, "formula")[[2]], c(".", ".x"), sym(dot_var)) env <- get_env(attr(fun, "formula")) } else if (is.function(fun)) { fun_name <- find_fun(fun) @@ -306,7 +306,18 @@ fun_name <- function(fun) { NULL } -replace_sym <- function(call, sym, replace) { + +replace_sym <- function(exprs, old, new) { + check_list(exprs, allow_null = TRUE) + check_character(old) + check_list(new) + # Allow new to be a list of quosures too + new <- purrr::map_if(new, is_quosure, quo_get_expr) + + purrr::map(exprs, \(expr) replace_sym1(expr, old, new)) +} + +replace_sym1 <- function(call, sym, replace) { if (is_symbol(call, sym)) { if (is_list(replace)) { replace[[match(as_string(call), sym)]] @@ -314,7 +325,7 @@ replace_sym <- function(call, sym, replace) { replace } } else if (is_call(call)) { - call[] <- lapply(call, replace_sym, sym = sym, replace = replace) + call[] <- lapply(call, replace_sym1, sym = sym, replace = replace) call } else { call diff --git a/R/translate-sql-window.R b/R/translate-sql-window.R index 81978aceb..19f9cfbc6 100644 --- a/R/translate-sql-window.R +++ b/R/translate-sql-window.R @@ -403,8 +403,11 @@ uses_window_fun <- function(x, con, lq) { check_list(x) calls <- unlist(lapply(x, all_calls)) - win_f <- env_names(dbplyr_sql_translation(con)$window) - any(calls %in% win_f) + any(calls %in% window_funs(con)) +} + +window_funs <- function(con = simulate_dbi()) { + env_names(sql_translation(con)$window) } is_aggregating <- function(x, non_group_cols, agg_f) { @@ -429,10 +432,6 @@ is_aggregating <- function(x, non_group_cols, agg_f) { return(TRUE) } -common_window_funs <- function() { - env_names(dbplyr_sql_translation(NULL)$window) # nocov -} - #' @noRd #' @examples #' translate_window_where(quote(1)) @@ -441,7 +440,9 @@ common_window_funs <- function() { #' translate_window_where(quote(x == 1 && y == 2)) #' translate_window_where(quote(n() > 10)) #' translate_window_where(quote(rank() > cumsum(AB))) -translate_window_where <- function(expr, window_funs = common_window_funs()) { +translate_window_where <- function(expr, window_funs = NULL) { + window_funs <- window_funs %||% window_funs() + switch( typeof(expr), logical = , @@ -475,12 +476,12 @@ translate_window_where <- function(expr, window_funs = common_window_funs()) { ) } - #' @noRd #' @examples #' translate_window_where_all(list(quote(x == 1), quote(n() > 2))) #' translate_window_where_all(list(quote(cumsum(x) == 10), quote(n() > 2))) -translate_window_where_all <- function(x, window_funs = common_window_funs()) { +translate_window_where_all <- function(x, window_funs = NULL) { + window_funs <- window_funs %||% window_funs() out <- lapply(x, translate_window_where, window_funs = window_funs) list( diff --git a/R/utils-check.R b/R/utils-check.R index 6cd42ae51..e92c18738 100644 --- a/R/utils-check.R +++ b/R/utils-check.R @@ -10,6 +10,9 @@ check_list <- function( if (vctrs::vec_is_list(x)) { return() } + if (allow_null && is_null(x)) { + return() + } stop_input_type( x, c("a list"), diff --git a/R/verb-filter.R b/R/verb-filter.R index b873ca3d7..a6de658c2 100644 --- a/R/verb-filter.R +++ b/R/verb-filter.R @@ -25,16 +25,15 @@ filter.tbl_lazy <- function(.data, ..., .by = NULL, .preserve = FALSE) { data_arg = ".data", error_call = caller_env() ) - if (by$from_by) { - .data$lazy_query$group_vars <- by$names - } - dots <- partial_eval_dots(.data, ..., .named = FALSE) - - if (is_empty(dots)) { + if (missing(...)) { return(.data) } + if (by$from_by) { + .data$lazy_query$group_vars <- by$names + } + dots <- partial_eval_dots(.data, ..., .named = FALSE) .data$lazy_query <- add_filter(.data, dots) if (by$from_by) { .data$lazy_query$group_vars <- character() @@ -47,45 +46,9 @@ add_filter <- function(.data, dots) { lazy_query <- .data$lazy_query dots <- unname(dots) - dots_use_window_fun <- uses_window_fun(dots, con) - - if (filter_can_use_having(lazy_query, dots_use_window_fun)) { - return(filter_via_having(lazy_query, dots)) - } - - if (!dots_use_window_fun) { - if (filter_needs_new_query(dots, lazy_query, con)) { - lazy_select_query( - x = lazy_query, - where = dots - ) - } else { - exprs <- lazy_query$select$expr - nms <- lazy_query$select$name - projection <- purrr::map2_lgl( - exprs, - nms, - \(expr, name) is_symbol(expr) && !identical(expr, sym(name)) - ) - - if (any(projection)) { - dots <- purrr::map( - dots, - replace_sym, - nms[projection], - exprs[projection] - ) - } - - lazy_query$where <- c(lazy_query$where, dots) - lazy_query - } - } else { + if (uses_window_fun(dots, con)) { # Do partial evaluation, then extract out window functions - where <- translate_window_where_all( - dots, - env_names(dbplyr_sql_translation(con)$window) - ) + where <- translate_window_where_all(dots, window_funs(con)) # Add extracted window expressions as columns mutated <- mutate(.data, !!!where$comp) @@ -97,30 +60,40 @@ add_filter <- function(.data, dots) { select = syms(set_names(original_vars)), where = where$expr ) + } else if (filter_can_use_having(lazy_query)) { + filter_via_having(lazy_query, dots) + } else if (filter_can_inline(dots, lazy_query, con)) { + # WHERE processed before SELECT + dots <- replace_sym(dots, lazy_query$select$name, lazy_query$select$expr) + + lazy_query$where <- c(lazy_query$where, dots) + lazy_query + } else { + lazy_select_query(x = lazy_query, where = dots) } } -filter_needs_new_query <- function(dots, lazy_query, con) { +filter_can_inline <- function(dots, lazy_query, con) { if (!is_lazy_select_query(lazy_query)) { - return(TRUE) + return(FALSE) } if (uses_mutated_vars(dots, lazy_query$select)) { - return(TRUE) + return(FALSE) } if (uses_window_fun(lazy_query$select$expr, con)) { - return(TRUE) + return(FALSE) } if (any_expr_uses_sql(lazy_query$select$expr)) { - return(TRUE) + return(FALSE) } - FALSE + TRUE } -filter_can_use_having <- function(lazy_query, dots_use_window_fun) { +filter_can_use_having <- function(lazy_query) { # From the Postgres documentation: https://www.postgresql.org/docs/current/sql-select.html#SQL-HAVING # Each column referenced in condition must unambiguously reference a grouping # column, unless the reference appears within an aggregate function or the @@ -133,22 +106,16 @@ filter_can_use_having <- function(lazy_query, dots_use_window_fun) { # # Therefore, if `filter()` does not use a window function, then we only use # grouping or aggregated columns - - if (dots_use_window_fun) { - return(FALSE) - } - if (!is_lazy_select_query(lazy_query)) { - return(FALSE) + FALSE + } else { + lazy_query$select_operation == "summarise" } - - lazy_query$select_operation == "summarise" } filter_via_having <- function(lazy_query, dots) { - names <- lazy_query$select$name - exprs <- purrr::map_if(lazy_query$select$expr, is_quosure, quo_get_expr) - dots <- purrr::map(dots, replace_sym, names, exprs) + # ANSI SQL processes HAVING before SELECT + dots <- replace_sym(dots, lazy_query$select$name, lazy_query$select$expr) lazy_query$having <- c(lazy_query$having, dots) lazy_query diff --git a/R/verb-select.R b/R/verb-select.R index bd0a1083a..f10485521 100644 --- a/R/verb-select.R +++ b/R/verb-select.R @@ -149,9 +149,7 @@ rename_order <- function(order_vars, select_vars) { order_vars <- order_vars[order_names %in% select_vars] # Rename the remaining - order_vars[] <- lapply(order_vars, \(expr) { - replace_sym(expr, select_vars, syms(names(select_vars))) - }) + order_vars[] <- replace_sym(order_vars, select_vars, syms(names(select_vars))) order_vars }