diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7c4a02678899..3b9d64567229 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -941,7 +941,11 @@ impl OptimizerRule for PushDownFilter { let group_expr_columns = agg .group_expr .iter() - .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string()))) + .map(|e| { + Ok(Column::from_qualified_name_ignore_case( + e.schema_name().to_string(), + )) + }) .collect::>>()?; let predicates = split_conjunction_owned(filter.predicate); @@ -4123,4 +4127,55 @@ mod tests { " ) } + + /// Create a test table scan with uppercase column names for case sensitivity testing + fn test_table_scan_with_uppercase_columns() -> Result { + let schema = Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("A", DataType::UInt32, false), + Field::new("B", DataType::UInt32, false), + Field::new("C", DataType::UInt32, false), + ]); + table_scan(Some("test"), &schema, None)?.build() + } + + #[test] + fn filter_agg_case_insensitive() -> Result<()> { + let table_scan = test_table_scan_with_uppercase_columns()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col(r#""A""#)], + vec![sum(col(r#""B""#)).alias("total_salary")], + )? + .filter(col(r#""A""#).gt(lit(10i64)))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.A]], aggr=[[sum(test.B) AS total_salary]] + TableScan: test, full_filters=[test.A > Int64(10)] + " + ) + } + + #[test] + fn filter_agg_mix_case_insensitive() -> Result<()> { + let table_scan = test_table_scan_with_uppercase_columns()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("a")], + vec![sum(col(r#""B""#)).alias("total_salary")], + )? + .filter(col("a").gt(lit(10i64)))? + .build()?; + + assert_optimized_plan_equal!( + plan, + @r" + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.B) AS total_salary]] + TableScan: test, full_filters=[test.a > Int64(10)] + " + ) + } }