Skip to content

Commit c71a5fb

Browse files
authored
Add matmul example (#701)
1 parent 05cf589 commit c71a5fb

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

Diff for: examples/matmul_example.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import importlib
2+
import os
3+
4+
import sparse
5+
6+
from utils import benchmark
7+
8+
import numpy as np
9+
import scipy.sparse as sps
10+
11+
LEN = 100000
12+
DENSITY = 0.00001
13+
ITERS = 3
14+
rng = np.random.default_rng(0)
15+
16+
17+
if __name__ == "__main__":
18+
print("Matmul Example:\n")
19+
20+
a_sps = sps.random(LEN, LEN - 10, format="csr", density=DENSITY, random_state=rng) * 10
21+
a_sps.sum_duplicates()
22+
b_sps = sps.random(LEN - 10, LEN, format="csr", density=DENSITY, random_state=rng) * 10
23+
b_sps.sum_duplicates()
24+
25+
# ======= Finch =======
26+
os.environ[sparse._ENV_VAR_NAME] = "Finch"
27+
importlib.reload(sparse)
28+
29+
a = sparse.asarray(a_sps)
30+
b = sparse.asarray(b_sps)
31+
32+
@sparse.compiled
33+
def sddmm_finch(a, b):
34+
return sparse.sum(a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :], axis=-1)
35+
36+
# Compile
37+
result_finch = sddmm_finch(a, b)
38+
# Benchmark
39+
benchmark(sddmm_finch, args=[a, b], info="Finch", iters=ITERS)
40+
41+
# ======= Numba =======
42+
os.environ[sparse._ENV_VAR_NAME] = "Numba"
43+
importlib.reload(sparse)
44+
45+
a = sparse.asarray(a_sps)
46+
b = sparse.asarray(b_sps)
47+
48+
def sddmm_numba(a, b):
49+
return a @ b
50+
51+
# Compile
52+
result_numba = sddmm_numba(a, b)
53+
# Benchmark
54+
benchmark(sddmm_numba, args=[a, b], info="Numba", iters=ITERS)
55+
56+
# ======= SciPy =======
57+
def sddmm_scipy(a, b):
58+
return a @ b
59+
60+
a = a_sps
61+
b = b_sps
62+
63+
result_scipy = sddmm_scipy(a, b)
64+
# Benchmark
65+
benchmark(sddmm_scipy, args=[a, b], info="SciPy", iters=ITERS)
66+
67+
# np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
68+
# np.testing.assert_allclose(result_finch.todense(), result_numba.todense())
69+
# np.testing.assert_allclose(result_finch.todense(), result_scipy.toarray())

0 commit comments

Comments
 (0)