2929//! that the items are not compatible (e.g. that a type doesn't implement a
3030//! necessary trait).
3131
32+ #![ warn( missing_docs) ]
33+
3234use crate :: rand:: distr:: { Distribution , Uniform } ;
3335use crate :: rand:: rngs:: SmallRng ;
3436use crate :: rand:: seq:: index;
3537use crate :: rand:: { rng, Rng , SeedableRng } ;
3638
37- use ndarray:: { Array , Axis , RemoveAxis , ShapeBuilder } ;
39+ use ndarray:: { Array , ArrayRef , Axis , RemoveAxis , ShapeBuilder } ;
3840use ndarray:: { ArrayBase , Data , DataOwned , Dimension , RawData } ;
3941#[ cfg( feature = "quickcheck" ) ]
4042use quickcheck:: { Arbitrary , Gen } ;
@@ -51,18 +53,15 @@ pub mod rand_distr
5153 pub use rand_distr:: * ;
5254}
5355
54- /// Constructors for n-dimensional arrays with random elements.
55- ///
56- /// This trait extends ndarray’s `ArrayBase` and can not be implemented
57- /// for other types.
56+ /// Extension trait for constructing n-dimensional arrays with random elements.
5857///
5958/// The default RNG is a fast automatically seeded rng (currently
60- /// [`rand::rngs::SmallRng`], seeded from [`rand::thread_rng `]).
59+ /// [`rand::rngs::SmallRng`], seeded from [`rand::rng `]).
6160///
6261/// Note that `SmallRng` is cheap to initialize and fast, but it may generate
6362/// low-quality random numbers, and reproducibility is not guaranteed. See its
6463/// documentation for information. You can select a different RNG with
65- /// [`.random_using()`](Self ::random_using).
64+ /// [`.random_using()`](RandomExt ::random_using).
6665pub trait RandomExt < S , A , D >
6766where
6867 S : RawData < Elem = A > ,
@@ -124,6 +123,40 @@ where
124123 S : DataOwned < Elem = A > ,
125124 Sh : ShapeBuilder < Dim = D > ;
126125
126+ /// Sample `n_samples` lanes slicing along `axis` using the default RNG.
127+ ///
128+ /// See [`RandomRefExt::sample_axis`] for additional information.
129+ fn sample_axis ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy ) -> Array < A , D >
130+ where
131+ A : Copy ,
132+ S : Data < Elem = A > ,
133+ D : RemoveAxis ;
134+
135+ /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
136+ ///
137+ /// See [`RandomRefExt::sample_axis_using`] for additional information.
138+ fn sample_axis_using < R > (
139+ & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy , rng : & mut R ,
140+ ) -> Array < A , D >
141+ where
142+ R : Rng + ?Sized ,
143+ A : Copy ,
144+ S : Data < Elem = A > ,
145+ D : RemoveAxis ;
146+ }
147+
148+ /// Extension trait for sampling from [`ArrayRef`] with random elements.
149+ ///
150+ /// The default RNG is a fast, automatically seeded rng (currently
151+ /// [`rand::rngs::SmallRng`], seeded from [`rand::rng`]).
152+ ///
153+ /// Note that `SmallRng` is cheap to initialize and fast, but it may generate
154+ /// low-quality random numbers, and reproducibility is not guaranteed. See its
155+ /// documentation for information. You can select a different RNG with
156+ /// [`.sample_axis_using()`](RandomRefExt::sample_axis_using).
157+ pub trait RandomRefExt < A , D >
158+ where D : Dimension
159+ {
127160 /// Sample `n_samples` lanes slicing along `axis` using the default RNG.
128161 ///
129162 /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
@@ -168,7 +201,6 @@ where
168201 fn sample_axis ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy ) -> Array < A , D >
169202 where
170203 A : Copy ,
171- S : Data < Elem = A > ,
172204 D : RemoveAxis ;
173205
174206 /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
@@ -225,7 +257,6 @@ where
225257 where
226258 R : Rng + ?Sized ,
227259 A : Copy ,
228- S : Data < Elem = A > ,
229260 D : RemoveAxis ;
230261}
231262
@@ -259,7 +290,7 @@ where
259290 S : Data < Elem = A > ,
260291 D : RemoveAxis ,
261292 {
262- self . sample_axis_using ( axis, n_samples, strategy, & mut get_rng ( ) )
293+ ( * * self ) . sample_axis ( axis, n_samples, strategy)
263294 }
264295
265296 fn sample_axis_using < R > ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy , rng : & mut R ) -> Array < A , D >
@@ -268,6 +299,27 @@ where
268299 A : Copy ,
269300 S : Data < Elem = A > ,
270301 D : RemoveAxis ,
302+ {
303+ ( * * self ) . sample_axis_using ( axis, n_samples, strategy, rng)
304+ }
305+ }
306+
307+ impl < A , D > RandomRefExt < A , D > for ArrayRef < A , D >
308+ where D : Dimension
309+ {
310+ fn sample_axis ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy ) -> Array < A , D >
311+ where
312+ A : Copy ,
313+ D : RemoveAxis ,
314+ {
315+ self . sample_axis_using ( axis, n_samples, strategy, & mut get_rng ( ) )
316+ }
317+
318+ fn sample_axis_using < R > ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy , rng : & mut R ) -> Array < A , D >
319+ where
320+ R : Rng + ?Sized ,
321+ A : Copy ,
322+ D : RemoveAxis ,
271323 {
272324 let indices: Vec < _ > = match strategy {
273325 SamplingStrategy :: WithReplacement => {
@@ -284,9 +336,10 @@ where
284336/// if lanes from the original array should only be sampled once (*without replacement*) or
285337/// multiple times (*with replacement*).
286338///
287- /// [`sample_axis`]: RandomExt ::sample_axis
288- /// [`sample_axis_using`]: RandomExt ::sample_axis_using
339+ /// [`sample_axis`]: RandomRefExt ::sample_axis
340+ /// [`sample_axis_using`]: RandomRefExt ::sample_axis_using
289341#[ derive( Debug , Clone ) ]
342+ #[ allow( missing_docs) ]
290343pub enum SamplingStrategy
291344{
292345 WithReplacement ,
0 commit comments