Skip to content

Commit

Permalink
Remove allowInputShuffle_ flag in AggregationTestBase::testReadFromFi…
Browse files Browse the repository at this point in the history
…les (facebookincubator#9482)

Summary:
Simplify the AggregationTestBase by automatically determining if the
aggregate function is order-sensitive, which would disallow input shuffling.
If the orderSensitive flag is set to true for any of the test aggregates,
disable the test from files.

Fixes facebookincubator#9274

Pull Request resolved: facebookincubator#9482

Reviewed By: Yuhta

Differential Revision: D57340678

Pulled By: kgpai

fbshipit-source-id: 2e54fb7a2362a95cc5ca1dc01032722af23fd745
  • Loading branch information
yanngyoung authored and facebook-github-bot committed May 16, 2024
1 parent 7b8e41f commit e2c0014
Show file tree
Hide file tree
Showing 33 changed files with 62 additions and 63 deletions.
3 changes: 1 addition & 2 deletions velox/docs/develop/aggregate-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -976,8 +976,7 @@ The following query plans are being tested.
final aggregation with forced spilling. Query runs using 4 threads.

Query run with forced spilling is enabled only for group-by aggregations and
only if `allowInputShuffle_` flag is enabled by calling allowInputShuffle
() method from the SetUp(). Spill testing requires multiple batches of input.
only if aggregate functions are not order-sensitive. Spill testing requires multiple batches of input.
To split input data into multiple batches we add local exchange with
round-robin repartitioning before the partial aggregation. This changes the order
in which aggregation inputs are processed, hence, query results with spilling
Expand Down
2 changes: 0 additions & 2 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class SimpleAverageAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();

registerSimpleAverageAggregate(kSimpleAvg);
}
Expand Down Expand Up @@ -114,7 +113,6 @@ class SimpleArrayAggAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
disallowInputShuffle();

registerSimpleArrayAggAggregate(kSimpleArrayAgg);
}
Expand Down
1 change: 0 additions & 1 deletion velox/functions/lib/aggregates/tests/SumTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class SumTestBase : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

template <
Expand Down
68 changes: 60 additions & 8 deletions velox/functions/lib/aggregates/tests/utils/AggregationTestBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,55 @@ getFunctionNamesAndArgs(const std::vector<std::string>& aggregates) {
}
return std::make_pair(functionNames, aggregateArgs);
}

// Given a list of aggregation expressions, e.g., {"avg(c0)",
// "\"$internal$count_distinct\"(c1)"}, fetch the function names from
// AggregationNode of builder.planNode().
std::vector<std::string> getFunctionNames(
std::function<void(exec::test::PlanBuilder&)> makeSource,
const std::vector<std::string>& aggregates,
memory::MemoryPool* pool) {
std::vector<std::string> functionNames;

PlanBuilder builder(pool);
makeSource(builder);

builder.singleAggregation({}, aggregates);
auto& aggregationNode =
static_cast<const core::AggregationNode&>(*builder.planNode());

for (const auto& aggregate : aggregationNode.aggregates()) {
const auto& aggregateExpr = aggregate.call;
const auto& name = aggregateExpr->name();

functionNames.push_back(name);
}

return functionNames;
}

// Given a list of aggregation expressions, check if any of aggregate functions
// are order sensitive with metadata.
bool hasOrderSensitive(
std::function<void(exec::test::PlanBuilder&)> makeSource,
const std::vector<std::string>& aggregates,
memory::MemoryPool* pool) {
auto functionNames = getFunctionNames(makeSource, aggregates, pool);
return std::any_of(
functionNames.begin(), functionNames.end(), [](const auto& functionName) {
auto* entry = exec::getAggregateFunctionEntry(functionName);
const auto& metadata = entry->metadata;
return metadata.orderSensitive;
});
}
// Same as above, but allows to specify input data instead of a function.
bool hasOrderSensitive(
const std::vector<RowVectorPtr>& data,
const std::vector<std::string>& aggregates,
memory::MemoryPool* pool) {
return hasOrderSensitive(
[&](PlanBuilder& builder) { builder.values(data); }, aggregates, pool);
}
} // namespace

