Skip to content

Commit 64fc60b

Browse files
authored
Merge branch 'main' into pagerank-example
2 parents 1bf25ea + 8cb2bb9 commit 64fc60b

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ defaults:
33
shell: bash -leo pipefail {0}
44

55
concurrency:
6-
group: ${{ github.head_ref }}
6+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
77
cancel-in-progress: true
88

99
jobs:

examples/matmul_example.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import importlib
2+
import os
3+
4+
import sparse
5+
6+
from utils import benchmark
7+
8+
import numpy as np
9+
import scipy.sparse as sps
10+
11+
LEN = 100000
12+
DENSITY = 0.00001
13+
ITERS = 3
14+
rng = np.random.default_rng(0)
15+
16+
17+
if __name__ == "__main__":
18+
print("Matmul Example:\n")
19+
20+
a_sps = sps.random(LEN, LEN - 10, format="csr", density=DENSITY, random_state=rng) * 10
21+
a_sps.sum_duplicates()
22+
b_sps = sps.random(LEN - 10, LEN, format="csr", density=DENSITY, random_state=rng) * 10
23+
b_sps.sum_duplicates()
24+
25+
# ======= Finch =======
26+
os.environ[sparse._ENV_VAR_NAME] = "Finch"
27+
importlib.reload(sparse)
28+
29+
a = sparse.asarray(a_sps)
30+
b = sparse.asarray(b_sps)
31+
32+
@sparse.compiled
33+
def sddmm_finch(a, b):
34+
return sparse.sum(a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :], axis=-1)
35+
36+
# Compile
37+
result_finch = sddmm_finch(a, b)
38+
# Benchmark
39+
benchmark(sddmm_finch, args=[a, b], info="Finch", iters=ITERS)
40+
41+
# ======= Numba =======
42+
os.environ[sparse._ENV_VAR_NAME] = "Numba"
43+
importlib.reload(sparse)
44+
45+
a = sparse.asarray(a_sps)
46+
b = sparse.asarray(b_sps)
47+
48+
def sddmm_numba(a, b):
49+
return a @ b
50+
51+
# Compile
52+
result_numba = sddmm_numba(a, b)
53+
# Benchmark
54+
benchmark(sddmm_numba, args=[a, b], info="Numba", iters=ITERS)
55+
56+
# ======= SciPy =======
57+
def sddmm_scipy(a, b):
58+
return a @ b
59+
60+
a = a_sps
61+
b = b_sps
62+
63+
result_scipy = sddmm_scipy(a, b)
64+
# Benchmark
65+
benchmark(sddmm_scipy, args=[a, b], info="SciPy", iters=ITERS)
66+
67+
# np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
68+
# np.testing.assert_allclose(result_finch.todense(), result_numba.todense())
69+
# np.testing.assert_allclose(result_finch.todense(), result_scipy.toarray())

0 commit comments

Comments
 (0)