diff --git a/crates/blas-tests/tests/dyn.rs b/crates/blas-tests/tests/dyn.rs new file mode 100644 index 000000000..6c0fd975e --- /dev/null +++ b/crates/blas-tests/tests/dyn.rs @@ -0,0 +1,80 @@ +extern crate blas_src; +use ndarray::{linalg::Dot, Array1, Array2, ArrayD, Ix1, Ix2}; + +#[test] +fn test_arrayd_dot_2d() +{ + let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); + let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + + let result = mat1.dot(&mat2); + + // Verify the result is correct + assert_eq!(result.ndim(), 2); + assert_eq!(result.shape(), &[3, 3]); + + // Compare with Array2 implementation + let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); + let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); + let expected = mat1_2d.dot(&mat2_2d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); +} + +#[test] +fn test_arrayd_dot_1d() +{ + // Test 1D array dot product + let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); + let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); + + let result = vec1.dot(&vec2); + + // Verify scalar result + assert_eq!(result.ndim(), 0); + assert_eq!(result.shape(), &[]); + assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6 +} + +#[test] +#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] +fn test_arrayd_dot_3d() +{ + // Test that 3D arrays are not supported + let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic +} + +#[test] +#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] +fn test_arrayd_dot_incompatible_dims() +{ + // Test arrays with incompatible dimensions + let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic +} + +#[test] +fn test_arrayd_dot_matrix_vector() +{ + // Test matrix-vector multiplication + let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); + + let result = mat.dot(&vec); + + // Verify result + assert_eq!(result.ndim(), 1); + assert_eq!(result.shape(), &[3]); + + // Compare with Array2 implementation + let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec_1d = Array1::from_vec(vec![1.0, 2.0]); + let expected = mat_2d.dot(&vec_1d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); +} diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 7472d8292..e05740378 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -14,8 +14,11 @@ use crate::numeric_util; use crate::{LinalgScalar, Zip}; +#[cfg(not(feature = "std"))] +use alloc::vec; #[cfg(not(feature = "std"))] use alloc::vec::Vec; + use std::any::TypeId; use std::mem::MaybeUninit; @@ -353,7 +356,7 @@ where /// /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// - /// **Panics** if broadcasting isn’t possible. + /// **Panics** if broadcasting isn't possible. #[track_caller] pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayBase) where @@ -1067,3 +1070,64 @@ mod blas_tests } } } + +/// Dot product for dynamic-dimensional arrays (`ArrayD`). +/// +/// For one-dimensional arrays, computes the vector dot product, which is the sum +/// of the elementwise products (no conjugation of complex operands). +/// Both arrays must have the same length. +/// +/// For two-dimensional arrays, performs matrix multiplication. The array shapes +/// must be compatible in the following ways: +/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication +/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M* +/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N* +/// - If both arrays are one-dimensional of length *N*, returns a scalar +/// +/// **Panics** if: +/// - The arrays have dimensions other than 1 or 2 +/// - The array shapes are incompatible for the operation +/// - For vector dot product: the vectors have different lengths +/// +impl Dot> for ArrayBase +where + S: Data, + S2: Data, + A: LinalgScalar, +{ + type Output = Array; + + fn dot(&self, rhs: &ArrayBase) -> Self::Output + { + match (self.ndim(), rhs.ndim()) { + (1, 1) => { + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + ArrayD::from_elem(vec![], result) + } + (2, 2) => { + // Matrix-matrix multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + (2, 1) => { + // Matrix-vector multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + (1, 2) => { + // Vector-matrix multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + _ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"), + } + } +}