Skip to content

Commit 851e82c

Browse files
authored
Merge pull request #849 from pydata/ms/csr-csc-mT-fix
Fix CSC/CSR `.mT` fill-value result
2 parents 60c552f + 769a5cd commit 851e82c

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

ci/Finch-array-api-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
162162
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
163163
array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__]
164164
array_api_tests/test_has_names.py::test_has_names[array_attribute-T]
165-
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
166165

167166
# test_indexing_functions
168167

sparse/numba_backend/_compressed/compressed.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,7 @@ def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"
913913
self = self.copy()
914914
if axes == (0, 1):
915915
return self
916-
return CSC((self.data, self.indices, self.indptr), self.shape[::-1])
916+
return CSC((self.data, self.indices, self.indptr), self.shape[::-1], fill_value=self.fill_value)
917917

918918

919919
class CSC(_Compressed2d):
@@ -945,4 +945,4 @@ def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"
945945
self = self.copy()
946946
if axes == (0, 1):
947947
return self
948-
return CSR((self.data, self.indices, self.indptr), self.shape[::-1])
948+
return CSR((self.data, self.indices, self.indptr), self.shape[::-1], fill_value=self.fill_value)

sparse/numba_backend/tests/test_compressed_2d.py

+7
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ def test_transpose(random_sparse, copy):
117117
random_sparse.transpose(axes=0)
118118

119119

120+
@pytest.mark.parametrize("format", ["csr", "csc"])
121+
def test_mT_fill_value(format):
122+
fv = 1.0
123+
arr = sparse.full((10, 20), fill_value=fv, format=format)
124+
assert_eq(arr.mT, sparse.full((20, 10), fill_value=fv))
125+
126+
120127
def test_transpose_error(random_sparse):
121128
with pytest.raises(ValueError):
122129
random_sparse.transpose(axes=1)

0 commit comments

Comments
 (0)