Skip to content

Commit b22ebfe

Browse files
committed
Add Min and Max aggregate functions
Signed-off-by: Nicholas Gates <nick@nickgates.com>
1 parent 2138a72 commit b22ebfe

8 files changed

Lines changed: 543 additions & 4 deletions

File tree

vortex-array/public-api.lock

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,56 @@ pub struct vortex_array::aggregate_fn::fns::last::LastPartial
538538

539539
pub fn vortex_array::aggregate_fn::fns::last::last(&vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
540540

541+
pub mod vortex_array::aggregate_fn::fns::max
542+
543+
pub struct vortex_array::aggregate_fn::fns::max::Max
544+
545+
impl core::clone::Clone for vortex_array::aggregate_fn::fns::max::Max
546+
547+
pub fn vortex_array::aggregate_fn::fns::max::Max::clone(&self) -> vortex_array::aggregate_fn::fns::max::Max
548+
549+
impl core::fmt::Debug for vortex_array::aggregate_fn::fns::max::Max
550+
551+
pub fn vortex_array::aggregate_fn::fns::max::Max::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result
552+
553+
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::max::Max
554+
555+
pub type vortex_array::aggregate_fn::fns::max::Max::Options = vortex_array::aggregate_fn::EmptyOptions
556+
557+
pub type vortex_array::aggregate_fn::fns::max::Max::Partial = vortex_array::aggregate_fn::fns::max::MaxPartial
558+
559+
pub fn vortex_array::aggregate_fn::fns::max::Max::accumulate(&self, &mut Self::Partial, &vortex_array::Columnar, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
560+
561+
pub fn vortex_array::aggregate_fn::fns::max::Max::coerce_args(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
562+
563+
pub fn vortex_array::aggregate_fn::fns::max::Max::combine_partials(&self, &mut Self::Partial, vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
564+
565+
pub fn vortex_array::aggregate_fn::fns::max::Max::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
566+
567+
pub fn vortex_array::aggregate_fn::fns::max::Max::empty_partial(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
568+
569+
pub fn vortex_array::aggregate_fn::fns::max::Max::finalize(&self, vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
570+
571+
pub fn vortex_array::aggregate_fn::fns::max::Max::finalize_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
572+
573+
pub fn vortex_array::aggregate_fn::fns::max::Max::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
574+
575+
pub fn vortex_array::aggregate_fn::fns::max::Max::is_saturated(&self, &Self::Partial) -> bool
576+
577+
pub fn vortex_array::aggregate_fn::fns::max::Max::partial_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
578+
579+
pub fn vortex_array::aggregate_fn::fns::max::Max::reset(&self, &mut Self::Partial)
580+
581+
pub fn vortex_array::aggregate_fn::fns::max::Max::return_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
582+
583+
pub fn vortex_array::aggregate_fn::fns::max::Max::serialize(&self, &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
584+
585+
pub fn vortex_array::aggregate_fn::fns::max::Max::to_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
586+
587+
pub fn vortex_array::aggregate_fn::fns::max::Max::try_accumulate(&self, &mut Self::Partial, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<bool>
588+
589+
pub struct vortex_array::aggregate_fn::fns::max::MaxPartial
590+
541591
pub mod vortex_array::aggregate_fn::fns::mean
542592

543593
pub struct vortex_array::aggregate_fn::fns::mean::Mean
@@ -586,6 +636,56 @@ pub fn vortex_array::aggregate_fn::fns::mean::Mean::serialize(&self, &vortex_arr
586636

587637
pub fn vortex_array::aggregate_fn::fns::mean::mean(&vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
588638

639+
pub mod vortex_array::aggregate_fn::fns::min
640+
641+
pub struct vortex_array::aggregate_fn::fns::min::Min
642+
643+
impl core::clone::Clone for vortex_array::aggregate_fn::fns::min::Min
644+
645+
pub fn vortex_array::aggregate_fn::fns::min::Min::clone(&self) -> vortex_array::aggregate_fn::fns::min::Min
646+
647+
impl core::fmt::Debug for vortex_array::aggregate_fn::fns::min::Min
648+
649+
pub fn vortex_array::aggregate_fn::fns::min::Min::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result
650+
651+
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::min::Min
652+
653+
pub type vortex_array::aggregate_fn::fns::min::Min::Options = vortex_array::aggregate_fn::EmptyOptions
654+
655+
pub type vortex_array::aggregate_fn::fns::min::Min::Partial = vortex_array::aggregate_fn::fns::min::MinPartial
656+
657+
pub fn vortex_array::aggregate_fn::fns::min::Min::accumulate(&self, &mut Self::Partial, &vortex_array::Columnar, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
658+
659+
pub fn vortex_array::aggregate_fn::fns::min::Min::coerce_args(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
660+
661+
pub fn vortex_array::aggregate_fn::fns::min::Min::combine_partials(&self, &mut Self::Partial, vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
662+
663+
pub fn vortex_array::aggregate_fn::fns::min::Min::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
664+
665+
pub fn vortex_array::aggregate_fn::fns::min::Min::empty_partial(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
666+
667+
pub fn vortex_array::aggregate_fn::fns::min::Min::finalize(&self, vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
668+
669+
pub fn vortex_array::aggregate_fn::fns::min::Min::finalize_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
670+
671+
pub fn vortex_array::aggregate_fn::fns::min::Min::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
672+
673+
pub fn vortex_array::aggregate_fn::fns::min::Min::is_saturated(&self, &Self::Partial) -> bool
674+
675+
pub fn vortex_array::aggregate_fn::fns::min::Min::partial_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
676+
677+
pub fn vortex_array::aggregate_fn::fns::min::Min::reset(&self, &mut Self::Partial)
678+
679+
pub fn vortex_array::aggregate_fn::fns::min::Min::return_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
680+
681+
pub fn vortex_array::aggregate_fn::fns::min::Min::serialize(&self, &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
682+
683+
pub fn vortex_array::aggregate_fn::fns::min::Min::to_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
684+
685+
pub fn vortex_array::aggregate_fn::fns::min::Min::try_accumulate(&self, &mut Self::Partial, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<bool>
686+
687+
pub struct vortex_array::aggregate_fn::fns::min::MinPartial
688+
589689
pub mod vortex_array::aggregate_fn::fns::min_max
590690

591691
pub struct vortex_array::aggregate_fn::fns::min_max::MinMax
@@ -1354,6 +1454,78 @@ pub fn vortex_array::aggregate_fn::fns::last::Last::to_scalar(&self, &Self::Part
13541454

13551455
pub fn vortex_array::aggregate_fn::fns::last::Last::try_accumulate(&self, &mut Self::Partial, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<bool>
13561456

1457+
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::max::Max
1458+
1459+
pub type vortex_array::aggregate_fn::fns::max::Max::Options = vortex_array::aggregate_fn::EmptyOptions
1460+
1461+
pub type vortex_array::aggregate_fn::fns::max::Max::Partial = vortex_array::aggregate_fn::fns::max::MaxPartial
1462+
1463+
pub fn vortex_array::aggregate_fn::fns::max::Max::accumulate(&self, &mut Self::Partial, &vortex_array::Columnar, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
1464+
1465+
pub fn vortex_array::aggregate_fn::fns::max::Max::coerce_args(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
1466+
1467+
pub fn vortex_array::aggregate_fn::fns::max::Max::combine_partials(&self, &mut Self::Partial, vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
1468+
1469+
pub fn vortex_array::aggregate_fn::fns::max::Max::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
1470+
1471+
pub fn vortex_array::aggregate_fn::fns::max::Max::empty_partial(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
1472+
1473+
pub fn vortex_array::aggregate_fn::fns::max::Max::finalize(&self, vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1474+
1475+
pub fn vortex_array::aggregate_fn::fns::max::Max::finalize_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
1476+
1477+
pub fn vortex_array::aggregate_fn::fns::max::Max::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
1478+
1479+
pub fn vortex_array::aggregate_fn::fns::max::Max::is_saturated(&self, &Self::Partial) -> bool
1480+
1481+
pub fn vortex_array::aggregate_fn::fns::max::Max::partial_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
1482+
1483+
pub fn vortex_array::aggregate_fn::fns::max::Max::reset(&self, &mut Self::Partial)
1484+
1485+
pub fn vortex_array::aggregate_fn::fns::max::Max::return_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
1486+
1487+
pub fn vortex_array::aggregate_fn::fns::max::Max::serialize(&self, &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
1488+
1489+
pub fn vortex_array::aggregate_fn::fns::max::Max::to_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
1490+
1491+
pub fn vortex_array::aggregate_fn::fns::max::Max::try_accumulate(&self, &mut Self::Partial, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<bool>
1492+
1493+
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::min::Min
1494+
1495+
pub type vortex_array::aggregate_fn::fns::min::Min::Options = vortex_array::aggregate_fn::EmptyOptions
1496+
1497+
pub type vortex_array::aggregate_fn::fns::min::Min::Partial = vortex_array::aggregate_fn::fns::min::MinPartial
1498+
1499+
pub fn vortex_array::aggregate_fn::fns::min::Min::accumulate(&self, &mut Self::Partial, &vortex_array::Columnar, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<()>
1500+
1501+
pub fn vortex_array::aggregate_fn::fns::min::Min::coerce_args(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<vortex_array::dtype::DType>
1502+
1503+
pub fn vortex_array::aggregate_fn::fns::min::Min::combine_partials(&self, &mut Self::Partial, vortex_array::scalar::Scalar) -> vortex_error::VortexResult<()>
1504+
1505+
pub fn vortex_array::aggregate_fn::fns::min::Min::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult<Self::Options>
1506+
1507+
pub fn vortex_array::aggregate_fn::fns::min::Min::empty_partial(&self, &Self::Options, &vortex_array::dtype::DType) -> vortex_error::VortexResult<Self::Partial>
1508+
1509+
pub fn vortex_array::aggregate_fn::fns::min::Min::finalize(&self, vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1510+
1511+
pub fn vortex_array::aggregate_fn::fns::min::Min::finalize_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
1512+
1513+
pub fn vortex_array::aggregate_fn::fns::min::Min::id(&self) -> vortex_array::aggregate_fn::AggregateFnId
1514+
1515+
pub fn vortex_array::aggregate_fn::fns::min::Min::is_saturated(&self, &Self::Partial) -> bool
1516+
1517+
pub fn vortex_array::aggregate_fn::fns::min::Min::partial_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
1518+
1519+
pub fn vortex_array::aggregate_fn::fns::min::Min::reset(&self, &mut Self::Partial)
1520+
1521+
pub fn vortex_array::aggregate_fn::fns::min::Min::return_dtype(&self, &Self::Options, &vortex_array::dtype::DType) -> core::option::Option<vortex_array::dtype::DType>
1522+
1523+
pub fn vortex_array::aggregate_fn::fns::min::Min::serialize(&self, &Self::Options) -> vortex_error::VortexResult<core::option::Option<alloc::vec::Vec<u8>>>
1524+
1525+
pub fn vortex_array::aggregate_fn::fns::min::Min::to_scalar(&self, &Self::Partial) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
1526+
1527+
pub fn vortex_array::aggregate_fn::fns::min::Min::try_accumulate(&self, &mut Self::Partial, &vortex_array::ArrayRef, &mut vortex_array::ExecutionCtx) -> vortex_error::VortexResult<bool>
1528+
13571529
impl vortex_array::aggregate_fn::AggregateFnVTable for vortex_array::aggregate_fn::fns::min_max::MinMax
13581530

13591531
pub type vortex_array::aggregate_fn::fns::min_max::MinMax::Options = vortex_array::aggregate_fn::EmptyOptions
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexExpect;
5+
use vortex_error::VortexResult;
6+
7+
use crate::ArrayRef;
8+
use crate::Columnar;
9+
use crate::ExecutionCtx;
10+
use crate::IntoArray;
11+
use crate::aggregate_fn::AggregateFnId;
12+
use crate::aggregate_fn::AggregateFnVTable;
13+
use crate::aggregate_fn::EmptyOptions;
14+
use crate::aggregate_fn::fns::min_max::MinMax;
15+
use crate::aggregate_fn::fns::min_max::min_max;
16+
use crate::dtype::DType;
17+
use crate::partial_ord::partial_max;
18+
use crate::scalar::Scalar;
19+
20+
/// Compute the maximum non-null value of an array.
21+
#[derive(Clone, Debug)]
22+
pub struct Max;
23+
24+
/// Partial accumulator state for the maximum aggregate.
25+
pub struct MaxPartial {
26+
max: Option<Scalar>,
27+
element_dtype: DType,
28+
}
29+
30+
impl MaxPartial {
31+
fn merge(&mut self, max: Scalar) {
32+
if max.is_null() {
33+
return;
34+
}
35+
36+
self.max = Some(match self.max.take() {
37+
Some(current) => partial_max(max, current).vortex_expect("incomparable max scalars"),
38+
None => max,
39+
});
40+
}
41+
}
42+
43+
impl AggregateFnVTable for Max {
44+
type Options = EmptyOptions;
45+
type Partial = MaxPartial;
46+
47+
fn id(&self) -> AggregateFnId {
48+
AggregateFnId::new("vortex.max")
49+
}
50+
51+
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
52+
Ok(None)
53+
}
54+
55+
fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
56+
MinMax
57+
.return_dtype(&EmptyOptions, input_dtype)
58+
.map(|_| input_dtype.as_nonnullable())
59+
}
60+
61+
fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
62+
self.return_dtype(options, input_dtype)
63+
}
64+
65+
fn empty_partial(
66+
&self,
67+
_options: &Self::Options,
68+
input_dtype: &DType,
69+
) -> VortexResult<Self::Partial> {
70+
Ok(MaxPartial {
71+
max: None,
72+
element_dtype: input_dtype.clone(),
73+
})
74+
}
75+
76+
fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
77+
partial.merge(other);
78+
Ok(())
79+
}
80+
81+
fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
82+
Ok(partial
83+
.max
84+
.clone()
85+
.unwrap_or_else(|| Scalar::null(partial.element_dtype.as_nonnullable())))
86+
}
87+
88+
fn reset(&self, partial: &mut Self::Partial) {
89+
partial.max = None;
90+
}
91+
92+
fn is_saturated(&self, _partial: &Self::Partial) -> bool {
93+
false
94+
}
95+
96+
fn accumulate(
97+
&self,
98+
partial: &mut Self::Partial,
99+
batch: &Columnar,
100+
ctx: &mut ExecutionCtx,
101+
) -> VortexResult<()> {
102+
// Delegate to the existing min_max implementation for now. A dedicated max aggregate
103+
// would avoid computing min when only max is needed.
104+
let array = match batch {
105+
Columnar::Canonical(canonical) => canonical.clone().into_array(),
106+
Columnar::Constant(constant) => constant.clone().into_array(),
107+
};
108+
if let Some(result) = min_max(&array, ctx)? {
109+
partial.merge(result.max);
110+
}
111+
Ok(())
112+
}
113+
114+
fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
115+
Ok(partials)
116+
}
117+
118+
fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
119+
self.to_scalar(partial)
120+
}
121+
}
122+
123+
#[cfg(test)]
124+
mod tests {
125+
use vortex_buffer::buffer;
126+
use vortex_error::VortexResult;
127+
128+
use crate::IntoArray as _;
129+
use crate::LEGACY_SESSION;
130+
use crate::VortexSessionExecute;
131+
use crate::aggregate_fn::Accumulator;
132+
use crate::aggregate_fn::DynAccumulator;
133+
use crate::aggregate_fn::EmptyOptions;
134+
use crate::aggregate_fn::fns::max::Max;
135+
use crate::arrays::PrimitiveArray;
136+
use crate::dtype::DType;
137+
use crate::dtype::Nullability;
138+
use crate::dtype::PType;
139+
use crate::scalar::Scalar;
140+
use crate::validity::Validity;
141+
142+
#[test]
143+
fn max_aggregate_fn() -> VortexResult<()> {
144+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
145+
let dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
146+
let mut acc = Accumulator::try_new(Max, EmptyOptions, dtype)?;
147+
148+
let batch1 = PrimitiveArray::new(buffer![10i32, 20, 5], Validity::NonNullable).into_array();
149+
acc.accumulate(&batch1, &mut ctx)?;
150+
151+
let batch2 = PrimitiveArray::new(buffer![3i32, 25], Validity::NonNullable).into_array();
152+
acc.accumulate(&batch2, &mut ctx)?;
153+
154+
assert_eq!(acc.finish()?, Scalar::from(25i32));
155+
Ok(())
156+
}
157+
}

0 commit comments

Comments
 (0)