1
1
use indexmap:: IndexMap ;
2
2
use ndarray:: prelude:: * ;
3
3
use ndarray:: { Data , DataMut , Slice } ;
4
- use rand:: prelude:: * ;
5
- use rand:: thread_rng;
6
4
7
5
/// Methods for sorting and partitioning 1-D arrays.
8
6
pub trait Sort1dExt < A , S >
@@ -50,26 +48,21 @@ where
50
48
S : DataMut ,
51
49
S2 : Data < Elem = usize > ;
52
50
53
- /// Partitions the array in increasing order based on the value initially
54
- /// located at `pivot_index` and returns the new index of the value.
51
+ /// Partitions the array in increasing order based on the values initially located at `0` and
52
+ /// `n - 1` where `n` is the number of elements in the array and returns the new indexes of the
53
+ /// values.
55
54
///
56
- /// The elements are rearranged in such a way that the value initially
57
- /// located at `pivot_index` is moved to the position it would be in an
58
- /// array sorted in increasing order. The return value is the new index of
59
- /// the value after rearrangement. All elements smaller than the value are
60
- /// moved to its left and all elements equal or greater than the value are
61
- /// moved to its right. The ordering of the elements in the two partitions
62
- /// is undefined.
55
+ /// The elements are rearranged in such a way that the values initially located at `0` and
56
+ /// `n - 1` are moved to the position it would be in an array sorted in increasing order. The
57
+ /// return values are the new indexes of the values after rearrangement. All elements less than
58
+ /// the values are moved to their left and all elements equal or greater than the values are
59
+ /// moved to their right. The ordering of the elements in the three partitions is undefined.
63
60
///
64
- /// `self` is shuffled **in place** to operate the desired partition:
65
- /// no copy of the array is allocated.
61
+ /// The array is shuffled **in place**, no copy of the array is allocated.
66
62
///
67
- /// The method uses Hoare's partition algorithm.
68
- /// Complexity: O(`n`), where `n` is the number of elements in the array.
69
- /// Average number of element swaps: n/6 - 1/3 (see
70
- /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
63
+ /// This method implements the partitioning scheme of [Yaroslavskiy-Bentley-Bloch Quicksort].
71
64
///
72
- /// **Panics** if `pivot_index` is greater than or equal to `n`.
65
+ /// [Yaroslavskiy-Bentley-Bloch Quicksort]: https://api.semanticscholar.org/CorpusID:51871084
73
66
///
74
67
/// # Example
75
68
///
@@ -78,23 +71,30 @@ where
78
71
/// use ndarray_stats::Sort1dExt;
79
72
///
80
73
/// let mut data = array![3, 1, 4, 5, 2];
81
- /// let pivot_index = 2;
82
- /// let pivot_value = data[pivot_index] ;
74
+ /// // Sorted pivot values.
75
+ /// let (lower_value, upper_value) = ( data[data.len() - 1], data[0]) ;
83
76
///
84
- /// // Partition by the value located at `pivot_index`.
85
- /// let new_index = data.partition_mut(pivot_index);
86
- /// // The pivot value is now located at `new_index`.
87
- /// assert_eq!(data[new_index], pivot_value);
88
- /// // Elements less than that value are moved to the left.
89
- /// for i in 0..new_index {
90
- /// assert!(data[i] < pivot_value);
77
+ /// // Partitions by the values located at `0` and `data.len() - 1`.
78
+ /// let (lower_index, upper_index) = data.partition_mut();
79
+ /// // The pivot values are now located at `lower_index` and `upper_index`.
80
+ /// assert_eq!(data[lower_index], lower_value);
81
+ /// assert_eq!(data[upper_index], upper_value);
82
+ /// // Elements lower than the lower pivot value are moved to its left.
83
+ /// for i in 0..lower_index {
84
+ /// assert!(data[i] < lower_value);
85
+ /// }
86
+ /// // Elements greater than or equal the lower pivot value and less than or equal the upper
87
+ /// // pivot value are moved between the two pivot indexes.
88
+ /// for i in lower_index + 1..upper_index {
89
+ /// assert!(lower_value <= data[i]);
90
+ /// assert!(data[i] <= upper_value);
91
91
/// }
92
- /// // Elements greater than or equal to that value are moved to the right.
93
- /// for i in (new_index + 1) ..data.len() {
94
- /// assert!(data[i] >= pivot_value );
92
+ /// // Elements greater than or equal the upper pivot value are moved to its right.
93
+ /// for i in upper_index + 1..data.len() {
94
+ /// assert!(upper_value <= data[i]);
95
95
/// }
96
96
/// ```
97
- fn partition_mut ( & mut self , pivot_index : usize ) -> usize
97
+ fn partition_mut ( & mut self ) -> ( usize , usize )
98
98
where
99
99
A : Ord + Clone ,
100
100
S : DataMut ;
@@ -115,17 +115,20 @@ where
115
115
if n == 1 {
116
116
self [ 0 ] . clone ( )
117
117
} else {
118
- let mut rng = thread_rng ( ) ;
119
- let pivot_index = rng. gen_range ( 0 ..n) ;
120
- let partition_index = self . partition_mut ( pivot_index) ;
121
- if i < partition_index {
122
- self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..partition_index) )
118
+ let ( lower_index, upper_index) = self . partition_mut ( ) ;
119
+ if i < lower_index {
120
+ self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..lower_index) )
123
121
. get_from_sorted_mut ( i)
124
- } else if i == partition_index {
122
+ } else if i == lower_index {
123
+ self [ i] . clone ( )
124
+ } else if i < upper_index {
125
+ self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( lower_index + 1 ..upper_index) )
126
+ . get_from_sorted_mut ( i - ( lower_index + 1 ) )
127
+ } else if i == upper_index {
125
128
self [ i] . clone ( )
126
129
} else {
127
- self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( partition_index + 1 ..) )
128
- . get_from_sorted_mut ( i - ( partition_index + 1 ) )
130
+ self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( upper_index + 1 ..) )
131
+ . get_from_sorted_mut ( i - ( upper_index + 1 ) )
129
132
}
130
133
}
131
134
}
@@ -143,42 +146,51 @@ where
143
146
get_many_from_sorted_mut_unchecked ( self , & deduped_indexes)
144
147
}
145
148
146
- fn partition_mut ( & mut self , pivot_index : usize ) -> usize
149
+ fn partition_mut ( & mut self ) -> ( usize , usize )
147
150
where
148
151
A : Ord + Clone ,
149
152
S : DataMut ,
150
153
{
151
- let pivot_value = self [ pivot_index] . clone ( ) ;
152
- self . swap ( pivot_index, 0 ) ;
153
- let n = self . len ( ) ;
154
- let mut i = 1 ;
155
- let mut j = n - 1 ;
156
- loop {
157
- loop {
158
- if i > j {
159
- break ;
160
- }
161
- if self [ i] >= pivot_value {
162
- break ;
154
+ // Sort `lowermost` and `uppermost` elements and use them as dual pivot.
155
+ let lowermost = 0 ;
156
+ let uppermost = self . len ( ) - 1 ;
157
+ if self [ lowermost] > self [ uppermost] {
158
+ self . swap ( lowermost, uppermost) ;
159
+ }
160
+ // Increasing running and partition index starting after lower pivot.
161
+ let mut index = lowermost + 1 ;
162
+ let mut lower = lowermost + 1 ;
163
+ // Decreasing partition index starting before upper pivot.
164
+ let mut upper = uppermost - 1 ;
165
+ // Swap elements at `index` into their partitions.
166
+ while index <= upper {
167
+ if self [ index] < self [ lowermost] {
168
+ // Swap elements into lower partition.
169
+ self . swap ( index, lower) ;
170
+ lower += 1 ;
171
+ } else if self [ index] >= self [ uppermost] {
172
+ // Search first element of upper partition.
173
+ while self [ upper] > self [ uppermost] && index < upper {
174
+ upper -= 1 ;
163
175
}
164
- i += 1 ;
165
- }
166
- while pivot_value <= self [ j] {
167
- if j == 1 {
168
- break ;
176
+ // Swap elements into upper partition.
177
+ self . swap ( index, upper) ;
178
+ if self [ index] < self [ lowermost] {
179
+ // Swap swapped elements into lower partition.
180
+ self . swap ( index, lower) ;
181
+ lower += 1 ;
169
182
}
170
- j -= 1 ;
171
- }
172
- if i >= j {
173
- break ;
174
- } else {
175
- self . swap ( i, j) ;
176
- i += 1 ;
177
- j -= 1 ;
183
+ upper -= 1 ;
178
184
}
185
+ index += 1 ;
179
186
}
180
- self . swap ( 0 , i - 1 ) ;
181
- i - 1
187
+ lower -= 1 ;
188
+ upper += 1 ;
189
+ // Swap pivots to their new indexes.
190
+ self . swap ( lowermost, lower) ;
191
+ self . swap ( uppermost, upper) ;
192
+ // Lower and upper pivot index.
193
+ ( lower, upper)
182
194
}
183
195
184
196
private_impl ! { }
@@ -249,50 +261,72 @@ fn _get_many_from_sorted_mut_unchecked<A>(
249
261
return ;
250
262
}
251
263
252
- // We pick a random pivot index: the corresponding element is the pivot value
253
- let mut rng = thread_rng ( ) ;
254
- let pivot_index = rng. gen_range ( 0 ..n) ;
264
+ // We partition the array with respect to the two pivot values. The pivot values move to
265
+ // `lower_index` and `upper_index`.
266
+ //
267
+ // Elements strictly less than the lower pivot value have indexes < `lower_index`.
268
+ //
269
+ // Elements greater than or equal the lower pivot value and less than or equal the upper pivot
270
+ // value have indexes > `lower_index` and < `upper_index`.
271
+ //
272
+ // Elements less than or equal the upper pivot value have indexes > `upper_index`.
273
+ let ( lower_index, upper_index) = array. partition_mut ( ) ;
255
274
256
- // We partition the array with respect to the pivot value.
257
- // The pivot value moves to `array_partition_index`.
258
- // Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
259
- // Elements greater or equal to the pivot value have indexes > `array_partition_index`.
260
- let array_partition_index = array. partition_mut ( pivot_index) ;
275
+ // We use a divide-and-conquer strategy, splitting the indexes we are searching for (`indexes`)
276
+ // and the corresponding portions of the output slice (`values`) into partitions with respect to
277
+ // `lower_index` and `upper_index`.
278
+ let ( found_exact, split_index) = match indexes. binary_search ( & lower_index) {
279
+ Ok ( index) => ( true , index) ,
280
+ Err ( index) => ( false , index) ,
281
+ } ;
282
+ let ( lower_indexes, inner_indexes) = indexes. split_at_mut ( split_index) ;
283
+ let ( lower_values, inner_values) = values. split_at_mut ( split_index) ;
284
+ let ( upper_indexes, upper_values) = if found_exact {
285
+ inner_values[ 0 ] = array[ lower_index] . clone ( ) ; // Write exactly found value.
286
+ ( & mut inner_indexes[ 1 ..] , & mut inner_values[ 1 ..] )
287
+ } else {
288
+ ( inner_indexes, inner_values)
289
+ } ;
261
290
262
- // We use a divide-and-conquer strategy, splitting the indexes we are
263
- // searching for (`indexes`) and the corresponding portions of the output
264
- // slice (`values`) into pieces with respect to `array_partition_index`.
265
- let ( found_exact, index_split) = match indexes. binary_search ( & array_partition_index) {
291
+ let ( found_exact, split_index) = match upper_indexes. binary_search ( & upper_index) {
266
292
Ok ( index) => ( true , index) ,
267
293
Err ( index) => ( false , index) ,
268
294
} ;
269
- let ( smaller_indexes , other_indexes ) = indexes . split_at_mut ( index_split ) ;
270
- let ( smaller_values , other_values ) = values . split_at_mut ( index_split ) ;
271
- let ( bigger_indexes , bigger_values ) = if found_exact {
272
- other_values [ 0 ] = array[ array_partition_index ] . clone ( ) ; // Write exactly found value.
273
- ( & mut other_indexes [ 1 ..] , & mut other_values [ 1 ..] )
295
+ let ( inner_indexes , upper_indexes ) = upper_indexes . split_at_mut ( split_index ) ;
296
+ let ( inner_values , upper_values ) = upper_values . split_at_mut ( split_index ) ;
297
+ let ( upper_indexes , upper_values ) = if found_exact {
298
+ upper_values [ 0 ] = array[ upper_index ] . clone ( ) ; // Write exactly found value.
299
+ ( & mut upper_indexes [ 1 ..] , & mut upper_values [ 1 ..] )
274
300
} else {
275
- ( other_indexes , other_values )
301
+ ( upper_indexes , upper_values )
276
302
} ;
277
303
278
- // We search recursively for the values corresponding to strictly smaller
279
- // indexes to the left of `partition_index`.
304
+ // We search recursively for the values corresponding to indexes strictly less than
305
+ // `lower_index` in the lower partition.
306
+ _get_many_from_sorted_mut_unchecked (
307
+ array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..lower_index) ) ,
308
+ lower_indexes,
309
+ lower_values,
310
+ ) ;
311
+
312
+ // We search recursively for the values corresponding to indexes greater than or equal
313
+ // `lower_index` in the inner partition, that is between the lower and upper partition. Since
314
+ // only the inner partition of the array is passed in, the indexes need to be shifted by length
315
+ // of the lower partition.
316
+ inner_indexes. iter_mut ( ) . for_each ( |x| * x -= lower_index + 1 ) ;
280
317
_get_many_from_sorted_mut_unchecked (
281
- array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..array_partition_index ) ) ,
282
- smaller_indexes ,
283
- smaller_values ,
318
+ array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( lower_index + 1 ..upper_index ) ) ,
319
+ inner_indexes ,
320
+ inner_values ,
284
321
) ;
285
322
286
- // We search recursively for the values corresponding to strictly bigger
287
- // indexes to the right of `partition_index`. Since only the right portion
288
- // of the array is passed in, the indexes need to be shifted by length of
289
- // the removed portion.
290
- bigger_indexes
291
- . iter_mut ( )
292
- . for_each ( |x| * x -= array_partition_index + 1 ) ;
323
+ // We search recursively for the values corresponding to indexes greater than or equal
324
+ // `upper_index` in the upper partition. Since only the upper partition of the array is passed
325
+ // in, the indexes need to be shifted by the combined length of the lower and inner partition.
326
+ upper_indexes. iter_mut ( ) . for_each ( |x| * x -= upper_index + 1 ) ;
293
327
_get_many_from_sorted_mut_unchecked (
294
- array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( array_partition_index + 1 ..) ) ,
295
- bigger_indexes ,
296
- bigger_values ,
328
+ array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( upper_index + 1 ..) ) ,
329
+ upper_indexes ,
330
+ upper_values ,
297
331
) ;
298
332
}
0 commit comments