From 6faaabf7b3a2801f12354029621f56061a976462 Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Wed, 4 Sep 2024 21:48:14 +0200 Subject: [PATCH] Add MultinomialError Move check from separate fn into `Multinomial::new` --- src/distribution/internal.rs | 52 ------------------------- src/distribution/mod.rs | 2 +- src/distribution/multinomial.rs | 69 ++++++++++++++++++++++++++------- 3 files changed, 57 insertions(+), 66 deletions(-) diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 85ad0e73..c3b9f22c 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,42 +1,5 @@ use num_traits::Num; -#[cfg(feature = "nalgebra")] -use nalgebra::{Dim, OVector}; - -#[cfg(feature = "nalgebra")] -pub fn check_multinomial(arr: &OVector, accept_zeroes: bool) -> crate::Result<()> -where - D: Dim, - nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, -{ - use crate::StatsError; - - if arr.len() < 2 { - return Err(StatsError::BadParams); - } - let mut sum = 0.0; - for &x in arr.iter() { - #[allow(clippy::if_same_then_else)] - if x.is_nan() { - return Err(StatsError::BadParams); - } else if x.is_infinite() { - return Err(StatsError::BadParams); - } else if x < 0.0 { - return Err(StatsError::BadParams); - } else if x == 0.0 && !accept_zeroes { - return Err(StatsError::BadParams); - } else { - sum += x; - } - } - - if sum != 0.0 { - Ok(()) - } else { - Err(StatsError::BadParams) - } -} - /// Implements univariate function bisection searching for criteria /// ```text /// smallest k such that f(k) >= z @@ -485,21 +448,6 @@ pub mod test { check_sum_pmf_is_cdf(dist, x_max); } - #[cfg(feature = "nalgebra")] - #[test] - fn test_is_valid_multinomial() { - use std::f64; - - let invalid = [1.0, f64::NAN, 3.0]; - assert!(check_multinomial(&invalid.to_vec().into(), true).is_err()); - let invalid2 = [-2.0, 5.0, 1.0, 6.2]; - assert!(check_multinomial(&invalid2.to_vec().into(), true).is_err()); - let invalid3 = [0.0, 0.0, 0.0]; - assert!(check_multinomial(&invalid3.to_vec().into(), true).is_err()); - let valid = [5.2, 0.0, 1e-15, 1000000.12]; - assert!(check_multinomial(&valid.to_vec().into(), true).is_ok()); - } - #[test] fn test_integer_bisection() { fn search(z: usize, data: &[usize]) -> Option { diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index efc3c5f3..8955ed63 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -27,7 +27,7 @@ pub use self::inverse_gamma::{InverseGamma, InverseGammaError}; pub use self::laplace::{Laplace, LaplaceError}; pub use self::log_normal::{LogNormal, LogNormalError}; #[cfg(feature = "nalgebra")] -pub use self::multinomial::Multinomial; +pub use self::multinomial::{Multinomial, MultinomialError}; #[cfg(feature = "nalgebra")] pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; #[cfg(feature = "nalgebra")] diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index dc402050..d1d78863 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,7 +1,6 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; -use crate::Result; use nalgebra::{Const, DVector, Dim, Dyn, OMatrix, OVector}; use rand::Rng; @@ -33,6 +32,33 @@ where n: u64, } +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub enum MultinomialError { + /// Fewer than two probabilities. + NotEnoughProbabilities, + + /// The sum of all probabilities is zero. + ProbabilitySumZero, + + /// At least one probability is NaN, infinite, or less than zero. + ProbabilityInvalid, +} + +impl std::fmt::Display for MultinomialError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"), + MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"), + MultinomialError::ProbabilityInvalid => write!( + f, + "The probabilities contain at least one NaN, infinity, or value less than zero" + ), + } + } +} + +impl std::error::Error for MultinomialError {} + impl Multinomial { /// Constructs a new multinomial distribution with probabilities `p` /// and `n` number of trials. @@ -57,7 +83,7 @@ impl Multinomial { /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: Vec, n: u64) -> Result { + pub fn new(p: Vec, n: u64) -> Result { Self::new_from_nalgebra(p.into(), n) } } @@ -67,14 +93,26 @@ where D: Dim, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { - match super::internal::check_multinomial(&p, true) { - Err(e) => Err(e), - Ok(_) => { - p.unscale_mut(p.lp_norm(1)); - Ok(Self { p, n }) + pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { + if p.len() < 2 { + return Err(MultinomialError::NotEnoughProbabilities); + } + + let mut sum = 0.0; + for &val in &p { + if val.is_nan() || val < 0.0 { + return Err(MultinomialError::ProbabilityInvalid); } + + sum += val; + } + + if sum == 0.0 { + return Err(MultinomialError::ProbabilitySumZero); } + + p.unscale_mut(p.lp_norm(1)); + Ok(Self { p, n }) } /// Returns the probabilities of the multinomial @@ -295,7 +333,7 @@ where #[cfg(test)] mod tests { use crate::{ - distribution::{Discrete, Multinomial}, + distribution::{Discrete, Multinomial, MultinomialError}, statistics::{MeanN, VarianceN}, }; use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector}; @@ -311,7 +349,7 @@ mod tests { mvn.unwrap() } - fn bad_create_case(p: OVector, n: u64) -> crate::StatsError + fn bad_create_case(p: OVector, n: u64) -> MultinomialError where D: DimMin, nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, @@ -344,18 +382,23 @@ mod tests { #[test] fn test_bad_create() { + assert_eq!( + bad_create_case(vector![0.5], 4), + MultinomialError::NotEnoughProbabilities, + ); + assert_eq!( bad_create_case(vector![-1.0, 2.0], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilityInvalid, ); assert_eq!( bad_create_case(vector![0.0, 0.0], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilitySumZero, ); assert_eq!( bad_create_case(vector![1.0, f64::NAN], 4), - crate::StatsError::BadParams + MultinomialError::ProbabilityInvalid, ); }