Skip to content

Commit 9abd336

Browse files
committed
Add test for FixedShapeTensorArray::MakeTensor
1 parent 27cebb3 commit 9abd336

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

cpp/src/arrow/extension/fixed_shape_tensor_test.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,29 @@ TEST_F(TestExtensionType, CreateFromArray) {
152152
ASSERT_EQ(ext_arr->null_count(), 0);
153153
}
154154

155+
TEST_F(TestExtensionType, MakeArrayCanGetCorrectScalarType) {
156+
ASSERT_OK_AND_ASSIGN(std::shared_ptr<Tensor> tensor,
157+
Tensor::Make(value_type_, Buffer::Wrap(values_), shape_));
158+
159+
auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
160+
ASSERT_OK_AND_ASSIGN(auto ext_arr, FixedShapeTensorArray::FromTensor(tensor));
161+
162+
std::shared_ptr<ArrayData> data = ext_arr->data();
163+
std::shared_ptr<FixedShapeTensorArray> array =
164+
internal::checked_pointer_cast<FixedShapeTensorArray>(
165+
exact_ext_type->MakeArray(data));
166+
ASSERT_EQ(array->length(), shape_[0]);
167+
ASSERT_EQ(array->null_count(), 0);
168+
169+
// Check that we can get the first element of the array
170+
ASSERT_OK_AND_ASSIGN(auto first_element, array->GetScalar(0));
171+
ASSERT_EQ(*(first_element->type),
172+
*(fixed_shape_tensor(value_type_, element_shape_, {0, 1})));
173+
174+
ASSERT_OK_AND_ASSIGN(auto tensor_from_array, array->ToTensor());
175+
ASSERT_TRUE(tensor->Equals(*tensor_from_array));
176+
}
177+
155178
void CheckSerializationRoundtrip(const std::shared_ptr<DataType>& ext_type) {
156179
auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
157180
auto serialized = fst_type->Serialize();

0 commit comments

Comments
 (0)