diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 3b14ea221..03823136f 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -411,6 +411,25 @@ pub fn offset_from_ptr_to_memory(dim: &D, strides: &D) -> isize { offset } +/// Return the index of the dimension from a specific distance from first index. +/// +/// Panics if dis is greater than or equal to the count of dim elements. +pub(crate) fn index_from_distance(dim: &D, dis: usize) -> D { + let mut dis = dis; + let mut index = D::zeros(dim.ndim()); + let len = dim.ndim(); + for i in 0..len { + let m = len - i -1; + index[m] = dis % dim[m]; + dis = dis / dim[m]; + if dis == 0 { + break; + } + } + assert!(dis == 0); + index +} + /// Modify dimension, stride and return data pointer offset /// /// **Panics** if stride is 0 or if any index is out of bounds. diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index 2163c58a6..fd8d03056 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -4,6 +4,7 @@ use super::LanesIter; use super::LanesIterMut; use crate::imp_prelude::*; use crate::{Layout, NdProducer}; +use crate::iterators::LanesIterCore; impl_ndproducer! { ['a, A, D: Dimension] @@ -84,7 +85,9 @@ where type IntoIter = LanesIter<'a, A, D>; fn into_iter(self) -> Self::IntoIter { LanesIter { - iter: self.base.into_base_iter(), + iter: unsafe { + LanesIterCore::new(self.base.ptr.as_ptr(), self.base.dim, self.base.strides) + }, inner_len: self.inner_len, inner_stride: self.inner_stride, life: PhantomData, @@ -135,7 +138,9 @@ where type IntoIter = LanesIterMut<'a, A, D>; fn into_iter(self) -> Self::IntoIter { LanesIterMut { - iter: self.base.into_base_iter(), + iter: unsafe { + LanesIterCore::new(self.base.ptr.as_ptr(), self.base.dim, self.base.strides) + }, inner_len: self.inner_len, inner_stride: self.inner_stride, life: PhantomData, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 595f0897d..35df03c20 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -676,6 +676,184 @@ where } } +pub struct LanesIterCore { + ptr: *mut A, + dim: D, + strides: D, + dis: usize, + end: usize, +} + +clone_bounds!( + [A, D: Clone] + LanesIterCore[A, D] { + @copy { + ptr, + dis, + end, + } + dim, + strides, + } +); + +impl LanesIterCore { + #[inline] + pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> LanesIterCore { + LanesIterCore { + ptr, + dis: 0, + end: len.size(), + dim: len, + strides: stride, + } + } +} + +impl Iterator for LanesIterCore { + type Item = *mut A; + + #[inline] + fn next(&mut self) -> Option<*mut A> { + if self.dis < self.end { + let index = index_from_distance(&self.dim, self.dis); + let offset = D::stride_offset(&index, &self.strides); + self.dis += 1; + unsafe { Some(self.ptr.offset(offset)) } + } else { + None + } + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.len(); + (len, Some(len)) + } + + fn fold(mut self, init: Acc, mut g: G) -> Acc + where + G: FnMut(Acc, *mut A) -> Acc, + { + let ndim = self.dim.ndim(); + debug_assert_ne!(ndim, 0); + let mut accum = init; + while self.dis < self.end { + let index = index_from_distance(&self.dim, self.dis); + let stride = self.strides.last_elem() as isize; + let elem_index = index.last_elem(); + let mut len = self.dim.last_elem(); + if self.len() < len { + len = self.len(); + } + let offset = D::stride_offset(&index, &self.strides); + unsafe { + let row_ptr = self.ptr.offset(offset); + let mut i = 0; + let i_end = len - elem_index; + while i < i_end { + accum = g(accum, row_ptr.offset(i as isize * stride)); + i += 1; + } + self.dis += i_end; + } + } + accum + } +} + +impl<'a, A, D: Dimension> ExactSizeIterator for LanesIterCore { + fn len(&self) -> usize { + if let Some(len) = self.end.checked_sub(self.dis) { + len + } else { + 0 + } + } +} + +impl DoubleEndedIterator for LanesIterCore { + #[inline] + fn next_back(&mut self) -> Option<*mut A> { + if self.end <= self.dis { + return None + } + let index = index_from_distance(&self.dim, self.end - 1); + self.end -= 1; + let offset = D::stride_offset(&index, &self.strides); + + unsafe { Some(self.ptr.offset(offset)) } + } + + fn nth_back(&mut self, n: usize) -> Option<*mut A> { + let len = self.len(); + if n < len { + self.end -= n + 1; + let index = index_from_distance(&self.dim, self.end); + let offset = D::stride_offset(&index, &self.strides); + unsafe { Some(self.ptr.offset(offset)) } + } else { + None + } + } + + fn rfold(mut self, init: Acc, mut g: G) -> Acc + where + G: FnMut(Acc, *mut A) -> Acc, + { + let mut accum = init; + while self.dis < self.end { + let index = index_from_distance(&self.dim, self.end - 1); + let stride = self.strides.last_elem() as isize; + let elem_index = index.last_elem(); + let mut len = self.dim.last_elem(); + if self.len() < len { + len = self.len(); + } + let offset = D::stride_offset(&index, &self.strides); + unsafe { + let row_ptr = self.ptr.offset(offset); + let mut i = 0; + let i_end = len - elem_index; + while i < i_end { + accum = g(accum, row_ptr.offset(i as isize * -stride)); + i += 1; + } + self.end -= i_end; + } + } + accum + } +} + +impl LanesIterCore { + /// Splits the iterator at `index`, yielding two disjoint iterators. + /// + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. + fn split_at(self, index: usize) -> (Self, Self) { + assert!(index <= self.len()); + let mid = self.dis + index; + let left = LanesIterCore { + dim: self.dim.clone(), + strides: self.strides.clone(), + dis: self.dis, + end: mid, + ptr: self.ptr, + }; + let right = LanesIterCore{ + dim: self.dim, + strides: self.strides, + dis: mid, + end: self.end, + ptr: self.ptr, + }; + (left, right) + } +} + /// An iterator that traverses over all axes but one, and yields a view for /// each lane along that axis. /// @@ -683,7 +861,7 @@ where pub struct LanesIter<'a, A, D> { inner_len: Ix, inner_stride: Ixs, - iter: Baseiter, + iter: LanesIterCore, life: PhantomData<&'a A>, } @@ -715,6 +893,17 @@ where } } +impl<'a, A, D> DoubleEndedIterator for LanesIter<'a, A, D> + where + D: Dimension, +{ + fn next_back(&mut self) -> Option { + self.iter.next_back().map(|ptr| unsafe { + ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) + }) + } +} + impl<'a, A, D> ExactSizeIterator for LanesIter<'a, A, D> where D: Dimension, @@ -724,6 +913,33 @@ where } } +impl<'a, A, D: Dimension> LanesIter<'a, A, D> { + /// Splits the iterator at `index`, yielding two disjoint iterators. + /// + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. + pub fn split_at(self, index: usize) -> (Self, Self) { + let (left, right) = self.iter.split_at(index); + ( + LanesIter { + inner_len: self.inner_len, + inner_stride: self.inner_stride, + iter: left, + life: self.life, + }, + LanesIter { + inner_len: self.inner_len, + inner_stride: self.inner_stride, + iter: right, + life: self.life, + }, + ) + } +} + // NOTE: LanesIterMut is a mutable iterator and must not expose aliasing // pointers. Due to this we use an empty slice for the raw data (it's unused // anyway). @@ -735,7 +951,7 @@ where pub struct LanesIterMut<'a, A, D> { inner_len: Ix, inner_stride: Ixs, - iter: Baseiter, + iter: LanesIterCore, life: PhantomData<&'a mut A>, } @@ -755,6 +971,17 @@ where } } +impl<'a, A, D> DoubleEndedIterator for LanesIterMut<'a, A, D> + where + D: Dimension, +{ + fn next_back(&mut self) -> Option { + self.iter.next_back().map(|ptr| unsafe { + ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) + }) + } +} + impl<'a, A, D> ExactSizeIterator for LanesIterMut<'a, A, D> where D: Dimension, @@ -764,6 +991,33 @@ where } } +impl<'a, A, D: Dimension> LanesIterMut<'a, A, D> { + /// Splits the iterator at `index`, yielding two disjoint iterators. + /// + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. + pub fn split_at(self, index: usize) -> (Self, Self) { + let (left, right) = self.iter.split_at(index); + ( + LanesIterMut { + inner_len: self.inner_len, + inner_stride: self.inner_stride, + iter: left, + life: self.life, + }, + LanesIterMut { + inner_len: self.inner_len, + inner_stride: self.inner_stride, + iter: right, + life: self.life, + }, + ) + } +} + #[derive(Debug)] pub struct AxisIterCore { /// Index along the axis of the value of `.next()`, relative to the start @@ -1449,6 +1703,8 @@ use crate::indexes::IndicesIterF; use crate::iter::IndicesIter; #[cfg(feature = "std")] use crate::{geomspace::Geomspace, linspace::Linspace, logspace::Logspace}; +use crate::dimension::index_from_distance; + #[cfg(feature = "std")] unsafe impl TrustedIterator for Linspace {} #[cfg(feature = "std")] diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 878444c26..66328daae 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -7,7 +7,7 @@ // except according to those terms. use crate::imp_prelude::*; -use crate::numeric_util; +use crate::{numeric_util, IntoDimension}; use crate::{LinalgScalar, Zip}; @@ -36,6 +36,11 @@ const GEMM_BLAS_CUTOFF: usize = 7; #[allow(non_camel_case_types)] type blas_index = c_int; // blas index type +#[cfg(feature = "rayon")] +use crate::rayon::iter::IntoParallelIterator; +#[cfg(feature = "rayon")] +use crate::parallel::prelude::*; + impl ArrayBase where S: Data, @@ -155,15 +160,24 @@ unsafe fn blas_1d_params( } } + /// Matrix Multiplication -/// -/// For two-dimensional arrays, the dot method computes the matrix -/// multiplication. pub trait Dot { /// The result of the operation. /// - /// For two-dimensional arrays: a rectangular array. + /// If the dimensions of the two arrays are m and n respectively, + /// the result is an Array with dimensions m+n-2. type Output; + /// Dot product of two arrays. Specifically, + /// + /// If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation). + /// If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred. + /// If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred. + /// If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b. + /// If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b: + /// dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) + /// + /// **Panics** if shapes are incompatible. fn dot(&self, rhs: &Rhs) -> Self::Output; } @@ -356,10 +370,158 @@ where } } +macro_rules! multi_dot{ + ($lhs:ty, $rhs:ty, $out:ty) => { + impl Dot> for ArrayBase + where + S: Data, + S2: Data, + A: LinalgScalar, + { + type Output = Array; + + fn dot(&self, rhs: &ArrayBase) -> Array { + let (dim1, dim2) = (self.raw_dim(), rhs.raw_dim()); + let l1 = dim1.ndim(); + let l2 = dim2.ndim(); + let match_dim = if l2 > 1 { + l2 - 2 + } else { + 0 + }; + if dim1[l1-1] != dim2[match_dim] { + panic!( + "ndarray: shapes {:?} and {:?} not aligned for matrix multiplication", + dim1.slice(), dim2.slice() + ); + } + let mut new_dim =Vec::new(); + new_dim.extend_from_slice(&dim1.slice()[0..l1-1]); + new_dim.extend_from_slice(&dim2.slice()[0..match_dim]); + if l2 > 1 { + new_dim.push(dim2[l2-1]); + } + let new_dim = <$out>::from_dimension(&new_dim.into_dimension()).unwrap(); + if let Err(_) = size_of_shape_checked(&new_dim){ + panic!("ndarray: shape {:?} overflows isize", new_dim) + } + let lhs_s0 = self.strides()[0]; + let rhs_s0 = rhs.strides()[0]; + let column_major = lhs_s0 == 1 && rhs_s0 == 1; + let v = self.lanes(Axis(l1-1)) + .into_iter() + .map(|l1| { rhs.lanes(Axis(match_dim)) + .into_iter() + .map(|l2| l2.dot_impl(&l1)) + .collect::>() + }) + .flatten() + .collect::>(); + Array::from_shape_vec(new_dim.into_shape().set_f(column_major), v).unwrap() + } + } + }; +} + +macro_rules! dot_declare { + ($method:ident, $lhs:ty, $([$rhs:ty, $out:ty])*) => { + $($method!($lhs, $rhs, $out);)* + }; +} +dot_declare!(multi_dot, Ix1, [Ix3, Ix2] [Ix4, Ix3] [Ix5, Ix4] [Ix6, Ix5] [IxDyn, IxDyn]); +dot_declare!(multi_dot, Ix2, [Ix3, Ix3] [Ix4, Ix4] [Ix5, Ix5] [Ix6, Ix6] [IxDyn, IxDyn]); +dot_declare!(multi_dot, Ix3, [Ix1, Ix2] [Ix2, Ix3] [Ix3, Ix4] [Ix4, Ix5] [Ix5, Ix6] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_dot, Ix4, [Ix1, Ix3] [Ix2, Ix4] [Ix3, Ix5] [Ix4, Ix6] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_dot, Ix5, [Ix1, Ix4] [Ix2, Ix5] [Ix3, Ix6] [Ix4, IxDyn] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_dot, Ix6, [Ix1, Ix5] [Ix2, Ix6] [Ix3, IxDyn] [Ix4, IxDyn] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_dot, IxDyn, [Ix1, IxDyn] [Ix2, IxDyn] [Ix3, IxDyn] [Ix4, IxDyn] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); + +/// Parallel Matrix Multiplication +/// +/// If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation). +/// If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred. +/// If either a or b is 0-D (scalar), it is equivalent to multiply and using numpy.multiply(a, b) or a * b is preferred. +/// If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b. +/// If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b: +#[cfg(feature = "rayon")] +pub trait ParDot { + /// The result of the operation. + /// + /// If the dimensions of the two arrays are m and n respectively, + /// the result is an Array with dimensions m+n-2. + type Output; + fn par_dot(&self, rhs: &Rhs) -> Self::Output; +} + +macro_rules! multi_par_dot{ + ($lhs:ty, $rhs:ty, $out:ty) => { + #[cfg(feature = "rayon")] + impl ParDot> for ArrayBase + where + S: Data + Sync, + S2: Data + Sync, + A: LinalgScalar + Sync + Send, + { + type Output = Array; + + fn par_dot(&self, rhs: &ArrayBase) -> Array { + let (dim1, dim2) = (self.raw_dim(), rhs.raw_dim()); + let l1 = dim1.ndim(); + let l2 = dim2.ndim(); + let match_dim = if l2 > 1 { + l2 - 2 + } else { + 0 + }; + if dim1[l1-1] != dim2[match_dim] { + panic!( + "ndarray: shapes {:?} and {:?} not aligned for matrix multiplication", + dim1.slice(), dim2.slice() + ); + } + let mut new_dim =Vec::new(); + new_dim.extend_from_slice(&dim1.slice()[0..l1-1]); + new_dim.extend_from_slice(&dim2.slice()[0..match_dim]); + if l2 > 1 { + new_dim.push(dim2[l2-1]); + } + let new_dim = <$out>::from_dimension(&new_dim.into_dimension()).unwrap(); + if let Err(_) = size_of_shape_checked(&new_dim){ + panic!("ndarray: shape {:?} overflows isize", new_dim) + } + let lhs_s0 = self.strides()[0]; + let rhs_s0 = rhs.strides()[0]; + let column_major = lhs_s0 == 1 && rhs_s0 == 1; + let v = self.lanes(Axis(l1-1)) + .into_iter() + .into_par_iter() + .map(|l1| { rhs.lanes(Axis(match_dim)) + .into_iter() + .into_par_iter() + .map(|l2| l2.dot_impl(&l1)) + .collect::>() + }) + .flatten() + .collect::>(); + Array::from_shape_vec(new_dim.into_shape().set_f(column_major), v).unwrap() + } + } + }; +} + +dot_declare!(multi_par_dot, Ix1, [Ix3, Ix2] [Ix4, Ix3] [Ix5, Ix4] [Ix6, Ix5] [IxDyn, IxDyn]); +dot_declare!(multi_par_dot, Ix2, [Ix3, Ix3] [Ix4, Ix4] [Ix5, Ix5] [Ix6, Ix6] [IxDyn, IxDyn]); +dot_declare!(multi_par_dot, Ix3, [Ix1, Ix2] [Ix2, Ix3] [Ix3, Ix4] [Ix4, Ix5] [Ix5, Ix6] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_par_dot, Ix4, [Ix1, Ix3] [Ix2, Ix4] [Ix3, Ix5] [Ix4, Ix6] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_par_dot, Ix5, [Ix1, Ix4] [Ix2, Ix5] [Ix3, Ix6] [Ix4, IxDyn] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_par_dot, Ix6, [Ix1, Ix5] [Ix2, Ix6] [Ix3, IxDyn] [Ix4, IxDyn] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); +dot_declare!(multi_par_dot, IxDyn, [Ix1, IxDyn] [Ix2, IxDyn] [Ix3, IxDyn] [Ix4, IxDyn] [Ix5, IxDyn] [Ix6, IxDyn] [IxDyn, IxDyn]); + // mat_mul_impl uses ArrayView arguments to send all array kinds into // the same instantiated implementation. #[cfg(not(feature = "blas"))] use self::mat_mul_general as mat_mul_impl; +use crate::dimension::size_of_shape_checked; #[cfg(feature = "blas")] fn mat_mul_impl( diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index 8575905cd..2b2d9d2c8 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -11,5 +11,7 @@ pub use self::impl_linalg::general_mat_mul; pub use self::impl_linalg::general_mat_vec_mul; pub use self::impl_linalg::Dot; +#[cfg(feature = "rayon")] +pub use self::impl_linalg::ParDot; mod impl_linalg; diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 12059cee4..2f23575c1 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -134,6 +134,8 @@ use crate::{ use crate::iter::{ AxisIter, AxisIterMut, + LanesIter, + LanesIterMut, AxisChunksIter, AxisChunksIterMut, }; diff --git a/src/parallel/par.rs b/src/parallel/par.rs index d9d592af6..17e2acb1f 100644 --- a/src/parallel/par.rs +++ b/src/parallel/par.rs @@ -13,6 +13,8 @@ use crate::iter::AxisChunksIter; use crate::iter::AxisChunksIterMut; use crate::iter::AxisIter; use crate::iter::AxisIterMut; +use crate::iter::LanesIter; +use crate::iter::LanesIterMut; use crate::Dimension; use crate::{ArrayView, ArrayViewMut}; use crate::split_at::SplitPreference; @@ -115,6 +117,8 @@ macro_rules! par_iter_wrapper { par_iter_wrapper!(AxisIter, [Sync]); par_iter_wrapper!(AxisIterMut, [Send + Sync]); +par_iter_wrapper!(LanesIter, [Sync]); +par_iter_wrapper!(LanesIterMut, [Send + Sync]); par_iter_wrapper!(AxisChunksIter, [Sync]); par_iter_wrapper!(AxisChunksIterMut, [Send + Sync]); diff --git a/tests/array.rs b/tests/array.rs index 6581e572e..b5841b904 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -14,6 +14,7 @@ use ndarray::prelude::*; use ndarray::{arr3, rcarr2}; use ndarray::{Slice, SliceInfo, SliceOrIndex}; use std::iter::FromIterator; +use ndarray::linalg::Dot; macro_rules! assert_panics { ($body:expr) => { @@ -54,6 +55,35 @@ fn test_matmul_arcarray() { } } +#[test] +#[cfg(feature = "std")] +fn test_muti_dot() { + let a = Array::range(0., 3., 1.); + let b = Array::range(0., 12., 1.).into_shape((2, 3, 2)).unwrap(); + let c =a.dot(&b); + assert_eq!(c, arr2(&[[10.0, 13.0], [28.0, 31.0]])); + + let a = Array::range(0., 6., 1.).into_shape((2, 3)).unwrap(); + let b = Array::range(0., 24., 1.).into_shape((2, 2, 3, 2)).unwrap(); + let c =a.dot(&b); + let v = vec![10.0, 13.0, 28.0, 31.0, 46.0, 49.0, 64.0, 67.0, 28.0, 40.0, 100.0, 112.0, 172.0, 184.0, 244.0, 256.0]; + assert_eq!(c, Array4::from_shape_vec((2, 2, 2, 2), v).unwrap()); + + let a = Array::range(0., 6., 1.).into_shape((2, 3)).unwrap(); + let v =vec![1, 2, 1, 2, 1, 3, 2]; + let b = Array::range(0., 24., 1.).into_shape(v.clone()).unwrap(); + let c =a.dot(&b); + let v2 = vec![10.0, 13.0, 28.0, 31.0, 46.0, 49.0, 64.0, 67.0, 28.0, 40.0, 100.0, 112.0, 172.0, 184.0, 244.0, 256.0]; + assert_eq!(c, Array::from_shape_vec(vec![2, 1, 2, 1, 2, 1, 2], v2).unwrap()); + + let a = Array::range(0., 6., 1.).into_shape(vec![2, 3]).unwrap(); + let v =vec![1, 2, 1, 2, 1, 3, 2]; + let b = Array::range(0., 24., 1.).into_shape(v.clone()).unwrap(); + let c =a.dot(&b); + let v2 = vec![10.0, 13.0, 28.0, 31.0, 46.0, 49.0, 64.0, 67.0, 28.0, 40.0, 100.0, 112.0, 172.0, 184.0, 244.0, 256.0]; + assert_eq!(c, Array::from_shape_vec(vec![2, 1, 2, 1, 2, 1, 2], v2).unwrap()); +} + #[allow(unused)] fn arrayview_shrink_lifetime<'a, 'b: 'a>(view: ArrayView1<'b, f64>) -> ArrayView1<'a, f64> { view.reborrow() diff --git a/tests/par_dot.rs b/tests/par_dot.rs new file mode 100644 index 000000000..fe39c16cf --- /dev/null +++ b/tests/par_dot.rs @@ -0,0 +1,31 @@ +#![cfg(feature = "rayon")] +use ndarray::{Array, arr2, Array4}; +use ndarray::linalg::ParDot; + +#[test] +fn test_muti_par_dot() { + let a = Array::range(0., 3., 1.); + let b = Array::range(0., 12., 1.).into_shape((2, 3, 2)).unwrap(); + let c =a.par_dot(&b); + assert_eq!(c, arr2(&[[10.0, 13.0], [28.0, 31.0]])); + + let a = Array::range(0., 6., 1.).into_shape((2, 3)).unwrap(); + let b = Array::range(0., 24., 1.).into_shape((2, 2, 3, 2)).unwrap(); + let c =a.par_dot(&b); + let v = vec![10.0, 13.0, 28.0, 31.0, 46.0, 49.0, 64.0, 67.0, 28.0, 40.0, 100.0, 112.0, 172.0, 184.0, 244.0, 256.0]; + assert_eq!(c, Array4::from_shape_vec((2, 2, 2, 2), v).unwrap()); + + let a = Array::range(0., 6., 1.).into_shape((2, 3)).unwrap(); + let v =vec![1, 2, 1, 2, 1, 3, 2]; + let b = Array::range(0., 24., 1.).into_shape(v.clone()).unwrap(); + let c =a.par_dot(&b); + let v2 = vec![10.0, 13.0, 28.0, 31.0, 46.0, 49.0, 64.0, 67.0, 28.0, 40.0, 100.0, 112.0, 172.0, 184.0, 244.0, 256.0]; + assert_eq!(c, Array::from_shape_vec(vec![2, 1, 2, 1, 2, 1, 2], v2).unwrap()); + + let a = Array::range(0., 6., 1.).into_shape(vec![2, 3]).unwrap(); + let v =vec![1, 2, 1, 2, 1, 3, 2]; + let b = Array::range(0., 24., 1.).into_shape(v.clone()).unwrap(); + let c =a.par_dot(&b); + let v2 = vec![10.0, 13.0, 28.0, 31.0, 46.0, 49.0, 64.0, 67.0, 28.0, 40.0, 100.0, 112.0, 172.0, 184.0, 244.0, 256.0]; + assert_eq!(c, Array::from_shape_vec(vec![2, 1, 2, 1, 2, 1, 2], v2).unwrap()); +} \ No newline at end of file