@@ -12,6 +12,7 @@ use vortex_array::ExecutionCtx;
1212use vortex_array:: IntoArray ;
1313use vortex_array:: LEGACY_SESSION ;
1414use vortex_array:: VortexSessionExecute ;
15+ use vortex_array:: aggregate_fn:: AggregateFnRef ;
1516use vortex_array:: aggregate_fn:: fns:: sum:: sum;
1617use vortex_array:: arrays:: ConstantArray ;
1718use vortex_array:: arrays:: StructArray ;
@@ -25,7 +26,6 @@ use vortex_array::dtype::Nullability;
2526use vortex_array:: dtype:: PType ;
2627use vortex_array:: expr:: stats:: Precision ;
2728use vortex_array:: expr:: stats:: Stat ;
28- use vortex_array:: expr:: stats:: StatsProvider ;
2929use vortex_array:: scalar:: Scalar ;
3030use vortex_array:: scalar:: ScalarTruncation ;
3131use vortex_array:: scalar:: lower_bound;
@@ -35,9 +35,12 @@ use vortex_array::validity::Validity;
3535use vortex_buffer:: BufferString ;
3636use vortex_buffer:: ByteBuffer ;
3737use vortex_error:: VortexResult ;
38+ use vortex_error:: vortex_ensure_eq;
3839
3940use crate :: layouts:: zoned:: schema:: MAX_IS_TRUNCATED ;
4041use crate :: layouts:: zoned:: schema:: MIN_IS_TRUNCATED ;
42+ use crate :: layouts:: zoned:: schema:: aggregate_descriptor;
43+ use crate :: layouts:: zoned:: schema:: aggregate_state_dtype;
4144
4245/// Accumulates write-time statistics for each logical zone.
4346pub struct StatsAccumulator {
@@ -67,18 +70,6 @@ impl StatsAccumulator {
6770 }
6871 }
6972
70- pub fn push_chunk_without_compute ( & mut self , array : & ArrayRef ) -> VortexResult < ( ) > {
71- for builder in & mut self . builders {
72- if let Some ( Precision :: Exact ( value) ) = array. statistics ( ) . get ( builder. stat ( ) ) {
73- builder. append_scalar ( value. cast ( & value. dtype ( ) . as_nullable ( ) ) ?) ?;
74- } else {
75- builder. append_null ( ) ;
76- }
77- }
78- self . length += 1 ;
79- Ok ( ( ) )
80- }
81-
8273 pub fn push_chunk ( & mut self , array : & ArrayRef , ctx : & mut ExecutionCtx ) -> VortexResult < ( ) > {
8374 for builder in & mut self . builders {
8475 if let Some ( value) = array. statistics ( ) . compute_stat ( builder. stat ( ) , ctx) ? {
@@ -165,6 +156,102 @@ impl StatsAccumulator {
165156 }
166157}
167158
159+ /// Accumulates aggregate-function partials for each logical zone.
160+ pub ( crate ) struct AggregateStatsAccumulator {
161+ builders : Vec < AggregateStatsArrayBuilder > ,
162+ length : usize ,
163+ }
164+
165+ impl AggregateStatsAccumulator {
166+ pub ( crate ) fn new ( dtype : & DType , aggregate_fns : & [ AggregateFnRef ] ) -> Self {
167+ let builders = aggregate_fns
168+ . iter ( )
169+ . filter_map ( |aggregate_fn| {
170+ aggregate_state_dtype ( dtype, aggregate_fn) . map ( |partial_dtype| {
171+ AggregateStatsArrayBuilder :: new (
172+ aggregate_fn. clone ( ) ,
173+ & partial_dtype. as_nullable ( ) ,
174+ 1024 ,
175+ )
176+ } )
177+ } )
178+ . collect :: < Vec < _ > > ( ) ;
179+
180+ Self {
181+ builders,
182+ length : 0 ,
183+ }
184+ }
185+
186+ pub ( crate ) fn aggregate_fns ( & self ) -> Arc < [ AggregateFnRef ] > {
187+ self . builders
188+ . iter ( )
189+ . map ( |builder| builder. aggregate_fn . clone ( ) )
190+ . collect :: < Vec < _ > > ( )
191+ . into ( )
192+ }
193+
194+ pub ( crate ) fn push_partials ( & mut self , partials : Vec < Scalar > ) -> VortexResult < ( ) > {
195+ vortex_ensure_eq ! (
196+ partials. len( ) ,
197+ self . builders. len( ) ,
198+ "aggregate partial count must match zone stats builder count"
199+ ) ;
200+
201+ for ( builder, value) in self . builders . iter_mut ( ) . zip_eq ( partials) {
202+ builder. append_scalar ( value) ?;
203+ }
204+ self . length += 1 ;
205+ Ok ( ( ) )
206+ }
207+
208+ pub ( crate ) fn as_array (
209+ & mut self ,
210+ ) -> VortexResult < Option < ( StructArray , Arc < [ AggregateFnRef ] > ) > > {
211+ let mut names = Vec :: new ( ) ;
212+ let mut fields = Vec :: new ( ) ;
213+ let mut aggregate_fns = Vec :: new ( ) ;
214+
215+ for builder in self
216+ . builders
217+ . iter_mut ( )
218+ . sorted_unstable_by ( |lhs, rhs| lhs. descriptor . cmp ( & rhs. descriptor ) )
219+ {
220+ let values = builder. finish ( ) ;
221+
222+ if values. all_invalid ( ) ? {
223+ continue ;
224+ }
225+
226+ aggregate_fns. push ( builder. aggregate_fn . clone ( ) ) ;
227+ names. extend ( values. names ) ;
228+ fields. extend ( values. arrays ) ;
229+ }
230+
231+ if names. is_empty ( ) {
232+ return Ok ( None ) ;
233+ }
234+
235+ let array = StructArray :: try_new ( names. into ( ) , fields, self . length , Validity :: NonNullable ) ?;
236+ Ok ( Some ( ( array, aggregate_fns. into ( ) ) ) )
237+ }
238+ }
239+
240+ pub ( crate ) fn aggregate_partials (
241+ array : & ArrayRef ,
242+ aggregate_fns : & [ AggregateFnRef ] ,
243+ ctx : & mut ExecutionCtx ,
244+ ) -> VortexResult < Vec < Scalar > > {
245+ aggregate_fns
246+ . iter ( )
247+ . map ( |aggregate_fn| {
248+ let mut accumulator = aggregate_fn. accumulator ( array. dtype ( ) ) ?;
249+ accumulator. accumulate ( array, ctx) ?;
250+ accumulator. partial_scalar ( )
251+ } )
252+ . collect ( )
253+ }
254+
168255fn stats_builder_with_capacity (
169256 stat : Stat ,
170257 dtype : & DType ,
@@ -203,6 +290,35 @@ fn stats_builder_with_capacity(
203290 }
204291}
205292
293+ struct AggregateStatsArrayBuilder {
294+ aggregate_fn : AggregateFnRef ,
295+ descriptor : String ,
296+ dtype : DType ,
297+ builder : Box < dyn ArrayBuilder > ,
298+ }
299+
300+ impl AggregateStatsArrayBuilder {
301+ fn new ( aggregate_fn : AggregateFnRef , dtype : & DType , capacity : usize ) -> Self {
302+ Self {
303+ descriptor : aggregate_descriptor ( & aggregate_fn) ,
304+ aggregate_fn,
305+ dtype : dtype. clone ( ) ,
306+ builder : builder_with_capacity ( dtype, capacity) ,
307+ }
308+ }
309+
310+ fn append_scalar ( & mut self , value : Scalar ) -> VortexResult < ( ) > {
311+ self . builder . append_scalar ( & value. cast ( & self . dtype ) ?)
312+ }
313+
314+ fn finish ( & mut self ) -> NamedArrays {
315+ NamedArrays {
316+ names : vec ! [ self . descriptor. clone( ) . into( ) ] ,
317+ arrays : vec ! [ self . builder. finish( ) ] ,
318+ }
319+ }
320+ }
321+
206322/// Arrays with their associated names, reduced version of a `StructArray`.
207323struct NamedArrays {
208324 names : Vec < FieldName > ,
0 commit comments