diff --git a/src/query/sql/src/planner/binder/bind_mutation/update.rs b/src/query/sql/src/planner/binder/bind_mutation/update.rs index b7666ba3b5fbd..1d9b26468ad85 100644 --- a/src/query/sql/src/planner/binder/bind_mutation/update.rs +++ b/src/query/sql/src/planner/binder/bind_mutation/update.rs @@ -28,10 +28,12 @@ use crate::binder::bind_mutation::bind::MutationStrategy; use crate::binder::bind_mutation::mutation_expression::MutationExpression; use crate::binder::util::TableIdentifier; use crate::binder::Binder; +use crate::optimizer::ir::Matcher; use crate::optimizer::ir::SExpr; use crate::plans::AggregateFunction; use crate::plans::BoundColumnRef; use crate::plans::Plan; +use crate::plans::RelOp; use crate::plans::RelOperator; use crate::plans::ScalarItem; use crate::plans::VisitorMut; @@ -118,17 +120,20 @@ impl Binder { let Plan::DataMutation { box s_expr, .. } = &plan else { return Ok(plan); }; - let RelOperator::Mutation(mutation) = &*s_expr.plan else { + let RelOperator::Mutation(mutation) = s_expr.plan() else { return Ok(plan); }; - let filter_expr = &s_expr.children[0]; - let RelOperator::Filter(_) = &*filter_expr.plan else { - return Ok(plan); + let input_expr = s_expr.unary_child(); + let matcher = Matcher::MatchOp { + op_type: RelOp::Filter, + children: vec![Matcher::MatchOp { + op_type: RelOp::Join, + children: vec![Matcher::Leaf, Matcher::Leaf], + }], }; - let input = &filter_expr.children[0]; - let RelOperator::Join(_) = &*input.plan else { + if !matcher.matches(input_expr) { return Ok(plan); - }; + } let mut mutation = mutation.clone(); @@ -176,7 +181,6 @@ impl Binder { .flat_map(|expr| expr.used_columns().into_iter()) }) }) - .chain(mutation.required_columns.iter().copied()) .collect::>(); let used_columns = used_columns @@ -201,7 +205,7 @@ impl Binder { let display_name = format!("any({})", binding.index); let old = binding.index; - let mut aggr_func = ScalarExpr::AggregateFunction(AggregateFunction { + let mut aggr_func: ScalarExpr = AggregateFunction { span: None, func_name: "any".to_string(), distinct: false, @@ -213,7 +217,8 @@ impl Binder { return_type: binding.data_type.clone(), sort_descs: vec![], display_name: display_name.clone(), - }); + } + .into(); let mut rewriter = AggregateRewriter::new(&mut mutation.bind_context, self.metadata.clone()); @@ -242,14 +247,30 @@ impl Binder { for eval in &mut mutation.matched_evaluators { if let Some(expr) = &mut eval.condition { for (_, old, new) in &aggr_columns { - expr.replace_column(*old, *new)? + expr.replace_column_datatype_to_nullable(*old, *new)? } } if let Some(update) = &mut eval.update { for (_, expr) in update.iter_mut() { for (_, old, new) in &aggr_columns { - expr.replace_column(*old, *new)? + expr.replace_column_datatype_to_nullable(*old, *new)? + } + } + + for (field_index, expr) in update.iter_mut() { + if let Some(target_column) = + mutation.bind_context.columns.iter().find(|binding| { + binding.table_index == Some(mutation.target_table_index) + && binding.column_name == field_index.to_string() + }) + { + let columns_used = expr.used_columns(); + for col_idx in columns_used { + if col_idx != target_column.index { + expr.replace_column_datatype_to_nullable(col_idx, col_idx)?; + } + } } } } @@ -270,12 +291,12 @@ impl Binder { .collect(), ); - let aggr_expr = self.bind_aggregate(&mut mutation.bind_context, (**filter_expr).clone())?; + let aggr_expr = self.bind_aggregate(&mut mutation.bind_context, input_expr.clone())?; - let s_expr = SExpr::create_unary( + let s_expr = Box::new(SExpr::create_unary( Arc::new(RelOperator::Mutation(mutation)), Arc::new(aggr_expr), - ); + )); let Plan::DataMutation { schema, metadata, .. @@ -283,13 +304,10 @@ impl Binder { else { unreachable!() }; - - let plan = Plan::DataMutation { - s_expr: Box::new(s_expr), + Ok(Plan::DataMutation { + s_expr, schema, metadata, - }; - - Ok(plan) + }) } } diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index 3b0ff44f71195..f25d217108f43 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -293,6 +293,33 @@ impl ScalarExpr { Ok(()) } + pub fn replace_column_datatype_to_nullable( + &mut self, + old: IndexType, + new: IndexType, + ) -> Result<()> { + struct Replace { + old: IndexType, + new: IndexType, + } + + impl VisitorMut<'_> for Replace { + fn visit_bound_column_ref(&mut self, col: &mut BoundColumnRef) -> Result<()> { + if col.column.index == self.old { + col.column.index = self.new; + if !col.column.data_type.is_nullable() { + col.column.data_type = Box::new(col.column.data_type.wrap_nullable()) + } + } + Ok(()) + } + } + + let mut visitor = Replace { old, new }; + visitor.visit(self)?; + Ok(()) + } + pub fn columns_and_data_types(&self, metadata: MetadataRef) -> HashMap { struct UsedColumnsVisitor { columns: HashMap, diff --git a/tests/sqllogictests/suites/query/cte/update_cte.test b/tests/sqllogictests/suites/query/cte/update_cte.test index fe0f2f299d3e4..63f0d585db76d 100644 --- a/tests/sqllogictests/suites/query/cte/update_cte.test +++ b/tests/sqllogictests/suites/query/cte/update_cte.test @@ -79,6 +79,23 @@ select * from t2; statement error (?s)1065.*?column a doesn't exist with tt1 as (select * from t1) update t2 set a = tt1.a; - statement ok drop table t2; + +statement ok +create or replace table test_merge(col1 varchar, col2 varchar, col3 varchar); + +statement ok +insert into test_merge values(2,'abc',2),(3,'abc',3),(4,'abc',4); + +statement ok +with tbb("col1", "col2", "col3") as (values ('1', 'add', '11'), ('4', 'add', '44')) update test_merge tba set tba.col1 =tbb.col1, tba.col2 = 'update', tba.col3 = tbb.col3 from tbb where tba.col1 = tbb.col1; + +statement ok +with tbb as (select col0::string null col1,col1::string null col2,col2::string null col3 from (values ('1', 'add', '11'), ('4', 'add', '44'))) update test_merge tba set tba.col1 =tbb.col1, tba.col2 = 'update', tba.col3 = tbb.col3 from tbb where tba.col1 = tbb.col1; + +statement ok +with tbb("col1", "col2", "col3") as (values ('1', 'add', '11'), ('4', 'add', '44')) update test_merge tba set tba.col1 =tbb.col1::string null, tba.col2 = 'update', tba.col3 = tbb.col3::string null from tbb where tba.col1 = tbb.col1::string null; + +statement ok +drop table test_merge; \ No newline at end of file