Skip to content

Commit 9d5ad8f

Browse files
committed
Make YieldStream public to allow static dispatch
1 parent 4211db7 commit 9d5ad8f

File tree

6 files changed

+101
-130
lines changed

6 files changed

+101
-130
lines changed

datafusion/physical-plan/src/aggregates/no_grouping.rs

Lines changed: 66 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,29 @@
1717

1818
//! Aggregate without grouping columns
1919
20+
use super::AggregateExec;
2021
use crate::aggregates::{
2122
aggregate_expressions, create_accumulators, finalize_aggregation, AccumulatorItem,
2223
AggregateMode,
2324
};
25+
use crate::filter::batch_filter;
2426
use crate::metrics::{BaselineMetrics, RecordOutput};
27+
use crate::poll_budget::PollBudget;
2528
use crate::{RecordBatchStream, SendableRecordBatchStream};
2629
use arrow::datatypes::SchemaRef;
2730
use arrow::record_batch::RecordBatch;
2831
use datafusion_common::Result;
32+
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
2933
use datafusion_execution::TaskContext;
3034
use datafusion_physical_expr::PhysicalExpr;
31-
use futures::stream::BoxStream;
35+
use futures::stream::{Stream, StreamExt};
36+
use futures::FutureExt;
3237
use std::borrow::Cow;
3338
use std::sync::Arc;
34-
use std::task::{Context, Poll};
35-
36-
use super::AggregateExec;
37-
use crate::filter::batch_filter;
38-
use crate::poll_budget::PollBudget;
39-
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
40-
use futures::stream::{Stream, StreamExt};
39+
use std::task::{ready, Context, Poll};
4140

