Skip to content

Commit 6be4cfd

Browse files
committed
Implement Yaroslavskiy-Bentley-Bloch Quicksort.
1 parent b6628c6 commit 6be4cfd

File tree

2 files changed

+149
-108
lines changed

2 files changed

+149
-108
lines changed

src/sort.rs

+134-100
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use indexmap::IndexMap;
22
use ndarray::prelude::*;
33
use ndarray::{Data, DataMut, Slice};
4-
use rand::prelude::*;
5-
use rand::thread_rng;
64

75
/// Methods for sorting and partitioning 1-D arrays.
86
pub trait Sort1dExt<A, S>
@@ -50,26 +48,21 @@ where
5048
S: DataMut,
5149
S2: Data<Elem = usize>;
5250

53-
/// Partitions the array in increasing order based on the value initially
54-
/// located at `pivot_index` and returns the new index of the value.
51+
/// Partitions the array in increasing order based on the values initially located at `0` and
52+
/// `n - 1` where `n` is the number of elements in the array and returns the new indexes of the
53+
/// values.
5554
///
56-
/// The elements are rearranged in such a way that the value initially
57-
/// located at `pivot_index` is moved to the position it would be in an
58-
/// array sorted in increasing order. The return value is the new index of
59-
/// the value after rearrangement. All elements smaller than the value are
60-
/// moved to its left and all elements equal or greater than the value are
61-
/// moved to its right. The ordering of the elements in the two partitions
62-
/// is undefined.
55+
/// The elements are rearranged in such a way that the values initially located at `0` and
56+
/// `n - 1` are moved to the position it would be in an array sorted in increasing order. The
57+
/// return values are the new indexes of the values after rearrangement. All elements less than
58+
/// the values are moved to their left and all elements equal or greater than the values are
59+
/// moved to their right. The ordering of the elements in the three partitions is undefined.
6360
///
64-
/// `self` is shuffled **in place** to operate the desired partition:
65-
/// no copy of the array is allocated.
61+
/// The array is shuffled **in place**, no copy of the array is allocated.
6662
///
67-
/// The method uses Hoare's partition algorithm.
68-
/// Complexity: O(`n`), where `n` is the number of elements in the array.
69-
/// Average number of element swaps: n/6 - 1/3 (see
70-
/// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
63+
/// This method implements the partitioning scheme of [Yaroslavskiy-Bentley-Bloch Quicksort].
7164
///
72-
/// **Panics** if `pivot_index` is greater than or equal to `n`.
65+
/// [Yaroslavskiy-Bentley-Bloch Quicksort]: https://api.semanticscholar.org/CorpusID:51871084
7366
///
7467
/// # Example
7568
///
@@ -78,23 +71,30 @@ where
7871
/// use ndarray_stats::Sort1dExt;
7972
///
8073
/// let mut data = array![3, 1, 4, 5, 2];
81-
/// let pivot_index = 2;
82-
/// let pivot_value = data[pivot_index];
74+
/// // Sorted pivot values.
75+
/// let (lower_value, upper_value) = (data[data.len() - 1], data[0]);
8376
///
84-
/// // Partition by the value located at `pivot_index`.
85-
/// let new_index = data.partition_mut(pivot_index);
86-
/// // The pivot value is now located at `new_index`.
87-
/// assert_eq!(data[new_index], pivot_value);
88-
/// // Elements less than that value are moved to the left.
89-
/// for i in 0..new_index {
90-
/// assert!(data[i] < pivot_value);
77+
/// // Partitions by the values located at `0` and `data.len() - 1`.
78+
/// let (lower_index, upper_index) = data.partition_mut();
79+
/// // The pivot values are now located at `lower_index` and `upper_index`.
80+
/// assert_eq!(data[lower_index], lower_value);
81+
/// assert_eq!(data[upper_index], upper_value);
82+
/// // Elements lower than the lower pivot value are moved to its left.
83+
/// for i in 0..lower_index {
84+
/// assert!(data[i] < lower_value);
85+
/// }
86+
/// // Elements greater than or equal the lower pivot value and less than or equal the upper
87+
/// // pivot value are moved between the two pivot indexes.
88+
/// for i in lower_index + 1..upper_index {
89+
/// assert!(lower_value <= data[i]);
90+
/// assert!(data[i] <= upper_value);
9191
/// }
92-
/// // Elements greater than or equal to that value are moved to the right.
93-
/// for i in (new_index + 1)..data.len() {
94-
/// assert!(data[i] >= pivot_value);
92+
/// // Elements greater than or equal the upper pivot value are moved to its right.
93+
/// for i in upper_index + 1..data.len() {
94+
/// assert!(upper_value <= data[i]);
9595
/// }
9696
/// ```
97-
fn partition_mut(&mut self, pivot_index: usize) -> usize
97+
fn partition_mut(&mut self) -> (usize, usize)
9898
where
9999
A: Ord + Clone,
100100
S: DataMut;
@@ -115,17 +115,20 @@ where
115115
if n == 1 {
116116
self[0].clone()
117117
} else {
118-
let mut rng = thread_rng();
119-
let pivot_index = rng.gen_range(0..n);
120-
let partition_index = self.partition_mut(pivot_index);
121-
if i < partition_index {
122-
self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
118+
let (lower_index, upper_index) = self.partition_mut();
119+
if i < lower_index {
120+
self.slice_axis_mut(Axis(0), Slice::from(..lower_index))
123121
.get_from_sorted_mut(i)
124-
} else if i == partition_index {
122+
} else if i == lower_index {
123+
self[i].clone()
124+
} else if i < upper_index {
125+
self.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index))
126+
.get_from_sorted_mut(i - (lower_index + 1))
127+
} else if i == upper_index {
125128
self[i].clone()
126129
} else {
127-
self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
128-
.get_from_sorted_mut(i - (partition_index + 1))
130+
self.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..))
131+
.get_from_sorted_mut(i - (upper_index + 1))
129132
}
130133
}
131134
}
@@ -143,42 +146,51 @@ where
143146
get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
144147
}
145148

