diff --git a/onnxruntime/core/providers/cpu/tensor/identity_op.h b/onnxruntime/core/providers/cpu/tensor/identity_op.h index 4e2a97d58861e..e2b0bf4e2c09a 100644 --- a/onnxruntime/core/providers/cpu/tensor/identity_op.h +++ b/onnxruntime/core/providers/cpu/tensor/identity_op.h @@ -49,21 +49,8 @@ class IdentityOp final : public OpKernel { const auto* X = &input_ort_value->Get(); const TensorShape& shape = X->Shape(); Tensor* Y = context->Output(0, shape); - auto X_type = X->DataType(); - const void* source = X->DataRaw(X_type); - void* target = Y->MutableDataRaw(X_type); - // If source and target pointers are not equal, we need to copy the data. - if (target != source) { - if (!X->IsDataTypeString()) { - memcpy(target, source, SafeInt(shape.Size()) * X_type->Size()); - } else { - // handle std::string - const auto* src = X->Data(); - auto* dst = Y->MutableData(); - std::copy(src, src + shape.Size(), dst); - } - } + CopyCpuTensor(X, Y); if (is_dropout) { Tensor* mask = context->Output(1, shape); diff --git a/onnxruntime/core/providers/cpu/tensor/utils.h b/onnxruntime/core/providers/cpu/tensor/utils.h index 313e9ea4b9948..0bd608e842092 100644 --- a/onnxruntime/core/providers/cpu/tensor/utils.h +++ b/onnxruntime/core/providers/cpu/tensor/utils.h @@ -452,9 +452,7 @@ inline void CopyCpuTensor(const Tensor* src, Tensor* tgt) { auto* dst_string = tgt->MutableData(); std::copy(src_span.begin(), src_span.end(), dst_string); } else { - const auto element_size = src->DataType()->Size(); - const auto elements = src->Shape().Size(); - memcpy(target, source, SafeInt(elements) * element_size); + memcpy(target, source, src->SizeInBytes()); } } } diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 67bb5d780ad2d..d5b6630668000 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -10,6 +10,9 @@ #include "gtest/gtest.h" #include "core/framework/data_types_internal.h" +#include "core/framework/int2.h" +#include "core/framework/tensor.h" +#include "core/providers/cpu/tensor/utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/providers/provider_test_utils.h" @@ -2731,5 +2734,219 @@ TEST(CastOpTest, Float4E2M1x2ToFloat) { #endif +// Regression tests for sub-byte same-type cast (CopyCpuTensor heap overflow fix). +// When src and dst types are the same, Cast::Compute calls CopyCpuTensor which must +// use SizeInBytes() (not shape.Size() * DataType()->Size()) for the memcpy byte count. + +TEST(CastOpTest, Int4x2ToInt4x2_SameType) { + const std::vector shape{3, 3}; // 9 elements (odd, tests ceil-division) + const std::vector input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2), + Int4x2(1, 0) // 9th element in low nibble of 5th byte (padding in high nibble) + }; + + TestCastOp(gsl::make_span(input), gsl::make_span(input), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt4x2_SameType) { + const std::vector shape{2, 5}; // 10 elements (even) + const std::vector input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 6), + UInt4x2(9, 11)}; + + TestCastOp(gsl::make_span(input), gsl::make_span(input), shape); +} + +TEST(CastOpTest, Int4x2ToInt4x2_LargeShape) { + // Large shape (16464 elements) + const std::vector shape{28, 6, 14, 7}; + const int64_t num_elements = 28 * 6 * 14 * 7; // 16464 + const size_t num_storage = static_cast((num_elements + 1) / 2); + + std::vector input_vec(num_storage); + for (size_t i = 0; i < num_storage; ++i) { + input_vec[i] = Int4x2(static_cast(i % 8), static_cast(-(static_cast(i % 7)))); + } + const auto& input = input_vec; + + TestCastOp(gsl::make_span(input), gsl::make_span(input), shape); +} + +TEST(CastOpTest, Int2x4ToInt2x4_SameType) { + const std::vector shape{5}; // 5 elements (not multiple of 4, tests ceil-division) + const std::vector input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, 0, 0, 0) // 5th element in first position (padding in positions 2-4) + }; + + TestCastOpInt2(gsl::make_span(input), gsl::make_span(input), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt4x2_LargeShape) { + const std::vector shape{28, 6, 14, 7}; + const int64_t num_elements = 28 * 6 * 14 * 7; // 16464 + const size_t num_storage = static_cast((num_elements + 1) / 2); + + std::vector input_vec(num_storage); + for (size_t i = 0; i < num_storage; ++i) { + input_vec[i] = UInt4x2(static_cast(i % 16), static_cast((i + 3) % 16)); + } + const auto& input = input_vec; + + TestCastOp(gsl::make_span(input), gsl::make_span(input), shape); +} + +TEST(CastOpTest, Int2x4ToInt2x4_LargeShape) { + const std::vector shape{100, 100}; // 10000 elements (not multiple of 4) + const int64_t num_elements = 100 * 100; + const size_t num_storage = static_cast((num_elements + 3) / 4); + + std::vector input_vec(num_storage); + for (size_t i = 0; i < num_storage; ++i) { + input_vec[i] = Int2x4(static_cast(i % 2), static_cast(-(static_cast(i % 2))), + static_cast((i + 1) % 2), static_cast(0)); + } + const auto& input = input_vec; + + TestCastOpInt2(gsl::make_span(input), gsl::make_span(input), shape); +} + +TEST(CastOpTest, UInt2x4ToUInt2x4_LargeShape) { + const std::vector shape{100, 101}; // 10100 elements (not multiple of 4) + const int64_t num_elements = 100 * 101; + const size_t num_storage = static_cast((num_elements + 3) / 4); + + std::vector input_vec(num_storage); + for (size_t i = 0; i < num_storage; ++i) { + input_vec[i] = UInt2x4(static_cast(i % 4), static_cast((i + 1) % 4), + static_cast((i + 2) % 4), static_cast((i + 3) % 4)); + } + const auto& input = input_vec; + + TestCastOpInt2(gsl::make_span(input), gsl::make_span(input), shape); +} + +// Direct CopyCpuTensor test with guaranteed distinct buffers to exercise the memcpy path. +// This bypasses the MayInplace optimization that can alias input/output in OpTester. +// Uses guard bytes after the valid buffer region to detect overflow deterministically +// without relying on ASan — the pre-fix code would overwrite these sentinel bytes. +TEST(CastOpTest, CopyCpuTensor_SubByteTypes_DistinctBuffers) { + constexpr uint8_t kGuardByte = 0xCD; + constexpr size_t kGuardSize = 64; + + // Helper: allocate a buffer of `valid_bytes` + guard region, fill guard with sentinel, + // then construct a non-owning Tensor over the valid portion. + auto make_guarded_tensor = [&](MLDataType dtype, const TensorShape& shape, + size_t valid_bytes, std::vector& backing) { + backing.resize(valid_bytes + kGuardSize); + std::memset(backing.data() + valid_bytes, kGuardByte, kGuardSize); + return Tensor(dtype, shape, backing.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)); + }; + + auto check_guard = [&](const std::vector& backing, size_t valid_bytes, + const char* label) { + for (size_t i = 0; i < kGuardSize; ++i) { + EXPECT_EQ(backing[valid_bytes + i], kGuardByte) + << label << ": guard byte at offset " << i << " was overwritten (heap overflow detected)"; + } + }; + + // Test Int4x2 with odd element count (ceil-division edge case) + { + const int64_t num_logical_elements = 17; // odd: requires ceil(17/2) = 9 storage bytes + TensorShape shape({num_logical_elements}); + auto int4_type = DataTypeImpl::GetType(); + constexpr size_t expected_bytes = 9; + + std::vector src_backing, dst_backing; + Tensor src = make_guarded_tensor(int4_type, shape, expected_bytes, src_backing); + Tensor dst = make_guarded_tensor(int4_type, shape, expected_bytes, dst_backing); + + ASSERT_EQ(src.SizeInBytes(), expected_bytes); + + // Fill source with known pattern + for (size_t i = 0; i < expected_bytes; ++i) { + src_backing[i] = static_cast(0xA0 + i); + } + // Fill destination valid region with different pattern + std::memset(dst_backing.data(), 0xFF, expected_bytes); + + ASSERT_NE(src.DataRaw(), dst.MutableDataRaw()); + + CopyCpuTensor(&src, &dst); + + // Verify copy correctness + for (size_t i = 0; i < expected_bytes; ++i) { + EXPECT_EQ(dst_backing[i], src_backing[i]) << "Int4x2: mismatch at byte " << i; + } + // Verify no overflow past the valid region + check_guard(src_backing, expected_bytes, "Int4x2 src"); + check_guard(dst_backing, expected_bytes, "Int4x2 dst"); + } + + // Test UInt4x2 with large even element count (matches PoC shape) + { + const int64_t num_logical_elements = 16464; // from PoC: ceil(16464/2) = 8232 bytes + TensorShape shape({num_logical_elements}); + auto uint4_type = DataTypeImpl::GetType(); + constexpr size_t expected_bytes = 8232; + + std::vector src_backing, dst_backing; + Tensor src = make_guarded_tensor(uint4_type, shape, expected_bytes, src_backing); + Tensor dst = make_guarded_tensor(uint4_type, shape, expected_bytes, dst_backing); + + ASSERT_EQ(src.SizeInBytes(), expected_bytes); + + for (size_t i = 0; i < expected_bytes; ++i) { + src_backing[i] = static_cast(i & 0xFF); + } + std::memset(dst_backing.data(), 0xFF, expected_bytes); + + ASSERT_NE(src.DataRaw(), dst.MutableDataRaw()); + + CopyCpuTensor(&src, &dst); + + for (size_t i = 0; i < expected_bytes; ++i) { + EXPECT_EQ(dst_backing[i], src_backing[i]) << "UInt4x2: mismatch at byte " << i; + } + check_guard(src_backing, expected_bytes, "UInt4x2 src"); + check_guard(dst_backing, expected_bytes, "UInt4x2 dst"); + } + + // Test Int2x4 (4 elements per byte — would be 4x overflow with old code) + { + const int64_t num_logical_elements = 7; // ceil(7/4) = 2 storage bytes + TensorShape shape({num_logical_elements}); + auto int2_type = DataTypeImpl::GetType(); + constexpr size_t expected_bytes = 2; + + std::vector src_backing, dst_backing; + Tensor src = make_guarded_tensor(int2_type, shape, expected_bytes, src_backing); + Tensor dst = make_guarded_tensor(int2_type, shape, expected_bytes, dst_backing); + + ASSERT_EQ(src.SizeInBytes(), expected_bytes); + + src_backing[0] = 0xAB; + src_backing[1] = 0xCD; + std::memset(dst_backing.data(), 0xFF, expected_bytes); + + ASSERT_NE(src.DataRaw(), dst.MutableDataRaw()); + + CopyCpuTensor(&src, &dst); + + for (size_t i = 0; i < expected_bytes; ++i) { + EXPECT_EQ(dst_backing[i], src_backing[i]) << "Int2x4: mismatch at byte " << i; + } + check_guard(src_backing, expected_bytes, "Int2x4 src"); + check_guard(dst_backing, expected_bytes, "Int2x4 dst"); + } +} + } // namespace test } // namespace onnxruntime