diff --git a/velox/exec/HashAggregation.cpp b/velox/exec/HashAggregation.cpp index ad9fdb518d3c..05c2f73c5fd2 100644 --- a/velox/exec/HashAggregation.cpp +++ b/velox/exec/HashAggregation.cpp @@ -81,10 +81,8 @@ void HashAggregation::initialize() { core::AggregationNode::stepName(aggregationNode_->step())); } - if (isDistinct_) { - for (auto i = 0; i < hashers.size(); ++i) { - identityProjections_.emplace_back(hashers[i]->channel(), i); - } + for (auto i = 0; i < hashers.size(); ++i) { + identityProjections_.emplace_back(hashers[i]->channel(), i); } std::optional groupIdChannel; diff --git a/velox/exec/Operator.h b/velox/exec/Operator.h index 051ce55c02c8..56c54daf7ca0 100644 --- a/velox/exec/Operator.h +++ b/velox/exec/Operator.h @@ -454,7 +454,10 @@ class Operator : public BaseRuntimeStatWriter { } /// Returns a list of identity projections, e.g. columns that are projected - /// as-is possibly after applying a filter. + /// as-is possibly after applying a filter. Used to allow pushdown of dynamic + /// filters generated by HashProbe into the TableScan. Examples of identity + /// projections: all columns in FilterProject(only filters), group-by keys in + /// HashAggregation. const std::vector& identityProjections() const { return identityProjections_; } diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index b5488d14f2ab..07c65c40348f 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -921,6 +921,13 @@ class HashJoinTest : public HiveConnectorTestBase { return stats[operatorIndex].runtimeStats[statsName]; } + // Get the operator index from the plan node id. Only used in the probe-side + // pipeline. The plan node id starts from "1" and the operator index starts + // from 0. Plan node IDs map to operators 1:1. + static int32_t getOperatorIndex(const core::PlanNodeId& planNodeId) { + return folly::to(planNodeId) - 1; + } + static core::JoinType flipJoinType(core::JoinType joinType) { switch (joinType) { case core::JoinType::kInner: @@ -5177,6 +5184,75 @@ TEST_F(HashJoinTest, dynamicFiltersAppliedToPreloadedSplits) { .run(); } +TEST_F(HashJoinTest, dynamicFiltersPushDownThroughAgg) { + const int32_t numRowsProbe = 300; + const int32_t numRowsBuild = 100; + + // Create probe data + std::vector probeVectors{makeRowVector({ + makeFlatVector(numRowsProbe, [&](auto row) { return row - 10; }), + makeFlatVector(numRowsProbe, folly::identity), + })}; + std::shared_ptr probeFile = TempFilePath::create(); + writeToFile(probeFile->getPath(), probeVectors); + + // Create build data + std::vector buildVectors{makeRowVector( + {"u0"}, {makeFlatVector(numRowsBuild, [&](auto row) { + return 35 + 2 * (row + numRowsBuild / 5); + })})}; + + createDuckDbTable("t", probeVectors); + createDuckDbTable("u", buildVectors); + + auto probeType = ROW({"c0", "c1"}, {INTEGER(), BIGINT()}); + auto planNodeIdGenerator = std::make_shared(); + auto buildSide = + PlanBuilder(planNodeIdGenerator).values(buildVectors).planNode(); + + // Inner join. + core::PlanNodeId scanNodeId; + core::PlanNodeId joinNodeId; + core::PlanNodeId aggNodeId; + auto op = PlanBuilder(planNodeIdGenerator, pool_.get()) + .tableScan(probeType) + .capturePlanNodeId(scanNodeId) + .partialAggregation({"c0"}, {"sum(c1)"}) + .capturePlanNodeId(aggNodeId) + .hashJoin( + {"c0"}, + {"u0"}, + buildSide, + "", + {"c0", "a0"}, + core::JoinType::kInner) + .capturePlanNodeId(joinNodeId) + .planNode(); + + SplitInput splitInput = { + {scanNodeId, {Split(makeHiveConnectorSplit(probeFile->getPath()))}}}; + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(std::move(op)) + .inputSplits(splitInput) + .injectSpill(false) + .checkSpillStats(false) + .referenceQuery("SELECT c0, sum(c1) FROM t, u WHERE c0 = u0 group by c0") + .verifier([&](const std::shared_ptr& task, bool hasSpill) { + auto planStats = toPlanStats(task->taskStats()); + auto dynamicFilterStats = planStats.at(scanNodeId).dynamicFilterStats; + ASSERT_EQ( + 1, getFiltersProduced(task, getOperatorIndex(joinNodeId)).sum); + ASSERT_EQ( + 1, getFiltersAccepted(task, getOperatorIndex(scanNodeId)).sum); + ASSERT_LT( + getInputPositions(task, getOperatorIndex(aggNodeId)), numRowsProbe); + ASSERT_EQ( + dynamicFilterStats.producerNodeIds, + std::unordered_set({joinNodeId})); + }) + .run(); +} + // Verify the size of the join output vectors when projecting build-side // variable-width column. TEST_F(HashJoinTest, memoryUsage) {