Skip to content

Commit

Permalink
Add MultivariateNormalError
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Sep 4, 2024
1 parent 896c5ad commit 04bd36d
Showing 1 changed file with 54 additions and 11 deletions.
65 changes: 54 additions & 11 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,51 @@ where
pdf_const: f64,
}

#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum MultivariateNormalError {
/// The covariance matrix is asymmetric or contains a NaN.
CovInvalid,

/// The mean vector contains a NaN.
MeanInvalid,

/// The amount of rows in the vector of means is not equal to the amount
/// of rows in the covariance matrix.
DimensionMismatch,

/// After all other validation, computing the Cholesky decomposition failed.
CholeskyFailed,
}

impl std::fmt::Display for MultivariateNormalError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {

Check warning on line 129 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L128-L129

Added lines #L128 - L129 were not covered by tests
MultivariateNormalError::CovInvalid => {
write!(f, "Covariance matrix is asymmetric or contains a NaN")

Check warning on line 131 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L131

Added line #L131 was not covered by tests
}
MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"),
MultivariateNormalError::DimensionMismatch => write!(
f,
"Mean vector and covariance matrix do not have the same number of rows"
),

Check warning on line 137 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L133-L137

Added lines #L133 - L137 were not covered by tests
MultivariateNormalError::CholeskyFailed => {
write!(f, "Computing the Cholesky decomposition failed")

Check warning on line 139 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L139

Added line #L139 was not covered by tests
}
}
}

Check warning on line 142 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L142

Added line #L142 was not covered by tests
}

impl std::error::Error for MultivariateNormalError {}

impl MultivariateNormal<Dyn> {
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// Constructs a new multivariate normal distribution with a mean of `mean`
/// and covariance matrix `cov`
///
/// # Errors
///
/// Returns an error if the given covariance matrix is not
/// symmetric or positive-definite
pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self, StatsError> {
pub fn new(mean: Vec<f64>, cov: Vec<f64>) -> Result<Self, MultivariateNormalError> {
let mean = DVector::from_vec(mean);
let cov = DMatrix::from_vec(mean.len(), mean.len(), cov);
MultivariateNormal::new_from_nalgebra(mean, cov)
Expand All @@ -141,24 +177,31 @@ where
pub fn new_from_nalgebra(
mean: OVector<f64, D>,
cov: OMatrix<f64, D, D>,
) -> Result<Self, StatsError> {
// Check that the provided covariance matrix is symmetric
if cov.lower_triangle() != cov.upper_triangle().transpose()
// Check that mean and covariance do not contain NaN
|| mean.iter().any(|f| f.is_nan())
) -> Result<Self, MultivariateNormalError> {
if mean.iter().any(|f| f.is_nan()) {
return Err(MultivariateNormalError::MeanInvalid);
}

if !cov.is_square()
|| cov.lower_triangle() != cov.upper_triangle().transpose()
|| cov.iter().any(|f| f.is_nan())
// Check that the dimensions match
|| mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols()
{
return Err(StatsError::BadParams);
return Err(MultivariateNormalError::CovInvalid);
}

// Compare number of rows
if mean.shape_generic().0 != cov.shape_generic().0 {
return Err(MultivariateNormalError::DimensionMismatch);

Check warning on line 194 in src/distribution/multivariate_normal.rs

View check run for this annotation

Codecov / codecov/patch

src/distribution/multivariate_normal.rs#L194

Added line #L194 was not covered by tests
}

// Store the Cholesky decomposition of the covariance matrix
// for sampling
match Cholesky::new(cov.clone()) {
None => Err(StatsError::BadParams),
None => Err(MultivariateNormalError::CholeskyFailed),
Some(cholesky_decomp) => {
let precision = cholesky_decomp.inverse();
Ok(MultivariateNormal {
// .unwrap() because prerequisites are already checked above
pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(),
cov_chol_decomp: cholesky_decomp.unpack(),
mu: mean,
Expand Down

0 comments on commit 04bd36d

Please sign in to comment.