Skip to content

Commit 49cad1f

Browse files
committed
Implement Cheng algorithm for sampling Beta
This should be faster than the gamma variate transformation we are currently using, and it seems to work better for parameters smaller than one. The algorithm is also used by the R language, however I did not consult their implementation in order to avoid licensing problems. Reference: R. C. H. Cheng (1978). Generating beta variates with nonintegral shape parameters. Communications of the ACM 21, 317-322. https://doi.org/10.1145/359460.359482
1 parent dca9cb5 commit 49cad1f

File tree

4 files changed

+188
-28
lines changed

4 files changed

+188
-28
lines changed

rand_distr/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- All error types now implement `std::error::Error` (#919)
1515
- Re-exported `rand::distributions::BernoulliError` (#919)
1616
- Add case `lambda = 0` in the parametrixation of `Exp` (#972)
17+
- Improve algorithm for sampling `Beta` (#1000)
1718

1819
## [0.2.2] - 2019-09-10
1920
- Fix version requirement on rand lib (#847)

rand_distr/benches/distributions.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ distr_float!(distr_normal, f64, Normal::new(-1.23, 4.56).unwrap());
112112
distr_float!(distr_log_normal, f64, LogNormal::new(-1.23, 4.56).unwrap());
113113
distr_float!(distr_gamma_large_shape, f64, Gamma::new(10., 1.0).unwrap());
114114
distr_float!(distr_gamma_small_shape, f64, Gamma::new(0.1, 1.0).unwrap());
115+
distr_float!(distr_beta_small_param, f64, Beta::new(0.1, 0.1).unwrap());
116+
distr_float!(distr_beta_large_param_similar, f64, Beta::new(101., 95.).unwrap());
117+
distr_float!(distr_beta_large_param_different, f64, Beta::new(10., 1000.).unwrap());
118+
distr_float!(distr_beta_mixed_param, f64, Beta::new(0.5, 100.).unwrap());
115119
distr_float!(distr_cauchy, f64, Cauchy::new(4.2, 6.9).unwrap());
116120
distr_float!(distr_triangular, f64, Triangular::new(0., 1., 0.9).unwrap());
117121
distr_int!(distr_binomial, u64, Binomial::new(20, 0.7).unwrap());

rand_distr/src/gamma.rs

Lines changed: 164 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,38 @@ where
495495
}
496496
}
497497

498+
/// The algorithm used for sampling the Beta distribution.
499+
///
500+
/// Reference:
501+
///
502+
/// R. C. H. Cheng (1978).
503+
/// Generating beta variates with nonintegral shape parameters.
504+
/// Communications of the ACM 21, 317-322.
505+
/// https://doi.org/10.1145/359460.359482
506+
#[derive(Clone, Copy, Debug)]
507+
enum BetaAlgorithm<N> {
508+
BB(BB<N>),
509+
BC(BC<N>),
510+
}
511+
512+
/// Algorithm BB for `min(alpha, beta) > 1`.
513+
#[derive(Clone, Copy, Debug)]
514+
struct BB<N> {
515+
alpha: N,
516+
beta: N,
517+
gamma: N,
518+
}
519+
520+
/// Algorithm BC for `min(alpha, beta) <= 1`.
521+
#[derive(Clone, Copy, Debug)]
522+
struct BC<N> {
523+
alpha: N,
524+
beta: N,
525+
delta: N,
526+
kappa1: N,
527+
kappa2: N,
528+
}
529+
498530
/// The Beta distribution with shape parameters `alpha` and `beta`.
499531
///
500532
/// # Example
@@ -510,12 +542,11 @@ where
510542
pub struct Beta<F>
511543
where
512544
F: Float,
513-
StandardNormal: Distribution<F>,
514-
Exp1: Distribution<F>,
515545
Open01: Distribution<F>,
516546
{
517-
gamma_a: Gamma<F>,
518-
gamma_b: Gamma<F>,
547+
a: F, a0: F,
548+
b: F, b0: F,
549+
algorithm: BetaAlgorithm<F>,
519550
}
520551

521552
/// Error type returned from `Beta::new`.
@@ -542,31 +573,140 @@ impl std::error::Error for BetaError {}
542573
impl<F> Beta<F>
543574
where
544575
F: Float,
545-
StandardNormal: Distribution<F>,
546-
Exp1: Distribution<F>,
547576
Open01: Distribution<F>,
548577
{
549578
/// Construct an object representing the `Beta(alpha, beta)`
550579
/// distribution.
551580
pub fn new(alpha: F, beta: F) -> Result<Beta<F>, BetaError> {
552-
Ok(Beta {
553-
gamma_a: Gamma::new(alpha, F::one()).map_err(|_| BetaError::AlphaTooSmall)?,
554-
gamma_b: Gamma::new(beta, F::one()).map_err(|_| BetaError::BetaTooSmall)?,
555-
})
581+
if !(alpha > F::zero()) {
582+
return Err(BetaError::AlphaTooSmall);
583+
}
584+
if !(beta > F::zero()) {
585+
return Err(BetaError::BetaTooSmall);
586+
}
587+
// From now on, we use the notation from the reference,
588+
// i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
589+
let (a0, b0) = (alpha, beta);
590+
let (a, b) = if a0 < b0 { (a0, b0) } else { (b0, a0) };
591+
if alpha > F::one() {
592+
let alpha = a + b;
593+
let beta = ((alpha - F::from(2.).unwrap())
594+
/ (F::from(2.).unwrap()*a*b - alpha)).sqrt();
595+
let gamma = a + F::one() / beta;
596+
597+
Ok(Beta {
598+
a, a0, b, b0,
599+
algorithm: BetaAlgorithm::BB(BB {
600+
alpha, beta, gamma,
601+
})
602+
})
603+
} else {
604+
let alpha = a + b;
605+
let beta = F::one() / b;
606+
let delta = F::one() + a - b;
607+
let kappa1 = delta
608+
* (F::from(0.0138889).unwrap() + F::from(0.0416667).unwrap()*b)
609+
/ (a*beta - F::from(0.777778).unwrap());
610+
let kappa2 = F::from(0.25).unwrap()
611+
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b;
612+
613+
Ok(Beta {
614+
a, a0, b, b0,
615+
algorithm: BetaAlgorithm::BC(BC {
616+
alpha, beta, delta, kappa1, kappa2,
617+
})
618+
})
619+
}
556620
}
557621
}
558622

559623
impl<F> Distribution<F> for Beta<F>
560624
where
561625
F: Float,
562-
StandardNormal: Distribution<F>,
563-
Exp1: Distribution<F>,
564626
Open01: Distribution<F>,
565627
{
566628
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
567-
let x = self.gamma_a.sample(rng);
568-
let y = self.gamma_b.sample(rng);
569-
x / (x + y)
629+
match self.algorithm {
630+
BetaAlgorithm::BB(algo) => {
631+
let mut w;
632+
loop {
633+
// 1.
634+
let u1 = rng.sample(Open01);
635+
let u2 = rng.sample(Open01);
636+
let v = algo.beta * (u1 / (F::one() - u1)).ln();
637+
w = self.a * v.exp();
638+
let z = u1*u1 * u2;
639+
let r = algo.gamma * v - F::from(4.).unwrap().ln();
640+
let s = self.a + r - w;
641+
// 2.
642+
if s + F::one() + F::from(5.).unwrap().ln()
643+
>= F::from(5.).unwrap() * z {
644+
break;
645+
}
646+
// 3.
647+
let t = z.ln();
648+
if s >= t {
649+
break;
650+
}
651+
// 4.
652+
if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
653+
break;
654+
}
655+
}
656+
// 5.
657+
if self.a == self.a0 {
658+
w / (self.b + w)
659+
} else {
660+
self.b / (self.b + w)
661+
}
662+
},
663+
BetaAlgorithm::BC(algo) => {
664+
let mut w;
665+
loop {
666+
let z;
667+
// 1.
668+
let u1 = rng.sample(Open01);
669+
let u2 = rng.sample(Open01);
670+
if u1 < F::from(0.5).unwrap() {
671+
// 2.
672+
let y = u1 * u2;
673+
z = u1 * y;
674+
if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
675+
continue;
676+
}
677+
} else {
678+
// 3.
679+
z = u1 * u1 * u2;
680+
if z <= F::from(0.25).unwrap() {
681+
let v = algo.beta * (u1 / (F::one() - u1)).ln();
682+
w = self.a * v.exp();
683+
break;
684+
}
685+
// 4.
686+
if z >= algo.kappa2 {
687+
continue;
688+
}
689+
}
690+
// 5.
691+
let v = algo.beta * (u1 / (F::one() - u1)).ln();
692+
w = self.a * v.exp();
693+
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
694+
- F::from(1.3862944).unwrap() < z.ln()) {
695+
break;
696+
};
697+
}
698+
// 6.
699+
if self.a == self.a0 {
700+
if w == F::infinity() {
701+
// Assuming `b` is finite, for large `w`:
702+
return F::one();
703+
}
704+
w / (self.b + w)
705+
} else {
706+
self.b / (self.b + w)
707+
}
708+
},
709+
}
570710
}
571711
}
572712

