Skip to content

Commit 9d77026

Browse files
committed
Phase 1. Implementation for RFC-0005: Scalar function stats propagation.
1. Support for annotating functions with both constant stats and propagating source stats. 2. Added tests for the same. 3. Added Scalar stats calculation based on annotation and tests for the same. Not added SQLInvokedScalarFunctions. Not annotated builtin functions, as that is covered in next implementation phase. Not added C++ changes as this phase only covers Java side of changes. Added documentation for the new properties and ... 1. Previously, if any of the source stats were missing, we would still compute the max/min/sum of argument stats etc.. now we propagate NaNs if any one of the arguments' stats are missing. 2. For distinct values count, upper bounding it to row count is as good as unknown. Therefore, the approach here is, when distinctValuesCount is greater than row count and is provided via annotation we set it to unknown. A function developer has full control here, for example developer can choose to upper bound or not by selecting the appropriate StatsPropagationBehavior value. 3. For average row size, a) If average row size is provided via ScalarFunctionConstantStats annotation, then we allow even if the size is greater than functions return type width. b) If average row size is provided via one of the StatsPropagationBehavior values, then we upper bound it to functions return type width - if available. If both (a) and (b) is unknown, then we default it to functions return type width if available. This way the function developer has greater control. Added new behaviour SUM_ARGUMENTS_UPPER_BOUND_ROW_COUNT which would upper bound the values to row count, so that summing distinct values count not exceed row counts.
1 parent 27eb666 commit 9d77026

20 files changed

+1672
-31
lines changed

presto-docs/src/main/sphinx/admin/properties.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,17 @@ This can also be specified on a per-query basis using the ``confidence_based_bro
884884
Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. This can also be specified
885885
on a per-query basis using the ``treat-low-confidence-zero-estimation-as-unknown`` session property.
886886

887+
``optimizer.scalar-function-stats-propagation-enabled``
888+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
889+
890+
891+
* **Type:** ``boolean``
892+
* **Default value:** ``false``
893+
894+
Enable scalar functions stats propagation using annotations. Annotations define the behavior of the scalar
895+
function's stats characteristics. When set to ``true``, this property enables the stats propagation through annotations.
896+
This can also be specified on a per-query basis using the ``scalar_function_stats_propagation_enabled`` session property.
897+
887898
``optimizer.retry-query-with-history-based-optimization``
888899
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
889900

presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ public final class SystemSessionProperties
371371
public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms";
372372
public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns";
373373
public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values";
374+
public static final String SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED = "scalar_function_stats_propagation_enabled";
374375

375376
private final List<PropertyMetadata<?>> sessionProperties;
376377

@@ -2077,6 +2078,10 @@ public SystemSessionProperties(
20772078
booleanProperty(INLINE_PROJECTIONS_ON_VALUES,
20782079
"Whether to evaluate project node on values node",
20792080
featuresConfig.getInlineProjectionsOnValues(),
2081+
false),
2082+
booleanProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED,
2083+
"whether or not to respect stats propagation annotation for scalar functions (or UDF)",
2084+
featuresConfig.isScalarFunctionStatsPropagationEnabled(),
20802085
false));
20812086
}
20822087

