Skip to content

Transform scalar correlated subqueries in Where to DependentJoin #16174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ impl LogicalPlanBuilder {
join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
filter: Option<Expr>,
) -> Result<Self> {
self.join_detailed(right, join_type, join_keys, filter, false)
self.join_detailed(right, join_type, join_keys, filter, false, vec![])
}

/// Apply a join using the specified expressions.
Expand Down Expand Up @@ -957,6 +957,26 @@ impl LogicalPlanBuilder {
(Vec::<Column>::new(), Vec::<Column>::new()),
filter,
false,
vec![],
)
}

pub fn dependent_join_on(
self,
right: LogicalPlan,
join_type: JoinType,
on_exprs: impl IntoIterator<Item = Expr>,
outer_ref_columns: Vec<Expr>,
) -> Result<Self> {
let filter = on_exprs.into_iter().reduce(Expr::and);

self.join_detailed(
right,
join_type,
(Vec::<Column>::new(), Vec::<Column>::new()),
filter,
false,
outer_ref_columns,
)
}

Expand Down Expand Up @@ -994,6 +1014,7 @@ impl LogicalPlanBuilder {
join_keys: (Vec<impl Into<Column>>, Vec<impl Into<Column>>),
filter: Option<Expr>,
null_equals_null: bool,
outer_ref_columns: Vec<Expr>,
) -> Result<Self> {
if join_keys.0.len() != join_keys.1.len() {
return plan_err!("left_keys and right_keys were not the same length");
Expand Down Expand Up @@ -1111,6 +1132,8 @@ impl LogicalPlanBuilder {
join_constraint: JoinConstraint::On,
schema: DFSchemaRef::new(join_schema),
null_equals_null,
dependent_join: false,
outer_ref_columns,
})))
}

Expand Down Expand Up @@ -1337,12 +1360,12 @@ impl LogicalPlanBuilder {
.unzip();
if is_all {
LogicalPlanBuilder::from(left_plan)
.join_detailed(right_plan, join_type, join_keys, None, true)?
.join_detailed(right_plan, join_type, join_keys, None, true, vec![])?
.build()
} else {
LogicalPlanBuilder::from(left_plan)
.distinct()?
.join_detailed(right_plan, join_type, join_keys, None, true)?
.join_detailed(right_plan, join_type, join_keys, None, true, vec![])?
.build()
}
}
Expand Down
17 changes: 17 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,8 @@ impl LogicalPlan {
on,
schema: _,
null_equals_null,
dependent_join,
outer_ref_columns,
}) => {
let schema =
build_join_schema(left.schema(), right.schema(), &join_type)?;
Expand All @@ -680,6 +682,8 @@ impl LogicalPlan {
filter,
schema: DFSchemaRef::new(schema),
null_equals_null,
dependent_join,
outer_ref_columns,
}))
}
LogicalPlan::Subquery(_) => Ok(self),
Expand Down Expand Up @@ -937,6 +941,8 @@ impl LogicalPlan {
filter: filter_expr,
schema: DFSchemaRef::new(schema),
null_equals_null: *null_equals_null,
dependent_join: false,
outer_ref_columns: vec![],
}))
}
LogicalPlan::Subquery(Subquery {
Expand Down Expand Up @@ -3706,6 +3712,11 @@ pub struct Join {
pub schema: DFSchemaRef,
/// If null_equals_null is true, null == null else null != null
pub null_equals_null: bool,
// TODO: maybe it's better to add a new logical plan: DependentJoin.
/// DependentJoin is intermediate state of correlated subquery rewriting.
pub dependent_join: bool,
/// The outer references used in the subquery
pub outer_ref_columns: Vec<Expr>,
}

impl Join {
Expand Down Expand Up @@ -3747,6 +3758,8 @@ impl Join {
join_constraint,
schema: Arc::new(join_schema),
null_equals_null,
dependent_join: false,
outer_ref_columns: vec![],
})
}

Expand Down Expand Up @@ -3780,6 +3793,8 @@ impl Join {
join_constraint: original_join.join_constraint,
schema: Arc::new(join_schema),
null_equals_null: original_join.null_equals_null,
dependent_join: false,
outer_ref_columns: vec![],
})
}
}
Expand Down Expand Up @@ -4879,6 +4894,8 @@ digraph {
join_constraint: JoinConstraint::On,
schema: Arc::new(left_schema.join(&right_schema)?),
null_equals_null: false,
dependent_join: false,
outer_ref_columns: vec![],
}))
}

Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ impl TreeNode for LogicalPlan {
join_constraint,
schema,
null_equals_null,
dependent_join,
outer_ref_columns,
}) => (left, right).map_elements(f)?.update_data(|(left, right)| {
LogicalPlan::Join(Join {
left,
Expand All @@ -151,6 +153,8 @@ impl TreeNode for LogicalPlan {
join_constraint,
schema,
null_equals_null,
dependent_join,
outer_ref_columns,
})
}),
LogicalPlan::Limit(Limit { skip, fetch, input }) => input
Expand Down Expand Up @@ -577,6 +581,8 @@ impl LogicalPlan {
join_constraint,
schema,
null_equals_null,
dependent_join,
outer_ref_columns,
}) => (on, filter).map_elements(f)?.update_data(|(on, filter)| {
LogicalPlan::Join(Join {
left,
Expand All @@ -587,6 +593,8 @@ impl LogicalPlan {
join_constraint,
schema,
null_equals_null,
dependent_join,
outer_ref_columns,
})
}),
LogicalPlan::Sort(Sort { expr, input, fetch }) => expr
Expand Down
171 changes: 171 additions & 0 deletions datafusion/optimizer/src/create_dependent_join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion_common::tree_node::Transformed;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Subquery};

