Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dot product support for ArrayD #1483

Merged
merged 13 commits into from
Mar 18, 2025
Prev Previous commit
Next Next commit
move tests into blas-tests
NewBornRustacean committed Mar 17, 2025
commit ba3c02914d758bfb50e8db2fe235ddbc3112398d
75 changes: 75 additions & 0 deletions crates/blas-tests/tests/dyn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
extern crate blas_src;
use ndarray::{Array1, Array2, ArrayD, linalg::Dot, Ix1, Ix2};

#[test]
fn test_arrayd_dot_2d() {
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();

let result = mat1.dot(&mat2);

// Verify the result is correct
assert_eq!(result.ndim(), 2);
assert_eq!(result.shape(), &[3, 3]);

// Compare with Array2 implementation
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let expected = mat1_2d.dot(&mat2_2d);

assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
}

#[test]
fn test_arrayd_dot_1d() {
// Test 1D array dot product
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();

let result = vec1.dot(&vec2);

// Verify scalar result
assert_eq!(result.ndim(), 0);
assert_eq!(result.shape(), &[]);
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
}

#[test]
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
fn test_arrayd_dot_3d() {
// Test that 3D arrays are not supported
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
fn test_arrayd_dot_incompatible_dims() {
// Test arrays with incompatible dimensions
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
fn test_arrayd_dot_matrix_vector() {
// Test matrix-vector multiplication
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();

let result = mat.dot(&vec);

// Verify result
assert_eq!(result.ndim(), 1);
assert_eq!(result.shape(), &[3]);

// Compare with Array2 implementation
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
let expected = mat_2d.dot(&vec_1d);

assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
}
103 changes: 1 addition & 102 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
@@ -1089,21 +1089,6 @@ mod blas_tests
/// - The array shapes are incompatible for the operation
/// - For vector dot product: the vectors have different lengths
///
/// # Examples
///
/// ```
/// use ndarray::{ArrayD, linalg::Dot};
///
/// // Matrix multiplication
/// let a = ArrayD::from_shape_vec(vec![2, 3], vec![1., 2., 3., 4., 5., 6.]).unwrap();
/// let b = ArrayD::from_shape_vec(vec![3, 2], vec![1., 2., 3., 4., 5., 6.]).unwrap();
/// let c = a.dot(&b);
///
/// // Vector dot product
/// let v1 = ArrayD::from_shape_vec(vec![3], vec![1., 2., 3.]).unwrap();
/// let v2 = ArrayD::from_shape_vec(vec![3], vec![4., 5., 6.]).unwrap();
/// let scalar = v1.dot(&v2);
/// ```
impl<A, S, S2> Dot<ArrayBase<S2, IxDyn>> for ArrayBase<S, IxDyn>
where
S: Data<Elem = A>,
@@ -1145,90 +1130,4 @@ where
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}

#[cfg(test)]
mod arrayd_dot_tests
{
use super::*;
use crate::ArrayD;

#[test]
fn test_arrayd_dot_2d()
{
// Test case from the original issue
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();

let result = mat1.dot(&mat2);

// Verify the result is correct
assert_eq!(result.ndim(), 2);
assert_eq!(result.shape(), &[3, 3]);

// Compare with Array2 implementation
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let expected = mat1_2d.dot(&mat2_2d);

assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
}

#[test]
fn test_arrayd_dot_1d()
{
// Test 1D array dot product
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();

let result = vec1.dot(&vec2);

// Verify scalar result
assert_eq!(result.ndim(), 0);
assert_eq!(result.shape(), &[]);
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
}

#[test]
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
fn test_arrayd_dot_3d()
{
// Test that 3D arrays are not supported
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
fn test_arrayd_dot_incompatible_dims()
{
// Test arrays with incompatible dimensions
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
fn test_arrayd_dot_matrix_vector()
{
// Test matrix-vector multiplication
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();

let result = mat.dot(&vec);

// Verify result
assert_eq!(result.ndim(), 1);
assert_eq!(result.shape(), &[3]);

// Compare with Array2 implementation
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
let expected = mat_2d.dot(&vec_1d);

assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
}
}
}