Skip to content

Commit dc1eeb0

Browse files
committed
API: Implement tensordot and matmul
1 parent d1b33c2 commit dc1eeb0

File tree

4 files changed

+82
-5
lines changed

4 files changed

+82
-5
lines changed

src/finch/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
Tensor,
1616
astype,
1717
random,
18+
tensordot,
19+
matmul,
1820
permute_dims,
1921
multiply,
2022
sum,
2123
prod,
2224
add,
2325
subtract,
24-
multiply,
2526
divide,
2627
positive,
2728
negative,
@@ -65,6 +66,8 @@
6566
"DenseStorage",
6667
"astype",
6768
"random",
69+
"tensordot",
70+
"matmul",
6871
"permute_dims",
6972
"int_",
7073
"int8",
@@ -82,7 +85,6 @@
8285
"complex64",
8386
"complex128",
8487
"bool",
85-
"multiply",
8688
"lazy",
8789
"compiled",
8890
"compute",

src/finch/julia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import juliapkg
22

3-
juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.19")
3+
juliapkg.add("Finch", "9177782c-1635-4eb9-9bfb-d9dfa25e6bce", version="0.6.20")
44
import juliacall # noqa
55

66
juliapkg.resolve()

src/finch/tensor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Callable, Iterable, Optional, Union
22

33
import numpy as np
44
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
@@ -113,7 +113,11 @@ def __pow__(self, other):
113113
return self._elemwise_op(".^", other)
114114

115115
def __matmul__(self, other):
116-
raise NotImplementedError
116+
# TODO: Implement and use mul instead of tensordot
117+
# https://github.com/willow-ahrens/finch-tensor/pull/22#issuecomment-2007884763
118+
if self.ndim != 2 or other.ndim != 2:
119+
raise ValueError(f"Both tensors must be 2-dimensional, but are: {self.ndim=} and {other.ndim=}.")
120+
return tensordot(self, other, axes=((-1,), (-2,)))
117121

118122
def __abs__(self):
119123
return self._elemwise_op("abs")
@@ -463,6 +467,20 @@ def prod(
463467
return _reduce(x, jl.prod, axis, dtype)
464468

465469

470+
def tensordot(x1: Tensor, x2: Tensor, /, *, axes=2) -> Tensor:
471+
if isinstance(axes, Iterable):
472+
self_axes = normalize_axis_tuple(axes[0], x1.ndim)
473+
other_axes = normalize_axis_tuple(axes[1], x2.ndim)
474+
axes = (tuple(i + 1 for i in self_axes), tuple(i + 1 for i in other_axes))
475+
476+
result = jl.tensordot(x1._obj, x2._obj, axes)
477+
return Tensor(result)
478+
479+
480+
def matmul(x1: Tensor, x2: Tensor) -> Tensor:
481+
return x1 @ x2
482+
483+
466484
def add(x1: Tensor, x2: Tensor, /) -> Tensor:
467485
return x1 + x2
468486

tests/test_ops.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,60 @@ def test_reductions(arr3d, func_name, axis, dtype):
8686
actual = actual.todense()
8787

8888
assert_equal(actual, expected)
89+
90+
91+
@pytest.mark.parametrize(
92+
"storage",
93+
[
94+
None,
95+
(
96+
finch.Storage(finch.SparseList(finch.Element(np.int64(0))), order="C"),
97+
finch.Storage(finch.Dense(finch.SparseList(finch.Element(np.int64(0)))), order="C"),
98+
finch.Storage(
99+
finch.Dense(finch.SparseList(finch.SparseList(finch.Element(np.int64(0))))),
100+
order="C",
101+
),
102+
)
103+
]
104+
)
105+
def test_tensordot(arr3d, storage):
106+
A_finch = finch.Tensor(arr1d)
107+
B_finch = finch.Tensor(arr2d)
108+
C_finch = finch.Tensor(arr3d)
109+
if storage is not None:
110+
A_finch = A_finch.to_device(storage[0])
111+
B_finch = B_finch.to_device(storage[1])
112+
C_finch = C_finch.to_device(storage[2])
113+
114+
actual = finch.tensordot(B_finch, B_finch)
115+
expected = np.tensordot(arr2d, arr2d)
116+
assert_equal(actual.todense(), expected)
117+
118+
actual = finch.tensordot(B_finch, B_finch, axes=(1, 1))
119+
expected = np.tensordot(arr2d, arr2d, axes=(1, 1))
120+
assert_equal(actual.todense(), expected)
121+
122+
actual = finch.tensordot(C_finch, finch.permute_dims(C_finch, (2, 1, 0)), axes=((2, 0), (0, 2)))
123+
expected = np.tensordot(arr3d, arr3d.T, axes=((2, 0), (0, 2)))
124+
assert_equal(actual.todense(), expected)
125+
126+
actual = finch.tensordot(C_finch, A_finch, axes=(2, 0))
127+
expected = np.tensordot(arr3d, arr1d, axes=(2, 0))
128+
assert_equal(actual.todense(), expected)
129+
130+
131+
def test_matmul(arr2d, arr3d):
132+
A_finch = finch.Tensor(arr2d)
133+
B_finch = finch.Tensor(arr2d.T)
134+
C_finch = finch.permute_dims(A_finch, (1, 0))
135+
D_finch = finch.Tensor(arr3d)
136+
137+
actual = A_finch @ B_finch
138+
expected = arr2d @ arr2d.T
139+
assert_equal(actual.todense(), expected)
140+
141+
actual = A_finch @ C_finch
142+
assert_equal(actual.todense(), expected)
143+
144+
with pytest.raises(ValueError, match="Both tensors must be 2-dimensional"):
145+
A_finch @ D_finch

0 commit comments

Comments
 (0)