From ecc39cc5eeddbd33ee37026b84802e25c3949415 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Tue, 2 Apr 2024 18:03:00 -0500 Subject: [PATCH] fix: mode for gamma distribution is 0 for shape<=1 --- src/distribution/gamma.rs | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 7a36a30f..9c497487 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -133,17 +133,13 @@ impl ContinuousCDF for Gamma { fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 - } - else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { + } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { 0.0 - } - else if self.rate.is_infinite() { + } else if self.rate.is_infinite() { 1.0 - } - else if x.is_infinite() { + } else if x.is_infinite() { 0.0 - } - else { + } else { gamma::gamma_ur(self.shape, x * self.rate) } } @@ -239,13 +235,17 @@ impl Mode> for Gamma { /// /// # Formula /// - /// ```ignore - /// (α - 1) / β + /// ```text + /// max{(α - 1) / β, 0} /// ``` /// /// where `α` is the shape and `β` is the rate fn mode(&self) -> Option { - Some((self.shape - 1.0) / self.rate) + if self.shape > 1.0 { + Some((self.shape - 1.0) / self.rate) + } else { + Some(0.0) + } } } @@ -353,6 +353,7 @@ pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> #[cfg(all(test, feature = "nightly"))] mod tests { + use super::*; use crate::consts::ACC; use crate::distribution::internal::*; @@ -452,10 +453,12 @@ mod tests { #[test] fn test_mode() { + use rand::distributions::uniform::UniformFloat; + use rand::rngs; let f = |x: Gamma| x.mode().unwrap(); - let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)]; - for &(arg, res) in test.iter() { - test_case_special(arg, res, 10e-6, f); + let test = [(1.0, 0.1), (1.0, 1.0)]; + for &arg in test.iter() { + test_case_special(arg, 0.0, 10e-6, f); } let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, INF), 0.0)]; for &(arg, res) in test.iter() {