Skip to content

Commit 464f49a

Browse files
committed
Add __binsparse_descriptor__ and __binsparse_dlpack__.
1 parent fb0affe commit 464f49a

File tree

4 files changed

+127
-0
lines changed

4 files changed

+127
-0
lines changed

sparse/numba_backend/_compressed/compressed.py

+58
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,15 @@ def isinf(self):
844844
def isnan(self):
845845
return self.tocoo().isnan().asformat("gcxs", compressed_axes=self.compressed_axes)
846846

847+
# `GCXS` is a reshaped/transposed `CSR`, but it can't (usually)
848+
# be expressed in the `binsparse` 0.1 language.
849+
# We are missing index maps.
850+
def __binsparse_descriptor__(self) -> dict:
851+
return super().__binsparse_descriptor__()
852+
853+
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
854+
return super().__binsparse_dlpack__()
855+
847856

848857
class _Compressed2d(GCXS):
849858
class_compressed_axes: tuple[int]
@@ -883,6 +892,34 @@ def from_numpy(cls, x, fill_value=0, idx_dtype=None):
883892
coo = COO.from_numpy(x, fill_value=fill_value, idx_dtype=idx_dtype)
884893
return cls.from_coo(coo, cls.class_compressed_axes, idx_dtype)
885894

895+
def __binsparse_descriptor__(self) -> dict:
896+
from sparse._version import __version__
897+
898+
data_dt = str(self.data.dtype)
899+
if np.issubdtype(data_dt, np.complexfloating):
900+
data_dt = f"complex[{self.data.dtype.itemsize // (8 * 2)}]"
901+
return {
902+
"binsparse": {
903+
"version": "0.1",
904+
"format": self.format.upper(),
905+
"shape": list(self.shape),
906+
"number_of_stored_values": self.nnz,
907+
"data_types": {
908+
"pointers_to_1": str(self.indices.dtype),
909+
"indices_1": str(self.indptr.dtype),
910+
"values": data_dt,
911+
},
912+
},
913+
"original_source": f"`sparse`, version {__version__}",
914+
}
915+
916+
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
917+
return {
918+
"pointers_to_1": self.indices,
919+
"indices_1": self.indptr,
920+
"values": self.data,
921+
}
922+
886923

887924
class CSR(_Compressed2d):
888925
"""
@@ -915,6 +952,27 @@ def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"
915952
return self
916953
return CSC((self.data, self.indices, self.indptr), self.shape[::-1])
917954

955+
def __binsparse_descriptor__(self) -> dict:
956+
from sparse._version import __version__
957+
958+
data_dt = str(self.data.dtype)
959+
if np.issubdtype(data_dt, np.complexfloating):
960+
data_dt = f"complex[{self.data.dtype.itemsize // (8 * 2)}]"
961+
return {
962+
"binsparse": {
963+
"version": "0.1",
964+
"format": "CSR",
965+
"shape": list(self.shape),
966+
"number_of_stored_values": self.nnz,
967+
"data_types": {
968+
"pointers_to_1": str(self.indices.dtype),
969+
"indices_1": str(self.indptr.dtype),
970+
"values": data_dt,
971+
},
972+
},
973+
"original_source": f"`sparse`, version {__version__}",
974+
}
975+
918976

919977
class CSC(_Compressed2d):
920978
"""

sparse/numba_backend/_coo/core.py

+38
Original file line numberDiff line numberDiff line change
@@ -1537,6 +1537,44 @@ def isnan(self):
15371537
prune=True,
15381538
)
15391539

1540+
def __binsparse_descriptor__(self) -> dict:
1541+
from sparse._version import __version__
1542+
1543+
data_dt = str(self.data.dtype)
1544+
if np.issubdtype(data_dt, np.complexfloating):
1545+
data_dt = f"complex[{self.data.dtype.itemsize // (8 * 2)}]"
1546+
return {
1547+
"binsparse": {
1548+
"version": "0.1",
1549+
"format": {
1550+
"custom": {
1551+
"level": {
1552+
"level_desc": "sparse",
1553+
"rank": self.ndim,
1554+
"level": {
1555+
"level_desc": "element",
1556+
},
1557+
}
1558+
}
1559+
},
1560+
"shape": list(self.shape),
1561+
"number_of_stored_values": self.nnz,
1562+
"data_types": {
1563+
"pointers_to_1": "uint8",
1564+
"indices_1": str(self.coords.dtype),
1565+
"values": data_dt,
1566+
},
1567+
},
1568+
"original_source": f"`sparse`, version {__version__}",
1569+
}
1570+
1571+
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
1572+
return {
1573+
"pointers_to_1": np.array([0, self.nnz], dtype=np.uint8),
1574+
"indices_1": self.coords,
1575+
"values": self.data,
1576+
}
1577+
15401578

15411579
def as_coo(x, shape=None, fill_value=None, idx_dtype=None):
15421580
"""

sparse/numba_backend/_dok.py

+6
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,12 @@ def reshape(self, shape, order="C"):
548548

549549
return DOK.from_coo(self.to_coo().reshape(shape))
550550

551+
def __binsparse_descriptor__(self) -> dict:
552+
raise RuntimeError("`DOK` doesn't support the `__binsparse_descriptor__` protocol.")
553+
554+
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
555+
raise RuntimeError("`DOK` doesn't support the `__binsparse_dlpack__` protocol.")
556+
551557

552558
def to_slice(k):
553559
"""Convert integer indices to one-element slices for consistency"""

sparse/numba_backend/_sparse_array.py

+25
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,31 @@ def _str_impl(self, summary):
218218
except (ImportError, ValueError):
219219
return summary
220220

221+
@abstractmethod
222+
def __binsparse_descriptor__(self) -> dict:
223+
"""Return a `dict` equivalent to a parsed JSON [`binsparse` descriptor](https://graphblas.org/binsparse-specification/#descriptor)
224+
of this array.
225+
226+
Returns
227+
-------
228+
dict
229+
Parsed `binsparse` descriptor.
230+
"""
231+
raise NotImplementedError
232+
233+
@abstractmethod
234+
def __binsparse_dlpack__(self) -> dict[str, np.ndarray]:
235+
"""A `dict` containing the constituent arrays of this sparse array. The keys are compatible with the
236+
[`binsparse`](https://graphblas.org/binsparse-specification/) scheme, and the values are [`__dlpack__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html)
237+
compatible objects.
238+
239+
Returns
240+
-------
241+
dict[str, np.ndarray]
242+
The constituent arrays.
243+
"""
244+
raise NotImplementedError
245+
221246
@abstractmethod
222247
def asformat(self, format):
223248
"""

0 commit comments

Comments
 (0)