void AggregationTestBase::testAggregationsWithCompanion(
Expand Down Expand Up @@ -372,7 +421,7 @@ void AggregationTestBase::testAggregationsWithCompanion(
assertResults(queryBuilder);
}

if (!groupingKeys.empty() && allowInputShuffle_) {
if (!groupingKeys.empty() && !hasOrderSensitive(data, aggregates, pool())) {
SCOPED_TRACE("Run partial + final with spilling");
PlanBuilder builder(pool());
builder.values(dataWithExtraGroupingKey);
Expand Down Expand Up @@ -604,7 +653,6 @@ bool isTableScanSupported(const TypePtr& type) {

return true;
}

} // namespace

void AggregationTestBase::testReadFromFiles(
Expand Down Expand Up @@ -633,9 +681,10 @@ void AggregationTestBase::testReadFromFiles(
auto writerPool = rootPool_->addAggregateChild("AggregationTestBase.writer");

// Splits and writes the input vectors into two files, to some extent,
// involves shuffling of the inputs. So only split input if allowInputShuffle_
// is true. Otherwise, only write into a single file.
if (allowInputShuffle_ && input->size() >= 2) {
// involves shuffling of the inputs. So only split input if aggregate
// is non-orderSensitive. Otherwise, only write into a single file.
if (!hasOrderSensitive(makeSource, aggregates, pool()) &&
input->size() >= 2) {
auto size1 = input->size() / 2;
auto size2 = input->size() - size1;
auto input1 = input->slice(0, size1);
Expand Down Expand Up @@ -814,7 +863,8 @@ void AggregationTestBase::testAggregationsImpl(
assertResults(queryBuilder);
}

if (!groupingKeys.empty() && allowInputShuffle_) {
if (!groupingKeys.empty() &&
!hasOrderSensitive(makeSource, aggregates, pool())) {
SCOPED_TRACE("Run partial + final with spilling");
PlanBuilder builder(pool());
makeSource(builder);
Expand Down Expand Up @@ -926,7 +976,8 @@ void AggregationTestBase::testAggregationsImpl(
assertResults(queryBuilder);
}

if (!groupingKeys.empty() && allowInputShuffle_) {
if (!groupingKeys.empty() &&
!hasOrderSensitive(makeSource, aggregates, pool())) {
SCOPED_TRACE("Run single with spilling");
PlanBuilder builder(pool());
makeSource(builder);
Expand Down Expand Up @@ -1117,7 +1168,8 @@ void AggregationTestBase::testAggregations(
testIncrementalAggregation(makeSource, aggregates, config);
}

if (allowInputShuffle_ && !groupingKeys.empty()) {
if (!hasOrderSensitive(makeSource, aggregates, pool()) &&
!groupingKeys.empty()) {
testStreamingAggregationsImpl(
makeSource,
groupingKeys,
Expand Down
13 changes: 0 additions & 13 deletions velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,6 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
const std::string& expectedMessage,
const std::unordered_map<std::string, std::string>& config = {});

/// Specifies that aggregate functions used in this test are not sensitive
/// to the order of inputs.
void allowInputShuffle() {
allowInputShuffle_ = true;
}
/// Specifies that aggregate functions used in this test are sensitive
/// to the order of inputs.
void disallowInputShuffle() {
allowInputShuffle_ = false;
}

void disableTestStreaming() {
testStreaming_ = false;
}
Expand Down Expand Up @@ -296,8 +285,6 @@ class AggregationTestBase : public exec::test::OperatorTestBase {
std::function<std::shared_ptr<exec::Task>(
exec::test::AssertQueryBuilder&)> assertResults,
const std::unordered_map<std::string, std::string>& config);

bool allowInputShuffle_{false};
};

} // namespace facebook::velox::functions::aggregate::test
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ struct ApproxMostFrequentTest : AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

std::shared_ptr<FlatVector<int>> makeGroupKeys() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ class ApproxPercentileTest : public AggregationTestBase {
void SetUp() override {
AggregationTestBase::SetUp();
random::setSeed(0);
allowInputShuffle();
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class AverageAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();

registerSimpleAverageAggregate("simple_avg");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class BitwiseAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

RowTypePtr rowType_{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class BoolAndOrTest : public virtual AggregationTestBase,
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class CentralMomentsAggregationTest
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

void testGroupBy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class ChecksumAggregateTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class CountAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

RowTypePtr rowType_{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class CountDistinctTest : public AggregationTestBase {
void SetUp() override {
prestosql::registerInternalAggregateFunctions("");
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class CountIfAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

RowTypePtr rowType_{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class CovarianceAggregationTest
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

void testGroupBy(const std::string& aggName, const RowVectorPtr& data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class EntropyAggregationTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

void testGroupByAgg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class GeometricMeanTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class HistogramTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

void testHistogramWithDuck(
Expand Down
2 changes: 0 additions & 2 deletions velox/functions/prestosql/aggregates/tests/MapAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ TEST_F(MapAggTest, groupByWithNullValues) {
}

TEST_F(MapAggTest, groupByWithDuplicates) {
disallowInputShuffle();

auto data = makeRowVector({
makeFlatVector<int32_t>({0, 0, 1, 1, 2, 2, 3, 3, 4, 4}),
makeFlatVector<int32_t>({0, 0, 1, 1, 2, 2, 3, 3, 4, 4}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class MaxSizeForStatsTest : public AggregationTestBase {
public:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ std::string asSql(bool value) {

void MinMaxByAggregationTestBase::SetUp() {
AggregationTestBase::SetUp();
AggregationTestBase::disallowInputShuffle();
std::vector<TypePtr> supportedTypes;
std::vector<std::string> columnNames;
int columnId = 0;
Expand Down Expand Up @@ -647,8 +646,6 @@ TEST_P(
}

// Enable disk spilling test with distinct comparison values.
AggregationTestBase::allowInputShuffle();

auto rowType =
ROW({"c0", "c1"},
{createScalarType(GetParam().valueType),
Expand Down Expand Up @@ -1083,8 +1080,6 @@ TEST_P(
}

// Enable disk spilling test with distinct comparison values.
AggregationTestBase::allowInputShuffle();

auto rowType =
ROW({"c0", "c1", "c2"},
{createScalarType(GetParam().valueType),
Expand Down Expand Up @@ -1387,7 +1382,6 @@ class MinMaxByNTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
AggregationTestBase::allowInputShuffle();
AggregationTestBase::enableTestStreaming();
}
};
Expand Down
2 changes: 0 additions & 2 deletions velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class MinMaxTest : public functions::aggregate::test::AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

std::vector<RowVectorPtr> fuzzData(const RowTypePtr& rowType) {
Expand Down Expand Up @@ -474,7 +473,6 @@ class MinMaxNTest : public functions::aggregate::test::AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class MultiMapAggTest : public functions::aggregate::test::AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class ReduceAggTest : public functions::aggregate::test::AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
disableTestStreaming();
}

Expand Down
1 change: 0 additions & 1 deletion velox/functions/prestosql/aggregates/tests/SetAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class SetAggTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class SetUnionTest : public AggregationTestBase {
protected:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ class SumDataSizeForStatsTest : public AggregationTestBase {
public:
void SetUp() override {
AggregationTestBase::SetUp();
allowInputShuffle();
}
};

Expand Down
Loading

0 comments on commit e2c0014

Please sign in to comment.