146-
fn partition_mut(&mut self, pivot_index: usize) -> usize
149+
fn partition_mut(&mut self) -> (usize, usize)
147150
where
148151
A: Ord + Clone,
149152
S: DataMut,
150153
{
151-
let pivot_value = self[pivot_index].clone();
152-
self.swap(pivot_index, 0);
153-
let n = self.len();
154-
let mut i = 1;
155-
let mut j = n - 1;
156-
loop {
157-
loop {
158-
if i > j {
159-
break;
160-
}
161-
if self[i] >= pivot_value {
162-
break;
154+
// Sort `lowermost` and `uppermost` elements and use them as dual pivot.
155+
let lowermost = 0;
156+
let uppermost = self.len() - 1;
157+
if self[lowermost] > self[uppermost] {
158+
self.swap(lowermost, uppermost);
159+
}
160+
// Increasing running and partition index starting after lower pivot.
161+
let mut index = lowermost + 1;
162+
let mut lower = lowermost + 1;
163+
// Decreasing partition index starting before upper pivot.
164+
let mut upper = uppermost - 1;
165+
// Swap elements at `index` into their partitions.
166+
while index <= upper {
167+
if self[index] < self[lowermost] {
168+
// Swap elements into lower partition.
169+
self.swap(index, lower);
170+
lower += 1;
171+
} else if self[index] >= self[uppermost] {
172+
// Search first element of upper partition.
173+
while self[upper] > self[uppermost] && index < upper {
174+
upper -= 1;
163175
}
164-
i += 1;
165-
}
166-
while pivot_value <= self[j] {
167-
if j == 1 {
168-
break;
176+
// Swap elements into upper partition.
177+
self.swap(index, upper);
178+
if self[index] < self[lowermost] {
179+
// Swap swapped elements into lower partition.
180+
self.swap(index, lower);
181+
lower += 1;
169182
}
170-
j -= 1;
171-
}
172-
if i >= j {
173-
break;
174-
} else {
175-
self.swap(i, j);
176-
i += 1;
177-
j -= 1;
183+
upper -= 1;
178184
}
185+
index += 1;
179186
}
180-
self.swap(0, i - 1);
181-
i - 1
187+
lower -= 1;
188+
upper += 1;
189+
// Swap pivots to their new indexes.
190+
self.swap(lowermost, lower);
191+
self.swap(uppermost, upper);
192+
// Lower and upper pivot index.
193+
(lower, upper)
182194
}
183195

184196
private_impl! {}
@@ -249,50 +261,72 @@ fn _get_many_from_sorted_mut_unchecked<A>(
249261
return;
250262
}
251263

252-
// We pick a random pivot index: the corresponding element is the pivot value
253-
let mut rng = thread_rng();
254-
let pivot_index = rng.gen_range(0..n);
264+
// We partition the array with respect to the two pivot values. The pivot values move to
265+
// `lower_index` and `upper_index`.
266+
//
267+
// Elements strictly less than the lower pivot value have indexes < `lower_index`.
268+
//
269+
// Elements greater than or equal the lower pivot value and less than or equal the upper pivot
270+
// value have indexes > `lower_index` and < `upper_index`.
271+
//
272+
// Elements less than or equal the upper pivot value have indexes > `upper_index`.
273+
let (lower_index, upper_index) = array.partition_mut();
255274

256-
// We partition the array with respect to the pivot value.
257-
// The pivot value moves to `array_partition_index`.
258-
// Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
259-
// Elements greater or equal to the pivot value have indexes > `array_partition_index`.
260-
let array_partition_index = array.partition_mut(pivot_index);
275+
// We use a divide-and-conquer strategy, splitting the indexes we are searching for (`indexes`)
276+
// and the corresponding portions of the output slice (`values`) into partitions with respect to
277+
// `lower_index` and `upper_index`.
278+
let (found_exact, split_index) = match indexes.binary_search(&lower_index) {
279+
Ok(index) => (true, index),
280+
Err(index) => (false, index),
281+
};
282+
let (lower_indexes, inner_indexes) = indexes.split_at_mut(split_index);
283+
let (lower_values, inner_values) = values.split_at_mut(split_index);
284+
let (upper_indexes, upper_values) = if found_exact {
285+
inner_values[0] = array[lower_index].clone(); // Write exactly found value.
286+
(&mut inner_indexes[1..], &mut inner_values[1..])
287+
} else {
288+
(inner_indexes, inner_values)
289+
};
261290

262-
// We use a divide-and-conquer strategy, splitting the indexes we are
263-
// searching for (`indexes`) and the corresponding portions of the output
264-
// slice (`values`) into pieces with respect to `array_partition_index`.
265-
let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
291+
let (found_exact, split_index) = match upper_indexes.binary_search(&upper_index) {
266292
Ok(index) => (true, index),
267293
Err(index) => (false, index),
268294
};
269-
let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
270-
let (smaller_values, other_values) = values.split_at_mut(index_split);
271-
let (bigger_indexes, bigger_values) = if found_exact {
272-
other_values[0] = array[array_partition_index].clone(); // Write exactly found value.
273-
(&mut other_indexes[1..], &mut other_values[1..])
295+
let (inner_indexes, upper_indexes) = upper_indexes.split_at_mut(split_index);
296+
let (inner_values, upper_values) = upper_values.split_at_mut(split_index);
297+
let (upper_indexes, upper_values) = if found_exact {
298+
upper_values[0] = array[upper_index].clone(); // Write exactly found value.
299+
(&mut upper_indexes[1..], &mut upper_values[1..])
274300
} else {
275-
(other_indexes, other_values)
301+
(upper_indexes, upper_values)
276302
};
277303

278-
// We search recursively for the values corresponding to strictly smaller
279-
// indexes to the left of `partition_index`.
304+
// We search recursively for the values corresponding to indexes strictly less than
305+
// `lower_index` in the lower partition.
306+
_get_many_from_sorted_mut_unchecked(
307+
array.slice_axis_mut(Axis(0), Slice::from(..lower_index)),
308+
lower_indexes,
309+
lower_values,
310+
);
311+
312+
// We search recursively for the values corresponding to indexes greater than or equal
313+
// `lower_index` in the inner partition, that is between the lower and upper partition. Since
314+
// only the inner partition of the array is passed in, the indexes need to be shifted by length
315+
// of the lower partition.
316+
inner_indexes.iter_mut().for_each(|x| *x -= lower_index + 1);
280317
_get_many_from_sorted_mut_unchecked(
281-
array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
282-
smaller_indexes,
283-
smaller_values,
318+
array.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index)),
319+
inner_indexes,
320+
inner_values,
284321
);
285322

286-
// We search recursively for the values corresponding to strictly bigger
287-
// indexes to the right of `partition_index`. Since only the right portion
288-
// of the array is passed in, the indexes need to be shifted by length of
289-
// the removed portion.
290-
bigger_indexes
291-
.iter_mut()
292-
.for_each(|x| *x -= array_partition_index + 1);
323+
// We search recursively for the values corresponding to indexes greater than or equal
324+
// `upper_index` in the upper partition. Since only the upper partition of the array is passed
325+
// in, the indexes need to be shifted by the combined length of the lower and inner partition.
326+
upper_indexes.iter_mut().for_each(|x| *x -= upper_index + 1);
293327
_get_many_from_sorted_mut_unchecked(
294-
array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
295-
bigger_indexes,
296-
bigger_values,
328+
array.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..)),
329+
upper_indexes,
330+
upper_values,
297331
);
298332
}

