Skip to content

Commit 1c56a0b

Browse files
feat, perf: Refactor the PoC to support multiple dtypes (#757)
Co-authored-by: Mateusz Sokół <[email protected]>
1 parent 41159c0 commit 1c56a0b

File tree

8 files changed

+242
-94
lines changed

8 files changed

+242
-94
lines changed

sparse/mlir_backend/_common.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import abc
2+
import functools
3+
4+
from mlir import ir
5+
6+
7+
class MlirType(abc.ABC):
8+
@classmethod
9+
@abc.abstractmethod
10+
def get_mlir_type(cls) -> ir.Type: ...
11+
12+
13+
def fn_cache(f, maxsize: int | None = None):
14+
return functools.wraps(f)(functools.lru_cache(maxsize=maxsize)(f))

sparse/mlir_backend/_constructors.py

+54-37
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import ctypes
22
import ctypes.util
3+
import functools
4+
import weakref
35

46
import mlir.execution_engine
57
import mlir.passmanager
@@ -9,9 +11,26 @@
911
import numpy as np
1012
import scipy.sparse as sps
1113

12-
from ._core import DEBUG, MLIR_C_RUNNER_UTILS, SCRIPT_PATH, ctx
13-
from ._dtypes import DType, Float64, Index
14-
from ._memref import MemrefF64_1D, MemrefIdx_1D
14+
from ._common import fn_cache
15+
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx
16+
from ._dtypes import DType, Index, asdtype
17+
from ._memref import make_memref_ctype, ranked_memref_from_np
18+
19+
20+
def _hold_self_ref_in_ret(fn):
21+
@functools.wraps(fn)
22+
def wrapped(self, *a, **kw):
23+
ptr = ctypes.py_object(self)
24+
ctypes.pythonapi.Py_IncRef(ptr)
25+
ret = fn(self, *a, **kw)
26+
27+
def finalizer(ptr):
28+
ctypes.pythonapi.Py_DecRef(ptr)
29+
30+
weakref.finalize(ret, finalizer, ptr)
31+
return ret
32+
33+
return wrapped
1534

1635

1736
class Tensor:
@@ -26,21 +45,21 @@ def __init__(self, obj, module, tensor_type, disassemble_fn, values_dtype, index
2645
def __del__(self):
2746
self.module.invoke("free_tensor", ctypes.pointer(self.obj))
2847

48+
@_hold_self_ref_in_ret
2949
def to_scipy_sparse(self):
3050
"""
3151
Returns scipy.sparse or ndarray
3252
"""
33-
return self.disassemble_fn(self.module, self.obj)
53+
return self.disassemble_fn(self.module, self.obj, self.values_dtype)
3454

3555

3656
class DenseFormat:
37-
modules = {}
38-
57+
@fn_cache
3958
def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
4059
with ir.Location.unknown(ctx):
4160
module = ir.Module.create()
42-
values_dtype = values_dtype.get()
43-
index_dtype = index_dtype.get()
61+
values_dtype = values_dtype.get_mlir_type()
62+
index_dtype = index_dtype.get_mlir_type()
4463
index_width = getattr(index_dtype, "width", 0)
4564
levels = (sparse_tensor.LevelType.dense, sparse_tensor.LevelType.dense)
4665
ordering = ir.AffineMap.get_permutation([0, 1])
@@ -78,18 +97,19 @@ def free_tensor(tensor_shaped):
7897
disassemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
7998
free_tensor.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
8099
if DEBUG:
81-
(SCRIPT_PATH / "dense_module.mlir").write_text(str(module))
100+
(CWD / "dense_module.mlir").write_text(str(module))
82101
pm = mlir.passmanager.PassManager.parse("builtin.module(sparsifier{create-sparse-deallocs=1})")
83102
pm.run(module.operation)
84103
if DEBUG:
85-
(SCRIPT_PATH / "dense_module_opt.mlir").write_text(str(module))
104+
(CWD / "dense_module_opt.mlir").write_text(str(module))
86105

87106
module = mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
88107
return (module, dense_shaped)
89108

90109
@classmethod
91110
def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
92-
data = MemrefF64_1D.from_numpy(arr.flatten())
111+
assert arr.ndim == 2
112+
data = ranked_memref_from_np(arr.flatten())
93113
out = ctypes.c_void_p()
94114
module.invoke(
95115
"assemble",
@@ -99,18 +119,18 @@ def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
99119
return out
100120

101121
@classmethod
102-
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p) -> np.ndarray:
122+
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p, dtype: type[DType]) -> np.ndarray:
103123
class Dense(ctypes.Structure):
104124
_fields_ = [
105-
("data", MemrefF64_1D),
125+
("data", make_memref_ctype(dtype, 1)),
106126
("data_len", np.ctypeslib.c_intp),
107127
("shape_x", np.ctypeslib.c_intp),
108128
("shape_y", np.ctypeslib.c_intp),
109129
]
110130

111131
def to_np(self) -> np.ndarray:
112132
data = self.data.to_numpy()[: self.data_len]
113-
return data.copy().reshape((self.shape_x, self.shape_y))
133+
return data.reshape((self.shape_x, self.shape_y))
114134

115135
arr = Dense()
116136
module.invoke(
@@ -122,18 +142,17 @@ def to_np(self) -> np.ndarray:
122142

123143

124144
class COOFormat:
125-
modules = {}
126145
# TODO: implement
146+
...
127147

128148

129149
class CSRFormat:
130-
modules = {}
131-
132-
def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
150+
@fn_cache
151+
def get_module(shape: tuple[int], values_dtype: type[DType], index_dtype: type[DType]):
133152
with ir.Location.unknown(ctx):
134153
module = ir.Module.create()
135-
values_dtype = values_dtype.get()
136-
index_dtype = index_dtype.get()
154+
values_dtype = values_dtype.get_mlir_type()
155+
index_dtype = index_dtype.get_mlir_type()
137156
index_width = getattr(index_dtype, "width", 0)
138157
levels = (sparse_tensor.LevelType.dense, sparse_tensor.LevelType.compressed)
139158
ordering = ir.AffineMap.get_permutation([0, 1])
@@ -175,11 +194,11 @@ def free_tensor(tensor_shaped):
175194
disassemble.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
176195
free_tensor.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
177196
if DEBUG:
178-
(SCRIPT_PATH / "scr_module.mlir").write_text(str(module))
197+
(CWD / "csr_module.mlir").write_text(str(module))
179198
pm = mlir.passmanager.PassManager.parse("builtin.module(sparsifier{create-sparse-deallocs=1})")
180199
pm.run(module.operation)
181200
if DEBUG:
182-
(SCRIPT_PATH / "csr_module_opt.mlir").write_text(str(module))
201+
(CWD / "csr_module_opt.mlir").write_text(str(module))
183202

184203
module = mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
185204
return (module, csr_shaped)
@@ -189,20 +208,20 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
189208
out = ctypes.c_void_p()
190209
module.invoke(
191210
"assemble",
192-
ctypes.pointer(ctypes.pointer(MemrefIdx_1D.from_numpy(arr.indptr))),
193-
ctypes.pointer(ctypes.pointer(MemrefIdx_1D.from_numpy(arr.indices))),
194-
ctypes.pointer(ctypes.pointer(MemrefF64_1D.from_numpy(arr.data))),
211+
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.indptr))),
212+
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.indices))),
213+
ctypes.pointer(ctypes.pointer(ranked_memref_from_np(arr.data))),
195214
ctypes.pointer(out),
196215
)
197216
return out
198217

