@@ -10,30 +10,37 @@ use vortex_session::registry::CachedId;
1010use crate :: ArrayRef ;
1111use crate :: ExecutionCtx ;
1212use crate :: IntoArray ;
13+ use crate :: array:: ArrayView ;
1314use crate :: arrays:: ConstantArray ;
1415use crate :: arrays:: List ;
1516use crate :: arrays:: ListView ;
16- use crate :: arrays:: ListViewArray ;
1717use crate :: arrays:: list:: ListArrayExt ;
1818use crate :: arrays:: listview:: ListViewArrayExt ;
1919use crate :: builtins:: ArrayBuiltins ;
2020use crate :: dtype:: DType ;
2121use crate :: dtype:: Nullability ;
2222use crate :: dtype:: PType ;
2323use crate :: expr:: Expression ;
24+ use crate :: matcher:: Matcher ;
2425use crate :: scalar:: Scalar ;
2526use crate :: scalar_fn:: Arity ;
2627use crate :: scalar_fn:: ChildName ;
2728use crate :: scalar_fn:: EmptyOptions ;
2829use crate :: scalar_fn:: ExecutionArgs ;
30+ use crate :: scalar_fn:: ReduceCtx ;
31+ use crate :: scalar_fn:: ReduceNode ;
32+ use crate :: scalar_fn:: ReduceNodeRef ;
2933use crate :: scalar_fn:: ScalarFnId ;
3034use crate :: scalar_fn:: ScalarFnVTable ;
35+ use crate :: scalar_fn:: ScalarFnVTableExt ;
36+ use crate :: scalar_fn:: fns:: literal:: Literal ;
3137use crate :: scalar_fn:: fns:: operators:: Operator ;
3238
33- /// Number of elements in each list of a `List` typed array.
39+ /// Number of elements in each list of a `List` or `FixedSizeList` typed array.
3440///
35- /// This is computed purely from the list's offsets/sizes (the per-list length), never reading the
36- /// element *values*. Validity is carried over from the original array.
41+ /// This is computed purely from the list's offsets (`ListArray`), sizes (`ListViewArray`), or
42+ /// dtype (`FixedSizeListArray`) without reading the element *values*. Validity is carried over
43+ /// from the original array.
3744#[ derive( Clone ) ]
3845pub struct ListLength ;
3946
@@ -70,8 +77,10 @@ impl ScalarFnVTable for ListLength {
7077
7178 fn return_dtype ( & self , _options : & Self :: Options , arg_dtypes : & [ DType ] ) -> VortexResult < DType > {
7279 match & arg_dtypes[ 0 ] {
73- DType :: List ( _, nullable) => Ok ( DType :: Primitive ( PType :: U64 , * nullable) ) ,
74- other => vortex_bail ! ( "list_length() requires List, got {other}" ) ,
80+ DType :: List ( _, nullable) | DType :: FixedSizeList ( _, _, nullable) => {
81+ Ok ( DType :: Primitive ( PType :: U64 , * nullable) )
82+ }
83+ other => vortex_bail ! ( "list_length() requires List or FixedSizeList, got {other}" ) ,
7584 }
7685 }
7786
@@ -89,10 +98,23 @@ impl ScalarFnVTable for ListLength {
8998 return Ok ( ConstantArray :: new ( len_scalar, args. row_count ( ) ) . into_array ( ) ) ;
9099 }
91100
92- match input. dtype ( ) {
93- DType :: List ( ..) => list_length ( & input, nullability, ctx) ,
94- other => vortex_bail ! ( "list_length() requires List, got {other}" ) ,
101+ list_length ( & input, nullability, ctx)
102+ }
103+
104+ fn reduce (
105+ & self ,
106+ _options : & Self :: Options ,
107+ node : & dyn ReduceNode ,
108+ ctx : & dyn ReduceCtx ,
109+ ) -> VortexResult < Option < ReduceNodeRef > > {
110+ // The length of nonnullable fixed-size list is constant
111+ if let DType :: FixedSizeList ( _, size, Nullability :: NonNullable ) =
112+ node. child ( 0 ) . node_dtype ( ) ?
113+ {
114+ let length = Scalar :: primitive ( size as u64 , Nullability :: NonNullable ) ;
115+ return Ok ( Some ( ctx. new_node ( Literal . bind ( length) , & [ ] ) ?) ) ;
95116 }
117+ Ok ( None )
96118 }
97119
98120 fn validity (
@@ -126,26 +148,33 @@ pub(crate) fn list_length(
126148 nullability : Nullability ,
127149 ctx : & mut ExecutionCtx ,
128150) -> VortexResult < ArrayRef > {
129- // TODO(mk): short-circuit when array is all null
130-
131- let ( lengths, validity) = if let Some ( list) = array. as_opt :: < ListView > ( ) {
132- // Length array is exactly size child
133- ( list. sizes ( ) . clone ( ) , list. listview_validity ( ) )
134- } else if let Some ( list) = array. as_opt :: < List > ( ) {
135- // `length[i] = offsets[i + 1] - offsets[i]`
136- let offsets = list. offsets ( ) ;
137- let n = offsets. len ( ) . saturating_sub ( 1 ) ;
138- let lengths = offsets
139- . slice ( 1 ..offsets. len ( ) ) ?
140- . binary ( offsets. slice ( 0 ..n) ?, Operator :: Sub ) ?;
141- ( lengths, list. list_validity ( ) )
142- } else {
143- // Otherwise execute to `ListViewArray` and take size child
144- let list = array. clone ( ) . execute :: < ListViewArray > ( ctx) ?;
145- ( list. sizes ( ) . clone ( ) , list. listview_validity ( ) )
151+ let ( lengths, validity) = match array. dtype ( ) {
152+ // The length of fixed-size list is constant, so just need to carry over validity
153+ DType :: FixedSizeList ( _, size, _) => {
154+ let lengths = ConstantArray :: new (
155+ Scalar :: primitive ( * size as u64 , Nullability :: NonNullable ) ,
156+ array. len ( ) ,
157+ )
158+ . into_array ( ) ;
159+ ( lengths, array. validity ( ) ?)
160+ }
161+ DType :: List ( ..) => {
162+ let list = array. clone ( ) . execute_until :: < AnyList > ( ctx) ?;
163+
164+ if let Some ( list) = list. as_opt :: < List > ( ) {
165+ let lengths = list_length_from_offsets ( list) ?;
166+ ( lengths, list. list_validity ( ) )
167+ } else if let Some ( list_view) = list. as_opt :: < ListView > ( ) {
168+ // Length array is exactly the sizes child
169+ ( list_view. sizes ( ) . clone ( ) , list_view. listview_validity ( ) )
170+ } else {
171+ unreachable ! ( "AnyList matcher guarantees List or ListView" )
172+ }
173+ }
174+ other => vortex_bail ! ( "list_length() requires List or FixedSizeList, got {other}" ) ,
146175 } ;
147176
148- // Cast to the declared `U64` result
177+ // Cast to `U64`
149178 let len = lengths. len ( ) ;
150179 let lengths = lengths. cast ( DType :: Primitive ( PType :: U64 , nullability) ) ?;
151180
@@ -157,6 +186,28 @@ pub(crate) fn list_length(
157186 }
158187}
159188
189+ /// Calculate the lengths of `ListArray` elements via the `offsets` child:
190+ /// `length[i] = offsets[i + 1] - offsets[i]`.
191+ fn list_length_from_offsets ( list : ArrayView < ' _ , List > ) -> VortexResult < ArrayRef > {
192+ let offsets = list. offsets ( ) ;
193+ let n = offsets. len ( ) . saturating_sub ( 1 ) ;
194+
195+ offsets
196+ . slice ( 1 ..offsets. len ( ) ) ?
197+ . binary ( offsets. slice ( 0 ..n) ?, Operator :: Sub )
198+ }
199+
200+ /// Matches an `Array<List>` or `Array<ListView>`.
201+ struct AnyList ;
202+
203+ impl Matcher for AnyList {
204+ type Match < ' a > = ( ) ;
205+
206+ fn try_match ( array : & ArrayRef ) -> Option < Self :: Match < ' _ > > {
207+ ( array. as_opt :: < List > ( ) . is_some ( ) || array. as_opt :: < ListView > ( ) . is_some ( ) ) . then_some ( ( ) )
208+ }
209+ }
210+
160211#[ cfg( test) ]
161212mod tests {
162213 use std:: sync:: Arc ;
@@ -171,16 +222,20 @@ mod tests {
171222 use crate :: VortexSessionExecute ;
172223 use crate :: arrays:: BoolArray ;
173224 use crate :: arrays:: ConstantArray ;
225+ use crate :: arrays:: FixedSizeListArray ;
174226 use crate :: arrays:: ListArray ;
175227 use crate :: arrays:: ListViewArray ;
176228 use crate :: arrays:: PrimitiveArray ;
229+ use crate :: arrays:: ScalarFn ;
230+ use crate :: arrays:: scalar_fn:: ScalarFnArrayExt ;
177231 use crate :: assert_arrays_eq;
178232 use crate :: dtype:: DType ;
179233 use crate :: dtype:: Nullability ;
180234 use crate :: dtype:: PType ;
181235 use crate :: expr:: list_length;
182236 use crate :: expr:: root;
183237 use crate :: scalar:: Scalar ;
238+ use crate :: scalar_fn:: fns:: literal:: Literal ;
184239 use crate :: validity:: Validity ;
185240
186241 fn create_list_elements ( ) -> ArrayRef {
@@ -299,6 +354,45 @@ mod tests {
299354 Ok ( ( ) )
300355 }
301356
357+ fn create_fixed_size_list ( validity : Validity ) -> ArrayRef {
358+ // 4 lists of size 2 over 8 primitive elements.
359+ let elements = PrimitiveArray :: from_iter ( [ 1i32 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] ) . into_array ( ) ;
360+ FixedSizeListArray :: new ( elements, 2 , validity, 4 ) . into_array ( )
361+ }
362+
363+ #[ test]
364+ fn test_fixed_size_list_length ( ) -> VortexResult < ( ) > {
365+ let fsl = create_fixed_size_list ( Validity :: NonNullable ) ;
366+ let result = fsl. apply ( & list_length ( root ( ) ) ) ?;
367+
368+ // A non-nullable fixed-size list reduces to a constant literal length, never touching the
369+ // `ListLength` execution path.
370+ assert ! (
371+ result
372+ . as_opt:: <ScalarFn >( )
373+ . is_some_and( |f| f. scalar_fn( ) . as_opt:: <Literal >( ) . is_some( ) ) ,
374+ "list_length over a non-nullable FixedSizeList must reduce to a constant literal"
375+ ) ;
376+ assert_arrays_eq ! ( result, PrimitiveArray :: from_iter( [ 2u64 , 2 , 2 , 2 ] ) ) ;
377+ Ok ( ( ) )
378+ }
379+
380+ #[ test]
381+ fn test_fixed_size_list_length_nullable ( ) -> VortexResult < ( ) > {
382+ let fsl = create_fixed_size_list ( Validity :: Array (
383+ BoolArray :: from_iter ( [ true , false , true , false ] ) . into_array ( ) ,
384+ ) ) ;
385+ let result = fsl. apply ( & list_length ( root ( ) ) ) ?;
386+
387+ let session = VortexSession :: empty ( ) ;
388+ let mut ctx = session. create_execution_ctx ( ) ;
389+ let result = result. execute :: < PrimitiveArray > ( & mut ctx) ?;
390+
391+ let expected = PrimitiveArray :: from_option_iter :: < u64 , _ > ( [ Some ( 2 ) , None , Some ( 2 ) , None ] ) ;
392+ assert_arrays_eq ! ( result, expected) ;
393+ Ok ( ( ) )
394+ }
395+
302396 #[ test]
303397 fn test_display ( ) {
304398 let expr = list_length ( root ( ) ) ;
0 commit comments