Skip to content

Commit 7c795c9

Browse files
Dandandankszucs
authored andcommitted
ARROW-11290: [Rust][DataFusion] Address hash aggregate performance issue with high number of groups
Currently, we loop to the hashmap for every key. However, as we receive a batch, if we a lot of groups in the group by expression (or receive sorted data, etc.) then we could create a lot of empty batches and call `update_batch` for each of the key already in the hashmap. In the PR we keep track of which keys we received in the batch and only update the accumulators with the same keys instead of all accumulators. On the db-benchmark h2oai/db-benchmark#182 this is the difference (mainly q3 and q5, others seem to be noise). It doesn't seem to completely solve the problem, but it reduces the problem already quite a bit. This PR: ``` q1 took 340 ms q2 took 1768 ms q3 took 10975 ms q4 took 337 ms q5 took 13529 ms ``` Master: ``` q1 took 330 ms q2 took 1648 ms q3 took 16408 ms q4 took 335 ms q5 took 21074 ms ``` Closes #9234 from Dandandan/hash_agg_speed2 Authored-by: Heres, Daniel <[email protected]> Signed-off-by: Andrew Lamb <[email protected]>
1 parent 0a58593 commit 7c795c9

File tree

1 file changed

+48
-39
lines changed

1 file changed

+48
-39
lines changed

rust/datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 48 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ fn group_aggregate_batch(
288288
// Make sure we can create the accumulators or otherwise return an error
289289
create_accumulators(aggr_expr).map_err(DataFusionError::into_arrow_external_error)?;
290290

291+
// Keys received in this batch
292+
let mut batch_keys = vec![];
293+
291294
for row in 0..batch.num_rows() {
292295
// 1.1
293296
create_key(&group_values, row, &mut key)
@@ -297,11 +300,17 @@ fn group_aggregate_batch(
297300
.raw_entry_mut()
298301
.from_key(&key)
299302
// 1.3
300-
.and_modify(|_, (_, _, v)| v.push(row as u32))
303+
.and_modify(|_, (_, _, v)| {
304+
if v.is_empty() {
305+
batch_keys.push(key.clone())
306+
};
307+
v.push(row as u32)
308+
})
301309
// 1.2
302310
.or_insert_with(|| {
303311
// We can safely unwrap here as we checked we can create an accumulator before
304312
let accumulator_set = create_accumulators(aggr_expr).unwrap();
313+
batch_keys.push(key.clone());
305314
let _ = create_group_by_values(&group_values, row, &mut group_by_values);
306315
(
307316
key.clone(),
@@ -310,48 +319,48 @@ fn group_aggregate_batch(
310319
});
311320
}
312321

313-
// 2.1 for each key
322+
// 2.1 for each key in this batch
314323
// 2.2 for each aggregation
315324
// 2.3 `take` from each of its arrays the keys' values
316325
// 2.4 update / merge the accumulator with the values
317326
// 2.5 clear indices
318-
accumulators
319-
.iter_mut()
320-
.try_for_each(|(_, (_, accumulator_set, indices))| {
321-
// 2.2
322-
accumulator_set
323-
.iter_mut()
324-
.zip(&aggr_input_values)
325-
.map(|(accumulator, aggr_array)| {
326-
(
327-
accumulator,
328-
aggr_array
329-
.iter()
330-
.map(|array| {
331-
// 2.3
332-
compute::take(
333-
array.as_ref(),
334-
&UInt32Array::from(indices.clone()),
335-
None, // None: no index check
336-
)
337-
.unwrap()
338-
})
339-
.collect::<Vec<ArrayRef>>(),
340-
)
341-
})
342-
.try_for_each(|(accumulator, values)| match mode {
343-
AggregateMode::Partial => accumulator.update_batch(&values),
344-
AggregateMode::Final => {
345-
// note: the aggregation here is over states, not values, thus the merge
346-
accumulator.merge_batch(&values)
347-
}
348-
})
349-
// 2.5
350-
.and({
351-
indices.clear();
352-
Ok(())
353-
})
354-
})?;
327+
batch_keys.iter_mut().try_for_each(|key| {
328+
let (_, accumulator_set, indices) = accumulators.get_mut(key).unwrap();
329+
let primitive_indices = UInt32Array::from(indices.clone());
330+
// 2.2
331+
accumulator_set
332+
.iter_mut()
333+
.zip(&aggr_input_values)
334+
.map(|(accumulator, aggr_array)| {
335+
(
336+
accumulator,
337+
aggr_array
338+
.iter()
339+
.map(|array| {
340+
// 2.3
341+
compute::take(
342+
array.as_ref(),
343+
&primitive_indices,
344+
None, // None: no index check
345+
)
346+
.unwrap()
347+
})
348+
.collect::<Vec<ArrayRef>>(),
349+
)
350+
})
351+
.try_for_each(|(accumulator, values)| match mode {
352+
AggregateMode::Partial => accumulator.update_batch(&values),
353+
AggregateMode::Final => {
354+
// note: the aggregation here is over states, not values, thus the merge
355+
accumulator.merge_batch(&values)
356+
}
357+
})
358+
// 2.5
359+
.and({
360+
indices.clear();
361+
Ok(())
362+
})
363+
})?;
355364
Ok(accumulators)
356365
}
357366

0 commit comments

Comments
 (0)