199218
@classmethod
200-
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p) -> sps.csr_array:
219+
def disassemble(cls, module: ir.Module, ptr: ctypes.c_void_p, dtype: type[DType]) -> sps.csr_array:
201220
class Csr(ctypes.Structure):
202221
_fields_ = [
203-
("data", MemrefF64_1D),
204-
("pos", MemrefIdx_1D),
205-
("crd", MemrefIdx_1D),
222+
("data", make_memref_ctype(dtype, 1)),
223+
("pos", make_memref_ctype(Index, 1)),
224+
("crd", make_memref_ctype(Index, 1)),
206225
("data_len", np.ctypeslib.c_intp),
207226
("pos_len", np.ctypeslib.c_intp),
208227
("crd_len", np.ctypeslib.c_intp),
@@ -214,7 +233,7 @@ def to_sps(self) -> sps.csr_array:
214233
pos = self.pos.to_numpy()[: self.pos_len]
215234
crd = self.crd.to_numpy()[: self.crd_len]
216235
data = self.data.to_numpy()[: self.data_len]
217-
return sps.csr_array((data.copy(), crd.copy(), pos.copy()), shape=(self.shape_x, self.shape_y))
236+
return sps.csr_array((data, crd, pos), shape=(self.shape_x, self.shape_y))
218237

219238
arr = Csr()
220239
module.invoke(
@@ -235,23 +254,21 @@ def _is_numpy_obj(x) -> bool:
235254

236255
def asarray(obj) -> Tensor:
237256
# TODO: discover obj's dtype
238-
values_dtype = Float64
239-
index_dtype = Index
257+
values_dtype = asdtype(obj.dtype)
240258

241259
# TODO: support other scipy formats
242260
if _is_scipy_sparse_obj(obj):
243261
format_class = CSRFormat
262+
# This can be int32 or int64
263+
index_dtype = asdtype(obj.indptr.dtype)
244264
elif _is_numpy_obj(obj):
245265
format_class = DenseFormat
266+
index_dtype = Index
246267
else:
247268
raise Exception(f"{type(obj)} not supported.")
248269

249270
# TODO: support proper caching
250-
if hash(obj.shape) in format_class.modules:
251-
module, tensor_type = format_class.modules[hash(obj.shape)]
252-
else:
253-
module, tensor_type = format_class.get_module(obj.shape, values_dtype, index_dtype)
254-
format_class.modules[hash(obj.shape)] = module, tensor_type
271+
module, tensor_type = format_class.get_module(obj.shape, values_dtype, index_dtype)
255272

256273
assembled_obj = format_class.assemble(module, obj)
257274
return Tensor(assembled_obj, module, tensor_type, format_class.disassemble, values_dtype, index_dtype)

sparse/mlir_backend/_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from mlir.ir import Context
66

77
DEBUG = bool(int(os.environ.get("DEBUG", "0")))
8-
SCRIPT_PATH = pathlib.Path(__file__).parent
8+
CWD = pathlib.Path(".")
99

1010
MLIR_C_RUNNER_UTILS = ctypes.util.find_library("mlir_c_runner_utils")
1111
libc = ctypes.CDLL(ctypes.util.find_library("c")) if os.name != "nt" else ctypes.cdll.msvcrt

sparse/mlir_backend/_dtypes.py

+80-32
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,119 @@
1+
import inspect
2+
import math
3+
import sys
4+
import typing
5+
16
from mlir import ir
27

38
import numpy as np
49

10+
from ._common import MlirType
11+
12+
13+
def _get_pointer_width() -> int:
14+
return round(math.log2(sys.maxsize + 1.0)) + 1
15+
16+
17+
_PTR_WIDTH = _get_pointer_width()
18+
19+
20+
def _make_int_classes(namespace: dict[str, object], bit_widths: typing.Iterable[int]) -> None:
21+
for bw in bit_widths:
22+
23+
class SignedBW(SignedIntegerDType):
24+
np_dtype = getattr(np, f"int{bw}")
25+
bit_width = bw
26+
27+
@classmethod
28+
def get_mlir_type(cls):
29+
return ir.IntegerType.get_signless(cls.bit_width)
30+
31+
SignedBW.__name__ = f"Int{bw}"
32+
SignedBW.__module__ = __name__
33+
34+
class UnsignedBW(UnsignedIntegerDType):
35+
np_dtype = getattr(np, f"uint{bw}")
36+
bit_width = bw
37+
38+
@classmethod
39+
def get_mlir_type(cls):
40+
return ir.IntegerType.get_signless(cls.bit_width)
41+
42+
UnsignedBW.__name__ = f"UInt{bw}"
43+
UnsignedBW.__module__ = __name__
44+
45+
namespace[SignedBW.__name__] = SignedBW
46+
namespace[UnsignedBW.__name__] = UnsignedBW
547

6-
class DType:
7-
pass
848

49+
class DType(MlirType):
50+
np_dtype: np.dtype
51+
bit_width: int
952

10-
class Float64(DType):
53+
54+
class FloatingDType(DType): ...
55+
56+
57+
class Float64(FloatingDType):
1158
np_dtype = np.float64
59+
bit_width = 64
1260

1361
@classmethod
14-
def get(cls):
62+
def get_mlir_type(cls):
1563
return ir.F64Type.get()
1664

1765

18-
class Float32(DType):
66+
class Float32(FloatingDType):
1967
np_dtype = np.float32
68+
bit_width = 32
2069

2170
@classmethod
22-
def get(cls):
71+
def get_mlir_type(cls):
2372
return ir.F32Type.get()
2473

2574

26-
class Int64(DType):
27-
np_dtype = np.int64
75+
class Float16(FloatingDType):
76+
np_dtype = np.float16
77+
bit_width = 16
2878

2979
@classmethod
30-
def get(cls):
31-
return ir.IntegerType.get_signed(64)
80+
def get_mlir_type(cls):
81+
return ir.F16Type.get()
3282

3383

34-
class UInt64(DType):
35-
np_dtype = np.uint64
84+
class IntegerDType(DType): ...
3685

37-
@classmethod
38-
def get(cls):
39-
return ir.IntegerType.get_unsigned(64)
4086

87+
class UnsignedIntegerDType(IntegerDType): ...
4188

42-
class Int32(DType):
43-
np_dtype = np.int32
44-
45-
@classmethod
46-
def get(cls):
47-
return ir.IntegerType.get_signed(32)
4889

90+
class SignedIntegerDType(IntegerDType): ...
4991

50-
class UInt32(DType):
51-
np_dtype = np.uint32
5292

53-
@classmethod
54-
def get(cls):
55-
return ir.IntegerType.get_unsigned(32)
93+
_make_int_classes(locals(), [8, 16, 32, 64])
5694

5795

5896
class Index(DType):
5997
np_dtype = np.intp
6098

6199
@classmethod
62-
def get(cls):
100+
def get_mlir_type(cls):
63101
return ir.IndexType.get()
64102

65103

66-
class SignlessInt64(DType):
67-
np_dtype = np.int64
104+
IntP: type[SignedIntegerDType] = locals()[f"Int{_PTR_WIDTH}"]
105+
UIntP: type[UnsignedIntegerDType] = locals()[f"UInt{_PTR_WIDTH}"]
68106

69-
@classmethod
70-
def get(cls):
71-
return ir.IntegerType.get_signless(64)
107+
108+
def isdtype(dt, /) -> bool:
109+
return isinstance(dt, type) and issubclass(dt, DType) and not inspect.isabstract(dt)
110+
111+
112+
NUMPY_DTYPE_MAP = {np.dtype(dt.np_dtype): dt for dt in locals().values() if isdtype(dt)}
113+
114+
115+
def asdtype(dt, /) -> type[DType]:
116+
if isdtype(dt):
117+
return dt
118+
119+
return NUMPY_DTYPE_MAP[np.dtype(dt)]

0 commit comments

Comments
 (0)