diff --git a/NAMESPACE b/NAMESPACE index b2e724c54..215993945 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -407,6 +407,9 @@ S3method(sql_values_subquery,Redshift) S3method(sql_values_subquery,RedshiftConnection) S3method(src_tbls,src_sql) S3method(summarise,tbl_lazy) +S3method(supports_distinct_on,DBIConnection) +S3method(supports_distinct_on,PostgreSQL) +S3method(supports_distinct_on,PqConnection) S3method(supports_window_clause,"Spark SQL") S3method(supports_window_clause,ACCESS) S3method(supports_window_clause,DBIConnection) @@ -586,6 +589,7 @@ export(src_dbi) export(src_memdb) export(src_sql) export(src_test) +export(supports_distinct_on) export(supports_window_clause) export(table_path_components) export(table_path_name) diff --git a/R/backend-postgres.R b/R/backend-postgres.R index 23fbf47bd..f1bc30436 100644 --- a/R/backend-postgres.R +++ b/R/backend-postgres.R @@ -481,6 +481,16 @@ supports_window_clause.PostgreSQL <- function(con) { TRUE } +#' @export +supports_distinct_on.PqConnection <- function(con) { + TRUE +} + +#' @export +supports_distinct_on.PostgreSQL <- function(con) { + TRUE +} + #' @export db_supports_table_alias_with_as.PqConnection <- function(con) { TRUE diff --git a/R/db-sql.R b/R/db-sql.R index 6c3186c69..849812361 100644 --- a/R/db-sql.R +++ b/R/db-sql.R @@ -348,6 +348,16 @@ db_supports_table_alias_with_as.TestConnection <- function(con) { TRUE } +#' @rdname db-sql +#' @export +supports_distinct_on <- function(con) { + UseMethod("supports_distinct_on") +} + +#' @export +supports_distinct_on.DBIConnection <- function(con) { + FALSE +} # Query generation -------------------------------------------------------- diff --git a/R/lazy-select-query.R b/R/lazy-select-query.R index f28eef794..ab6753124 100644 --- a/R/lazy-select-query.R +++ b/R/lazy-select-query.R @@ -21,7 +21,10 @@ lazy_select_query <- function( # stopifnot(is.character(group_by)) stopifnot(is_lazy_sql_part(order_by)) check_number_whole(limit, allow_infinite = TRUE, allow_null = TRUE) - check_bool(distinct) + # distinct = FALSE -> no distinct + # distinct = TRUE -> distinct over all columns + # distinct = columns -> DISTINCT ON (...) + stopifnot(is_bool(distinct) || is_lazy_sql_part(distinct)) select <- select %||% syms(set_names(op_vars(x))) select_operation <- arg_match0( @@ -152,7 +155,12 @@ is_select_identity <- function(select, vars_prev) { #' @export print.lazy_select_query <- function(x, ...) { - cat_line("") + cat_line( + "" + ) cat_line("From:") cat_line(indent_print(sql_build(x$x, simulate_dbi()))) @@ -160,6 +168,9 @@ print.lazy_select_query <- function(x, ...) { if (length(select)) { cat_line("Select: ", named_commas(select)) } + if (is_lazy_sql_part(x$distinct)) { + cat_line("Dist on: ", named_commas(x$distinct)) + } if (length(x$where)) { cat_line("Where: ", named_commas(x$where)) } @@ -195,6 +206,20 @@ sql_build.lazy_select_query <- function(op, con, ..., sql_options = NULL) { alias <- remote_name(op$x, null_if_local = FALSE) %||% unique_subquery_name() from <- sql_build(op$x, con, sql_options = sql_options) + + if (is_lazy_sql_part(op$distinct)) { + distinct <- get_select_sql( + select = new_lazy_select(op$distinct), + select_operation = op$select_operation, + in_vars = op_vars(op$x), + table_alias = alias, + con = con, + use_star = sql_options$use_star + )$select_sql + } else { + distinct <- op$distinct + } + select_sql_list <- get_select_sql( select = op$select, select_operation = op$select_operation, @@ -217,7 +242,7 @@ sql_build.lazy_select_query <- function(op, con, ..., sql_options = NULL) { having = translate_sql_(op$having, con = con, window = FALSE), window = select_sql_list$window_sql, order_by = translate_sql_(op$order_by, con = con), - distinct = op$distinct, + distinct = distinct, limit = op$limit, from_alias = alias ) diff --git a/R/query-select.R b/R/query-select.R index 5720fabe0..70f889eef 100644 --- a/R/query-select.R +++ b/R/query-select.R @@ -19,7 +19,7 @@ select_query <- function( check_character(window) check_character(order_by) check_number_whole(limit, allow_infinite = TRUE, allow_null = TRUE) - check_bool(distinct) + stopifnot(is_bool(distinct) || is.character(distinct)) check_string(from_alias, allow_null = TRUE) structure( @@ -41,10 +41,18 @@ select_query <- function( #' @export print.select_query <- function(x, ...) { - cat_line("") + cat_line( + "" + ) cat_line("From:") cat_line(indent_print(x$from)) + if (!is.logical(x$distinct)) { + cat_line("Dist On: ", named_commas(x$distinct)) + } if (length(x$select)) { cat_line("Select: ", named_commas(x$select)) } diff --git a/R/sql-clause.R b/R/sql-clause.R index 37c749d3f..f30a73a3d 100644 --- a/R/sql-clause.R +++ b/R/sql-clause.R @@ -56,13 +56,23 @@ sql_clause_select <- function( clause <- glue_sql2( con, "SELECT", - if (distinct) " DISTINCT", + sql_distinct(con, distinct), if (!is.null(top)) " TOP {.val top}" ) sql_clause(clause, select) } +sql_distinct <- function(con, distinct) { + if (isTRUE(distinct)) { + " DISTINCT" + } else if (isFALSE(distinct)) { + "" + } else { + glue_sql2(con, " DISTINCT ON ({.col distinct})") + } +} + sql_clause_from <- function(from, lvl = 0) { sql_clause("FROM", from, lvl = lvl) } diff --git a/R/verb-distinct.R b/R/verb-distinct.R index 5f76b8b96..9cbabac39 100644 --- a/R/verb-distinct.R +++ b/R/verb-distinct.R @@ -17,8 +17,31 @@ distinct.tbl_lazy <- function(.data, ..., .keep_all = FALSE) { grps <- syms(op_grps(.data)) empty_dots <- dots_n(...) == 0 - can_use_distinct <- !.keep_all || (empty_dots && is_empty(grps)) - if (!can_use_distinct) { + can_use_distinct <- + !.keep_all || + (empty_dots && is_empty(grps)) || + supports_distinct_on(.data$src$con) + + if (can_use_distinct) { + if (empty_dots) { + dots <- quos(!!!syms(colnames(.data))) + } else { + dots <- partial_eval_dots(.data, ..., .named = FALSE) + dots <- quos(!!!dots) + } + prep <- distinct_prepare_compat(.data, dots, group_vars = group_vars(.data)) + + if (!.keep_all) { + out <- dplyr::select(prep$data, prep$keep) + out$lazy_query <- add_distinct(out, distinct = TRUE) + } else { + out <- prep$data + out$lazy_query <- add_distinct( + out, + distinct = set_names(names(quos_auto_name(dots))) + ) + } + } else { needs_dummy_order <- is.null(op_sort(.data)) if (needs_dummy_order) { @@ -26,28 +49,15 @@ distinct.tbl_lazy <- function(.data, ..., .keep_all = FALSE) { .data <- .data %>% window_order(!!sym(dummy_order_vars)) } - .data <- .data %>% + out <- .data %>% group_by(..., .add = TRUE) %>% filter(row_number() == 1L) %>% group_by(!!!grps) if (needs_dummy_order) { - .data <- .data %>% window_order() + out <- out %>% window_order() } - - return(.data) - } - - if (empty_dots) { - dots <- quos(!!!syms(colnames(.data))) - } else { - dots <- partial_eval_dots(.data, ..., .named = FALSE) - dots <- quos(!!!dots) } - prep <- distinct_prepare_compat(.data, dots, group_vars = group_vars(.data)) - out <- dplyr::select(prep$data, prep$keep) - - out$lazy_query <- add_distinct(out) out } @@ -148,12 +158,16 @@ quo_is_variable_reference <- function(quo) { } -add_distinct <- function(.data) { +add_distinct <- function(.data, distinct) { lazy_query <- .data$lazy_query + if (!is_bool(distinct)) { + distinct <- syms(distinct) + } + out <- lazy_select_query( x = lazy_query, - distinct = TRUE + distinct = distinct ) # TODO this could also work for joins if (!is_lazy_select_query(lazy_query)) { @@ -173,6 +187,7 @@ add_distinct <- function(.data) { return(out) } - lazy_query$distinct <- TRUE + lazy_query$distinct <- distinct + lazy_query } diff --git a/R/verb-select.R b/R/verb-select.R index ef93594f3..39cc208b1 100644 --- a/R/verb-select.R +++ b/R/verb-select.R @@ -140,7 +140,8 @@ select_can_be_inlined <- function(lazy_query, vars) { ordered_present <- all(intersect(computed_columns, order_vars) %in% vars) # if the projection is distinct, it must be bijective - is_distinct <- is_true(lazy_query$distinct) + is_distinct <- !is.logical(lazy_query$distinct) || + is_true(lazy_query$distinct) is_bijective_projection <- identical(sort(unname(vars)), op_vars(lazy_query)) distinct_is_bijective <- !is_distinct || is_bijective_projection diff --git a/man/db-sql.Rd b/man/db-sql.Rd index c1aec3dea..4b1588daa 100644 --- a/man/db-sql.Rd +++ b/man/db-sql.Rd @@ -15,6 +15,7 @@ \alias{sql_query_rows} \alias{supports_window_clause} \alias{db_supports_table_alias_with_as} +\alias{supports_distinct_on} \alias{sql_query_select} \alias{sql_query_join} \alias{sql_query_multi_join} @@ -58,6 +59,8 @@ supports_window_clause(con) db_supports_table_alias_with_as(con) +supports_distinct_on(con) + sql_query_select( con, select, diff --git a/tests/testthat/test-backend-postgres.R b/tests/testthat/test-backend-postgres.R index f24ffb1e9..8a663d526 100644 --- a/tests/testthat/test-backend-postgres.R +++ b/tests/testthat/test-backend-postgres.R @@ -573,3 +573,33 @@ test_that("correctly escapes dates", { dd <- as.Date("2022-03-04") expect_equal(escape(dd, con = con), sql("'2022-03-04'::date")) }) + +# we test that again because postgres uses the DISTINCT ON feature +test_that("distinct can compute variables when .keep_all is TRUE", { + con <- src_test("postgres") + + out <- + local_db_table(con, data.frame(x = c(2, 1), y = c(1, 2)), "df_x") %>% + distinct(z = x + y, .keep_all = TRUE) %>% + collect() + + expect_named(out, c("x", "y", "z")) + expect_equal(out$z, 3) +}) + +test_that("distinct respects window_order when .keep_all is TRUE", { + con <- src_test("postgres") + + mf <- local_db_table(con, data.frame(x = c(1, 1, 2, 2), y = 1:4), "mf") + out <- mf %>% + window_order(desc(y)) %>% + distinct(x, .keep_all = TRUE) + + expect_equal(out %>% collect(), tibble(x = 1:2, y = c(2, 4))) + + expect_snapshot( + mf %>% + window_order(desc(y)) %>% + distinct(x, .keep_all = TRUE) + ) +})