tests/sort.rs

+15-8
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,22 @@ fn test_partition_mut() {
1919
];
2020
for a in l.iter_mut() {
2121
let n = a.len();
22-
let pivot_index = n - 1;
23-
let pivot_value = a[pivot_index].clone();
24-
let partition_index = a.partition_mut(pivot_index);
25-
for i in 0..partition_index {
26-
assert!(a[i] < pivot_value);
22+
let (mut lower_value, mut upper_value) = (a[0].clone(), a[n - 1].clone());
23+
if lower_value > upper_value {
24+
std::mem::swap(&mut lower_value, &mut upper_value);
2725
}
28-
assert_eq!(a[partition_index], pivot_value);
29-
for j in (partition_index + 1)..n {
30-
assert!(pivot_value <= a[j]);
26+
let (lower_index, upper_index) = a.partition_mut();
27+
for i in 0..lower_index {
28+
assert!(a[i] < lower_value);
29+
}
30+
assert_eq!(a[lower_index], lower_value);
31+
for i in lower_index + 1..upper_index {
32+
assert!(lower_value <= a[i]);
33+
assert!(a[i] <= upper_value);
34+
}
35+
assert_eq!(a[upper_index], upper_value);
36+
for i in (upper_index + 1)..n {
37+
assert!(upper_value <= a[i]);
3138
}
3239
}
3340
}

0 commit comments

Comments
 (0)