Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions backends/aoti/slim/core/SlimTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,34 @@ class SlimTensor {
return *this;
}

/**
* Extract the scalar value from a tensor with exactly 1 element.
* Automatically handles CUDA tensors by copying data to CPU.
*
* @tparam T The type to extract (must match tensor dtype).
* @return The scalar value.
*/
template <typename T>
T item() const {
ET_CHECK_MSG(
this->numel() == 1,
"item() requires tensor to have exactly 1 element, got %zu",
this->numel());

T result;
if (this->is_cpu()) {
result = *static_cast<const T*>(this->data_ptr());
} else {
#if defined(CUDA_AVAILABLE)
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
&result, this->data_ptr(), sizeof(T), CPU_DEVICE, this->device());
#else
ET_CHECK_MSG(false, "item(): CUDA tensor but CUDA support not available");
#endif
}
return result;
}

private:
SlimTensor _clone_impl(
c10::IntArrayRef sizes,
Expand Down
105 changes: 105 additions & 0 deletions backends/aoti/slim/core/test/test_slimtensor_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,111 @@ TEST(SlimTensorBasicTest, CopyConstructor) {
EXPECT_EQ(copy.dtype(), c10::ScalarType::Float);
}

// =============================================================================
// Item Tests (Device-Parameterized)
// =============================================================================

// Helper to set value in storage (handles both CPU and CUDA)
template <typename T>
void set_storage_value(
Storage& storage,
const T& value,
const c10::Device& dev) {
if (dev.is_cpu()) {
*static_cast<T*>(storage->data()) = value;
} else {
#if defined(CUDA_AVAILABLE)
DeviceTraits<c10::DeviceType::CUDA>::memcpy(
storage->data(), &value, sizeof(T), dev, CPU_DEVICE);
#endif
}
}

// Template function for testing item<T>() with explicit type
template <typename T>
void test_item_typed(
const c10::Device& dev,
c10::ScalarType dtype,
T input_value,
T expected_value) {
std::vector<int64_t> sizes = {1};
std::vector<int64_t> strides = {1};
Storage storage(new MaybeOwningStorage(dev, sizeof(T)));
set_storage_value(storage, input_value, dev);

SlimTensor tensor(
std::move(storage), makeArrayRef(sizes), makeArrayRef(strides), dtype);

T result = tensor.item<T>();
if constexpr (std::is_floating_point_v<T>) {
EXPECT_FLOAT_EQ(result, expected_value);
} else {
EXPECT_EQ(result, expected_value);
}
}

// Tests for item<T>() with explicit type
TEST_P(SlimTensorBasicDeviceTest, ItemTypedFloat) {
test_item_typed<float>(device(), c10::ScalarType::Float, 42.5f, 42.5f);
}

TEST_P(SlimTensorBasicDeviceTest, ItemTypedInt) {
test_item_typed<int32_t>(device(), c10::ScalarType::Int, 123, 123);
}

TEST_P(SlimTensorBasicDeviceTest, ItemTypedLong) {
test_item_typed<int64_t>(
device(), c10::ScalarType::Long, 9876543210LL, 9876543210LL);
}

TEST_P(SlimTensorBasicDeviceTest, ItemTypedShort) {
test_item_typed<int16_t>(device(), c10::ScalarType::Short, 1234, 1234);
}

TEST_P(SlimTensorBasicDeviceTest, ItemTypedChar) {
test_item_typed<int8_t>(device(), c10::ScalarType::Char, -42, -42);
}

TEST_P(SlimTensorBasicDeviceTest, ItemTypedBool) {
test_item_typed<bool>(device(), c10::ScalarType::Bool, true, true);
}

// Can't reuse test_item_typed() because we need to cast to float explictly for
// comparison.
TEST_P(SlimTensorBasicDeviceTest, ItemTypedBFloat16) {
c10::BFloat16 input{3.14f};
c10::BFloat16 expected{3.14f};
std::vector<int64_t> sizes = {1};
std::vector<int64_t> strides = {1};
Storage storage(new MaybeOwningStorage(device(), sizeof(c10::BFloat16)));
set_storage_value(storage, input, device());

SlimTensor tensor(
std::move(storage),
makeArrayRef(sizes),
makeArrayRef(strides),
c10::ScalarType::BFloat16);

c10::BFloat16 result = tensor.item<c10::BFloat16>();
EXPECT_FLOAT_EQ(static_cast<float>(result), static_cast<float>(expected));
}

// Test item() fails on non-scalar tensor (numel > 1)
TEST_P(SlimTensorBasicDeviceTest, ItemFailsOnNonScalarTensor) {
std::vector<int64_t> sizes = {2, 3};
std::vector<int64_t> strides = {3, 1};
Storage storage = make_storage(6 * sizeof(float));

SlimTensor tensor(
std::move(storage),
makeArrayRef(sizes),
makeArrayRef(strides),
c10::ScalarType::Float);

EXPECT_EQ(tensor.numel(), 6u);
EXPECT_DEATH(tensor.item<float>(), "");
}

// CPU-only test for DataPtrWithOffset (requires reading data back)
TEST(SlimTensorBasicTest, DataPtrWithOffset) {
std::vector<int64_t> sizes = {2, 3};
Expand Down
Loading