Skip to content

Commit c12b29e

Browse files
authored
Add SDDMM example (#674)
1 parent 79b9d71 commit c12b29e

File tree

5 files changed

+100
-7
lines changed

5 files changed

+100
-7
lines changed

.github/workflows/ci.yml

+15
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,21 @@ jobs:
122122
- name: Run benchmarks
123123
run: |
124124
asv run --quick
125+
examples:
126+
runs-on: ubuntu-latest
127+
steps:
128+
- name: Checkout Repo
129+
uses: actions/checkout@v4
130+
- name: Set up Python
131+
uses: actions/[email protected]
132+
with:
133+
python-version: '3.11'
134+
- name: Build and install Sparse
135+
run: |
136+
python -m pip install '.[finch]' scipy
137+
- name: Run examples
138+
run: |
139+
source ci/test_examples.sh
125140
array_api_tests:
126141
runs-on: ubuntu-latest
127142
steps:

benchmarks/benchmark_backends.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from .utils import SkipNotImplemented
66

7-
TIMEOUT: float = 200.0
7+
TIMEOUT: float = 500.0
88
BACKEND: sparse.BackendType = sparse.backend_var.get()
99

1010

@@ -42,8 +42,7 @@ def time_tensordot(self):
4242

4343
class SpMv:
4444
timeout = TIMEOUT
45-
# NOTE: https://github.com/willow-ahrens/Finch.jl/issues/488
46-
params = [[True, False], [(10, 0.01)]] # (1000, 0.01), (1_000_000, 1e-05)
45+
params = [[True, False], [(50, 0.1)]] # (1000, 0.01), (1_000_000, 1e-05)
4746
param_names = ["lazy_mode", "size_and_density"]
4847

4948
def setup(self, lazy_mode, size_and_density):
@@ -55,9 +54,8 @@ def setup(self, lazy_mode, size_and_density):
5554
random_kwargs["format"] = "gcxs"
5655

5756
self.M = sparse.random((size, size), **random_kwargs)
58-
# NOTE: Once https://github.com/willow-ahrens/Finch.jl/issues/487 is fixed change to (size, 1).
59-
self.v1 = rng.normal(size=(size, 2))
60-
self.v2 = rng.normal(size=(size, 2))
57+
self.v1 = rng.normal(size=(size, 1))
58+
self.v2 = rng.normal(size=(size, 1))
6159

6260
if sparse.BackendType.Finch == BACKEND:
6361
import finch

ci/test_examples.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
for example in $(find ./examples/ -iname *.py); do
2+
python $example
3+
done

examples/sddmm_example.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import time
2+
3+
import sparse
4+
5+
import numpy as np
6+
import scipy.sparse as sps
7+
8+
LEN = 10000
9+
DENSITY = 0.0001
10+
ITERS = 3
11+
rng = np.random.default_rng(0)
12+
13+
14+
def benchmark(func, info, args):
15+
print(info)
16+
start = time.time()
17+
for _ in range(ITERS):
18+
func(*args)
19+
elapsed = time.time() - start
20+
print(f"Took {elapsed / ITERS} s.\n")
21+
22+
23+
if __name__ == "__main__":
24+
a_sps = rng.random((LEN, LEN - 10)) * 10
25+
b_sps = rng.random((LEN - 10, LEN)) * 10
26+
s_sps = sps.random(LEN, LEN, format="coo", density=DENSITY, random_state=rng) * 10
27+
s_sps.sum_duplicates()
28+
29+
# Finch
30+
with sparse.Backend(backend=sparse.BackendType.Finch):
31+
s = sparse.asarray(s_sps)
32+
a = sparse.asarray(np.array(a_sps, order="F"))
33+
b = sparse.asarray(np.array(b_sps, order="C"))
34+
35+
@sparse.compiled
36+
def sddmm_finch(s, a, b):
37+
return sparse.sum(
38+
s[:, :, None] * (a[:, None, :] * sparse.permute_dims(b, (1, 0))[None, :, :]),
39+
axis=-1,
40+
)
41+
42+
# Compile
43+
result_finch = sddmm_finch(s, a, b)
44+
assert sparse.nonzero(result_finch)[0].size > 5
45+
# Benchmark
46+
benchmark(sddmm_finch, info="Finch", args=[s, a, b])
47+
48+
# Numba
49+
with sparse.Backend(backend=sparse.BackendType.Numba):
50+
s = sparse.asarray(s_sps)
51+
a = a_sps
52+
b = b_sps
53+
54+
def sddmm_numba(s, a, b):
55+
return s * (a @ b)
56+
57+
# Compile
58+
result_numba = sddmm_numba(s, a, b)
59+
assert sparse.nonzero(result_numba)[0].size > 5
60+
# Benchmark
61+
benchmark(sddmm_numba, info="Numba", args=[s, a, b])
62+
63+
# SciPy
64+
def sddmm_scipy(s, a, b):
65+
return s.multiply(a @ b)
66+
67+
s = s_sps.asformat("csr")
68+
a = a_sps
69+
b = b_sps
70+
71+
result_scipy = sddmm_scipy(s, a, b)
72+
# Benchmark
73+
benchmark(sddmm_scipy, info="SciPy", args=[s, a, b])
74+
75+
np.testing.assert_allclose(result_numba.todense(), result_scipy.toarray())
76+
np.testing.assert_allclose(result_finch.todense(), result_numba.todense())
77+
np.testing.assert_allclose(result_finch.todense(), result_scipy.toarray())

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ tests = [
3838
]
3939
tox = ["sparse[tests]", "tox"]
4040
all = ["sparse[docs,tox]", "matrepr"]
41-
finch = ["finch-tensor>=0.1.14"]
41+
finch = ["finch-tensor>=0.1.19"]
4242

4343
[project.urls]
4444
Documentation = "https://sparse.pydata.org/"

0 commit comments

Comments
 (0)