@@ -21,10 +21,6 @@ use std::any::Any;
21
21
use std:: sync:: Arc ;
22
22
23
23
use super :: { DisplayAs , ExecutionPlanProperties , PlanProperties } ;
24
- use crate :: aggregates:: {
25
- no_grouping:: AggregateStream , row_hash:: GroupedHashAggregateStream ,
26
- topk_stream:: GroupedTopKAggregateStream ,
27
- } ;
28
24
use crate :: execution_plan:: { CardinalityEffect , EmissionType } ;
29
25
use crate :: metrics:: { ExecutionPlanMetricsSet , MetricsSet } ;
30
26
use crate :: windows:: get_ordered_partition_by_indices;
@@ -358,21 +354,10 @@ impl PartialEq for PhysicalGroupBy {
358
354
}
359
355
}
360
356
361
- #[ allow( clippy:: large_enum_variant) ]
362
357
enum StreamType {
363
- AggregateStream ( AggregateStream ) ,
364
- GroupedHash ( GroupedHashAggregateStream ) ,
365
- GroupedPriorityQueue ( GroupedTopKAggregateStream ) ,
366
- }
367
-
368
- impl From < StreamType > for SendableRecordBatchStream {
369
- fn from ( stream : StreamType ) -> Self {
370
- match stream {
371
- StreamType :: AggregateStream ( stream) => Box :: pin ( stream) ,
372
- StreamType :: GroupedHash ( stream) => Box :: pin ( stream) ,
373
- StreamType :: GroupedPriorityQueue ( stream) => Box :: pin ( stream) ,
374
- }
375
- }
358
+ AggregateStream ( SendableRecordBatchStream ) ,
359
+ GroupedHash ( SendableRecordBatchStream ) ,
360
+ GroupedPriorityQueue ( SendableRecordBatchStream ) ,
376
361
}
377
362
378
363
/// Hash aggregate execution plan
@@ -608,7 +593,7 @@ impl AggregateExec {
608
593
) -> Result < StreamType > {
609
594
// no group by at all
610
595
if self . group_by . expr . is_empty ( ) {
611
- return Ok ( StreamType :: AggregateStream ( AggregateStream :: new (
596
+ return Ok ( StreamType :: AggregateStream ( no_grouping :: aggregate_stream (
612
597
self , context, partition,
613
598
) ?) ) ;
614
599
}
@@ -617,13 +602,13 @@ impl AggregateExec {
617
602
if let Some ( limit) = self . limit {
618
603
if !self . is_unordered_unfiltered_group_by_distinct ( ) {
619
604
return Ok ( StreamType :: GroupedPriorityQueue (
620
- GroupedTopKAggregateStream :: new ( self , context, partition, limit) ?,
605
+ topk_stream :: aggregate_stream ( self , context, partition, limit) ?,
621
606
) ) ;
622
607
}
623
608
}
624
609
625
610
// grouping by something else and we need to just materialize all results
626
- Ok ( StreamType :: GroupedHash ( GroupedHashAggregateStream :: new (
611
+ Ok ( StreamType :: GroupedHash ( row_hash :: aggregate_stream (
627
612
self , context, partition,
628
613
) ?) )
629
614
}
@@ -998,8 +983,11 @@ impl ExecutionPlan for AggregateExec {
998
983
partition : usize ,
999
984
context : Arc < TaskContext > ,
1000
985
) -> Result < SendableRecordBatchStream > {
1001
- self . execute_typed ( partition, context)
1002
- . map ( |stream| stream. into ( ) )
986
+ match self . execute_typed ( partition, context) ? {
987
+ StreamType :: AggregateStream ( s) => Ok ( s) ,
988
+ StreamType :: GroupedHash ( s) => Ok ( s) ,
989
+ StreamType :: GroupedPriorityQueue ( s) => Ok ( s) ,
990
+ }
1003
991
}
1004
992
1005
993
fn metrics ( & self ) -> Option < MetricsSet > {
@@ -1274,7 +1262,7 @@ pub fn create_accumulators(
1274
1262
/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
1275
1263
pub fn finalize_aggregation (
1276
1264
accumulators : & mut [ AccumulatorItem ] ,
1277
- mode : & AggregateMode ,
1265
+ mode : AggregateMode ,
1278
1266
) -> Result < Vec < ArrayRef > > {
1279
1267
match mode {
1280
1268
AggregateMode :: Partial => {
@@ -2105,20 +2093,20 @@ mod tests {
2105
2093
let stream = partial_aggregate. execute_typed ( 0 , Arc :: clone ( & task_ctx) ) ?;
2106
2094
2107
2095
// ensure that we really got the version we wanted
2108
- match version {
2109
- 0 => {
2110
- assert ! ( matches!( stream, StreamType :: AggregateStream ( _) ) ) ;
2096
+ let stream = match stream {
2097
+ StreamType :: AggregateStream ( s) => {
2098
+ assert_eq ! ( version, 0 ) ;
2099
+ s
2111
2100
}
2112
- 1 => {
2113
- assert ! ( matches!( stream, StreamType :: GroupedHash ( _) ) ) ;
2101
+ StreamType :: GroupedHash ( s) => {
2102
+ assert ! ( version == 1 || version == 2 ) ;
2103
+ s
2114
2104
}
2115
- 2 => {
2116
- assert ! ( matches! ( stream, StreamType :: GroupedHash ( _ ) ) ) ;
2105
+ StreamType :: GroupedPriorityQueue ( _ ) => {
2106
+ panic ! ( "Unexpected GroupedPriorityQueue stream type" ) ;
2117
2107
}
2118
- _ => panic ! ( "Unknown version: {version}" ) ,
2119
- }
2108
+ } ;
2120
2109
2121
- let stream: SendableRecordBatchStream = stream. into ( ) ;
2122
2110
let err = collect ( stream) . await . unwrap_err ( ) ;
2123
2111
2124
2112
// error root cause traversal is a bit complicated, see #4172.
0 commit comments