Skip to content

Commit db259ff

Browse files
branch-4.0: [fix](nereids) adjust distribute expr lists after project common sub expression for aggregation #57258 (#57538)
Cherry-picked from #57258 Co-authored-by: minghong <[email protected]>
1 parent c74f134 commit db259ff

File tree

4 files changed

+174
-15
lines changed

4 files changed

+174
-15
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/ProjectAggregateExpressionsForCse.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.doris.nereids.processor.post;
1919

2020
import org.apache.doris.nereids.CascadesContext;
21+
import org.apache.doris.nereids.properties.ChildOutputPropertyDeriver;
2122
import org.apache.doris.nereids.properties.DataTrait;
2223
import org.apache.doris.nereids.properties.LogicalProperties;
2324
import org.apache.doris.nereids.properties.PhysicalProperties;
@@ -33,6 +34,7 @@
3334
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalPlan;
3435
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
3536
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
37+
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
3638
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
3739
import org.apache.doris.nereids.util.ExpressionUtils;
3840

@@ -144,7 +146,11 @@ public Plan visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg
144146
}
145147
}
146148
newProjections.addAll(cseCandidates.values());
147-
project = project.withProjectionsAndChild(newProjections, (Plan) project.child());
149+
150+
project = project.withProjectionsAndChild(newProjections, project.child());
151+
PhysicalProperties projectPhysicalProperties = ChildOutputPropertyDeriver.computeProjectOutputProperties(
152+
project.getProjects(), ((PhysicalPlan) project.child()).getPhysicalProperties());
153+
project = project.withPhysicalPropertiesAndStats(projectPhysicalProperties, project.getStats());
148154
aggregate = (PhysicalHashAggregate<? extends Plan>) aggregate
149155
.withAggOutput(aggOutputReplaced)
150156
.withChildren(project);
@@ -161,9 +167,8 @@ public Plan visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg
161167
() -> DataTrait.EMPTY_TRAIT
162168
);
163169
AbstractPhysicalPlan child = ((AbstractPhysicalPlan) aggregate.child());
164-
PhysicalProperties projectPhysicalProperties = new PhysicalProperties(
165-
child.getPhysicalProperties().getDistributionSpec(),
166-
child.getPhysicalProperties().getOrderSpec());
170+
PhysicalProperties projectPhysicalProperties = ChildOutputPropertyDeriver.computeProjectOutputProperties(
171+
projections, child.getPhysicalProperties());
167172
PhysicalProject<? extends Plan> project = new PhysicalProject<>(projections, Optional.empty(),
168173
projectLogicalProperties,
169174
projectPhysicalProperties,

fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriver.java

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -319,17 +319,18 @@ public PhysicalProperties visitPhysicalNestedLoopJoin(
319319
return new PhysicalProperties(leftOutputProperty.getDistributionSpec());
320320
}
321321

322-
@Override
323-
public PhysicalProperties visitPhysicalProject(PhysicalProject<? extends Plan> project, PlanContext context) {
324-
// TODO: order spec do not process since we do not use it.
325-
Preconditions.checkState(childrenOutputProperties.size() == 1);
326-
PhysicalProperties childProperties = childrenOutputProperties.get(0);
322+
/**
323+
* Derive output properties for physical project.
324+
*/
325+
public static PhysicalProperties computeProjectOutputProperties(
326+
List<NamedExpression> projects,
327+
PhysicalProperties childProperties) {
327328
DistributionSpec childDistributionSpec = childProperties.getDistributionSpec();
328329
OrderSpec childOrderSpec = childProperties.getOrderSpec();
329330
if (childDistributionSpec instanceof DistributionSpecHash) {
330331
Map<ExprId, ExprId> projections = Maps.newHashMap();
331332
Set<ExprId> obstructions = Sets.newHashSet();
332-
for (NamedExpression namedExpression : project.getProjects()) {
333+
for (NamedExpression namedExpression : projects) {
333334
if (namedExpression instanceof Alias) {
334335
Alias alias = (Alias) namedExpression;
335336
Expression child = alias.child();
@@ -345,22 +346,42 @@ && isSameHashValue(child.child(0).getDataType(), child.getDataType())) {
345346
.map(NamedExpression::getExprId)
346347
.collect(Collectors.toSet()));
347348
}
349+
} else {
350+
// namedExpression is slot
351+
projections.put(namedExpression.getExprId(), namedExpression.getExprId());
348352
}
349353
}
350-
if (projections.entrySet().stream().allMatch(kv -> kv.getKey().equals(kv.getValue()))) {
351-
return childrenOutputProperties.get(0);
352-
}
354+
353355
DistributionSpecHash childDistributionSpecHash = (DistributionSpecHash) childDistributionSpec;
356+
boolean canUseChildProperties = true;
357+
for (ExprId exprId : childDistributionSpecHash.getOrderedShuffledColumns()) {
358+
if (!projections.containsKey(exprId) || !projections.get(exprId).equals(exprId)) {
359+
canUseChildProperties = false;
360+
break;
361+
}
362+
}
363+
364+
if (canUseChildProperties) {
365+
return childProperties;
366+
}
354367
DistributionSpec defaultAnySpec = childDistributionSpecHash.getShuffleType() == ShuffleType.NATURAL
355368
? DistributionSpecStorageAny.INSTANCE : DistributionSpecAny.INSTANCE;
356369
DistributionSpec outputDistributionSpec = childDistributionSpecHash.project(
357370
projections, obstructions, defaultAnySpec);
358371
return new PhysicalProperties(outputDistributionSpec, childOrderSpec);
359372
} else {
360-
return childrenOutputProperties.get(0);
373+
return childProperties;
361374
}
362375
}
363376

377+
@Override
378+
public PhysicalProperties visitPhysicalProject(PhysicalProject<? extends Plan> project, PlanContext context) {
379+
// TODO: order spec do not process since we do not use it.
380+
Preconditions.checkState(childrenOutputProperties.size() == 1);
381+
PhysicalProperties childProperties = childrenOutputProperties.get(0);
382+
return computeProjectOutputProperties(project.getProjects(), childProperties);
383+
}
384+
364385
@Override
365386
public PhysicalProperties visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, PlanContext context) {
366387
Preconditions.checkState(childrenOutputProperties.size() == 1);
@@ -622,7 +643,7 @@ private DistributionSpecHash mockAnotherSideSpecFromConjuncts(
622643
return new DistributionSpecHash(anotherSideOrderedExprIds, oneSideSpec.getShuffleType());
623644
}
624645

625-
private boolean isSameHashValue(DataType originType, DataType castType) {
646+
private static boolean isSameHashValue(DataType originType, DataType castType) {
626647
if (originType.isStringLikeType() && (castType.isVarcharType() || castType.isStringType())
627648
&& (castType.width() >= originType.width() || castType.width() < 0)) {
628649
return true;

fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@
3131
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
3232
import org.apache.doris.nereids.trees.expressions.EqualTo;
3333
import org.apache.doris.nereids.trees.expressions.ExprId;
34+
import org.apache.doris.nereids.trees.expressions.NamedExpression;
3435
import org.apache.doris.nereids.trees.expressions.Slot;
3536
import org.apache.doris.nereids.trees.expressions.SlotReference;
3637
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
38+
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
3739
import org.apache.doris.nereids.trees.expressions.literal.Literal;
3840
import org.apache.doris.nereids.trees.plans.AggMode;
3941
import org.apache.doris.nereids.trees.plans.AggPhase;
@@ -928,4 +930,27 @@ void testRepeatReturnChild2() {
928930
PhysicalProperties result = deriver.getOutputProperties(null, groupExpression);
929931
Assertions.assertEquals(child, result);
930932
}
933+
934+
@Test
935+
void testComputeProjectOutputProperties() {
936+
SlotReference c1 = new SlotReference(
937+
new ExprId(1), "c1", TinyIntType.INSTANCE, true, ImmutableList.of());
938+
PhysicalProperties hashC1 = PhysicalProperties.createHash(
939+
ImmutableList.of(new ExprId(1)), ShuffleType.EXECUTION_BUCKETED);
940+
List<NamedExpression> projects1 = new ArrayList<>();
941+
projects1.add(c1);
942+
PhysicalProperties phyProp = ChildOutputPropertyDeriver.computeProjectOutputProperties(projects1, hashC1);
943+
Assertions.assertEquals(hashC1, phyProp);
944+
945+
List<NamedExpression> projects2 = new ArrayList<>();
946+
projects2.add(new Alias(new Abs(c1)));
947+
PhysicalProperties phyProp2 = ChildOutputPropertyDeriver.computeProjectOutputProperties(projects2, hashC1);
948+
Assertions.assertEquals(DistributionSpecAny.INSTANCE, phyProp2.getDistributionSpec());
949+
950+
List<NamedExpression> projects3 = new ArrayList<>();
951+
projects3.add(new Alias(new Abs(c1)));
952+
projects3.add(c1);
953+
PhysicalProperties phyProp3 = ChildOutputPropertyDeriver.computeProjectOutputProperties(projects3, hashC1);
954+
Assertions.assertEquals(hashC1, phyProp3);
955+
}
931956
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
// The cases is copied from https://github.com/trinodb/trino/tree/master
19+
// /testing/trino-product-tests/src/main/resources/sql-tests/testcases/aggregate
20+
// and modified by Doris.
21+
22+
suite("dist_expr_list") {
23+
sql """
24+
drop table if exists agg_cse_shuffle;
25+
create table agg_cse_shuffle(
26+
a int, b int, c int
27+
)distributed by hash(a) buckets 3
28+
properties("replication_num" = "1")
29+
;
30+
insert into agg_cse_shuffle values (1, 2, 3), (3, 4, 5);
31+
"""
32+
33+
explain {
34+
sql """
35+
select max(
36+
case
37+
when (abs(b) > 10) then a+1
38+
when (abs(b) < 10) then a+2
39+
else NULL
40+
end
41+
) x
42+
from agg_cse_shuffle
43+
where c > 0;
44+
"""
45+
notContains "distribute expr lists: a"
46+
}
47+
/*
48+
after ProjectAggregateExpressionsForCse, a#6 is not output slot of scanNode, and hence it should not be distr expr.
49+
50+
expect explain string
51+
| 1:VAGGREGATE (update serialize)(114) |
52+
| | output: partial_max(CASE WHEN (abs(b) > 10) THEN (cast(a as BIGINT) + 1) WHEN (abs(b) < 10) THEN (cast(a as BIGINT) + 2) ELSE NULL END[#8])[#9] |
53+
| | group by: |
54+
| | sortByGroupKey:false |
55+
| | cardinality=1 |
56+
| | distribute expr lists: |
57+
| | |
58+
| 0:VOlapScanNode(95)
59+
60+
61+
the bad case: distribute expr lists: a[#6]
62+
explain string
63+
| 1:VAGGREGATE (update serialize)(114) |
64+
| | output: partial_max(CASE WHEN (abs(b) > 10) THEN (cast(a as BIGINT) + 1) WHEN (abs(b) < 10) THEN (cast(a as BIGINT) + 2) ELSE NULL END[#8])[#9] |
65+
| | group by: |
66+
| | sortByGroupKey:false |
67+
| | cardinality=1 |
68+
| | distribute expr lists: a[#6] <== |
69+
| | |
70+
| 0:VOlapScanNode(95)
71+
| TABLE: rqg.agg_cse_shuffle(agg_cse_shuffle), PREAGGREGATION: ON |
72+
| partitions=1/1 (agg_cse_shuffle) |
73+
| tablets=3/3, tabletList=1761200234884,1761200234886,1761200234888 |
74+
| cardinality=2, avgRowSize=1182.5, numNodes=2 |
75+
| pushAggOp=NONE |
76+
| final projections: a[#2], b[#3], CASE WHEN (abs(b)[#4] > 10) THEN (cast(a as BIGINT)[#5] + 1) WHEN (abs(b)[#4] < 10) THEN (cast(a as BIGINT)[#5] + 2) ELSE NULL END |
77+
| final project output tuple id: 2 |
78+
| intermediate projections: a[#0], b[#1], abs(b[#1]), CAST(a[#0] AS bigint) |
79+
| intermediate tuple id: 1
80+
*/
81+
82+
explain {
83+
sql """
84+
select max(
85+
case
86+
when (abs(b) > 10) then a+1
87+
when (abs(b) < 10) then a+2
88+
else NULL
89+
end
90+
),
91+
max(a)
92+
from agg_cse_shuffle;
93+
"""
94+
contains "distribute expr lists: a"
95+
/*
96+
expect explain string
97+
| 1:VAGGREGATE (update serialize)(100) |
98+
| | output: partial_max(CASE WHEN (abs(b) > 10) THEN (cast(a as BIGINT) + 1) WHEN (abs(b) < 10) THEN (cast(a as BIGINT) + 2) ELSE NULL END[#8])[#9], partial_max(a[#6])[#10] |
99+
| | group by: |
100+
| | sortByGroupKey:false |
101+
| | cardinality=1 |
102+
| | distribute expr lists: a[#6] |
103+
| | tuple ids: 3 |
104+
| | |
105+
| 0:VOlapScanNode(81)
106+
*/
107+
}
108+
}

0 commit comments

Comments
 (0)