Skip to content

Commit 59d7be7

Browse files
committed
group ids as array ref, multi encoding kernel lookup
Signed-off-by: Onur Satici <onur@spiraldb.com>
1 parent 2ef64b2 commit 59d7be7

9 files changed

Lines changed: 405 additions & 208 deletions

File tree

vortex-array/benches/aggregate_grouped.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use vortex_array::VortexSessionExecute;
1515
use vortex_array::aggregate_fn::AggregateFnVTable;
1616
use vortex_array::aggregate_fn::DynGroupedAccumulator;
1717
use vortex_array::aggregate_fn::EmptyOptions;
18+
use vortex_array::aggregate_fn::GroupIds;
1819
use vortex_array::aggregate_fn::GroupedAccumulator;
1920
use vortex_array::aggregate_fn::fns::count::Count;
2021
use vortex_array::aggregate_fn::fns::sum::Sum;
@@ -45,24 +46,22 @@ fn total_element_count(group_sizes: &[usize]) -> usize {
4546

4647
struct DenseGroupedInput {
4748
values: ArrayRef,
48-
group_ids: Vec<u32>,
49-
num_groups: usize,
49+
group_ids: GroupIds,
5050
}
5151

5252
fn dense_grouped_input(values: ArrayRef, group_sizes: &[usize]) -> DenseGroupedInput {
5353
assert_eq!(values.len(), total_element_count(group_sizes));
5454

55-
let group_ids = group_sizes
56-
.iter()
57-
.enumerate()
58-
.flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size))
59-
.collect();
55+
let group_ids = GroupIds::from_iter(
56+
group_sizes
57+
.iter()
58+
.enumerate()
59+
.flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)),
60+
group_sizes.len(),
61+
)
62+
.unwrap();
6063

61-
DenseGroupedInput {
62-
values,
63-
group_ids,
64-
num_groups: group_sizes.len(),
65-
}
64+
DenseGroupedInput { values, group_ids }
6665
}
6766

6867
fn i32_nullable_all_valid_input() -> DenseGroupedInput {
@@ -142,14 +141,14 @@ where
142141
{
143142
let mut acc =
144143
GroupedAccumulator::try_new(vtable, EmptyOptions, input.values.dtype().clone()).unwrap();
144+
let num_groups = input.group_ids.num_groups();
145145
acc.accumulate(
146146
&input.values,
147147
&input.group_ids,
148-
input.num_groups,
149148
&mut LEGACY_SESSION.create_execution_ctx(),
150149
)
151150
.unwrap();
152-
divan::black_box(acc.finish(input.num_groups).unwrap())
151+
divan::black_box(acc.finish(num_groups).unwrap())
153152
}
154153

155154
#[divan::bench]

vortex-array/src/aggregate_fn/accumulator_grouped.rs

Lines changed: 130 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use vortex_error::vortex_ensure;
77
use vortex_error::vortex_err;
88

99
use crate::ArrayRef;
10-
use crate::Columnar;
1110
use crate::ExecutionCtx;
1211
use crate::IntoArray;
1312
use crate::aggregate_fn::Accumulator;
@@ -17,16 +16,107 @@ use crate::aggregate_fn::AggregateFnVTable;
1716
use crate::aggregate_fn::DynAccumulator;
1817
use crate::aggregate_fn::kernels::GroupedAggregateKernelResult;
1918
use crate::aggregate_fn::session::AggregateFnSessionExt;
19+
use crate::array::ArrayId;
20+
use crate::arrays::PrimitiveArray;
2021
use crate::builders::builder_with_capacity;
2122
use crate::columnar::AnyColumnar;
2223
use crate::dtype::DType;
24+
use crate::dtype::Nullability;
25+
use crate::dtype::PType;
2326
use crate::executor::max_iterations;
2427
use crate::scalar::Scalar;
28+
use crate::validity::Validity;
2529

2630
/// Reference-counted type-erased grouped accumulator.
2731
pub type GroupedAccumulatorRef = Box<dyn DynGroupedAccumulator>;
2832

29-
/// An accumulator used for computing aggregates over dense group ids.
33+
/// Encoded group ids parallel to a grouped aggregate input batch.
34+
///
35+
/// The array must contain non-null `u32` ordinals. The ordinals are dense state slots in
36+
/// `0..num_groups`, not raw group keys. Range validation may require executing the encoded array,
37+
/// so kernels that can prove the invariant from encoded metadata should avoid materializing and
38+
/// otherwise call [`Self::validated_ids`] before indexing group state.
39+
#[derive(Clone, Debug)]
40+
pub struct GroupIds {
41+
ids: ArrayRef,
42+
num_groups: usize,
43+
}
44+
45+
impl GroupIds {
46+
/// Create group ids from an encoded non-null `u32` array.
47+
pub fn new(ids: ArrayRef, num_groups: usize) -> VortexResult<Self> {
48+
validate_num_groups(num_groups)?;
49+
vortex_ensure!(
50+
ids.dtype() == &DType::Primitive(PType::U32, Nullability::NonNullable),
51+
"Group ids must be non-nullable u32, got {}",
52+
ids.dtype()
53+
);
54+
Ok(Self { ids, num_groups })
55+
}
56+
57+
/// Create group ids from a materialized buffer.
58+
pub fn from_buffer(ids: Buffer<u32>, num_groups: usize) -> VortexResult<Self> {
59+
Self::new(
60+
PrimitiveArray::new(ids, Validity::NonNullable).into_array(),
61+
num_groups,
62+
)
63+
}
64+
65+
/// Create group ids from materialized values.
66+
pub fn from_iter(ids: impl IntoIterator<Item = u32>, num_groups: usize) -> VortexResult<Self> {
67+
Self::from_buffer(Buffer::from_iter(ids), num_groups)
68+
}
69+
70+
/// Create group ids containing `0..num_groups`.
71+
pub fn range(num_groups: usize) -> VortexResult<Self> {
72+
validate_num_groups(num_groups)?;
73+
if num_groups == 0 {
74+
return Self::from_buffer(Buffer::<u32>::empty(), num_groups);
75+
}
76+
77+
let last = u32::try_from(num_groups - 1).map_err(|_| {
78+
vortex_err!(
79+
"num_groups {} exceeds dense u32 group id capacity",
80+
num_groups
81+
)
82+
})?;
83+
Self::from_buffer((0..=last).collect(), num_groups)
84+
}
85+
86+
/// Return the encoded ids array.
87+
pub fn ids(&self) -> &ArrayRef {
88+
&self.ids
89+
}
90+
91+
/// Return the number of dense group state slots.
92+
pub fn num_groups(&self) -> usize {
93+
self.num_groups
94+
}
95+
96+
/// Return the number of ids.
97+
pub fn len(&self) -> usize {
98+
self.ids.len()
99+
}
100+
101+
/// Return whether there are no ids.
102+
pub fn is_empty(&self) -> bool {
103+
self.ids.is_empty()
104+
}
105+
106+
/// Return the encoding id for kernel dispatch.
107+
pub fn encoding_id(&self) -> ArrayId {
108+
self.ids.encoding_id()
109+
}
110+
111+
/// Execute the ids to a native buffer and validate every id is in range.
112+
pub fn validated_ids(&self, ctx: &mut ExecutionCtx) -> VortexResult<Buffer<u32>> {
113+
let ids = self.ids.clone().execute::<Buffer<u32>>(ctx)?;
114+
validate_group_ids(ids.as_ref(), self.num_groups)?;
115+
Ok(ids)
116+
}
117+
}
118+
119+
/// An accumulator used for computing aggregates over group ids.
30120
///
31121
/// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches
32122
/// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a
@@ -88,54 +178,30 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
88178
Ok(())
89179
}
90180

