Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argsort methods #84

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! - [order statistics] (minimum, maximum, median, quantiles, etc.);
//! - [summary statistics] (mean, skewness, kurtosis, central moments, etc.)
//! - [partitioning];
//! - [sorting](SortExt);
//! - [correlation analysis] (covariance, pearson correlation);
//! - [measures from information theory] (entropy, KL divergence, etc.);
//! - [measures of deviation] (count equal, L1, L2 distances, mean squared err etc.)
Expand Down Expand Up @@ -35,7 +36,7 @@ pub use crate::entropy::EntropyExt;
pub use crate::histogram::HistogramExt;
pub use crate::maybe_nan::{MaybeNan, MaybeNanExt};
pub use crate::quantile::{interpolate, Quantile1dExt, QuantileExt};
pub use crate::sort::Sort1dExt;
pub use crate::sort::{Sort1dExt, SortExt};
pub use crate::summary_statistics::SummaryStatisticsExt;

#[cfg(test)]
Expand Down
116 changes: 116 additions & 0 deletions src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use ndarray::prelude::*;
use ndarray::{Data, DataMut, Slice};
use rand::prelude::*;
use rand::thread_rng;
use std::cmp::Ordering;

/// Methods for sorting and partitioning 1-D arrays.
pub trait Sort1dExt<A, S>
Expand Down Expand Up @@ -296,3 +297,118 @@ fn _get_many_from_sorted_mut_unchecked<A>(
bigger_values,
);
}

/// Methods for sorting n-D arrays.
pub trait SortExt<A, S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
/// Returns the indices which would sort the array along the specified
/// axis.
///
/// Each lane of the array is sorted separately. The sort is stable.
///
/// **Panics** if `axis` is out-of-bounds.
///
/// # Example
///
/// ```
/// use ndarray::{array, Axis};
/// use ndarray_stats::SortExt;
///
/// let a = array![
/// [0, 5, 2],
/// [3, 3, 1],
/// ];
/// assert_eq!(
/// a.argsort_axis(Axis(0)),
/// array![
/// [0, 1, 1],
/// [1, 0, 0],
/// ],
/// );
/// assert_eq!(
/// a.argsort_axis(Axis(1)),
/// array![
/// [0, 2, 1],
/// [2, 0, 1],
/// ],
/// );
/// ```
fn argsort_axis(&self, axis: Axis) -> Array<usize, D>
where
A: Ord;

/// Returns the indices which would sort the array along the specified
/// axis, using the specified comparator function.
///
/// Each lane of the array is sorted separately. The sort is stable.
///
/// **Panics** if `axis` is out-of-bounds.
fn argsort_axis_by<F>(&self, axis: Axis, compare: F) -> Array<usize, D>
where
F: FnMut(&A, &A) -> Ordering;

/// Returns the indices which would sort the array along the specified
/// axis, using the specified key function.
///
/// Each lane of the array is sorted separately. The sort is stable.
///
/// **Panics** if `axis` is out-of-bounds.
fn argsort_axis_by_key<K, F>(&self, axis: Axis, f: F) -> Array<usize, D>
where
F: FnMut(&A) -> K,
K: Ord;
}

impl<A, S, D> SortExt<A, S, D> for ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
fn argsort_axis(&self, axis: Axis) -> Array<usize, D>
where
A: Ord,
{
self.view().argsort_axis_by(axis, Ord::cmp)
}

fn argsort_axis_by<F>(&self, axis: Axis, mut compare: F) -> Array<usize, D>
where
F: FnMut(&A, &A) -> Ordering,
{
// Construct an array with the same shape as `self` and with strides such
// that `axis` is contiguous.
let mut indices: Array<usize, D> = {
let ax = axis.index();
let ndim = self.ndim();
let mut shape = self.raw_dim();
shape.as_array_view_mut().swap(ax, ndim - 1);
let mut tmp = Array::zeros(shape);
tmp.swap_axes(ax, ndim - 1);
tmp
};
debug_assert_eq!(self.shape(), indices.shape());
debug_assert!(indices.len_of(axis) <= 1 || indices.stride_of(axis) == 1);

// Fill in the array with the sorted indices.
azip!((input_lane in self.lanes(axis), mut indices_lane in indices.lanes_mut(axis)) {
// This unwrap won't panic because we constructed `indices` so that
// `axis` is contiguous.
let indices_lane = indices_lane.as_slice_mut().unwrap();
indices_lane.iter_mut().enumerate().for_each(|(i, x)| *x = i);
indices_lane.sort_by(|&i, &j| compare(&input_lane[i], &input_lane[j]));
});

indices
}

