Skip to content

Commit 9cf10a3

Browse files
authored
Add missing linear algebra functions from array API (#633)
* Add matrix_transpose. * Add vecdot.
1 parent c57fed2 commit 9cf10a3

File tree

6 files changed

+124
-2
lines changed

6 files changed

+124
-2
lines changed

docs/generated/sparse.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ API
9696

9797
matmul
9898

99+
matrix_transpose
100+
99101
max
100102

101103
mean
@@ -168,6 +170,8 @@ API
168170

169171
var
170172

173+
vecdot
174+
171175
where
172176

173177
zeros

sparse/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
sum,
118118
tensordot,
119119
var,
120+
vecdot,
120121
zeros,
121122
zeros_like,
122123
)
@@ -135,6 +136,7 @@
135136
isneginf,
136137
isposinf,
137138
kron,
139+
matrix_transpose,
138140
nanmax,
139141
nanmean,
140142
nanmin,
@@ -250,6 +252,7 @@
250252
"logical_or",
251253
"logical_xor",
252254
"matmul",
255+
"matrix_transpose",
253256
"max",
254257
"mean",
255258
"min",
@@ -307,6 +310,7 @@
307310
"unique_counts",
308311
"unique_values",
309312
"var",
313+
"vecdot",
310314
"where",
311315
"zeros",
312316
"zeros_like",

sparse/_common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,6 +2037,7 @@ def format_to_string(format):
20372037
def asarray(obj, /, *, dtype=None, format="coo", backend="pydata", device=None, copy=False):
20382038
"""
20392039
Convert the input to a sparse array.
2040+
20402041
Parameters
20412042
----------
20422043
obj : array_like
@@ -2051,10 +2052,12 @@ def asarray(obj, /, *, dtype=None, format="coo", backend="pydata", device=None,
20512052
Device on which to place the created array.
20522053
copy : bool, optional
20532054
Boolean indicating whether or not to copy the input.
2055+
20542056
Returns
20552057
-------
20562058
out : Union[SparseArray, numpy.ndarray]
20572059
Sparse or 0-D array containing the data from `obj`.
2060+
20582061
Examples
20592062
--------
20602063
>>> x = np.eye(8, dtype="i8")
@@ -2209,3 +2212,25 @@ def isfinite(x, /):
22092212

22102213
def nonzero(x, /):
22112214
return x.nonzero()
2215+
2216+
2217+
def vecdot(x1, x2, /, *, axis=-1):
2218+
"""
2219+
Computes the (vector) dot product of two arrays.
2220+
2221+
Parameters
2222+
----------
2223+
x1, x2 : array_like
2224+
Input sparse arrays
2225+
axis : int
2226+
The axis to reduce over.
2227+
2228+
Returns
2229+
-------
2230+
out : Union[SparseArray, numpy.ndarray]
2231+
Sparse or 0-D array containing dot product.
2232+
"""
2233+
if np.issubdtype(x1.dtype, np.complexfloating):
2234+
x1 = np.conjugate(x1)
2235+
2236+
return np.sum(x1 * x2, axis=axis)

sparse/_coo/common.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,6 @@ def take(x, indices, /, *, axis=None):
13321332
------
13331333
ValueError
13341334
If the input array isn't and can't be converted to COO format.
1335-
13361335
"""
13371336