91-
fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> {
92-
validate_num_groups(num_groups)?;
93-
for &group_id in group_ids {
94-
vortex_ensure!(
95-
(group_id as usize) < num_groups,
96-
"Group id {} out of range for {} groups",
97-
group_id,
98-
num_groups
99-
);
100-
}
101-
Ok(())
102-
}
103-
104181
fn accumulate_kernel_result(
105182
&mut self,
106183
result: GroupedAggregateKernelResult,
107-
num_groups: usize,
108184
ctx: &mut ExecutionCtx,
109185
) -> VortexResult<()> {
110-
self.accumulate_partials(result.partials(), result.group_ids(), num_groups, ctx)
186+
self.accumulate_partials(result.partials(), result.group_ids(), ctx)
111187
}
112188

113189
fn try_accumulate_kernel(
114190
&mut self,
115191
batch: &ArrayRef,
116-
group_ids: &[u32],
117-
num_groups: usize,
192+
group_ids: &GroupIds,
118193
ctx: &mut ExecutionCtx,
119194
) -> VortexResult<bool> {
120195
let session = ctx.session().clone();
121196

122-
if let Some(kernel) = session
123-
.aggregate_fns()
124-
.find_grouped_encoding_kernel(batch.encoding_id(), self.aggregate_fn.id())
125-
&& let Some(result) =
126-
kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)?
127-
{
128-
self.accumulate_kernel_result(result, num_groups, ctx)?;
129-
return Ok(true);
130-
}
131-
132-
if let Some(kernel) = session
133-
.aggregate_fns()
134-
.find_grouped_kernel(self.aggregate_fn.id())
135-
&& let Some(result) =
136-
kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)?
197+
if let Some(kernel) = session.aggregate_fns().find_grouped_kernel(
198+
self.aggregate_fn.id(),
199+
batch.encoding_id(),
200+
group_ids.encoding_id(),
201+
) && let Some(result) =
202+
kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, ctx)?
137203
{
138-
self.accumulate_kernel_result(result, num_groups, ctx)?;
204+
self.accumulate_kernel_result(result, ctx)?;
139205
return Ok(true);
140206
}
141207

