diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index 61ccfa57..909a7d17 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -108,15 +108,51 @@ where pdf_const: f64, } +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub enum MultivariateNormalError { + /// The covariance matrix is asymmetric or contains a NaN. + CovInvalid, + + /// The mean vector contains a NaN. + MeanInvalid, + + /// The amount of rows in the vector of means is not equal to the amount + /// of rows in the covariance matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateNormalError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateNormalError::CovInvalid => { + write!(f, "Covariance matrix is asymmetric or contains a NaN") + } + MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"), + MultivariateNormalError::DimensionMismatch => write!( + f, + "Mean vector and covariance matrix do not have the same number of rows" + ), + MultivariateNormalError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateNormalError {} + impl MultivariateNormal { - /// Constructs a new multivariate normal distribution with a mean of `mean` + /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new(mean: Vec, cov: Vec) -> Result { + pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); MultivariateNormal::new_from_nalgebra(mean, cov) @@ -141,24 +177,31 @@ where pub fn new_from_nalgebra( mean: OVector, cov: OMatrix, - ) -> Result { - // Check that the provided covariance matrix is symmetric - if cov.lower_triangle() != cov.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || mean.iter().any(|f| f.is_nan()) + ) -> Result { + if mean.iter().any(|f| f.is_nan()) { + return Err(MultivariateNormalError::MeanInvalid); + } + + if !cov.is_square() + || cov.lower_triangle() != cov.upper_triangle().transpose() || cov.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols() { - return Err(StatsError::BadParams); + return Err(MultivariateNormalError::CovInvalid); + } + + // Compare number of rows + if mean.shape_generic().0 != cov.shape_generic().0 { + return Err(MultivariateNormalError::DimensionMismatch); } + // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { - None => Err(StatsError::BadParams), + None => Err(MultivariateNormalError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal { + // .unwrap() because prerequisites are already checked above pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), cov_chol_decomp: cholesky_decomp.unpack(), mu: mean,