Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for issue #185 #220

Merged
merged 6 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/distribution/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,24 @@ mod tests {
test_case(0.5, 3, 0.0, sf(5));
}

#[test]
fn test_inverse_cdf() {
let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg);
test_case(0.4, 5, 2, invcdf(0.3456));

// cases in issue #185
test_case(0.018, 465, 1, invcdf(3.472e-4));
test_case(0.5, 6, 4, invcdf(0.75));
}

#[test]
fn test_cdf_inverse_cdf() {
let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg));
test_case(0.3, 10, 3, cdf_invcdf(3));
test_case(0.3, 10, 4, cdf_invcdf(4));
test_case(0.5, 6, 4, cdf_invcdf(4));
}

#[test]
fn test_discrete() {
test::check_discrete_distribution(&try_create(0.3, 5), 5);
Expand Down
63 changes: 62 additions & 1 deletion src/distribution/internal.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use num_traits::{Bounded, Float, Num};

/// Returns true if there are no elements in `x` in `arr`
/// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`.
/// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0`
Expand All @@ -12,10 +14,47 @@ pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool {
sum != 0.0
}

/// Implements univariate function bisection searching for criteria
/// ```text
/// smallest k such that f(k) >= z
/// ```
/// Evaluates to `None` if
/// - provided interval has lower bound greater than upper bound
/// - function found not semi-monotone on the provided interval containing `z`
/// Evaluates to `Some(k)`, where `k` satisfies the search criteria
pub fn integral_bisection_search<K: Num + Clone, T: Num + PartialOrd>(
f: impl Fn(&K) -> T, z: T, lb: K, ub: K,
) -> Option<K> {
if !(f(&lb)..=f(&ub)).contains(&z) {
return None;
}
let two = K::one() + K::one();
let mut lb = lb;
let mut ub = ub;
loop {
let mid = (lb.clone() + ub.clone()) / two.clone();
if !(f(&lb)..=f(&ub)).contains(&f(&mid)) {
// if f found not monotone on the interval
return None;
} else if f(&lb) == z {
return Some(lb);
} else if f(&ub) == z {
return Some(ub);
} else if (lb.clone() + K::one()) == ub {
// no more elements to search
return Some(ub);
} else if f(&mid) >= z {
ub = mid;
} else {
lb = mid;
}
}
}

#[macro_use]
#[cfg(all(test, feature = "nightly"))]
pub mod test {
use super::is_valid_multinomial;
use super::*;
use crate::consts::ACC;
use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF};

Expand Down Expand Up @@ -196,4 +235,26 @@ pub mod test {
let invalid = [5.2, 0.0, 1e-15, 1000000.12];
assert!(!is_valid_multinomial(&invalid, false));
}

#[test]
fn test_integer_bisection() {
fn search(z: usize, data: &Vec<usize>) -> Option<usize> {
integral_bisection_search(|idx: &usize| data[*idx], z, 0, data.len() - 1)
}

let needle = 3;
let data = (0..5)
.map(|n| if n >= needle { n + 1 } else { n })
.collect::<Vec<_>>();

for i in 0..(data.len()) {
assert_eq!(search(data[i], &data), Some(i),)
}
{
let infimum = search(needle, &data);
let found_element = search(needle + 1, &data); // 4 > needle && member of range
assert_eq!(found_element, Some(needle));
assert_eq!(infimum, found_element)
}
}
}
38 changes: 19 additions & 19 deletions src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
//! and provides
//! concrete implementations for a variety of distributions.
use super::statistics::{Max, Min};
use ::num_traits::{float::Float, Bounded, Num};
use ::num_traits::{Bounded, Float, Num};
use num_traits::{NumAssign, NumAssignOps, NumAssignRef};

pub use self::bernoulli::Bernoulli;
pub use self::beta::Beta;
Expand Down Expand Up @@ -145,7 +146,9 @@ pub trait ContinuousCDF<K: Float, T: Float>: Min<K> + Max<K> {

/// The `DiscreteCDF` trait is used to specify an interface for univariate
/// discrete distributions.
pub trait DiscreteCDF<K: Bounded + Clone + Num, T: Float>: Min<K> + Max<K> {
pub trait DiscreteCDF<K: Sized + Num + Ord + Clone + NumAssignOps, T: Float>:
Min<K> + Max<K>
{
/// Returns the cumulative distribution function calculated
/// at `x` for a given distribution. May panic depending
/// on the implementor.
Expand Down Expand Up @@ -177,29 +180,26 @@ pub trait DiscreteCDF<K: Bounded + Clone + Num, T: Float>: Min<K> + Max<K> {

/// Due to issues with rounding and floating-point accuracy the default implementation may be ill-behaved
/// Specialized inverse cdfs should be used whenever possible.
///
/// # Panics
/// this default impl panics if provided `p` not on interval [0.0, 1.0]
fn inverse_cdf(&self, p: T) -> K {
// TODO: fix integer implementation
if p == T::zero() {
return self.min();
};
if p == T::one() {
} else if p == T::one() {
return self.max();
};
let two = K::one() + K::one();
let mut high = two.clone();
let mut low = K::min_value();
while self.cdf(high.clone()) < p {
high = high.clone() + high.clone();
} else if !(T::zero()..=T::one()).contains(&p) {
panic!("p must be on [0, 1]")
}
while high != low {
let mid = (high.clone() + low.clone()) / two.clone();
if self.cdf(mid.clone()) >= p {
high = mid;
} else {
low = mid;
}

let two = K::one() + K::one();
let mut ub = two.clone();
let lb = self.min();
while self.cdf(ub.clone()) < p {
ub *= two.clone();
}
high

internal::integral_bisection_search(|p| self.cdf(p.clone()), p, lb, ub).unwrap()
}
}

Expand Down
Loading