Skip to content
This repository was archived by the owner on Jan 20, 2023. It is now read-only.

Commit 4b1e9e6

Browse files
authored
add window expression stream, delegated window aggregation to aggregate functions, and implement row_number (apache#375)
* Squashed commit of the following: commit 7fb3640 Author: Jiayu Liu <[email protected]> Date: Fri May 21 16:38:25 2021 +0800 row number done commit 1723926 Author: Jiayu Liu <[email protected]> Date: Fri May 21 16:05:50 2021 +0800 add row number commit bf5b8a5 Author: Jiayu Liu <[email protected]> Date: Fri May 21 15:04:49 2021 +0800 save commit d2ce852 Author: Jiayu Liu <[email protected]> Date: Fri May 21 14:53:05 2021 +0800 add streams commit 0a861a7 Author: Jiayu Liu <[email protected]> Date: Thu May 20 22:28:34 2021 +0800 save stream commit a9121af Author: Jiayu Liu <[email protected]> Date: Thu May 20 22:01:51 2021 +0800 update unit test commit 2af2a27 Author: Jiayu Liu <[email protected]> Date: Fri May 21 14:25:12 2021 +0800 fix unit test commit bb57c76 Author: Jiayu Liu <[email protected]> Date: Fri May 21 14:23:34 2021 +0800 use upper case commit 5d96e52 Author: Jiayu Liu <[email protected]> Date: Fri May 21 14:16:16 2021 +0800 fix unit test commit 1ecae8f Author: Jiayu Liu <[email protected]> Date: Fri May 21 12:27:26 2021 +0800 fix unit test commit bc2271d Author: Jiayu Liu <[email protected]> Date: Fri May 21 10:04:29 2021 +0800 fix error commit 880b94f Author: Jiayu Liu <[email protected]> Date: Fri May 21 08:24:00 2021 +0800 fix unit test commit 4e792e1 Author: Jiayu Liu <[email protected]> Date: Fri May 21 08:05:17 2021 +0800 fix test commit c36c04a Author: Jiayu Liu <[email protected]> Date: Fri May 21 00:07:54 2021 +0800 add more tests commit f5e64de Author: Jiayu Liu <[email protected]> Date: Thu May 20 23:41:36 2021 +0800 update commit a1eae86 Author: Jiayu Liu <[email protected]> Date: Thu May 20 23:36:15 2021 +0800 enrich unit test commit 0d2a214 Author: Jiayu Liu <[email protected]> Date: Thu May 20 23:25:43 2021 +0800 adding filter by todo commit 8b486d5 Author: Jiayu Liu <[email protected]> Date: Thu May 20 23:17:22 2021 +0800 adding more built-in functions commit abf08cd Author: Jiayu Liu <[email protected]> Date: Thu May 20 22:36:27 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb <[email protected]> commit 0cbca53 Author: Jiayu Liu <[email protected]> Date: Thu May 20 22:34:57 2021 +0800 Update datafusion/src/physical_plan/window_functions.rs Co-authored-by: Andrew Lamb <[email protected]> commit 831c069 Author: Jiayu Liu <[email protected]> Date: Thu May 20 22:34:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb <[email protected]> commit f70c739 Author: Jiayu Liu <[email protected]> Date: Thu May 20 22:33:04 2021 +0800 Update datafusion/src/logical_plan/builder.rs Co-authored-by: Andrew Lamb <[email protected]> commit 3ee87aa Author: Jiayu Liu <[email protected]> Date: Wed May 19 22:55:08 2021 +0800 fix unit test commit 5c4d92d Author: Jiayu Liu <[email protected]> Date: Wed May 19 22:48:26 2021 +0800 fix clippy commit a0b7526 Author: Jiayu Liu <[email protected]> Date: Wed May 19 22:46:38 2021 +0800 fix unused imports commit 1d3b076 Author: Jiayu Liu <[email protected]> Date: Thu May 13 18:51:14 2021 +0800 add window expr * fix unit test
1 parent 3593d1f commit 4b1e9e6

File tree

11 files changed

+736
-75
lines changed

11 files changed

+736
-75
lines changed

datafusion/src/execution/context.rs

+29
Original file line numberDiff line numberDiff line change
@@ -1268,6 +1268,35 @@ mod tests {
12681268
Ok(())
12691269
}
12701270

1271+
#[tokio::test]
1272+
async fn window() -> Result<()> {
1273+
let results = execute(
1274+
"SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5",
1275+
4,
1276+
)
1277+
.await?;
1278+
// result in one batch, although e.g. having 2 batches do not change
1279+
// result semantics, having a len=1 assertion upfront keeps surprises
1280+
// at bay
1281+
assert_eq!(results.len(), 1);
1282+
1283+
let expected = vec![
1284+
"+----+----+---------+-----------+---------+---------+---------+",
1285+
"| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
1286+
"+----+----+---------+-----------+---------+---------+---------+",
1287+
"| 0 | 1 | 220 | 40 | 10 | 1 | 5.5 |",
1288+
"| 0 | 2 | 220 | 40 | 10 | 1 | 5.5 |",
1289+
"| 0 | 3 | 220 | 40 | 10 | 1 | 5.5 |",
1290+
"| 0 | 4 | 220 | 40 | 10 | 1 | 5.5 |",
1291+
"| 0 | 5 | 220 | 40 | 10 | 1 | 5.5 |",
1292+
"+----+----+---------+-----------+---------+---------+---------+",
1293+
];
1294+
1295+
// window function shall respect ordering
1296+
assert_batches_eq!(expected, &results);
1297+
Ok(())
1298+
}
1299+
12711300
#[tokio::test]
12721301
async fn aggregate() -> Result<()> {
12731302
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;

datafusion/src/physical_plan/expressions/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ mod min_max;
4141
mod negative;
4242
mod not;
4343
mod nullif;
44+
mod row_number;
4445
mod sum;
4546
mod try_cast;
4647

@@ -58,6 +59,7 @@ pub use min_max::{Max, Min};
5859
pub use negative::{negative, NegativeExpr};
5960
pub use not::{not, NotExpr};
6061
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
62+
pub use row_number::RowNumber;
6163
pub use sum::{sum_return_type, Sum};
6264
pub use try_cast::{try_cast, TryCastExpr};
6365
/// returns the name of the state
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Defines physical expression for `row_number` that can evaluated at runtime during query execution
19+
20+
use crate::error::Result;
21+
use crate::physical_plan::{
22+
window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
23+
};
24+
use crate::scalar::ScalarValue;
25+
use arrow::array::{ArrayRef, UInt64Array};
26+
use arrow::datatypes::{DataType, Field};
27+
use std::any::Any;
28+
use std::sync::Arc;
29+
30+
/// row_number expression
31+
#[derive(Debug)]
32+
pub struct RowNumber {
33+
name: String,
34+
}
35+
36+
impl RowNumber {
37+
/// Create a new ROW_NUMBER function
38+
pub fn new(name: String) -> Self {
39+
Self { name }
40+
}
41+
}
42+
43+
impl BuiltInWindowFunctionExpr for RowNumber {
44+
/// Return a reference to Any that can be used for downcasting
45+
fn as_any(&self) -> &dyn Any {
46+
self
47+
}
48+
49+
fn field(&self) -> Result<Field> {
50+
let nullable = false;
51+
let data_type = DataType::UInt64;
52+
Ok(Field::new(&self.name(), data_type, nullable))
53+
}
54+
55+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
56+
vec![]
57+
}
58+
59+
fn name(&self) -> &str {
60+
self.name.as_str()
61+
}
62+
63+
fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
64+
Ok(Box::new(RowNumberAccumulator::new()))
65+
}
66+
}
67+
68+
#[derive(Debug)]
69+
struct RowNumberAccumulator {
70+
row_number: u64,
71+
}
72+
73+
impl RowNumberAccumulator {
74+
/// new row_number accumulator
75+
pub fn new() -> Self {
76+
// row number is 1 based
77+
Self { row_number: 1 }
78+
}
79+
}
80+
81+
impl WindowAccumulator for RowNumberAccumulator {
82+
fn scan(&mut self, _values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
83+
let result = Some(ScalarValue::UInt64(Some(self.row_number)));
84+
self.row_number += 1;
85+
Ok(result)
86+
}
87+
88+
fn scan_batch(
89+
&mut self,
90+
num_rows: usize,
91+
_values: &[ArrayRef],
92+
) -> Result<Option<ArrayRef>> {
93+
let new_row_number = self.row_number + (num_rows as u64);
94+
// TODO: probably would be nice to have a (optimized) kernel for this at some point to
95+
// generate an array like this.
96+
let result = UInt64Array::from_iter_values(self.row_number..new_row_number);
97+
self.row_number = new_row_number;
98+
Ok(Some(Arc::new(result)))
99+
}
100+
101+
fn evaluate(&self) -> Result<Option<ScalarValue>> {
102+
Ok(None)
103+
}
104+
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use super::*;
109+
use crate::error::Result;
110+
use arrow::record_batch::RecordBatch;
111+
use arrow::{array::*, datatypes::*};
112+
113+
#[test]
114+
fn row_number_all_null() -> Result<()> {
115+
let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
116+
None, None, None, None, None, None, None, None,
117+
]));
118+
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
119+
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
120+
121+
let row_number = Arc::new(RowNumber::new("row_number".to_owned()));
122+
123+
let mut acc = row_number.create_accumulator()?;
124+
let expr = row_number.expressions();
125+
let values = expr
126+
.iter()
127+
.map(|e| e.evaluate(&batch))
128+
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
129+
.collect::<Result<Vec<_>>>()?;
130+
131+
let result = acc.scan_batch(batch.num_rows(), &values)?;
132+
assert_eq!(true, result.is_some());
133+
134+
let result = result.unwrap();
135+
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
136+
let result = result.values();
137+
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
138+
139+
let result = acc.evaluate()?;
140+
assert_eq!(false, result.is_some());
141+
Ok(())
142+
}
143+
144+
#[test]
145+
fn row_number_all_values() -> Result<()> {
146+
let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
147+
true, false, true, false, false, true, false, true,
148+
]));
149+
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
150+
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;
151+
152+
let row_number = Arc::new(RowNumber::new("row_number".to_owned()));
153+
154+
let mut acc = row_number.create_accumulator()?;
155+
let expr = row_number.expressions();
156+
let values = expr
157+
.iter()
158+
.map(|e| e.evaluate(&batch))
159+
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
160+
.collect::<Result<Vec<_>>>()?;
161+
162+
let result = acc.scan_batch(batch.num_rows(), &values)?;
163+
assert_eq!(true, result.is_some());
164+
165+
let result = result.unwrap();
166+
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
167+
let result = result.values();
168+
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);
169+
170+
let result = acc.evaluate()?;
171+
assert_eq!(false, result.is_some());
172+
Ok(())
173+
}
174+
}

