Skip to content

Commit

Permalink
[C++] Support mapping types for RowEncodeTrait
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice committed Dec 24, 2023
1 parent fd67af5 commit 41286ee
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 2 deletions.
77 changes: 75 additions & 2 deletions src/fury/encoder/row_encode_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "fury/meta/field_info.h"
#include "fury/meta/type_traits.h"
#include "fury/row/writer.h"
#include <memory>
#include <string_view>
#include <type_traits>
#include <utility>
Expand Down Expand Up @@ -67,16 +68,20 @@ template <typename T>
inline constexpr bool IsString =
meta::IsOneOf<T, std::string, std::string_view>::value;

template <typename T> inline constexpr bool IsMap = meta::IsPairIterable<T>;

template <typename T>
inline constexpr bool IsArray = meta::IsIterable<T> && !IsString<T>;
inline constexpr bool IsArray =
meta::IsIterable<T> && !IsString<T> && !IsMap<T>;

template <typename> inline constexpr bool IsOptional = false;

template <typename T> inline constexpr bool IsOptional<std::optional<T>> = true;

template <typename T>
inline constexpr bool IsClassButNotBuiltin =
std::is_class_v<T> && !(IsString<T> || IsArray<T> || IsOptional<T>);
std::is_class_v<T> &&
!(IsString<T> || IsArray<T> || IsOptional<T> || IsMap<T>);

inline decltype(auto) GetChildType(RowWriter &writer, int index) {
return writer.schema()->field(index)->type();
Expand Down Expand Up @@ -263,6 +268,74 @@ struct RowEncodeTrait<T,
}
};

template <typename T>
struct RowEncodeTrait<T,
std::enable_if_t<details::IsMap<std::remove_cv_t<T>>>> {
static auto Type() {
return arrow::map(
RowEncodeTrait<typename T::value_type::first_type>::Type(),
RowEncodeTrait<typename T::value_type::second_type>::Type());
}

template <typename V>
static void WriteKey(V &&visitor, const T &value, ArrayWriter &writer) {
int index = 0;
for (const auto &v : value) {
RowEncodeTrait<typename T::value_type::first_type>::Write(
std::forward<V>(visitor), v.first, writer, index);
++index;
}
}

template <typename V>
static void WriteValue(V &&visitor, const T &value, ArrayWriter &writer) {
int index = 0;
for (const auto &v : value) {
RowEncodeTrait<typename T::value_type::second_type>::Write(
std::forward<V>(visitor), v.second, writer, index);
++index;
}
}

template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&visitor, const T &value, W &writer, int index) {
auto offset = writer.cursor();
writer.WriteDirectly(-1);

auto map_type = std::dynamic_pointer_cast<arrow::MapType>(
details::GetChildType(writer, index));

auto key_writer =
std::make_unique<ArrayWriter>(std::static_pointer_cast<arrow::ListType>(
arrow::list(map_type->key_type())),
&writer);

key_writer->Reset(value.size());
RowEncodeTrait<T>::WriteKey(std::forward<V>(visitor), value,
*key_writer.get());

writer.WriteDirectly(offset, key_writer->size());

auto value_writer =
std::make_unique<ArrayWriter>(std::static_pointer_cast<arrow::ListType>(
arrow::list(map_type->item_type())),
&writer);

value_writer->Reset(value.size());
RowEncodeTrait<T>::WriteValue(std::forward<V>(visitor), value,
*value_writer.get());

writer.SetOffsetAndSize(index, offset, writer.cursor() - offset);

std::forward<V>(visitor).template Visit<std::remove_cv_t<T>>(
std::move(key_writer));
std::forward<V>(visitor).template Visit<std::remove_cv_t<T>>(
std::move(value_writer));
}
};

} // namespace encoder

} // namespace fury
64 changes: 64 additions & 0 deletions src/fury/encoder/row_encode_trait_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,70 @@ TEST(RowEncodeTrait, Optional) {
}
}

struct G {
std::map<int, std::map<int, int>> a;
std::map<std::string, A> b;
};

FURY_FIELD_INFO(G, a, b);

