diff --git a/NEWS.md b/NEWS.md index 365cbeeb1..de0ee2b36 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # dbplyr (development version) +* `filter()` after a sequence of `left_join()` and `inner_joins()` no longer generates a subquery (#722). * `summarise()` now reports grouping immediately, rather than when you summarise. * `sql_optimise()` has been removed. It was only used for two cases (filter + summarise and arrange + summarise), and these are now handled at a higher level (#1720). * `distinct()` after a join no longer creates a subquery (#722). diff --git a/R/db-sql.R b/R/db-sql.R index f5e9840c3..bd6b49f87 100644 --- a/R/db-sql.R +++ b/R/db-sql.R @@ -375,6 +375,7 @@ sql_query_join <- function( by = NULL, na_matches = FALSE, ..., + where = NULL, lvl = 0 ) { check_dots_used() @@ -390,6 +391,7 @@ sql_query_join.DBIConnection <- function( by = NULL, na_matches = FALSE, ..., + where = NULL, lvl = 0 ) { JOIN <- switch( @@ -412,7 +414,8 @@ sql_query_join.DBIConnection <- function( sql_clause_select(con, select), sql_clause_from(x), sql_clause(JOIN, y), - sql_clause("ON", on, sep = " AND", parens = TRUE, lvl = 1) + sql_clause("ON", on, sep = " AND", parens = TRUE, lvl = 1), + sql_clause_where(where) ) sql_format_clauses(clauses, lvl, con) } @@ -426,6 +429,7 @@ dbplyr_query_join <- function( na_matches = FALSE, ..., select = NULL, + where = NULL, lvl = 0 ) { check_2ed(con) @@ -437,6 +441,7 @@ dbplyr_query_join <- function( type = type, by = by, na_matches = na_matches, + where = where, ..., lvl = lvl ) @@ -452,6 +457,7 @@ sql_query_multi_join <- function( by_list, select, ..., + where = NULL, distinct = FALSE, lvl = 0 ) { @@ -497,6 +503,7 @@ sql_query_multi_join.DBIConnection <- function( by_list, select, ..., + where = NULL, distinct = FALSE, lvl = 0 ) { @@ -524,7 +531,8 @@ sql_query_multi_join.DBIConnection <- function( clauses <- list2( sql_clause_select(con, select, distinct), sql_clause_from(from), - !!!out + !!!out, + sql_clause_where(where) ) sql_format_clauses(clauses, lvl = lvl, con = con) } diff --git a/R/lazy-join-query.R b/R/lazy-join-query.R index 5a1bd354f..742d2529d 100644 --- a/R/lazy-join-query.R +++ b/R/lazy-join-query.R @@ -5,6 +5,7 @@ lazy_multi_join_query <- function( joins, table_names, vars, + where = list(), distinct = FALSE, group_vars = op_grps(x), order_vars = op_sort(x), @@ -31,6 +32,7 @@ lazy_multi_join_query <- function( joins = joins, table_names = table_names, vars = vars, + where = where, distinct = distinct, group_vars = group_vars, order_vars = order_vars, @@ -47,6 +49,7 @@ lazy_rf_join_query <- function( by, table_names, vars, + where = list(), group_vars = op_grps(x), order_vars = op_sort(x), frame = op_frame(x), @@ -74,6 +77,7 @@ lazy_rf_join_query <- function( by = by, table_names = table_names, vars = vars, + where = where, group_vars = group_vars, order_vars = order_vars, frame = frame @@ -190,11 +194,34 @@ sql_build.lazy_multi_join_query <- function(op, con, ..., sql_options = NULL) { } ) + join_vars <- sql_multi_join_vars( + con, + op$vars, + table_vars, + use_star = FALSE, + qualify_all_columns = sql_options$qualify_all_columns + ) + # WHERE happens after SELECT, but columns names are disambiguated using + # SELECT expressions, so need to backtransform + where <- lapply(op$where, \(expr) { + replace_sym( + expr, + names(join_vars), + lapply(unname(join_vars), \(x) sql(x[[1]])) + ) + }) + where_sql <- translate_sql_( + where, + con = con, + context = list(clause = "WHERE") + ) + multi_join_query( x = sql_build(op$x, con, sql_options = sql_options), joins = op$joins, table_names = table_names_out, select = select_sql, + where = where_sql, distinct = op$distinct ) } diff --git a/R/query-join.R b/R/query-join.R index e164f66dd..df172061a 100644 --- a/R/query-join.R +++ b/R/query-join.R @@ -10,7 +10,8 @@ join_query <- function( type = "inner", by = NULL, suffix = c(".x", ".y"), - na_matches = FALSE + na_matches = FALSE, + where = NULL ) { structure( list( @@ -19,19 +20,28 @@ join_query <- function( select = select, type = type, by = by, - na_matches = na_matches + na_matches = na_matches, + where = where ), class = c("join_query", "query") ) } -multi_join_query <- function(x, joins, table_names, select, distinct = FALSE) { +multi_join_query <- function( + x, + joins, + table_names, + select, + where = NULL, + distinct = FALSE +) { structure( list( x = x, joins = joins, table_names = table_names, select = select, + where = where, distinct = distinct ), class = c("multi_join_query", "query") @@ -91,6 +101,7 @@ sql_render.join_query <- function( by = query$by, na_matches = query$na_matches, select = query$select, + where = query$where, lvl = lvl ) } @@ -117,6 +128,7 @@ sql_render.multi_join_query <- function( table_names = query$table_names, by_list = query$by_list, select = query$select, + where = query$where, distinct = query$distinct, lvl = lvl ) diff --git a/R/verb-filter.R b/R/verb-filter.R index b873ca3d7..550441169 100644 --- a/R/verb-filter.R +++ b/R/verb-filter.R @@ -44,83 +44,74 @@ filter.tbl_lazy <- function(.data, ..., .by = NULL, .preserve = FALSE) { add_filter <- function(.data, dots) { con <- remote_con(.data) - lazy_query <- .data$lazy_query dots <- unname(dots) - dots_use_window_fun <- uses_window_fun(dots, con) - - if (filter_can_use_having(lazy_query, dots_use_window_fun)) { - return(filter_via_having(lazy_query, dots)) - } - - if (!dots_use_window_fun) { - if (filter_needs_new_query(dots, lazy_query, con)) { - lazy_select_query( - x = lazy_query, - where = dots - ) - } else { - exprs <- lazy_query$select$expr - nms <- lazy_query$select$name - projection <- purrr::map2_lgl( - exprs, - nms, - \(expr, name) is_symbol(expr) && !identical(expr, sym(name)) - ) - - if (any(projection)) { - dots <- purrr::map( - dots, - replace_sym, - nms[projection], - exprs[projection] - ) - } - - lazy_query$where <- c(lazy_query$where, dots) - lazy_query - } - } else { - # Do partial evaluation, then extract out window functions - where <- translate_window_where_all( - dots, - env_names(dbplyr_sql_translation(con)$window) - ) + # Handle window functions by adding an intermediate mutate + # by definition this has to create a subquery + if (uses_window_fun(dots, con)) { + window_funs <- env_names(dbplyr_sql_translation(con)$window) + where <- translate_window_where_all(dots, window_funs) # Add extracted window expressions as columns mutated <- mutate(.data, !!!where$comp) # And filter with the modified `where` using the new columns original_vars <- op_vars(.data) - lazy_select_query( + return(lazy_select_query( x = mutated$lazy_query, select = syms(set_names(original_vars)), where = where$expr - ) + )) + } + + lazy_query <- .data$lazy_query + if (filter_can_use_having(lazy_query)) { + names <- lazy_query$select$name + exprs <- purrr::map_if(lazy_query$select$expr, is_quosure, quo_get_expr) + dots <- purrr::map(dots, replace_sym, names, exprs) + + lazy_query$having <- c(lazy_query$having, dots) + lazy_query + } else if (filter_can_inline(dots, lazy_query, con)) { + # WHERE happens before SELECT so can't refer to aliases + # might be either a lazy_select_query or a lazy_multi_join_query + if (is_lazy_select_query(lazy_query)) { + dots <- rename_aliases(dots, lazy_query$select) + } + + lazy_query$where <- c(lazy_query$where, dots) + lazy_query + } else { + lazy_select_query(x = lazy_query, where = dots) } } -filter_needs_new_query <- function(dots, lazy_query, con) { - if (!is_lazy_select_query(lazy_query)) { +filter_can_inline <- function(dots, lazy_query, con) { + if (inherits(lazy_query, "lazy_multi_join_query")) { + # can't use mutated variables, window funs, or SQL return(TRUE) } + if (!is_lazy_select_query(lazy_query)) { + return(FALSE) + } + if (uses_mutated_vars(dots, lazy_query$select)) { - return(TRUE) + return(FALSE) } if (uses_window_fun(lazy_query$select$expr, con)) { - return(TRUE) + return(FALSE) } if (any_expr_uses_sql(lazy_query$select$expr)) { - return(TRUE) + return(FALSE) } - FALSE + TRUE } -filter_can_use_having <- function(lazy_query, dots_use_window_fun) { +filter_can_use_having <- function(lazy_query) { # From the Postgres documentation: https://www.postgresql.org/docs/current/sql-select.html#SQL-HAVING # Each column referenced in condition must unambiguously reference a grouping # column, unless the reference appears within an aggregate function or the @@ -134,24 +125,23 @@ filter_can_use_having <- function(lazy_query, dots_use_window_fun) { # Therefore, if `filter()` does not use a window function, then we only use # grouping or aggregated columns - if (dots_use_window_fun) { - return(FALSE) - } - - if (!is_lazy_select_query(lazy_query)) { - return(FALSE) + if (is_lazy_select_query(lazy_query)) { + lazy_query$select_operation == "summarise" + } else { + FALSE } - - lazy_query$select_operation == "summarise" } -filter_via_having <- function(lazy_query, dots) { - names <- lazy_query$select$name - exprs <- purrr::map_if(lazy_query$select$expr, is_quosure, quo_get_expr) - dots <- purrr::map(dots, replace_sym, names, exprs) +rename_aliases <- function(dots, select) { + exprs <- select$expr + nms <- select$name + projection <- purrr::map_lgl(exprs, is_symbol) + + if (!any(projection)) { + return(dots) + } - lazy_query$having <- c(lazy_query$having, dots) - lazy_query + purrr::map(dots, \(dot) replace_sym(dot, nms[projection], exprs[projection])) } check_filter <- function(...) { diff --git a/man/db-sql.Rd b/man/db-sql.Rd index f88522721..e535ef974 100644 --- a/man/db-sql.Rd +++ b/man/db-sql.Rd @@ -83,6 +83,7 @@ sql_query_join( by = NULL, na_matches = FALSE, ..., + where = NULL, lvl = 0 ) @@ -94,6 +95,7 @@ sql_query_multi_join( by_list, select, ..., + where = NULL, distinct = FALSE, lvl = 0 ) diff --git a/man/sql_build.Rd b/man/sql_build.Rd index 4ba414bf9..879ae326d 100644 --- a/man/sql_build.Rd +++ b/man/sql_build.Rd @@ -24,6 +24,7 @@ lazy_multi_join_query( joins, table_names, vars, + where = list(), distinct = FALSE, group_vars = op_grps(x), order_vars = op_sort(x), @@ -38,6 +39,7 @@ lazy_rf_join_query( by, table_names, vars, + where = list(), group_vars = op_grps(x), order_vars = op_sort(x), frame = op_frame(x), @@ -104,7 +106,8 @@ join_query( type = "inner", by = NULL, suffix = c(".x", ".y"), - na_matches = FALSE + na_matches = FALSE, + where = NULL ) select_query( diff --git a/tests/testthat/_snaps/backend-mssql.md b/tests/testthat/_snaps/backend-mssql.md index 9bb3b2621..e8612ba7f 100644 --- a/tests/testthat/_snaps/backend-mssql.md +++ b/tests/testthat/_snaps/backend-mssql.md @@ -759,6 +759,8 @@ EOF within quoted string Warning in `scan()`: EOF within quoted string + Warning in `scan()`: + EOF within quoted string Output SELECT [LHS]] diff --git a/tests/testthat/_snaps/verb-filter.md b/tests/testthat/_snaps/verb-filter.md index e77c9c7df..3a4dd5d34 100644 --- a/tests/testthat/_snaps/verb-filter.md +++ b/tests/testthat/_snaps/verb-filter.md @@ -52,6 +52,42 @@ ! `.preserve = TRUE` isn't supported on database backends. i It must be FALSE instead. +# filter() inlined after join + + Code + show_query(out) + Output + + SELECT "df1".*, "z" + FROM "df1" + LEFT JOIN "df2" + ON ("df1"."x" = "df2"."x") + WHERE ("z" = 1.0) + +--- + + Code + show_query(out) + Output + + SELECT "df1".*, "z" + FROM "df1" + LEFT JOIN "df2" + ON ("df1"."x" = "df2"."x") + WHERE ("y" = 1.0) AND ("z" = 2.0) + +--- + + Code + show_query(out) + Output + + SELECT "df1"."x" AS "x", "df1"."y" AS "y.x", "df3"."y" AS "y.y" + FROM "df1" + LEFT JOIN "df3" + ON ("df1"."x" = "df3"."x") + WHERE ("df3"."y" = 1.0) + # catches `.by` with grouped-df Code diff --git a/tests/testthat/test-lazy-join-query.R b/tests/testthat/test-lazy-join-query.R index 1057eec58..056050d42 100644 --- a/tests/testthat/test-lazy-join-query.R +++ b/tests/testthat/test-lazy-join-query.R @@ -1,3 +1,17 @@ +test_that("sql_build.lazy_multi_join_query() includes where", { + lf1 <- lazy_frame(x = 1, y = 1) + lf2 <- lazy_frame(x = 1, z = 2) + + out <- lf1 |> + left_join(lf2, by = "x") |> + filter(y == 1, z == 2) + query <- out$lazy_query + + expect_s3_class(query, "lazy_multi_join_query") + built <- sql_build(query, simulate_dbi()) + expect_equal(built$where, sql('"y" = 1.0', '"z" = 2.0')) +}) + test_that("sql_build.lazy_multi_join_query() includes distinct", { lf1 <- lazy_frame(x = 1, y = 1) lf2 <- lazy_frame(x = 1, z = 2) diff --git a/tests/testthat/test-verb-filter.R b/tests/testthat/test-verb-filter.R index d4ba432c2..b15336559 100644 --- a/tests/testthat/test-verb-filter.R +++ b/tests/testthat/test-verb-filter.R @@ -32,7 +32,7 @@ test_that("correctly inlines across all verbs", { # two table verbs lf2 <- lazy_frame(x = 1) - expect_selects(lf |> left_join(lf2, by = "x") |> filter(x == 1), 2) + expect_selects(lf |> left_join(lf2, by = "x") |> filter(x == 1), 1) expect_selects(lf |> right_join(lf2, by = "x") |> filter(x == 1), 2) expect_selects(lf |> semi_join(lf2, by = "x") |> filter(x == 1), 3) expect_selects(lf |> union(lf2) |> filter(x == 1), 3) @@ -140,6 +140,40 @@ test_that("filter() inlined after mutate()", { expect_equal(lq3$where, list(quo(y == sql("1"))), ignore_formula_env = TRUE) }) +test_that("filter() inlined after join", { + lf1 <- lazy_frame(x = 1, y = 1, .name = "df1") + lf2 <- lazy_frame(x = 1, z = 2, .name = "df2") + + out <- lf1 |> + left_join(lf2, by = "x") |> + filter(z == 1) + expect_s3_class(out$lazy_query, "lazy_multi_join_query") + expect_snapshot(show_query(out)) + + # multiple filters are combined + out <- lf1 |> + left_join(lf2, by = "x") |> + filter(y == 1) |> + filter(z == 2) + expect_s3_class(out$lazy_query, "lazy_multi_join_query") + expect_snapshot(show_query(out)) + + # Handles aliasing from join + lf3 <- lazy_frame(x = 1, y = 2, .name = "df3") + out <- lf1 |> + left_join(lf3, by = "x") |> + filter(y.y == 1) + expect_snapshot(show_query(out)) +}) + +test_that("inlined join works", { + # single integration test of the most complicated case + lf1 <- memdb_frame(x = 1:2, y = 1:2) + lf2 <- memdb_frame(x = 1:2, y = 3:4) + jf <- lf1 |> left_join(lf2, by = "x") |> filter(y.x == 1) + expect_equal(collect(jf), tibble(x = 1, y.x = 1, y.y = 3)) +}) + test_that("filter isn't inlined after mutate with window function #1135", { lf <- lazy_frame(x = 1L, y = 1:2) out <- lf |>