Skip to content

Commit 8769ce4

Browse files
committed
feat: use spawned tasks to reduce call stack depth and avoid busy waiting
1 parent 1daa5ed commit 8769ce4

File tree

5 files changed

+72
-25
lines changed

5 files changed

+72
-25
lines changed

datafusion/common/src/error.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ impl From<GenericError> for DataFusionError {
350350
}
351351
}
352352

353+
impl From<JoinError> for DataFusionError {
354+
fn from(e: JoinError) -> Self {
355+
DataFusionError::ExecutionJoin(e)
356+
}
357+
}
358+
353359
impl Display for DataFusionError {
354360
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
355361
let error_prefix = self.error_prefix();

datafusion/physical-plan/src/joins/cross_join.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
//! Defines the cross join plan for loading the left side of the cross join
1919
//! and producing batches in parallel for the right partitions
2020
21+
use futures::FutureExt;
2122
use std::{any::Any, sync::Arc, task::Poll};
2223

2324
use super::utils::{
@@ -47,6 +48,7 @@ use datafusion_execution::TaskContext;
4748
use datafusion_physical_expr::equivalence::join_equivalence_properties;
4849

4950
use async_trait::async_trait;
51+
use datafusion_common_runtime::SpawnedTask;
5052
use futures::{ready, Stream, StreamExt, TryStreamExt};
5153

5254
/// Data of the left side that is buffered into memory
@@ -303,12 +305,13 @@ impl ExecutionPlan for CrossJoinExec {
303305

304306
let left_fut = self.left_fut.try_once(|| {
305307
let left_stream = self.left.execute(0, context)?;
306-
307-
Ok(load_left_input(
308-
left_stream,
309-
join_metrics.clone(),
310-
reservation,
311-
))
308+
let task = load_left_input(left_stream, join_metrics.clone(), reservation);
309+
Ok(async move {
310+
// Spawn a task the first time the stream is polled for the build phase.
311+
// This ensures the consumer of the join does not poll unnecessarily
312+
// while the build is ongoing
313+
SpawnedTask::spawn(task).map(|r| r?).await
314+
})
312315
})?;
313316

314317
if enforce_batch_size_in_joins {

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ use datafusion_physical_expr::PhysicalExprRef;
8282
use datafusion_physical_expr_common::datum::compare_op_for_nested;
8383

8484
use ahash::RandomState;
85+
use datafusion_common_runtime::SpawnedTask;
8586
use datafusion_physical_expr_common::physical_expr::fmt_sql;
86-
use futures::{ready, Stream, StreamExt, TryStreamExt};
87+
use futures::{ready, FutureExt, Stream, StreamExt, TryStreamExt};
8788
use parking_lot::Mutex;
8889

8990
/// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions.
@@ -810,15 +811,22 @@ impl ExecutionPlan for HashJoinExec {
810811
let reservation =
811812
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
812813

813-
Ok(collect_left_input(
814+
let task = collect_left_input(
814815
self.random_state.clone(),
815816
left_stream,
816817
on_left.clone(),
817818
join_metrics.clone(),
818819
reservation,
819820
need_produce_result_in_final(self.join_type),
820821
self.right().output_partitioning().partition_count(),
821-
))
822+
);
823+
824+
Ok(async move {
825+
// Spawn a task the first time the stream is polled for the build phase.
826+
// This ensures the consumer of the join does not poll unnecessarily
827+
// while the build is ongoing
828+
SpawnedTask::spawn(task).map(|r| r?).await
829+
})
822830
})?,
823831
PartitionMode::Partitioned => {
824832
let left_stream = self.left.execute(partition, Arc::clone(&context))?;
@@ -827,15 +835,22 @@ impl ExecutionPlan for HashJoinExec {
827835
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
828836
.register(context.memory_pool());
829837

830-
OnceFut::new(collect_left_input(
838+
let task = collect_left_input(
831839
self.random_state.clone(),
832840
left_stream,
833841
on_left.clone(),
834842
join_metrics.clone(),
835843
reservation,
836844
need_produce_result_in_final(self.join_type),
837845
1,
838-
))
846+
);
847+
848+
OnceFut::new(async move {
849+
// Spawn a task the first time the stream is polled for the build phase.
850+
// This ensures the consumer of the join does not poll unnecessarily
851+
// while the build is ongoing
852+
SpawnedTask::spawn(task).map(|r| r?).await
853+
})
839854
}
840855
PartitionMode::Auto => {
841856
return plan_err!(

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! [`NestedLoopJoinExec`]: joins without equijoin (equality predicates).
1919
20+
use futures::FutureExt;
2021
use std::any::Any;
2122
use std::fmt::Formatter;
2223
use std::sync::atomic::{AtomicUsize, Ordering};
@@ -61,6 +62,7 @@ use datafusion_physical_expr::equivalence::{
6162
join_equivalence_properties, ProjectionMapping,
6263
};
6364

65+
use datafusion_common_runtime::SpawnedTask;
6466
use futures::{ready, Stream, StreamExt, TryStreamExt};
6567
use parking_lot::Mutex;
6668

@@ -499,13 +501,19 @@ impl ExecutionPlan for NestedLoopJoinExec {
499501
let inner_table = self.inner_table.try_once(|| {
500502
let stream = self.left.execute(0, Arc::clone(&context))?;
501503

502-
Ok(collect_left_input(
504+
let task = collect_left_input(
503505
stream,
504506
join_metrics.clone(),
505507
load_reservation,
506508
need_produce_result_in_final(self.join_type),
507509
self.right().output_partitioning().partition_count(),
508-
))
510+
);
511+
Ok(async move {
512+
// Spawn a task the first time the stream is polled for the build phase.
513+
// This ensures the consumer of the join does not poll unnecessarily
514+
// while the build is ongoing
515+
SpawnedTask::spawn(task).map(|r| r?).await
516+
})
509517
})?;
510518

511519
let batch_size = context.session_config().batch_size();

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ use datafusion_execution::runtime_env::RuntimeEnv;
5454
use datafusion_execution::TaskContext;
5555
use datafusion_physical_expr::LexOrdering;
5656

57+
use datafusion_common_runtime::SpawnedTask;
5758
use futures::{StreamExt, TryStreamExt};
5859
use log::{debug, trace};
5960

@@ -1126,15 +1127,22 @@ impl ExecutionPlan for SortExec {
11261127
Ok(Box::pin(RecordBatchStreamAdapter::new(
11271128
self.schema(),
11281129
futures::stream::once(async move {
1129-
while let Some(batch) = input.next().await {
1130-
let batch = batch?;
1131-
topk.insert_batch(batch)?;
1132-
if topk.finished {
1133-
break;
1130+
// Spawn a task the first time the stream is polled for the sort phase.
1131+
// This ensures the consumer of the sort does not poll unnecessarily
1132+
// while the sort is ongoing
1133+
SpawnedTask::spawn(async move {
1134+
while let Some(batch) = input.next().await {
1135+
let batch = batch?;
1136+
topk.insert_batch(batch)?;
1137+
if topk.finished {
1138+
break;
1139+
}
11341140
}
1135-
}
1136-
topk.emit()
1141+
topk.emit()
1142+
})
1143+
.await
11371144
})
1145+
.map(|s| s?)
11381146
.try_flatten(),
11391147
)))
11401148
}
@@ -1152,12 +1160,19 @@ impl ExecutionPlan for SortExec {
11521160
Ok(Box::pin(RecordBatchStreamAdapter::new(
11531161
self.schema(),
11541162
futures::stream::once(async move {
1155-
while let Some(batch) = input.next().await {
1156-
let batch = batch?;
1157-
sorter.insert_batch(batch).await?;
1158-
}
1159-
sorter.sort().await
1163+
// Spawn a task the first time the stream is polled for the sort phase.
1164+
// This ensures the consumer of the sort does not poll unnecessarily
1165+
// while the sort is ongoing
1166+
SpawnedTask::spawn(async move {
1167+
while let Some(batch) = input.next().await {
1168+
let batch = batch?;
1169+
sorter.insert_batch(batch).await?;
1170+
}
1171+
sorter.sort().await
1172+
})
1173+
.await
11601174
})
1175+
.map(|s| s?)
11611176
.try_flatten(),
11621177
)))
11631178
}

0 commit comments

Comments
 (0)