Skip to content

Commit

Permalink
tests: 3d matrices test cases for pdf.
Browse files Browse the repository at this point in the history
Also improves documentation for multivariate t minorly.
  • Loading branch information
henryjac authored and YeungOnion committed Aug 16, 2024
1 parent 612355b commit b8261a0
Showing 1 changed file with 48 additions and 47 deletions.
95 changes: 48 additions & 47 deletions src/distribution/multivariate_students_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ use rand::Rng;
use std::f64::consts::PI;

/// Implements the [Multivariate Student's t-distribution](https://en.wikipedia.org/wiki/Multivariate_t-distribution)
/// distribution using the "nalgebra" crate for matrix operations
/// distribution using the "nalgebra" crate for matrix operations.
///
/// Assumes all the marginal distributions have the same degree of freedom, ν
/// Assumes all the marginal distributions have the same degree of freedom, ν.
///
/// # Examples
///
Expand All @@ -38,7 +38,7 @@ pub struct MultivariateStudent {

impl MultivariateStudent {
/// Constructs a new multivariate students t distribution with a location of `location`,
/// scale matrix `scale` and `freedom` degrees of freedom
/// scale matrix `scale` and `freedom` degrees of freedom.
///
/// # Errors
///
Expand Down Expand Up @@ -91,34 +91,34 @@ impl MultivariateStudent {
}
}

/// Returns the dimension of the distribution
/// Returns the dimension of the distribution.
pub fn dim(&self) -> usize {
self.dim
}

Check warning on line 97 in src/distribution/multivariate_students_t.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_students_t.rs#L95-L97

Added lines #L95 - L97 were not covered by tests
/// Returns the cholesky decomposiiton matrix of the scale matrix
/// Returns the cholesky decomposiiton matrix of the scale matrix.
///
/// Returns A where Σ = AAᵀ
/// Returns A where Σ = AAᵀ.
pub fn scale_chol_decomp(&self) -> DMatrix<f64> {
self.scale_chol_decomp.clone()
}

Check warning on line 103 in src/distribution/multivariate_students_t.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_students_t.rs#L101-L103

Added lines #L101 - L103 were not covered by tests
/// Returns the location of the distribution
/// Returns the location of the distribution.
pub fn location(&self) -> DVector<f64> {
self.location.clone()
}

Check warning on line 107 in src/distribution/multivariate_students_t.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_students_t.rs#L105-L107

Added lines #L105 - L107 were not covered by tests
/// Returns the scale matrix of the distribution
/// Returns the scale matrix of the distribution.
pub fn scale(&self) -> DMatrix<f64> {
self.scale.clone()
}

Check warning on line 111 in src/distribution/multivariate_students_t.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_students_t.rs#L109-L111

Added lines #L109 - L111 were not covered by tests
/// Returns the degrees of freedom of the distribution
/// Returns the degrees of freedom of the distribution.
pub fn freedom(&self) -> f64 {
self.freedom
}

Check warning on line 115 in src/distribution/multivariate_students_t.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_students_t.rs#L113-L115

Added lines #L113 - L115 were not covered by tests
/// Returns the inverse of the cholesky decomposition matrix
/// Returns the inverse of the cholesky decomposition matrix.
pub fn precision(&self) -> DMatrix<f64> {
self.precision.clone()
}

Check warning on line 119 in src/distribution/multivariate_students_t.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_students_t.rs#L117-L119

Added lines #L117 - L119 were not covered by tests
/// Returns the logarithmed constant part of the probability
/// distribution function
/// distribution function.
pub fn ln_pdf_const(&self) -> f64 {
self.ln_pdf_const
}

Check warning on line 124 in src/distribution/multivariate_students_t.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_students_t.rs#L122-L124

Added lines #L122 - L124 were not covered by tests
Expand All @@ -129,8 +129,8 @@ impl ::rand::distributions::Distribution<DVector<f64>> for MultivariateStudent {
///
/// # Formula
///
///```ignore
/// W * L * Z + μ
///```math
/// W L Z + μ
///```
///
/// where `W` has √(ν/Sν) distribution, Sν has Chi-squared
Expand Down Expand Up @@ -164,35 +164,32 @@ impl Max<DVector<f64>> for MultivariateStudent {
}

impl MeanN<DVector<f64>> for MultivariateStudent {
/// Returns the mean of the student distribution
/// Returns the mean of the student distribution.
///
/// # Remarks
///
/// This is the same mean used to construct the distribution if
/// the degrees of freedom is larger than 1.
fn mean(&self) -> Option<DVector<f64>> {
if self.freedom > 1. {
let mut vec = vec![];
for elt in self.location.clone().into_iter() {
vec.push(*elt);
}
Some(DVector::from_vec(vec))
Some(self.location.clone())
} else {
None
}
}
}

impl VarianceN<DMatrix<f64>> for MultivariateStudent {
/// Returns the covariance matrix of the multivariate student distribution
/// Returns the covariance matrix of the multivariate student distribution.
///
/// # Formula
/// ```ignore
///
/// ```math
/// Σ ⋅ ν / (ν - 2)
/// ```
///
/// where `Σ` is the scale matrix and `ν` is the degrees of freedom.
/// Only defined if freedom is larger than 2
/// Only defined if freedom is larger than 2.
fn variance(&self) -> Option<DMatrix<f64>> {
if self.freedom > 2. {
Some(self.scale.clone() * self.freedom / (self.freedom - 2.))
Expand All @@ -203,39 +200,34 @@ impl VarianceN<DMatrix<f64>> for MultivariateStudent {
}

impl Mode<DVector<f64>> for MultivariateStudent {
/// Returns the mode of the multivariate student distribution
/// Returns the mode of the multivariate student distribution.
///
/// # Formula
///
/// ```ignore
/// ```math
/// μ
/// ```
///
/// where `μ` is the location
/// where `μ` is the location.
fn mode(&self) -> DVector<f64> {
self.location.clone()
}
}

impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateStudent {
/// Calculates the probability density function for the multivariate
/// student distribution at `x`
/// Calculates the probability density function for the multivariate.
/// student distribution at `x`.
///
/// # Formula
///
/// ```ignore
/// Gamma[(ν+p)/2] / [Gamma(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν transpose(x - μ) inv(Σ) (x - μ)]^(-(ν+p)/2)
/// ```math
/// (ν+p)/2] / [Γ(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν (x - μ) inv(Σ) (x - μ)]^(-(ν+p)/2)
/// ```
///
/// where `ν` is the degrees of freedom, `μ` is the mean, `Gamma`
/// where `ν` is the degrees of freedom, `μ` is the mean, `Γ`
/// is the Gamma function, `inv(Σ)`
/// is the precision matrix, `det(Σ)` is the determinant
/// of the scale matrix, and `k` is the dimension of the distribution
///
/// TODO: Make this converge for large degrees of freedom
/// Current commented code beneath fails since `MultivariateNormal::new` accepts Vec<f64> and
/// not DVector or DMatrix. Should implement that instead of changing back to Vec<f64>, or
/// even have a constructor `MultivariateNormal::from_student`.
/// of the scale matrix, and `k` is the dimension of the distribution.
fn pdf(&self, x: &'a DVector<f64>) -> f64 {
if self.freedom == f64::INFINITY {
let mvn = MultivariateNormal::from_students(self.clone()).unwrap();
Expand Down Expand Up @@ -267,7 +259,6 @@ impl<'a> Continuous<&'a DVector<f64>, f64> for MultivariateStudent {
}
}

// TODO: Add more tests for other matrices than really straightforward symmetric positive
#[rustfmt::skip]
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -364,18 +355,19 @@ mod tests {

#[test]
fn test_bad_create() {
// scale not symmetric
// scale not symmetric.
bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.], 1.);
// scale not positive-definite
// scale not positive-definite.
bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.], 1.);
// NaN in location
// NaN in location.
bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.], 1.);
// NaN in scale Matrix
// NaN in scale Matrix.
bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN], 1.);
// NaN in freedom
// NaN in freedom.
bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], f64::NAN);
// Non-positive freedom
// Non-positive freedom.
bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.);
bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], -1.);
}

