Skip to content

Commit

Permalink
Memory corruption in PartitionedOutput when keys are not a prefix of …
Browse files Browse the repository at this point in the history
…input (facebookincubator#10075)

Summary:
Pull Request resolved: facebookincubator#10075

In PartitionedOutput's collectNullRows() function, it assumes that the key channels are a prefix of the
input channels, i.e. the keys appear at the beginning of the input type.  It allocates an std::vector of
size equal to the number of key channels to hold DecodedVectors and assumes it can access these
using the key channels as indices.

When that assumption does not hold it accesses a DecodedVector off the end of that std::vector and
writes to it, leading to memory corruption as it writes to arbitrary memory.

The fix is to access the std::vector using the index of the keyChannel rather than the value of the
keyChannel.  This guarantees the std::vector is of minimal sufficient size and we do not read off the
end of it.

Note, this bug only happens if some of the keys are not a prefix of the input and
replicateNullsAndAny is set and there are nulls one of the key columns that is not a prefix of the
input.

Reviewed By: xiaoxmeng

Differential Revision: D58216159

fbshipit-source-id: 9809b12895369d3413485ab09eaafd9ddcac723b
  • Loading branch information
Kevin Wilfong authored and facebook-github-bot committed Jun 6, 2024
1 parent 179b108 commit c600d09
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
14 changes: 9 additions & 5 deletions velox/exec/PartitionedOutput.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,18 @@ void PartitionedOutput::collectNullRows() {

decodedVectors_.resize(keyChannels_.size());

for (auto i : keyChannels_) {
if (i == kConstantChannel) {
for (size_t keyChannelIndex = 0; keyChannelIndex < keyChannels_.size();
++keyChannelIndex) {
column_index_t keyChannel = keyChannels_[keyChannelIndex];
// Skip constant channel.
if (keyChannel == kConstantChannel) {
continue;
}
auto& keyVector = input_->childAt(i);
auto& keyVector = input_->childAt(keyChannel);
if (keyVector->mayHaveNulls()) {
decodedVectors_[i].decode(*keyVector, rows_);
if (auto* rawNulls = decodedVectors_[i].nulls(&rows_)) {
DecodedVector& decodedVector = decodedVectors_[keyChannelIndex];
decodedVector.decode(*keyVector, rows_);
if (auto* rawNulls = decodedVector.nulls(&rows_)) {
bits::orWithNegatedBits(
nullRows_.asMutableRange().bits(), rawNulls, 0, size);
}
Expand Down
40 changes: 40 additions & 0 deletions velox/exec/tests/PartitionedOutputTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,44 @@ TEST_F(PartitionedOutputTest, flush) {
EXPECT_EQ(partition1.size(), 2);
}

TEST_F(PartitionedOutputTest, keyChannelNotAtBeginningWithNulls) {
// This test verifies that PartitionedOutput can handle the case where a key
// channel is not at the beginning of the input type when nulls are present in
// the key channel. This triggers collectNullRows() to run which has special
// handling logic for the key channels.

auto input = makeRowVector(
// The key column p1 is the second column.
{"v1", "p1"},
{makeFlatVector<std::string>({"0", "1", "2", "3"}),
// Add nulls to the key column.
makeNullableFlatVector<int32_t>(std::vector<std::optional<int32_t>>{
0, std::nullopt, 1, std::nullopt})});

auto plan =
PlanBuilder()
.values({input}, false, 13)
// Set replicateNullsAndAny to true so we trigger the null path.
.partitionedOutput({"p1"}, 2, true, std::vector<std::string>{"v1"})
.planNode();

auto taskId = "local://test-partitioned-output-0";
auto task = Task::create(
taskId,
core::PlanFragment{plan},
0,
createQueryContext({}),
Task::ExecutionMode::kParallel);
task->start(1);

const auto partition0 = getAllData(taskId, 0);
const auto partition1 = getAllData(taskId, 1);

ASSERT_TRUE(waitForTaskCompletion(
task.get(),
std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::seconds(10))
.count()));
}

} // namespace facebook::velox::exec::test

0 comments on commit c600d09

Please sign in to comment.