Skip to content

Commit 7c748a8

Browse files
committed
comments
Signed-off-by: Onur Satici <onur@spiraldb.com>
1 parent 59d7be7 commit 7c748a8

7 files changed

Lines changed: 193 additions & 157 deletions

File tree

vortex-array/src/aggregate_fn/accumulator_grouped.rs

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ use crate::aggregate_fn::AggregateFn;
1414
use crate::aggregate_fn::AggregateFnRef;
1515
use crate::aggregate_fn::AggregateFnVTable;
1616
use crate::aggregate_fn::DynAccumulator;
17-
use crate::aggregate_fn::kernels::GroupedAggregateKernelResult;
1817
use crate::aggregate_fn::session::AggregateFnSessionExt;
1918
use crate::array::ArrayId;
2019
use crate::arrays::PrimitiveArray;
@@ -67,22 +66,6 @@ impl GroupIds {
6766
Self::from_buffer(Buffer::from_iter(ids), num_groups)
6867
}
6968

70-
/// Create group ids containing `0..num_groups`.
71-
pub fn range(num_groups: usize) -> VortexResult<Self> {
72-
validate_num_groups(num_groups)?;
73-
if num_groups == 0 {
74-
return Self::from_buffer(Buffer::<u32>::empty(), num_groups);
75-
}
76-
77-
let last = u32::try_from(num_groups - 1).map_err(|_| {
78-
vortex_err!(
79-
"num_groups {} exceeds dense u32 group id capacity",
80-
num_groups
81-
)
82-
})?;
83-
Self::from_buffer((0..=last).collect(), num_groups)
84-
}
85-
8669
/// Return the encoded ids array.
8770
pub fn ids(&self) -> &ArrayRef {
8871
&self.ids
@@ -178,14 +161,6 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
178161
Ok(())
179162
}
180163

181-
fn accumulate_kernel_result(
182-
&mut self,
183-
result: GroupedAggregateKernelResult,
184-
ctx: &mut ExecutionCtx,
185-
) -> VortexResult<()> {
186-
self.accumulate_partials(result.partials(), result.group_ids(), ctx)
187-
}
188-
189164
fn try_accumulate_kernel(
190165
&mut self,
191166
batch: &ArrayRef,
@@ -198,10 +173,13 @@ impl<V: AggregateFnVTable> GroupedAccumulator<V> {
198173
self.aggregate_fn.id(),
199174
batch.encoding_id(),
200175
group_ids.encoding_id(),
201-
) && let Some(result) =
202-
kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, ctx)?
203-
{
204-
self.accumulate_kernel_result(result, ctx)?;
176+
) && kernel.grouped_accumulate(
177+
&self.aggregate_fn,
178+
batch,
179+
group_ids,
180+
&mut self.partials,
181+
ctx,
182+
)? {
205183
return Ok(true);
206184
}
207185

vortex-array/src/aggregate_fn/fns/count/grouped.rs

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,48 +6,33 @@ use vortex_error::VortexResult;
66
use super::Count;
77
use crate::ArrayRef;
88
use crate::ExecutionCtx;
9-
use crate::IntoArray;
10-
use crate::aggregate_fn::AggregateFnRef;
9+
use crate::aggregate_fn::EmptyOptions;
1110
use crate::aggregate_fn::GroupIds;
12-
use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
13-
use crate::aggregate_fn::kernels::GroupedAggregateKernelResult;
14-
use crate::arrays::PrimitiveArray;
11+
use crate::aggregate_fn::kernels::GroupedAggregateKernel;
12+
use crate::aggregate_fn::kernels::GroupedAggregateKernelAdapter;
13+
14+
pub(crate) static COUNT_GROUPED_KERNEL: GroupedAggregateKernelAdapter<Count, CountGroupedKernel> =
15+
GroupedAggregateKernelAdapter::new(CountGroupedKernel);
1516

1617
#[derive(Debug)]
1718
pub(crate) struct CountGroupedKernel;
1819

