17
17
18
18
//! Physical expression schema rewriting utilities
19
19
20
- use std:: sync:: Arc ;
21
20
use std:: cmp:: Ordering ;
21
+ use std:: sync:: Arc ;
22
22
23
23
use arrow:: compute:: can_cast_types;
24
24
use arrow:: datatypes:: {
@@ -230,7 +230,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
230
230
left. as_any ( ) . downcast_ref :: < CastExpr > ( ) ,
231
231
right. as_any ( ) . downcast_ref :: < Literal > ( ) ,
232
232
) {
233
- if let Some ( optimized) = self . unwrap_cast_with_literal ( cast_expr, literal, * op) ? {
233
+ if let Some ( optimized) =
234
+ self . unwrap_cast_with_literal ( cast_expr, literal, * op) ?
235
+ {
234
236
return Ok ( Some ( Arc :: new ( BinaryExpr :: new (
235
237
optimized. 0 ,
236
238
* op,
@@ -244,7 +246,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
244
246
left. as_any ( ) . downcast_ref :: < Literal > ( ) ,
245
247
right. as_any ( ) . downcast_ref :: < CastExpr > ( ) ,
246
248
) {
247
- if let Some ( optimized) = self . unwrap_cast_with_literal ( cast_expr, literal, * op) ? {
249
+ if let Some ( optimized) =
250
+ self . unwrap_cast_with_literal ( cast_expr, literal, * op) ?
251
+ {
248
252
return Ok ( Some ( Arc :: new ( BinaryExpr :: new (
249
253
optimized. 1 ,
250
254
* op,
@@ -265,32 +269,36 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
265
269
) -> Result < Option < ( Arc < dyn PhysicalExpr > , Arc < dyn PhysicalExpr > ) > > {
266
270
// Get the inner expression (what's being cast)
267
271
let inner_expr = cast_expr. expr ( ) ;
268
-
272
+
269
273
// Handle the case where inner expression might be another cast (due to schema rewriting)
270
274
// This can happen when the schema rewriter adds a cast to a column, and then we have
271
275
// an original cast on top of that.
272
- let ( final_inner_expr, column) = if let Some ( inner_cast) = inner_expr. as_any ( ) . downcast_ref :: < CastExpr > ( ) {
273
- // We have a nested cast, check if the inner cast's expression is a column
274
- let inner_inner_expr = inner_cast. expr ( ) ;
275
- if let Some ( col) = inner_inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
276
- ( inner_inner_expr, col)
276
+ let ( final_inner_expr, column) =
277
+ if let Some ( inner_cast) = inner_expr. as_any ( ) . downcast_ref :: < CastExpr > ( ) {
278
+ // We have a nested cast, check if the inner cast's expression is a column
279
+ let inner_inner_expr = inner_cast. expr ( ) ;
280
+ if let Some ( col) = inner_inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
281
+ ( inner_inner_expr, col)
282
+ } else {
283
+ return Ok ( None ) ;
284
+ }
285
+ } else if let Some ( col) = inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
286
+ ( inner_expr, col)
277
287
} else {
278
288
return Ok ( None ) ;
279
- }
280
- } else if let Some ( col) = inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
281
- ( inner_expr, col)
282
- } else {
283
- return Ok ( None ) ;
284
- } ;
289
+ } ;
285
290
286
291
// Get the column's data type from the physical schema
287
- let column_data_type = match self . physical_file_schema . field_with_name ( column. name ( ) ) {
288
- Ok ( field) => field. data_type ( ) ,
289
- Err ( _) => return Ok ( None ) , // Column not found, can't optimize
290
- } ;
292
+ let column_data_type =
293
+ match self . physical_file_schema . field_with_name ( column. name ( ) ) {
294
+ Ok ( field) => field. data_type ( ) ,
295
+ Err ( _) => return Ok ( None ) , // Column not found, can't optimize
296
+ } ;
291
297
292
298
// Try to cast the literal to the column's data type
293
- if let Some ( casted_literal) = try_cast_literal_to_type ( literal. value ( ) , column_data_type, op) {
299
+ if let Some ( casted_literal) =
300
+ try_cast_literal_to_type ( literal. value ( ) , column_data_type, op)
301
+ {
294
302
return Ok ( Some ( (
295
303
Arc :: clone ( final_inner_expr) ,
296
304
expressions:: lit ( casted_literal) ,
@@ -323,7 +331,6 @@ fn cast_literal_to_type_with_op(
323
331
target_type : & DataType ,
324
332
op : Operator ,
325
333
) -> Option < ScalarValue > {
326
-
327
334
match ( op, lit_value) {
328
335
(
329
336
Operator :: Eq | Operator :: NotEq ,
@@ -754,22 +761,27 @@ mod tests {
754
761
let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
755
762
let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Int64 , None ) ) ;
756
763
let literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ;
757
- let binary_expr = Arc :: new ( BinaryExpr :: new (
758
- cast_expr,
759
- Operator :: Eq ,
760
- literal_expr,
761
- ) ) ;
764
+ let binary_expr =
765
+ Arc :: new ( BinaryExpr :: new ( cast_expr, Operator :: Eq , literal_expr) ) ;
762
766
763
767
let result = rewriter. rewrite ( binary_expr. clone ( ) as Arc < dyn PhysicalExpr > ) ?;
764
768
765
769
// The result should be a binary expression with the cast unwrapped
766
770
let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
767
-
771
+
768
772
// Left side should be the original column (no cast)
769
- assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ;
770
-
773
+ assert ! ( result_binary
774
+ . left( )
775
+ . as_any( )
776
+ . downcast_ref:: <Column >( )
777
+ . is_some( ) ) ;
778
+
771
779
// Right side should be a literal with the value cast to Int32
772
- let right_literal = result_binary. right ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ;
780
+ let right_literal = result_binary
781
+ . right ( )
782
+ . as_any ( )
783
+ . downcast_ref :: < Literal > ( )
784
+ . unwrap ( ) ;
773
785
assert_eq ! ( * right_literal. value( ) , ScalarValue :: Int32 ( Some ( 123 ) ) ) ;
774
786
775
787
Ok ( ( ) )
@@ -787,23 +799,28 @@ mod tests {
787
799
let literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ;
788
800
let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
789
801
let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Int64 , None ) ) ;
790
- let binary_expr = Arc :: new ( BinaryExpr :: new (
791
- literal_expr,
792
- Operator :: Eq ,
793
- cast_expr,
794
- ) ) ;
802
+ let binary_expr =
803
+ Arc :: new ( BinaryExpr :: new ( literal_expr, Operator :: Eq , cast_expr) ) ;
795
804
796
805
let result = rewriter. rewrite ( binary_expr) ?;
797
806
798
807
// The result should be a binary expression with the cast unwrapped
799
808
let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
800
-
809
+
801
810
// Left side should be a literal with the value cast to Int32
802
- let left_literal = result_binary. left ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ;
811
+ let left_literal = result_binary
812
+ . left ( )
813
+ . as_any ( )
814
+ . downcast_ref :: < Literal > ( )
815
+ . unwrap ( ) ;
803
816
assert_eq ! ( * left_literal. value( ) , ScalarValue :: Int32 ( Some ( 123 ) ) ) ;
804
-
817
+
805
818
// Right side should be the original column (no cast)
806
- assert ! ( result_binary. right( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ;
819
+ assert ! ( result_binary
820
+ . right( )
821
+ . as_any( )
822
+ . downcast_ref:: <Column >( )
823
+ . is_some( ) ) ;
807
824
808
825
Ok ( ( ) )
809
826
}
@@ -820,22 +837,27 @@ mod tests {
820
837
let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
821
838
let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Utf8 , None ) ) ;
822
839
let literal_expr = expressions:: lit ( ScalarValue :: Utf8 ( Some ( "123" . to_string ( ) ) ) ) ;
823
- let binary_expr = Arc :: new ( BinaryExpr :: new (
824
- cast_expr,
825
- Operator :: Eq ,
826
- literal_expr,
827
- ) ) ;
840
+ let binary_expr =
841
+ Arc :: new ( BinaryExpr :: new ( cast_expr, Operator :: Eq , literal_expr) ) ;
828
842
829
843
let result = rewriter. rewrite ( binary_expr) ?;
830
844
831
845
// The result should be a binary expression with the cast unwrapped
832
846
let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
833
-
847
+
834
848
// Left side should be the original column (no cast)
835
- assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ;
836
-
849
+ assert ! ( result_binary
850
+ . left( )
851
+ . as_any( )
852
+ . downcast_ref:: <Column >( )
853
+ . is_some( ) ) ;
854
+
837
855
// Right side should be a literal with the value cast to Int32
838
- let right_literal = result_binary. right ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ;
856
+ let right_literal = result_binary
857
+ . right ( )
858
+ . as_any ( )
859
+ . downcast_ref :: < Literal > ( )
860
+ . unwrap ( ) ;
839
861
assert_eq ! ( * right_literal. value( ) , ScalarValue :: Int32 ( Some ( 123 ) ) ) ;
840
862
841
863
Ok ( ( ) )
@@ -844,7 +866,8 @@ mod tests {
844
866
#[ test]
845
867
fn test_no_unwrap_cast_optimization_when_not_applicable ( ) -> Result < ( ) > {
846
868
// Test case where optimization should not apply - unsupported cast
847
- let physical_schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float32 , false ) ] ) ;
869
+ let physical_schema =
870
+ Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float32 , false ) ] ) ;
848
871
let logical_schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ;
849
872
850
873
let rewriter = PhysicalExprSchemaRewriter :: new ( & physical_schema, & logical_schema) ;
@@ -854,18 +877,19 @@ mod tests {
854
877
let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
855
878
let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Int64 , None ) ) ;
856
879
let literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ;
857
- let binary_expr = Arc :: new ( BinaryExpr :: new (
858
- cast_expr,
859
- Operator :: Eq ,
860
- literal_expr,
861
- ) ) ;
880
+ let binary_expr =
881
+ Arc :: new ( BinaryExpr :: new ( cast_expr, Operator :: Eq , literal_expr) ) ;
862
882
863
883
let result = rewriter. rewrite ( binary_expr) ?;
864
884
865
885
// The result should still be a binary expression with a cast on the left side
866
886
// since Float32 is not in our supported types for unwrap cast optimization
867
887
let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
868
- assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <CastExpr >( ) . is_some( ) ) ;
888
+ assert ! ( result_binary
889
+ . left( )
890
+ . as_any( )
891
+ . downcast_ref:: <CastExpr >( )
892
+ . is_some( ) ) ;
869
893
870
894
Ok ( ( ) )
871
895
}
0 commit comments