TEST(RowEncodeTrait, Map) {
G v{{{1, {{3, 4}, {5, 6}}}, {2, {{7, 8}, {9, 10}, {11, 12}}}},
{{"a", A{1, 1.1, true}}, {"b", A{2, 3.3, false}}}};

auto schema = encoder::RowEncodeTrait<G>::Type();

auto a_map =
std::dynamic_pointer_cast<arrow::MapType>(schema->field(0)->type());
ASSERT_EQ(a_map->key_type()->name(), "int32");
ASSERT_EQ(a_map->item_type()->name(), "map");
ASSERT_EQ(std::dynamic_pointer_cast<arrow::MapType>(a_map->item_type())
->key_type()
->name(),
"int32");
ASSERT_EQ(std::dynamic_pointer_cast<arrow::MapType>(a_map->item_type())
->item_type()
->name(),
"int32");

auto b_map =
std::dynamic_pointer_cast<arrow::MapType>(schema->field(1)->type());
ASSERT_EQ(b_map->key_type()->name(), "utf8");
ASSERT_EQ(b_map->item_type()->name(), "struct");
ASSERT_EQ(b_map->item_type()->field(0)->type()->name(), "int32");
ASSERT_EQ(b_map->item_type()->field(1)->type()->name(), "float");
ASSERT_EQ(b_map->item_type()->field(2)->type()->name(), "bool");

RowWriter writer(encoder::RowEncodeTrait<G>::Schema());
writer.Reset();

encoder::RowEncodeTrait<G>::Write(encoder::EmptyWriteVisitor{}, v, writer);

auto map_a = writer.ToRow()->GetMap(0);
ASSERT_EQ(map_a->keys_array()->GetInt32(0), 1);
ASSERT_EQ(map_a->keys_array()->GetInt32(1), 2);
ASSERT_EQ(map_a->values_array()->GetMap(0)->keys_array()->GetInt32(0), 3);
ASSERT_EQ(map_a->values_array()->GetMap(0)->keys_array()->GetInt32(1), 5);
ASSERT_EQ(map_a->values_array()->GetMap(0)->values_array()->GetInt32(0), 4);
ASSERT_EQ(map_a->values_array()->GetMap(0)->values_array()->GetInt32(1), 6);
ASSERT_EQ(map_a->values_array()->GetMap(1)->keys_array()->GetInt32(0), 7);
ASSERT_EQ(map_a->values_array()->GetMap(1)->keys_array()->GetInt32(1), 9);
ASSERT_EQ(map_a->values_array()->GetMap(1)->keys_array()->GetInt32(2), 11);
ASSERT_EQ(map_a->values_array()->GetMap(1)->values_array()->GetInt32(0), 8);
ASSERT_EQ(map_a->values_array()->GetMap(1)->values_array()->GetInt32(1), 10);
ASSERT_EQ(map_a->values_array()->GetMap(1)->values_array()->GetInt32(2), 12);

auto map_b = writer.ToRow()->GetMap(1);
ASSERT_EQ(map_b->keys_array()->GetString(0), "a");
ASSERT_EQ(map_b->keys_array()->GetString(1), "b");
ASSERT_EQ(map_b->values_array()->GetStruct(0)->GetInt32(0), 1);
ASSERT_EQ(map_b->values_array()->GetStruct(1)->GetInt32(0), 2);
ASSERT_FLOAT_EQ(map_b->values_array()->GetStruct(0)->GetFloat(1), 1.1);
ASSERT_FLOAT_EQ(map_b->values_array()->GetStruct(1)->GetFloat(1), 3.3);
ASSERT_EQ(map_b->values_array()->GetStruct(0)->GetBoolean(2), true);
ASSERT_EQ(map_b->values_array()->GetStruct(1)->GetBoolean(2), false);
}

} // namespace test

} // namespace fury
Expand Down
20 changes: 20 additions & 0 deletions src/fury/meta/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,26 @@ constexpr inline bool IsIterable =
template <typename T>
using GetValueType = typename details::GetValueTypeImpl<T>::type;

namespace details {

template <typename> constexpr inline bool IsPair = false;

template <typename T1, typename T2>
constexpr inline bool IsPair<std::pair<T1, T2>> = true;

template <typename> std::false_type IsPairIterableImpl(...);

template <
typename T,
std::enable_if_t<IsIterable<T> && IsPair<typename T::value_type>, int> = 0>
std::true_type IsPairIterableImpl(int);

} // namespace details

template <typename T>
constexpr inline bool IsPairIterable =
decltype(details::IsPairIterableImpl<T>(0))::value;

} // namespace meta

} // namespace fury
21 changes: 21 additions & 0 deletions src/fury/meta/type_traits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <deque>
#include <initializer_list>
#include <list>
#include <queue>
#include <type_traits>

#include "fury/meta/field_info.h"
#include "src/fury/meta/type_traits.h"
Expand Down Expand Up @@ -64,6 +66,11 @@ TEST(Meta, IsUnique) {
}

TEST(Meta, IsIterable) {
static_assert(!IsIterable<int>);
static_assert(!IsIterable<const bool>);
static_assert(!IsIterable<int &>);
static_assert(!IsIterable<std::false_type>);
static_assert(!IsIterable<std::queue<int>>);
static_assert(IsIterable<std::vector<int>>);
static_assert(IsIterable<std::vector<std::vector<int>>>);
static_assert(IsIterable<std::deque<float>>);
Expand All @@ -77,6 +84,20 @@ TEST(Meta, IsIterable) {
static_assert(IsIterable<std::string_view>);
}

TEST(Meta, IsPairIterable) {
static_assert(!IsPairIterable<int>);
static_assert(!IsPairIterable<std::string>);
static_assert(!IsPairIterable<std::vector<int>>);
static_assert(!IsPairIterable<std::vector<std::vector<int>>>);
static_assert(!IsPairIterable<std::deque<float>>);
static_assert(!IsPairIterable<std::list<int>>);
static_assert(!IsPairIterable<std::set<int>>);
static_assert(IsPairIterable<std::map<int, std::vector<unsigned>>>);
static_assert(IsPairIterable<std::map<std::string, int>>);
static_assert(IsPairIterable<std::multimap<std::string, bool>>);
static_assert(IsPairIterable<std::unordered_map<std::string, float>>);
}

} // namespace test

} // namespace fury
Expand Down

0 comments on commit 41286ee

Please sign in to comment.