Skip to content

Commit d9f991f

Browse files
committed
API: Implement tensordot and matmul
1 parent d1b33c2 commit d9f991f

File tree

4 files changed

+67
-5
lines changed

4 files changed

+67
-5
lines changed

src/finch/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
Tensor,
1616
astype,
1717
random,
18+
tensordot,
19+
matmul,
1820
permute_dims,
1921
multiply,
2022
sum,
2123
prod,
2224
add,
2325
subtract,
24-
multiply,
2526
divide,
2627
positive,
2728
negative,
@@ -65,6 +66,8 @@
6566
"DenseStorage",
6667
"astype",
6768
"random",
69+
"tensordot",
70+
"matmul",
6871
"permute_dims",
6972
"int_",
7073
"int8",
@@ -82,7 +85,6 @@
8285
"complex64",
8386
"complex128",
8487
"bool",
85-
"multiply",
8688
"lazy",
8789
"compiled",
8890
"compute",

src/finch/julia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import juliapkg
22

3-
juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.19")
3+
juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.20")
44
import juliacall # noqa
55

66
juliapkg.resolve()

src/finch/tensor.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Callable, Iterable, Optional, Union
22

33
import numpy as np
44
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
@@ -113,7 +113,11 @@ def __pow__(self, other):
113113
return self._elemwise_op(".^", other)
114114

115115
def __matmul__(self, other):
116-
raise NotImplementedError
116+
# TODO: Implement and use mul instead of tensordot
117+
# https://github.com/willow-ahrens/finch-tensor/pull/22#issuecomment-2007884763
118+
if self.ndim != 2 or other.ndim != 2:
119+
raise ValueError(f"Both tensors must be 2-dimensional, but are: {self.ndim} and {other.ndim}.")
120+
return self.tensordot(other, axes=((-1,), (-2,)))
117121

118122
def __abs__(self):
119123
return self._elemwise_op("abs")
@@ -199,6 +203,15 @@ def _order(self) -> tuple[int, ...]:
199203
def is_computed(self) -> bool:
200204
return not jl.isa(self._obj, jl.Finch.LazyTensor)
201205

206+
def tensordot(self, other: "Tensor", /, *, axes=2) -> "Tensor":
207+
if isinstance(axes, Iterable):
208+
self_axes = normalize_axis_tuple(axes[0], self.ndim)
209+
other_axes = normalize_axis_tuple(axes[1], other.ndim)
210+
axes = (tuple(i + 1 for i in self_axes), tuple(i + 1 for i in other_axes))
211+
212+
result = jl.tensordot(self._obj, other._obj, axes)
213+
return Tensor(result)
214+
202215
@classmethod
203216
def preprocess_order(
204217
cls, order: OrderType, ndim: int
@@ -463,6 +476,14 @@ def prod(
463476
return _reduce(x, jl.prod, axis, dtype)
464477

465478

479+
def tensordot(x1: Tensor, x2: Tensor, /, *, axes=2) -> Tensor:
480+
return x1.tensordot(x2, axes=axes)
481+
482+
483+
def matmul(x1: Tensor, x2: Tensor) -> Tensor:
484+
return x1 @ x2
485+
486+
466487
def add(x1: Tensor, x2: Tensor, /) -> Tensor:
467488
return x1 + x2
468489

tests/test_ops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,42 @@ def test_reductions(arr3d, func_name, axis, dtype):
8686
actual = actual.todense()
8787

8888
assert_equal(actual, expected)
89+
90+
91+
def test_tensordot(arr3d):
92+
A_finch = finch.Tensor(arr1d)
93+
B_finch = finch.Tensor(arr2d)
94+
C_finch = finch.Tensor(arr3d)
95+
96+
actual = finch.tensordot(B_finch, B_finch)
97+
expected = np.tensordot(arr2d, arr2d)
98+
assert_equal(actual.todense(), expected)
99+
100+
actual = finch.tensordot(B_finch, B_finch, axes=(1, 1))
101+
expected = np.tensordot(arr2d, arr2d, axes=(1, 1))
102+
assert_equal(actual.todense(), expected)
103+
104+
actual = finch.tensordot(C_finch, finch.permute_dims(C_finch, (2, 1, 0)), axes=((2, 0), (0, 2)))
105+
expected = np.tensordot(arr3d, arr3d.T, axes=((2, 0), (0, 2)))
106+
assert_equal(actual.todense(), expected)
107+
108+
actual = finch.tensordot(C_finch, A_finch, axes=(2, 0))
109+
expected = np.tensordot(arr3d, arr1d, axes=(2, 0))
110+
assert_equal(actual.todense(), expected)
111+
112+
113+
def test_matmul(arr2d, arr3d):
114+
A_finch = finch.Tensor(arr2d)
115+
B_finch = finch.Tensor(arr2d.T)
116+
C_finch = finch.permute_dims(A_finch, (1, 0))
117+
D_finch = finch.Tensor(arr3d)
118+
119+
actual = A_finch @ B_finch
120+
expected = arr2d @ arr2d.T
121+
assert_equal(actual.todense(), expected)
122+
123+
actual = A_finch @ C_finch
124+
assert_equal(actual.todense(), expected)
125+
126+
with pytest.raises(ValueError, match="Both tensors must be 2-dimensional"):
127+
A_finch @ D_finch

0 commit comments

Comments
 (0)