Skip to content

Commit

Permalink
Add UniformError
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Sep 6, 2024
1 parent af7f6df commit 1581ec6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub use self::pareto::{Pareto, ParetoError};
pub use self::poisson::{Poisson, PoissonError};
pub use self::students_t::{StudentsT, StudentsTError};
pub use self::triangular::{Triangular, TriangularError};
pub use self::uniform::Uniform;
pub use self::uniform::{Uniform, UniformError};
pub use self::weibull::Weibull;

mod bernoulli;
Expand Down
72 changes: 53 additions & 19 deletions src/distribution/uniform.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::distribution::{Continuous, ContinuousCDF};
use crate::statistics::*;
use crate::{Result, StatsError};
use rand::distributions::Uniform as RandUniform;
use rand::Rng;
use std::f64;
Expand All @@ -26,13 +25,40 @@ pub struct Uniform {
max: f64,
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum UniformError {
/// The minimum is NaN or infinite.
MinInvalid,

/// The maximum is NaN or infinite.
MaxInvalid,

/// The maximum is not greater than the minimum.
MaxNotGreaterThanMin,
}

impl std::fmt::Display for UniformError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"),
UniformError::MaxInvalid => write!(f, "Maximum is NaN or infinite"),
UniformError::MaxNotGreaterThanMin => {
write!(f, "Maximum is not greater than the minimum")
}
}
}
}

impl std::error::Error for UniformError {}

impl Uniform {
/// Constructs a new uniform distribution with a min of `min` and a max
/// of `max`
/// of `max`.
///
/// # Errors
///
/// Returns an error if `min` or `max` are `NaN` or unbounded
/// Returns an error if `min` or `max` are `NaN` or infinite.
/// Returns an error if `min >= max`.
///
/// # Examples
///
Expand All @@ -49,17 +75,19 @@ impl Uniform {
/// result = Uniform::new(f64::NEG_INFINITY, 1.0);
/// assert!(result.is_err());
/// ```
pub fn new(min: f64, max: f64) -> Result<Uniform> {
if min.is_nan() || max.is_nan() {
return Err(StatsError::BadParams);
pub fn new(min: f64, max: f64) -> Result<Uniform, UniformError> {
if !min.is_finite() {
return Err(UniformError::MinInvalid);
}

match (min.is_finite(), max.is_finite(), min < max) {
(false, false, _) => Err(StatsError::ArgFinite("min and max")),
(false, true, _) => Err(StatsError::ArgFinite("min")),
(true, false, _) => Err(StatsError::ArgFinite("max")),
(true, true, false) => Err(StatsError::ArgLteArg("min", "max")),
(true, true, true) => Ok(Uniform { min, max }),
if !max.is_finite() {
return Err(UniformError::MaxInvalid);
}

if min < max {
Ok(Uniform { min, max })
} else {
Err(UniformError::MaxNotGreaterThanMin)
}
}

Expand Down Expand Up @@ -288,7 +316,7 @@ mod tests {
use crate::distribution::internal::*;
use crate::testing_boiler;

testing_boiler!(min: f64, max: f64; Uniform; StatsError);
testing_boiler!(min: f64, max: f64; Uniform; UniformError);

#[test]
fn test_create() {
Expand All @@ -300,12 +328,18 @@ mod tests {

#[test]
fn test_bad_create() {
create_err(0.0, 0.0);
create_err(f64::NAN, 1.0);
create_err(1.0, f64::NAN);
create_err(f64::NAN, f64::NAN);
create_err(0.0, f64::INFINITY);
create_err(1.0, 0.0);
let invalid = [
(0.0, 0.0, UniformError::MaxNotGreaterThanMin),
(f64::NAN, 1.0, UniformError::MinInvalid),
(1.0, f64::NAN, UniformError::MaxInvalid),
(f64::NAN, f64::NAN, UniformError::MinInvalid),
(0.0, f64::INFINITY, UniformError::MaxInvalid),
(1.0, 0.0, UniformError::MaxNotGreaterThanMin),
];

for (min, max, err) in invalid {
test_create_err(min, max, err);
}
}

#[test]
Expand Down

0 comments on commit 1581ec6

Please sign in to comment.