Skip to content

Commit 8317821

Browse files
authored
ENH: CSC and CSF formats for MLIR backend (#775)
1 parent 9c36a32 commit 8317821

File tree

4 files changed

+188
-59
lines changed

4 files changed

+188
-59
lines changed

Diff for: sparse/mlir_backend/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
from ._constructors import (
1111
asarray,
1212
)
13+
from ._dtypes import (
14+
asdtype,
15+
)
1316
from ._ops import (
1417
add,
1518
)
1619

1720
__all__ = [
1821
"add",
1922
"asarray",
23+
"asdtype",
2024
]

Diff for: sparse/mlir_backend/_constructors.py

+101-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
from typing import Any
23

34
import mlir.runtime as rt
45
from mlir import ir
@@ -48,18 +49,23 @@ def free_memref(obj: ctypes.Structure) -> None:
4849

4950

5051
@fn_cache
51-
def get_csr_class(values_dtype: type[DType], index_dtype: type[DType]) -> type:
52-
class Csr(ctypes.Structure):
52+
def get_csx_class(
53+
values_dtype: type[DType],
54+
index_dtype: type[DType],
55+
order: str,
56+
) -> type[ctypes.Structure]:
57+
class Csx(ctypes.Structure):
5358
_fields_ = [
5459
("indptr", get_nd_memref_descr(1, index_dtype)),
5560
("indices", get_nd_memref_descr(1, index_dtype)),
5661
("data", get_nd_memref_descr(1, values_dtype)),
5762
]
5863
dtype = values_dtype
5964
_index_dtype = index_dtype
65+
_order = order
6066

6167
@classmethod
62-
def from_sps(cls, arr: sps.csr_array) -> "Csr":
68+
def from_sps(cls, arr: sps.csr_array | sps.csc_array) -> "Csx":
6369
indptr = numpy_to_ranked_memref(arr.indptr)
6470
indices = numpy_to_ranked_memref(arr.indices)
6571
data = numpy_to_ranked_memref(arr.data)
@@ -69,11 +75,11 @@ def from_sps(cls, arr: sps.csr_array) -> "Csr":
6975

7076
return csr_instance
7177

72-
def to_sps(self, shape: tuple[int, ...]) -> sps.csr_array:
78+
def to_sps(self, shape: tuple[int, ...]) -> sps.csr_array | sps.csc_array:
7379
pos = ranked_memref_to_numpy(self.indptr)
7480
crd = ranked_memref_to_numpy(self.indices)
7581
data = ranked_memref_to_numpy(self.data)
76-
return sps.csr_array((data, crd, pos), shape=shape)
82+
return get_csx_scipy_class(self._order)((data, crd, pos), shape=shape)
7783

7884
def to_module_arg(self) -> list:
7985
return [
@@ -93,15 +99,15 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
9399
index_dtype = cls._index_dtype.get_mlir_type()
94100
index_width = getattr(index_dtype, "width", 0)
95101
levels = (sparse_tensor.LevelFormat.dense, sparse_tensor.LevelFormat.compressed)
96-
ordering = ir.AffineMap.get_permutation([0, 1])
102+
ordering = ir.AffineMap.get_permutation(get_order_tuple(cls._order))
97103
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
98104
return ir.RankedTensorType.get(list(shape), values_dtype, encoding)
99105

100-
return Csr
106+
return Csx
101107

102108

103109
@fn_cache
104-
def get_coo_class(values_dtype: type[DType], index_dtype: type[DType]) -> type:
110+
def get_coo_class(values_dtype: type[DType], index_dtype: type[DType]) -> type[ctypes.Structure]:
105111
class Coo(ctypes.Structure):
106112
_fields_ = [
107113
("pos", get_nd_memref_descr(1, index_dtype)),
@@ -162,12 +168,61 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
162168

163169

164170
@fn_cache
165-
def get_csf_class(values_dtype: type[DType], index_dtype: type[DType]) -> type:
166-
raise NotImplementedError
171+
def get_csf_class(
172+
values_dtype: type[DType],
173+
index_dtype: type[DType],
174+
) -> type[ctypes.Structure]:
175+
class Csf(ctypes.Structure):
176+
_fields_ = [
177+
("indptr_1", get_nd_memref_descr(1, index_dtype)),
178+
("indices_1", get_nd_memref_descr(1, index_dtype)),
179+
("indptr_2", get_nd_memref_descr(1, index_dtype)),
180+
("indices_2", get_nd_memref_descr(1, index_dtype)),
181+
("data", get_nd_memref_descr(1, values_dtype)),
182+
]
183+
dtype = values_dtype
184+
_index_dtype = index_dtype
185+
186+
@classmethod
187+
def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
188+
csf_instance = cls(*[numpy_to_ranked_memref(arr) for arr in arrs])
189+
for arr in arrs:
190+
_take_owneship(csf_instance, arr)
191+
return csf_instance
192+
193+
def to_sps(self, shape: tuple[int, ...]) -> list[np.ndarray]:
194+
class List(list):
195+
pass
196+
197+
return List(ranked_memref_to_numpy(field) for field in self.get__fields_())
198+
199+
def to_module_arg(self) -> list:
200+
return [ctypes.pointer(ctypes.pointer(field)) for field in self.get__fields_()]
201+
202+
def get__fields_(self) -> list:
203+
return [self.indptr_1, self.indices_1, self.indptr_2, self.indices_2, self.data]
204+
205+
@classmethod
206+
@fn_cache
207+
def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
208+
with ir.Location.unknown(ctx):
209+
values_dtype = cls.dtype.get_mlir_type()
210+
index_dtype = cls._index_dtype.get_mlir_type()
211+
index_width = getattr(index_dtype, "width", 0)
212+
levels = (
213+
sparse_tensor.LevelFormat.dense,
214+
sparse_tensor.LevelFormat.compressed,
215+
sparse_tensor.LevelFormat.compressed,
216+
)
217+
ordering = ir.AffineMap.get_permutation([0, 1, 2])
218+
encoding = sparse_tensor.EncodingAttr.get(levels, ordering, ordering, index_width, index_width)
219+
return ir.RankedTensorType.get(list(shape), values_dtype, encoding)
220+
221+
return Csf
167222

168223

169224
@fn_cache
170-
def get_dense_class(values_dtype: type[DType], index_dtype: type[DType]) -> type:
225+
def get_dense_class(values_dtype: type[DType], index_dtype: type[DType]) -> type[ctypes.Structure]:
171226
class Dense(ctypes.Structure):
172227
_fields_ = [
173228
("data", get_nd_memref_descr(1, values_dtype)),
@@ -221,22 +276,42 @@ def _is_mlir_obj(x) -> bool:
221276
return isinstance(x, ctypes.Structure)
222277

223278

279+
def get_order_tuple(order: str) -> tuple[int, int]:
280+
if order in ("r", "c"):
281+
return (0, 1) if order == "r" else (1, 0)
282+
raise Exception(f"Invalid order: {order}")
283+
284+
285+
def get_csx_scipy_class(order: str) -> type[sps.sparray]:
286+
if order in ("r", "c"):
287+
return sps.csr_array if order == "r" else sps.csc_array
288+
raise Exception(f"Invalid order: {order}")
289+
290+
224291
################
225292
# Tensor class #
226293
################
227294

228295

229296
class Tensor:
230-
def __init__(self, obj, shape=None) -> None:
297+
def __init__(
298+
self,
299+
obj: Any,
300+
shape: tuple[int, ...] | None = None,
301+
dtype: type[DType] | None = None,
302+
format: str | None = None,
303+
) -> None:
231304
self.shape = shape if shape is not None else obj.shape
232-
self._values_dtype = asdtype(obj.dtype)
305+
self.ndim = len(self.shape)
306+
self._values_dtype = dtype if dtype is not None else asdtype(obj.dtype)
233307

234308
if _is_scipy_sparse_obj(obj):
235309
self._owns_memory = False
236310

237-
if obj.format == "csr":
311+
if obj.format in ("csr", "csc"):
312+
order = "r" if obj.format == "csr" else "c"
238313
index_dtype = asdtype(obj.indptr.dtype)
239-
self._format_class = get_csr_class(self._values_dtype, index_dtype)
314+
self._format_class = get_csx_class(self._values_dtype, index_dtype, order)
240315
self._obj = self._format_class.from_sps(obj)
241316
elif obj.format == "coo":
242317
index_dtype = asdtype(obj.coords[0].dtype)
@@ -256,6 +331,15 @@ def __init__(self, obj, shape=None) -> None:
256331
self._format_class = type(obj)
257332
self._obj = obj
258333

334+
elif format is not None:
335+
if format == "csf":
336+
self._owns_memory = False
337+
index_dtype = asdtype(np.intp)
338+
self._format_class = get_csf_class(self._values_dtype, index_dtype)
339+
self._obj = self._format_class.from_sps(obj)
340+
else:
341+
raise Exception(f"Format {format} not supported.")
342+
259343
else:
260344
raise Exception(f"{type(obj)} not supported.")
261345

@@ -269,5 +353,5 @@ def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
269353
return self._obj.to_sps(self.shape)
270354

271355

272-
def asarray(obj) -> Tensor:
273-
return Tensor(obj)
356+
def asarray(obj, shape=None, dtype=None, format=None) -> Tensor:
357+
return Tensor(obj, shape, dtype, format)

Diff for: sparse/mlir_backend/_ops.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@ def get_add_module(
1717
b_tensor_type: ir.RankedTensorType,
1818
out_tensor_type: ir.RankedTensorType,
1919
dtype: type[DType],
20+
rank: int,
2021
) -> ir.Module:
2122
with ir.Location.unknown(ctx):
2223
module = ir.Module.create()
2324
# TODO: add support for complex dialect/dtypes
2425
arith_op = arith.AddFOp if issubclass(dtype, FloatingDType) else arith.AddIOp
2526
dtype = dtype.get_mlir_type()
26-
ordering = ir.AffineMap.get_permutation([0, 1])
27+
ordering = ir.AffineMap.get_permutation(range(rank))
2728

2829
with ir.InsertionPoint(module.body):
2930

@@ -35,7 +36,7 @@ def add(a, b):
3536
[a, b],
3637
[out],
3738
ir.ArrayAttr.get([ir.AffineMapAttr.get(p) for p in (ordering,) * 3]),
38-
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * 2),
39+
ir.ArrayAttr.get([ir.Attribute.parse("#linalg.iterator_type<parallel>")] * rank),
3940
)
4041
block = generic_op.regions[0].blocks.append(dtype, dtype, dtype)
4142
with ir.InsertionPoint(block):
@@ -78,6 +79,7 @@ def add(x1: Tensor, x2: Tensor) -> Tensor:
7879
x2._obj.get_tensor_definition(x2.shape),
7980
out_tensor_type=out_tensor_type,
8081
dtype=x1._values_dtype,
82+
rank=x1.ndim,
8183
)
8284
add_module.invoke(
8385
"add",

0 commit comments

Comments
 (0)