diff --git a/src/lib.rs b/src/lib.rs
index 4ae1100..5c1ff08 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,6 +5,7 @@
//! - [order statistics] (minimum, maximum, median, quantiles, etc.);
//! - [summary statistics] (mean, skewness, kurtosis, central moments, etc.)
//! - [partitioning];
+//! - [sorting](SortExt);
//! - [correlation analysis] (covariance, pearson correlation);
//! - [measures from information theory] (entropy, KL divergence, etc.);
//! - [measures of deviation] (count equal, L1, L2 distances, mean squared err etc.)
@@ -35,7 +36,7 @@ pub use crate::entropy::EntropyExt;
pub use crate::histogram::HistogramExt;
pub use crate::maybe_nan::{MaybeNan, MaybeNanExt};
pub use crate::quantile::{interpolate, Quantile1dExt, QuantileExt};
-pub use crate::sort::Sort1dExt;
+pub use crate::sort::{Sort1dExt, SortExt};
pub use crate::summary_statistics::SummaryStatisticsExt;
#[cfg(test)]
diff --git a/src/sort.rs b/src/sort.rs
index f43a95b..8ca423d 100644
--- a/src/sort.rs
+++ b/src/sort.rs
@@ -3,6 +3,7 @@ use ndarray::prelude::*;
use ndarray::{Data, DataMut, Slice};
use rand::prelude::*;
use rand::thread_rng;
+use std::cmp::Ordering;
/// Methods for sorting and partitioning 1-D arrays.
pub trait Sort1dExt
@@ -296,3 +297,118 @@ fn _get_many_from_sorted_mut_unchecked(
bigger_values,
);
}
+
+/// Methods for sorting n-D arrays.
+pub trait SortExt
+where
+ S: Data,
+ D: Dimension,
+{
+ /// Returns the indices which would sort the array along the specified
+ /// axis.
+ ///
+ /// Each lane of the array is sorted separately. The sort is stable.
+ ///
+ /// **Panics** if `axis` is out-of-bounds.
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use ndarray::{array, Axis};
+ /// use ndarray_stats::SortExt;
+ ///
+ /// let a = array![
+ /// [0, 5, 2],
+ /// [3, 3, 1],
+ /// ];
+ /// assert_eq!(
+ /// a.argsort_axis(Axis(0)),
+ /// array![
+ /// [0, 1, 1],
+ /// [1, 0, 0],
+ /// ],
+ /// );
+ /// assert_eq!(
+ /// a.argsort_axis(Axis(1)),
+ /// array![
+ /// [0, 2, 1],
+ /// [2, 0, 1],
+ /// ],
+ /// );
+ /// ```
+ fn argsort_axis(&self, axis: Axis) -> Array
+ where
+ A: Ord;
+
+ /// Returns the indices which would sort the array along the specified
+ /// axis, using the specified comparator function.
+ ///
+ /// Each lane of the array is sorted separately. The sort is stable.
+ ///
+ /// **Panics** if `axis` is out-of-bounds.
+ fn argsort_axis_by(&self, axis: Axis, compare: F) -> Array
+ where
+ F: FnMut(&A, &A) -> Ordering;
+
+ /// Returns the indices which would sort the array along the specified
+ /// axis, using the specified key function.
+ ///
+ /// Each lane of the array is sorted separately. The sort is stable.
+ ///
+ /// **Panics** if `axis` is out-of-bounds.
+ fn argsort_axis_by_key(&self, axis: Axis, f: F) -> Array
+ where
+ F: FnMut(&A) -> K,
+ K: Ord;
+}
+
+impl SortExt for ArrayBase
+where
+ S: Data,
+ D: Dimension,
+{
+ fn argsort_axis(&self, axis: Axis) -> Array
+ where
+ A: Ord,
+ {
+ self.view().argsort_axis_by(axis, Ord::cmp)
+ }
+
+ fn argsort_axis_by(&self, axis: Axis, mut compare: F) -> Array
+ where
+ F: FnMut(&A, &A) -> Ordering,
+ {
+ // Construct an array with the same shape as `self` and with strides such
+ // that `axis` is contiguous.
+ let mut indices: Array = {
+ let ax = axis.index();
+ let ndim = self.ndim();
+ let mut shape = self.raw_dim();
+ shape.as_array_view_mut().swap(ax, ndim - 1);
+ let mut tmp = Array::zeros(shape);
+ tmp.swap_axes(ax, ndim - 1);
+ tmp
+ };
+ debug_assert_eq!(self.shape(), indices.shape());
+ debug_assert!(indices.len_of(axis) <= 1 || indices.stride_of(axis) == 1);
+
+ // Fill in the array with the sorted indices.
+ azip!((input_lane in self.lanes(axis), mut indices_lane in indices.lanes_mut(axis)) {
+ // This unwrap won't panic because we constructed `indices` so that
+ // `axis` is contiguous.
+ let indices_lane = indices_lane.as_slice_mut().unwrap();
+ indices_lane.iter_mut().enumerate().for_each(|(i, x)| *x = i);
+ indices_lane.sort_by(|&i, &j| compare(&input_lane[i], &input_lane[j]));
+ });
+
+ indices
+ }
+
+ fn argsort_axis_by_key(&self, axis: Axis, mut f: F) -> Array
+ where
+ F: FnMut(&A) -> K,
+ K: Ord,
+ {
+ self.view().argsort_axis_by(axis, |a, b| f(a).cmp(&f(b)))
+ }
+}
diff --git a/tests/sort.rs b/tests/sort.rs
index b2bd12f..0aa7c43 100644
--- a/tests/sort.rs
+++ b/tests/sort.rs
@@ -1,5 +1,7 @@
use ndarray::prelude::*;
-use ndarray_stats::Sort1dExt;
+use ndarray_rand::rand_distr::Uniform;
+use ndarray_rand::RandomExt;
+use ndarray_stats::{Sort1dExt, SortExt};
use quickcheck_macros::quickcheck;
#[test]
@@ -84,3 +86,87 @@ fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec) -> bool {
xs == sorted_v
}
}
+
+#[test]
+fn argsort_1d() {
+ let a = array![5, 2, 0, 7, 3, 2, 8, 9];
+ let correct = array![2, 1, 5, 4, 0, 3, 6, 7];
+ assert_eq!(a.argsort_axis(Axis(0)), &correct);
+ assert_eq!(a.argsort_axis_by(Axis(0), |a, b| a.cmp(b)), &correct);
+ assert_eq!(a.argsort_axis_by_key(Axis(0), |&x| x), &correct);
+}
+
+#[test]
+fn argsort_2d() {
+ let a = array![[3, 5, 1, 2], [2, 0, 1, 3], [9, 4, 6, 1]];
+ for (axis, correct) in [
+ (Axis(0), array![[1, 1, 0, 2], [0, 2, 1, 0], [2, 0, 2, 1]]),
+ (Axis(1), array![[2, 3, 0, 1], [1, 2, 0, 3], [3, 1, 2, 0]]),
+ ] {
+ assert_eq!(a.argsort_axis(axis), &correct);
+ assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct);
+ assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct);
+ }
+}
+
+#[test]
+fn argsort_3d() {
+ let a = array![
+ [[3, 5, 1, 2], [9, 7, 6, 8]],
+ [[2, 0, 1, 3], [1, 2, 3, 4]],
+ [[9, 4, 6, 1], [8, 5, 3, 2]],
+ ];
+ for (axis, correct) in [
+ (
+ Axis(0),
+ array![
+ [[1, 1, 0, 2], [1, 1, 1, 2]],
+ [[0, 2, 1, 0], [2, 2, 2, 1]],
+ [[2, 0, 2, 1], [0, 0, 0, 0]],
+ ],
+ ),
+ (
+ Axis(1),
+ array![
+ [[0, 0, 0, 0], [1, 1, 1, 1]],
+ [[1, 0, 0, 0], [0, 1, 1, 1]],
+ [[1, 0, 1, 0], [0, 1, 0, 1]],
+ ],
+ ),
+ (
+ Axis(2),
+ array![
+ [[2, 3, 0, 1], [2, 1, 3, 0]],
+ [[1, 2, 0, 3], [0, 1, 2, 3]],
+ [[3, 1, 2, 0], [3, 2, 1, 0]],
+ ],
+ ),
+ ] {
+ assert_eq!(a.argsort_axis(axis), &correct);
+ assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct);
+ assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct);
+ }
+}
+
+#[test]
+fn argsort_len_0_or_1_axis() {
+ fn test_shape(base_shape: impl ndarray::IntoDimension) {
+ let base_shape = base_shape.into_dimension();
+ for ax in 0..base_shape.ndim() {
+ for axis_len in [0, 1] {
+ let mut shape = base_shape.clone();
+ shape[ax] = axis_len;
+ let a = Array::random(shape.clone(), Uniform::new(0, 100));
+ let axis = Axis(ax);
+ let correct = Array::zeros(shape);
+ assert_eq!(a.argsort_axis(axis), &correct);
+ assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct);
+ assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct);
+ }
+ }
+ }
+ test_shape([1]);
+ test_shape([2, 3]);
+ test_shape([3, 2, 4]);
+ test_shape([2, 4, 3, 2]);
+}