Skip to content
Merged
80 changes: 80 additions & 0 deletions crates/blas-tests/tests/dyn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
extern crate blas_src;
use ndarray::{linalg::Dot, Array1, Array2, ArrayD, Ix1, Ix2};

#[test]
fn test_arrayd_dot_2d()
{
let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap();
let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();

let result = mat1.dot(&mat2);

// Verify the result is correct
assert_eq!(result.ndim(), 2);
assert_eq!(result.shape(), &[3, 3]);

// Compare with Array2 implementation
let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap();
let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
let expected = mat1_2d.dot(&mat2_2d);

assert_eq!(result.into_dimensionality::<Ix2>().unwrap(), expected);
}

#[test]
fn test_arrayd_dot_1d()
{
// Test 1D array dot product
let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap();

let result = vec1.dot(&vec2);

// Verify scalar result
assert_eq!(result.ndim(), 0);
assert_eq!(result.shape(), &[]);
assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6
}

#[test]
#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")]
fn test_arrayd_dot_3d()
{
// Test that 3D arrays are not supported
let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")]
fn test_arrayd_dot_incompatible_dims()
{
// Test arrays with incompatible dimensions
let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap();

let _result = arr1.dot(&arr2); // Should panic
}

#[test]
fn test_arrayd_dot_matrix_vector()
{
// Test matrix-vector multiplication
let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap();

let result = mat.dot(&vec);

// Verify result
assert_eq!(result.ndim(), 1);
assert_eq!(result.shape(), &[3]);

// Compare with Array2 implementation
let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let vec_1d = Array1::from_vec(vec![1.0, 2.0]);
let expected = mat_2d.dot(&vec_1d);

assert_eq!(result.into_dimensionality::<Ix1>().unwrap(), expected);
}
66 changes: 65 additions & 1 deletion src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ use crate::numeric_util;

use crate::{LinalgScalar, Zip};

#[cfg(not(feature = "std"))]
use alloc::vec;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;

use std::any::TypeId;
use std::mem::MaybeUninit;

Expand Down Expand Up @@ -353,7 +356,7 @@ where
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isnt possible.
/// **Panics** if broadcasting isn't possible.
#[track_caller]
pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
where
Expand Down Expand Up @@ -1067,3 +1070,64 @@ mod blas_tests
}
}
}

/// Dot product for dynamic-dimensional arrays (`ArrayD`).
///
/// For one-dimensional arrays, computes the vector dot product, which is the sum
/// of the elementwise products (no conjugation of complex operands).
/// Both arrays must have the same length.
///
/// For two-dimensional arrays, performs matrix multiplication. The array shapes
/// must be compatible in the following ways:
/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication
/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M*
/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N*
/// - If both arrays are one-dimensional of length *N*, returns a scalar
///
/// **Panics** if:
/// - The arrays have dimensions other than 1 or 2
/// - The array shapes are incompatible for the operation
/// - For vector dot product: the vectors have different lengths
///
impl<A, S, S2> Dot<ArrayBase<S2, IxDyn>> for ArrayBase<S, IxDyn>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = Array<A, IxDyn>;

fn dot(&self, rhs: &ArrayBase<S2, IxDyn>) -> Self::Output
{
match (self.ndim(), rhs.ndim()) {
(1, 1) => {
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
ArrayD::from_elem(vec![], result)
}
(2, 2) => {
// Matrix-matrix multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(2, 1) => {
// Matrix-vector multiplication
let a = self.view().into_dimensionality::<Ix2>().unwrap();
let b = rhs.view().into_dimensionality::<Ix1>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
(1, 2) => {
// Vector-matrix multiplication
let a = self.view().into_dimensionality::<Ix1>().unwrap();
let b = rhs.view().into_dimensionality::<Ix2>().unwrap();
let result = a.dot(&b);
result.into_dimensionality::<IxDyn>().unwrap()
}
_ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"),
}
}
}
Loading