Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dot product support for ArrayD #1483

Merged
merged 13 commits into from
Mar 18, 2025
80 changes: 80 additions & 0 deletions crates/blas-tests/tests/dyn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
extern crate blas_src;
use ndarray::{linalg::Dot, Array1, Array2, ArrayD, Ix1, Ix2};

#[test]
fn test_arrayd_dot_2d()
{
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();

let result = mat1.dot(&mat2);

// Verify the result is correct
assert_eq!(result.ndim(), 2);
assert_eq!(result.shape(), &[3, 3]);

// Compare with Array2 implementation
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let expected = mat1_2d.dot(&mat2_2d);

assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
}

#[test]
fn test_arrayd_dot_1d()
{
// Test 1D array dot product
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();

let result = vec1.dot(&vec2);

// Verify scalar result
assert_eq!(result.ndim(), 0);
assert_eq!(result.shape(), &[]);
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
}

#[test]
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
fn test_arrayd_dot_3d()
{
// Test that 3D arrays are not supported
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
fn test_arrayd_dot_incompatible_dims()
{
// Test arrays with incompatible dimensions
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
fn test_arrayd_dot_matrix_vector()
{
// Test matrix-vector multiplication
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();

let result = mat.dot(&vec);

// Verify result
assert_eq!(result.ndim(), 1);
assert_eq!(result.shape(), &[3]);

// Compare with Array2 implementation
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
let expected = mat_2d.dot(&vec_1d);

assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
}
66 changes: 65 additions & 1 deletion src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ use crate::numeric_util;

use crate::{LinalgScalar, Zip};

#[cfg(not(feature = "std"))]
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;

use std::any::TypeId;
use std::mem::MaybeUninit;

Expand Down Expand Up @@ -353,7 +356,7 @@ where
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isnt possible.
/// **Panics** if broadcasting isn't possible.
#[track_caller]
pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
where
Expand Down Expand Up @@ -1067,3 +1070,64 @@ mod blas_tests
}
}
}

/// Dot product for dynamic-dimensional arrays (`ArrayD`).
///
/// For one-dimensional arrays, computes the vector dot product, which is the sum
/// of the elementwise products (no conjugation of complex operands).
/// Both arrays must have the same length.
///
/// For two-dimensional arrays, performs matrix multiplication. The array shapes
/// must be compatible in the following ways:
/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication
/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M*
/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N*
/// - If both arrays are one-dimensional of length *N*, returns a scalar
///
/// **Panics** if:
/// - The arrays have dimensions other than 1 or 2
/// - The array shapes are incompatible for the operation
/// - For vector dot product: the vectors have different lengths
///
impl<A, S, S2> Dot<ArrayBase<S2, IxDyn>> for ArrayBase<S, IxDyn>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = Array<A, IxDyn>;

fn dot(&self, rhs: &ArrayBase<S2, IxDyn>) -> Self::Output
{
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
// Matrix-matrix multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
// Matrix-vector multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
// Vector-matrix multiplication
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}
Loading