@@ -636,4 +776,13 @@ mod test {
636776
fn test_beta_invalid_dof() {
637777
Beta::new(0., 0.).unwrap();
638778
}
779+
780+
#[test]
781+
fn test_beta_small_param() {
782+
let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
783+
let mut rng = crate::test::rng(206);
784+
for i in 0..1000 {
785+
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
786+
}
787+
}
639788
}

rand_distr/tests/value_stability.rs

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,11 @@ fn normal_inverse_gaussian_stability() {
121121
fn pert_stability() {
122122
// mean = 4, var = 12/7
123123
test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[
124-
4.631484136029422f64,
125-
3.307201472321789f64,
126-
3.29995019556348f64,
127-
3.66835483991721f64,
128-
3.514246139933899f64,
124+
4.908681667460367,
125+
4.014196196158352,
126+
2.6489397149197234,
127+
3.4569780580044727,
128+
4.242864311947118,
129129
]);
130130
}
131131

@@ -200,15 +200,21 @@ fn gamma_stability() {
200200
-2.377641221169782,
201201
]);
202202

203-
// Beta has same special cases as Gamma on each param
203+
// Beta has two special cases:
204+
//
205+
// 1. min(alpha, beta) <= 1
206+
// 2. min(alpha, beta > 1
204207
test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[
205-
0.6444564f32, 0.357635, 0.4110078, 0.7347192,
206-
]);
207-
test_samples(223, Beta::new(0.7, 1.2).unwrap(), &[
208-
0.6433129944095513f64,
209-
0.5373371199711573,
210-
0.10313293199269491,
211-
0.002472280249144378,
208+
0.2958284085602274,
209+
0.9384411906056516,
210+
0.3151361582723264,
211+
0.6150273348630618,
212+
]);
213+
test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[
214+
0.49563509121756827,
215+
0.9551305482256759,
216+
0.5151181353461637,
217+
0.7551732971235077,
212218
]);
213219
}
214220

0 commit comments

Comments
 (0)