Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions onnxruntime/core/providers/cpu/tensor/identity_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,8 @@ class IdentityOp final : public OpKernel {
const auto* X = &input_ort_value->Get<Tensor>();
const TensorShape& shape = X->Shape();
Tensor* Y = context->Output(0, shape);
auto X_type = X->DataType();

const void* source = X->DataRaw(X_type);
void* target = Y->MutableDataRaw(X_type);
// If source and target pointers are not equal, we need to copy the data.
if (target != source) {
if (!X->IsDataTypeString()) {
memcpy(target, source, SafeInt<size_t>(shape.Size()) * X_type->Size());
} else {
// handle std::string
const auto* src = X->Data<std::string>();
auto* dst = Y->MutableData<std::string>();
std::copy(src, src + shape.Size(), dst);
}
}
CopyCpuTensor(X, Y);

if (is_dropout) {
Tensor* mask = context->Output(1, shape);
Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/core/providers/cpu/tensor/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,7 @@ inline void CopyCpuTensor(const Tensor* src, Tensor* tgt) {
auto* dst_string = tgt->MutableData<std::string>();
std::copy(src_span.begin(), src_span.end(), dst_string);
} else {
const auto element_size = src->DataType()->Size();
const auto elements = src->Shape().Size();
memcpy(target, source, SafeInt<size_t>(elements) * element_size);
memcpy(target, source, src->SizeInBytes());
}
}
}
Expand Down
217 changes: 217 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/cast_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include "gtest/gtest.h"

#include "core/framework/data_types_internal.h"
#include "core/framework/int2.h"
#include "core/framework/tensor.h"
#include "core/providers/cpu/tensor/utils.h"

