Skip to content

Commit 7a1bb68

Browse files
committed
fmt
1 parent 617d66e commit 7a1bb68

File tree

1 file changed

+79
-55
lines changed

1 file changed

+79
-55
lines changed

datafusion/physical-expr/src/schema_rewriter.rs

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
//! Physical expression schema rewriting utilities
1919
20-
use std::sync::Arc;
2120
use std::cmp::Ordering;
21+
use std::sync::Arc;
2222

2323
use arrow::compute::can_cast_types;
2424
use arrow::datatypes::{
@@ -230,7 +230,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
230230
left.as_any().downcast_ref::<CastExpr>(),
231231
right.as_any().downcast_ref::<Literal>(),
232232
) {
233-
if let Some(optimized) = self.unwrap_cast_with_literal(cast_expr, literal, *op)? {
233+
if let Some(optimized) =
234+
self.unwrap_cast_with_literal(cast_expr, literal, *op)?
235+
{
234236
return Ok(Some(Arc::new(BinaryExpr::new(
235237
optimized.0,
236238
*op,
@@ -244,7 +246,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
244246
left.as_any().downcast_ref::<Literal>(),
245247
right.as_any().downcast_ref::<CastExpr>(),
246248
) {
247-
if let Some(optimized) = self.unwrap_cast_with_literal(cast_expr, literal, *op)? {
249+
if let Some(optimized) =
250+
self.unwrap_cast_with_literal(cast_expr, literal, *op)?
251+
{
248252
return Ok(Some(Arc::new(BinaryExpr::new(
249253
optimized.1,
250254
*op,
@@ -265,32 +269,36 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
265269
) -> Result<Option<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>> {
266270
// Get the inner expression (what's being cast)
267271
let inner_expr = cast_expr.expr();
268-
272+
269273
// Handle the case where inner expression might be another cast (due to schema rewriting)
270274
// This can happen when the schema rewriter adds a cast to a column, and then we have
271275
// an original cast on top of that.
272-
let (final_inner_expr, column) = if let Some(inner_cast) = inner_expr.as_any().downcast_ref::<CastExpr>() {
273-
// We have a nested cast, check if the inner cast's expression is a column
274-
let inner_inner_expr = inner_cast.expr();
275-
if let Some(col) = inner_inner_expr.as_any().downcast_ref::<Column>() {
276-
(inner_inner_expr, col)
276+
let (final_inner_expr, column) =
277+
if let Some(inner_cast) = inner_expr.as_any().downcast_ref::<CastExpr>() {
278+
// We have a nested cast, check if the inner cast's expression is a column
279+
let inner_inner_expr = inner_cast.expr();
280+
if let Some(col) = inner_inner_expr.as_any().downcast_ref::<Column>() {
281+
(inner_inner_expr, col)
282+
} else {
283+
return Ok(None);
284+
}
285+
} else if let Some(col) = inner_expr.as_any().downcast_ref::<Column>() {
286+
(inner_expr, col)
277287
} else {
278288
return Ok(None);
279-
}
280-
} else if let Some(col) = inner_expr.as_any().downcast_ref::<Column>() {
281-
(inner_expr, col)
282-
} else {
283-
return Ok(None);
284-
};
289+
};
285290

286291
// Get the column's data type from the physical schema
287-
let column_data_type = match self.physical_file_schema.field_with_name(column.name()) {
288-
Ok(field) => field.data_type(),
289-
Err(_) => return Ok(None), // Column not found, can't optimize
290-
};
292+
let column_data_type =
293+
match self.physical_file_schema.field_with_name(column.name()) {
294+
Ok(field) => field.data_type(),
295+
Err(_) => return Ok(None), // Column not found, can't optimize
296+
};
291297

292298
// Try to cast the literal to the column's data type
293-
if let Some(casted_literal) = try_cast_literal_to_type(literal.value(), column_data_type, op) {
299+
if let Some(casted_literal) =
300+
try_cast_literal_to_type(literal.value(), column_data_type, op)
301+
{
294302
return Ok(Some((
295303
Arc::clone(final_inner_expr),
296304
expressions::lit(casted_literal),
@@ -323,7 +331,6 @@ fn cast_literal_to_type_with_op(
323331
target_type: &DataType,
324332
op: Operator,
325333
) -> Option<ScalarValue> {
326-
327334
match (op, lit_value) {
328335
(
329336
Operator::Eq | Operator::NotEq,
@@ -754,22 +761,27 @@ mod tests {
754761
let column_expr = Arc::new(Column::new("a", 0));
755762
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
756763
let literal_expr = expressions::lit(ScalarValue::Int64(Some(123)));
757-
let binary_expr = Arc::new(BinaryExpr::new(
758-
cast_expr,
759-
Operator::Eq,
760-
literal_expr,
761-
));
764+
let binary_expr =
765+
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr));
762766

763767
let result = rewriter.rewrite(binary_expr.clone() as Arc<dyn PhysicalExpr>)?;
764768

765769
// The result should be a binary expression with the cast unwrapped
766770
let result_binary = result.as_any().downcast_ref::<BinaryExpr>().unwrap();
767-
771+
768772
// Left side should be the original column (no cast)
769-
assert!(result_binary.left().as_any().downcast_ref::<Column>().is_some());
770-
773+
assert!(result_binary
774+
.left()
775+
.as_any()
776+
.downcast_ref::<Column>()
777+
.is_some());
778+
771779
// Right side should be a literal with the value cast to Int32
772-
let right_literal = result_binary.right().as_any().downcast_ref::<Literal>().unwrap();
780+
let right_literal = result_binary
781+
.right()
782+
.as_any()
783+
.downcast_ref::<Literal>()
784+
.unwrap();
773785
assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123)));
774786

775787
Ok(())
@@ -787,23 +799,28 @@ mod tests {
787799
let literal_expr = expressions::lit(ScalarValue::Int64(Some(123)));
788800
let column_expr = Arc::new(Column::new("a", 0));
789801
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
790-
let binary_expr = Arc::new(BinaryExpr::new(
791-
literal_expr,
792-
Operator::Eq,
793-
cast_expr,
794-
));
802+
let binary_expr =
803+
Arc::new(BinaryExpr::new(literal_expr, Operator::Eq, cast_expr));
795804

796805
let result = rewriter.rewrite(binary_expr)?;
797806

798807
// The result should be a binary expression with the cast unwrapped
799808
let result_binary = result.as_any().downcast_ref::<BinaryExpr>().unwrap();
800-
809+
801810
// Left side should be a literal with the value cast to Int32
802-
let left_literal = result_binary.left().as_any().downcast_ref::<Literal>().unwrap();
811+
let left_literal = result_binary
812+
.left()
813+
.as_any()
814+
.downcast_ref::<Literal>()
815+
.unwrap();
803816
assert_eq!(*left_literal.value(), ScalarValue::Int32(Some(123)));
804-
817+
805818
// Right side should be the original column (no cast)
806-
assert!(result_binary.right().as_any().downcast_ref::<Column>().is_some());
819+
assert!(result_binary
820+
.right()
821+
.as_any()
822+
.downcast_ref::<Column>()
823+
.is_some());
807824

808825
Ok(())
809826
}
@@ -820,22 +837,27 @@ mod tests {
820837
let column_expr = Arc::new(Column::new("a", 0));
821838
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Utf8, None));
822839
let literal_expr = expressions::lit(ScalarValue::Utf8(Some("123".to_string())));
823-
let binary_expr = Arc::new(BinaryExpr::new(
824-
cast_expr,
825-
Operator::Eq,
826-
literal_expr,
827-
));
840+
let binary_expr =
841+
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr));
828842

829843
let result = rewriter.rewrite(binary_expr)?;
830844

831845
// The result should be a binary expression with the cast unwrapped
832846
let result_binary = result.as_any().downcast_ref::<BinaryExpr>().unwrap();
833-
847+
834848
// Left side should be the original column (no cast)
835-
assert!(result_binary.left().as_any().downcast_ref::<Column>().is_some());
836-
849+
assert!(result_binary
850+
.left()
851+
.as_any()
852+
.downcast_ref::<Column>()
853+
.is_some());
854+
837855
// Right side should be a literal with the value cast to Int32
838-
let right_literal = result_binary.right().as_any().downcast_ref::<Literal>().unwrap();
856+
let right_literal = result_binary
857+
.right()
858+
.as_any()
859+
.downcast_ref::<Literal>()
860+
.unwrap();
839861
assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123)));
840862

841863
Ok(())
@@ -844,7 +866,8 @@ mod tests {
844866
#[test]
845867
fn test_no_unwrap_cast_optimization_when_not_applicable() -> Result<()> {
846868
// Test case where optimization should not apply - unsupported cast
847-
let physical_schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
869+
let physical_schema =
870+
Schema::new(vec![Field::new("a", DataType::Float32, false)]);
848871
let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]);
849872

850873
let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema);
@@ -854,18 +877,19 @@ mod tests {
854877
let column_expr = Arc::new(Column::new("a", 0));
855878
let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None));
856879
let literal_expr = expressions::lit(ScalarValue::Int64(Some(123)));
857-
let binary_expr = Arc::new(BinaryExpr::new(
858-
cast_expr,
859-
Operator::Eq,
860-
literal_expr,
861-
));
880+
let binary_expr =
881+
Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr));
862882

863883
let result = rewriter.rewrite(binary_expr)?;
864884

865885
// The result should still be a binary expression with a cast on the left side
866886
// since Float32 is not in our supported types for unwrap cast optimization
867887
let result_binary = result.as_any().downcast_ref::<BinaryExpr>().unwrap();
868-
assert!(result_binary.left().as_any().downcast_ref::<CastExpr>().is_some());
888+
assert!(result_binary
889+
.left()
890+
.as_any()
891+
.downcast_ref::<CastExpr>()
892+
.is_some());
869893

870894
Ok(())
871895
}

0 commit comments

Comments
 (0)