From 0a70ccca4880fbcd088b905e29a19a8d181a219f Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Wed, 4 Aug 2021 00:02:53 +1000 Subject: [PATCH 1/4] Define and implement a cov_to_corr() method in the CorrelationExt trait --- src/correlation.rs | 87 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/src/correlation.rs b/src/correlation.rs index 5ae194b..9592f55 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -1,7 +1,9 @@ use crate::errors::EmptyInput; use ndarray::prelude::*; -use ndarray::Data; +use ndarray::{Data, ShapeError}; use num_traits::{Float, FromPrimitive}; +use itertools::Itertools; +use std::error::Error; /// Extension trait for `ArrayBase` providing functions /// to compute different correlation measures. @@ -123,6 +125,22 @@ where A: Float + FromPrimitive; private_decl! {} + + /// Given that self is a covariance matrix, returns the appropriate correlation matrix + /// + /// # Example + /// ``` + /// use::ndarray::array; + /// use ndarray_stats::CorrelationExt; + /// + /// let a = array![[1., 3., 5.], + /// [2., 4., 6.]]; + /// let covariance = a.cov(1.).unwrap(); + /// let corr = a.pearson_correlation().unwrap(); + /// assert_eq!(covariance.cov_to_corr().unwrap(), corr); + /// ``` + fn cov_to_corr(&self) -> Result, Box> + where A: Float + FromPrimitive; } impl CorrelationExt for ArrayBase @@ -179,6 +197,27 @@ where } private_impl! {} + + fn cov_to_corr<'a>(&self) -> Result, Box> + where A: Float + FromPrimitive + { + if !self.is_square() { + return Err("A covariance matrix must be square".into()); + } + + let vals = self + .indexed_iter() + .map(|((x, y), v)| { + // rho_ij = sigma_ij / sqrt(sigma_ii * sigma_jj) + *v / (self[[x, x]] * self[[y, y]]).powf(A::from_f64(0.5).unwrap()) + }) + .collect_vec(); + + match Array2::from_shape_vec(self.raw_dim(), vals){ + Ok(x) => Ok(x), + Err(e) => Err(Box::new(e)) + } + } } #[cfg(test)] @@ -274,6 +313,7 @@ mod cov_tests { assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8); } + #[test] #[should_panic] // We lose precision, hence the failing assert @@ -367,3 +407,48 @@ mod pearson_correlation_tests { ); } } + +#[cfg(test)] +mod cov_to_corr_tests{ + use ndarray::array; + use super::*; + + #[test] + fn test_cov_2_corr_known(){ + // Very basic maths that can be done in your head + let cov = array![ + [ 4., 1. ], + [ 3., 4. ], + ]; + assert_eq!(cov.cov_to_corr().unwrap(), array![ + [1., 0.25], + [0.75, 1.] + ]) + } + + #[test] + fn test_cov_2_corr_failure(){ + // A 1D array can't be a covariance matrix + let cov = array![ + [ 4., 1. ], + ]; + cov.cov_to_corr().unwrap_err(); + } + + #[test] + fn test_cov_2_corr_random(){ + let a = array![ + [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457], + [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245], + [0.59044195, 0.10720363, 0.76573717, 0.54693675, 0.95923036], + [0.24102952, 0.131347, 0.11118028, 0.21451351, 0.30515539], + [0.26952473, 0.93079841, 0.8080893, 0.42814155, 0.24642258] + ]; + + // Calculating cov, and then corr should always be equivalent to calculating corr directly + let cov = a.cov(1.).unwrap(); + let corr = a.pearson_correlation().unwrap(); + let corr_2 = cov.cov_to_corr().unwrap(); + assert!(corr.abs_diff_eq(&corr_2, 0.001)); + } +} \ No newline at end of file From 3a3778e107ed588489605f3e84583b8d660ae9a6 Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Wed, 4 Aug 2021 00:04:45 +1000 Subject: [PATCH 2/4] Fix cargo check --- src/correlation.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/correlation.rs b/src/correlation.rs index 9592f55..b330dd2 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -1,6 +1,6 @@ use crate::errors::EmptyInput; use ndarray::prelude::*; -use ndarray::{Data, ShapeError}; +use ndarray::{Data}; use num_traits::{Float, FromPrimitive}; use itertools::Itertools; use std::error::Error; From b3f11ac9fd19e4b11b01f5b269c419bd96860eba Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Wed, 4 Aug 2021 00:05:40 +1000 Subject: [PATCH 3/4] rustfmt --- src/correlation.rs | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/correlation.rs b/src/correlation.rs index b330dd2..5ac1ae8 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -1,8 +1,8 @@ use crate::errors::EmptyInput; +use itertools::Itertools; use ndarray::prelude::*; -use ndarray::{Data}; +use ndarray::Data; use num_traits::{Float, FromPrimitive}; -use itertools::Itertools; use std::error::Error; /// Extension trait for `ArrayBase` providing functions @@ -140,7 +140,8 @@ where /// assert_eq!(covariance.cov_to_corr().unwrap(), corr); /// ``` fn cov_to_corr(&self) -> Result, Box> - where A: Float + FromPrimitive; + where + A: Float + FromPrimitive; } impl CorrelationExt for ArrayBase @@ -199,7 +200,8 @@ where private_impl! {} fn cov_to_corr<'a>(&self) -> Result, Box> - where A: Float + FromPrimitive + where + A: Float + FromPrimitive, { if !self.is_square() { return Err("A covariance matrix must be square".into()); @@ -213,9 +215,9 @@ where }) .collect_vec(); - match Array2::from_shape_vec(self.raw_dim(), vals){ + match Array2::from_shape_vec(self.raw_dim(), vals) { Ok(x) => Ok(x), - Err(e) => Err(Box::new(e)) + Err(e) => Err(Box::new(e)), } } } @@ -313,7 +315,6 @@ mod cov_tests { assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8); } - #[test] #[should_panic] // We lose precision, hence the failing assert @@ -409,34 +410,26 @@ mod pearson_correlation_tests { } #[cfg(test)] -mod cov_to_corr_tests{ - use ndarray::array; +mod cov_to_corr_tests { use super::*; + use ndarray::array; #[test] - fn test_cov_2_corr_known(){ + fn test_cov_2_corr_known() { // Very basic maths that can be done in your head - let cov = array![ - [ 4., 1. ], - [ 3., 4. ], - ]; - assert_eq!(cov.cov_to_corr().unwrap(), array![ - [1., 0.25], - [0.75, 1.] - ]) + let cov = array![[4., 1.], [3., 4.],]; + assert_eq!(cov.cov_to_corr().unwrap(), array![[1., 0.25], [0.75, 1.]]) } #[test] - fn test_cov_2_corr_failure(){ + fn test_cov_2_corr_failure() { // A 1D array can't be a covariance matrix - let cov = array![ - [ 4., 1. ], - ]; + let cov = array![[4., 1.],]; cov.cov_to_corr().unwrap_err(); } #[test] - fn test_cov_2_corr_random(){ + fn test_cov_2_corr_random() { let a = array![ [0.72009497, 0.12568055, 0.55705966, 0.5959984, 0.69471457], [0.56717131, 0.47619486, 0.21526298, 0.88915366, 0.91971245], @@ -451,4 +444,4 @@ mod cov_to_corr_tests{ let corr_2 = cov.cov_to_corr().unwrap(); assert!(corr.abs_diff_eq(&corr_2, 0.001)); } -} \ No newline at end of file +} From 43cf9806427fecbd03c11ec7b9c5200631468ad7 Mon Sep 17 00:00:00 2001 From: Michael Milton Date: Wed, 4 Aug 2021 00:07:35 +1000 Subject: [PATCH 4/4] Don't use itertools --- src/correlation.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/correlation.rs b/src/correlation.rs index 5ac1ae8..9ab8e5e 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -1,5 +1,4 @@ use crate::errors::EmptyInput; -use itertools::Itertools; use ndarray::prelude::*; use ndarray::Data; use num_traits::{Float, FromPrimitive}; @@ -213,7 +212,7 @@ where // rho_ij = sigma_ij / sqrt(sigma_ii * sigma_jj) *v / (self[[x, x]] * self[[y, y]]).powf(A::from_f64(0.5).unwrap()) }) - .collect_vec(); + .collect(); match Array2::from_shape_vec(self.raw_dim(), vals) { Ok(x) => Ok(x),