#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
Expand Down Expand Up @@ -2731,5 +2734,219 @@ TEST(CastOpTest, Float4E2M1x2ToFloat) {

#endif

// Regression tests for sub-byte same-type cast (CopyCpuTensor heap overflow fix).
// When src and dst types are the same, Cast::Compute calls CopyCpuTensor which must
// use SizeInBytes() (not shape.Size() * DataType()->Size()) for the memcpy byte count.

TEST(CastOpTest, Int4x2ToInt4x2_SameType) {
const std::vector<int64_t> shape{3, 3}; // 9 elements (odd, tests ceil-division)
const std::vector<Int4x2> input = {
Int4x2(-8, 7),
Int4x2(0, -1),
Int4x2(3, -5),
Int4x2(6, 2),
Int4x2(1, 0) // 9th element in low nibble of 5th byte (padding in high nibble)
};

TestCastOp(gsl::make_span(input), gsl::make_span(input), shape);
}

TEST(CastOpTest, UInt4x2ToUInt4x2_SameType) {
const std::vector<int64_t> shape{2, 5}; // 10 elements (even)
const std::vector<UInt4x2> input = {
UInt4x2(0, 15),
UInt4x2(1, 14),
UInt4x2(7, 8),
UInt4x2(3, 6),
UInt4x2(9, 11)};

TestCastOp(gsl::make_span(input), gsl::make_span(input), shape);
}

TEST(CastOpTest, Int4x2ToInt4x2_LargeShape) {
// Large shape (16464 elements)
const std::vector<int64_t> shape{28, 6, 14, 7};
const int64_t num_elements = 28 * 6 * 14 * 7; // 16464
const size_t num_storage = static_cast<size_t>((num_elements + 1) / 2);

std::vector<Int4x2> input_vec(num_storage);
for (size_t i = 0; i < num_storage; ++i) {
input_vec[i] = Int4x2(static_cast<int8_t>(i % 8), static_cast<int8_t>(-(static_cast<int8_t>(i % 7))));
}
const auto& input = input_vec;

TestCastOp(gsl::make_span(input), gsl::make_span(input), shape);
}

TEST(CastOpTest, Int2x4ToInt2x4_SameType) {
const std::vector<int64_t> shape{5}; // 5 elements (not multiple of 4, tests ceil-division)
const std::vector<Int2x4> input = {
Int2x4(-2, 1, 0, -1),
Int2x4(1, 0, 0, 0) // 5th element in first position (padding in positions 2-4)
};

TestCastOpInt2(gsl::make_span(input), gsl::make_span(input), shape);
}

TEST(CastOpTest, UInt4x2ToUInt4x2_LargeShape) {
const std::vector<int64_t> shape{28, 6, 14, 7};
const int64_t num_elements = 28 * 6 * 14 * 7; // 16464
const size_t num_storage = static_cast<size_t>((num_elements + 1) / 2);

std::vector<UInt4x2> input_vec(num_storage);
for (size_t i = 0; i < num_storage; ++i) {
input_vec[i] = UInt4x2(static_cast<uint8_t>(i % 16), static_cast<uint8_t>((i + 3) % 16));
}
const auto& input = input_vec;

TestCastOp(gsl::make_span(input), gsl::make_span(input), shape);
}

TEST(CastOpTest, Int2x4ToInt2x4_LargeShape) {
const std::vector<int64_t> shape{100, 100}; // 10000 elements (not multiple of 4)
const int64_t num_elements = 100 * 100;
const size_t num_storage = static_cast<size_t>((num_elements + 3) / 4);

std::vector<Int2x4> input_vec(num_storage);
for (size_t i = 0; i < num_storage; ++i) {
input_vec[i] = Int2x4(static_cast<int8_t>(i % 2), static_cast<int8_t>(-(static_cast<int8_t>(i % 2))),
static_cast<int8_t>((i + 1) % 2), static_cast<int8_t>(0));
}
const auto& input = input_vec;

TestCastOpInt2(gsl::make_span(input), gsl::make_span(input), shape);
}

TEST(CastOpTest, UInt2x4ToUInt2x4_LargeShape) {
const std::vector<int64_t> shape{100, 101}; // 10100 elements (not multiple of 4)
const int64_t num_elements = 100 * 101;
const size_t num_storage = static_cast<size_t>((num_elements + 3) / 4);

std::vector<UInt2x4> input_vec(num_storage);
for (size_t i = 0; i < num_storage; ++i) {
input_vec[i] = UInt2x4(static_cast<uint8_t>(i % 4), static_cast<uint8_t>((i + 1) % 4),
static_cast<uint8_t>((i + 2) % 4), static_cast<uint8_t>((i + 3) % 4));
}
const auto& input = input_vec;

TestCastOpInt2(gsl::make_span(input), gsl::make_span(input), shape);
}

// Direct CopyCpuTensor test with guaranteed distinct buffers to exercise the memcpy path.
// This bypasses the MayInplace optimization that can alias input/output in OpTester.
// Uses guard bytes after the valid buffer region to detect overflow deterministically
// without relying on ASan — the pre-fix code would overwrite these sentinel bytes.
TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) {
constexpr uint8_t kGuardByte = 0xCD;
constexpr size_t kGuardSize = 64;

// Helper: allocate a buffer of `valid_bytes` + guard region, fill guard with sentinel,
// then construct a non-owning Tensor over the valid portion.
auto make_guarded_tensor = [&](MLDataType dtype, const TensorShape& shape,
size_t valid_bytes, std::vector<uint8_t>& backing) {
backing.resize(valid_bytes + kGuardSize);
std::memset(backing.data() + valid_bytes, kGuardByte, kGuardSize);
return Tensor(dtype, shape, backing.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator));
};

auto check_guard = [&](const std::vector<uint8_t>& backing, size_t valid_bytes,
const char* label) {
for (size_t i = 0; i < kGuardSize; ++i) {
EXPECT_EQ(backing[valid_bytes + i], kGuardByte)
<< label << ": guard byte at offset " << i << " was overwritten (heap overflow detected)";
}
};

