From bb7b82e5954d0aa1688efe407c98d17b8345ecea Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 17 Jun 2025 13:39:07 -0500 Subject: [PATCH 01/22] wip --- datafusion/datasource-parquet/src/opener.rs | 62 ++++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 285044803d73..6d9b30f44fef 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -25,6 +25,8 @@ use crate::{ apply_file_schema_type_coercions, coerce_int96_to_resolution, row_filter, ParquetAccessPlan, ParquetFileMetrics, ParquetFileReaderFactory, }; +use arrow::compute::can_cast_types; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; use datafusion_datasource::schema_adapter::SchemaAdapterFactory; @@ -35,7 +37,7 @@ use datafusion_common::pruning::{ CompositePruningStatistics, PartitionPruningStatistics, PrunableStatistics, PruningStatistics, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_datasource::PartitionedFile; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; @@ -248,10 +250,16 @@ impl FileOpener for ParquetOpener { } } + let predicate = predicate + .map(|p| { + cast_expr_to_schema(p, &physical_file_schema, &logical_file_schema) + }) + .transpose()?; + // Build predicates for this specific file let (pruning_predicate, page_pruning_predicate) = build_pruning_predicates( predicate.as_ref(), - &logical_file_schema, + &physical_file_schema, &predicate_creation_errors, ); @@ -524,6 +532,56 @@ fn should_enable_page_index( .unwrap_or(false) } +use datafusion_physical_expr::expressions; + +/// Given a [`PhysicalExpr`] and a [`SchemaRef`], returns a new [`PhysicalExpr`] that +/// is cast to the specified data type. +/// Preference is always given to casting literal values to the data type of the column +/// since casting the column to the literal value's data type can be significantly more expensive. +/// Given two columns the cast is applied arbitrarily to the first column. +pub fn cast_expr_to_schema( + expr: Arc, + physical_file_schema: &Schema, + logical_file_schema: &Schema, +) -> Result> { + expr.transform(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + let logical_field = logical_file_schema.field_with_name(column.name())?; + let Ok(physical_field) = physical_file_schema.field_with_name(column.name()) + else { + // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. + let value = ScalarValue::Null.cast_to(logical_field.data_type())?; + return Ok(Transformed::yes(expressions::lit(value))); + }; + + if logical_field.data_type() == physical_field.data_type() { + return Ok(Transformed::no(expr)); + } + + // If the logical field and physical field are different, we need to cast + // the column to the logical field's data type. + // We will try later to move the cast to literal values if possible, which is computationally cheaper. + if !can_cast_types(logical_field.data_type(), physical_field.data_type()) { + return exec_err!( + "Cannot cast column '{}' from '{}' to '{}'", + column.name(), + logical_field.data_type(), + physical_field.data_type() + ); + } + let casted_expr = Arc::new(expressions::CastExpr::new( + expr, + logical_field.data_type().clone(), + None, + )); + return Ok(Transformed::yes(casted_expr)); + } + + Ok(Transformed::no(expr)) + }) + .data() +} + #[cfg(test)] mod test { use std::sync::Arc; From 317f76764d2ccf8f1bc20d3b46f2f60dfcf2e81f Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 13:13:23 -0500 Subject: [PATCH 02/22] adapt filter expressions to file schema during parquet scan --- datafusion/datasource-parquet/src/opener.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 6d9b30f44fef..a8624c6d56e5 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -549,6 +549,12 @@ pub fn cast_expr_to_schema( let logical_field = logical_file_schema.field_with_name(column.name())?; let Ok(physical_field) = physical_file_schema.field_with_name(column.name()) else { + if !logical_field.is_nullable() { + return exec_err!( + "Non-nullable column '{}' is missing from the physical schema", + column.name() + ); + } // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. let value = ScalarValue::Null.cast_to(logical_field.data_type())?; return Ok(Transformed::yes(expressions::lit(value))); From 30e171f607a6e0ac0100c0248905a48344d9e1cf Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 13:41:58 -0500 Subject: [PATCH 03/22] handle partition values --- datafusion/datasource-parquet/src/opener.rs | 32 +++++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index a8624c6d56e5..bd51ff1e7ea4 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -161,7 +161,7 @@ impl FileOpener for ParquetOpener { if let Some(pruning_predicate) = pruning_predicate { // The partition column schema is the schema of the table - the schema of the file let mut pruning = Box::new(PartitionPruningStatistics::try_new( - vec![file.partition_values], + vec![file.partition_values.clone()], partition_fields.clone(), )?) as Box; @@ -252,7 +252,14 @@ impl FileOpener for ParquetOpener { let predicate = predicate .map(|p| { - cast_expr_to_schema(p, &physical_file_schema, &logical_file_schema) + cast_expr_to_schema( + p, + &physical_file_schema, + &logical_file_schema, + file.partition_values, + &partition_fields, + ) + .map_err(ArrowError::from) }) .transpose()?; @@ -543,10 +550,29 @@ pub fn cast_expr_to_schema( expr: Arc, physical_file_schema: &Schema, logical_file_schema: &Schema, + partition_values: Vec, + partition_fields: &[FieldRef], ) -> Result> { expr.transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - let logical_field = logical_file_schema.field_with_name(column.name())?; + let logical_field = match logical_file_schema.field_with_name(column.name()) { + Ok(field) => field, + Err(e) => { + // Is this a partition field? + for (partition_field, partition_value) in + partition_fields.iter().zip(partition_values.iter()) + { + if partition_field.name() == column.name() { + // If the column is a partition field, we can use the partition value + return Ok(Transformed::yes(expressions::lit( + partition_value.clone(), + ))); + } + } + // If the column is not found in the logical schema, return an error + return Err(e.into()); + } + }; let Ok(physical_field) = physical_file_schema.field_with_name(column.name()) else { if !logical_field.is_nullable() { From 3f79f05cf985eb95388580e3708ab52246b34137 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 13:47:18 -0500 Subject: [PATCH 04/22] add more comments --- datafusion/datasource-parquet/src/opener.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index bd51ff1e7ea4..09a596b23af8 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -558,18 +558,19 @@ pub fn cast_expr_to_schema( let logical_field = match logical_file_schema.field_with_name(column.name()) { Ok(field) => field, Err(e) => { - // Is this a partition field? + // If the column is a partition field, we can use the partition value for (partition_field, partition_value) in partition_fields.iter().zip(partition_values.iter()) { if partition_field.name() == column.name() { - // If the column is a partition field, we can use the partition value return Ok(Transformed::yes(expressions::lit( partition_value.clone(), ))); } } // If the column is not found in the logical schema, return an error + // This should probably never be hit unless something upstream broke, but nontheless it's better + // for us to return a handleable error than to panic / do something unexpected. return Err(e.into()); } }; @@ -582,6 +583,8 @@ pub fn cast_expr_to_schema( ); } // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. + // TODO: do we need to sync this with what the `SchemaAdapter` actually does? + // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else! let value = ScalarValue::Null.cast_to(logical_field.data_type())?; return Ok(Transformed::yes(expressions::lit(value))); }; @@ -595,7 +598,7 @@ pub fn cast_expr_to_schema( // We will try later to move the cast to literal values if possible, which is computationally cheaper. if !can_cast_types(logical_field.data_type(), physical_field.data_type()) { return exec_err!( - "Cannot cast column '{}' from '{}' to '{}'", + "Cannot cast column '{}' from '{}' (file data type) to '{}' (table data type)", column.name(), logical_field.data_type(), physical_field.data_type() From e85f3d9be8ba204b6bda502c39d65ad3fb2d197d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 13:53:54 -0500 Subject: [PATCH 05/22] add a new test --- datafusion/datasource-parquet/src/opener.rs | 94 +++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 09a596b23af8..22951c8e23b1 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -972,4 +972,98 @@ mod test { assert_eq!(num_batches, 0); assert_eq!(num_rows, 0); } + + #[tokio::test] + async fn test_prune_on_partition_value_and_data_value() { + let store = Arc::new(InMemory::new()) as Arc; + + // Note: number 3 is missing! + let batch = record_batch!(("a", Int32, vec![Some(1), Some(2), Some(4)])).unwrap(); + let data_size = + write_parquet(Arc::clone(&store), "part=1/file.parquet", batch.clone()).await; + + let file_schema = batch.schema(); + let mut file = PartitionedFile::new( + "part=1/file.parquet".to_string(), + u64::try_from(data_size).unwrap(), + ); + file.partition_values = vec![ScalarValue::Int32(Some(1))]; + + let table_schema = Arc::new(Schema::new(vec![ + Field::new("part", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let make_opener = |predicate| { + ParquetOpener { + partition_index: 0, + projection: Arc::new([0]), + batch_size: 1024, + limit: None, + predicate: Some(predicate), + logical_file_schema: file_schema.clone(), + metadata_size_hint: None, + metrics: ExecutionPlanMetricsSet::new(), + parquet_file_reader_factory: Arc::new( + DefaultParquetFileReaderFactory::new(Arc::clone(&store)), + ), + partition_fields: vec![Arc::new(Field::new( + "part", + DataType::Int32, + false, + ))], + pushdown_filters: true, // note that this is true! + reorder_filters: true, + enable_page_index: false, + enable_bloom_filter: false, + schema_adapter_factory: Arc::new(DefaultSchemaAdapterFactory), + enable_row_group_stats_pruning: false, // note that this is false! + coerce_int96: None, + } + }; + + let make_meta = || FileMeta { + object_meta: ObjectMeta { + location: Path::from("part=1/file.parquet"), + last_modified: Utc::now(), + size: u64::try_from(data_size).unwrap(), + e_tag: None, + version: None, + }, + range: None, + extensions: None, + metadata_size_hint: None, + }; + + // Filter should match the partition value and data value + let expr = col("part").eq(lit(1)).and(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 1); + + // Filter should match the partition value but not the data value + let expr = col("part").eq(lit(1)).and(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + + // Filter should not match the partition value but match the data value + let expr = col("part").eq(lit(2)).and(col("a").eq(lit(1))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file.clone()).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 0); + assert_eq!(num_rows, 0); + } } From e6e94c94b8be8d20c7064101688e8ee50b888152 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 13:58:25 -0500 Subject: [PATCH 06/22] better test? --- datafusion/datasource-parquet/src/opener.rs | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 22951c8e23b1..ffff424c6d54 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -1036,7 +1036,7 @@ mod test { }; // Filter should match the partition value and data value - let expr = col("part").eq(lit(1)).and(col("a").eq(lit(1))); + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(1))); let predicate = logical2physical(&expr, &table_schema); let opener = make_opener(predicate); let stream = opener @@ -1046,23 +1046,32 @@ mod test { .unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; assert_eq!(num_batches, 1); - assert_eq!(num_rows, 1); + assert_eq!(num_rows, 3); // Filter should match the partition value but not the data value - let expr = col("part").eq(lit(1)).and(col("a").eq(lit(3))); + let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); let predicate = logical2physical(&expr, &table_schema); let opener = make_opener(predicate); let stream = opener.open(make_meta(), file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; - assert_eq!(num_batches, 0); - assert_eq!(num_rows, 0); + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 3); // Filter should not match the partition value but match the data value - let expr = col("part").eq(lit(2)).and(col("a").eq(lit(1))); + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); let predicate = logical2physical(&expr, &table_schema); let opener = make_opener(predicate); let stream = opener.open(make_meta(), file.clone()).unwrap().await.unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; + assert_eq!(num_batches, 1); + assert_eq!(num_rows, 1); + + // Filter should not match the partition value or the data value + let expr = col("part").eq(lit(2)).or(col("a").eq(lit(3))); + let predicate = logical2physical(&expr, &table_schema); + let opener = make_opener(predicate); + let stream = opener.open(make_meta(), file).unwrap().await.unwrap(); + let (num_batches, num_rows) = count_batches_and_rows(stream).await; assert_eq!(num_batches, 0); assert_eq!(num_rows, 0); } From 1d820c18e787e8aac0d650fa8da62e1c28631331 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 14:19:03 -0500 Subject: [PATCH 07/22] fmt --- datafusion/datasource-parquet/src/opener.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index ffff424c6d54..16e86110ede4 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -1052,7 +1052,11 @@ mod test { let expr = col("part").eq(lit(1)).or(col("a").eq(lit(3))); let predicate = logical2physical(&expr, &table_schema); let opener = make_opener(predicate); - let stream = opener.open(make_meta(), file.clone()).unwrap().await.unwrap(); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; assert_eq!(num_batches, 1); assert_eq!(num_rows, 3); @@ -1061,7 +1065,11 @@ mod test { let expr = col("part").eq(lit(2)).or(col("a").eq(lit(1))); let predicate = logical2physical(&expr, &table_schema); let opener = make_opener(predicate); - let stream = opener.open(make_meta(), file.clone()).unwrap().await.unwrap(); + let stream = opener + .open(make_meta(), file.clone()) + .unwrap() + .await + .unwrap(); let (num_batches, num_rows) = count_batches_and_rows(stream).await; assert_eq!(num_batches, 1); assert_eq!(num_rows, 1); From 752351b3cf477611a71bc0e68c0dcafb2a878d30 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 23:55:53 -0500 Subject: [PATCH 08/22] remove schema adapters --- datafusion/datasource-parquet/src/opener.rs | 6 +- .../datasource-parquet/src/row_filter.rs | 132 +----------------- 2 files changed, 8 insertions(+), 130 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 16e86110ede4..c584b1ae060e 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -39,6 +39,7 @@ use datafusion_common::pruning::{ }; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::utils::reassign_predicate_columns; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}; @@ -119,7 +120,6 @@ impl FileOpener for ParquetOpener { let projected_schema = SchemaRef::from(self.logical_file_schema.project(&self.projection)?); - let schema_adapter_factory = Arc::clone(&self.schema_adapter_factory); let schema_adapter = self .schema_adapter_factory .create(projected_schema, Arc::clone(&self.logical_file_schema)); @@ -260,7 +260,9 @@ impl FileOpener for ParquetOpener { &partition_fields, ) .map_err(ArrowError::from) + .map(|p| reassign_predicate_columns(p, &physical_file_schema, false)) }) + .transpose()? .transpose()?; // Build predicates for this specific file @@ -303,11 +305,9 @@ impl FileOpener for ParquetOpener { let row_filter = row_filter::build_row_filter( &predicate, &physical_file_schema, - &logical_file_schema, builder.metadata(), reorder_predicates, &file_metrics, - &schema_adapter_factory, ); match row_filter { diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index db455fed6160..9dac0a89b489 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -67,6 +67,7 @@ use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; +use itertools::Itertools; use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; @@ -74,9 +75,8 @@ use parquet::file::metadata::ParquetMetaData; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::Result; -use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::reassign_predicate_columns; +use datafusion_physical_expr::utils::{collect_columns, reassign_predicate_columns}; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use datafusion_physical_plan::metrics; @@ -106,8 +106,6 @@ pub(crate) struct DatafusionArrowPredicate { rows_matched: metrics::Count, /// how long was spent evaluating this predicate time: metrics::Time, - /// used to perform type coercion while filtering rows - schema_mapper: Arc, } impl DatafusionArrowPredicate { @@ -132,7 +130,6 @@ impl DatafusionArrowPredicate { rows_pruned, rows_matched, time, - schema_mapper: candidate.schema_mapper, }) } } @@ -143,8 +140,6 @@ impl ArrowPredicate for DatafusionArrowPredicate { } fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult { - let batch = self.schema_mapper.map_batch(batch)?; - // scoped timer updates on drop let mut timer = self.time.timer(); @@ -187,9 +182,6 @@ pub(crate) struct FilterCandidate { /// required to pass thorugh a `SchemaMapper` to the table schema /// upon which we then evaluate the filter expression. projection: Vec, - /// A `SchemaMapper` used to map batches read from the file schema to - /// the filter's projection of the table schema. - schema_mapper: Arc, /// The projected table schema that this filter references filter_schema: SchemaRef, } @@ -230,25 +222,16 @@ struct FilterCandidateBuilder { /// columns in the file schema that are not in the table schema or columns that /// are in the table schema that are not in the file schema. file_schema: SchemaRef, - /// The schema of the table (merged schema) -- columns may be in different - /// order than in the file and have columns that are not in the file schema - table_schema: SchemaRef, - /// A `SchemaAdapterFactory` used to map the file schema to the table schema. - schema_adapter_factory: Arc, } impl FilterCandidateBuilder { pub fn new( expr: Arc, file_schema: Arc, - table_schema: Arc, - schema_adapter_factory: Arc, ) -> Self { Self { expr, file_schema, - table_schema, - schema_adapter_factory, } } @@ -261,20 +244,17 @@ impl FilterCandidateBuilder { /// * `Err(e)` if an error occurs while building the candidate pub fn build(self, metadata: &ParquetMetaData) -> Result> { let Some(required_indices_into_table_schema) = - pushdown_columns(&self.expr, &self.table_schema)? + pushdown_columns(&self.expr, &self.file_schema)? else { return Ok(None); }; let projected_table_schema = Arc::new( - self.table_schema + self.file_schema .project(&required_indices_into_table_schema)?, ); - let (schema_mapper, projection_into_file_schema) = self - .schema_adapter_factory - .create(Arc::clone(&projected_table_schema), self.table_schema) - .map_schema(&self.file_schema)?; + let projection_into_file_schema = collect_columns(&self.expr).iter().map(|c| c.index()).sorted_unstable().collect_vec(); let required_bytes = size_of_columns(&projection_into_file_schema, metadata)?; let can_use_index = columns_sorted(&projection_into_file_schema, metadata)?; @@ -284,7 +264,6 @@ impl FilterCandidateBuilder { required_bytes, can_use_index, projection: projection_into_file_schema, - schema_mapper: Arc::clone(&schema_mapper), filter_schema: Arc::clone(&projected_table_schema), })) } @@ -426,11 +405,9 @@ fn columns_sorted(_columns: &[usize], _metadata: &ParquetMetaData) -> Result, physical_file_schema: &SchemaRef, - logical_file_schema: &SchemaRef, metadata: &ParquetMetaData, reorder_predicates: bool, file_metrics: &ParquetFileMetrics, - schema_adapter_factory: &Arc, ) -> Result> { let rows_pruned = &file_metrics.pushdown_rows_pruned; let rows_matched = &file_metrics.pushdown_rows_matched; @@ -447,8 +424,6 @@ pub fn build_row_filter( FilterCandidateBuilder::new( Arc::clone(expr), Arc::clone(physical_file_schema), - Arc::clone(logical_file_schema), - Arc::clone(schema_adapter_factory), ) .build(metadata) }) @@ -492,13 +467,9 @@ mod test { use super::*; use datafusion_common::ScalarValue; - use arrow::datatypes::{Field, TimeUnit::Nanosecond}; - use datafusion_datasource::schema_adapter::DefaultSchemaAdapterFactory; use datafusion_expr::{col, Expr}; use datafusion_physical_expr::planner::logical2physical; - use datafusion_physical_plan::metrics::{Count, Time}; - use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; @@ -520,14 +491,11 @@ mod test { let expr = col("int64_list").is_not_null(); let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); let table_schema = Arc::new(table_schema.clone()); let candidate = FilterCandidateBuilder::new( expr, table_schema.clone(), - table_schema, - schema_adapter_factory, ) .build(metadata) .expect("building candidate"); @@ -535,96 +503,6 @@ mod test { assert!(candidate.is_none()); } - #[test] - fn test_filter_type_coercion() { - let testdata = datafusion_common::test_util::parquet_test_data(); - let file = std::fs::File::open(format!("{testdata}/alltypes_plain.parquet")) - .expect("opening file"); - - let parquet_reader_builder = - ParquetRecordBatchReaderBuilder::try_new(file).expect("creating reader"); - let metadata = parquet_reader_builder.metadata().clone(); - let file_schema = parquet_reader_builder.schema().clone(); - - // This is the schema we would like to coerce to, - // which is different from the physical schema of the file. - let table_schema = Schema::new(vec![Field::new( - "timestamp_col", - DataType::Timestamp(Nanosecond, Some(Arc::from("UTC"))), - false, - )]); - - // Test all should fail - let expr = col("timestamp_col").lt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), - None, - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema.clone(), - table_schema.clone(), - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let mut parquet_reader = parquet_reader_builder - .with_projection(row_filter.projection().clone()) - .build() - .expect("building reader"); - - // Parquet file is small, we only need 1 record batch - let first_rb = parquet_reader - .next() - .expect("expected record batch") - .expect("expected error free record batch"); - - let filtered = row_filter.evaluate(first_rb.clone()); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8]))); - - // Test all should pass - let expr = col("timestamp_col").gt(Expr::Literal( - ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), - None, - )); - let expr = logical2physical(&expr, &table_schema); - let schema_adapter_factory = Arc::new(DefaultSchemaAdapterFactory); - let candidate = FilterCandidateBuilder::new( - expr, - file_schema, - table_schema, - schema_adapter_factory, - ) - .build(&metadata) - .expect("building candidate") - .expect("candidate expected"); - - let mut row_filter = DatafusionArrowPredicate::try_new( - candidate, - &metadata, - Count::new(), - Count::new(), - Time::new(), - ) - .expect("creating filter predicate"); - - let filtered = row_filter.evaluate(first_rb); - assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![true; 8]))); - } - #[test] fn nested_data_structures_prevent_pushdown() { let table_schema = Arc::new(get_lists_table_schema()); From 6faa180d74360599d7b37fe7a03ff4d705eff691 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 19 Jun 2025 23:56:02 -0500 Subject: [PATCH 09/22] fmt --- .../datasource-parquet/src/row_filter.rs | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 9dac0a89b489..5626f83186e3 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -225,14 +225,8 @@ struct FilterCandidateBuilder { } impl FilterCandidateBuilder { - pub fn new( - expr: Arc, - file_schema: Arc, - ) -> Self { - Self { - expr, - file_schema, - } + pub fn new(expr: Arc, file_schema: Arc) -> Self { + Self { expr, file_schema } } /// Attempt to build a `FilterCandidate` from the expression @@ -254,7 +248,11 @@ impl FilterCandidateBuilder { .project(&required_indices_into_table_schema)?, ); - let projection_into_file_schema = collect_columns(&self.expr).iter().map(|c| c.index()).sorted_unstable().collect_vec(); + let projection_into_file_schema = collect_columns(&self.expr) + .iter() + .map(|c| c.index()) + .sorted_unstable() + .collect_vec(); let required_bytes = size_of_columns(&projection_into_file_schema, metadata)?; let can_use_index = columns_sorted(&projection_into_file_schema, metadata)?; @@ -493,12 +491,9 @@ mod test { let table_schema = Arc::new(table_schema.clone()); - let candidate = FilterCandidateBuilder::new( - expr, - table_schema.clone(), - ) - .build(metadata) - .expect("building candidate"); + let candidate = FilterCandidateBuilder::new(expr, table_schema.clone()) + .build(metadata) + .expect("building candidate"); assert!(candidate.is_none()); } From bdb10c156f48c126665c9266c469a603a17fd435 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 10:41:46 -0500 Subject: [PATCH 10/22] address PR feedback --- datafusion/datasource-parquet/src/opener.rs | 69 +--- datafusion/physical-expr/src/lib.rs | 2 + .../physical-expr/src/schema_rewriter.rs | 295 ++++++++++++++++++ 3 files changed, 302 insertions(+), 64 deletions(-) create mode 100644 datafusion/physical-expr/src/schema_rewriter.rs diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index c584b1ae060e..2911adedca05 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -25,8 +25,6 @@ use crate::{ apply_file_schema_type_coercions, coerce_int96_to_resolution, row_filter, ParquetAccessPlan, ParquetFileMetrics, ParquetFileReaderFactory, }; -use arrow::compute::can_cast_types; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_datasource::file_meta::FileMeta; use datafusion_datasource::file_stream::{FileOpenFuture, FileOpener}; use datafusion_datasource::schema_adapter::SchemaAdapterFactory; @@ -539,7 +537,7 @@ fn should_enable_page_index( .unwrap_or(false) } -use datafusion_physical_expr::expressions; +use datafusion_physical_expr::PhysicalExprSchemaRewriter; /// Given a [`PhysicalExpr`] and a [`SchemaRef`], returns a new [`PhysicalExpr`] that /// is cast to the specified data type. @@ -553,68 +551,11 @@ pub fn cast_expr_to_schema( partition_values: Vec, partition_fields: &[FieldRef], ) -> Result> { - expr.transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - let logical_field = match logical_file_schema.field_with_name(column.name()) { - Ok(field) => field, - Err(e) => { - // If the column is a partition field, we can use the partition value - for (partition_field, partition_value) in - partition_fields.iter().zip(partition_values.iter()) - { - if partition_field.name() == column.name() { - return Ok(Transformed::yes(expressions::lit( - partition_value.clone(), - ))); - } - } - // If the column is not found in the logical schema, return an error - // This should probably never be hit unless something upstream broke, but nontheless it's better - // for us to return a handleable error than to panic / do something unexpected. - return Err(e.into()); - } - }; - let Ok(physical_field) = physical_file_schema.field_with_name(column.name()) - else { - if !logical_field.is_nullable() { - return exec_err!( - "Non-nullable column '{}' is missing from the physical schema", - column.name() - ); - } - // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. - // TODO: do we need to sync this with what the `SchemaAdapter` actually does? - // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else! - let value = ScalarValue::Null.cast_to(logical_field.data_type())?; - return Ok(Transformed::yes(expressions::lit(value))); - }; - - if logical_field.data_type() == physical_field.data_type() { - return Ok(Transformed::no(expr)); - } - - // If the logical field and physical field are different, we need to cast - // the column to the logical field's data type. - // We will try later to move the cast to literal values if possible, which is computationally cheaper. - if !can_cast_types(logical_field.data_type(), physical_field.data_type()) { - return exec_err!( - "Cannot cast column '{}' from '{}' (file data type) to '{}' (table data type)", - column.name(), - logical_field.data_type(), - physical_field.data_type() - ); - } - let casted_expr = Arc::new(expressions::CastExpr::new( - expr, - logical_field.data_type().clone(), - None, - )); - return Ok(Transformed::yes(casted_expr)); - } + let rewriter = + PhysicalExprSchemaRewriter::new(physical_file_schema, logical_file_schema) + .with_partition_columns(partition_fields.to_vec(), partition_values); - Ok(Transformed::no(expr)) - }) - .data() + rewriter.rewrite(expr) } #[cfg(test)] diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 6741f94c9545..f74b739d15a4 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -37,6 +37,7 @@ mod partitioning; mod physical_expr; pub mod planner; mod scalar_function; +pub mod schema_rewriter; pub mod statistics; pub mod utils; pub mod window; @@ -67,6 +68,7 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use schema_rewriter::PhysicalExprSchemaRewriter; pub use utils::{conjunction, conjunction_opt, split_conjunction}; // For backwards compatibility diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs new file mode 100644 index 000000000000..7d51042e134b --- /dev/null +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -0,0 +1,295 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical expression schema rewriting utilities + +use std::sync::Arc; + +use arrow::compute::can_cast_types; +use arrow::datatypes::{FieldRef, Schema}; +use datafusion_common::{ + exec_err, + tree_node::{Transformed, TransformedResult, TreeNode}, + Result, ScalarValue, +}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +use crate::expressions::{self, CastExpr, Column}; + +/// Builder for rewriting physical expressions to match different schemas. +/// +/// # Example +/// +/// ```rust +/// use datafusion_physical_expr::schema_rewriter::PhysicalExprSchemaRewriter; +/// use arrow::datatypes::Schema; +/// +/// # fn example( +/// # predicate: std::sync::Arc, +/// # physical_file_schema: &Schema, +/// # logical_file_schema: &Schema, +/// # ) -> datafusion_common::Result<()> { +/// let rewriter = PhysicalExprSchemaRewriter::new(physical_file_schema, logical_file_schema); +/// let adapted_predicate = rewriter.rewrite(predicate)?; +/// # Ok(()) +/// # } +/// ``` +pub struct PhysicalExprSchemaRewriter<'a> { + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + partition_fields: Vec, + partition_values: Vec, +} + +impl<'a> PhysicalExprSchemaRewriter<'a> { + /// Create a new schema rewriter with the given schemas + pub fn new( + physical_file_schema: &'a Schema, + logical_file_schema: &'a Schema, + ) -> Self { + Self { + physical_file_schema, + logical_file_schema, + partition_fields: Vec::new(), + partition_values: Vec::new(), + } + } + + /// Add partition columns and their corresponding values + /// + /// When a column reference matches a partition field, it will be replaced + /// with the corresponding literal value from partition_values. + pub fn with_partition_columns( + mut self, + partition_fields: Vec, + partition_values: Vec, + ) -> Self { + self.partition_fields = partition_fields; + self.partition_values = partition_values; + self + } + + /// Rewrite the given physical expression to match the target schema + /// + /// This method applies the following transformations: + /// 1. Replaces partition column references with literal values + /// 2. Handles missing columns by inserting null literals + /// 3. Casts columns when logical and physical schemas have different types + pub fn rewrite(&self, expr: Arc) -> Result> { + expr.transform(|expr| self.rewrite_expr(expr)).data() + } + + fn rewrite_expr( + &self, + expr: Arc, + ) -> Result>> { + if let Some(column) = expr.as_any().downcast_ref::() { + return self.rewrite_column(Arc::clone(&expr), column); + } + + Ok(Transformed::no(expr)) + } + + fn rewrite_column( + &self, + expr: Arc, + column: &Column, + ) -> Result>> { + // Check if this is a partition column + if let Some(partition_value) = self.get_partition_value(column.name()) { + return Ok(Transformed::yes(expressions::lit(partition_value))); + } + + // Get the logical field for this column + let logical_field = match self.logical_file_schema.field_with_name(column.name()) + { + Ok(field) => field, + Err(e) => { + return Err(e.into()); + } + }; + + // Check if the column exists in the physical schema + let physical_field = + match self.physical_file_schema.field_with_name(column.name()) { + Ok(field) => field, + Err(_) => { + // Column is missing from physical schema + if !logical_field.is_nullable() { + return exec_err!( + "Non-nullable column '{}' is missing from the physical schema", + column.name() + ); + } + // Fill in with null value + let null_value = + ScalarValue::Null.cast_to(logical_field.data_type())?; + return Ok(Transformed::yes(expressions::lit(null_value))); + } + }; + + // Check if casting is needed + if logical_field.data_type() == physical_field.data_type() { + return Ok(Transformed::no(expr)); + } + + // Perform type casting + if !can_cast_types(physical_field.data_type(), logical_field.data_type()) { + return exec_err!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", + column.name(), + physical_field.data_type(), + logical_field.data_type() + ); + } + + let cast_expr = + Arc::new(CastExpr::new(expr, logical_field.data_type().clone(), None)); + + Ok(Transformed::yes(cast_expr)) + } + + fn get_partition_value(&self, column_name: &str) -> Option { + self.partition_fields + .iter() + .zip(self.partition_values.iter()) + .find(|(field, _)| field.name() == column_name) + .map(|(_, value)| value.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::ScalarValue; + use std::sync::Arc; + + fn create_test_schema() -> (Schema, Schema) { + let physical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, true), + ]); + + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int64, false), // Different type + Field::new("b", DataType::Utf8, true), + Field::new("c", DataType::Float64, true), // Missing from physical + ]); + + (physical_schema, logical_schema) + } + + #[test] + fn test_rewrite_column_with_type_cast() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("a", 0)); + + let result = rewriter.rewrite(column_expr)?; + + // Should be wrapped in a cast expression + assert!(result.as_any().downcast_ref::().is_some()); + + Ok(()) + } + + #[test] + fn test_rewrite_missing_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("c", 2)); + + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with a literal null + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!(*literal.value(), ScalarValue::Float64(None)); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_partition_column() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let partition_fields = + vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; + let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) + .with_partition_columns(partition_fields, partition_values); + + let column_expr = Arc::new(Column::new("partition_col", 0)); + let result = rewriter.rewrite(column_expr)?; + + // Should be replaced with the partition value + if let Some(literal) = result.as_any().downcast_ref::() { + assert_eq!( + *literal.value(), + ScalarValue::Utf8(Some("test_value".to_string())) + ); + } else { + panic!("Expected literal expression"); + } + + Ok(()) + } + + #[test] + fn test_rewrite_no_change_needed() -> Result<()> { + let (physical_schema, logical_schema) = create_test_schema(); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)); + + let result = rewriter.rewrite(column_expr.clone())?; + + // Should be the same expression (no transformation needed) + // We compare the underlying pointer through the trait object + assert!(std::ptr::eq( + column_expr.as_ref() as *const dyn PhysicalExpr, + result.as_ref() as *const dyn PhysicalExpr + )); + + Ok(()) + } + + #[test] + fn test_non_nullable_missing_column_error() { + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), // Non-nullable missing column + ]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + let column_expr = Arc::new(Column::new("b", 1)); + + let result = rewriter.rewrite(column_expr); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Non-nullable column 'b' is missing")); + } +} From c7b40712baa004b8a225b352cae12b3a92a095fc Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:04:28 -0500 Subject: [PATCH 11/22] cleanup --- datafusion/datasource-parquet/src/opener.rs | 36 ++++----------------- 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 2911adedca05..7b0ad9ab90ce 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -35,9 +35,10 @@ use datafusion_common::pruning::{ CompositePruningStatistics, PartitionPruningStatistics, PrunableStatistics, PruningStatistics, }; -use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_common::{exec_err, Result}; use datafusion_datasource::PartitionedFile; use datafusion_physical_expr::utils::reassign_predicate_columns; +use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; use datafusion_physical_plan::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder}; @@ -250,13 +251,11 @@ impl FileOpener for ParquetOpener { let predicate = predicate .map(|p| { - cast_expr_to_schema( - p, - &physical_file_schema, - &logical_file_schema, - file.partition_values, - &partition_fields, - ) + let rewriter = + PhysicalExprSchemaRewriter::new(&physical_file_schema, &logical_file_schema) + .with_partition_columns(partition_fields.to_vec(), file.partition_values); + + rewriter.rewrite(p) .map_err(ArrowError::from) .map(|p| reassign_predicate_columns(p, &physical_file_schema, false)) }) @@ -537,27 +536,6 @@ fn should_enable_page_index( .unwrap_or(false) } -use datafusion_physical_expr::PhysicalExprSchemaRewriter; - -/// Given a [`PhysicalExpr`] and a [`SchemaRef`], returns a new [`PhysicalExpr`] that -/// is cast to the specified data type. -/// Preference is always given to casting literal values to the data type of the column -/// since casting the column to the literal value's data type can be significantly more expensive. -/// Given two columns the cast is applied arbitrarily to the first column. -pub fn cast_expr_to_schema( - expr: Arc, - physical_file_schema: &Schema, - logical_file_schema: &Schema, - partition_values: Vec, - partition_fields: &[FieldRef], -) -> Result> { - let rewriter = - PhysicalExprSchemaRewriter::new(physical_file_schema, logical_file_schema) - .with_partition_columns(partition_fields.to_vec(), partition_values); - - rewriter.rewrite(expr) -} - #[cfg(test)] mod test { use std::sync::Arc; From 891a3cc799a5e0462145919b5d627f8e2043f3d1 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:05:26 -0500 Subject: [PATCH 12/22] remove unecessary reassign --- datafusion/datasource-parquet/src/opener.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 7b0ad9ab90ce..b41090305222 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -37,7 +37,6 @@ use datafusion_common::pruning::{ }; use datafusion_common::{exec_err, Result}; use datafusion_datasource::PartitionedFile; -use datafusion_physical_expr::utils::reassign_predicate_columns; use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; @@ -257,9 +256,7 @@ impl FileOpener for ParquetOpener { rewriter.rewrite(p) .map_err(ArrowError::from) - .map(|p| reassign_predicate_columns(p, &physical_file_schema, false)) }) - .transpose()? .transpose()?; // Build predicates for this specific file From 7d649c054b2e743cccdb3abf4d61909627f0c231 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:06:08 -0500 Subject: [PATCH 13/22] fmt --- datafusion/datasource-parquet/src/opener.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index b41090305222..9ea917b56a0b 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -250,12 +250,16 @@ impl FileOpener for ParquetOpener { let predicate = predicate .map(|p| { - let rewriter = - PhysicalExprSchemaRewriter::new(&physical_file_schema, &logical_file_schema) - .with_partition_columns(partition_fields.to_vec(), file.partition_values); + let rewriter = PhysicalExprSchemaRewriter::new( + &physical_file_schema, + &logical_file_schema, + ) + .with_partition_columns( + partition_fields.to_vec(), + file.partition_values, + ); - rewriter.rewrite(p) - .map_err(ArrowError::from) + rewriter.rewrite(p).map_err(ArrowError::from) }) .transpose()?; From 484257d40589f2c807499d4922a6aafa535a4c76 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:10:11 -0500 Subject: [PATCH 14/22] better comments --- .../physical-expr/src/schema_rewriter.rs | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 7d51042e134b..ba058587680e 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -109,16 +109,18 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { expr: Arc, column: &Column, ) -> Result>> { - // Check if this is a partition column - if let Some(partition_value) = self.get_partition_value(column.name()) { - return Ok(Transformed::yes(expressions::lit(partition_value))); - } - // Get the logical field for this column let logical_field = match self.logical_file_schema.field_with_name(column.name()) { Ok(field) => field, Err(e) => { + // If the column is a partition field, we can use the partition value + if let Some(partition_value) = self.get_partition_value(column.name()) { + return Ok(Transformed::yes(expressions::lit(partition_value))); + } + // If the column is not found in the logical schema and is not a partition value, return an error + // This should probably never be hit unless something upstream broke, but nontheless it's better + // for us to return a handleable error than to panic / do something unexpected. return Err(e.into()); } }; @@ -128,26 +130,27 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { match self.physical_file_schema.field_with_name(column.name()) { Ok(field) => field, Err(_) => { - // Column is missing from physical schema if !logical_field.is_nullable() { return exec_err!( "Non-nullable column '{}' is missing from the physical schema", column.name() ); } - // Fill in with null value + // If the column is missing from the physical schema fill it in with nulls as `SchemaAdapter` would do. + // TODO: do we need to sync this with what the `SchemaAdapter` actually does? + // While the default implementation fills in nulls in theory a custom `SchemaAdapter` could do something else! let null_value = ScalarValue::Null.cast_to(logical_field.data_type())?; return Ok(Transformed::yes(expressions::lit(null_value))); } }; - // Check if casting is needed + // If the logical field and physical field are different, we need to cast + // the column to the logical field's data type. + // We will try later to move the cast to literal values if possible, which is computationally cheaper. if logical_field.data_type() == physical_field.data_type() { return Ok(Transformed::no(expr)); } - - // Perform type casting if !can_cast_types(physical_field.data_type(), logical_field.data_type()) { return exec_err!( "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", From 893489ba4e187c2129b68a145f9fde5326f47a10 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:10:39 -0500 Subject: [PATCH 15/22] Revert "remove unecessary reassign" This reverts commit 4cef38fb7f8909f5492bd989af4a20f60ca7fa62. --- datafusion/datasource-parquet/src/opener.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 9ea917b56a0b..01900f853302 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -37,6 +37,7 @@ use datafusion_common::pruning::{ }; use datafusion_common::{exec_err, Result}; use datafusion_datasource::PartitionedFile; +use datafusion_physical_expr::utils::reassign_predicate_columns; use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; @@ -259,8 +260,11 @@ impl FileOpener for ParquetOpener { file.partition_values, ); - rewriter.rewrite(p).map_err(ArrowError::from) + rewriter.rewrite(p) + .map_err(ArrowError::from) + .map(|p| reassign_predicate_columns(p, &physical_file_schema, false)) }) + .transpose()? .transpose()?; // Build predicates for this specific file From cb23671ef2e7df43b20b8790d036f558711d26a7 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:31:21 -0500 Subject: [PATCH 16/22] handle indexes internally --- datafusion/datasource-parquet/src/opener.rs | 10 +++---- .../physical-expr/src/schema_rewriter.rs | 27 +++++++++++++------ 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/datafusion/datasource-parquet/src/opener.rs b/datafusion/datasource-parquet/src/opener.rs index 01900f853302..0e8e19b69294 100644 --- a/datafusion/datasource-parquet/src/opener.rs +++ b/datafusion/datasource-parquet/src/opener.rs @@ -37,7 +37,6 @@ use datafusion_common::pruning::{ }; use datafusion_common::{exec_err, Result}; use datafusion_datasource::PartitionedFile; -use datafusion_physical_expr::utils::reassign_predicate_columns; use datafusion_physical_expr::PhysicalExprSchemaRewriter; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_optimizer::pruning::PruningPredicate; @@ -251,20 +250,17 @@ impl FileOpener for ParquetOpener { let predicate = predicate .map(|p| { - let rewriter = PhysicalExprSchemaRewriter::new( + PhysicalExprSchemaRewriter::new( &physical_file_schema, &logical_file_schema, ) .with_partition_columns( partition_fields.to_vec(), file.partition_values, - ); - - rewriter.rewrite(p) + ) + .rewrite(p) .map_err(ArrowError::from) - .map(|p| reassign_predicate_columns(p, &physical_file_schema, false)) }) - .transpose()? .transpose()?; // Build predicates for this specific file diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index ba058587680e..9a7949402f86 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -98,7 +98,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { expr: Arc, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { - return self.rewrite_column(Arc::clone(&expr), column); + return self.rewrite_column(Arc::clone(&expr), column) } Ok(Transformed::no(expr)) @@ -126,9 +126,8 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { }; // Check if the column exists in the physical schema - let physical_field = - match self.physical_file_schema.field_with_name(column.name()) { - Ok(field) => field, + let physical_column_index = match self.physical_file_schema.index_of(column.name()) { + Ok(index) => index, Err(_) => { if !logical_field.is_nullable() { return exec_err!( @@ -144,13 +143,25 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { return Ok(Transformed::yes(expressions::lit(null_value))); } }; + let physical_field = self.physical_file_schema.field(physical_column_index); + + let column = match (column.index() == physical_column_index, logical_field.data_type() == physical_field.data_type()) { + // If the column index matches and the data types match, we can use the column as is + (true, true) => return Ok(Transformed::no(expr)), + // If the indexes or data types do not match, we need to create a new column expression + (true, _) => column.clone(), + (false, _) => Column::new_with_schema(logical_field.name(), self.logical_file_schema)? + }; - // If the logical field and physical field are different, we need to cast - // the column to the logical field's data type. - // We will try later to move the cast to literal values if possible, which is computationally cheaper. if logical_field.data_type() == physical_field.data_type() { - return Ok(Transformed::no(expr)); + // If the data types match, we can use the column as is + return Ok(Transformed::yes(Arc::new(column))); } + + // We need to cast the column to the logical data type + // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` + // since that's much cheaper to evalaute. + // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 if !can_cast_types(physical_field.data_type(), logical_field.data_type()) { return exec_err!( "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", From 732e326308ca6e3c517aae80255b4e342cdcc407 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:34:01 -0500 Subject: [PATCH 17/22] reafactor --- datafusion/physical-expr/src/schema_rewriter.rs | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 9a7949402f86..8924a76860bb 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -98,7 +98,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { expr: Arc, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { - return self.rewrite_column(Arc::clone(&expr), column) + return self.rewrite_column(Arc::clone(&expr), column); } Ok(Transformed::no(expr)) @@ -126,7 +126,8 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { }; // Check if the column exists in the physical schema - let physical_column_index = match self.physical_file_schema.index_of(column.name()) { + let physical_column_index = + match self.physical_file_schema.index_of(column.name()) { Ok(index) => index, Err(_) => { if !logical_field.is_nullable() { @@ -144,13 +145,18 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { } }; let physical_field = self.physical_file_schema.field(physical_column_index); - - let column = match (column.index() == physical_column_index, logical_field.data_type() == physical_field.data_type()) { + + let column = match ( + column.index() == physical_column_index, + logical_field.data_type() == physical_field.data_type(), + ) { // If the column index matches and the data types match, we can use the column as is (true, true) => return Ok(Transformed::no(expr)), // If the indexes or data types do not match, we need to create a new column expression (true, _) => column.clone(), - (false, _) => Column::new_with_schema(logical_field.name(), self.logical_file_schema)? + (false, _) => { + Column::new_with_schema(logical_field.name(), self.logical_file_schema)? + } }; if logical_field.data_type() == physical_field.data_type() { From 6815e35319094cde421761395e69dd382d142262 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 21 Jun 2025 11:43:32 -0500 Subject: [PATCH 18/22] fix --- datafusion/physical-expr/src/schema_rewriter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 8924a76860bb..bccc85b1e5c8 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -155,7 +155,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { // If the indexes or data types do not match, we need to create a new column expression (true, _) => column.clone(), (false, _) => { - Column::new_with_schema(logical_field.name(), self.logical_file_schema)? + Column::new_with_schema(logical_field.name(), self.physical_file_schema)? } }; From 0de4f9b25fd776c4fa8c8a244d7440bc59a61e7a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 22 Jun 2025 00:45:12 -0500 Subject: [PATCH 19/22] fix --- datafusion/physical-expr/src/schema_rewriter.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index bccc85b1e5c8..53af90862435 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -177,8 +177,11 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { ); } - let cast_expr = - Arc::new(CastExpr::new(expr, logical_field.data_type().clone(), None)); + let cast_expr = Arc::new(CastExpr::new( + Arc::new(column), + logical_field.data_type().clone(), + None, + )); Ok(Transformed::yes(cast_expr)) } @@ -280,9 +283,9 @@ mod tests { let (physical_schema, logical_schema) = create_test_schema(); let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - let column_expr = Arc::new(Column::new("b", 1)); + let column_expr = Arc::new(Column::new("b", 1)) as Arc; - let result = rewriter.rewrite(column_expr.clone())?; + let result = rewriter.rewrite(Arc::clone(&column_expr))?; // Should be the same expression (no transformation needed) // We compare the underlying pointer through the trait object From 617d66ebd092ac89921ee832e2a4d636ebd324c1 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 22 Jun 2025 09:15:51 -0500 Subject: [PATCH 20/22] add cast unwraps --- .../physical-expr/src/schema_rewriter.rs | 568 +++++++++++++++++- 1 file changed, 561 insertions(+), 7 deletions(-) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 53af90862435..4feb18405bee 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -18,17 +18,23 @@ //! Physical expression schema rewriting utilities use std::sync::Arc; +use std::cmp::Ordering; use arrow::compute::can_cast_types; -use arrow::datatypes::{FieldRef, Schema}; +use arrow::datatypes::{ + DataType, FieldRef, Schema, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, +}; +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; use datafusion_common::{ exec_err, tree_node::{Transformed, TransformedResult, TreeNode}, Result, ScalarValue, }; +use datafusion_expr::Operator; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use crate::expressions::{self, CastExpr, Column}; +use crate::expressions::{self, BinaryExpr, CastExpr, Column, Literal}; /// Builder for rewriting physical expressions to match different schemas. /// @@ -89,6 +95,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { /// 1. Replaces partition column references with literal values /// 2. Handles missing columns by inserting null literals /// 3. Casts columns when logical and physical schemas have different types + /// 4. Optimizes cast expressions in binary comparisons by unwrapping casts pub fn rewrite(&self, expr: Arc) -> Result> { expr.transform(|expr| self.rewrite_expr(expr)).data() } @@ -97,6 +104,16 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { &self, expr: Arc, ) -> Result>> { + // Check for binary expressions that can be optimized by unwrapping casts FIRST + // before we rewrite the children, since child rewriting might add casts + if let Some(binary_expr) = expr.as_any().downcast_ref::() { + if let Some(optimized) = self.try_unwrap_cast_in_comparison(binary_expr)? { + // Don't recursively transform the optimized expression here since it might + // cause double-casting. Instead just return it and let the parent transform handle it. + return Ok(Transformed::yes(optimized)); + } + } + if let Some(column) = expr.as_any().downcast_ref::() { return self.rewrite_column(Arc::clone(&expr), column); } @@ -165,9 +182,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { } // We need to cast the column to the logical data type - // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` - // since that's much cheaper to evalaute. - // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 + // Note: Binary expressions with casts are optimized separately in try_unwrap_cast_in_comparison + // to move the cast from the column to literal expressions when possible (e.g., col = 123) + // since that's much cheaper to evaluate. if !can_cast_types(physical_field.data_type(), logical_field.data_type()) { return exec_err!( "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", @@ -193,6 +210,414 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { .find(|(field, _)| field.name() == column_name) .map(|(_, value)| value.clone()) } + + /// Attempt to optimize cast expressions in binary comparisons by unwrapping the cast + /// and applying it to the literal instead. + /// + /// For example: `cast(column as INT64) = 123i64` becomes `column = 123i32` + /// This is much more efficient as the cast is applied once to the literal rather + /// than to every row value. + fn try_unwrap_cast_in_comparison( + &self, + binary_expr: &BinaryExpr, + ) -> Result>> { + let op = binary_expr.op(); + let left = binary_expr.left(); + let right = binary_expr.right(); + + // Check if left side is a cast and right side is a literal + if let (Some(cast_expr), Some(literal)) = ( + left.as_any().downcast_ref::(), + right.as_any().downcast_ref::(), + ) { + if let Some(optimized) = self.unwrap_cast_with_literal(cast_expr, literal, *op)? { + return Ok(Some(Arc::new(BinaryExpr::new( + optimized.0, + *op, + optimized.1, + )))); + } + } + + // Check if right side is a cast and left side is a literal + if let (Some(literal), Some(cast_expr)) = ( + left.as_any().downcast_ref::(), + right.as_any().downcast_ref::(), + ) { + if let Some(optimized) = self.unwrap_cast_with_literal(cast_expr, literal, *op)? { + return Ok(Some(Arc::new(BinaryExpr::new( + optimized.1, + *op, + optimized.0, + )))); + } + } + + Ok(None) + } + + /// Unwrap a cast expression when used with a literal in a comparison + fn unwrap_cast_with_literal( + &self, + cast_expr: &CastExpr, + literal: &Literal, + op: Operator, + ) -> Result, Arc)>> { + // Get the inner expression (what's being cast) + let inner_expr = cast_expr.expr(); + + // Handle the case where inner expression might be another cast (due to schema rewriting) + // This can happen when the schema rewriter adds a cast to a column, and then we have + // an original cast on top of that. + let (final_inner_expr, column) = if let Some(inner_cast) = inner_expr.as_any().downcast_ref::() { + // We have a nested cast, check if the inner cast's expression is a column + let inner_inner_expr = inner_cast.expr(); + if let Some(col) = inner_inner_expr.as_any().downcast_ref::() { + (inner_inner_expr, col) + } else { + return Ok(None); + } + } else if let Some(col) = inner_expr.as_any().downcast_ref::() { + (inner_expr, col) + } else { + return Ok(None); + }; + + // Get the column's data type from the physical schema + let column_data_type = match self.physical_file_schema.field_with_name(column.name()) { + Ok(field) => field.data_type(), + Err(_) => return Ok(None), // Column not found, can't optimize + }; + + // Try to cast the literal to the column's data type + if let Some(casted_literal) = try_cast_literal_to_type(literal.value(), column_data_type, op) { + return Ok(Some(( + Arc::clone(final_inner_expr), + expressions::lit(casted_literal), + ))); + } + + Ok(None) + } +} + +/// Try to cast a literal value to a target type, considering the comparison operator +/// This is adapted from the logical layer unwrap_cast functionality +fn try_cast_literal_to_type( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option { + // First try operator-specific casting (e.g., string to int for equality) + if let Some(result) = cast_literal_to_type_with_op(lit_value, target_type, op) { + return Some(result); + } + + // Fall back to general casting + try_cast_literal_to_type_general(lit_value, target_type) +} + +/// Cast literal with operator-specific logic +fn cast_literal_to_type_with_op( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option { + + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let casted = lit_value.cast_to(target_type).ok()?; + let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?; + if lit_value != &round_tripped { + return None; + } + Some(casted) + } else { + None + } + } + _ => None, + } +} + +/// General literal casting logic adapted from the logical layer +fn try_cast_literal_to_type_general( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { + return None; + } + if lit_value.is_null() { + // null value can be cast to any type of null value + return ScalarValue::try_from(target_type).ok(); + } + try_cast_numeric_literal(lit_value, target_type) + .or_else(|| try_cast_string_literal(lit_value, target_type)) + .or_else(|| try_cast_dictionary(lit_value, target_type)) + .or_else(|| try_cast_binary(lit_value, target_type)) +} + +/// Returns true if unwrap_cast_in_comparison supports this data type +fn is_supported_type(data_type: &DataType) -> bool { + is_supported_numeric_type(data_type) + || is_supported_string_type(data_type) + || is_supported_dictionary_type(data_type) + || is_supported_binary_type(data_type) +} + +/// Returns true if unwrap_cast_in_comparison support this numeric type +fn is_supported_numeric_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Decimal128(_, _) + | DataType::Timestamp(_, _) + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a string +fn is_supported_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary +fn is_supported_dictionary_type(data_type: &DataType) -> bool { + matches!(data_type, + DataType::Dictionary(_, inner) if is_supported_type(inner)) +} + +fn is_supported_binary_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) +} + +/// Convert a numeric value from one numeric data type to another +fn try_cast_numeric_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_numeric_type(&lit_data_type) + || !is_supported_numeric_type(target_type) + { + return None; + } + + let mul = match target_type { + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => 1_i128, + DataType::Timestamp(_, _) => 1_i128, + DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + _ => return None, + }; + let (target_min, target_max) = match target_type { + DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), + DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), + DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), + DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), + DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), + DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), + DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), + DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal128(precision, _) => ( + MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + ), + _ => return None, + }; + let lit_value_target_type = match lit_value { + ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal128(Some(v), _, scale) => { + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + (*v).checked_mul(mul / lit_scale_mul) + } else if (*v) % (lit_scale_mul / mul) == 0 { + Some(*v / (lit_scale_mul / mul)) + } else { + None + } + } + _ => None, + }; + + match lit_value_target_type { + None => None, + Some(value) => { + if value >= target_min && value <= target_max { + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), + DataType::Timestamp(TimeUnit::Second, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Second, tz.clone()), + value, + ); + ScalarValue::TimestampSecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + value, + ); + ScalarValue::TimestampMillisecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + value, + ); + ScalarValue::TimestampMicrosecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + value, + ); + ScalarValue::TimestampNanosecond(value, tz.clone()) + } + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + _ => { + return None; + } + }; + Some(result_scalar) + } else { + None + } + } + } +} + +fn try_cast_string_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); + let scalar_value = match target_type { + DataType::Utf8 => ScalarValue::Utf8(string_value), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + DataType::Utf8View => ScalarValue::Utf8View(string_value), + _ => return None, + }; + Some(scalar_value) +} + +/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary +fn try_cast_dictionary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_value_type = lit_value.data_type(); + let result_scalar = match (lit_value, target_type) { + // Unwrap dictionary when inner type matches target type + (ScalarValue::Dictionary(_, inner_value), _) + if inner_value.data_type() == *target_type => + { + (**inner_value).clone() + } + // Wrap type when target type is dictionary + (_, DataType::Dictionary(index_type, inner_type)) + if **inner_type == lit_value_type => + { + ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) + } + _ => { + return None; + } + }; + Some(result_scalar) +} + +/// Cast a timestamp value from one unit to another +fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { + let value = value as i64; + let from_scale = match from { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + let to_scale = match to { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + match from_scale.cmp(&to_scale) { + Ordering::Less => value.checked_mul(to_scale / from_scale), + Ordering::Greater => Some(value / (from_scale / to_scale)), + Ordering::Equal => Some(value), + } +} + +fn try_cast_binary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + match (lit_value, target_type) { + (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) + if v.len() == *n as usize => + { + Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) + } + _ => None, + } } #[cfg(test)] @@ -200,6 +625,7 @@ mod tests { use super::*; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; + use datafusion_expr::Operator; use std::sync::Arc; fn create_test_schema() -> (Schema, Schema) { @@ -242,7 +668,7 @@ mod tests { let result = rewriter.rewrite(column_expr)?; // Should be replaced with a literal null - if let Some(literal) = result.as_any().downcast_ref::() { + if let Some(literal) = result.as_any().downcast_ref::() { assert_eq!(*literal.value(), ScalarValue::Float64(None)); } else { panic!("Expected literal expression"); @@ -266,7 +692,7 @@ mod tests { let result = rewriter.rewrite(column_expr)?; // Should be replaced with the partition value - if let Some(literal) = result.as_any().downcast_ref::() { + if let Some(literal) = result.as_any().downcast_ref::() { assert_eq!( *literal.value(), ScalarValue::Utf8(Some("test_value".to_string())) @@ -315,4 +741,132 @@ mod tests { .to_string() .contains("Non-nullable column 'b' is missing")); } + + #[test] + fn test_unwrap_cast_optimization() -> Result<()> { + // Test case: cast(int32_column as int64) = 123i64 should become int32_column = 123i32 + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create: cast(column("a") as Int64) = 123i64 + let column_expr = Arc::new(Column::new("a", 0)); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); + let binary_expr = Arc::new(BinaryExpr::new( + cast_expr, + Operator::Eq, + literal_expr, + )); + + let result = rewriter.rewrite(binary_expr.clone() as Arc)?; + + // The result should be a binary expression with the cast unwrapped + let result_binary = result.as_any().downcast_ref::().unwrap(); + + // Left side should be the original column (no cast) + assert!(result_binary.left().as_any().downcast_ref::().is_some()); + + // Right side should be a literal with the value cast to Int32 + let right_literal = result_binary.right().as_any().downcast_ref::().unwrap(); + assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123))); + + Ok(()) + } + + #[test] + fn test_unwrap_cast_optimization_reverse_order() -> Result<()> { + // Test case: 123i64 = cast(int32_column as int64) should become 123i32 = int32_column + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create: 123i64 = cast(column("a") as Int64) + let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); + let column_expr = Arc::new(Column::new("a", 0)); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let binary_expr = Arc::new(BinaryExpr::new( + literal_expr, + Operator::Eq, + cast_expr, + )); + + let result = rewriter.rewrite(binary_expr)?; + + // The result should be a binary expression with the cast unwrapped + let result_binary = result.as_any().downcast_ref::().unwrap(); + + // Left side should be a literal with the value cast to Int32 + let left_literal = result_binary.left().as_any().downcast_ref::().unwrap(); + assert_eq!(*left_literal.value(), ScalarValue::Int32(Some(123))); + + // Right side should be the original column (no cast) + assert!(result_binary.right().as_any().downcast_ref::().is_some()); + + Ok(()) + } + + #[test] + fn test_unwrap_cast_optimization_string_to_int() -> Result<()> { + // Test case: cast(int32_column as Utf8) = "123" should become int32_column = 123i32 + let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let logical_schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create: cast(column("a") as Utf8) = "123" + let column_expr = Arc::new(Column::new("a", 0)); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Utf8, None)); + let literal_expr = expressions::lit(ScalarValue::Utf8(Some("123".to_string()))); + let binary_expr = Arc::new(BinaryExpr::new( + cast_expr, + Operator::Eq, + literal_expr, + )); + + let result = rewriter.rewrite(binary_expr)?; + + // The result should be a binary expression with the cast unwrapped + let result_binary = result.as_any().downcast_ref::().unwrap(); + + // Left side should be the original column (no cast) + assert!(result_binary.left().as_any().downcast_ref::().is_some()); + + // Right side should be a literal with the value cast to Int32 + let right_literal = result_binary.right().as_any().downcast_ref::().unwrap(); + assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123))); + + Ok(()) + } + + #[test] + fn test_no_unwrap_cast_optimization_when_not_applicable() -> Result<()> { + // Test case where optimization should not apply - unsupported cast + let physical_schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + + let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); + + // Create: cast(column("a") as Int64) = 123i64 + // Float32 to Int64 casting might not be optimizable due to precision + let column_expr = Arc::new(Column::new("a", 0)); + let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); + let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); + let binary_expr = Arc::new(BinaryExpr::new( + cast_expr, + Operator::Eq, + literal_expr, + )); + + let result = rewriter.rewrite(binary_expr)?; + + // The result should still be a binary expression with a cast on the left side + // since Float32 is not in our supported types for unwrap cast optimization + let result_binary = result.as_any().downcast_ref::().unwrap(); + assert!(result_binary.left().as_any().downcast_ref::().is_some()); + + Ok(()) + } } From 7a1bb68f38a7cf18368fde320c95c1799b48e135 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 22 Jun 2025 09:18:00 -0500 Subject: [PATCH 21/22] fmt --- .../physical-expr/src/schema_rewriter.rs | 134 +++++++++++------- 1 file changed, 79 insertions(+), 55 deletions(-) diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index 4feb18405bee..cb627293d4cb 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -17,8 +17,8 @@ //! Physical expression schema rewriting utilities -use std::sync::Arc; use std::cmp::Ordering; +use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{ @@ -230,7 +230,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { left.as_any().downcast_ref::(), right.as_any().downcast_ref::(), ) { - if let Some(optimized) = self.unwrap_cast_with_literal(cast_expr, literal, *op)? { + if let Some(optimized) = + self.unwrap_cast_with_literal(cast_expr, literal, *op)? + { return Ok(Some(Arc::new(BinaryExpr::new( optimized.0, *op, @@ -244,7 +246,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { left.as_any().downcast_ref::(), right.as_any().downcast_ref::(), ) { - if let Some(optimized) = self.unwrap_cast_with_literal(cast_expr, literal, *op)? { + if let Some(optimized) = + self.unwrap_cast_with_literal(cast_expr, literal, *op)? + { return Ok(Some(Arc::new(BinaryExpr::new( optimized.1, *op, @@ -265,32 +269,36 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { ) -> Result, Arc)>> { // Get the inner expression (what's being cast) let inner_expr = cast_expr.expr(); - + // Handle the case where inner expression might be another cast (due to schema rewriting) // This can happen when the schema rewriter adds a cast to a column, and then we have // an original cast on top of that. - let (final_inner_expr, column) = if let Some(inner_cast) = inner_expr.as_any().downcast_ref::() { - // We have a nested cast, check if the inner cast's expression is a column - let inner_inner_expr = inner_cast.expr(); - if let Some(col) = inner_inner_expr.as_any().downcast_ref::() { - (inner_inner_expr, col) + let (final_inner_expr, column) = + if let Some(inner_cast) = inner_expr.as_any().downcast_ref::() { + // We have a nested cast, check if the inner cast's expression is a column + let inner_inner_expr = inner_cast.expr(); + if let Some(col) = inner_inner_expr.as_any().downcast_ref::() { + (inner_inner_expr, col) + } else { + return Ok(None); + } + } else if let Some(col) = inner_expr.as_any().downcast_ref::() { + (inner_expr, col) } else { return Ok(None); - } - } else if let Some(col) = inner_expr.as_any().downcast_ref::() { - (inner_expr, col) - } else { - return Ok(None); - }; + }; // Get the column's data type from the physical schema - let column_data_type = match self.physical_file_schema.field_with_name(column.name()) { - Ok(field) => field.data_type(), - Err(_) => return Ok(None), // Column not found, can't optimize - }; + let column_data_type = + match self.physical_file_schema.field_with_name(column.name()) { + Ok(field) => field.data_type(), + Err(_) => return Ok(None), // Column not found, can't optimize + }; // Try to cast the literal to the column's data type - if let Some(casted_literal) = try_cast_literal_to_type(literal.value(), column_data_type, op) { + if let Some(casted_literal) = + try_cast_literal_to_type(literal.value(), column_data_type, op) + { return Ok(Some(( Arc::clone(final_inner_expr), expressions::lit(casted_literal), @@ -323,7 +331,6 @@ fn cast_literal_to_type_with_op( target_type: &DataType, op: Operator, ) -> Option { - match (op, lit_value) { ( Operator::Eq | Operator::NotEq, @@ -754,22 +761,27 @@ mod tests { let column_expr = Arc::new(Column::new("a", 0)); let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); - let binary_expr = Arc::new(BinaryExpr::new( - cast_expr, - Operator::Eq, - literal_expr, - )); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr)); let result = rewriter.rewrite(binary_expr.clone() as Arc)?; // The result should be a binary expression with the cast unwrapped let result_binary = result.as_any().downcast_ref::().unwrap(); - + // Left side should be the original column (no cast) - assert!(result_binary.left().as_any().downcast_ref::().is_some()); - + assert!(result_binary + .left() + .as_any() + .downcast_ref::() + .is_some()); + // Right side should be a literal with the value cast to Int32 - let right_literal = result_binary.right().as_any().downcast_ref::().unwrap(); + let right_literal = result_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123))); Ok(()) @@ -787,23 +799,28 @@ mod tests { let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); let column_expr = Arc::new(Column::new("a", 0)); let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); - let binary_expr = Arc::new(BinaryExpr::new( - literal_expr, - Operator::Eq, - cast_expr, - )); + let binary_expr = + Arc::new(BinaryExpr::new(literal_expr, Operator::Eq, cast_expr)); let result = rewriter.rewrite(binary_expr)?; // The result should be a binary expression with the cast unwrapped let result_binary = result.as_any().downcast_ref::().unwrap(); - + // Left side should be a literal with the value cast to Int32 - let left_literal = result_binary.left().as_any().downcast_ref::().unwrap(); + let left_literal = result_binary + .left() + .as_any() + .downcast_ref::() + .unwrap(); assert_eq!(*left_literal.value(), ScalarValue::Int32(Some(123))); - + // Right side should be the original column (no cast) - assert!(result_binary.right().as_any().downcast_ref::().is_some()); + assert!(result_binary + .right() + .as_any() + .downcast_ref::() + .is_some()); Ok(()) } @@ -820,22 +837,27 @@ mod tests { let column_expr = Arc::new(Column::new("a", 0)); let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Utf8, None)); let literal_expr = expressions::lit(ScalarValue::Utf8(Some("123".to_string()))); - let binary_expr = Arc::new(BinaryExpr::new( - cast_expr, - Operator::Eq, - literal_expr, - )); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr)); let result = rewriter.rewrite(binary_expr)?; // The result should be a binary expression with the cast unwrapped let result_binary = result.as_any().downcast_ref::().unwrap(); - + // Left side should be the original column (no cast) - assert!(result_binary.left().as_any().downcast_ref::().is_some()); - + assert!(result_binary + .left() + .as_any() + .downcast_ref::() + .is_some()); + // Right side should be a literal with the value cast to Int32 - let right_literal = result_binary.right().as_any().downcast_ref::().unwrap(); + let right_literal = result_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123))); Ok(()) @@ -844,7 +866,8 @@ mod tests { #[test] fn test_no_unwrap_cast_optimization_when_not_applicable() -> Result<()> { // Test case where optimization should not apply - unsupported cast - let physical_schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + let physical_schema = + Schema::new(vec![Field::new("a", DataType::Float32, false)]); let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); @@ -854,18 +877,19 @@ mod tests { let column_expr = Arc::new(Column::new("a", 0)); let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); - let binary_expr = Arc::new(BinaryExpr::new( - cast_expr, - Operator::Eq, - literal_expr, - )); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr)); let result = rewriter.rewrite(binary_expr)?; // The result should still be a binary expression with a cast on the left side // since Float32 is not in our supported types for unwrap cast optimization let result_binary = result.as_any().downcast_ref::().unwrap(); - assert!(result_binary.left().as_any().downcast_ref::().is_some()); + assert!(result_binary + .left() + .as_any() + .downcast_ref::() + .is_some()); Ok(()) } From 4e2af3eb81925e2277ff6702a0ac1fc31126295e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 23 Jun 2025 12:03:10 -0500 Subject: [PATCH 22/22] move to shared impl --- datafusion/common/src/lib.rs | 1 + datafusion/common/src/scalar_literal_cast.rs | 315 +++++++++++ .../src/simplify_expressions/unwrap_cast.rs | 374 ++----------- .../physical-expr/src/schema_rewriter.rs | 507 ++---------------- 4 files changed, 397 insertions(+), 800 deletions(-) create mode 100644 datafusion/common/src/scalar_literal_cast.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index d89e08c7d4a6..a517be2552a1 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -51,6 +51,7 @@ pub mod parsers; pub mod pruning; pub mod rounding; pub mod scalar; +pub mod scalar_literal_cast; pub mod spans; pub mod stats; pub mod test_util; diff --git a/datafusion/common/src/scalar_literal_cast.rs b/datafusion/common/src/scalar_literal_cast.rs new file mode 100644 index 000000000000..25145e46a95a --- /dev/null +++ b/datafusion/common/src/scalar_literal_cast.rs @@ -0,0 +1,315 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for casting scalar literals to different data types +//! +//! This module contains functions for casting ScalarValue literals +//! to different data types, originally extracted from the optimizer's +//! unwrap_cast module to be shared between logical and physical layers. + +use std::cmp::Ordering; + +use arrow::datatypes::{ + DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, + MIN_DECIMAL128_FOR_EACH_PRECISION, +}; +use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; + +use crate::ScalarValue; + +/// Convert a literal value from one data type to another +pub fn try_cast_literal_to_type( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { + return None; + } + if lit_value.is_null() { + // null value can be cast to any type of null value + return ScalarValue::try_from(target_type).ok(); + } + try_cast_numeric_literal(lit_value, target_type) + .or_else(|| try_cast_string_literal(lit_value, target_type)) + .or_else(|| try_cast_dictionary(lit_value, target_type)) + .or_else(|| try_cast_binary(lit_value, target_type)) +} + +/// Returns true if unwrap_cast_in_comparison supports this data type +pub fn is_supported_type(data_type: &DataType) -> bool { + is_supported_numeric_type(data_type) + || is_supported_string_type(data_type) + || is_supported_dictionary_type(data_type) + || is_supported_binary_type(data_type) +} + +/// Returns true if unwrap_cast_in_comparison support this numeric type +pub fn is_supported_numeric_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Decimal128(_, _) + | DataType::Timestamp(_, _) + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a string +pub fn is_supported_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + +/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary +pub fn is_supported_dictionary_type(data_type: &DataType) -> bool { + matches!(data_type, + DataType::Dictionary(_, inner) if is_supported_type(inner)) +} + +pub fn is_supported_binary_type(data_type: &DataType) -> bool { + matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) +} + +/// Convert a numeric value from one numeric data type to another +pub fn try_cast_numeric_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_numeric_type(&lit_data_type) + || !is_supported_numeric_type(target_type) + { + return None; + } + + let mul = match target_type { + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 => 1_i128, + DataType::Timestamp(_, _) => 1_i128, + DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), + _ => return None, + }; + let (target_min, target_max) = match target_type { + DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), + DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), + DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), + DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), + DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), + DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), + DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), + DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), + DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), + DataType::Decimal128(precision, _) => ( + // Different precision for decimal128 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], + ), + _ => return None, + }; + let lit_value_target_type = match lit_value { + ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), + ScalarValue::Decimal128(Some(v), _, scale) => { + let lit_scale_mul = 10_i128.pow(*scale as u32); + if mul >= lit_scale_mul { + // Example: + // lit is decimal(123,3,2) + // target type is decimal(5,3) + // the lit can be converted to the decimal(1230,5,3) + (*v).checked_mul(mul / lit_scale_mul) + } else if (*v) % (lit_scale_mul / mul) == 0 { + // Example: + // lit is decimal(123000,10,3) + // target type is int32: the lit can be converted to INT32(123) + // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) + Some(*v / (lit_scale_mul / mul)) + } else { + // can't convert the lit decimal to the target data type + None + } + } + _ => None, + }; + + match lit_value_target_type { + None => None, + Some(value) => { + if value >= target_min && value <= target_max { + // the value casted from lit to the target type is in the range of target type. + // return the target type of scalar value + let result_scalar = match target_type { + DataType::Int8 => ScalarValue::Int8(Some(value as i8)), + DataType::Int16 => ScalarValue::Int16(Some(value as i16)), + DataType::Int32 => ScalarValue::Int32(Some(value as i32)), + DataType::Int64 => ScalarValue::Int64(Some(value as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), + DataType::Timestamp(TimeUnit::Second, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Second, tz.clone()), + value, + ); + ScalarValue::TimestampSecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + value, + ); + ScalarValue::TimestampMillisecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + value, + ); + ScalarValue::TimestampMicrosecond(value, tz.clone()) + } + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + let value = cast_between_timestamp( + &lit_data_type, + &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + value, + ); + ScalarValue::TimestampNanosecond(value, tz.clone()) + } + DataType::Decimal128(p, s) => { + ScalarValue::Decimal128(Some(value), *p, *s) + } + _ => { + return None; + } + }; + Some(result_scalar) + } else { + None + } + } + } +} + +pub fn try_cast_string_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); + let scalar_value = match target_type { + DataType::Utf8 => ScalarValue::Utf8(string_value), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + DataType::Utf8View => ScalarValue::Utf8View(string_value), + _ => return None, + }; + Some(scalar_value) +} + +/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary +pub fn try_cast_dictionary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_value_type = lit_value.data_type(); + let result_scalar = match (lit_value, target_type) { + // Unwrap dictionary when inner type matches target type + (ScalarValue::Dictionary(_, inner_value), _) + if inner_value.data_type() == *target_type => + { + (**inner_value).clone() + } + // Wrap type when target type is dictionary + (_, DataType::Dictionary(index_type, inner_type)) + if **inner_type == lit_value_type => + { + ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) + } + _ => { + return None; + } + }; + Some(result_scalar) +} + +/// Cast a timestamp value from one unit to another +pub fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { + let value = value as i64; + let from_scale = match from { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + let to_scale = match to { + DataType::Timestamp(TimeUnit::Second, _) => 1, + DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, + DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, + DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, + _ => return Some(value), + }; + + match from_scale.cmp(&to_scale) { + Ordering::Less => value.checked_mul(to_scale / from_scale), + Ordering::Greater => Some(value / (from_scale / to_scale)), + Ordering::Equal => Some(value), + } +} + +pub fn try_cast_binary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + match (lit_value, target_type) { + (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) + if v.len() == *n as usize => + { + Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) + } + _ => None, + } +} \ No newline at end of file diff --git a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs index 7c8ff8305e84..d9737b03b345 100644 --- a/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs +++ b/datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs @@ -55,18 +55,50 @@ //! ``` //! -use std::cmp::Ordering; - -use arrow::datatypes::{ - DataType, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, -}; -use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_common::scalar_literal_cast::is_supported_type; use datafusion_common::{internal_err, tree_node::Transformed}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{lit, BinaryExpr}; use datafusion_expr::{simplify::SimplifyInfo, Cast, Expr, Operator, TryCast}; +// Re-export the shared function for backward compatibility +pub(super) use datafusion_common::scalar_literal_cast::try_cast_literal_to_type; + +/// Cast literal with operator-specific logic for comparisons +fn cast_literal_to_type_with_op( + lit_value: &ScalarValue, + target_type: &DataType, + op: Operator, +) -> Option { + match (op, lit_value) { + ( + Operator::Eq | Operator::NotEq, + ScalarValue::Utf8(Some(_)) + | ScalarValue::Utf8View(Some(_)) + | ScalarValue::LargeUtf8(Some(_)), + ) => { + // Only try for integer types (TODO can we do this for other types + // like timestamps)? + use DataType::*; + if matches!( + target_type, + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + ) { + let casted = lit_value.cast_to(target_type).ok()?; + let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?; + if lit_value != &round_tripped { + return None; + } + Some(casted) + } else { + None + } + } + _ => None, + } +} + pub(super) fn unwrap_cast_in_comparison_for_binary( info: &S, cast_expr: Expr, @@ -192,334 +224,6 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist< true } -/// Returns true if unwrap_cast_in_comparison supports this data type -fn is_supported_type(data_type: &DataType) -> bool { - is_supported_numeric_type(data_type) - || is_supported_string_type(data_type) - || is_supported_dictionary_type(data_type) - || is_supported_binary_type(data_type) -} - -/// Returns true if unwrap_cast_in_comparison support this numeric type -fn is_supported_numeric_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - | DataType::Timestamp(_, _) - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a string -fn is_supported_string_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary -fn is_supported_dictionary_type(data_type: &DataType) -> bool { - matches!(data_type, - DataType::Dictionary(_, inner) if is_supported_type(inner)) -} - -fn is_supported_binary_type(data_type: &DataType) -> bool { - matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) -} - -///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./ -/// -/// Specifically, rewrites -/// ```sql -/// cast(col) -/// ``` -/// -/// To -/// -/// ```sql -/// col cast() -/// col -/// ``` -fn cast_literal_to_type_with_op( - lit_value: &ScalarValue, - target_type: &DataType, - op: Operator, -) -> Option { - match (op, lit_value) { - ( - Operator::Eq | Operator::NotEq, - ScalarValue::Utf8(Some(_)) - | ScalarValue::Utf8View(Some(_)) - | ScalarValue::LargeUtf8(Some(_)), - ) => { - // Only try for integer types (TODO can we do this for other types - // like timestamps)? - use DataType::*; - if matches!( - target_type, - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 - ) { - let casted = lit_value.cast_to(target_type).ok()?; - let round_tripped = casted.cast_to(&lit_value.data_type()).ok()?; - if lit_value != &round_tripped { - return None; - } - Some(casted) - } else { - None - } - } - _ => None, - } -} - -/// Convert a literal value from one data type to another -pub(super) fn try_cast_literal_to_type( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { - return None; - } - if lit_value.is_null() { - // null value can be cast to any type of null value - return ScalarValue::try_from(target_type).ok(); - } - try_cast_numeric_literal(lit_value, target_type) - .or_else(|| try_cast_string_literal(lit_value, target_type)) - .or_else(|| try_cast_dictionary(lit_value, target_type)) - .or_else(|| try_cast_binary(lit_value, target_type)) -} - -/// Convert a numeric value from one numeric data type to another -fn try_cast_numeric_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_numeric_type(&lit_data_type) - || !is_supported_numeric_type(target_type) - { - return None; - } - - let mul = match target_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => 1_i128, - DataType::Timestamp(_, _) => 1_i128, - DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), - _ => return None, - }; - let (target_min, target_max) = match target_type { - DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), - DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), - DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), - DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), - DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), - DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), - DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), - DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), - DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), - DataType::Decimal128(precision, _) => ( - // Different precision for decimal128 can store different range of value. - // For example, the precision is 3, the max of value is `999` and the min - // value is `-999` - MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - ), - _ => return None, - }; - let lit_value_target_type = match lit_value { - ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::Decimal128(Some(v), _, scale) => { - let lit_scale_mul = 10_i128.pow(*scale as u32); - if mul >= lit_scale_mul { - // Example: - // lit is decimal(123,3,2) - // target type is decimal(5,3) - // the lit can be converted to the decimal(1230,5,3) - (*v).checked_mul(mul / lit_scale_mul) - } else if (*v) % (lit_scale_mul / mul) == 0 { - // Example: - // lit is decimal(123000,10,3) - // target type is int32: the lit can be converted to INT32(123) - // target type is decimal(10,2): the lit can be converted to decimal(12300,10,2) - Some(*v / (lit_scale_mul / mul)) - } else { - // can't convert the lit decimal to the target data type - None - } - } - _ => None, - }; - - match lit_value_target_type { - None => None, - Some(value) => { - if value >= target_min && value <= target_max { - // the value casted from lit to the target type is in the range of target type. - // return the target type of scalar value - let result_scalar = match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), - DataType::Timestamp(TimeUnit::Second, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Second, tz.clone()), - value, - ); - ScalarValue::TimestampSecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), - value, - ); - ScalarValue::TimestampMillisecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), - value, - ); - ScalarValue::TimestampMicrosecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), - value, - ); - ScalarValue::TimestampNanosecond(value, tz.clone()) - } - DataType::Decimal128(p, s) => { - ScalarValue::Decimal128(Some(value), *p, *s) - } - _ => { - return None; - } - }; - Some(result_scalar) - } else { - None - } - } - } -} - -fn try_cast_string_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); - let scalar_value = match target_type { - DataType::Utf8 => ScalarValue::Utf8(string_value), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), - DataType::Utf8View => ScalarValue::Utf8View(string_value), - _ => return None, - }; - Some(scalar_value) -} - -/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary -fn try_cast_dictionary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { - // Unwrap dictionary when inner type matches target type - (ScalarValue::Dictionary(_, inner_value), _) - if inner_value.data_type() == *target_type => - { - (**inner_value).clone() - } - // Wrap type when target type is dictionary - (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => - { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) - } - _ => { - return None; - } - }; - Some(result_scalar) -} - -/// Cast a timestamp value from one unit to another -fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { - let value = value as i64; - let from_scale = match from { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - let to_scale = match to { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - match from_scale.cmp(&to_scale) { - Ordering::Less => value.checked_mul(to_scale / from_scale), - Ordering::Greater => Some(value / (from_scale / to_scale)), - Ordering::Equal => Some(value), - } -} - -fn try_cast_binary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - match (lit_value, target_type) { - (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) - if v.len() == *n as usize => - { - Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) - } - _ => None, - } -} #[cfg(test)] mod tests { diff --git a/datafusion/physical-expr/src/schema_rewriter.rs b/datafusion/physical-expr/src/schema_rewriter.rs index cb627293d4cb..71e5190d04ac 100644 --- a/datafusion/physical-expr/src/schema_rewriter.rs +++ b/datafusion/physical-expr/src/schema_rewriter.rs @@ -17,21 +17,17 @@ //! Physical expression schema rewriting utilities -use std::cmp::Ordering; use std::sync::Arc; use arrow::compute::can_cast_types; -use arrow::datatypes::{ - DataType, FieldRef, Schema, TimeUnit, MAX_DECIMAL128_FOR_EACH_PRECISION, - MIN_DECIMAL128_FOR_EACH_PRECISION, -}; -use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; +use arrow::datatypes::{DataType, FieldRef, Schema}; use datafusion_common::{ exec_err, tree_node::{Transformed, TransformedResult, TreeNode}, Result, ScalarValue, }; use datafusion_expr::Operator; +use datafusion_common::scalar_literal_cast::try_cast_literal_to_type; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expressions::{self, BinaryExpr, CastExpr, Column, Literal}; @@ -297,7 +293,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { // Try to cast the literal to the column's data type if let Some(casted_literal) = - try_cast_literal_to_type(literal.value(), column_data_type, op) + try_cast_literal_to_type_with_operator(literal.value(), column_data_type, op) { return Ok(Some(( Arc::clone(final_inner_expr), @@ -309,23 +305,7 @@ impl<'a> PhysicalExprSchemaRewriter<'a> { } } -/// Try to cast a literal value to a target type, considering the comparison operator -/// This is adapted from the logical layer unwrap_cast functionality -fn try_cast_literal_to_type( - lit_value: &ScalarValue, - target_type: &DataType, - op: Operator, -) -> Option { - // First try operator-specific casting (e.g., string to int for equality) - if let Some(result) = cast_literal_to_type_with_op(lit_value, target_type, op) { - return Some(result); - } - - // Fall back to general casting - try_cast_literal_to_type_general(lit_value, target_type) -} - -/// Cast literal with operator-specific logic +/// Cast literal with operator-specific logic for comparisons fn cast_literal_to_type_with_op( lit_value: &ScalarValue, target_type: &DataType, @@ -338,7 +318,8 @@ fn cast_literal_to_type_with_op( | ScalarValue::Utf8View(Some(_)) | ScalarValue::LargeUtf8(Some(_)), ) => { - // Only try for integer types + // Only try for integer types (TODO can we do this for other types + // like timestamps)? use DataType::*; if matches!( target_type, @@ -358,273 +339,20 @@ fn cast_literal_to_type_with_op( } } -/// General literal casting logic adapted from the logical layer -fn try_cast_literal_to_type_general( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { - return None; - } - if lit_value.is_null() { - // null value can be cast to any type of null value - return ScalarValue::try_from(target_type).ok(); - } - try_cast_numeric_literal(lit_value, target_type) - .or_else(|| try_cast_string_literal(lit_value, target_type)) - .or_else(|| try_cast_dictionary(lit_value, target_type)) - .or_else(|| try_cast_binary(lit_value, target_type)) -} - -/// Returns true if unwrap_cast_in_comparison supports this data type -fn is_supported_type(data_type: &DataType) -> bool { - is_supported_numeric_type(data_type) - || is_supported_string_type(data_type) - || is_supported_dictionary_type(data_type) - || is_supported_binary_type(data_type) -} - -/// Returns true if unwrap_cast_in_comparison support this numeric type -fn is_supported_numeric_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Decimal128(_, _) - | DataType::Timestamp(_, _) - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a string -fn is_supported_string_type(data_type: &DataType) -> bool { - matches!( - data_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View - ) -} - -/// Returns true if unwrap_cast_in_comparison supports casting this value as a dictionary -fn is_supported_dictionary_type(data_type: &DataType) -> bool { - matches!(data_type, - DataType::Dictionary(_, inner) if is_supported_type(inner)) -} - -fn is_supported_binary_type(data_type: &DataType) -> bool { - matches!(data_type, DataType::Binary | DataType::FixedSizeBinary(_)) -} - -/// Convert a numeric value from one numeric data type to another -fn try_cast_numeric_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_data_type = lit_value.data_type(); - if !is_supported_numeric_type(&lit_data_type) - || !is_supported_numeric_type(target_type) - { - return None; - } - - let mul = match target_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => 1_i128, - DataType::Timestamp(_, _) => 1_i128, - DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), - _ => return None, - }; - let (target_min, target_max) = match target_type { - DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), - DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128), - DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128), - DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128), - DataType::Int8 => (i8::MIN as i128, i8::MAX as i128), - DataType::Int16 => (i16::MIN as i128, i16::MAX as i128), - DataType::Int32 => (i32::MIN as i128, i32::MAX as i128), - DataType::Int64 => (i64::MIN as i128, i64::MAX as i128), - DataType::Timestamp(_, _) => (i64::MIN as i128, i64::MAX as i128), - DataType::Decimal128(precision, _) => ( - MIN_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - MAX_DECIMAL128_FOR_EACH_PRECISION[*precision as usize], - ), - _ => return None, - }; - let lit_value_target_type = match lit_value { - ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::TimestampNanosecond(Some(v), _) => (*v as i128).checked_mul(mul), - ScalarValue::Decimal128(Some(v), _, scale) => { - let lit_scale_mul = 10_i128.pow(*scale as u32); - if mul >= lit_scale_mul { - (*v).checked_mul(mul / lit_scale_mul) - } else if (*v) % (lit_scale_mul / mul) == 0 { - Some(*v / (lit_scale_mul / mul)) - } else { - None - } - } - _ => None, - }; - - match lit_value_target_type { - None => None, - Some(value) => { - if value >= target_min && value <= target_max { - let result_scalar = match target_type { - DataType::Int8 => ScalarValue::Int8(Some(value as i8)), - DataType::Int16 => ScalarValue::Int16(Some(value as i16)), - DataType::Int32 => ScalarValue::Int32(Some(value as i32)), - DataType::Int64 => ScalarValue::Int64(Some(value as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), - DataType::Timestamp(TimeUnit::Second, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Second, tz.clone()), - value, - ); - ScalarValue::TimestampSecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Millisecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), - value, - ); - ScalarValue::TimestampMillisecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Microsecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), - value, - ); - ScalarValue::TimestampMicrosecond(value, tz.clone()) - } - DataType::Timestamp(TimeUnit::Nanosecond, tz) => { - let value = cast_between_timestamp( - &lit_data_type, - &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), - value, - ); - ScalarValue::TimestampNanosecond(value, tz.clone()) - } - DataType::Decimal128(p, s) => { - ScalarValue::Decimal128(Some(value), *p, *s) - } - _ => { - return None; - } - }; - Some(result_scalar) - } else { - None - } - } - } -} - -fn try_cast_string_literal( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); - let scalar_value = match target_type { - DataType::Utf8 => ScalarValue::Utf8(string_value), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), - DataType::Utf8View => ScalarValue::Utf8View(string_value), - _ => return None, - }; - Some(scalar_value) -} - -/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary -fn try_cast_dictionary( +/// Try to cast a literal value to a target type, considering the comparison operator +/// This is adapted from the logical layer unwrap_cast functionality +fn try_cast_literal_to_type_with_operator( lit_value: &ScalarValue, target_type: &DataType, + op: Operator, ) -> Option { - let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { - // Unwrap dictionary when inner type matches target type - (ScalarValue::Dictionary(_, inner_value), _) - if inner_value.data_type() == *target_type => - { - (**inner_value).clone() - } - // Wrap type when target type is dictionary - (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => - { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) - } - _ => { - return None; - } - }; - Some(result_scalar) -} - -/// Cast a timestamp value from one unit to another -fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { - let value = value as i64; - let from_scale = match from { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - let to_scale = match to { - DataType::Timestamp(TimeUnit::Second, _) => 1, - DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS, - DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS, - DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS, - _ => return Some(value), - }; - - match from_scale.cmp(&to_scale) { - Ordering::Less => value.checked_mul(to_scale / from_scale), - Ordering::Greater => Some(value / (from_scale / to_scale)), - Ordering::Equal => Some(value), + // First try operator-specific casting (e.g., string to int for equality) + if let Some(result) = cast_literal_to_type_with_op(lit_value, target_type, op) { + return Some(result); } -} -fn try_cast_binary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - match (lit_value, target_type) { - (ScalarValue::Binary(Some(v)), DataType::FixedSizeBinary(n)) - if v.len() == *n as usize => - { - Some(ScalarValue::FixedSizeBinary(*n, Some(v.clone()))) - } - _ => None, - } + // Fall back to general casting using shared function + try_cast_literal_to_type(lit_value, target_type) } #[cfg(test)] @@ -678,18 +406,17 @@ mod tests { if let Some(literal) = result.as_any().downcast_ref::() { assert_eq!(*literal.value(), ScalarValue::Float64(None)); } else { - panic!("Expected literal expression"); + panic!("Expected literal expression for missing column"); } Ok(()) } #[test] - fn test_rewrite_partition_column() -> Result<()> { + fn test_rewrite_with_partition_columns() -> Result<()> { let (physical_schema, logical_schema) = create_test_schema(); - let partition_fields = - vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; + let partition_fields = vec![Arc::new(Field::new("partition_col", DataType::Utf8, false))]; let partition_values = vec![ScalarValue::Utf8(Some("test_value".to_string()))]; let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema) @@ -705,192 +432,42 @@ mod tests { ScalarValue::Utf8(Some("test_value".to_string())) ); } else { - panic!("Expected literal expression"); + panic!("Expected literal expression for partition column"); } Ok(()) } - #[test] - fn test_rewrite_no_change_needed() -> Result<()> { - let (physical_schema, logical_schema) = create_test_schema(); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - let column_expr = Arc::new(Column::new("b", 1)) as Arc; - - let result = rewriter.rewrite(Arc::clone(&column_expr))?; - - // Should be the same expression (no transformation needed) - // We compare the underlying pointer through the trait object - assert!(std::ptr::eq( - column_expr.as_ref() as *const dyn PhysicalExpr, - result.as_ref() as *const dyn PhysicalExpr - )); - - Ok(()) - } - - #[test] - fn test_non_nullable_missing_column_error() { - let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let logical_schema = Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), // Non-nullable missing column - ]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - let column_expr = Arc::new(Column::new("b", 1)); - - let result = rewriter.rewrite(column_expr); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Non-nullable column 'b' is missing")); - } - #[test] fn test_unwrap_cast_optimization() -> Result<()> { - // Test case: cast(int32_column as int64) = 123i64 should become int32_column = 123i32 - let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Create: cast(column("a") as Int64) = 123i64 - let column_expr = Arc::new(Column::new("a", 0)); - let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); - let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); - let binary_expr = - Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr)); - - let result = rewriter.rewrite(binary_expr.clone() as Arc)?; - - // The result should be a binary expression with the cast unwrapped - let result_binary = result.as_any().downcast_ref::().unwrap(); - - // Left side should be the original column (no cast) - assert!(result_binary - .left() - .as_any() - .downcast_ref::() - .is_some()); - - // Right side should be a literal with the value cast to Int32 - let right_literal = result_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123))); - - Ok(()) - } - - #[test] - fn test_unwrap_cast_optimization_reverse_order() -> Result<()> { - // Test case: 123i64 = cast(int32_column as int64) should become 123i32 = int32_column - let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Create: 123i64 = cast(column("a") as Int64) - let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); - let column_expr = Arc::new(Column::new("a", 0)); - let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); - let binary_expr = - Arc::new(BinaryExpr::new(literal_expr, Operator::Eq, cast_expr)); - - let result = rewriter.rewrite(binary_expr)?; - - // The result should be a binary expression with the cast unwrapped - let result_binary = result.as_any().downcast_ref::().unwrap(); - - // Left side should be a literal with the value cast to Int32 - let left_literal = result_binary - .left() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(*left_literal.value(), ScalarValue::Int32(Some(123))); - - // Right side should be the original column (no cast) - assert!(result_binary - .right() - .as_any() - .downcast_ref::() - .is_some()); - - Ok(()) - } - - #[test] - fn test_unwrap_cast_optimization_string_to_int() -> Result<()> { - // Test case: cast(int32_column as Utf8) = "123" should become int32_column = 123i32 - let physical_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let logical_schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); - - let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - - // Create: cast(column("a") as Utf8) = "123" - let column_expr = Arc::new(Column::new("a", 0)); - let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Utf8, None)); - let literal_expr = expressions::lit(ScalarValue::Utf8(Some("123".to_string()))); - let binary_expr = - Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr)); - - let result = rewriter.rewrite(binary_expr)?; - - // The result should be a binary expression with the cast unwrapped - let result_binary = result.as_any().downcast_ref::().unwrap(); - - // Left side should be the original column (no cast) - assert!(result_binary - .left() - .as_any() - .downcast_ref::() - .is_some()); - - // Right side should be a literal with the value cast to Int32 - let right_literal = result_binary - .right() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(*right_literal.value(), ScalarValue::Int32(Some(123))); - - Ok(()) - } - - #[test] - fn test_no_unwrap_cast_optimization_when_not_applicable() -> Result<()> { - // Test case where optimization should not apply - unsupported cast - let physical_schema = - Schema::new(vec![Field::new("a", DataType::Float32, false)]); - let logical_schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); - + let (physical_schema, logical_schema) = create_test_schema(); let rewriter = PhysicalExprSchemaRewriter::new(&physical_schema, &logical_schema); - // Create: cast(column("a") as Int64) = 123i64 - // Float32 to Int64 casting might not be optimizable due to precision + // Create a cast expression: cast(column_a as INT64) = 123i64 let column_expr = Arc::new(Column::new("a", 0)); let cast_expr = Arc::new(CastExpr::new(column_expr, DataType::Int64, None)); let literal_expr = expressions::lit(ScalarValue::Int64(Some(123))); - let binary_expr = - Arc::new(BinaryExpr::new(cast_expr, Operator::Eq, literal_expr)); - - let result = rewriter.rewrite(binary_expr)?; - - // The result should still be a binary expression with a cast on the left side - // since Float32 is not in our supported types for unwrap cast optimization - let result_binary = result.as_any().downcast_ref::().unwrap(); - assert!(result_binary - .left() - .as_any() - .downcast_ref::() - .is_some()); + let binary_expr = BinaryExpr::new(cast_expr, Operator::Eq, literal_expr); + + let result = rewriter.try_unwrap_cast_in_comparison(&binary_expr)?; + + assert!(result.is_some()); + if let Some(optimized_expr) = result { + if let Some(binary_expr) = optimized_expr.as_any().downcast_ref::() { + // The left side should be the unwrapped column + assert!(binary_expr.left().as_any().downcast_ref::().is_some()); + + // The right side should be the literal cast to the column's type (Int32) + if let Some(literal) = binary_expr.right().as_any().downcast_ref::() { + assert_eq!(*literal.value(), ScalarValue::Int32(Some(123))); + } else { + panic!("Expected literal on right side"); + } + } else { + panic!("Expected binary expression"); + } + } Ok(()) } -} +} \ No newline at end of file