Skip to content

Commit 51b1a4c

Browse files
committed
Support FixedSizeList in list_length; use AnyList matcher
Signed-off-by: Matt Katz <mhkatz97@gmail.com>
1 parent f6e936a commit 51b1a4c

2 files changed

Lines changed: 124 additions & 32 deletions

File tree

vortex-array/src/expr/exprs.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -754,11 +754,9 @@ pub fn ext_storage(input: Expression) -> Expression {
754754

755755
// ---- ListLength ----
756756

757-
/// Creates an expression that computes the number of elements in each list.
758-
///
759-
/// This is akin to ANSI SQL `CARDINALITY()`, or DuckDB's `len()`/`array_length()`. The result is
760-
/// a `U64` array; a null list yields a null length. It reads only the list's offsets, not the
761-
/// element values.
757+
/// Creates an expression that computes the number of elements in each list
758+
/// for `List` and `FixedSizeList` inputs. This is akin to ANSI SQL `CARDINALITY()`,
759+
/// or DuckDB's `len()`/`array_length()`.
762760
///
763761
/// ```rust
764762
/// # use vortex_array::expr::{list_length, root};

vortex-array/src/scalar_fn/fns/list_length.rs

Lines changed: 121 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,37 @@ use vortex_session::registry::CachedId;
1010
use crate::ArrayRef;
1111
use crate::ExecutionCtx;
1212
use crate::IntoArray;
13+
use crate::array::ArrayView;
1314
use crate::arrays::ConstantArray;
1415
use crate::arrays::List;
1516
use crate::arrays::ListView;
16-
use crate::arrays::ListViewArray;
1717
use crate::arrays::list::ListArrayExt;
1818
use crate::arrays::listview::ListViewArrayExt;
1919
use crate::builtins::ArrayBuiltins;
2020
use crate::dtype::DType;
2121
use crate::dtype::Nullability;
2222
use crate::dtype::PType;
2323
use crate::expr::Expression;
24+
use crate::matcher::Matcher;
2425
use crate::scalar::Scalar;
2526
use crate::scalar_fn::Arity;
2627
use crate::scalar_fn::ChildName;
2728
use crate::scalar_fn::EmptyOptions;
2829
use crate::scalar_fn::ExecutionArgs;
30+
use crate::scalar_fn::ReduceCtx;
31+
use crate::scalar_fn::ReduceNode;
32+
use crate::scalar_fn::ReduceNodeRef;
2933
use crate::scalar_fn::ScalarFnId;
3034
use crate::scalar_fn::ScalarFnVTable;
35+
use crate::scalar_fn::ScalarFnVTableExt;
36+
use crate::scalar_fn::fns::literal::Literal;
3137
use crate::scalar_fn::fns::operators::Operator;
3238

33-
/// Number of elements in each list of a `List` typed array.
39+
/// Number of elements in each list of a `List` or `FixedSizeList` typed array.
3440
///
35-
/// This is computed purely from the list's offsets/sizes (the per-list length), never reading the
36-
/// element *values*. Validity is carried over from the original array.
41+
/// This is computed purely from the list's offsets (`ListArray`), sizes (`ListViewArray`), or
42+
/// dtype (`FixedSizeListArray`) without reading the element *values*. Validity is carried over
43+
/// from the original array.
3744
#[derive(Clone)]
3845
pub struct ListLength;
3946

@@ -70,8 +77,10 @@ impl ScalarFnVTable for ListLength {
7077

7178
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
7279
match &arg_dtypes[0] {
73-
DType::List(_, nullable) => Ok(DType::Primitive(PType::U64, *nullable)),
74-
other => vortex_bail!("list_length() requires List, got {other}"),
80+
DType::List(_, nullable) | DType::FixedSizeList(_, _, nullable) => {
81+
Ok(DType::Primitive(PType::U64, *nullable))
82+
}
83+
other => vortex_bail!("list_length() requires List or FixedSizeList, got {other}"),
7584
}
7685
}
7786

@@ -89,10 +98,23 @@ impl ScalarFnVTable for ListLength {
8998
return Ok(ConstantArray::new(len_scalar, args.row_count()).into_array());
9099
}
91100

92-
match input.dtype() {
93-
DType::List(..) => list_length(&input, nullability, ctx),
94-
other => vortex_bail!("list_length() requires List, got {other}"),
101+
list_length(&input, nullability, ctx)
102+
}
103+
104+
fn reduce(
105+
&self,
106+
_options: &Self::Options,
107+
node: &dyn ReduceNode,
108+
ctx: &dyn ReduceCtx,
109+
) -> VortexResult<Option<ReduceNodeRef>> {
110+
// The length of nonnullable fixed-size list is constant
111+
if let DType::FixedSizeList(_, size, Nullability::NonNullable) =
112+
node.child(0).node_dtype()?
113+
{
114+
let length = Scalar::primitive(size as u64, Nullability::NonNullable);
115+
return Ok(Some(ctx.new_node(Literal.bind(length), &[])?));
95116
}
117+
Ok(None)
96118
}
97119

98120
fn validity(
@@ -126,26 +148,33 @@ pub(crate) fn list_length(
126148
nullability: Nullability,
127149
ctx: &mut ExecutionCtx,
128150
) -> VortexResult<ArrayRef> {
129-
// TODO(mk): short-circuit when array is all null
130-
131-
let (lengths, validity) = if let Some(list) = array.as_opt::<ListView>() {
132-
// Length array is exactly size child
133-
(list.sizes().clone(), list.listview_validity())
134-
} else if let Some(list) = array.as_opt::<List>() {
135-
// `length[i] = offsets[i + 1] - offsets[i]`
136-
let offsets = list.offsets();
137-
let n = offsets.len().saturating_sub(1);
138-
let lengths = offsets
139-
.slice(1..offsets.len())?
140-
.binary(offsets.slice(0..n)?, Operator::Sub)?;
141-
(lengths, list.list_validity())
142-
} else {
143-
// Otherwise execute to `ListViewArray` and take size child
144-
let list = array.clone().execute::<ListViewArray>(ctx)?;
145-
(list.sizes().clone(), list.listview_validity())
151+
let (lengths, validity) = match array.dtype() {
152+
// The length of fixed-size list is constant, so just need to carry over validity
153+
DType::FixedSizeList(_, size, _) => {
154+
let lengths = ConstantArray::new(
155+
Scalar::primitive(*size as u64, Nullability::NonNullable),
156+
array.len(),
157+
)
158+
.into_array();
159+
(lengths, array.validity()?)
160+
}
161+
DType::List(..) => {
162+
let list = array.clone().execute_until::<AnyList>(ctx)?;
163+
164+
if let Some(list) = list.as_opt::<List>() {
165+
let lengths = list_length_from_offsets(list)?;
166+
(lengths, list.list_validity())
167+
} else if let Some(list_view) = list.as_opt::<ListView>() {
168+
// Length array is exactly the sizes child
169+
(list_view.sizes().clone(), list_view.listview_validity())
170+
} else {
171+
unreachable!("AnyList matcher guarantees List or ListView")
172+
}
173+
}
174+
other => vortex_bail!("list_length() requires List or FixedSizeList, got {other}"),
146175
};
147176

148-
// Cast to the declared `U64` result
177+
// Cast to `U64`
149178
let len = lengths.len();
150179
let lengths = lengths.cast(DType::Primitive(PType::U64, nullability))?;
151180

@@ -157,6 +186,28 @@ pub(crate) fn list_length(
157186
}
158187
}
159188

189+
/// Calculate the lengths of `ListArray` elements via the `offsets` child:
190+
/// `length[i] = offsets[i + 1] - offsets[i]`.
191+
fn list_length_from_offsets(list: ArrayView<'_, List>) -> VortexResult<ArrayRef> {
192+
let offsets = list.offsets();
193+
let n = offsets.len().saturating_sub(1);
194+
195+
offsets
196+
.slice(1..offsets.len())?
197+
.binary(offsets.slice(0..n)?, Operator::Sub)
198+
}
199+
200+
/// Matches an `Array<List>` or `Array<ListView>`.
201+
struct AnyList;
202+
203+
impl Matcher for AnyList {
204+
type Match<'a> = ();
205+
206+
fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
207+
(array.as_opt::<List>().is_some() || array.as_opt::<ListView>().is_some()).then_some(())
208+
}
209+
}
210+
160211
#[cfg(test)]
161212
mod tests {
162213
use std::sync::Arc;
@@ -171,16 +222,20 @@ mod tests {
171222
use crate::VortexSessionExecute;
172223
use crate::arrays::BoolArray;
173224
use crate::arrays::ConstantArray;
225+
use crate::arrays::FixedSizeListArray;
174226
use crate::arrays::ListArray;
175227
use crate::arrays::ListViewArray;
176228
use crate::arrays::PrimitiveArray;
229+
use crate::arrays::ScalarFn;
230+
use crate::arrays::scalar_fn::ScalarFnArrayExt;
177231
use crate::assert_arrays_eq;
178232
use crate::dtype::DType;
179233
use crate::dtype::Nullability;
180234
use crate::dtype::PType;
181235
use crate::expr::list_length;
182236
use crate::expr::root;
183237
use crate::scalar::Scalar;
238+
use crate::scalar_fn::fns::literal::Literal;
184239
use crate::validity::Validity;
185240

186241
fn create_list_elements() -> ArrayRef {
@@ -299,6 +354,45 @@ mod tests {
299354
Ok(())
300355
}
301356

357+
fn create_fixed_size_list(validity: Validity) -> ArrayRef {
358+
// 4 lists of size 2 over 8 primitive elements.
359+
let elements = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5, 6, 7, 8]).into_array();
360+
FixedSizeListArray::new(elements, 2, validity, 4).into_array()
361+
}
362+
363+
#[test]
364+
fn test_fixed_size_list_length() -> VortexResult<()> {
365+
let fsl = create_fixed_size_list(Validity::NonNullable);
366+
let result = fsl.apply(&list_length(root()))?;
367+
368+
// A non-nullable fixed-size list reduces to a constant literal length, never touching the
369+
// `ListLength` execution path.
370+
assert!(
371+
result
372+
.as_opt::<ScalarFn>()
373+
.is_some_and(|f| f.scalar_fn().as_opt::<Literal>().is_some()),
374+
"list_length over a non-nullable FixedSizeList must reduce to a constant literal"
375+
);
376+
assert_arrays_eq!(result, PrimitiveArray::from_iter([2u64, 2, 2, 2]));
377+
Ok(())
378+
}
379+
380+
#[test]
381+
fn test_fixed_size_list_length_nullable() -> VortexResult<()> {
382+
let fsl = create_fixed_size_list(Validity::Array(
383+
BoolArray::from_iter([true, false, true, false]).into_array(),
384+
));
385+
let result = fsl.apply(&list_length(root()))?;
386+
387+
let session = VortexSession::empty();
388+
let mut ctx = session.create_execution_ctx();
389+
let result = result.execute::<PrimitiveArray>(&mut ctx)?;
390+
391+
let expected = PrimitiveArray::from_option_iter::<u64, _>([Some(2), None, Some(2), None]);
392+
assert_arrays_eq!(result, expected);
393+
Ok(())
394+
}
395+
302396
#[test]
303397
fn test_display() {
304398
let expr = list_length(root());

0 commit comments

Comments
 (0)