diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index d9a1478f0238..e6a90db2dc3e 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -109,6 +109,8 @@ impl JoinType { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) } } diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 7250a263d89c..2a112c8bbb7b 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -305,7 +305,6 @@ async fn test_left_mark_join_1k_filtered() { .await } -// todo: add JoinTestType::HjSmj after Right mark SortMergeJoin support #[tokio::test] async fn test_right_mark_join_1k() { JoinFuzzTestCase::new( @@ -314,7 +313,7 @@ async fn test_right_mark_join_1k() { JoinType::RightMark, None, ) - .run_test(&[NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } @@ -326,7 +325,7 @@ async fn test_right_mark_join_1k_filtered() { JoinType::RightMark, Some(Box::new(col_lt_col_filter)), ) - .run_test(&[NljHj], false) + .run_test(&[HjSmj, NljHj], false) .await } diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index 3477ac77123c..cfe4d33fd69a 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -371,6 +371,61 @@ async fn test_join_with_swap_semi() { } } +#[tokio::test] +async fn test_join_with_swap_mark() { + let join_types = [JoinType::LeftMark]; + for join_type in join_types { + let (big, small) = create_big_and_small(); + + let join = HashJoinExec::try_new( + Arc::clone(&big), + Arc::clone(&small), + vec![( + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + )], + None, + &join_type, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + ) + .unwrap(); + + let original_schema = join.schema(); + + let optimized_join = JoinSelection::new() + .optimize(Arc::new(join), &ConfigOptions::new()) + .unwrap(); + + let swapped_join = optimized_join + .as_any() + .downcast_ref::() + .expect( + "A proj is not required to swap columns back to their original order", + ); + + assert_eq!(swapped_join.schema().fields().len(), 2); + assert_eq!( + swapped_join + .left() + .partition_statistics(None) + .unwrap() + .total_byte_size, + Precision::Inexact(8192) + ); + assert_eq!( + swapped_join + .right() + .partition_statistics(None) + .unwrap() + .total_byte_size, + Precision::Inexact(2097152) + ); + assert_eq!(original_schema, swapped_join.schema()); + } +} + /// Compare the input plan with the plan after running the probe order optimizer. macro_rules! assert_optimized { ($EXPECTED_LINES: expr, $PLAN: expr) => { @@ -577,8 +632,10 @@ async fn test_nl_join_with_swap(join_type: JoinType) { join_type, case::left_semi(JoinType::LeftSemi), case::left_anti(JoinType::LeftAnti), + case::left_mark(JoinType::LeftMark), case::right_semi(JoinType::RightSemi), - case::right_anti(JoinType::RightAnti) + case::right_anti(JoinType::RightAnti), + case::right_mark(JoinType::RightMark) )] #[tokio::test] async fn test_nl_join_with_swap_no_proj(join_type: JoinType) { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 93dd6c2b89fc..64c107c7afbd 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1652,7 +1652,10 @@ pub fn build_join_schema( ); let (schema1, schema2) = match join_type { - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right), + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => (left, right), _ => (right, left), }; diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index dc220332141b..d8e85a8d15d8 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -539,6 +539,7 @@ pub fn hash_join_swap_subrule( | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti + | JoinType::LeftMark ) { input = swap_join_according_to_unboundedness(hash_join)?; @@ -549,10 +550,10 @@ pub fn hash_join_swap_subrule( /// This function swaps sides of a hash join to make it runnable even if one of /// its inputs are infinite. Note that this is not always possible; i.e. -/// [`JoinType::Full`], [`JoinType::Right`], [`JoinType::RightAnti`] and -/// [`JoinType::RightSemi`] can not run with an unbounded left side, even if -/// we swap join sides. Therefore, we do not consider them here. -/// This function is crate public as it is useful for downstream projects +/// [`JoinType::Full`], [`JoinType::Right`], [`JoinType::RightAnti`], +/// [`JoinType::RightSemi`], and [`JoinType::RightMark`] can not run with an +/// unbounded left side, even if we swap join sides. Therefore, we do not consider +/// them here. This function is crate public as it is useful for downstream projects /// to implement, or experiment with, their own join selection rules. pub(crate) fn swap_join_according_to_unboundedness( hash_join: &HashJoinExec, @@ -562,7 +563,11 @@ pub(crate) fn swap_join_according_to_unboundedness( match (*partition_mode, *join_type) { ( _, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full, + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark + | JoinType::Full, ) => internal_err!("{join_type} join cannot be swapped for unbounded input."), (PartitionMode::Partitioned, _) => { hash_join.swap_inputs(PartitionMode::Partitioned) diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 770399290dca..148a25ceb2c0 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -618,13 +618,16 @@ impl HashJoinExec { partition_mode, self.null_equality(), )?; - // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again + + // In case of Anti/Semi/Mark joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( self.join_type(), JoinType::LeftSemi | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) || self.projection.is_some() { Ok(Arc::new(new_join)) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index fcc1107a0e26..3ffc13275524 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -371,7 +371,7 @@ impl NestedLoopJoinExec { ), )?; - // For Semi/Anti joins, swap result will produce same output schema, + // For Semi/Anti/Mark joins, swap result will produce same output schema, // no need to wrap them into additional projection let plan: Arc = if matches!( self.join_type(), @@ -379,6 +379,8 @@ impl NestedLoopJoinExec { | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti + | JoinType::LeftMark + | JoinType::RightMark ) || self.projection.is_some() { Arc::new(new_join) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index a8c209a492ba..c94433716d64 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -230,7 +230,6 @@ impl SortMergeJoinExec { // When output schema contains only the right side, probe side is right. // Otherwise probe side is the left side. match join_type { - // TODO: sort merge support for right mark (tracked here: https://github.com/apache/datafusion/issues/16226) JoinType::Right | JoinType::RightSemi | JoinType::RightAnti @@ -1010,7 +1009,7 @@ fn get_corrected_filter_mask( corrected_mask.append_n(expected_size - corrected_mask.len(), false); Some(corrected_mask.finish()) } - JoinType::LeftMark => { + JoinType::LeftMark | JoinType::RightMark => { for i in 0..row_indices_length { let last_index = last_index_for_row(i, row_indices, batch_ids, row_indices_length); @@ -1160,6 +1159,7 @@ impl Stream for SortMergeJoinStream { JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark + | JoinType::RightMark | JoinType::Right | JoinType::RightSemi | JoinType::LeftAnti @@ -1271,6 +1271,7 @@ impl Stream for SortMergeJoinStream { | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark + | JoinType::RightMark | JoinType::Full ) { @@ -1298,6 +1299,7 @@ impl Stream for SortMergeJoinStream { | JoinType::RightAnti | JoinType::Full | JoinType::LeftMark + | JoinType::RightMark ) { let record_batch = self.filter_joined_batch()?; @@ -1623,6 +1625,7 @@ impl SortMergeJoinStream { | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark + | JoinType::RightMark ) { join_streamed = !self.streamed_joined; } @@ -1630,9 +1633,15 @@ impl SortMergeJoinStream { Ordering::Equal => { if matches!( self.join_type, - JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi + JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::RightSemi + | JoinType::RightMark ) { - mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); + mark_row_as_match = matches!( + self.join_type, + JoinType::LeftMark | JoinType::RightMark + ); // if the join filter is specified then its needed to output the streamed index // only if it has not been emitted before // the `join_filter_matched_idxs` keeps track on if streamed index has a successful @@ -1847,31 +1856,32 @@ impl SortMergeJoinStream { // The row indices of joined buffered batch let right_indices: UInt64Array = chunk.buffered_indices.finish(); - let mut right_columns = if matches!(self.join_type, JoinType::LeftMark) { - vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] - } else if matches!( - self.join_type, - JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::RightSemi - ) { - vec![] - } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - fetch_right_columns_by_idxs( - &self.buffered_data, - buffered_idx, - &right_indices, - )? - } else { - // If buffered batch none, meaning it is null joined batch. - // We need to create null arrays for buffered columns to join with streamed rows. - create_unmatched_columns( + let mut right_columns = + if matches!(self.join_type, JoinType::LeftMark | JoinType::RightMark) { + vec![Arc::new(is_not_null(&right_indices)?) as ArrayRef] + } else if matches!( self.join_type, - &self.buffered_schema, - right_indices.len(), - ) - }; + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::RightSemi + ) { + vec![] + } else if let Some(buffered_idx) = chunk.buffered_batch_idx { + fetch_right_columns_by_idxs( + &self.buffered_data, + buffered_idx, + &right_indices, + )? + } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. + create_unmatched_columns( + self.join_type, + &self.buffered_schema, + right_indices.len(), + ) + }; // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. @@ -1890,7 +1900,7 @@ impl SortMergeJoinStream { get_filter_column(&self.filter, &left_columns, &right_cols) } else if matches!( self.join_type, - JoinType::RightAnti | JoinType::RightSemi + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark ) { let right_cols = fetch_right_columns_by_idxs( &self.buffered_data, @@ -1956,6 +1966,7 @@ impl SortMergeJoinStream { | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark + | JoinType::RightMark | JoinType::Full ) { self.staging_output_record_batches @@ -2054,6 +2065,7 @@ impl SortMergeJoinStream { | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftMark + | JoinType::RightMark | JoinType::Full )) { @@ -2115,7 +2127,7 @@ impl SortMergeJoinStream { if matches!( self.join_type, - JoinType::Left | JoinType::LeftMark | JoinType::Right + JoinType::Left | JoinType::LeftMark | JoinType::Right | JoinType::RightMark ) { let null_mask = compute::not(corrected_mask)?; let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; @@ -2236,7 +2248,7 @@ fn create_unmatched_columns( schema: &SchemaRef, size: usize, ) -> Vec { - if matches!(join_type, JoinType::LeftMark) { + if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) { vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef] } else { schema @@ -3830,6 +3842,38 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_right_mark() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, RightMark).await?; + // The output order is important as SMJ preserves sortedness + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+-------+ + | a2 | b1 | c2 | mark | + +----+----+----+-------+ + | 10 | 4 | 60 | true | + | 20 | 4 | 70 | true | + | 30 | 5 | 80 | true | + | 40 | 6 | 90 | false | + +----+----+----+-------+ + "#); + Ok(()) + } + #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { let left = build_table( @@ -4158,7 +4202,7 @@ mod tests { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = vec![ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, ]; // Disable DiskManager to prevent spilling @@ -4240,7 +4284,7 @@ mod tests { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = vec![ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, ]; // Disable DiskManager to prevent spilling @@ -4300,7 +4344,7 @@ mod tests { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = [ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, ]; // Enable DiskManager to allow spilling @@ -4405,7 +4449,7 @@ mod tests { let sort_options = vec![SortOptions::default(); on.len()]; let join_types = [ - Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, + Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, RightMark, ]; // Enable DiskManager to allow spilling diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 6dbe75cc0ae4..259ded353990 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -1018,6 +1018,7 @@ pub(crate) fn join_with_probe_batch( | JoinType::LeftSemi | JoinType::LeftMark | JoinType::RightSemi + | JoinType::RightMark ) { Ok(None) } else { @@ -1859,6 +1860,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] @@ -1947,6 +1949,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] @@ -2015,6 +2018,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] @@ -2068,6 +2072,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] @@ -2096,6 +2101,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] @@ -2480,6 +2486,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] @@ -2566,6 +2573,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] @@ -2644,6 +2652,7 @@ mod tests { JoinType::LeftSemi, JoinType::LeftAnti, JoinType::LeftMark, + JoinType::RightMark, JoinType::RightAnti, JoinType::Full )] diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c5f7087ac195..4d860c56e9d7 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -297,7 +297,10 @@ pub fn build_join_schema( }; let (schema1, schema2) = match join_type { - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (left, right), + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => (left, right), _ => (right, left), }; @@ -1489,13 +1492,15 @@ pub(super) fn swap_join_projection( join_type: &JoinType, ) -> Option> { match join_type { - // For Anti/Semi join types, projection should remain unmodified, + // For Anti/Semi/Mark join types, projection should remain unmodified, // since these joins output schema remains the same after swap JoinType::LeftAnti | JoinType::LeftSemi | JoinType::RightAnti - | JoinType::RightSemi => projection.cloned(), - + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark => projection.cloned(), + // For everything else we need to shift the column indices _ => projection.map(|p| { p.iter() .map(|i| { diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 796570633f67..671dcfdac507 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1192,7 +1192,7 @@ physical_plan 01)CoalesceBatchesExec: target_batch_size=2 02)--FilterExec: t1_id@0 > 40 OR NOT mark@3, projection=[t1_id@0, t1_name@1, t1_int@2] 03)----CoalesceBatchesExec: target_batch_size=2 -04)------HashJoinExec: mode=CollectLeft, join_type=LeftMark, on=[(t1_id@0, t2_id@0)] +04)------HashJoinExec: mode=CollectLeft, join_type=RightMark, on=[(t2_id@0, t1_id@0)] 05)--------DataSourceExec: partitions=1, partition_sizes=[1] 06)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 07)----------DataSourceExec: partitions=1, partition_sizes=[1]