@@ -139,6 +139,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
139
139
// If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do.
140
140
// TODO: do we need to sync this with what the `SchemaAdapter` actually does?
141
141
// While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else!
142
+ // See https://github.com/apache/datafusion/issues/16527
142
143
let null_value =
143
144
ScalarValue :: Null . cast_to ( logical_field. data_type ( ) ) ?;
144
145
return Ok ( Transformed :: yes ( expressions:: lit ( null_value) ) ) ;
@@ -197,9 +198,12 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
197
198
198
199
#[ cfg( test) ]
199
200
mod tests {
201
+ use crate :: expressions:: lit;
202
+
200
203
use super :: * ;
201
204
use arrow:: datatypes:: { DataType , Field , Schema } ;
202
205
use datafusion_common:: ScalarValue ;
206
+ use datafusion_expr:: Operator ;
203
207
use std:: sync:: Arc ;
204
208
205
209
fn create_test_schema ( ) -> ( Schema , Schema ) {
@@ -218,18 +222,68 @@ mod tests {
218
222
}
219
223
220
224
#[ test]
221
- fn test_rewrite_column_with_type_cast ( ) -> Result < ( ) > {
225
+ fn test_rewrite_column_with_type_cast ( ) {
222
226
let ( physical_schema, logical_schema) = create_test_schema ( ) ;
223
227
224
228
let rewriter = PhysicalExprSchemaRewriter :: new ( & physical_schema, & logical_schema) ;
225
229
let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
226
230
227
- let result = rewriter. rewrite ( column_expr) ? ;
231
+ let result = rewriter. rewrite ( column_expr) . unwrap ( ) ;
228
232
229
233
// Should be wrapped in a cast expression
230
234
assert ! ( result. as_any( ) . downcast_ref:: <CastExpr >( ) . is_some( ) ) ;
235
+ }
231
236
232
- Ok ( ( ) )
237
+ #[ test]
238
+ fn test_rewrite_mulit_column_expr_with_type_cast ( ) {
239
+ let ( physical_schema, logical_schema) = create_test_schema ( ) ;
240
+ let rewriter = PhysicalExprSchemaRewriter :: new ( & physical_schema, & logical_schema) ;
241
+
242
+ // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter
243
+ let column_a = Arc :: new ( Column :: new ( "a" , 0 ) ) as Arc < dyn PhysicalExpr > ;
244
+ let column_c = Arc :: new ( Column :: new ( "c" , 2 ) ) as Arc < dyn PhysicalExpr > ;
245
+ let expr = expressions:: BinaryExpr :: new (
246
+ Arc :: clone ( & column_a) ,
247
+ Operator :: Plus ,
248
+ Arc :: new ( expressions:: Literal :: new ( ScalarValue :: Int64 ( Some ( 5 ) ) ) ) ,
249
+ ) ;
250
+ let expr = expressions:: BinaryExpr :: new (
251
+ Arc :: new ( expr) ,
252
+ Operator :: Or ,
253
+ Arc :: new ( expressions:: BinaryExpr :: new (
254
+ Arc :: clone ( & column_c) ,
255
+ Operator :: Gt ,
256
+ Arc :: new ( expressions:: Literal :: new ( ScalarValue :: Float64 ( Some ( 0.0 ) ) ) ) ,
257
+ ) ) ,
258
+ ) ;
259
+
260
+ let result = rewriter. rewrite ( Arc :: new ( expr) ) . unwrap ( ) ;
261
+ println ! ( "Rewritten expression: {}" , result) ;
262
+
263
+ let expected = expressions:: BinaryExpr :: new (
264
+ Arc :: new ( CastExpr :: new (
265
+ Arc :: new ( Column :: new ( "a" , 0 ) ) ,
266
+ DataType :: Int64 ,
267
+ None ,
268
+ ) ) ,
269
+ Operator :: Plus ,
270
+ Arc :: new ( expressions:: Literal :: new ( ScalarValue :: Int64 ( Some ( 5 ) ) ) ) ,
271
+ ) ;
272
+ let expected = Arc :: new ( expressions:: BinaryExpr :: new (
273
+ Arc :: new ( expected) ,
274
+ Operator :: Or ,
275
+ Arc :: new ( expressions:: BinaryExpr :: new (
276
+ lit ( ScalarValue :: Null ) ,
277
+ Operator :: Gt ,
278
+ Arc :: new ( expressions:: Literal :: new ( ScalarValue :: Float64 ( Some ( 0.0 ) ) ) ) ,
279
+ ) ) ,
280
+ ) ) as Arc < dyn PhysicalExpr > ;
281
+
282
+ assert_eq ! (
283
+ result. to_string( ) ,
284
+ expected. to_string( ) ,
285
+ "The rewritten expression did not match the expected output"
286
+ ) ;
233
287
}
234
288
235
289
#[ test]
0 commit comments