use crate::{ApplyOrder, OptimizerConfig, OptimizerRule};

/// (temporary) OPtimizer rule for rewriting current plan with
/// DependentJoin to jj
#[derive(Default, Debug)]
pub struct CreateDependentJoin {}

impl CreateDependentJoin {
#[allow(missing_docs)]
pub fn new() -> Self {
Self::default()
}
}

impl OptimizerRule for CreateDependentJoin {
fn supports_rewrite(&self) -> bool {
true
}

fn name(&self) -> &str {
"create_dependent_join"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}

fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
if let LogicalPlan::Filter(ref filter) = plan {
match &filter.predicate {
Copy link
Contributor

@duongcongtoai duongcongtoai May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here are more cases i can think of:

  1. a predicate can be a complex expressions such as
 where column1=(scalar_subquery) or column2=(exists_subquery)

In this case 2 nested dependent join will be generated

  1. The scalar subquery exprs sometimes is not the direct child of the predicate for example
where column1 > 1 + (subquery)
  1. We can have 2 subqueries in the same binary expr
where (subquery1) > (subquery2) + 1

Expr::BinaryExpr(binary) => {
// Check if right hand side is a scalar subquery
if let Expr::ScalarSubquery(subquery) = binary.right.as_ref() {
let new_plan = build_dependent_join(
subquery,
filter.input.as_ref().clone(),
JoinType::Left,
)?;
return Ok(Transformed::yes(new_plan));
}
// Continue searching in children if no subquery found
return Ok(Transformed::no(plan));
}
_ => {
// TODO: add other type of subqueries.
return Ok(Transformed::no(plan));
}
}
}

// No Filter found, continue searching in children
Ok(Transformed::no(plan))
}
}