4241
/// stream struct for aggregation without grouping columns
4342
pub(crate) struct AggregateStream {
44-
stream: BoxStream<'static, Result<RecordBatch>>,
45-
schema: SchemaRef,
46-
}
47-
48-
/// Actual implementation of [`AggregateStream`].
49-
///
50-
/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem
51-
/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with
52-
/// [`futures::stream::unfold`].
53-
///
54-
/// The latter requires a state object, which is [`AggregateStreamInner`].
55-
struct AggregateStreamInner {
5643
schema: SchemaRef,
5744
mode: AggregateMode,
5845
input: SendableRecordBatchStream,
@@ -62,6 +49,7 @@ struct AggregateStreamInner {
6249
accumulators: Vec<AccumulatorItem>,
6350
reservation: MemoryReservation,
6451
finished: bool,
52+
poll_budget: PollBudget,
6553
}
6654

6755
impl AggregateStream {
@@ -71,7 +59,6 @@ impl AggregateStream {
7159
context: Arc<TaskContext>,
7260
partition: usize,
7361
) -> Result<Self> {
74-
let agg_schema = Arc::clone(&agg.schema);
7562
let agg_filter_expr = agg.filter_expr.clone();
7663

7764
let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
@@ -91,81 +78,17 @@ impl AggregateStream {
9178
let reservation = MemoryConsumer::new(format!("AggregateStream[{partition}]"))
9279
.register(context.memory_pool());
9380

94-
let inner = AggregateStreamInner {
81+
Ok(AggregateStream {
9582
schema: Arc::clone(&agg.schema),
9683
mode: agg.mode,
97-
input: PollBudget::from(context.as_ref()).wrap_stream(input),
84+
input,
9885
baseline_metrics,
9986
aggregate_expressions,
10087
filter_expressions,
10188
accumulators,
10289
reservation,
10390
finished: false,
104-
};
105-
let stream = futures::stream::unfold(inner, |mut this| async move {
106-
if this.finished {
107-
return None;
108-
}
109-
110-
let elapsed_compute = this.baseline_metrics.elapsed_compute();
111-
112-
loop {
113-
let result = match this.input.next().await {
114-
Some(Ok(batch)) => {
115-
let timer = elapsed_compute.timer();
116-
let result = aggregate_batch(
117-
&this.mode,
118-
batch,
119-
&mut this.accumulators,
120-
&this.aggregate_expressions,
121-
&this.filter_expressions,
122-
);
123-
124-
timer.done();
125-
126-
// allocate memory
127-
// This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
128-
// overshooting a bit. Also this means we either store the whole record batch or not.
129-
match result
130-
.and_then(|allocated| this.reservation.try_grow(allocated))
131-
{
132-
Ok(_) => continue,
133-
Err(e) => Err(e),
134-
}
135-
}
136-
Some(Err(e)) => Err(e),
137-
None => {
138-
this.finished = true;
139-
let timer = this.baseline_metrics.elapsed_compute().timer();
140-
let result =
141-
finalize_aggregation(&mut this.accumulators, &this.mode)
142-
.and_then(|columns| {
143-
RecordBatch::try_new(
144-
Arc::clone(&this.schema),
145-
columns,
146-
)
147-
.map_err(Into::into)
148-
})
149-
.record_output(&this.baseline_metrics);
150-
151-
timer.done();
152-
153-
result
154-
}
155-
};
156-
157-
this.finished = true;
158-
return Some((result, this));
159-
}
160-
});
161-
162-
// seems like some consumers call this stream even after it returned `None`, so let's fuse the stream.
163-
let stream = stream.fuse();
164-
let stream = Box::pin(stream);
165-
166-
Ok(Self {
167-
schema: agg_schema,
168-
stream,
91+
poll_budget: PollBudget::from(context.as_ref()),
16992
})
17093
}
17194
}
@@ -178,7 +101,61 @@ impl Stream for AggregateStream {
178101
cx: &mut Context<'_>,
179102
) -> Poll<Option<Self::Item>> {
180103
let this = &mut *self;
181-
this.stream.poll_next_unpin(cx)
104+
105+
if this.finished {
106+
return Poll::Ready(None);
107+
}
108+
109+
let elapsed_compute = this.baseline_metrics.elapsed_compute();
110+
111+
let mut consume_budget = this.poll_budget.consume_budget();
112+
113+
loop {
114+
let result = match ready!(this.input.poll_next_unpin(cx)) {
115+
Some(Ok(batch)) => {
116+
let timer = elapsed_compute.timer();
117+
let result = aggregate_batch(
118+
&this.mode,
119+
batch,
120+
&mut this.accumulators,
121+
&this.aggregate_expressions,
122+
&this.filter_expressions,
123+
);
124+
125+
timer.done();
126+
127+
// allocate memory
128+
// This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with
129+
// overshooting a bit. Also this means we either store the whole record batch or not.
130+
match result
131+
.and_then(|allocated| this.reservation.try_grow(allocated))
132+
{
133+
Ok(_) => {
134+
ready!(consume_budget.poll_unpin(cx));
135+
continue;
136+
}
137+
Err(e) => Err(e),
138+
}
139+
}
140+
Some(Err(e)) => Err(e),
141+
None => {
142+
let timer = this.baseline_metrics.elapsed_compute().timer();
143+
let result = finalize_aggregation(&mut this.accumulators, &this.mode)
144+
.and_then(|columns| {
145+
RecordBatch::try_new(Arc::clone(&this.schema), columns)
146+
.map_err(Into::into)
147+
})
148+
.record_output(&this.baseline_metrics);
149+
150+
timer.done();
151+
152+
result
153+
}
154+
};
155+
156+
this.finished = true;
157+
return Poll::Ready(Some(result));
158+
}
182159
}
183160
}
184161

datafusion/physical-plan/src/joins/cross_join.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
4646
use datafusion_execution::TaskContext;
4747
use datafusion_physical_expr::equivalence::join_equivalence_properties;
4848

49-
use crate::poll_budget::PollBudget;
49+
use crate::poll_budget::{PollBudget, YieldStream};
5050
use async_trait::async_trait;
5151
use futures::{ready, Stream, StreamExt, TryStreamExt};
5252

@@ -189,7 +189,7 @@ impl CrossJoinExec {
189189

190190
/// Asynchronously collect the result of the left child
191191
async fn load_left_input(
192-
stream: SendableRecordBatchStream,
192+
stream: YieldStream,
193193
metrics: BuildProbeJoinMetrics,
194194
reservation: MemoryReservation,
195195
) -> Result<JoinLeftData> {

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ use datafusion_physical_expr::equivalence::{
8181
use datafusion_physical_expr::PhysicalExprRef;
8282
use datafusion_physical_expr_common::datum::compare_op_for_nested;
8383

84-
use crate::poll_budget::PollBudget;
84+
use crate::poll_budget::{PollBudget, YieldStream};
8585
use ahash::RandomState;
8686
use datafusion_physical_expr_common::physical_expr::fmt_sql;
8787
use futures::{ready, Stream, StreamExt, TryStreamExt};
@@ -953,7 +953,7 @@ impl ExecutionPlan for HashJoinExec {
953953
/// hash table (`LeftJoinData`)
954954
async fn collect_left_input(
955955
random_state: RandomState,
956-
left_stream: SendableRecordBatchStream,
956+
left_stream: YieldStream,
957957
on_left: Vec<PhysicalExprRef>,
958958
metrics: BuildProbeJoinMetrics,
959959
reservation: MemoryReservation,

datafusion/physical-plan/src/joins/nested_loop_join.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ use datafusion_physical_expr::equivalence::{
6161
join_equivalence_properties, ProjectionMapping,
6262
};
6363

64-
use crate::poll_budget::PollBudget;
64+
use crate::poll_budget::{PollBudget, YieldStream};
6565
use futures::{ready, Stream, StreamExt, TryStreamExt};
6666
use parking_lot::Mutex;
6767

@@ -626,7 +626,7 @@ impl ExecutionPlan for NestedLoopJoinExec {
626626

627627
/// Asynchronously collect input into a single batch, and creates `JoinLeftData` from it
628628
async fn collect_left_input(
629-
stream: SendableRecordBatchStream,
629+
stream: YieldStream,
630630
join_metrics: BuildProbeJoinMetrics,
631631
reservation: MemoryReservation,
632632
with_visited_left_side: bool,

datafusion/physical-plan/src/poll_budget.rs

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,8 @@ impl PollBudget {
5454
}
5555
}
5656

57-
pub fn wrap_stream(
58-
&self,
59-
inner: SendableRecordBatchStream,
60-
) -> SendableRecordBatchStream {
61-
match self.budget {
62-
None => inner,
63-
Some(budget) => {
64-
Box::pin(YieldStream::new(inner, budget)) as SendableRecordBatchStream
65-
}
66-
}
57+
pub fn wrap_stream(&self, inner: SendableRecordBatchStream) -> YieldStream {
58+
YieldStream::new(inner, self.budget)
6759
}
6860
}
6961

@@ -102,18 +94,18 @@ impl Future for ConsumeBudget {
10294
}
10395
}
10496

105-
struct YieldStream {
97+
pub struct YieldStream {
10698
inner: SendableRecordBatchStream,
107-
budget: u8,
108-
remaining: u8,
99+
budget: Option<u8>,
100+
remaining: Option<u8>,
109101
}
110102

111103
impl YieldStream {
112-
pub fn new(inner: SendableRecordBatchStream, budget: u8) -> Self {
104+
pub fn new(inner: SendableRecordBatchStream, budget: Option<u8>) -> Self {
113105
Self {
114106
inner,
115107
budget,
116-
remaining: 0,
108+
remaining: budget,
117109
}
118110
}
119111
}
@@ -125,21 +117,23 @@ impl Stream for YieldStream {
125117
mut self: Pin<&mut Self>,
126118
cx: &mut Context<'_>,
127119
) -> Poll<Option<Self::Item>> {
128-
if self.remaining == 0 {
129-
self.remaining = self.budget;
130-
cx.waker().wake_by_ref();
131-
return Pending;
132-
}
133-
134-
match self.inner.poll_next_unpin(cx) {
135-
ready @ Ready(Some(_)) => {
136-
self.remaining -= 1;
137-
ready
138-
}
139-
other => {
120+
match self.remaining {
121+
None => self.inner.poll_next_unpin(cx),
122+
Some(0) => {
140123
self.remaining = self.budget;
141-
other
124+
cx.waker().wake_by_ref();
125+
Pending
142126
}
127+
Some(remaining) => match self.inner.poll_next_unpin(cx) {
128+
ready @ Ready(Some(_)) => {
129+
self.remaining = Some(remaining - 1);
130+
ready
131+
}
132+
other => {
133+
self.remaining = self.budget;
134+
other
135+
}
136+
},
143137
}
144138
}
145139
}

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
//! It will do in-memory sorting if it has enough memory budget
2020
//! but spills to disk if needed.
2121
22-
use std::any::Any;
23-
use std::fmt;
24-
use std::fmt::{Debug, Formatter};
25-
use std::sync::Arc;
26-
2722
use crate::common::spawn_buffered;
2823
use crate::execution_plan::{Boundedness, CardinalityEffect, EmissionType};
2924
use crate::expressions::PhysicalSortExpr;
@@ -43,6 +38,10 @@ use crate::{
4338
ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream,
4439
Statistics,
4540
};
41+
use std::any::Any;
42+
use std::fmt;
43+
use std::fmt::{Debug, Formatter};
44+
use std::sync::Arc;
4645

4746
use arrow::array::{Array, RecordBatch, RecordBatchOptions, StringViewArray};
4847
use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays};
@@ -1107,10 +1106,9 @@ impl ExecutionPlan for SortExec {
11071106
.equivalence_properties()
11081107
.ordering_satisfy_requirement(requirement);
11091108

1110-
let mut input = PollBudget::from(context.as_ref()).wrap_stream(input);
1111-
11121109
match (sort_satisfied, self.fetch.as_ref()) {
11131110
(true, Some(fetch)) => Ok(Box::pin(LimitStream::new(
1111+
// limit is not a pipeline breaking stream, so poll budget is not required
11141112
input,
11151113
0,
11161114
Some(*fetch),
@@ -1128,6 +1126,7 @@ impl ExecutionPlan for SortExec {
11281126
context.runtime_env(),
11291127
&self.metrics_set,
11301128
)?;
1129+
let mut input = PollBudget::from(context.as_ref()).wrap_stream(input);
11311130
Ok(Box::pin(RecordBatchStreamAdapter::new(
11321131
self.schema(),
11331132
futures::stream::once(async move {
@@ -1154,6 +1153,7 @@ impl ExecutionPlan for SortExec {
11541153
&self.metrics_set,
11551154
context.runtime_env(),
11561155
)?;
1156+
let mut input = PollBudget::from(context.as_ref()).wrap_stream(input);
11571157
Ok(Box::pin(RecordBatchStreamAdapter::new(
11581158
self.schema(),
11591159
futures::stream::once(async move {

0 commit comments

Comments
 (0)