fn argsort_axis_by_key<K, F>(&self, axis: Axis, mut f: F) -> Array<usize, D>
where
F: FnMut(&A) -> K,
K: Ord,
{
self.view().argsort_axis_by(axis, |a, b| f(a).cmp(&f(b)))
}
}
88 changes: 87 additions & 1 deletion tests/sort.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use ndarray::prelude::*;
use ndarray_stats::Sort1dExt;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use ndarray_stats::{Sort1dExt, SortExt};
use quickcheck_macros::quickcheck;

#[test]
Expand Down Expand Up @@ -84,3 +86,87 @@ fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec<i64>) -> bool {
xs == sorted_v
}
}

#[test]
fn argsort_1d() {
let a = array![5, 2, 0, 7, 3, 2, 8, 9];
let correct = array![2, 1, 5, 4, 0, 3, 6, 7];
assert_eq!(a.argsort_axis(Axis(0)), &correct);
assert_eq!(a.argsort_axis_by(Axis(0), |a, b| a.cmp(b)), &correct);
assert_eq!(a.argsort_axis_by_key(Axis(0), |&x| x), &correct);
}

#[test]
fn argsort_2d() {
let a = array![[3, 5, 1, 2], [2, 0, 1, 3], [9, 4, 6, 1]];
for (axis, correct) in [
(Axis(0), array![[1, 1, 0, 2], [0, 2, 1, 0], [2, 0, 2, 1]]),
(Axis(1), array![[2, 3, 0, 1], [1, 2, 0, 3], [3, 1, 2, 0]]),
] {
assert_eq!(a.argsort_axis(axis), &correct);
assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct);
assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct);
}
}

#[test]
fn argsort_3d() {
let a = array![
[[3, 5, 1, 2], [9, 7, 6, 8]],
[[2, 0, 1, 3], [1, 2, 3, 4]],
[[9, 4, 6, 1], [8, 5, 3, 2]],
];
for (axis, correct) in [
(
Axis(0),
array![
[[1, 1, 0, 2], [1, 1, 1, 2]],
[[0, 2, 1, 0], [2, 2, 2, 1]],
[[2, 0, 2, 1], [0, 0, 0, 0]],
],
),
(
Axis(1),
array![
[[0, 0, 0, 0], [1, 1, 1, 1]],
[[1, 0, 0, 0], [0, 1, 1, 1]],
[[1, 0, 1, 0], [0, 1, 0, 1]],
],
),
(
Axis(2),
array![
[[2, 3, 0, 1], [2, 1, 3, 0]],
[[1, 2, 0, 3], [0, 1, 2, 3]],
[[3, 1, 2, 0], [3, 2, 1, 0]],
],
),
] {
assert_eq!(a.argsort_axis(axis), &correct);
assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct);
assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct);
}
}

#[test]
fn argsort_len_0_or_1_axis() {
fn test_shape(base_shape: impl ndarray::IntoDimension) {
let base_shape = base_shape.into_dimension();
for ax in 0..base_shape.ndim() {
for axis_len in [0, 1] {
let mut shape = base_shape.clone();
shape[ax] = axis_len;
let a = Array::random(shape.clone(), Uniform::new(0, 100));
let axis = Axis(ax);
let correct = Array::zeros(shape);
assert_eq!(a.argsort_axis(axis), &correct);
assert_eq!(a.argsort_axis_by(axis, |a, b| a.cmp(b)), &correct);
assert_eq!(a.argsort_axis_by_key(axis, |&x| x), &correct);
}
}
}
test_shape([1]);
test_shape([2, 3]);
test_shape([3, 2, 4]);
test_shape([2, 4, 3, 2]);
}