@@ -5,7 +5,6 @@ use num_traits::AsPrimitive;
55use num_traits:: ToPrimitive ;
66use vortex_error:: VortexExpect ;
77use vortex_error:: VortexResult ;
8- use vortex_error:: vortex_bail;
98use vortex_error:: vortex_panic;
109use vortex_mask:: AllOr ;
1110use vortex_mask:: Mask ;
@@ -19,74 +18,51 @@ use super::primitive::sum_float_all;
1918use super :: primitive:: sum_signed_all;
2019use super :: primitive:: sum_unsigned_all;
2120use crate :: ArrayRef ;
22- use crate :: Canonical ;
23- use crate :: Columnar ;
2421use crate :: ExecutionCtx ;
25- use crate :: aggregate_fn:: AggregateFnRef ;
26- use crate :: aggregate_fn:: AggregateFnVTable ;
22+ use crate :: aggregate_fn:: EmptyOptions ;
2723use crate :: aggregate_fn:: GroupIds ;
28- use crate :: aggregate_fn:: kernels:: DynGroupedAggregateKernel ;
29- use crate :: aggregate_fn:: kernels:: GroupedAggregateKernelResult ;
24+ use crate :: aggregate_fn:: kernels:: GroupedAggregateKernel ;
25+ use crate :: aggregate_fn:: kernels:: GroupedAggregateKernelAdapter ;
26+ use crate :: arrays:: Bool ;
3027use crate :: arrays:: BoolArray ;
28+ use crate :: arrays:: Primitive ;
3129use crate :: arrays:: PrimitiveArray ;
3230use crate :: arrays:: bool:: BoolArrayExt ;
3331use crate :: dtype:: NativePType ;
3432use crate :: match_each_native_ptype;
3533
3634const MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS : usize = 4 ;
3735
36+ pub ( crate ) static SUM_GROUPED_KERNEL : GroupedAggregateKernelAdapter < Sum , SumGroupedKernel > =
37+ GroupedAggregateKernelAdapter :: new ( SumGroupedKernel ) ;
38+
3839#[ derive( Debug ) ]
3940pub ( crate ) struct SumGroupedKernel ;
4041
41- impl DynGroupedAggregateKernel for SumGroupedKernel {
42- fn grouped_aggregate (
42+ impl GroupedAggregateKernel < Sum > for SumGroupedKernel {
43+ fn grouped_accumulate (
4344 & self ,
44- aggregate_fn : & AggregateFnRef ,
45+ _options : & EmptyOptions ,
46+ partials : & mut [ SumPartial ] ,
4547 batch : & ArrayRef ,
4648 group_ids : & GroupIds ,
4749 ctx : & mut ExecutionCtx ,
48- ) -> VortexResult < Option < GroupedAggregateKernelResult > > {
49- let Some ( options) = aggregate_fn. as_opt :: < Sum > ( ) else {
50- return Ok ( None ) ;
51- } ;
52-
53- let columnar = batch. clone ( ) . execute :: < Columnar > ( ctx) ?;
54- match & columnar {
55- Columnar :: Canonical ( Canonical :: Primitive ( _) )
56- | Columnar :: Canonical ( Canonical :: Bool ( _) ) => { }
57- // Decimal and constants still use the universal grouped fallback.
58- Columnar :: Canonical ( Canonical :: Decimal ( _) ) | Columnar :: Constant ( _) => return Ok ( None ) ,
59- Columnar :: Canonical ( _) => {
60- vortex_bail ! ( "Unsupported canonical type for sum: {}" , columnar. dtype( ) )
61- }
50+ ) -> VortexResult < bool > {
51+ if let Some ( primitive) = batch. as_opt :: < Primitive > ( ) {
52+ let group_ids = group_ids. validated_ids ( ctx) ?;
53+ let primitive = primitive. into_owned ( ) ;
54+ accumulate_grouped_primitive ( partials, & primitive, group_ids. as_ref ( ) , ctx) ?;
55+ return Ok ( true ) ;
6256 }
6357
64- let partial_dtype = Sum
65- . partial_dtype ( options, batch. dtype ( ) )
66- . ok_or_else ( || vortex_error:: vortex_err!( "Unsupported sum dtype: {}" , batch. dtype( ) ) ) ?;
67- let ids = group_ids. validated_ids ( ctx) ?;
68- let mut partials = ( 0 ..group_ids. num_groups ( ) )
69- . map ( |_| Sum . empty_partial ( options, batch. dtype ( ) ) )
70- . collect :: < VortexResult < Vec < _ > > > ( ) ?;
71-
72- match & columnar {
73- Columnar :: Canonical ( Canonical :: Primitive ( p) ) => {
74- accumulate_grouped_primitive ( & mut partials, p, ids. as_ref ( ) , ctx) ?;
75- }
76- Columnar :: Canonical ( Canonical :: Bool ( b) ) => {
77- accumulate_grouped_bool ( & mut partials, b, ids. as_ref ( ) , ctx) ?;
78- }
79- Columnar :: Canonical ( Canonical :: Decimal ( _) ) | Columnar :: Constant ( _) => unreachable ! ( ) ,
80- Columnar :: Canonical ( _) => unreachable ! ( ) ,
58+ if let Some ( bools) = batch. as_opt :: < Bool > ( ) {
59+ let group_ids = group_ids. validated_ids ( ctx) ?;
60+ let bools = bools. into_owned ( ) ;
61+ accumulate_grouped_bool ( partials, & bools, group_ids. as_ref ( ) , ctx) ?;
62+ return Ok ( true ) ;
8163 }
8264
83- let Some ( partials) = Sum . partials_to_array ( & partials, & partial_dtype) ? else {
84- return Ok ( None ) ;
85- } ;
86- Ok ( Some ( GroupedAggregateKernelResult :: dense (
87- partials,
88- group_ids. num_groups ( ) ,
89- ) ?) )
65+ Ok ( false )
9066 }
9167}
9268
0 commit comments