datafusion/src/physical_plan/hash_aggregate.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ impl GroupedHashAggregateStream {
712712
tx.send(result)
713713
});
714714

715-
GroupedHashAggregateStream {
715+
Self {
716716
schema,
717717
output: rx,
718718
finished: false,
@@ -825,7 +825,8 @@ fn aggregate_expressions(
825825
}
826826

827827
pin_project! {
828-
struct HashAggregateStream {
828+
/// stream struct for hash aggregation
829+
pub struct HashAggregateStream {
829830
schema: SchemaRef,
830831
#[pin]
831832
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
@@ -878,7 +879,7 @@ impl HashAggregateStream {
878879
tx.send(result)
879880
});
880881

881-
HashAggregateStream {
882+
Self {
882883
schema,
883884
output: rx,
884885
finished: false,

datafusion/src/physical_plan/mod.rs

+73-8
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717

1818
//! Traits for physical query plan, supporting parallel execution for partitioned relations.
1919
20-
use std::fmt::{self, Debug, Display};
21-
use std::sync::atomic::{AtomicUsize, Ordering};
22-
use std::sync::Arc;
23-
use std::{any::Any, pin::Pin};
24-
2520
use crate::execution::context::ExecutionContextState;
2621
use crate::logical_plan::LogicalPlan;
27-
use crate::{error::Result, scalar::ScalarValue};
22+
use crate::{
23+
error::{DataFusionError, Result},
24+
scalar::ScalarValue,
25+
};
2826
use arrow::datatypes::{DataType, Schema, SchemaRef};
2927
use arrow::error::Result as ArrowResult;
3028
use arrow::record_batch::RecordBatch;
3129
use arrow::{array::ArrayRef, datatypes::Field};
32-
3330
use async_trait::async_trait;
3431
pub use display::DisplayFormatType;
3532
use futures::stream::Stream;
33+
use std::fmt::{self, Debug, Display};
34+
use std::sync::atomic::{AtomicUsize, Ordering};
35+
use std::sync::Arc;
36+
use std::{any::Any, pin::Pin};
3637

3738
use self::{display::DisplayableExecutionPlan, merge::MergeExec};
3839
use hashbrown::HashMap;
@@ -457,10 +458,22 @@ pub trait WindowExpr: Send + Sync + Debug {
457458
fn name(&self) -> &str {
458459
"WindowExpr: default name"
459460
}
461+
462+
/// the accumulator used to accumulate values from the expressions.
463+
/// the accumulator expects the same number of arguments as `expressions` and must
464+
/// return states with the same description as `state_fields`
465+
fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>>;
466+
467+
/// expressions that are passed to the WindowAccumulator.
468+
/// Functions which take a single input argument, such as `sum`, return a single [`Expr`],
469+
/// others (e.g. `cov`) return many.
470+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
460471
}
461472

462473
/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
463-
/// generically accumulates values. An accumulator knows how to:
474+
/// generically accumulates values.
475+
///
476+
/// An accumulator knows how to:
464477
/// * update its state from inputs via `update`
465478
/// * convert its internal state to a vector of scalar values
466479
/// * update its state from multiple accumulators' states via `merge`
@@ -509,6 +522,58 @@ pub trait Accumulator: Send + Sync + Debug {
509522
fn evaluate(&self) -> Result<ScalarValue>;
510523
}
511524

525+
/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple
526+
/// rows and generically accumulates values.
527+
///
528+
/// An accumulator knows how to:
529+
/// * update its state from inputs via `update`
530+
/// * convert its internal state to a vector of scalar values
531+
/// * update its state from multiple accumulators' states via `merge`
532+
/// * compute the final value from its internal state via `evaluate`
533+
pub trait WindowAccumulator: Send + Sync + Debug {
534+
/// scans the accumulator's state from a vector of scalars, similar to Accumulator it also
535+
/// optionally generates values.
536+
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>>;
537+
538+
/// scans the accumulator's state from a vector of arrays.
539+
fn scan_batch(
540+
&mut self,
541+
num_rows: usize,
542+
values: &[ArrayRef],
543+
) -> Result<Option<ArrayRef>> {
544+
if values.is_empty() {
545+
return Ok(None);
546+
};
547+
// transpose columnar to row based so that we can apply window
548+
let result = (0..num_rows)
549+
.map(|index| {
550+
let v = values
551+
.iter()
552+
.map(|array| ScalarValue::try_from_array(array, index))
553+
.collect::<Result<Vec<_>>>()?;
554+
self.scan(&v)
555+
})
556+
.collect::<Result<Vec<Option<ScalarValue>>>>()?
557+
.into_iter()
558+
.collect::<Option<Vec<ScalarValue>>>();
559+
560+
Ok(match result {
561+
Some(arr) if num_rows == arr.len() => Some(ScalarValue::iter_to_array(&arr)?),
562+
None => None,
563+
Some(arr) => {
564+
return Err(DataFusionError::Internal(format!(
565+
"expect scan batch to return {:?} rows, but got {:?}",
566+
num_rows,
567+
arr.len()
568+
)))
569+
}
570+
})
571+
}
572+
573+
/// returns its value based on its current state.
574+
fn evaluate(&self) -> Result<Option<ScalarValue>>;
575+
}
576+
512577
pub mod aggregates;
513578
pub mod array_expressions;
514579
pub mod coalesce_batches;

datafusion/src/physical_plan/planner.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ impl DefaultPhysicalPlanner {
147147
// Initially need to perform the aggregate and then merge the partitions
148148
let input_exec = self.create_initial_plan(input, ctx_state)?;
149149
let input_schema = input_exec.schema();
150-
let physical_input_schema = input_exec.as_ref().schema();
150+
151151
let logical_input_schema = input.as_ref().schema();
152+
let physical_input_schema = input_exec.as_ref().schema();
153+
152154
let window_expr = window_expr
153155
.iter()
154156
.map(|e| {

datafusion/src/physical_plan/sort.rs

+1
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ fn sort_batches(
250250
}
251251

252252
pin_project! {
253+
/// stream for sort plan
253254
struct SortStream {
254255
#[pin]
255256
output: futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,

0 commit comments

Comments
 (0)