@@ -152,6 +152,29 @@ TEST_F(TestExtensionType, CreateFromArray) {
152
152
ASSERT_EQ (ext_arr->null_count (), 0 );
153
153
}
154
154
155
+ TEST_F (TestExtensionType, MakeArrayCanGetCorrectScalarType) {
156
+ ASSERT_OK_AND_ASSIGN (std::shared_ptr<Tensor> tensor,
157
+ Tensor::Make (value_type_, Buffer::Wrap (values_), shape_));
158
+
159
+ auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
160
+ ASSERT_OK_AND_ASSIGN (auto ext_arr, FixedShapeTensorArray::FromTensor (tensor));
161
+
162
+ std::shared_ptr<ArrayData> data = ext_arr->data ();
163
+ std::shared_ptr<FixedShapeTensorArray> array =
164
+ internal::checked_pointer_cast<FixedShapeTensorArray>(
165
+ exact_ext_type->MakeArray (data));
166
+ ASSERT_EQ (array->length (), shape_[0 ]);
167
+ ASSERT_EQ (array->null_count (), 0 );
168
+
169
+ // Check that we can get the first element of the array
170
+ ASSERT_OK_AND_ASSIGN (auto first_element, array->GetScalar (0 ));
171
+ ASSERT_EQ (*(first_element->type ),
172
+ *(fixed_shape_tensor (value_type_, element_shape_, {0 , 1 })));
173
+
174
+ ASSERT_OK_AND_ASSIGN (auto tensor_from_array, array->ToTensor ());
175
+ ASSERT_TRUE (tensor->Equals (*tensor_from_array));
176
+ }
177
+
155
178
void CheckSerializationRoundtrip (const std::shared_ptr<DataType>& ext_type) {
156
179
auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
157
180
auto serialized = fst_type->Serialize ();
0 commit comments