@@ -3414,4 +3419,9 @@ public static boolean isInlineProjectionsOnValues(Session session)
34143419
{
34153420
return session.getSystemProperty(INLINE_PROJECTIONS_ON_VALUES, Boolean.class);
34163421
}
3422+
3423+
public static boolean shouldEnableScalarFunctionStatsPropagation(Session session)
3424+
{
3425+
return session.getSystemProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, Boolean.class);
3426+
}
34173427
}
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package com.facebook.presto.cost;
16+
17+
import com.facebook.presto.common.type.FixedWidthType;
18+
import com.facebook.presto.common.type.Type;
19+
import com.facebook.presto.common.type.VarcharType;
20+
import com.facebook.presto.spi.function.ScalarPropagateSourceStats;
21+
import com.facebook.presto.spi.function.ScalarStatsHeader;
22+
import com.facebook.presto.spi.function.StatsPropagationBehavior;
23+
import com.facebook.presto.spi.relation.CallExpression;
24+
import com.facebook.presto.spi.relation.RowExpression;
25+
26+
import java.util.List;
27+
import java.util.Map;
28+
29+
import static com.facebook.presto.spi.function.StatsPropagationBehavior.NON_NULL_ROW_COUNT;
30+
import static com.facebook.presto.spi.function.StatsPropagationBehavior.ROW_COUNT;
31+
import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS;
32+
import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT;
33+
import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN;
34+
import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS;
35+
import static com.facebook.presto.util.MoreMath.max;
36+
import static com.facebook.presto.util.MoreMath.min;
37+
import static com.facebook.presto.util.MoreMath.minExcludingNaNs;
38+
import static com.facebook.presto.util.MoreMath.nearlyEqual;
39+
import static com.google.common.base.Preconditions.checkArgument;
40+
import static com.google.common.collect.ImmutableList.toImmutableList;
41+
import static java.lang.Double.NaN;
42+
import static java.lang.Double.isFinite;
43+
import static java.lang.Double.isNaN;
44+
45+
public final class ScalarStatsAnnotationProcessor
46+
{
47+
private ScalarStatsAnnotationProcessor()
48+
{
49+
}
50+
51+
public static VariableStatsEstimate process(
52+
double outputRowCount,
53+
CallExpression callExpression,
54+
List<VariableStatsEstimate> sourceStats,
55+
ScalarStatsHeader scalarStatsHeader)
56+
{
57+
double nullFraction = scalarStatsHeader.getNullFraction();
58+
double distinctValuesCount = NaN;
59+
double averageRowSize = NaN;
60+
double maxValue = scalarStatsHeader.getMax();
61+
double minValue = scalarStatsHeader.getMin();
62+
for (Map.Entry<Integer, ScalarPropagateSourceStats> paramIndexToStatsMap : scalarStatsHeader.getArgumentStats().entrySet()) {
63+
ScalarPropagateSourceStats scalarPropagateSourceStats = paramIndexToStatsMap.getValue();
64+
boolean propagateAllStats = scalarPropagateSourceStats.propagateAllStats();
65+
nullFraction = min(firstFiniteValue(nullFraction, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
66+
sourceStats.stream().map(VariableStatsEstimate::getNullsFraction).collect(toImmutableList()),
67+
paramIndexToStatsMap.getKey(),
68+
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.nullFraction()))), 1.0);
69+
distinctValuesCount = firstFiniteValue(distinctValuesCount, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
70+
sourceStats.stream().map(VariableStatsEstimate::getDistinctValuesCount).collect(toImmutableList()),
71+
paramIndexToStatsMap.getKey(),
72+
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.distinctValuesCount())));
73+
StatsPropagationBehavior averageRowSizeStatsBehaviour = applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.avgRowSize());
74+
averageRowSize = minExcludingNaNs(firstFiniteValue(averageRowSize, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
75+
sourceStats.stream().map(VariableStatsEstimate::getAverageRowSize).collect(toImmutableList()),
76+
paramIndexToStatsMap.getKey(),
77+
averageRowSizeStatsBehaviour)), returnNaNIfTypeWidthUnknown(getReturnTypeWidth(callExpression, averageRowSizeStatsBehaviour)));
78+
maxValue = firstFiniteValue(maxValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
79+
sourceStats.stream().map(VariableStatsEstimate::getHighValue).collect(toImmutableList()),
80+
paramIndexToStatsMap.getKey(),
81+
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.maxValue())));
82+
minValue = firstFiniteValue(minValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
83+
sourceStats.stream().map(VariableStatsEstimate::getLowValue).collect(toImmutableList()),
84+
paramIndexToStatsMap.getKey(),
85+
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.minValue())));
86+
}
87+
if (isNaN(maxValue) || isNaN(minValue)) {
88+
minValue = NaN;
89+
maxValue = NaN;
90+
}
91+
return VariableStatsEstimate.builder()
92+
.setLowValue(minValue)
93+
.setHighValue(maxValue)
94+
.setNullsFraction(nullFraction)
95+
.setAverageRowSize(firstFiniteValue(scalarStatsHeader.getAvgRowSize(), averageRowSize, returnNaNIfTypeWidthUnknown(getReturnTypeWidth(callExpression, UNKNOWN))))
96+
.setDistinctValuesCount(processDistinctValuesCount(outputRowCount, nullFraction, scalarStatsHeader.getDistinctValuesCount(), distinctValuesCount)).build();
97+
}
98+
99+
private static double processDistinctValuesCount(double outputRowCount, double nullFraction, double distinctValuesCountFromConstant, double distinctValuesCount)
100+
{
101+
if (isFinite(distinctValuesCountFromConstant)) {
102+
if (nearlyEqual(distinctValuesCountFromConstant, NON_NULL_ROW_COUNT.getValue(), 0.1)) {
103+
distinctValuesCountFromConstant = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0));
104+
}
105+
else if (nearlyEqual(distinctValuesCount, ROW_COUNT.getValue(), 0.1)) {
106+
distinctValuesCountFromConstant = outputRowCount;
107+
}
108+
}
109+
double distinctValuesCountFinal = firstFiniteValue(distinctValuesCountFromConstant, distinctValuesCount);
110+
if (distinctValuesCountFinal > outputRowCount) {
111+
distinctValuesCountFinal = NaN;
112+
}
113+
return distinctValuesCountFinal;
114+
}
115+
116+
private static double processSingleArgumentStatistic(
117+
double outputRowCount,
118+
double nullFraction,
119+
CallExpression callExpression,
120+
List<Double> sourceStats,
121+
int sourceStatsArgumentIndex,
122+
StatsPropagationBehavior operation)
123+
{
124+
// sourceStatsArgumentIndex is index of the argument on which
125+
// ScalarPropagateSourceStats annotation was applied.
126+
double statValue = NaN;
127+
if (operation.isMultiArgumentStat()) {
128+
for (int i = 0; i < sourceStats.size(); i++) {
129+
if (i == 0 && operation.isSourceStatsDependentStats() && isFinite(sourceStats.get(i))) {
130+
statValue = sourceStats.get(i);
131+
}
132+
else {
133+
switch (operation) {
134+
case MAX_TYPE_WIDTH_VARCHAR:
135+
statValue = returnNaNIfTypeWidthUnknown(getTypeWidthVarchar(callExpression.getArguments().get(i).getType()));
136+
break;
137+
case USE_MIN_ARGUMENT:
138+
statValue = min(statValue, sourceStats.get(i));
139+
break;
140+
case USE_MAX_ARGUMENT:
141+
statValue = max(statValue, sourceStats.get(i));
142+
break;
143+
case SUM_ARGUMENTS:
144+
statValue = statValue + sourceStats.get(i);
145+
break;
146+
case SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT:
147+
statValue = min(statValue + sourceStats.get(i), outputRowCount);
148+
break;
149+
}
150+
}
151+
}
152+
}
153+
else {
154+
switch (operation) {
155+
case USE_SOURCE_STATS:
156+
statValue = sourceStats.get(sourceStatsArgumentIndex);
157+
break;
158+
case ROW_COUNT:
159+
statValue = outputRowCount;
160+
break;
161+
case NON_NULL_ROW_COUNT:
162+
statValue = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0));
163+
break;
164+
case USE_TYPE_WIDTH_VARCHAR:
165+
statValue = returnNaNIfTypeWidthUnknown(getTypeWidthVarchar(callExpression.getArguments().get(sourceStatsArgumentIndex).getType()));
166+
break;
167+
}
168+
}
169+
return statValue;
170+
}
171+
172+
private static int getTypeWidthVarchar(Type argumentType)
173+
{
174+
if (argumentType instanceof VarcharType) {
175+
if (!((VarcharType) argumentType).isUnbounded()) {
176+
return ((VarcharType) argumentType).getLengthSafe();
177+
}
178+
}
179+
return -VarcharType.MAX_LENGTH;
180+
}
181+
182+
private static double returnNaNIfTypeWidthUnknown(int typeWidthValue)
183+
{
184+
if (typeWidthValue <= 0) {
185+
return NaN;
186+
}
187+
return typeWidthValue;
188+
}
189+
190+
private static int getReturnTypeWidth(CallExpression callExpression, StatsPropagationBehavior operation)
191+
{
192+
if (callExpression.getType() instanceof FixedWidthType) {
193+
return ((FixedWidthType) callExpression.getType()).getFixedSize();
194+
}
195+
if (callExpression.getType() instanceof VarcharType) {
196+
VarcharType returnType = (VarcharType) callExpression.getType();
197+
if (!returnType.isUnbounded()) {
198+
return returnType.getLengthSafe();
199+
}
200+
if (operation == SUM_ARGUMENTS || operation == SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT) {
201+
// since return type is an unbounded varchar and operation is SUM_ARGUMENTS,
202+
// calculating the type width by doing a SUM of each argument's varchar type bounds - if available.
203+
int sum = 0;
204+
for (RowExpression r : callExpression.getArguments()) {
205+
int typeWidth;
206+
if (r instanceof CallExpression) { // argument is another function call
207+
typeWidth = getReturnTypeWidth((CallExpression) r, UNKNOWN);
208+
}
209+
else {
210+
typeWidth = getTypeWidthVarchar(r.getType());
211+
}
212+
if (typeWidth < 0) {
213+
return -VarcharType.MAX_LENGTH;
214+
}
215+
sum += typeWidth;
216+
}
217+
return sum;
218+
}
219+
}
220+
return -VarcharType.MAX_LENGTH;
221+
}
222+
223+
// Return first 'finite' value from values, else return values[0]
224+
private static double firstFiniteValue(double... values)
225+
{
226+
checkArgument(values.length > 1);
227+
for (double v : values) {
228+
if (isFinite(v)) {
229+
return v;
230+
}
231+
}
232+
return values[0];
233+
}
234+
235+
private static StatsPropagationBehavior applyPropagateAllStats(
236+
boolean propagateAllStats, StatsPropagationBehavior operation)
237+
{
238+
if (operation == UNKNOWN && propagateAllStats) {
239+
return USE_SOURCE_STATS;
240+
}
241+
return operation;
242+
}
243+
}

0 commit comments

Comments
 (0)