Skip to content

Commit 7dbbbc5

Browse files
committed
feature: support sequential matrix multiply
1 parent 3a90550 commit 7dbbbc5

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

examples/matrix.rs

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
use anyhow::Result;
2+
fn main() -> Result<()> {
3+
println!("f64 default: {}", f64::default());
4+
Ok(())
5+
}

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
mod matrix;
2+
3+
pub use matrix::{multiply, Matrix};

src/matrix.rs

+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
use anyhow::{anyhow, Result};
2+
use std::{
3+
fmt,
4+
ops::{Add, AddAssign, Mul},
5+
};
6+
7+
// [[1, 2], [1, 2], [1, 2]] => [1, 2, 1, 2, 1, 2]
8+
pub struct Matrix<T> {
9+
data: Vec<T>,
10+
row: usize,
11+
col: usize,
12+
}
13+
14+
pub fn multiply<T>(a: &Matrix<T>, b: &Matrix<T>) -> Result<Matrix<T>>
15+
where
16+
T: Copy + Default + Add<Output = T> + AddAssign + Mul<Output = T>,
17+
{
18+
if a.col != b.row {
19+
return Err(anyhow!("Matrix multiply error: a.col != b.row"));
20+
}
21+
22+
let mut data = vec![T::default(); a.row * b.col];
23+
for i in 0..a.row {
24+
for j in 0..b.col {
25+
for k in 0..a.col {
26+
data[i * b.col + j] += a.data[i * a.col + k] * b.data[k * b.col + j];
27+
}
28+
}
29+
}
30+
31+
Ok(Matrix {
32+
data,
33+
row: a.row,
34+
col: b.col,
35+
})
36+
}
37+
38+
impl<T: fmt::Debug> Matrix<T> {
39+
pub fn new(data: impl Into<Vec<T>>, row: usize, col: usize) -> Self {
40+
Self {
41+
data: data.into(),
42+
row,
43+
col,
44+
}
45+
}
46+
}
47+
48+
impl<T> fmt::Display for Matrix<T>
49+
where
50+
T: fmt::Display,
51+
{
52+
// display a 2x3 as {1 2 3, 4 5 6}, 3x2 as {1 2, 3 4, 5 6}
53+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
54+
write!(f, "{{")?;
55+
for i in 0..self.row {
56+
for j in 0..self.col {
57+
write!(f, "{}", self.data[i * self.col + j])?;
58+
if j != self.col - 1 {
59+
write!(f, " ")?;
60+
}
61+
}
62+
63+
if i != self.row - 1 {
64+
write!(f, ", ")?;
65+
}
66+
}
67+
write!(f, "}}")?;
68+
Ok(())
69+
}
70+
}
71+
72+
impl<T> fmt::Debug for Matrix<T>
73+
where
74+
T: fmt::Display,
75+
{
76+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77+
write!(f, "Matrix(row={}, col={}, {})", self.row, self.col, self)
78+
}
79+
}
80+
81+
#[cfg(test)]
82+
mod tests {
83+
use super::*;
84+
85+
#[test]
86+
fn test_matrix_multiply() -> Result<()> {
87+
let a = Matrix::new([1, 2, 3, 4, 5, 6], 2, 3);
88+
let b = Matrix::new([1, 2, 3, 4, 5, 6], 3, 2);
89+
let c = multiply(&a, &b)?;
90+
assert_eq!(c.col, 2);
91+
assert_eq!(c.row, 2);
92+
assert_eq!(c.data, vec![22, 28, 49, 64]);
93+
assert_eq!(format!("{:?}", c), "Matrix(row=2, col=2, {22 28, 49 64})");
94+
95+
Ok(())
96+
}
97+
98+
#[test]
99+
fn test_matrix_display() -> Result<()> {
100+
let a = Matrix::new([1, 2, 3, 4], 2, 2);
101+
let b = Matrix::new([1, 2, 3, 4], 2, 2);
102+
let c = multiply(&a, &b)?;
103+
assert_eq!(c.data, vec![7, 10, 15, 22]);
104+
assert_eq!(format!("{}", c), "{7 10, 15 22}");
105+
Ok(())
106+
}
107+
}

0 commit comments

Comments
 (0)