// Test Int4x2 with odd element count (ceil-division edge case)
{
const int64_t num_logical_elements = 17; // odd: requires ceil(17/2) = 9 storage bytes
TensorShape shape({num_logical_elements});
auto int4_type = DataTypeImpl::GetType<Int4x2>();
constexpr size_t expected_bytes = 9;

std::vector<uint8_t> src_backing, dst_backing;
Tensor src = make_guarded_tensor(int4_type, shape, expected_bytes, src_backing);
Tensor dst = make_guarded_tensor(int4_type, shape, expected_bytes, dst_backing);

ASSERT_EQ(src.SizeInBytes(), expected_bytes);

// Fill source with known pattern
for (size_t i = 0; i < expected_bytes; ++i) {
src_backing[i] = static_cast<uint8_t>(0xA0 + i);
}
// Fill destination valid region with different pattern
std::memset(dst_backing.data(), 0xFF, expected_bytes);

ASSERT_NE(src.DataRaw(), dst.MutableDataRaw());

CopyCpuTensor(&src, &dst);

// Verify copy correctness
for (size_t i = 0; i < expected_bytes; ++i) {
EXPECT_EQ(dst_backing[i], src_backing[i]) << "Int4x2: mismatch at byte " << i;
}
// Verify no overflow past the valid region
check_guard(src_backing, expected_bytes, "Int4x2 src");
check_guard(dst_backing, expected_bytes, "Int4x2 dst");
}

// Test UInt4x2 with large even element count (matches PoC shape)
{
const int64_t num_logical_elements = 16464; // from PoC: ceil(16464/2) = 8232 bytes
TensorShape shape({num_logical_elements});
auto uint4_type = DataTypeImpl::GetType<UInt4x2>();
constexpr size_t expected_bytes = 8232;

std::vector<uint8_t> src_backing, dst_backing;
Tensor src = make_guarded_tensor(uint4_type, shape, expected_bytes, src_backing);
Tensor dst = make_guarded_tensor(uint4_type, shape, expected_bytes, dst_backing);

ASSERT_EQ(src.SizeInBytes(), expected_bytes);

for (size_t i = 0; i < expected_bytes; ++i) {
src_backing[i] = static_cast<uint8_t>(i & 0xFF);
}
std::memset(dst_backing.data(), 0xFF, expected_bytes);

ASSERT_NE(src.DataRaw(), dst.MutableDataRaw());

CopyCpuTensor(&src, &dst);

for (size_t i = 0; i < expected_bytes; ++i) {
EXPECT_EQ(dst_backing[i], src_backing[i]) << "UInt4x2: mismatch at byte " << i;
}
check_guard(src_backing, expected_bytes, "UInt4x2 src");
check_guard(dst_backing, expected_bytes, "UInt4x2 dst");
}

// Test Int2x4 (4 elements per byte — would be 4x overflow with old code)
{
const int64_t num_logical_elements = 7; // ceil(7/4) = 2 storage bytes
TensorShape shape({num_logical_elements});
auto int2_type = DataTypeImpl::GetType<Int2x4>();
constexpr size_t expected_bytes = 2;

std::vector<uint8_t> src_backing, dst_backing;
Tensor src = make_guarded_tensor(int2_type, shape, expected_bytes, src_backing);
Tensor dst = make_guarded_tensor(int2_type, shape, expected_bytes, dst_backing);

ASSERT_EQ(src.SizeInBytes(), expected_bytes);

src_backing[0] = 0xAB;
src_backing[1] = 0xCD;
std::memset(dst_backing.data(), 0xFF, expected_bytes);

ASSERT_NE(src.DataRaw(), dst.MutableDataRaw());

CopyCpuTensor(&src, &dst);

for (size_t i = 0; i < expected_bytes; ++i) {
EXPECT_EQ(dst_backing[i], src_backing[i]) << "Int2x4: mismatch at byte " << i;
}
check_guard(src_backing, expected_bytes, "Int2x4 src");
check_guard(dst_backing, expected_bytes, "Int2x4 dst");
}
}

} // namespace test
} // namespace onnxruntime
Loading