Skip to content

Commit 5d9b045

Browse files
authored
BUG: Pass stream to get_element() and others in cudf (#147)
Maybe `get_element()` didn't support it when this was first added (not sure). These failures are now very reliable on multi-gpu runs only, though. I audited them and the tests are passing now locally with multi-gpu. Signed-off-by: Sebastian Berg <[email protected]>
1 parent a7e3d4d commit 5d9b045

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

cpp/src/core/column.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ std::unique_ptr<cudf::scalar> LogicalColumn::get_cudf_scalar(
311311
if (col->size() != 1) {
312312
throw std::invalid_argument("only length 1/scalar columns can be converted to scalar.");
313313
}
314-
return std::move(cudf::get_element(col->view(), 0));
314+
return std::move(cudf::get_element(col->view(), 0, stream, mr));
315315
}
316316

317317
namespace task {
@@ -367,7 +367,7 @@ std::unique_ptr<cudf::scalar> PhysicalColumn::cudf_scalar() const
367367
if (num_rows() != 1) {
368368
throw std::invalid_argument("can only convert length one columns to scalar.");
369369
}
370-
return cudf::get_element(column_view(), 0);
370+
return cudf::get_element(column_view(), 0, ctx_->stream(), ctx_->mr());
371371
}
372372

373373
void PhysicalColumn::copy_into(std::unique_ptr<cudf::column> column)

cpp/src/sort.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ std::unique_ptr<cudf::column> create_column(cudf::size_type num_rows,
6868
{
6969
if (num_rows == 0) { return cudf::make_empty_column(cudf::data_type{cudf::type_to_id<T>()}); }
7070
return cudf::sequence(num_rows,
71-
*cudf::make_fixed_width_scalar(fill_value),
72-
*cudf::make_fixed_width_scalar(int32_t{0}));
71+
*cudf::make_fixed_width_scalar(fill_value, ctx.stream(), ctx.mr()),
72+
*cudf::make_fixed_width_scalar(int32_t{0}, ctx.stream(), ctx.mr()),
73+
ctx.stream(),
74+
ctx.mr());
7375
}
7476

7577
template <typename T>
@@ -270,15 +272,18 @@ std::vector<cudf::size_type> find_splits_for_distribution(
270272
ctx, sorted_table, global_split_values->view(), keys_idx, column_order, null_precedence);
271273
}
272274

273-
static std::unique_ptr<cudf::table> apply_limit(std::unique_ptr<cudf::table> tbl, int64_t limit)
275+
static std::unique_ptr<cudf::table> apply_limit(TaskContext& ctx,
276+
std::unique_ptr<cudf::table> tbl,
277+
int64_t limit)
274278
{
275279
if (limit != INT64_MIN && std::abs(limit) < tbl->num_rows()) {
276280
cudf::size_type cudf_limit = static_cast<cudf::size_type>(limit);
277281
cudf::table_view slice;
278282
if (limit < 0) {
279-
slice = cudf::slice(tbl->view(), {tbl->num_rows() + cudf_limit, tbl->num_rows()})[0];
283+
slice =
284+
cudf::slice(tbl->view(), {tbl->num_rows() + cudf_limit, tbl->num_rows()}, ctx.stream())[0];
280285
} else {
281-
slice = cudf::slice(tbl->view(), {0, cudf_limit})[0];
286+
slice = cudf::slice(tbl->view(), {0, cudf_limit}, ctx.stream())[0];
282287
}
283288
tbl = std::make_unique<cudf::table>(slice);
284289
}
@@ -324,7 +329,7 @@ static std::unique_ptr<cudf::table> apply_limit(std::unique_ptr<cudf::table> tbl
324329
auto sorted_table =
325330
sort_func(cudf_tbl, key, column_order, null_precedence, ctx.stream(), ctx.mr());
326331

327-
sorted_table = apply_limit(std::move(sorted_table), limit);
332+
sorted_table = apply_limit(ctx, std::move(sorted_table), limit);
328333

329334
if (ctx.nranks == 1) {
330335
output.move_into(sorted_table->release());

cpp/src/strings.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace legate::dataframe::task {
4141
}
4242

4343
std::unique_ptr<cudf::column> ret;
44-
auto cudf_pattern = cudf::string_scalar(pattern);
44+
auto cudf_pattern = cudf::string_scalar(pattern, true, ctx.stream(), ctx.mr());
4545

4646
if (match_func == "starts_with") {
4747
ret = cudf::strings::starts_with(input.column_view(), cudf_pattern, ctx.stream(), ctx.mr());

0 commit comments

Comments
 (0)