13381337
x = _validate_coo_input(x)
@@ -1533,3 +1532,31 @@ def _arg_minmax_common(
15331532
result = result.reshape([1 for _ in range(axis_none_original_ndim)])
15341533

15351534
return result if keepdims else result.squeeze()
1535+
1536+
1537+
def matrix_transpose(x, /):
1538+
"""
1539+
Transposes a matrix or a stack of matrices.
1540+
1541+
Parameters
1542+
----------
1543+
x : SparseArray
1544+
Input array.
1545+
1546+
Returns
1547+
-------
1548+
out : COO
1549+
Transposed COO array.
1550+
1551+
Raises
1552+
------
1553+
ValueError
1554+
If the input array isn't and can't be converted to COO format, or if ``x.ndim < 2``.
1555+
"""
1556+
if hasattr(x, "ndim") and x.ndim < 2:
1557+
raise ValueError("`x.ndim >= 2` must hold.")
1558+
x = _validate_coo_input(x)
1559+
transpose_axes = list(range(x.ndim))
1560+
transpose_axes[-2:] = transpose_axes[-2:][::-1]
1561+
1562+
return x.transpose(transpose_axes)

sparse/tests/test_coo.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1775,7 +1775,7 @@ def test_unique_values(self, arr, fill_value):
17751775

17761776
@pytest.mark.parametrize("func", [sparse.unique_counts, sparse.unique_values])
17771777
def test_input_validation(self, func):
1778-
with pytest.raises(ValueError, match=r"Input must be an instance of SparseArray"):
1778+
with pytest.raises(ValueError, match="Input must be an instance of SparseArray"):
17791779
func(self.arr)
17801780

17811781

@@ -1861,3 +1861,63 @@ def test_take(fill_value, indices, axis):
18611861
expected = np.take(arr, indices, axis)
18621862

18631863
np.testing.assert_equal(result.todense(), expected)
1864+
1865+
1866+
@pytest.mark.parametrize("ndim", [2, 3, 4, 5])
1867+
@pytest.mark.parametrize("density", [0.0, 0.1, 0.25, 1.0])
1868+
def test_matrix_transpose(ndim, density):
1869+
shape = tuple(range(2, 34)[:ndim])
1870+
xs = sparse.random(shape, density=density)
1871+
xd = xs.todense()
1872+
1873+
transpose_axes = list(range(ndim))
1874+
transpose_axes[-2:] = transpose_axes[-2:][::-1]
1875+
1876+
expected = np.transpose(xd, axes=transpose_axes)
1877+
actual = sparse.matrix_transpose(xs)
1878+
1879+
np.testing.assert_equal(actual.todense(), expected)
1880+
1881+
1882+
@pytest.mark.parametrize(
1883+
"shape1, shape2",
1884+
[
1885+
((2, 3, 4), (3, 4)),
1886+
((3, 4), (2, 3, 4)),
1887+
((3, 1, 4), (3, 2, 4)),
1888+
((1, 3, 4), (3, 4)),
1889+
((3, 4, 1), (3, 4, 2)),
1890+
((1, 5), (5, 1)),
1891+
((3, 1), (3, 4)),
1892+
((3, 1), (1, 4)),
1893+
((1, 4), (3, 4)),
1894+
((2, 2, 2), (1, 1, 1)),
1895+
],
1896+
)
1897+
@pytest.mark.parametrize("density", [0.0, 0.1, 0.25, 1.0])
1898+
@pytest.mark.parametrize("is_complex", [False, True])
1899+
def test_vecdot(shape1, shape2, density, rng, is_complex):
1900+
def data_rvs(size):
1901+
data = rng.random(size)
1902+
if is_complex:
1903+
data = data + rng.random(size) * 1j
1904+
return data
1905+
1906+
s1 = sparse.random(shape1, density=density, data_rvs=data_rvs)
1907+
s2 = sparse.random(shape2, density=density, data_rvs=data_rvs)
1908+
1909+
axis = rng.integers(max(s1.ndim, s2.ndim))
1910+
1911+
x1 = s1.todense()
1912+
x2 = s2.todense()
1913+
1914+
def np_vecdot(x1, x2, /, *, axis=-1):
1915+
if np.issubdtype(x1.dtype, np.complexfloating):
1916+
x1 = np.conjugate(x1)
1917+
1918+
return np.sum(x1 * x2, axis=axis)
1919+
1920+
expected = np_vecdot(x1, x2, axis=axis)
1921+
actual = sparse.vecdot(s1, s2, axis=axis)
1922+
1923+
np.testing.assert_allclose(actual.todense(), expected)

sparse/tests/test_namespace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def test_namespace():
9494
"logical_not",
9595
"logical_or",
9696
"logical_xor",
97+
"matrix_transpose",
9798
"matmul",
9899
"max",
99100
"mean",
@@ -152,6 +153,7 @@ def test_namespace():
152153
"unique_counts",
153154
"unique_values",
154155
"var",
156+
"vecdot",
155157
"where",
156158
"zeros",
157159
"zeros_like",

0 commit comments

Comments
 (0)