Skip to content

Commit 903cb23

Browse files
authored
Extend ndarray-rand to be able to randomly sample from ArrayRef (#1540)
Prior to ndarray 0.17, the RandomExt trait exposed by ndarray-rand contained methods for both creating new arrays randomly whole-cloth (random_using) and sampling from existing arrays (sample_axis_using). With the introduction of reference types in ndarray 0.17, users should be able to sample from ArrayRef instances as well. We choose to expose an additional extension trait, RandomRefExt, that provides this functionality. We keep the methods on the old trait for backwards compatibility, but collapse the implementation and documentation to the new trait to maintain a single source of truth.
1 parent 66dc0e1 commit 903cb23

File tree

1 file changed

+65
-12
lines changed

1 file changed

+65
-12
lines changed

ndarray-rand/src/lib.rs

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,14 @@
2929
//! that the items are not compatible (e.g. that a type doesn't implement a
3030
//! necessary trait).
3131
32+
#![warn(missing_docs)]
33+
3234
use crate::rand::distr::{Distribution, Uniform};
3335
use crate::rand::rngs::SmallRng;
3436
use crate::rand::seq::index;
3537
use crate::rand::{rng, Rng, SeedableRng};
3638

37-
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
39+
use ndarray::{Array, ArrayRef, Axis, RemoveAxis, ShapeBuilder};
3840
use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData};
3941
#[cfg(feature = "quickcheck")]
4042
use quickcheck::{Arbitrary, Gen};
@@ -51,18 +53,15 @@ pub mod rand_distr
5153
pub use rand_distr::*;
5254
}
5355

54-
/// Constructors for n-dimensional arrays with random elements.
55-
///
56-
/// This trait extends ndarray’s `ArrayBase` and can not be implemented
57-
/// for other types.
56+
/// Extension trait for constructing n-dimensional arrays with random elements.
5857
///
5958
/// The default RNG is a fast automatically seeded rng (currently
60-
/// [`rand::rngs::SmallRng`], seeded from [`rand::thread_rng`]).
59+
/// [`rand::rngs::SmallRng`], seeded from [`rand::rng`]).
6160
///
6261
/// Note that `SmallRng` is cheap to initialize and fast, but it may generate
6362
/// low-quality random numbers, and reproducibility is not guaranteed. See its
6463
/// documentation for information. You can select a different RNG with
65-
/// [`.random_using()`](Self::random_using).
64+
/// [`.random_using()`](RandomExt::random_using).
6665
pub trait RandomExt<S, A, D>
6766
where
6867
S: RawData<Elem = A>,
@@ -124,6 +123,40 @@ where
124123
S: DataOwned<Elem = A>,
125124
Sh: ShapeBuilder<Dim = D>;
126125

126+
/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
127+
///
128+
/// See [`RandomRefExt::sample_axis`] for additional information.
129+
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
130+
where
131+
A: Copy,
132+
S: Data<Elem = A>,
133+
D: RemoveAxis;
134+
135+
/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
136+
///
137+
/// See [`RandomRefExt::sample_axis_using`] for additional information.
138+
fn sample_axis_using<R>(
139+
&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R,
140+
) -> Array<A, D>
141+
where
142+
R: Rng + ?Sized,
143+
A: Copy,
144+
S: Data<Elem = A>,
145+
D: RemoveAxis;
146+
}
147+
148+
/// Extension trait for sampling from [`ArrayRef`] with random elements.
149+
///
150+
/// The default RNG is a fast, automatically seeded rng (currently
151+
/// [`rand::rngs::SmallRng`], seeded from [`rand::rng`]).
152+
///
153+
/// Note that `SmallRng` is cheap to initialize and fast, but it may generate
154+
/// low-quality random numbers, and reproducibility is not guaranteed. See its
155+
/// documentation for information. You can select a different RNG with
156+
/// [`.sample_axis_using()`](RandomRefExt::sample_axis_using).
157+
pub trait RandomRefExt<A, D>
158+
where D: Dimension
159+
{
127160
/// Sample `n_samples` lanes slicing along `axis` using the default RNG.
128161
///
129162
/// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
@@ -168,7 +201,6 @@ where
168201
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
169202
where
170203
A: Copy,
171-
S: Data<Elem = A>,
172204
D: RemoveAxis;
173205

174206
/// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
@@ -225,7 +257,6 @@ where
225257
where
226258
R: Rng + ?Sized,
227259
A: Copy,
228-
S: Data<Elem = A>,
229260
D: RemoveAxis;
230261
}
231262

@@ -259,7 +290,7 @@ where
259290
S: Data<Elem = A>,
260291
D: RemoveAxis,
261292
{
262-
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
293+
(**self).sample_axis(axis, n_samples, strategy)
263294
}
264295

265296
fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
@@ -268,6 +299,27 @@ where
268299
A: Copy,
269300
S: Data<Elem = A>,
270301
D: RemoveAxis,
302+
{
303+
(**self).sample_axis_using(axis, n_samples, strategy, rng)
304+
}
305+
}
306+
307+
impl<A, D> RandomRefExt<A, D> for ArrayRef<A, D>
308+
where D: Dimension
309+
{
310+
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
311+
where
312+
A: Copy,
313+
D: RemoveAxis,
314+
{
315+
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
316+
}
317+
318+
fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
319+
where
320+
R: Rng + ?Sized,
321+
A: Copy,
322+
D: RemoveAxis,
271323
{
272324
let indices: Vec<_> = match strategy {
273325
SamplingStrategy::WithReplacement => {
@@ -284,9 +336,10 @@ where
284336
/// if lanes from the original array should only be sampled once (*without replacement*) or
285337
/// multiple times (*with replacement*).
286338
///
287-
/// [`sample_axis`]: RandomExt::sample_axis
288-
/// [`sample_axis_using`]: RandomExt::sample_axis_using
339+
/// [`sample_axis`]: RandomRefExt::sample_axis
340+
/// [`sample_axis_using`]: RandomRefExt::sample_axis_using
289341
#[derive(Debug, Clone)]
342+
#[allow(missing_docs)]
290343
pub enum SamplingStrategy
291344
{
292345
WithReplacement,

0 commit comments

Comments
 (0)