diff --git a/src/fury/encoder/row_encode_trait.h b/src/fury/encoder/row_encode_trait.h index e87d2f1c2b..11519f8702 100644 --- a/src/fury/encoder/row_encode_trait.h +++ b/src/fury/encoder/row_encode_trait.h @@ -19,6 +19,7 @@ #include "fury/meta/field_info.h" #include "fury/meta/type_traits.h" #include "fury/row/writer.h" +#include #include #include #include @@ -67,8 +68,11 @@ template inline constexpr bool IsString = meta::IsOneOf::value; +template inline constexpr bool IsMap = meta::IsPairIterable; + template -inline constexpr bool IsArray = meta::IsIterable && !IsString; +inline constexpr bool IsArray = + meta::IsIterable && !IsString && !IsMap; template inline constexpr bool IsOptional = false; @@ -76,7 +80,8 @@ template inline constexpr bool IsOptional> = true; template inline constexpr bool IsClassButNotBuiltin = - std::is_class_v && !(IsString || IsArray || IsOptional); + std::is_class_v && + !(IsString || IsArray || IsOptional || IsMap); inline decltype(auto) GetChildType(RowWriter &writer, int index) { return writer.schema()->field(index)->type(); @@ -263,6 +268,74 @@ struct RowEncodeTrait +struct RowEncodeTrait>>> { + static auto Type() { + return arrow::map( + RowEncodeTrait::Type(), + RowEncodeTrait::Type()); + } + + template + static void WriteKey(V &&visitor, const T &value, ArrayWriter &writer) { + int index = 0; + for (const auto &v : value) { + RowEncodeTrait::Write( + std::forward(visitor), v.first, writer, index); + ++index; + } + } + + template + static void WriteValue(V &&visitor, const T &value, ArrayWriter &writer) { + int index = 0; + for (const auto &v : value) { + RowEncodeTrait::Write( + std::forward(visitor), v.second, writer, index); + ++index; + } + } + + template ::value, + int> = 0> + static void Write(V &&visitor, const T &value, W &writer, int index) { + auto offset = writer.cursor(); + writer.WriteDirectly(-1); + + auto map_type = std::dynamic_pointer_cast( + details::GetChildType(writer, index)); + + auto key_writer = + std::make_unique(std::static_pointer_cast( + arrow::list(map_type->key_type())), + &writer); + + key_writer->Reset(value.size()); + RowEncodeTrait::WriteKey(std::forward(visitor), value, + *key_writer.get()); + + writer.WriteDirectly(offset, key_writer->size()); + + auto value_writer = + std::make_unique(std::static_pointer_cast( + arrow::list(map_type->item_type())), + &writer); + + value_writer->Reset(value.size()); + RowEncodeTrait::WriteValue(std::forward(visitor), value, + *value_writer.get()); + + writer.SetOffsetAndSize(index, offset, writer.cursor() - offset); + + std::forward(visitor).template Visit>( + std::move(key_writer)); + std::forward(visitor).template Visit>( + std::move(value_writer)); + } +}; + } // namespace encoder } // namespace fury diff --git a/src/fury/encoder/row_encode_trait_test.cc b/src/fury/encoder/row_encode_trait_test.cc index e8a0ee4471..b8fd63b240 100644 --- a/src/fury/encoder/row_encode_trait_test.cc +++ b/src/fury/encoder/row_encode_trait_test.cc @@ -294,6 +294,70 @@ TEST(RowEncodeTrait, Optional) { } } +struct G { + std::map> a; + std::map b; +}; + +FURY_FIELD_INFO(G, a, b); + +TEST(RowEncodeTrait, Map) { + G v{{{1, {{3, 4}, {5, 6}}}, {2, {{7, 8}, {9, 10}, {11, 12}}}}, + {{"a", A{1, 1.1, true}}, {"b", A{2, 3.3, false}}}}; + + auto schema = encoder::RowEncodeTrait::Type(); + + auto a_map = + std::dynamic_pointer_cast(schema->field(0)->type()); + ASSERT_EQ(a_map->key_type()->name(), "int32"); + ASSERT_EQ(a_map->item_type()->name(), "map"); + ASSERT_EQ(std::dynamic_pointer_cast(a_map->item_type()) + ->key_type() + ->name(), + "int32"); + ASSERT_EQ(std::dynamic_pointer_cast(a_map->item_type()) + ->item_type() + ->name(), + "int32"); + + auto b_map = + std::dynamic_pointer_cast(schema->field(1)->type()); + ASSERT_EQ(b_map->key_type()->name(), "utf8"); + ASSERT_EQ(b_map->item_type()->name(), "struct"); + ASSERT_EQ(b_map->item_type()->field(0)->type()->name(), "int32"); + ASSERT_EQ(b_map->item_type()->field(1)->type()->name(), "float"); + ASSERT_EQ(b_map->item_type()->field(2)->type()->name(), "bool"); + + RowWriter writer(encoder::RowEncodeTrait::Schema()); + writer.Reset(); + + encoder::RowEncodeTrait::Write(encoder::EmptyWriteVisitor{}, v, writer); + + auto map_a = writer.ToRow()->GetMap(0); + ASSERT_EQ(map_a->keys_array()->GetInt32(0), 1); + ASSERT_EQ(map_a->keys_array()->GetInt32(1), 2); + ASSERT_EQ(map_a->values_array()->GetMap(0)->keys_array()->GetInt32(0), 3); + ASSERT_EQ(map_a->values_array()->GetMap(0)->keys_array()->GetInt32(1), 5); + ASSERT_EQ(map_a->values_array()->GetMap(0)->values_array()->GetInt32(0), 4); + ASSERT_EQ(map_a->values_array()->GetMap(0)->values_array()->GetInt32(1), 6); + ASSERT_EQ(map_a->values_array()->GetMap(1)->keys_array()->GetInt32(0), 7); + ASSERT_EQ(map_a->values_array()->GetMap(1)->keys_array()->GetInt32(1), 9); + ASSERT_EQ(map_a->values_array()->GetMap(1)->keys_array()->GetInt32(2), 11); + ASSERT_EQ(map_a->values_array()->GetMap(1)->values_array()->GetInt32(0), 8); + ASSERT_EQ(map_a->values_array()->GetMap(1)->values_array()->GetInt32(1), 10); + ASSERT_EQ(map_a->values_array()->GetMap(1)->values_array()->GetInt32(2), 12); + + auto map_b = writer.ToRow()->GetMap(1); + ASSERT_EQ(map_b->keys_array()->GetString(0), "a"); + ASSERT_EQ(map_b->keys_array()->GetString(1), "b"); + ASSERT_EQ(map_b->values_array()->GetStruct(0)->GetInt32(0), 1); + ASSERT_EQ(map_b->values_array()->GetStruct(1)->GetInt32(0), 2); + ASSERT_FLOAT_EQ(map_b->values_array()->GetStruct(0)->GetFloat(1), 1.1); + ASSERT_FLOAT_EQ(map_b->values_array()->GetStruct(1)->GetFloat(1), 3.3); + ASSERT_EQ(map_b->values_array()->GetStruct(0)->GetBoolean(2), true); + ASSERT_EQ(map_b->values_array()->GetStruct(1)->GetBoolean(2), false); +} + } // namespace test } // namespace fury diff --git a/src/fury/meta/type_traits.h b/src/fury/meta/type_traits.h index 01e440cb92..eb3df69b1a 100644 --- a/src/fury/meta/type_traits.h +++ b/src/fury/meta/type_traits.h @@ -97,6 +97,26 @@ constexpr inline bool IsIterable = template using GetValueType = typename details::GetValueTypeImpl::type; +namespace details { + +template constexpr inline bool IsPair = false; + +template +constexpr inline bool IsPair> = true; + +template std::false_type IsPairIterableImpl(...); + +template < + typename T, + std::enable_if_t && IsPair, int> = 0> +std::true_type IsPairIterableImpl(int); + +} // namespace details + +template +constexpr inline bool IsPairIterable = + decltype(details::IsPairIterableImpl(0))::value; + } // namespace meta } // namespace fury diff --git a/src/fury/meta/type_traits_test.cc b/src/fury/meta/type_traits_test.cc index 5193ca3a47..4686232f69 100644 --- a/src/fury/meta/type_traits_test.cc +++ b/src/fury/meta/type_traits_test.cc @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include "fury/meta/field_info.h" #include "src/fury/meta/type_traits.h" @@ -64,6 +66,11 @@ TEST(Meta, IsUnique) { } TEST(Meta, IsIterable) { + static_assert(!IsIterable); + static_assert(!IsIterable); + static_assert(!IsIterable); + static_assert(!IsIterable); + static_assert(!IsIterable>); static_assert(IsIterable>); static_assert(IsIterable>>); static_assert(IsIterable>); @@ -77,6 +84,20 @@ TEST(Meta, IsIterable) { static_assert(IsIterable); } +TEST(Meta, IsPairIterable) { + static_assert(!IsPairIterable); + static_assert(!IsPairIterable); + static_assert(!IsPairIterable>); + static_assert(!IsPairIterable>>); + static_assert(!IsPairIterable>); + static_assert(!IsPairIterable>); + static_assert(!IsPairIterable>); + static_assert(IsPairIterable>>); + static_assert(IsPairIterable>); + static_assert(IsPairIterable>); + static_assert(IsPairIterable>); +} + } // namespace test } // namespace fury