8
8
// except according to those terms.
9
9
10
10
//! The dirichlet distribution.
11
-
12
- use crate :: utils :: Float ;
11
+ #! [ cfg ( feature = "alloc" ) ]
12
+ use num_traits :: Float ;
13
13
use crate :: { Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
14
14
use rand:: Rng ;
15
- use std:: { error, fmt} ;
15
+ use core:: fmt;
16
+ use alloc:: { boxed:: Box , vec, vec:: Vec } ;
16
17
17
18
/// The Dirichlet distribution `Dirichlet(alpha)`.
18
19
///
@@ -26,14 +27,20 @@ use std::{error, fmt};
26
27
/// use rand::prelude::*;
27
28
/// use rand_distr::Dirichlet;
28
29
///
29
- /// let dirichlet = Dirichlet::new(vec! [1.0, 2.0, 3.0]).unwrap();
30
+ /// let dirichlet = Dirichlet::new(& [1.0, 2.0, 3.0]).unwrap();
30
31
/// let samples = dirichlet.sample(&mut rand::thread_rng());
31
32
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
32
33
/// ```
33
34
#[ derive( Clone , Debug ) ]
34
- pub struct Dirichlet < N > {
35
+ pub struct Dirichlet < F >
36
+ where
37
+ F : Float ,
38
+ StandardNormal : Distribution < F > ,
39
+ Exp1 : Distribution < F > ,
40
+ Open01 : Distribution < F > ,
41
+ {
35
42
/// Concentration parameters (alpha)
36
- alpha : Vec < N > ,
43
+ alpha : Box < [ F ] > ,
37
44
}
38
45
39
46
/// Error type returned from `Dirchlet::new`.
@@ -58,68 +65,70 @@ impl fmt::Display for Error {
58
65
}
59
66
}
60
67
61
- impl error:: Error for Error { }
68
+ #[ cfg( feature = "std" ) ]
69
+ impl std:: error:: Error for Error { }
62
70
63
- impl < N : Float > Dirichlet < N >
71
+ impl < F > Dirichlet < F >
64
72
where
65
- StandardNormal : Distribution < N > ,
66
- Exp1 : Distribution < N > ,
67
- Open01 : Distribution < N > ,
73
+ F : Float ,
74
+ StandardNormal : Distribution < F > ,
75
+ Exp1 : Distribution < F > ,
76
+ Open01 : Distribution < F > ,
68
77
{
69
78
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
70
79
///
71
80
/// Requires `alpha.len() >= 2`.
72
81
#[ inline]
73
- pub fn new < V : Into < Vec < N > > > ( alpha : V ) -> Result < Dirichlet < N > , Error > {
74
- let a = alpha. into ( ) ;
75
- if a. len ( ) < 2 {
82
+ pub fn new ( alpha : & [ F ] ) -> Result < Dirichlet < F > , Error > {
83
+ if alpha. len ( ) < 2 {
76
84
return Err ( Error :: AlphaTooShort ) ;
77
85
}
78
- for & ai in & a {
79
- if !( ai > N :: from ( 0.0 ) ) {
86
+ for & ai in alpha . iter ( ) {
87
+ if !( ai > F :: zero ( ) ) {
80
88
return Err ( Error :: AlphaTooSmall ) ;
81
89
}
82
90
}
83
91
84
- Ok ( Dirichlet { alpha : a } )
92
+ Ok ( Dirichlet { alpha : alpha . to_vec ( ) . into_boxed_slice ( ) } )
85
93
}
86
94
87
95
/// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
88
96
///
89
97
/// Requires `size >= 2`.
90
98
#[ inline]
91
- pub fn new_with_size ( alpha : N , size : usize ) -> Result < Dirichlet < N > , Error > {
92
- if !( alpha > N :: from ( 0.0 ) ) {
99
+ pub fn new_with_size ( alpha : F , size : usize ) -> Result < Dirichlet < F > , Error > {
100
+ if !( alpha > F :: zero ( ) ) {
93
101
return Err ( Error :: AlphaTooSmall ) ;
94
102
}
95
103
if size < 2 {
96
104
return Err ( Error :: SizeTooSmall ) ;
97
105
}
98
106
Ok ( Dirichlet {
99
- alpha : vec ! [ alpha; size] ,
107
+ alpha : vec ! [ alpha; size] . into_boxed_slice ( ) ,
100
108
} )
101
109
}
102
110
}
103
111
104
- impl < N : Float > Distribution < Vec < N > > for Dirichlet < N >
112
+ impl < F > Distribution < Vec < F > > for Dirichlet < F >
105
113
where
106
- StandardNormal : Distribution < N > ,
107
- Exp1 : Distribution < N > ,
108
- Open01 : Distribution < N > ,
114
+ F : Float ,
115
+ StandardNormal : Distribution < F > ,
116
+ Exp1 : Distribution < F > ,
117
+ Open01 : Distribution < F > ,
109
118
{
110
- fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < N > {
119
+ fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < F > {
111
120
let n = self . alpha . len ( ) ;
112
- let mut samples = vec ! [ N :: from ( 0.0 ) ; n] ;
113
- let mut sum = N :: from ( 0.0 ) ;
121
+ let mut samples = vec ! [ F :: zero ( ) ; n] ;
122
+ let mut sum = F :: zero ( ) ;
114
123
115
124
for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
116
- let g = Gamma :: new ( a, N :: from ( 1.0 ) ) . unwrap ( ) ;
125
+ let g = Gamma :: new ( a, F :: one ( ) ) . unwrap ( ) ;
117
126
* s = g. sample ( rng) ;
118
- sum += * s ;
127
+ sum = sum + ( * s ) ;
119
128
}
120
- let invacc = N :: from ( 1.0 ) / sum;
129
+ let invacc = F :: one ( ) / sum;
121
130
for s in samples. iter_mut ( ) {
122
- * s *= invacc;
131
+ * s = ( * s ) * invacc;
123
132
}
124
133
samples
125
134
}
@@ -131,7 +140,7 @@ mod test {
131
140
132
141
#[ test]
133
142
fn test_dirichlet ( ) {
134
- let d = Dirichlet :: new ( vec ! [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
143
+ let d = Dirichlet :: new ( & [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
135
144
let mut rng = crate :: test:: rng ( 221 ) ;
136
145
let samples = d. sample ( & mut rng) ;
137
146
let _: Vec < f64 > = samples
@@ -170,20 +179,4 @@ mod test {
170
179
fn test_dirichlet_invalid_alpha ( ) {
171
180
Dirichlet :: new_with_size ( 0.0f64 , 2 ) . unwrap ( ) ;
172
181
}
173
-
174
- #[ test]
175
- fn value_stability ( ) {
176
- let mut rng = crate :: test:: rng ( 223 ) ;
177
- assert_eq ! (
178
- rng. sample( Dirichlet :: new( vec![ 1.0 , 2.0 , 3.0 ] ) . unwrap( ) ) ,
179
- vec![ 0.12941567177708177 , 0.4702121891675036 , 0.4003721390554146 ]
180
- ) ;
181
- assert_eq ! ( rng. sample( Dirichlet :: new_with_size( 8.0 , 5 ) . unwrap( ) ) , vec![
182
- 0.17684200044809556 ,
183
- 0.29915953935953055 ,
184
- 0.1832858056608014 ,
185
- 0.1425623503573967 ,
186
- 0.19815030417417595
187
- ] ) ;
188
- }
189
182
}
0 commit comments