Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions velox/exec/SimpleAggregateAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ class SimpleAggregateAdapter : public Aggregate {
std::declval<const TypePtr&>(),
std::declval<const core::QueryConfig&>()))>> : std::true_type {};

// Whether the function defines setConstantInputs(). AggregateInfo discovers
// constant arguments after the aggregate is constructed and calls
// Aggregate::setConstantInputs() once before processing input rows. The
// adapter forwards that hook to simple aggregates that opt in.
template <typename T, typename = void>
struct support_set_constant_inputs : std::false_type {};

template <typename T>
struct support_set_constant_inputs<
T,
std::void_t<decltype(std::declval<T&>().setConstantInputs(
std::declval<const std::vector<VectorPtr>&>()))>> : std::true_type {};

// Whether the accumulator requires aligned access. If it is defined,
// SimpleAggregateAdapter::accumulatorAlignmentSize() returns
// alignof(typename FUNC::AccumulatorType).
Expand Down Expand Up @@ -248,6 +261,9 @@ class SimpleAggregateAdapter : public Aggregate {
static constexpr bool accumulator_is_aligned_ =
accumulator_is_aligned<typename FUNC::AccumulatorType>::value;

static constexpr bool support_set_constant_inputs_ =
support_set_constant_inputs<FUNC>::value;

bool isFixedSize() const override {
return accumulator_is_fixed_size_;
}
Expand All @@ -267,6 +283,13 @@ class SimpleAggregateAdapter : public Aggregate {
return Aggregate::accumulatorAlignmentSize();
}

void setConstantInputs(
const std::vector<VectorPtr>& constantInputs) override {
if constexpr (support_set_constant_inputs_) {
fn_->setConstantInputs(constantInputs);
}
}

// Add raw input to accumulators. If the simple aggregation function has
// default null behavior, input rows that has nulls are skipped. Otherwise,
// the accumulator type's addInput() method handles null inputs.
Expand Down
101 changes: 101 additions & 0 deletions velox/exec/tests/SimpleAggregateAdapterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,107 @@ TEST_F(SimpleCountNullsAggregationTest, basic) {
testAggregations({vectors}, {}, {"simple_count_nulls(c2)"}, {expected});
}

class ConstantInputForwardingAggregate {
public:
using InputType = Row<int64_t, int64_t>;
using IntermediateType = int64_t;
using OutputType = int64_t;

void setConstantInputs(const std::vector<VectorPtr>& constantInputs) {
VELOX_CHECK_EQ(constantInputs.size(), 2);
VELOX_CHECK_NULL(constantInputs[0]);
VELOX_CHECK_NOT_NULL(constantInputs[1]);
auto* constant = constantInputs[1]->as<ConstantVector<int64_t>>();
VELOX_CHECK_NOT_NULL(constant);

offset_ = constant->valueAt(0);
hasOffset_ = true;
}

struct Accumulator {
int64_t sum{0};
ConstantInputForwardingAggregate* fn;

explicit Accumulator(
HashStringAllocator* /*allocator*/,
ConstantInputForwardingAggregate* fn)
: fn(fn) {}

void addInput(
HashStringAllocator* /*allocator*/,
exec::arg_type<int64_t> value,
exec::arg_type<int64_t> /*constantValue*/) {
VELOX_CHECK(fn->hasOffset_);
sum += value + fn->offset_;
}

void combine(
HashStringAllocator* /*allocator*/,
exec::arg_type<int64_t> other) {
sum += other;
}

bool writeIntermediateResult(exec::out_type<IntermediateType>& out) {
out = sum;
return true;
}

bool writeFinalResult(exec::out_type<OutputType>& out) {
out = sum;
return true;
}
};

using AccumulatorType = Accumulator;

bool hasOffset_{false};
int64_t offset_{0};
};

class SimpleConstantInputForwardingAggregationTest
: public AggregationTestBase {};

TEST_F(SimpleConstantInputForwardingAggregationTest, forwardsConstantInputs) {
SimpleAggregateAdapter<ConstantInputForwardingAggregate> aggregate(
core::AggregationNode::Step::kSingle, {BIGINT(), BIGINT()}, BIGINT());

HashStringAllocator stringAllocator{pool()};
aggregate.setAllocator(&stringAllocator);

int32_t rowSizeOffset = bits::nbytes(1);
int32_t offset = rowSizeOffset + sizeof(uint32_t);
offset = bits::roundUp(offset, aggregate.accumulatorAlignmentSize());
aggregate.setOffsets(
offset,
RowContainer::nullByte(0),
RowContainer::nullMask(0),
RowContainer::initializedByte(0),
RowContainer::initializedMask(0),
rowSizeOffset);

auto constantInput = makeConstant<int64_t>(10, 1);
aggregate.setConstantInputs(std::vector<VectorPtr>{nullptr, constantInput});

std::vector<char> group(offset + aggregate.accumulatorFixedWidthSize());
std::vector<char*> groups{group.data()};
std::vector<vector_size_t> indices{0};
aggregate.initializeNewGroups(groups.data(), indices);

auto input = makeFlatVector<int64_t>({1, 2, 3});
auto constantArg =
BaseVector::wrapInConstant(input->size(), 0, constantInput);
aggregate.addSingleGroupRawInput(
group.data(),
SelectivityVector(input->size()),
{input, constantArg},
false);

auto result = BaseVector::create(BIGINT(), 1, pool());
aggregate.extractValues(groups.data(), 1, &result);

ASSERT_EQ(result->as<FlatVector<int64_t>>()->valueAt(0), 36);
}

// A testing simple avg aggregate function, and it is used to check for
// expectations for function-level variables. The validation logic is in the
// Accumulator::addInput method.
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ velox_add_library(
UnscaledValueFunction.h
Uuid.h
VarcharTypeWriteSideCheck.h
XxHash64.h
)

velox_link_libraries(
Expand Down
150 changes: 2 additions & 148 deletions velox/functions/sparksql/Hash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

#include <folly/CPortability.h>

#include "velox/common/base/BitUtil.h"
#include "velox/expression/DecodedArgs.h"
#include "velox/functions/lib/Murmur3Hash32Base.h"
#include "velox/functions/sparksql/XxHash64.h"
#include "velox/type/DecimalUtil.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::functions::sparksql {
Expand Down Expand Up @@ -470,153 +471,6 @@ class Murmur3HashFunction final : public exec::VectorFunction {
const std::optional<int32_t> seed_;
};

class XxHash64 final {
public:
using SeedType = int64_t;
using ReturnType = int64_t;

static uint64_t hashInt32(const int32_t input, uint64_t seed) {
int64_t hash = seed + PRIME64_5 + 4L;
hash ^= static_cast<int64_t>((input & 0xFFFFFFFFL) * PRIME64_1);
hash = bits::rotateLeft64(hash, 23) * PRIME64_2 + PRIME64_3;
return fmix(hash);
}

static uint64_t hashInt64(int64_t input, uint64_t seed) {
int64_t hash = seed + PRIME64_5 + 8L;
hash ^= bits::rotateLeft64(input * PRIME64_2, 31) * PRIME64_1;
hash = bits::rotateLeft64(hash, 27) * PRIME64_1 + PRIME64_4;
return fmix(hash);
}

// Floating point numbers are hashed as if they are integers, with
// -0f defined to have the same output as +0f.
static uint64_t hashFloat(float input, uint64_t seed) {
return hashInt32(
input == -0.f ? 0 : *reinterpret_cast<uint32_t*>(&input), seed);
}

static uint64_t hashDouble(double input, uint64_t seed) {
return hashInt64(
input == -0. ? 0 : *reinterpret_cast<uint64_t*>(&input), seed);
}

static uint64_t hashBytes(const StringView& input, uint64_t seed) {
const char* i = input.data();
const char* const end = input.data() + input.size();

uint64_t hash = hashBytesByWords(input, seed);
uint32_t length = input.size();
auto offset = i + (length & -8);
if (offset + 4L <= end) {
hash ^= (*reinterpret_cast<const uint64_t*>(offset) & 0xFFFFFFFFL) *
PRIME64_1;
hash = bits::rotateLeft64(hash, 23) * PRIME64_2 + PRIME64_3;
offset += 4L;
}

while (offset < end) {
hash ^= (*reinterpret_cast<const uint64_t*>(offset) & 0xFFL) * PRIME64_5;
hash = bits::rotateLeft64(hash, 11) * PRIME64_1;
offset++;
}
return fmix(hash);
}

static uint64_t hashLongDecimal(int128_t input, uint64_t seed) {
char out[sizeof(int128_t)];
int32_t length = DecimalUtil::toByteArray(input, out);
return hashBytes(StringView(out, length), seed);
}

static uint64_t hashTimestamp(Timestamp input, uint64_t seed) {
return hashInt64(input.toMicros(), seed);
}

private:
static const uint64_t PRIME64_1 = 0x9E3779B185EBCA87L;
static const uint64_t PRIME64_2 = 0xC2B2AE3D27D4EB4FL;
static const uint64_t PRIME64_3 = 0x165667B19E3779F9L;
static const uint64_t PRIME64_4 = 0x85EBCA77C2B2AE63L;
static const uint64_t PRIME64_5 = 0x27D4EB2F165667C5L;

static uint64_t fmix(uint64_t hash) {
hash ^= hash >> 33;
hash *= PRIME64_2;
hash ^= hash >> 29;
hash *= PRIME64_3;
hash ^= hash >> 32;
return hash;
}

static uint64_t hashBytesByWords(const StringView& input, uint64_t seed) {
const char* i = input.data();
const char* const end = input.data() + input.size();
uint32_t length = input.size();
uint64_t hash;
if (length >= 32) {
uint64_t v1 = seed + PRIME64_1 + PRIME64_2;
uint64_t v2 = seed + PRIME64_2;
uint64_t v3 = seed;
uint64_t v4 = seed - PRIME64_1;
for (; i <= end - 32; i += 32) {
v1 = bits::rotateLeft64(
v1 + (*reinterpret_cast<const uint64_t*>(i) * PRIME64_2), 31) *
PRIME64_1;
v2 = bits::rotateLeft64(
v2 + (*reinterpret_cast<const uint64_t*>(i + 8) * PRIME64_2),
31) *
PRIME64_1;
v3 = bits::rotateLeft64(
v3 + (*reinterpret_cast<const uint64_t*>(i + 16) * PRIME64_2),
31) *
PRIME64_1;
v4 = bits::rotateLeft64(
v4 + (*reinterpret_cast<const uint64_t*>(i + 24) * PRIME64_2),
31) *
PRIME64_1;
}
hash = bits::rotateLeft64(v1, 1) + bits::rotateLeft64(v2, 7) +
bits::rotateLeft64(v3, 12) + bits::rotateLeft64(v4, 18);
v1 *= PRIME64_2;
v1 = bits::rotateLeft64(v1, 31);
v1 *= PRIME64_1;
hash ^= v1;
hash = hash * PRIME64_1 + PRIME64_4;

v2 *= PRIME64_2;
v2 = bits::rotateLeft64(v2, 31);
v2 *= PRIME64_1;
hash ^= v2;
hash = hash * PRIME64_1 + PRIME64_4;

v3 *= PRIME64_2;
v3 = bits::rotateLeft64(v3, 31);
v3 *= PRIME64_1;
hash ^= v3;
hash = hash * PRIME64_1 + PRIME64_4;

v4 *= PRIME64_2;
v4 = bits::rotateLeft64(v4, 31);
v4 *= PRIME64_1;
hash ^= v4;
hash = hash * PRIME64_1 + PRIME64_4;
} else {
hash = seed + PRIME64_5;
}

hash += length;

for (; i <= end - 8; i += 8) {
hash ^= bits::rotateLeft64(
*reinterpret_cast<const uint64_t*>(i) * PRIME64_2, 31) *
PRIME64_1;
hash = bits::rotateLeft64(hash, 27) * PRIME64_1 + PRIME64_4;
}
return hash;
}
};

class XxHash64Function final : public exec::VectorFunction {
public:
XxHash64Function() = default;
Expand Down
Loading
Loading