Skip to content

Commit 7345e85

Browse files
committed
feat: use spawned tasks to reduce call stack depth and avoid busy waiting
1 parent 85eebcd commit 7345e85

File tree

12 files changed

+354
-232
lines changed

12 files changed

+354
-232
lines changed

datafusion/common/src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ impl From<GenericError> for DataFusionError {
350350
}
351351
}
352352

353+
impl From<JoinError> for DataFusionError {
354+
fn from(e: JoinError) -> Self {
355+
DataFusionError::ExecutionJoin(e)
356+
}
357+
}
358+
353359
impl Display for DataFusionError {
354360
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
355361
let error_prefix = self.error_prefix();

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ use std::any::Any;
2121
use std::sync::Arc;
2222

2323
use super::{DisplayAs, ExecutionPlanProperties, PlanProperties};
24-
use crate::aggregates::{
25-
no_grouping::AggregateStream, row_hash::GroupedHashAggregateStream,
26-
topk_stream::GroupedTopKAggregateStream,
27-
};
2824
use crate::execution_plan::{CardinalityEffect, EmissionType};
2925
use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
3026
use crate::windows::get_ordered_partition_by_indices;
@@ -358,21 +354,10 @@ impl PartialEq for PhysicalGroupBy {
358354
}
359355
}
360356

361-
#[allow(clippy::large_enum_variant)]
362357
enum StreamType {
363-
AggregateStream(AggregateStream),
364-
GroupedHash(GroupedHashAggregateStream),
365-
GroupedPriorityQueue(GroupedTopKAggregateStream),
366-
}
367-
368-
impl From<StreamType> for SendableRecordBatchStream {
369-
fn from(stream: StreamType) -> Self {
370-
match stream {
371-
StreamType::AggregateStream(stream) => Box::pin(stream),
372-
StreamType::GroupedHash(stream) => Box::pin(stream),
373-
StreamType::GroupedPriorityQueue(stream) => Box::pin(stream),
374-
}
375-
}
358+
AggregateStream(SendableRecordBatchStream),
359+
GroupedHash(SendableRecordBatchStream),
360+
GroupedPriorityQueue(SendableRecordBatchStream),
376361
}
377362

378363
/// Hash aggregate execution plan
@@ -608,7 +593,7 @@ impl AggregateExec {
608593
) -> Result<StreamType> {
609594
// no group by at all
610595
if self.group_by.expr.is_empty() {
611-
return Ok(StreamType::AggregateStream(AggregateStream::new(
596+
return Ok(StreamType::AggregateStream(no_grouping::aggregate_stream(
612597
self, context, partition,
613598
)?));
614599
}
@@ -617,13 +602,13 @@ impl AggregateExec {
617602
if let Some(limit) = self.limit {
618603
if !self.is_unordered_unfiltered_group_by_distinct() {
619604
return Ok(StreamType::GroupedPriorityQueue(
620-
GroupedTopKAggregateStream::new(self, context, partition, limit)?,
605+
topk_stream::aggregate_stream(self, context, partition, limit)?,
621606
));
622607
}
623608
}
624609

625610
// grouping by something else and we need to just materialize all results
626-
Ok(StreamType::GroupedHash(GroupedHashAggregateStream::new(
611+
Ok(StreamType::GroupedHash(row_hash::aggregate_stream(
627612
self, context, partition,
628613
)?))
629614
}
@@ -998,8 +983,11 @@ impl ExecutionPlan for AggregateExec {
998983
partition: usize,
999984
context: Arc<TaskContext>,
1000985
) -> Result<SendableRecordBatchStream> {
1001-
self.execute_typed(partition, context)
1002-
.map(|stream| stream.into())
986+
match self.execute_typed(partition, context)? {
987+
StreamType::AggregateStream(s) => Ok(s),
988+
StreamType::GroupedHash(s) => Ok(s),
989+
StreamType::GroupedPriorityQueue(s) => Ok(s),
990+
}
1003991
}
1004992

1005993
fn metrics(&self) -> Option<MetricsSet> {
@@ -1274,7 +1262,7 @@ pub fn create_accumulators(
12741262
/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
12751263
pub fn finalize_aggregation(
12761264
accumulators: &mut [AccumulatorItem],
1277-
mode: &AggregateMode,
1265+
mode: AggregateMode,
12781266
) -> Result<Vec<ArrayRef>> {
12791267
match mode {
12801268
AggregateMode::Partial => {
@@ -2105,20 +2093,20 @@ mod tests {
21052093
let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?;
21062094

21072095
// ensure that we really got the version we wanted
2108-
match version {
2109-
0 => {
2110-
assert!(matches!(stream, StreamType::AggregateStream(_)));
2096+
let stream = match stream {
2097+
StreamType::AggregateStream(s) => {
2098+
assert_eq!(version, 0);
2099+
s
21112100
}
2112-
1 => {
2113-
assert!(matches!(stream, StreamType::GroupedHash(_)));
2101+
StreamType::GroupedHash(s) => {
2102+
assert!(version == 1 || version == 2);
2103+
s
21142104
}
2115-
2 => {
2116-
assert!(matches!(stream, StreamType::GroupedHash(_)));
2105+
StreamType::GroupedPriorityQueue(_) => {
2106+
panic!("Unexpected GroupedPriorityQueue stream type");
21172107
}
2118-
_ => panic!("Unknown version: {version}"),
2119-
}
2108+
};
21202109

2121-
let stream: SendableRecordBatchStream = stream.into();
21222110
let err = collect(stream).await.unwrap_err();
21232111

21242112
// error root cause traversal is a bit complicated, see #4172.

0 commit comments

Comments
 (0)