Skip to content

Commit 83d99d5

Browse files
committed
Fixes for Arrow STL iterator for custom types
1 parent 64cfa7c commit 83d99d5

File tree

2 files changed

+143
-17
lines changed

2 files changed

+143
-17
lines changed

cpp/src/arrow/stl_iterator.h

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,13 @@ class ArrayIterator {
6464
// Value access
6565
value_type operator*() const {
6666
assert(array_);
67-
return array_->IsNull(index_) ? value_type{} : array_->GetView(index_);
67+
return array_->IsNull(index_) ? value_type{} : ValueAccessor{}(*array_, index_);
6868
}
6969

7070
value_type operator[](difference_type n) const {
7171
assert(array_);
72-
return array_->IsNull(index_ + n) ? value_type{} : array_->GetView(index_ + n);
72+
return array_->IsNull(index_ + n) ? value_type{}
73+
: ValueAccessor{}(*array_, index_ + n);
7374
}
7475

7576
int64_t index() const { return index_; }
@@ -154,7 +155,7 @@ class ChunkedArrayIterator {
154155
// Value access
155156
value_type operator*() const {
156157
auto chunk_location = GetChunkLocation(index_);
157-
ArrayIterator<ArrayType> target_iterator{
158+
ArrayIterator<ArrayType, ValueAccessor> target_iterator{
158159
arrow::internal::checked_cast<const ArrayType&>(
159160
*chunked_array_->chunk(static_cast<int>(chunk_location.chunk_index)))};
160161
return target_iterator[chunk_location.index_in_chunk];
@@ -247,33 +248,39 @@ class ChunkedArrayIterator {
247248
};
248249

249250
/// Return an iterator to the beginning of the chunked array
250-
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType>
251-
ChunkedArrayIterator<ArrayType> Begin(const ChunkedArray& chunked_array) {
252-
return ChunkedArrayIterator<ArrayType>(chunked_array);
251+
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType,
252+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
253+
ChunkedArrayIterator<ArrayType, ValueAccessor> Begin(const ChunkedArray& chunked_array) {
254+
return ChunkedArrayIterator<ArrayType, ValueAccessor>(chunked_array);
253255
}
254256

255257
/// Return an iterator to the end of the chunked array
256-
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType>
257-
ChunkedArrayIterator<ArrayType> End(const ChunkedArray& chunked_array) {
258-
return ChunkedArrayIterator<ArrayType>(chunked_array, chunked_array.length());
258+
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType,
259+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
260+
ChunkedArrayIterator<ArrayType, ValueAccessor> End(const ChunkedArray& chunked_array) {
261+
return ChunkedArrayIterator<ArrayType, ValueAccessor>(chunked_array,
262+
chunked_array.length());
259263
}
260264

261-
template <typename ArrayType>
265+
template <typename ArrayType,
266+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
262267
struct ChunkedArrayRange {
263268
const ChunkedArray* chunked_array;
264269

265-
ChunkedArrayIterator<ArrayType> begin() {
266-
return stl::ChunkedArrayIterator<ArrayType>(*chunked_array);
270+
ChunkedArrayIterator<ArrayType, ValueAccessor> begin() {
271+
return stl::ChunkedArrayIterator<ArrayType, ValueAccessor>(*chunked_array);
267272
}
268-
ChunkedArrayIterator<ArrayType> end() {
269-
return stl::ChunkedArrayIterator<ArrayType>(*chunked_array, chunked_array->length());
273+
ChunkedArrayIterator<ArrayType, ValueAccessor> end() {
274+
return stl::ChunkedArrayIterator<ArrayType, ValueAccessor>(*chunked_array,
275+
chunked_array->length());
270276
}
271277
};
272278

273279
/// Return an iterable range over the chunked array
274-
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType>
275-
ChunkedArrayRange<ArrayType> Iterate(const ChunkedArray& chunked_array) {
276-
return stl::ChunkedArrayRange<ArrayType>{&chunked_array};
280+
template <typename Type, typename ArrayType = typename TypeTraits<Type>::ArrayType,
281+
typename ValueAccessor = detail::DefaultValueAccessor<ArrayType>>
282+
ChunkedArrayRange<ArrayType, ValueAccessor> Iterate(const ChunkedArray& chunked_array) {
283+
return stl::ChunkedArrayRange<ArrayType, ValueAccessor>{&chunked_array};
277284
}
278285

279286
} // namespace stl

cpp/src/arrow/stl_iterator_test.cc

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,68 @@ TEST(ArrayIterator, StdMerge) {
248248
ASSERT_EQ(values, expected);
249249
}
250250

251+
// Custom ValueAccessor for DictionaryArray that decodes values
252+
struct TestDictionaryValueAccessor {
253+
using ValueType = std::string_view;
254+
255+
inline ValueType operator()(const DictionaryArray& array, int64_t index) {
256+
// Get the dictionary index for this position
257+
int64_t dict_index = array.GetValueIndex(index);
258+
259+
// Get the dictionary and cast it to StringArray
260+
auto dict = checked_pointer_cast<StringArray>(array.dictionary());
261+
262+
// Return the decoded string value
263+
return dict->GetView(dict_index);
264+
}
265+
};
266+
267+
TEST(ArrayIterator, CustomValueAccessorDictionary) {
268+
// Create a dictionary array with string values
269+
auto dict = ArrayFromJSON(utf8(), R"(["apple", "banana", "cherry", "date"])");
270+
auto indices = ArrayFromJSON(int32(), "[0, 1, 2, 3, 2, 1, 0, null, 3]");
271+
272+
auto dict_type = dictionary(int32(), utf8());
273+
auto dict_array = std::make_shared<DictionaryArray>(dict_type, indices, dict);
274+
275+
// Use custom accessor to iterate over decoded values
276+
ArrayIterator<DictionaryArray, TestDictionaryValueAccessor> it(*dict_array);
277+
278+
// Test basic access
279+
ASSERT_EQ(*it, "apple");
280+
ASSERT_EQ(it[1], "banana");
281+
ASSERT_EQ(it[2], "cherry");
282+
ASSERT_EQ(it[3], "date");
283+
ASSERT_EQ(it[4], "cherry");
284+
ASSERT_EQ(it[5], "banana");
285+
ASSERT_EQ(it[6], "apple");
286+
ASSERT_EQ(it[7], nullopt); // null index
287+
ASSERT_EQ(it[8], "date");
288+
289+
// Test iteration
290+
std::vector<optional<std::string_view>> values;
291+
for (auto end = it + 9; it != end; ++it) {
292+
values.push_back(*it);
293+
}
294+
295+
std::vector<optional<std::string_view>> expected{
296+
"apple", "banana", "cherry", "date", "cherry", "banana", "apple", nullopt, "date"};
297+
ASSERT_EQ(values, expected);
298+
299+
// Test with algorithms - find a specific value
300+
ArrayIterator<DictionaryArray, TestDictionaryValueAccessor> begin(*dict_array);
301+
ArrayIterator<DictionaryArray, TestDictionaryValueAccessor> end(*dict_array,
302+
dict_array->length());
303+
304+
auto found = std::find(begin, end, "cherry");
305+
ASSERT_NE(found, end);
306+
ASSERT_EQ(found.index(), 2); // First occurrence of "cherry"
307+
308+
// Count occurrences of "banana"
309+
auto count = std::count(begin, end, "banana");
310+
ASSERT_EQ(count, 2);
311+
}
312+
251313
TEST(ChunkedArrayIterator, Basics) {
252314
auto result = ChunkedArrayFromJSON(int32(), {R"([4, 5, null])", R"([6])"});
253315
auto it = Begin<Int32Type>(*result);
@@ -545,5 +607,62 @@ TEST(ChunkedArrayIterator, ForEachIterator) {
545607
ASSERT_EQ(values, expected);
546608
}
547609

610+
TEST(ChunkedArrayIterator, CustomValueAccessorDictionary) {
611+
// Create multiple dictionary arrays with the same dictionary
612+
auto dict = ArrayFromJSON(utf8(), R"(["red", "green", "blue", "yellow"])");
613+
614+
auto indices1 = ArrayFromJSON(int32(), "[0, 1, 2]");
615+
auto indices2 = ArrayFromJSON(int32(), "[3, 2, null]");
616+
auto indices3 = ArrayFromJSON(int32(), "[1, 0, 3, 2]");
617+
618+
auto dict_type = dictionary(int32(), utf8());
619+
auto dict_array1 = std::make_shared<DictionaryArray>(dict_type, indices1, dict);
620+
auto dict_array2 = std::make_shared<DictionaryArray>(dict_type, indices2, dict);
621+
auto dict_array3 = std::make_shared<DictionaryArray>(dict_type, indices3, dict);
622+
623+
// Create chunked array from dictionary arrays
624+
auto chunked_array = std::make_shared<ChunkedArray>(
625+
std::vector<std::shared_ptr<Array>>{dict_array1, dict_array2, dict_array3},
626+
dict_type);
627+
628+
// Use custom accessor to iterate over decoded values across chunks
629+
auto it =
630+
Begin<DictionaryType, DictionaryArray, TestDictionaryValueAccessor>(*chunked_array);
631+
auto end =
632+
End<DictionaryType, DictionaryArray, TestDictionaryValueAccessor>(*chunked_array);
633+
634+
// Test sequential access across chunks
635+
ASSERT_EQ(*it, "red"); // chunk 0, index 0
636+
ASSERT_EQ(*(it + 1), "green"); // chunk 0, index 1
637+
ASSERT_EQ(*(it + 2), "blue"); // chunk 0, index 2
638+
ASSERT_EQ(*(it + 3), "yellow"); // chunk 1, index 0
639+
ASSERT_EQ(*(it + 4), "blue"); // chunk 1, index 1
640+
ASSERT_EQ(*(it + 5), nullopt); // chunk 1, index 2 (null)
641+
ASSERT_EQ(*(it + 6), "green"); // chunk 2, index 0
642+
ASSERT_EQ(*(it + 7), "red"); // chunk 2, index 1
643+
ASSERT_EQ(*(it + 8), "yellow"); // chunk 2, index 2
644+
ASSERT_EQ(*(it + 9), "blue"); // chunk 2, index 3
645+
646+
// Collect all values
647+
std::vector<optional<std::string_view>> values;
648+
for (auto iter = it; iter != end; ++iter) {
649+
values.push_back(*iter);
650+
}
651+
652+
std::vector<optional<std::string_view>> expected{"red", "green", "blue", "yellow",
653+
"blue", nullopt, "green", "red",
654+
"yellow", "blue"};
655+
ASSERT_EQ(values, expected);
656+
657+
// Test with algorithms - count occurrences of "blue"
658+
auto count = std::count(it, end, "blue");
659+
ASSERT_EQ(count, 3);
660+
661+
// Find first occurrence of "yellow"
662+
auto found = std::find(it, end, "yellow");
663+
ASSERT_NE(found, end);
664+
ASSERT_EQ(found.index(), 3);
665+
}
666+
548667
} // namespace stl
549668
} // namespace arrow

0 commit comments

Comments
 (0)