Skip to content

Commit

Permalink
Add MultivariateNormalError
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Sep 6, 2024
1 parent 4a041b6 commit 858928c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/distribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub use self::log_normal::{LogNormal, LogNormalError};
#[cfg(feature = "nalgebra")]
pub use self::multinomial::Multinomial;
#[cfg(feature = "nalgebra")]
pub use self::multivariate_normal::MultivariateNormal;
pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError};
#[cfg(feature = "nalgebra")]
pub use self::multivariate_students_t::MultivariateStudent;
pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError};
Expand Down
65 changes: 54 additions & 11 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,51 @@ where
pdf_const: f64,
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum MultivariateNormalError {
/// The covariance matrix is asymmetric or contains a NaN.
CovInvalid,

/// The mean vector contains a NaN.
MeanInvalid,

/// The amount of rows in the vector of means is not equal to the amount
/// of rows in the covariance matrix.
DimensionMismatch,

/// After all other validation, computing the Cholesky decomposition failed.
CholeskyFailed,
}

impl std::fmt::Display for MultivariateNormalError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
MultivariateNormalError::CovInvalid => {
write!(f, "Covariance matrix is asymmetric or contains a NaN")
}
MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"),
MultivariateNormalError::DimensionMismatch => write!(
f,
"Mean vector and covariance matrix do not have the same number of rows"
),
MultivariateNormalError::CholeskyFailed => {
write!(f, "Computing the Cholesky decomposition failed")
}
}
}
}

impl std::error::Error for MultivariateNormalError {}

impl MultivariateNormal<Dyn> {
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// and covariance matrix `cov`
///
/// # Errors
///
/// Returns an error if the given covariance matrix is not
/// symmetric or positive-definite
pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self, StatsError> {
pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self, MultivariateNormalError> {
let mean = DVector::from_vec(mean);
let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
MultivariateNormal::new_from_nalgebra(mean, cov)
Expand All @@ -141,24 +177,31 @@ where
pub fn new_from_nalgebra(
mean: OVector<f64, D>,
cov: OMatrix<f64, D, D>,
) -> Result<Self, StatsError> {
// Check that the provided covariance matrix is symmetric
if cov.lower_triangle() != cov.upper_triangle().transpose()
// Check that mean and covariance do not contain NaN
|| mean.iter().any(|f| f.is_nan())
) -> Result<Self, MultivariateNormalError> {
if mean.iter().any(|f| f.is_nan()) {
return Err(MultivariateNormalError::MeanInvalid);
}

if !cov.is_square()
|| cov.lower_triangle() != cov.upper_triangle().transpose()
|| cov.iter().any(|f| f.is_nan())
// Check that the dimensions match
|| mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols()
{
return Err(StatsError::BadParams);
return Err(MultivariateNormalError::CovInvalid);
}

// Compare number of rows
if mean.shape_generic().0 != cov.shape_generic().0 {
return Err(MultivariateNormalError::DimensionMismatch);
}

// Store the Cholesky decomposition of the covariance matrix
// for sampling
match Cholesky::new(cov.clone()) {
None => Err(StatsError::BadParams),
None => Err(MultivariateNormalError::CholeskyFailed),
Some(cholesky_decomp) => {
let precision = cholesky_decomp.inverse();
Ok(MultivariateNormal {
// .unwrap() because prerequisites are already checked above
pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(),
cov_chol_decomp: cholesky_decomp.unpack(),
mu: mean,
Expand Down

0 comments on commit 858928c

Please sign in to comment.