@@ -198,18 +264,31 @@ fn validate_num_groups(num_groups: usize) -> VortexResult<()> {
198264
Ok(())
199265
}
200266

267+
fn validate_group_ids(group_ids: &[u32], num_groups: usize) -> VortexResult<()> {
268+
validate_num_groups(num_groups)?;
269+
for &group_id in group_ids {
270+
vortex_ensure!(
271+
(group_id as usize) < num_groups,
272+
"Group id {} out of range for {} groups",
273+
group_id,
274+
num_groups
275+
);
276+
}
277+
Ok(())
278+
}
279+
201280
/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the
202281
/// aggregate function is not known at compile time.
203282
pub trait DynGroupedAccumulator: 'static + Send {
204283
/// Accumulate a values batch into dense group state.
205284
///
206285
/// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in
207-
/// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch.
286+
/// `0..group_ids.num_groups()`; ids may repeat, appear out of order, or be absent from a
287+
/// given batch.
208288
fn accumulate(
209289
&mut self,
210290
batch: &ArrayRef,
211-
group_ids: &[u32],
212-
num_groups: usize,
291+
group_ids: &GroupIds,
213292
ctx: &mut ExecutionCtx,
214293
) -> VortexResult<()>;
215294

@@ -220,8 +299,7 @@ pub trait DynGroupedAccumulator: 'static + Send {
220299
fn accumulate_partials(
221300
&mut self,
222301
partials: &ArrayRef,
223-
group_ids: &[u32],
224-
num_groups: usize,
302+
group_ids: &GroupIds,
225303
ctx: &mut ExecutionCtx,
226304
) -> VortexResult<()>;
227305

@@ -254,10 +332,10 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
254332
fn accumulate(
255333
&mut self,
256334
batch: &ArrayRef,
257-
group_ids: &[u32],
258-
num_groups: usize,
335+
group_ids: &GroupIds,
259336
ctx: &mut ExecutionCtx,
260337
) -> VortexResult<()> {
338+
let num_groups = group_ids.num_groups();
261339
vortex_ensure!(
262340
batch.dtype() == &self.dtype,
263341
"Input DType mismatch: expected {}, got {}",
@@ -271,56 +349,43 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
271349
group_ids.len()
272350
);
273351

274-
self.validate_group_ids(group_ids, num_groups)?;
275352
self.ensure_groups(num_groups)?;
276353

277-
if self.try_accumulate_kernel(batch, group_ids, num_groups, ctx)? {
278-
return Ok(());
279-
}
280-
281-
if self.vtable.try_accumulate_grouped(
282-
&mut self.partials[..num_groups],
283-
batch,
284-
group_ids,
285-
ctx,
286-
)? {
354+
if self.try_accumulate_kernel(batch, group_ids, ctx)? {
287355
return Ok(());
288356
}
289357

290358
let input = batch.clone();
291359
let mut batch = batch.clone();
360+
let mut tried_current = true;
292361
for _ in 0..max_iterations() {
293362
if batch.is::<AnyColumnar>() {
294363
break;
295364
}
296365

297-
if self.try_accumulate_kernel(&batch, group_ids, num_groups, ctx)? {
366+
if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? {
298367
return Ok(());
299368
}
300369

301370
batch = batch.execute(ctx)?;
371+
tried_current = false;
302372
}
303373

304-
let columnar = batch.clone().execute::<Columnar>(ctx)?;
305-
if self.vtable.accumulate_grouped(
306-
&mut self.partials[..num_groups],
307-
&columnar,
308-
group_ids,
309-
ctx,
310-
)? {
374+
if !tried_current && self.try_accumulate_kernel(&batch, group_ids, ctx)? {
311375
return Ok(());
312376
}
313377

314-
self.accumulate_fallback(&input, group_ids, ctx)
378+
let group_ids = group_ids.validated_ids(ctx)?;
379+
self.accumulate_fallback(&input, group_ids.as_ref(), ctx)
315380
}
316381

317382
fn accumulate_partials(
318383
&mut self,
319384
partials: &ArrayRef,
320-
group_ids: &[u32],
321-
num_groups: usize,
385+
group_ids: &GroupIds,
322386
ctx: &mut ExecutionCtx,
323387
) -> VortexResult<()> {
388+
let num_groups = group_ids.num_groups();
324389
vortex_ensure!(
325390
partials.dtype() == &self.partial_dtype,
326391
"Partial DType mismatch: expected {}, got {}",
@@ -334,7 +399,7 @@ impl<V: AggregateFnVTable> DynGroupedAccumulator for GroupedAccumulator<V> {
334399
group_ids.len()
335400
);
336401

337-
self.validate_group_ids(group_ids, num_groups)?;
402+
let group_ids = group_ids.validated_ids(ctx)?;
338403
self.ensure_groups(num_groups)?;
339404

340405
for (row_idx, &group_id) in group_ids.iter().enumerate() {

0 commit comments

Comments
 (0)