19-
impl DynGroupedAggregateKernel for CountGroupedKernel {
20-
fn grouped_aggregate(
20+
impl GroupedAggregateKernel<Count> for CountGroupedKernel {
21+
fn grouped_accumulate(
2122
&self,
22-
aggregate_fn: &AggregateFnRef,
23+
_options: &EmptyOptions,
24+
states: &mut [u64],
2325
batch: &ArrayRef,
2426
group_ids: &GroupIds,
2527
ctx: &mut ExecutionCtx,
26-
) -> VortexResult<Option<GroupedAggregateKernelResult>> {
27-
if aggregate_fn.as_opt::<Count>().is_none() {
28-
return Ok(None);
29-
}
30-
31-
let partials = accumulate_grouped(batch, group_ids, ctx)?;
32-
Ok(Some(GroupedAggregateKernelResult::dense(
33-
PrimitiveArray::from_iter(partials).into_array(),
34-
group_ids.num_groups(),
35-
)?))
36-
}
37-
}
38-
39-
fn accumulate_grouped(
40-
batch: &ArrayRef,
41-
group_ids: &GroupIds,
42-
ctx: &mut ExecutionCtx,
43-
) -> VortexResult<Vec<u64>> {
44-
let ids = group_ids.validated_ids(ctx)?;
45-
let mut partials = vec![0u64; group_ids.num_groups()];
46-
let validity = batch.validity()?.execute_mask(batch.len(), ctx)?;
47-
for (&group_id, valid) in ids.iter().zip(validity.iter()) {
48-
if valid {
49-
partials[group_id as usize] += 1;
28+
) -> VortexResult<bool> {
29+
let group_ids = group_ids.validated_ids(ctx)?;
30+
let validity = batch.validity()?.execute_mask(batch.len(), ctx)?;
31+
for (&group_id, valid) in group_ids.iter().zip(validity.iter()) {
32+
if valid {
33+
states[group_id as usize] += 1;
34+
}
5035
}
36+
Ok(true)
5137
}
52-
Ok(partials)
5338
}

vortex-array/src/aggregate_fn/fns/count/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
mod grouped;
5-
pub(crate) use grouped::CountGroupedKernel;
5+
pub(crate) use grouped::COUNT_GROUPED_KERNEL;
66
use vortex_error::VortexExpect;
77
use vortex_error::VortexResult;
88

vortex-array/src/aggregate_fn/fns/sum/grouped.rs

Lines changed: 24 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ use num_traits::AsPrimitive;
55
use num_traits::ToPrimitive;
66
use vortex_error::VortexExpect;
77
use vortex_error::VortexResult;
8-
use vortex_error::vortex_bail;
98
use vortex_error::vortex_panic;
109
use vortex_mask::AllOr;
1110
use vortex_mask::Mask;
@@ -19,74 +18,51 @@ use super::primitive::sum_float_all;
1918
use super::primitive::sum_signed_all;
2019
use super::primitive::sum_unsigned_all;
2120
use crate::ArrayRef;
22-
use crate::Canonical;
23-
use crate::Columnar;
2421
use crate::ExecutionCtx;
25-
use crate::aggregate_fn::AggregateFnRef;
26-
use crate::aggregate_fn::AggregateFnVTable;
22+
use crate::aggregate_fn::EmptyOptions;
2723
use crate::aggregate_fn::GroupIds;
28-
use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
29-
use crate::aggregate_fn::kernels::GroupedAggregateKernelResult;
24+
use crate::aggregate_fn::kernels::GroupedAggregateKernel;
25+
use crate::aggregate_fn::kernels::GroupedAggregateKernelAdapter;
26+
use crate::arrays::Bool;
3027
use crate::arrays::BoolArray;
28+
use crate::arrays::Primitive;
3129
use crate::arrays::PrimitiveArray;
3230
use crate::arrays::bool::BoolArrayExt;
3331
use crate::dtype::NativePType;
3432
use crate::match_each_native_ptype;
3533

3634
const MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS: usize = 4;
3735

36+
pub(crate) static SUM_GROUPED_KERNEL: GroupedAggregateKernelAdapter<Sum, SumGroupedKernel> =
37+
GroupedAggregateKernelAdapter::new(SumGroupedKernel);
38+
3839
#[derive(Debug)]
3940
pub(crate) struct SumGroupedKernel;
4041

41-
impl DynGroupedAggregateKernel for SumGroupedKernel {
42-
fn grouped_aggregate(
42+
impl GroupedAggregateKernel<Sum> for SumGroupedKernel {
43+
fn grouped_accumulate(
4344
&self,
44-
aggregate_fn: &AggregateFnRef,
45+
_options: &EmptyOptions,
46+
partials: &mut [SumPartial],
4547
batch: &ArrayRef,
4648
group_ids: &GroupIds,
4749
ctx: &mut ExecutionCtx,
48-
) -> VortexResult<Option<GroupedAggregateKernelResult>> {
49-
let Some(options) = aggregate_fn.as_opt::<Sum>() else {
50-
return Ok(None);
51-
};
52-
53-
let columnar = batch.clone().execute::<Columnar>(ctx)?;
54-
match &columnar {
55-
Columnar::Canonical(Canonical::Primitive(_))
56-
| Columnar::Canonical(Canonical::Bool(_)) => {}
57-
// Decimal and constants still use the universal grouped fallback.
58-
Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => return Ok(None),
59-
Columnar::Canonical(_) => {
60-
vortex_bail!("Unsupported canonical type for sum: {}", columnar.dtype())
61-
}
50+
) -> VortexResult<bool> {
51+
if let Some(primitive) = batch.as_opt::<Primitive>() {
52+
let group_ids = group_ids.validated_ids(ctx)?;
53+
let primitive = primitive.into_owned();
54+
accumulate_grouped_primitive(partials, &primitive, group_ids.as_ref(), ctx)?;
55+
return Ok(true);
6256
}
6357

64-
let partial_dtype = Sum
65-
.partial_dtype(options, batch.dtype())
66-
.ok_or_else(|| vortex_error::vortex_err!("Unsupported sum dtype: {}", batch.dtype()))?;
67-
let ids = group_ids.validated_ids(ctx)?;
68-
let mut partials = (0..group_ids.num_groups())
69-
.map(|_| Sum.empty_partial(options, batch.dtype()))
70-
.collect::<VortexResult<Vec<_>>>()?;
71-
72-
match &columnar {
73-
Columnar::Canonical(Canonical::Primitive(p)) => {
74-
accumulate_grouped_primitive(&mut partials, p, ids.as_ref(), ctx)?;
75-
}
76-
Columnar::Canonical(Canonical::Bool(b)) => {
77-
accumulate_grouped_bool(&mut partials, b, ids.as_ref(), ctx)?;
78-
}
79-
Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => unreachable!(),
80-
Columnar::Canonical(_) => unreachable!(),
58+
if let Some(bools) = batch.as_opt::<Bool>() {
59+
let group_ids = group_ids.validated_ids(ctx)?;
60+
let bools = bools.into_owned();
61+
accumulate_grouped_bool(partials, &bools, group_ids.as_ref(), ctx)?;
62+
return Ok(true);
8163
}
8264

83-
let Some(partials) = Sum.partials_to_array(&partials, &partial_dtype)? else {
84-
return Ok(None);
85-
};
86-
Ok(Some(GroupedAggregateKernelResult::dense(
87-
partials,
88-
group_ids.num_groups(),
89-
)?))
65+
Ok(false)
9066
}
9167
}
9268

vortex-array/src/aggregate_fn/fns/sum/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod decimal;
77
mod grouped;
88
mod primitive;
99

10-
pub(crate) use grouped::SumGroupedKernel;
10+
pub(crate) use grouped::SUM_GROUPED_KERNEL;
1111
use vortex_buffer::Buffer;
1212
use vortex_error::VortexExpect;
1313
use vortex_error::VortexResult;

0 commit comments

Comments
 (0)