diff --git a/src/cholesky.rs b/src/cholesky.rs index 7f38e86b..c3ef4904 100644 --- a/src/cholesky.rs +++ b/src/cholesky.rs @@ -48,6 +48,7 @@ use num_traits::Float; use crate::convert::*; use crate::error::*; +use crate::lapack::cholesky_semi::cholesky_semi; use crate::layout::*; use crate::triangular::IntoTriangular; use crate::types::*; @@ -468,3 +469,32 @@ where Ok(self.factorizec_into(UPLO::Upper)?.ln_detc_into()) } } + +pub struct PsdCholeskyFactorization { + pub pivot: Vec, + pub positive_definite_factorization: CholeskyFactorized>, +} + +pub fn factorize_cholesky_semi_definite( + input_mat: &mut ArrayBase, Ix2>, +) -> Result { + let uplo = UPLO::Lower; + let mut rank: i32 = 0; + let pivot: Vec = cholesky_semi( + input_mat.square_layout()?, + uplo, + input_mat.as_allocated_mut()?, + &mut rank, + )? + .iter() + .map(|x| *x as usize) + .collect(); + let cholesky_factors = input_mat.into_triangular(uplo).slice(s![0..rank, 0..rank]); + Ok(PsdCholeskyFactorization { + pivot: pivot, + positive_definite_factorization: CholeskyFactorized::> { + factor: replicate(&cholesky_factors), + uplo: uplo, + }, + }) +} diff --git a/src/lapack/cholesky_semi.rs b/src/lapack/cholesky_semi.rs new file mode 100644 index 00000000..06ca4cf8 --- /dev/null +++ b/src/lapack/cholesky_semi.rs @@ -0,0 +1,25 @@ +use super::*; +use crate::{error::*, layout::*}; + +pub fn cholesky_semi(l: MatrixLayout, uplo: UPLO, a: &mut [f64], rank: &mut i32) -> Result> { + let (n, _) = l.size(); + let mut ipiv = vec![0; n as usize]; + let tol = 1.0e-12_f64; + let info = unsafe { + lapacke::dpstrf( + l.lapacke_layout(), + uplo as u8, + n, + a, + l.lda(), + &mut ipiv, + rank, + tol, + ) + }; + if info < 0 { + Err(LinalgError::Lapack { return_code: info }) + } else { + Ok(ipiv) + } +} diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index 18fd9eda..8f6b78e2 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -1,6 +1,7 @@ //! Define traits wrapping LAPACK routines pub mod cholesky; +pub mod cholesky_semi; pub mod eig; pub mod eigh; pub mod least_squares; @@ -14,6 +15,7 @@ pub mod triangular; pub mod tridiagonal; pub use self::cholesky::*; +pub use self::cholesky_semi::*; pub use self::eig::*; pub use self::eigh::*; pub use self::least_squares::*;