@@ -495,6 +495,38 @@ where
495
495
}
496
496
}
497
497
498
+ /// The algorithm used for sampling the Beta distribution.
499
+ ///
500
+ /// Reference:
501
+ ///
502
+ /// R. C. H. Cheng (1978).
503
+ /// Generating beta variates with nonintegral shape parameters.
504
+ /// Communications of the ACM 21, 317-322.
505
+ /// https://doi.org/10.1145/359460.359482
506
+ #[ derive( Clone , Copy , Debug ) ]
507
+ enum BetaAlgorithm < N > {
508
+ BB ( BB < N > ) ,
509
+ BC ( BC < N > ) ,
510
+ }
511
+
512
+ /// Algorithm BB for `min(alpha, beta) > 1`.
513
+ #[ derive( Clone , Copy , Debug ) ]
514
+ struct BB < N > {
515
+ alpha : N ,
516
+ beta : N ,
517
+ gamma : N ,
518
+ }
519
+
520
+ /// Algorithm BC for `min(alpha, beta) <= 1`.
521
+ #[ derive( Clone , Copy , Debug ) ]
522
+ struct BC < N > {
523
+ alpha : N ,
524
+ beta : N ,
525
+ delta : N ,
526
+ kappa1 : N ,
527
+ kappa2 : N ,
528
+ }
529
+
498
530
/// The Beta distribution with shape parameters `alpha` and `beta`.
499
531
///
500
532
/// # Example
@@ -510,12 +542,11 @@ where
510
542
pub struct Beta < F >
511
543
where
512
544
F : Float ,
513
- StandardNormal : Distribution < F > ,
514
- Exp1 : Distribution < F > ,
515
545
Open01 : Distribution < F > ,
516
546
{
517
- gamma_a : Gamma < F > ,
518
- gamma_b : Gamma < F > ,
547
+ a : F , a0 : F ,
548
+ b : F , b0 : F ,
549
+ algorithm : BetaAlgorithm < F > ,
519
550
}
520
551
521
552
/// Error type returned from `Beta::new`.
@@ -542,31 +573,140 @@ impl std::error::Error for BetaError {}
542
573
impl < F > Beta < F >
543
574
where
544
575
F : Float ,
545
- StandardNormal : Distribution < F > ,
546
- Exp1 : Distribution < F > ,
547
576
Open01 : Distribution < F > ,
548
577
{
549
578
/// Construct an object representing the `Beta(alpha, beta)`
550
579
/// distribution.
551
580
pub fn new ( alpha : F , beta : F ) -> Result < Beta < F > , BetaError > {
552
- Ok ( Beta {
553
- gamma_a : Gamma :: new ( alpha, F :: one ( ) ) . map_err ( |_| BetaError :: AlphaTooSmall ) ?,
554
- gamma_b : Gamma :: new ( beta, F :: one ( ) ) . map_err ( |_| BetaError :: BetaTooSmall ) ?,
555
- } )
581
+ if !( alpha > F :: zero ( ) ) {
582
+ return Err ( BetaError :: AlphaTooSmall ) ;
583
+ }
584
+ if !( beta > F :: zero ( ) ) {
585
+ return Err ( BetaError :: BetaTooSmall ) ;
586
+ }
587
+ // From now on, we use the notation from the reference,
588
+ // i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
589
+ let ( a0, b0) = ( alpha, beta) ;
590
+ let ( a, b) = if a0 < b0 { ( a0, b0) } else { ( b0, a0) } ;
591
+ if alpha > F :: one ( ) {
592
+ let alpha = a + b;
593
+ let beta = ( ( alpha - F :: from ( 2. ) . unwrap ( ) )
594
+ / ( F :: from ( 2. ) . unwrap ( ) * a* b - alpha) ) . sqrt ( ) ;
595
+ let gamma = a + F :: one ( ) / beta;
596
+
597
+ Ok ( Beta {
598
+ a, a0, b, b0,
599
+ algorithm : BetaAlgorithm :: BB ( BB {
600
+ alpha, beta, gamma,
601
+ } )
602
+ } )
603
+ } else {
604
+ let alpha = a + b;
605
+ let beta = F :: one ( ) / b;
606
+ let delta = F :: one ( ) + a - b;
607
+ let kappa1 = delta
608
+ * ( F :: from ( 0.0138889 ) . unwrap ( ) + F :: from ( 0.0416667 ) . unwrap ( ) * b)
609
+ / ( a* beta - F :: from ( 0.777778 ) . unwrap ( ) ) ;
610
+ let kappa2 = F :: from ( 0.25 ) . unwrap ( )
611
+ + ( F :: from ( 0.5 ) . unwrap ( ) + F :: from ( 0.25 ) . unwrap ( ) /delta) * b;
612
+
613
+ Ok ( Beta {
614
+ a, a0, b, b0,
615
+ algorithm : BetaAlgorithm :: BC ( BC {
616
+ alpha, beta, delta, kappa1, kappa2,
617
+ } )
618
+ } )
619
+ }
556
620
}
557
621
}
558
622
559
623
impl < F > Distribution < F > for Beta < F >
560
624
where
561
625
F : Float ,
562
- StandardNormal : Distribution < F > ,
563
- Exp1 : Distribution < F > ,
564
626
Open01 : Distribution < F > ,
565
627
{
566
628
fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> F {
567
- let x = self . gamma_a . sample ( rng) ;
568
- let y = self . gamma_b . sample ( rng) ;
569
- x / ( x + y)
629
+ match self . algorithm {
630
+ BetaAlgorithm :: BB ( algo) => {
631
+ let mut w;
632
+ loop {
633
+ // 1.
634
+ let u1 = rng. sample ( Open01 ) ;
635
+ let u2 = rng. sample ( Open01 ) ;
636
+ let v = algo. beta * ( u1 / ( F :: one ( ) - u1) ) . ln ( ) ;
637
+ w = self . a * v. exp ( ) ;
638
+ let z = u1* u1 * u2;
639
+ let r = algo. gamma * v - F :: from ( 4. ) . unwrap ( ) . ln ( ) ;
640
+ let s = self . a + r - w;
641
+ // 2.
642
+ if s + F :: one ( ) + F :: from ( 5. ) . unwrap ( ) . ln ( )
643
+ >= F :: from ( 5. ) . unwrap ( ) * z {
644
+ break ;
645
+ }
646
+ // 3.
647
+ let t = z. ln ( ) ;
648
+ if s >= t {
649
+ break ;
650
+ }
651
+ // 4.
652
+ if !( r + algo. alpha * ( algo. alpha / ( self . b + w) ) . ln ( ) < t) {
653
+ break ;
654
+ }
655
+ }
656
+ // 5.
657
+ if self . a == self . a0 {
658
+ w / ( self . b + w)
659
+ } else {
660
+ self . b / ( self . b + w)
661
+ }
662
+ } ,
663
+ BetaAlgorithm :: BC ( algo) => {
664
+ let mut w;
665
+ loop {
666
+ let z;
667
+ // 1.
668
+ let u1 = rng. sample ( Open01 ) ;
669
+ let u2 = rng. sample ( Open01 ) ;
670
+ if u1 < F :: from ( 0.5 ) . unwrap ( ) {
671
+ // 2.
672
+ let y = u1 * u2;
673
+ z = u1 * y;
674
+ if F :: from ( 0.25 ) . unwrap ( ) * u2 + z - y >= algo. kappa1 {
675
+ continue ;
676
+ }
677
+ } else {
678
+ // 3.
679
+ z = u1 * u1 * u2;
680
+ if z <= F :: from ( 0.25 ) . unwrap ( ) {
681
+ let v = algo. beta * ( u1 / ( F :: one ( ) - u1) ) . ln ( ) ;
682
+ w = self . a * v. exp ( ) ;
683
+ break ;
684
+ }
685
+ // 4.
686
+ if z >= algo. kappa2 {
687
+ continue ;
688
+ }
689
+ }
690
+ // 5.
691
+ let v = algo. beta * ( u1 / ( F :: one ( ) - u1) ) . ln ( ) ;
692
+ w = self . a * v. exp ( ) ;
693
+ if !( algo. alpha * ( ( algo. alpha / ( self . b + w) ) . ln ( ) + v)
694
+ - F :: from ( 1.3862944 ) . unwrap ( ) < z. ln ( ) ) {
695
+ break ;
696
+ } ;
697
+ }
698
+ // 6.
699
+ if self . a == self . a0 {
700
+ if w == F :: infinity ( ) {
701
+ // Assuming `b` is finite, for large `w`:
702
+ return F :: one ( ) ;
703
+ }
704
+ w / ( self . b + w)
705
+ } else {
706
+ self . b / ( self . b + w)
707
+ }
708
+ } ,
709
+ }
570
710
}
571
711
}
572
712
@@ -636,4 +776,13 @@ mod test {
636
776
fn test_beta_invalid_dof ( ) {
637
777
Beta :: new ( 0. , 0. ) . unwrap ( ) ;
638
778
}
779
+
780
+ #[ test]
781
+ fn test_beta_small_param ( ) {
782
+ let beta = Beta :: < f64 > :: new ( 1e-3 , 1e-3 ) . unwrap ( ) ;
783
+ let mut rng = crate :: test:: rng ( 206 ) ;
784
+ for i in 0 ..1000 {
785
+ assert ! ( !beta. sample( & mut rng) . is_nan( ) , "failed at i={}" , i) ;
786
+ }
787
+ }
639
788
}
0 commit comments