diff --git a/src/lib.rs b/src/lib.rs index 4ae1100..5c1ff08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ //! - [order statistics] (minimum, maximum, median, quantiles, etc.); //! - [summary statistics] (mean, skewness, kurtosis, central moments, etc.) //! - [partitioning]; +//! - [sorting](SortExt); //! - [correlation analysis] (covariance, pearson correlation); //! - [measures from information theory] (entropy, KL divergence, etc.); //! - [measures of deviation] (count equal, L1, L2 distances, mean squared err etc.) @@ -35,7 +36,7 @@ pub use crate::entropy::EntropyExt; pub use crate::histogram::HistogramExt; pub use crate::maybe_nan::{MaybeNan, MaybeNanExt}; pub use crate::quantile::{interpolate, Quantile1dExt, QuantileExt}; -pub use crate::sort::Sort1dExt; +pub use crate::sort::{Sort1dExt, SortExt}; pub use crate::summary_statistics::SummaryStatisticsExt; #[cfg(test)] diff --git a/src/sort.rs b/src/sort.rs index f43a95b..8ca423d 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -3,6 +3,7 @@ use ndarray::prelude::*; use ndarray::{Data, DataMut, Slice}; use rand::prelude::*; use rand::thread_rng; +use std::cmp::Ordering; /// Methods for sorting and partitioning 1-D arrays. pub trait Sort1dExt @@ -296,3 +297,118 @@ fn _get_many_from_sorted_mut_unchecked( bigger_values, ); } + +/// Methods for sorting n-D arrays. +pub trait SortExt +where + S: Data, + D: Dimension, +{ + /// Returns the indices which would sort the array along the specified + /// axis. + /// + /// Each lane of the array is sorted separately. The sort is stable. + /// + /// **Panics** if `axis` is out-of-bounds. + /// + /// # Example + /// + /// ``` + /// use ndarray::{array, Axis}; + /// use ndarray_stats::SortExt; + /// + /// let a = array![ + /// [0, 5, 2], + /// [3, 3, 1], + /// ]; + /// assert_eq!( + /// a.argsort_axis(Axis(0)), + /// array![ + /// [0, 1, 1], + /// [1, 0, 0], + /// ], + /// ); + /// assert_eq!( + /// a.argsort_axis(Axis(1)), + /// array![ + /// [0, 2, 1], + /// [2, 0, 1], + /// ], + /// ); + /// ``` + fn argsort_axis(&self, axis: Axis) -> Array + where + A: Ord; + + /// Returns the indices which would sort the array along the specified + /// axis, using the specified comparator function. + /// + /// Each lane of the array is sorted separately. The sort is stable. + /// + /// **Panics** if `axis` is out-of-bounds. + fn argsort_axis_by(&self, axis: Axis, compare: F) -> Array + where + F: FnMut(&A, &A) -> Ordering; + + /// Returns the indices which would sort the array along the specified + /// axis, using the specified key function. + /// + /// Each lane of the array is sorted separately. The sort is stable. + /// + /// **Panics** if `axis` is out-of-bounds. + fn argsort_axis_by_key(&self, axis: Axis, f: F) -> Array + where + F: FnMut(&A) -> K, + K: Ord; +} + +impl SortExt for ArrayBase +where + S: Data, + D: Dimension, +{ + fn argsort_axis(&self, axis: Axis) -> Array + where + A: Ord, + { + self.view().argsort_axis_by(axis, Ord::cmp) + } + + fn argsort_axis_by(&self, axis: Axis, mut compare: F) -> Array + where + F: FnMut(&A, &A) -> Ordering, + { + // Construct an array with the same shape as `self` and with strides such + // that `axis` is contiguous. + let mut indices: Array = { + let ax = axis.index(); + let ndim = self.ndim(); + let mut shape = self.raw_dim(); + shape.as_array_view_mut().swap(ax, ndim - 1); + let mut tmp = Array::zeros(shape); + tmp.swap_axes(ax, ndim - 1); + tmp + }; + debug_assert_eq!(self.shape(), indices.shape()); + debug_assert!(indices.len_of(axis) <= 1 || indices.stride_of(axis) == 1); + + // Fill in the array with the sorted indices. + azip!((input_lane in self.lanes(axis), mut indices_lane in indices.lanes_mut(axis)) { + // This unwrap won't panic because we constructed `indices` so that + // `axis` is contiguous. + let indices_lane = indices_lane.as_slice_mut().unwrap(); + indices_lane.iter_mut().enumerate().for_each(|(i, x)| *x = i); + indices_lane.sort_by(|&i, &j| compare(&input_lane[i], &input_lane[j])); + }); + + indices + } + + fn argsort_axis_by_key(&self, axis: Axis, mut f: F) -> Array + where + F: FnMut(&A) -> K, + K: Ord, + { + self.view().argsort_axis_by(axis, |a, b| f(a).cmp(&f(b))) + } +} diff --git a/tests/sort.rs b/tests/sort.rs index b2bd12f..0aa7c43 100644 --- a/tests/sort.rs +++ b/tests/sort.rs @@ -1,5 +1,7 @@ use ndarray::prelude::*; -use ndarray_stats::Sort1dExt; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; +use ndarray_stats::{Sort1dExt, SortExt}; use quickcheck_macros::quickcheck; #[test] @@ -84,3 +86,87 @@ fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec) -> bool { xs == sorted_v } } + +#[test] +fn argsort_1d() { + let a = array![5, 2, 0, 7, 3, 2, 8, 9]; + let correct = array![2, 1, 5, 4, 0, 3, 6, 7]; + assert_eq!(a.argsort_axis(Axis(0)), &correct); + assert_eq!(a.argsort_axis_by(Axis(0), |a, b| a.cmp(b)), &correct); + assert_eq!(a.argsort_axis_by_key(Axis(0), |&x| x), &correct); +} + +#[test] +fn argsort_2d() { + let a = array![[3, 5, 1, 2], [2, 0, 1, 3], [9, 4, 6, 1]]; + for (axis, correct) in [ + (Axis(0), array![[1, 1, 0, 2], [0, 2, 1, 0], [2, 0, 2, 1]]), + (Axis(1), array![[2, 3, 0, 1], [1, 2, 0, 3], [3, 1, 2, 0]]), + ] { + assert_eq!(a.argsort_axis(axis), &correct); + assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct); + assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct); + } +} + +#[test] +fn argsort_3d() { + let a = array![ + [[3, 5, 1, 2], [9, 7, 6, 8]], + [[2, 0, 1, 3], [1, 2, 3, 4]], + [[9, 4, 6, 1], [8, 5, 3, 2]], + ]; + for (axis, correct) in [ + ( + Axis(0), + array![ + [[1, 1, 0, 2], [1, 1, 1, 2]], + [[0, 2, 1, 0], [2, 2, 2, 1]], + [[2, 0, 2, 1], [0, 0, 0, 0]], + ], + ), + ( + Axis(1), + array![ + [[0, 0, 0, 0], [1, 1, 1, 1]], + [[1, 0, 0, 0], [0, 1, 1, 1]], + [[1, 0, 1, 0], [0, 1, 0, 1]], + ], + ), + ( + Axis(2), + array![ + [[2, 3, 0, 1], [2, 1, 3, 0]], + [[1, 2, 0, 3], [0, 1, 2, 3]], + [[3, 1, 2, 0], [3, 2, 1, 0]], + ], + ), + ] { + assert_eq!(a.argsort_axis(axis), &correct); + assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct); + assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct); + } +} + +#[test] +fn argsort_len_0_or_1_axis() { + fn test_shape(base_shape: impl ndarray::IntoDimension) { + let base_shape = base_shape.into_dimension(); + for ax in 0..base_shape.ndim() { + for axis_len in [0, 1] { + let mut shape = base_shape.clone(); + shape[ax] = axis_len; + let a = Array::random(shape.clone(), Uniform::new(0, 100)); + let axis = Axis(ax); + let correct = Array::zeros(shape); + assert_eq!(a.argsort_axis(axis), &correct); + assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct); + assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct); + } + } + } + test_shape([1]); + test_shape([2, 3]); + test_shape([3, 2, 4]); + test_shape([2, 4, 3, 2]); +}