Skip to content

Commit 0cb78cb

Browse files
committedJan 7, 2025
Use nanoarrow C++ helpers and iterate stream in accumulations
1 parent 2325b24 commit 0cb78cb

File tree

1 file changed

+66
-63
lines changed

1 file changed

+66
-63
lines changed
 

‎pandas/_libs/arrow_string_accumulations.cc

+66-63
Original file line numberDiff line numberDiff line change
@@ -30,68 +30,73 @@ static auto ReleaseArrowSchema(void *ptr) noexcept -> void {
3030
delete schema;
3131
}
3232

33-
static auto CumSum(const struct ArrowArrayView *array_view,
33+
template <size_t OffsetSize>
34+
static auto CumSum(struct ArrowArrayStream *array_stream,
3435
struct ArrowArray *out, bool skipna) {
3536
bool seen_na = false;
3637
std::stringstream ss{};
3738

38-
for (int64_t i = 0; i < array_view->length; i++) {
39-
const bool isna = ArrowArrayViewIsNull(array_view, i);
40-
if (!skipna && (seen_na || isna)) {
41-
seen_na = true;
42-
ArrowArrayAppendNull(out, 1);
43-
} else {
44-
if (!isna) {
45-
const auto std_sv = ArrowArrayViewGetStringUnsafe(array_view, i);
46-
ss << std::string_view{std_sv.data,
47-
static_cast<size_t>(std_sv.size_bytes)};
39+
nanoarrow::ViewArrayStream array_stream_view(array_stream);
40+
for (const auto &array : array_stream_view) {
41+
for (const auto &sv : nanoarrow::ViewArrayAsBytes<OffsetSize>(&array)) {
42+
if ((!sv || seen_na) && !skipna) {
43+
seen_na = true;
44+
ArrowArrayAppendNull(out, 1);
45+
} else {
46+
if (sv) {
47+
ss << std::string_view{(*sv).data,
48+
static_cast<size_t>((*sv).size_bytes)};
49+
}
50+
const auto str = ss.str();
51+
const ArrowStringView asv{str.c_str(),
52+
static_cast<int64_t>(str.size())};
53+
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, asv));
4854
}
49-
const auto str = ss.str();
50-
const ArrowStringView asv{str.c_str(), static_cast<int64_t>(str.size())};
51-
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, asv));
5255
}
5356
}
5457
}
5558

59+
// TODO: doesn't seem like all compilers in CI support this?
5660
// template <typename T>
5761
// concept MinOrMaxOp =
5862
// std::same_as<T, std::less<>> || std::same_as<T, std::greater<>>;
5963

