Skip to content

Commit 3a93242

Browse files
committed
better tests
1 parent 47a38f7 commit 3a93242

3 files changed

Lines changed: 79 additions & 63 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/physical-expr-adapter/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ itertools = { workspace = true }
2727
[dev-dependencies]
2828
datafusion-expr = { workspace = true }
2929
rstest = { workspace = true }
30+
insta = { workspace = true }

datafusion/physical-expr-adapter/src/schema_rewriter.rs

Lines changed: 77 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ mod tests {
397397
use datafusion_common::ScalarValue;
398398
use datafusion_expr::Operator;
399399
use datafusion_physical_expr::expressions::{col, BinaryExpr};
400+
use datafusion_physical_expr_common::physical_expr::fmt_sql;
400401
use std::sync::Arc;
401402

402403
fn create_test_schema() -> (Schema, Schema) {
@@ -414,17 +415,24 @@ mod tests {
414415
(physical_schema, logical_schema)
415416
}
416417

418+
fn expression_to_sql(expr: &Arc<dyn PhysicalExpr>) -> String {
419+
format!("{}", fmt_sql(expr.as_ref()))
420+
}
421+
417422
#[test]
418423
fn test_rewrite_column_with_type_cast() {
419424
let (physical_schema, logical_schema) = create_test_schema();
420425

421426
let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema);
422427
let column_expr = Arc::new(Column::new("a", 0));
423428

429+
// Capture input expression
430+
insta::assert_snapshot!(expression_to_sql(&(column_expr.clone() as Arc<dyn PhysicalExpr>)), @"a");
431+
424432
let result = rewriter.rewrite(column_expr).unwrap();
425433

426-
// Should be wrapped in a cast expression
427-
assert!(result.as_any().downcast_ref::<CastExpr>().is_some());
434+
// Capture output expression
435+
insta::assert_snapshot!(expression_to_sql(&result), @"CAST(a AS Int64)");
428436
}
429437

430438
#[test]
@@ -564,7 +572,7 @@ mod tests {
564572

565573
/// Test end-to-end struct evolution with simple field additions
566574
#[test]
567-
fn test_evolved_schema_struct_field_addition() -> Result<()> {
575+
fn test_evolved_schema_struct_field_addition() {
568576
// Physical schema: {user_info: {id: i32, name: string}}
569577
let physical_schema = Schema::new(vec![Field::new(
570578
"user_info",
@@ -597,33 +605,34 @@ mod tests {
597605

598606
// Test that we can rewrite a column reference
599607
let column_expr = Arc::new(Column::new("user_info", 0));
600-
let result = rewriter.rewrite(column_expr)?;
601608

602-
// Should be a struct function expression
603-
assert!(result
604-
.as_any()
605-
.downcast_ref::<ScalarFunctionExpr>()
606-
.is_some());
609+
// Capture input expression
610+
insta::assert_snapshot!(expression_to_sql(&(column_expr.clone() as Arc<dyn PhysicalExpr>)), @"user_info");
611+
612+
let result = rewriter.rewrite(column_expr).unwrap();
613+
614+
// Capture output expression
615+
insta::assert_snapshot!(expression_to_sql(&result), @"struct(get_field(user_info, id), get_field(user_info, name), NULL)");
607616

608617
// Test that we can rewrite a predicate on existing fields
609618
let predicate = Arc::new(BinaryExpr::new(
610-
col("user_info", &logical_schema)?,
619+
col("user_info", &logical_schema).unwrap(),
611620
Operator::IsNotDistinctFrom,
612621
expressions::lit(ScalarValue::Null),
613622
)) as Arc<dyn PhysicalExpr>;
614623

615-
let rewritten_predicate = rewriter.rewrite(predicate)?;
616-
assert!(rewritten_predicate
617-
.as_any()
618-
.downcast_ref::<BinaryExpr>()
619-
.is_some());
624+
// Capture input predicate
625+
insta::assert_snapshot!(expression_to_sql(&predicate), @"user_info IS NOT DISTINCT FROM NULL");
620626

621-
Ok(())
627+
let rewritten_predicate = rewriter.rewrite(predicate).unwrap();
628+
629+
// Capture output predicate
630+
insta::assert_snapshot!(expression_to_sql(&rewritten_predicate), @"struct(get_field(user_info, id), get_field(user_info, name), NULL) IS NOT DISTINCT FROM NULL");
622631
}
623632

624633
/// Test end-to-end struct evolution with field type changes
625634
#[test]
626-
fn test_evolved_schema_struct_field_type_evolution() -> Result<()> {
635+
fn test_evolved_schema_struct_field_type_evolution() {
627636
// Physical schema: {event_data: {timestamp: i64, count: i32}}
628637
let physical_schema = Schema::new(vec![Field::new(
629638
"event_data",
@@ -659,20 +668,19 @@ mod tests {
659668

660669
// Test column rewriting
661670
let column_expr = Arc::new(Column::new("event_data", 0));
662-
let result = rewriter.rewrite(column_expr)?;
663671

664-
// Should be a struct function expression that handles the type conversions
665-
assert!(result
666-
.as_any()
667-
.downcast_ref::<ScalarFunctionExpr>()
668-
.is_some());
672+
// Capture input expression
673+
insta::assert_snapshot!(expression_to_sql(&(column_expr.clone() as Arc<dyn PhysicalExpr>)), @"event_data");
669674

670-
Ok(())
675+
let result = rewriter.rewrite(column_expr).unwrap();
676+
677+
// Capture output expression
678+
insta::assert_snapshot!(expression_to_sql(&result), @"struct(CAST(get_field(event_data, timestamp) AS Timestamp(Millisecond, None)), CAST(get_field(event_data, count) AS Int64))");
671679
}
672680

673681
/// Test end-to-end struct evolution with nested structs
674682
#[test]
675-
fn test_evolved_schema_nested_struct_evolution() -> Result<()> {
683+
fn test_evolved_schema_nested_struct_evolution() {
676684
// Physical schema: {
677685
// metadata: {
678686
// user: {id: i32, name: string},
@@ -740,20 +748,19 @@ mod tests {
740748

741749
// Test that we can handle deeply nested struct evolution
742750
let column_expr = Arc::new(Column::new("metadata", 0));
743-
let result = rewriter.rewrite(column_expr)?;
744751

745-
// Should be a struct function expression
746-
assert!(result
747-
.as_any()
748-
.downcast_ref::<ScalarFunctionExpr>()
749-
.is_some());
752+
// Capture input expression
753+
insta::assert_snapshot!(expression_to_sql(&(column_expr.clone() as Arc<dyn PhysicalExpr>)), @"metadata");
750754

751-
Ok(())
755+
let result = rewriter.rewrite(column_expr).unwrap();
756+
757+
// Capture output expression
758+
insta::assert_snapshot!(expression_to_sql(&result), @"struct(struct(CAST(get_field(get_field(metadata, user), id) AS Int64), get_field(get_field(metadata, user), name), NULL), CAST(get_field(metadata, created_at) AS Timestamp(Millisecond, None)), NULL)");
752759
}
753760

754761
/// Test end-to-end struct evolution with field removal (extra fields in source)
755762
#[test]
756-
fn test_evolved_schema_struct_field_removal() -> Result<()> {
763+
fn test_evolved_schema_struct_field_removal() {
757764
// Physical schema: {config: {debug_mode: bool, log_level: string, deprecated_flag: bool}}
758765
let physical_schema = Schema::new(vec![Field::new(
759766
"config",
@@ -786,20 +793,19 @@ mod tests {
786793

787794
// Test that extra fields are properly ignored
788795
let column_expr = Arc::new(Column::new("config", 0));
789-
let result = rewriter.rewrite(column_expr)?;
790796

791-
// Should be a struct function expression that ignores the deprecated field
792-
assert!(result
793-
.as_any()
794-
.downcast_ref::<ScalarFunctionExpr>()
795-
.is_some());
797+
// Capture input expression
798+
insta::assert_snapshot!(expression_to_sql(&(column_expr.clone() as Arc<dyn PhysicalExpr>)), @"config");
796799

797-
Ok(())
800+
let result = rewriter.rewrite(column_expr).unwrap();
801+
802+
// Capture output expression
803+
insta::assert_snapshot!(expression_to_sql(&result), @"struct(get_field(config, debug_mode), get_field(config, log_level))");
798804
}
799805

800806
/// Test end-to-end struct evolution with mixed scenarios (realistic data evolution)
801807
#[test]
802-
fn test_evolved_schema_complex_struct_evolution() -> Result<()> {
808+
fn test_evolved_schema_complex_struct_evolution() {
803809
// Simulate a realistic data evolution scenario:
804810
// Physical schema represents an older version of the data
805811
let physical_schema = Schema::new(vec![
@@ -855,31 +861,41 @@ mod tests {
855861

856862
// Test rewriting of simple field with type change
857863
let id_expr = Arc::new(Column::new("id", 0));
858-
let id_result = rewriter.rewrite(id_expr)?;
859-
assert!(id_result.as_any().downcast_ref::<CastExpr>().is_some());
864+
865+
// Capture input expression
866+
insta::assert_snapshot!(expression_to_sql(&(id_expr.clone() as Arc<dyn PhysicalExpr>)), @"id");
867+
868+
let id_result = rewriter.rewrite(id_expr).unwrap();
869+
870+
// Capture output expression
871+
insta::assert_snapshot!(expression_to_sql(&id_result), @"CAST(id AS Int64)");
860872

861873
// Test rewriting of complex struct field
862874
let profile_expr = Arc::new(Column::new("profile", 1));
863-
let profile_result = rewriter.rewrite(profile_expr)?;
864-
assert!(profile_result
865-
.as_any()
866-
.downcast_ref::<ScalarFunctionExpr>()
867-
.is_some());
875+
876+
// Capture input expression
877+
insta::assert_snapshot!(expression_to_sql(&(profile_expr.clone() as Arc<dyn PhysicalExpr>)), @"profile");
878+
879+
let profile_result = rewriter.rewrite(profile_expr).unwrap();
880+
881+
// Capture output expression
882+
insta::assert_snapshot!(expression_to_sql(&profile_result), @"struct(get_field(profile, username), CAST(get_field(profile, age) AS Int64), NULL, NULL)");
868883

869884
// Test rewriting of missing field (should become null)
870885
let created_at_expr = Arc::new(Column::new("created_at", 2));
871-
let created_at_result = rewriter.rewrite(created_at_expr)?;
872-
assert!(created_at_result
873-
.as_any()
874-
.downcast_ref::<datafusion_physical_expr::expressions::Literal>()
875-
.is_some());
876886

877-
Ok(())
887+
// Capture input expression
888+
insta::assert_snapshot!(expression_to_sql(&(created_at_expr.clone() as Arc<dyn PhysicalExpr>)), @"created_at");
889+
890+
let created_at_result = rewriter.rewrite(created_at_expr).unwrap();
891+
892+
// Capture output expression
893+
insta::assert_snapshot!(expression_to_sql(&created_at_result), @"NULL");
878894
}
879895

880896
/// Test that struct evolution works correctly with predicates
881897
#[test]
882-
fn test_evolved_schema_struct_with_predicates() -> Result<()> {
898+
fn test_evolved_schema_struct_with_predicates() {
883899
// Physical schema: {event: {type: string, data: {count: i32}}}
884900
let physical_schema = Schema::new(vec![Field::new(
885901
"event",
@@ -926,19 +942,17 @@ mod tests {
926942

927943
// Create a complex predicate that references the struct
928944
let predicate = Arc::new(BinaryExpr::new(
929-
col("event", &logical_schema)?,
945+
col("event", &logical_schema).unwrap(),
930946
Operator::IsNotDistinctFrom,
931947
expressions::lit(ScalarValue::Null),
932948
)) as Arc<dyn PhysicalExpr>;
933949

934-
let rewritten_predicate = rewriter.rewrite(predicate)?;
950+
// Capture input expression
951+
insta::assert_snapshot!(expression_to_sql(&predicate), @"event IS NOT DISTINCT FROM NULL");
935952

936-
// The predicate should be successfully rewritten
937-
assert!(rewritten_predicate
938-
.as_any()
939-
.downcast_ref::<BinaryExpr>()
940-
.is_some());
953+
let rewritten_predicate = rewriter.rewrite(predicate).unwrap();
941954

942-
Ok(())
955+
// Capture output expression
956+
insta::assert_snapshot!(expression_to_sql(&rewritten_predicate), @"struct(get_field(event, type), struct(CAST(get_field(get_field(event, data), count) AS Int64), NULL)) IS NOT DISTINCT FROM NULL");
943957
}
944958
}

0 commit comments

Comments
 (0)