diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 5d3ad3446ddb..fd99d408c4d1 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -212,6 +212,10 @@ main() { # same data as for tpch data_tpch "1" ;; + sort_tpch_limit) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -251,6 +255,7 @@ main() { run_cancellation run_parquet run_sort + run_sort_tpch_limit run_clickbench_1 run_clickbench_partitioned run_clickbench_extended @@ -320,6 +325,9 @@ main() { sort_tpch) run_sort_tpch ;; + sort_tpch_limit) + run_sort_tpch_limit + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -918,6 +926,15 @@ run_sort_tpch() { $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" } +# Runs the sort tpch integration benchmark with limit +run_sort_tpch_limit() { + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/sort_tpch_limit.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running sort tpch benchmark..." + + $CARGO_COMMAND --bin dfbench -- sort-tpch --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" --limit 100 +} compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 0b5780b9143f..685f440f3b7d 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,19 +18,22 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - compute::interleave_record_batch, + array::{BooleanArray, Scalar}, + compute::{interleave_record_batch, is_null, or, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; -use std::mem::size_of; +use arrow_ord::cmp::{gt, gt_eq, lt, lt_eq}; +use datafusion_expr::ColumnarValue; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; +use std::{mem::size_of, sync::RwLock}; use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; use crate::spill::get_record_batch_memory_size; use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; +use datafusion_common::{internal_datafusion_err, HashMap, ScalarValue}; use datafusion_common::Result; -use datafusion_common::{internal_datafusion_err, HashMap}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, runtime_env::RuntimeEnv, @@ -117,6 +120,8 @@ pub struct TopK { /// to be greater (by byte order, after row conversion) than the top K, /// which means the top K won't change and the computation can be finished early. pub(crate) finished: bool, + + thresholds: Arc>>>, } // Guesstimate for memory allocation: estimated number of bytes used per row in the RowConverter @@ -170,7 +175,7 @@ impl TopK { build_sort_fields(&common_sort_prefix, &schema)?; Some(RowConverter::new(input_sort_fields)?) }; - + let num_exprs = expr.len(); Ok(Self { schema: Arc::clone(&schema), metrics: TopKMetrics::new(metrics, partition_id), @@ -183,6 +188,7 @@ impl TopK { common_sort_prefix_converter: prefix_row_converter, common_sort_prefix: Arc::from(common_sort_prefix), finished: false, + thresholds: Arc::new(RwLock::new(vec![None; num_exprs])), }) } @@ -193,7 +199,7 @@ impl TopK { let baseline = self.metrics.baseline.clone(); let _timer = baseline.elapsed_compute().timer(); - let sort_keys: Vec = self + let mut sort_keys: Vec = self .expr .iter() .map(|expr| { @@ -202,26 +208,117 @@ impl TopK { }) .collect::>>()?; + // Selected indices in the input batch. + // Some indices may be pre-filtered if they exceed the heap’s current max value. + + let mut selected_rows = None; + + let threshold0 = self + .thresholds + .read() + .expect("Read lock should succeed")[0].clone(); + + // If the heap doesn't have k elements yet, we can't create thresholds + if let Some(threshold0) = threshold0 { + let threshold0 = threshold0.clone(); + // skip filtering if threshold is null + if !threshold0.is_null() { + // Convert to scalar value - should be a single value since we're evaluating on a single row batch + let threshold = Scalar::new(threshold0.to_array_of_size(1)?); + + // Create a filter for each sort key + let is_multi_col = self.expr.len() > 1; + + let mut filter = match (is_multi_col, self.expr[0].options.descending) { + (true, true) => BooleanArray::new( + gt_eq(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (true, false) => BooleanArray::new( + lt_eq(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (false, true) => BooleanArray::new( + gt(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + (false, false) => BooleanArray::new( + lt(&sort_keys[0], &threshold)?.values().clone(), + None, + ), + }; + if sort_keys[0].is_nullable() { + // Keep any null values + // TODO it is possible to optimize this based on the current threshold value + // and the nulls first/last option and the number of following sort keys + filter = or(&filter, &is_null(&sort_keys[0])?)?; + } + if filter.true_count() == 0 { + // No rows are less than the max row, so we can skip this batch + // Early completion is still possible, as last row might be greater + self.attempt_early_completion(&batch)?; + + return Ok(()); + } + + let filter_predicate = FilterBuilder::new(&filter); + let filter_predicate = if sort_keys.len() > 1 { + // Optimize filter when it has multiple sort keys + filter_predicate.optimize().build() + } else { + filter_predicate.build() + }; + selected_rows = Some(filter); + + sort_keys = sort_keys + .iter() + .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) + .collect::>>()?; + } + } + // reuse existing `Rows` to avoid reallocations let rows = &mut self.scratch_rows; rows.clear(); self.row_converter.append(rows, &sort_keys)?; - // TODO make this algorithmically better?: - // Idea: filter out rows >= self.heap.max() early (before passing to `RowConverter`) - // this avoids some work and also might be better vectorizable. let mut batch_entry = self.heap.register_batch(batch.clone()); - for (index, row) in rows.iter().enumerate() { - match self.heap.max() { - // heap has k items, and the new row is greater than the - // current max in the heap ==> it is not a new topk - Some(max_row) if row.as_ref() >= max_row.row() => {} - // don't yet have k items or new item is lower than the currently k low values - None | Some(_) => { - self.heap.add(&mut batch_entry, row, index); - self.metrics.row_replacements.add(1); - } + + let replacements = match selected_rows { + Some(filter) => { + self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry) } + None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry), + }; + + self.metrics.row_replacements.add(replacements); + + if replacements > 0 { + // Extract threshold values for each sort expression + // TODO: create a filter for each key that respects lexical ordering + // in the form of col0 < threshold0 || col0 == threshold0 && (col1 < threshold1 || ...) + // This could use BinaryExpr to benefit from short circuiting and early evaluation + // https://github.com/apache/datafusion/issues/15698 + // Extract the value for this column from the max row + let thresholds: Vec<_> = self + .expr + .iter() + .map(|expr| { + let value = expr + .expr + .evaluate(&batch_entry.batch.slice(self.heap.max().unwrap().index, 1))?; + Ok(Some(match value { + ColumnarValue::Array(array) => { + ScalarValue::try_from_array(&array, 0)? + } + ColumnarValue::Scalar(scalar_value) => scalar_value, + })) + }) + .collect::>()?; + self.thresholds + .write() + .expect("Write lock should succeed") + .clone_from(&thresholds); } self.heap.insert_batch_entry(batch_entry); @@ -235,10 +332,31 @@ impl TopK { // subsequent batches are guaranteed to be greater (by byte order, after row conversion) than the top K, // which means the top K won't change and the computation can be finished early. self.attempt_early_completion(&batch)?; - Ok(()) } + fn find_new_topk_items( + &mut self, + items: impl Iterator, + batch_entry: &mut RecordBatchEntry, + ) -> usize { + let mut replacements = 0; + let rows = &mut self.scratch_rows; + for (index, row) in items.zip(rows.iter()) { + match self.heap.max() { + // heap has k items, and the new row is greater than the + // current max in the heap ==> it is not a new topk + Some(max_row) if row.as_ref() >= max_row.row() => {} + // don't yet have k items or new item is lower than the currently k low values + None | Some(_) => { + self.heap.add(batch_entry, row, index); + replacements += 1; + } + } + } + replacements + } + /// If input ordering shares a common sort prefix with the TopK, and if the TopK's heap is full, /// check if the computation can be finished early. /// This is the case if the last row of the current batch is strictly greater than the max row in the heap, @@ -328,6 +446,7 @@ impl TopK { common_sort_prefix_converter: _, common_sort_prefix: _, finished: _, + thresholds: _, } = self; let _timer = metrics.baseline.elapsed_compute().timer(); // time updated on drop @@ -360,6 +479,10 @@ impl TopK { + self.scratch_rows.size() + self.heap.size() } + + pub fn thresholds(&self) -> &Arc>>> { + &self.thresholds + } } struct TopKMetrics {