@@ -7,7 +7,6 @@ use vortex_error::vortex_ensure;
77use vortex_error:: vortex_err;
88
99use crate :: ArrayRef ;
10- use crate :: Columnar ;
1110use crate :: ExecutionCtx ;
1211use crate :: IntoArray ;
1312use crate :: aggregate_fn:: Accumulator ;
@@ -17,16 +16,107 @@ use crate::aggregate_fn::AggregateFnVTable;
1716use crate :: aggregate_fn:: DynAccumulator ;
1817use crate :: aggregate_fn:: kernels:: GroupedAggregateKernelResult ;
1918use crate :: aggregate_fn:: session:: AggregateFnSessionExt ;
19+ use crate :: array:: ArrayId ;
20+ use crate :: arrays:: PrimitiveArray ;
2021use crate :: builders:: builder_with_capacity;
2122use crate :: columnar:: AnyColumnar ;
2223use crate :: dtype:: DType ;
24+ use crate :: dtype:: Nullability ;
25+ use crate :: dtype:: PType ;
2326use crate :: executor:: max_iterations;
2427use crate :: scalar:: Scalar ;
28+ use crate :: validity:: Validity ;
2529
2630/// Reference-counted type-erased grouped accumulator.
2731pub type GroupedAccumulatorRef = Box < dyn DynGroupedAccumulator > ;
2832
29- /// An accumulator used for computing aggregates over dense group ids.
33+ /// Encoded group ids parallel to a grouped aggregate input batch.
34+ ///
35+ /// The array must contain non-null `u32` ordinals. The ordinals are dense state slots in
36+ /// `0..num_groups`, not raw group keys. Range validation may require executing the encoded array,
37+ /// so kernels that can prove the invariant from encoded metadata should avoid materializing and
38+ /// otherwise call [`Self::validated_ids`] before indexing group state.
39+ #[ derive( Clone , Debug ) ]
40+ pub struct GroupIds {
41+ ids : ArrayRef ,
42+ num_groups : usize ,
43+ }
44+
45+ impl GroupIds {
46+ /// Create group ids from an encoded non-null `u32` array.
47+ pub fn new ( ids : ArrayRef , num_groups : usize ) -> VortexResult < Self > {
48+ validate_num_groups ( num_groups) ?;
49+ vortex_ensure ! (
50+ ids. dtype( ) == & DType :: Primitive ( PType :: U32 , Nullability :: NonNullable ) ,
51+ "Group ids must be non-nullable u32, got {}" ,
52+ ids. dtype( )
53+ ) ;
54+ Ok ( Self { ids, num_groups } )
55+ }
56+
57+ /// Create group ids from a materialized buffer.
58+ pub fn from_buffer ( ids : Buffer < u32 > , num_groups : usize ) -> VortexResult < Self > {
59+ Self :: new (
60+ PrimitiveArray :: new ( ids, Validity :: NonNullable ) . into_array ( ) ,
61+ num_groups,
62+ )
63+ }
64+
65+ /// Create group ids from materialized values.
66+ pub fn from_iter ( ids : impl IntoIterator < Item = u32 > , num_groups : usize ) -> VortexResult < Self > {
67+ Self :: from_buffer ( Buffer :: from_iter ( ids) , num_groups)
68+ }
69+
70+ /// Create group ids containing `0..num_groups`.
71+ pub fn range ( num_groups : usize ) -> VortexResult < Self > {
72+ validate_num_groups ( num_groups) ?;
73+ if num_groups == 0 {
74+ return Self :: from_buffer ( Buffer :: < u32 > :: empty ( ) , num_groups) ;
75+ }
76+
77+ let last = u32:: try_from ( num_groups - 1 ) . map_err ( |_| {
78+ vortex_err ! (
79+ "num_groups {} exceeds dense u32 group id capacity" ,
80+ num_groups
81+ )
82+ } ) ?;
83+ Self :: from_buffer ( ( 0 ..=last) . collect ( ) , num_groups)
84+ }
85+
86+ /// Return the encoded ids array.
87+ pub fn ids ( & self ) -> & ArrayRef {
88+ & self . ids
89+ }
90+
91+ /// Return the number of dense group state slots.
92+ pub fn num_groups ( & self ) -> usize {
93+ self . num_groups
94+ }
95+
96+ /// Return the number of ids.
97+ pub fn len ( & self ) -> usize {
98+ self . ids . len ( )
99+ }
100+
101+ /// Return whether there are no ids.
102+ pub fn is_empty ( & self ) -> bool {
103+ self . ids . is_empty ( )
104+ }
105+
106+ /// Return the encoding id for kernel dispatch.
107+ pub fn encoding_id ( & self ) -> ArrayId {
108+ self . ids . encoding_id ( )
109+ }
110+
111+ /// Execute the ids to a native buffer and validate every id is in range.
112+ pub fn validated_ids ( & self , ctx : & mut ExecutionCtx ) -> VortexResult < Buffer < u32 > > {
113+ let ids = self . ids . clone ( ) . execute :: < Buffer < u32 > > ( ctx) ?;
114+ validate_group_ids ( ids. as_ref ( ) , self . num_groups ) ?;
115+ Ok ( ids)
116+ }
117+ }
118+
119+ /// An accumulator used for computing aggregates over group ids.
30120///
31121/// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches
32122/// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a
@@ -88,54 +178,30 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
88178 Ok ( ( ) )
89179 }
90180
91- fn validate_group_ids ( & self , group_ids : & [ u32 ] , num_groups : usize ) -> VortexResult < ( ) > {
92- validate_num_groups ( num_groups) ?;
93- for & group_id in group_ids {
94- vortex_ensure ! (
95- ( group_id as usize ) < num_groups,
96- "Group id {} out of range for {} groups" ,
97- group_id,
98- num_groups
99- ) ;
100- }
101- Ok ( ( ) )
102- }
103-
104181 fn accumulate_kernel_result (
105182 & mut self ,
106183 result : GroupedAggregateKernelResult ,
107- num_groups : usize ,
108184 ctx : & mut ExecutionCtx ,
109185 ) -> VortexResult < ( ) > {
110- self . accumulate_partials ( result. partials ( ) , result. group_ids ( ) , num_groups , ctx)
186+ self . accumulate_partials ( result. partials ( ) , result. group_ids ( ) , ctx)
111187 }
112188
113189 fn try_accumulate_kernel (
114190 & mut self ,
115191 batch : & ArrayRef ,
116- group_ids : & [ u32 ] ,
117- num_groups : usize ,
192+ group_ids : & GroupIds ,
118193 ctx : & mut ExecutionCtx ,
119194 ) -> VortexResult < bool > {
120195 let session = ctx. session ( ) . clone ( ) ;
121196
122- if let Some ( kernel) = session
123- . aggregate_fns ( )
124- . find_grouped_encoding_kernel ( batch. encoding_id ( ) , self . aggregate_fn . id ( ) )
125- && let Some ( result) =
126- kernel. grouped_aggregate ( & self . aggregate_fn , batch, group_ids, num_groups, ctx) ?
127- {
128- self . accumulate_kernel_result ( result, num_groups, ctx) ?;
129- return Ok ( true ) ;
130- }
131-
132- if let Some ( kernel) = session
133- . aggregate_fns ( )
134- . find_grouped_kernel ( self . aggregate_fn . id ( ) )
135- && let Some ( result) =
136- kernel. grouped_aggregate ( & self . aggregate_fn , batch, group_ids, num_groups, ctx) ?
197+ if let Some ( kernel) = session. aggregate_fns ( ) . find_grouped_kernel (
198+ self . aggregate_fn . id ( ) ,
199+ batch. encoding_id ( ) ,
200+ group_ids. encoding_id ( ) ,
201+ ) && let Some ( result) =
202+ kernel. grouped_aggregate ( & self . aggregate_fn , batch, group_ids, ctx) ?
137203 {
138- self . accumulate_kernel_result ( result, num_groups , ctx) ?;
204+ self . accumulate_kernel_result ( result, ctx) ?;
139205 return Ok ( true ) ;
140206 }
141207
@@ -198,18 +264,31 @@ fn validate_num_groups(num_groups: usize) -> VortexResult<()> {
198264 Ok ( ( ) )
199265}
200266
267+ fn validate_group_ids ( group_ids : & [ u32 ] , num_groups : usize ) -> VortexResult < ( ) > {
268+ validate_num_groups ( num_groups) ?;
269+ for & group_id in group_ids {
270+ vortex_ensure ! (
271+ ( group_id as usize ) < num_groups,
272+ "Group id {} out of range for {} groups" ,
273+ group_id,
274+ num_groups
275+ ) ;
276+ }
277+ Ok ( ( ) )
278+ }
279+
201280/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the
202281/// aggregate function is not known at compile time.
203282pub trait DynGroupedAccumulator : ' static + Send {
204283 /// Accumulate a values batch into dense group state.
205284 ///
206285 /// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in
207- /// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch.
286+ /// `0..group_ids.num_groups()`; ids may repeat, appear out of order, or be absent from a
287+ /// given batch.
208288 fn accumulate (
209289 & mut self ,
210290 batch : & ArrayRef ,
211- group_ids : & [ u32 ] ,
212- num_groups : usize ,
291+ group_ids : & GroupIds ,
213292 ctx : & mut ExecutionCtx ,
214293 ) -> VortexResult < ( ) > ;
215294
@@ -220,8 +299,7 @@ pub trait DynGroupedAccumulator: 'static + Send {
220299 fn accumulate_partials (
221300 & mut self ,
222301 partials : & ArrayRef ,
223- group_ids : & [ u32 ] ,
224- num_groups : usize ,
302+ group_ids : & GroupIds ,
225303 ctx : & mut ExecutionCtx ,
226304 ) -> VortexResult < ( ) > ;
227305
@@ -254,10 +332,10 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
254332 fn accumulate (
255333 & mut self ,
256334 batch : & ArrayRef ,
257- group_ids : & [ u32 ] ,
258- num_groups : usize ,
335+ group_ids : & GroupIds ,
259336 ctx : & mut ExecutionCtx ,
260337 ) -> VortexResult < ( ) > {
338+ let num_groups = group_ids. num_groups ( ) ;
261339 vortex_ensure ! (
262340 batch. dtype( ) == & self . dtype,
263341 "Input DType mismatch: expected {}, got {}" ,
@@ -271,56 +349,43 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
271349 group_ids. len( )
272350 ) ;
273351
274- self . validate_group_ids ( group_ids, num_groups) ?;
275352 self . ensure_groups ( num_groups) ?;
276353
277- if self . try_accumulate_kernel ( batch, group_ids, num_groups, ctx) ? {
278- return Ok ( ( ) ) ;
279- }
280-
281- if self . vtable . try_accumulate_grouped (
282- & mut self . partials [ ..num_groups] ,
283- batch,
284- group_ids,
285- ctx,
286- ) ? {
354+ if self . try_accumulate_kernel ( batch, group_ids, ctx) ? {
287355 return Ok ( ( ) ) ;
288356 }
289357
290358 let input = batch. clone ( ) ;
291359 let mut batch = batch. clone ( ) ;
360+ let mut tried_current = true ;
292361 for _ in 0 ..max_iterations ( ) {
293362 if batch. is :: < AnyColumnar > ( ) {
294363 break ;
295364 }
296365
297- if self . try_accumulate_kernel ( & batch, group_ids, num_groups , ctx) ? {
366+ if !tried_current && self . try_accumulate_kernel ( & batch, group_ids, ctx) ? {
298367 return Ok ( ( ) ) ;
299368 }
300369
301370 batch = batch. execute ( ctx) ?;
371+ tried_current = false ;
302372 }
303373
304- let columnar = batch. clone ( ) . execute :: < Columnar > ( ctx) ?;
305- if self . vtable . accumulate_grouped (
306- & mut self . partials [ ..num_groups] ,
307- & columnar,
308- group_ids,
309- ctx,
310- ) ? {
374+ if !tried_current && self . try_accumulate_kernel ( & batch, group_ids, ctx) ? {
311375 return Ok ( ( ) ) ;
312376 }
313377
314- self . accumulate_fallback ( & input, group_ids, ctx)
378+ let group_ids = group_ids. validated_ids ( ctx) ?;
379+ self . accumulate_fallback ( & input, group_ids. as_ref ( ) , ctx)
315380 }
316381
317382 fn accumulate_partials (
318383 & mut self ,
319384 partials : & ArrayRef ,
320- group_ids : & [ u32 ] ,
321- num_groups : usize ,
385+ group_ids : & GroupIds ,
322386 ctx : & mut ExecutionCtx ,
323387 ) -> VortexResult < ( ) > {
388+ let num_groups = group_ids. num_groups ( ) ;
324389 vortex_ensure ! (
325390 partials. dtype( ) == & self . partial_dtype,
326391 "Partial DType mismatch: expected {}, got {}" ,
@@ -334,7 +399,7 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
334399 group_ids. len( )
335400 ) ;
336401
337- self . validate_group_ids ( group_ids, num_groups ) ?;
402+ let group_ids = group_ids . validated_ids ( ctx ) ?;
338403 self . ensure_groups ( num_groups) ?;
339404
340405 for ( row_idx, & group_id) in group_ids. iter ( ) . enumerate ( ) {
0 commit comments