#[test]
Expand All @@ -385,6 +377,7 @@ mod tests {
test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 3., mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance);
}

// Variance is only defined for freedom > 2.
#[test]
fn test_bad_variance() {
let variance = |x: MultivariateStudent| x.variance();
Expand All @@ -405,6 +398,7 @@ mod tests {
test_case(vec![-1., 1., 3.], vec![1., 0., 0.5, 0., 2.0, 0., 0.5, 0., 3.0], 2., dvec![-1., 1., 3.], mean);
}

// Mean is only defined if freedom > 1.
#[test]
fn test_bad_mean() {
let mean = |x: MultivariateStudent| x.mean();
Expand All @@ -425,9 +419,14 @@ mod tests {
fn test_pdf() {
let pdf = |arg: DVector<f64>| move |x: MultivariateStudent| x.pdf(&arg);
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, 1e-15, pdf(dvec![1., 1.]));
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, 1e-15, pdf(dvec![1., 2.]));
test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, 1e-17, pdf(dvec![1., 2.]));
test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, 1e-19, pdf(dvec![1., 10.]));
test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, 1e-19, pdf(dvec![10., 10.]));
// These three are crossed checked against both python's scipy.multivariate_t.pdf and octave's mvtpdf.
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, 1e-30, pdf(dvec![0.9718, 0.1298, 0.8134]));
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, 1e-30, pdf(dvec![0.4922, 0.5522, 0.7185]));
test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8.,6.952297846610382e-16, 1e-30, pdf(dvec![0.3010, 0.1491, 0.5008]));
test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., 0., pdf(dvec![10., 10.]));
}

Expand All @@ -447,14 +446,16 @@ mod tests {
// let pdf_mvn = |mv: MultivariateNormal, arg: DVector<f64>| mv.pdf(&arg);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-15, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn);
// }

// #[test]
// fn test_ln_pdf_freedom_large() {
// let pdf_mvs = |mv: MultivariateStudent, arg: DVector<f64>| mv.ln_pdf(&arg);
// let pdf_mvn = |mv: MultivariateNormal, arg: DVector<f64>| mv.ln_pdf(&arg);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-50, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
// test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn);
// }
}

0 comments on commit b8261a0

Please sign in to comment.