60-
template <auto Op>
64+
template <size_t OffsetSize, auto Op>
6165
// requires MinOrMaxOp<decltype(Op)>
62-
static auto CumMinOrMax(const struct ArrowArrayView *array_view,
66+
static auto CumMinOrMax(struct ArrowArrayStream *array_stream,
6367
struct ArrowArray *out, bool skipna) {
6468
bool seen_na = false;
6569
std::optional<std::string> current_str{};
6670

67-
for (int64_t i = 0; i < array_view->length; i++) {
68-
const bool isna = ArrowArrayViewIsNull(array_view, i);
69-
if (!skipna && (seen_na || isna)) {
70-
seen_na = true;
71-
ArrowArrayAppendNull(out, 1);
72-
} else {
73-
if (!isna || current_str) {
74-
if (!isna) {
75-
const auto asv = ArrowArrayViewGetStringUnsafe(array_view, i);
76-
const nb::str pyval{asv.data, static_cast<size_t>(asv.size_bytes)};
77-
78-
if (current_str) {
79-
const nb::str pycurrent{current_str->data(), current_str->size()};
80-
if (Op(pyval, pycurrent)) {
81-
current_str =
82-
std::string{asv.data, static_cast<size_t>(asv.size_bytes)};
71+
nanoarrow::ViewArrayStream array_stream_view(array_stream);
72+
for (const auto &array : array_stream_view) {
73+
for (const auto &sv : nanoarrow::ViewArrayAsBytes<OffsetSize>(&array)) {
74+
if ((!sv || seen_na) && !skipna) {
75+
seen_na = true;
76+
ArrowArrayAppendNull(out, 1);
77+
} else {
78+
if (sv || current_str) {
79+
if (sv) {
80+
const nb::str pyval{(*sv).data,
81+
static_cast<size_t>((*sv).size_bytes)};
82+
if (current_str) {
83+
const nb::str pycurrent{current_str->data(), current_str->size()};
84+
if (Op(pyval, pycurrent)) {
85+
current_str = std::string{
86+
(*sv).data, static_cast<size_t>((*sv).size_bytes)};
87+
}
88+
} else {
89+
current_str = std::string{(*sv).data,
90+
static_cast<size_t>((*sv).size_bytes)};
8391
}
84-
} else {
85-
current_str =
86-
std::string{asv.data, static_cast<size_t>(asv.size_bytes)};
8792
}
88-
}
8993

90-
struct ArrowStringView out_sv{
91-
current_str->data(), static_cast<int64_t>(current_str->size())};
92-
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, out_sv));
93-
} else {
94-
ArrowArrayAppendEmpty(out, 1);
94+
struct ArrowStringView out_sv{
95+
current_str->data(), static_cast<int64_t>(current_str->size())};
96+
NANOARROW_THROW_NOT_OK(ArrowArrayAppendString(out, out_sv));
97+
} else {
98+
ArrowArrayAppendEmpty(out, 1);
99+
}
95100
}
96101
}
97102
}
@@ -131,7 +136,6 @@ class ArrowStringAccumulation {
131136
switch (schema_view.type) {
132137
case NANOARROW_TYPE_STRING:
133138
case NANOARROW_TYPE_LARGE_STRING:
134-
case NANOARROW_TYPE_STRING_VIEW:
135139
break;
136140
default:
137141
const auto error_message =
@@ -159,30 +163,29 @@ class ArrowStringAccumulation {
159163

160164
NANOARROW_THROW_NOT_OK(ArrowArrayStartAppending(uarray_out.get()));
161165

162-
nanoarrow::UniqueArray chunk{};
163-
int errcode{};
164-
165-
while ((errcode = ArrowArrayStreamGetNext(stream_.get(), chunk.get(),
166-
nullptr) == 0) &&
167-
chunk->release != nullptr) {
168-
struct ArrowArrayView array_view{};
169-
NANOARROW_THROW_NOT_OK(
170-
ArrowArrayViewInitFromSchema(&array_view, schema_.get(), nullptr));
171-
172-
NANOARROW_THROW_NOT_OK(
173-
ArrowArrayViewSetArray(&array_view, chunk.get(), nullptr));
174-
175-
if (accumulation_ == "cumsum") {
176-
CumSum(&array_view, uarray_out.get(), skipna_);
177-
} else if (accumulation_ == "cummin") {
178-
CumMinOrMax<std::less{}>(&array_view, uarray_out.get(), skipna_);
179-
} else if (accumulation_ == "cummax") {
180-
CumMinOrMax<std::greater{}>(&array_view, uarray_out.get(), skipna_);
166+
if (accumulation_ == "cumsum") {
167+
if (schema_view.type == NANOARROW_TYPE_STRING) {
168+
CumSum<32>(stream_.get(), uarray_out.get(), skipna_);
181169
} else {
182-
throw std::runtime_error("Unexpected branch");
170+
CumSum<64>(stream_.get(), uarray_out.get(), skipna_);
183171
}
184172

185-
chunk.reset();
173+
} else if (accumulation_ == "cummin") {
174+
if (schema_view.type == NANOARROW_TYPE_STRING) {
175+
CumMinOrMax<32, std::less{}>(stream_.get(), uarray_out.get(), skipna_);
176+
} else {
177+
CumMinOrMax<64, std::less{}>(stream_.get(), uarray_out.get(), skipna_);
178+
}
179+
} else if (accumulation_ == "cummax") {
180+
if (schema_view.type == NANOARROW_TYPE_STRING) {
181+
CumMinOrMax<32, std::greater{}>(stream_.get(), uarray_out.get(),
182+
skipna_);
183+
} else {
184+
CumMinOrMax<64, std::greater{}>(stream_.get(), uarray_out.get(),
185+
skipna_);
186+
}
187+
} else {
188+
throw std::runtime_error("Unexpected branch");
186189
}
187190

188191
NANOARROW_THROW_NOT_OK(

0 commit comments

Comments
 (0)
Please sign in to comment.