fn build_dependent_join(
subquery: &Subquery,
root: LogicalPlan,
join_type: JoinType,
) -> Result<LogicalPlan> {
let subquery_plan = (subquery.subquery).as_ref().clone();

let new_plan = LogicalPlanBuilder::from(root)
.dependent_join_on(
subquery_plan,
join_type,
vec![Expr::Literal(ScalarValue::Boolean(Some(true)))],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we have shorter syntax:

use datafusion_expr::lit;
let some_exprs = vec![lit(true)];

subquery.outer_ref_columns.clone(),
Copy link
Contributor

@duongcongtoai duongcongtoai May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the subquery has some nested subquery underneath, i believe this function won't be able to return the outer_ref_columns from lower level.
For example

where column1=(select count(*) from inner_table_lv1 lv1 where lv1.column2=lv0.column2 and exists (
  select * from inner_table_lv2 lv2 where lv2.column1=lv1.column1 and lv2.column2=lv0.column3
)

In this case, the calls to subquery.outer_ref_columns will only returns lv0.column2, while the general framework needs to be aware of lv0.column3 as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For cases where depth > 1, DataFusion doesn't support it at the planner stage. The reason is that each time parse_subquery is called, it uses the outer_query_schema, which is the schema from the previous layer of the query:

pub(super) fn parse_scalar_subquery(
        &self,
        subquery: Query,
        input_schema: &DFSchema,
        planner_context: &mut PlannerContext,
    ) -> Result<Expr> {
        let old_outer_query_schema =
            planner_context.set_outer_query_schema(Some(input_schema.clone().into()));
         ...

In #16060, I attempted to layer the schemas of query blocks at different depths within the PlannerContext, and record the depth of the subquery's own layer within the Subquery, then pass the PlannerContext into the optimizer. What are your thoughts on this approach? Welcome discussion of your ideas. For multi-layer cases, more detailed design and discussion may be needed. Currently, I'm more inclined to handle simple use cases between adjacent layers first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder would it be more simple to let the decorrelation optimizor aware of the depth and handle recursion itself 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder would it be more simple to let the decorrelation optimizor aware of the depth and handle recursion itself 🤔

Since there are multiple optimizer rules, I'm wondering if the depth will change because of other priority rules rewrite.🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there are multiple optimizer rules

In the final stage of this epic we only let one optimizor handle the decorrelation right?

Also in the middle of the implementation, even if we maintain multiple decorrelating rules, if existing rule such as DecorrelatePredicateSubquery or ScalarSubqueryToJoin detect any depth > 1, they will back off and leave the whole query untounched

Copy link
Contributor

@duongcongtoai duongcongtoai May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#16016

I also implemented something like this, but inside an optimizor (still alot of details need to be added, but at least it is capable of detect the correlated columns (including the ones with depth > 1), correlated exprs, the depth of the dependent join node)

Copy link
Contributor Author

@irenjj irenjj May 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @duongcongtoai, I've seen your pr, It's much more comprehensive than mine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#16016

I also implemented something like this, but inside an optimizor (still alot of details need to be added, but at least it is capable of detect the correlated columns (including the ones with depth > 1), correlated exprs, the depth of the dependent join node)

Maybe we could implement an initial version first, then list some pending work as tracking issues? I'm actually quite eager to contribute and help out as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, i'll try to wrap up with some basic usecase and ask for review soon

)?
.build()?;

Ok(new_plan)
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use datafusion_common::{Result, Spans};
use datafusion_expr::{col, Expr, LogicalPlanBuilder, Subquery};
use datafusion_functions_aggregate::expr_fn::avg;

use crate::assert_optimized_plan_eq_display_indent_snapshot;
use crate::create_dependent_join::CreateDependentJoin;
use crate::test::test_table_scan_with_name;

macro_rules! assert_optimized_plan_equal {
(
$plan:expr,
@ $expected:literal $(,)?
) => {{
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(CreateDependentJoin::new());
assert_optimized_plan_eq_display_indent_snapshot!(
rule,
$plan,
@ $expected,
)
}};
}

#[test]
fn test_correlated_scalar_subquery() -> Result<()> {
// outer table
let employees = test_table_scan_with_name("employees")?;
// inner table
let salary = test_table_scan_with_name("salary")?;

// SELECT employees.a
// FROM employees
// WHERE employees.b > (
// SELECT avg(salary.a)
// FROM salary
// WHERE salary.c = employees.c
// );

// SELECT AVG(salary.a) FROM salary WHERE salary.c = employees.c
let subquery = Arc::new(
LogicalPlanBuilder::from(salary)
.filter(col("salary.c").eq(col("employees.c")))?
.aggregate(Vec::<Expr>::new(), vec![avg(col("salary.a"))])?
.build()?,
);

// SELECT employees.a FROM employees WHERE employees.b > (subquery)
let plan = LogicalPlanBuilder::from(employees)
.filter(col("employees.a").gt(Expr::ScalarSubquery(Subquery {
subquery,
outer_ref_columns: vec![col("employees.c")],
spans: Spans::new(),
})))?
.project(vec![col("employees.a")])?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Projection: employees.a [a:UInt32]
Left Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, avg(salary.a):Float64;N]
TableScan: employees [a:UInt32, b:UInt32, c:UInt32]
Aggregate: groupBy=[[]], aggr=[[avg(salary.a)]] [avg(salary.a):Float64;N]
Filter: salary.c = employees.c [a:UInt32, b:UInt32, c:UInt32]
TableScan: salary [a:UInt32, b:UInt32, c:UInt32]
"
)
}
}
4 changes: 4 additions & 0 deletions datafusion/optimizer/src/eliminate_cross_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ fn find_inner_join(
filter: None,
schema: join_schema,
null_equals_null: false,
dependent_join: false,
outer_ref_columns: vec![],
}));
}
}
Expand All @@ -351,6 +353,8 @@ fn find_inner_join(
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
null_equals_null: false,
dependent_join: false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we use something like JoinType::DependentJoin instead of a boolean to separate it??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bool is a little strange, so I comment TODO: maybe it's better to add a new logical plan: DependentJoin.
But if we mark dependent join by JoinType::DependentJoin, how we can know the real JoinType?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought DependentJoin's actual type was to be decided later with Decorrelate Optimizer. Hence, the suggestion, though I am not sure anymore.

outer_ref_columns: vec![],
}))
}

Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/eliminate_outer_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ impl OptimizerRule for EliminateOuterJoin {
filter: join.filter.clone(),
schema: Arc::clone(&join.schema),
null_equals_null: join.null_equals_null,
dependent_join: false,
outer_ref_columns: vec![],
}));
Filter::try_new(filter.predicate, new_join)
.map(|f| Transformed::yes(LogicalPlan::Filter(f)))
Expand Down
6 changes: 6 additions & 0 deletions datafusion/optimizer/src/extract_equijoin_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ impl OptimizerRule for ExtractEquijoinPredicate {
join_constraint,
schema,
null_equals_null,
dependent_join,
outer_ref_columns,
}) => {
let left_schema = left.schema();
let right_schema = right.schema();
Expand All @@ -93,6 +95,8 @@ impl OptimizerRule for ExtractEquijoinPredicate {
join_constraint,
schema,
null_equals_null,
dependent_join,
outer_ref_columns,
})))
} else {
Ok(Transformed::no(LogicalPlan::Join(Join {
Expand All @@ -104,6 +108,8 @@ impl OptimizerRule for ExtractEquijoinPredicate {
join_constraint,
schema,
null_equals_null,
dependent_join,
outer_ref_columns,
})))
}
}
Expand Down
Loading