1
1
import ctypes
2
2
import ctypes .util
3
+ import functools
4
+ import weakref
3
5
4
6
import mlir .execution_engine
5
7
import mlir .passmanager
9
11
import numpy as np
10
12
import scipy .sparse as sps
11
13
12
- from ._core import DEBUG , MLIR_C_RUNNER_UTILS , SCRIPT_PATH , ctx
13
- from ._dtypes import DType , Float64 , Index
14
- from ._memref import MemrefF64_1D , MemrefIdx_1D
14
+ from ._common import fn_cache
15
+ from ._core import CWD , DEBUG , MLIR_C_RUNNER_UTILS , ctx
16
+ from ._dtypes import DType , Index , asdtype
17
+ from ._memref import make_memref_ctype , ranked_memref_from_np
18
+
19
+
20
+ def _hold_self_ref_in_ret (fn ):
21
+ @functools .wraps (fn )
22
+ def wrapped (self , * a , ** kw ):
23
+ ptr = ctypes .py_object (self )
24
+ ctypes .pythonapi .Py_IncRef (ptr )
25
+ ret = fn (self , * a , ** kw )
26
+
27
+ def finalizer (ptr ):
28
+ ctypes .pythonapi .Py_DecRef (ptr )
29
+
30
+ weakref .finalize (ret , finalizer , ptr )
31
+ return ret
32
+
33
+ return wrapped
15
34
16
35
17
36
class Tensor :
@@ -26,21 +45,21 @@ def __init__(self, obj, module, tensor_type, disassemble_fn, values_dtype, index
26
45
def __del__ (self ):
27
46
self .module .invoke ("free_tensor" , ctypes .pointer (self .obj ))
28
47
48
+ @_hold_self_ref_in_ret
29
49
def to_scipy_sparse (self ):
30
50
"""
31
51
Returns scipy.sparse or ndarray
32
52
"""
33
- return self .disassemble_fn (self .module , self .obj )
53
+ return self .disassemble_fn (self .module , self .obj , self . values_dtype )
34
54
35
55
36
56
class DenseFormat :
37
- modules = {}
38
-
57
+ @fn_cache
39
58
def get_module (shape : tuple [int ], values_dtype : DType , index_dtype : DType ):
40
59
with ir .Location .unknown (ctx ):
41
60
module = ir .Module .create ()
42
- values_dtype = values_dtype .get ()
43
- index_dtype = index_dtype .get ()
61
+ values_dtype = values_dtype .get_mlir_type ()
62
+ index_dtype = index_dtype .get_mlir_type ()
44
63
index_width = getattr (index_dtype , "width" , 0 )
45
64
levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .dense )
46
65
ordering = ir .AffineMap .get_permutation ([0 , 1 ])
@@ -78,18 +97,19 @@ def free_tensor(tensor_shaped):
78
97
disassemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
79
98
free_tensor .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
80
99
if DEBUG :
81
- (SCRIPT_PATH / "dense_module.mlir" ).write_text (str (module ))
100
+ (CWD / "dense_module.mlir" ).write_text (str (module ))
82
101
pm = mlir .passmanager .PassManager .parse ("builtin.module(sparsifier{create-sparse-deallocs=1})" )
83
102
pm .run (module .operation )
84
103
if DEBUG :
85
- (SCRIPT_PATH / "dense_module_opt.mlir" ).write_text (str (module ))
104
+ (CWD / "dense_module_opt.mlir" ).write_text (str (module ))
86
105
87
106
module = mlir .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = [MLIR_C_RUNNER_UTILS ])
88
107
return (module , dense_shaped )
89
108
90
109
@classmethod
91
110
def assemble (cls , module , arr : np .ndarray ) -> ctypes .c_void_p :
92
- data = MemrefF64_1D .from_numpy (arr .flatten ())
111
+ assert arr .ndim == 2
112
+ data = ranked_memref_from_np (arr .flatten ())
93
113
out = ctypes .c_void_p ()
94
114
module .invoke (
95
115
"assemble" ,
@@ -99,18 +119,18 @@ def assemble(cls, module, arr: np.ndarray) -> ctypes.c_void_p:
99
119
return out
100
120
101
121
@classmethod
102
- def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p ) -> np .ndarray :
122
+ def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [ DType ] ) -> np .ndarray :
103
123
class Dense (ctypes .Structure ):
104
124
_fields_ = [
105
- ("data" , MemrefF64_1D ),
125
+ ("data" , make_memref_ctype ( dtype , 1 ) ),
106
126
("data_len" , np .ctypeslib .c_intp ),
107
127
("shape_x" , np .ctypeslib .c_intp ),
108
128
("shape_y" , np .ctypeslib .c_intp ),
109
129
]
110
130
111
131
def to_np (self ) -> np .ndarray :
112
132
data = self .data .to_numpy ()[: self .data_len ]
113
- return data .copy (). reshape ((self .shape_x , self .shape_y ))
133
+ return data .reshape ((self .shape_x , self .shape_y ))
114
134
115
135
arr = Dense ()
116
136
module .invoke (
@@ -122,18 +142,17 @@ def to_np(self) -> np.ndarray:
122
142
123
143
124
144
class COOFormat :
125
- modules = {}
126
145
# TODO: implement
146
+ ...
127
147
128
148
129
149
class CSRFormat :
130
- modules = {}
131
-
132
- def get_module (shape : tuple [int ], values_dtype : DType , index_dtype : DType ):
150
+ @fn_cache
151
+ def get_module (shape : tuple [int ], values_dtype : type [DType ], index_dtype : type [DType ]):
133
152
with ir .Location .unknown (ctx ):
134
153
module = ir .Module .create ()
135
- values_dtype = values_dtype .get ()
136
- index_dtype = index_dtype .get ()
154
+ values_dtype = values_dtype .get_mlir_type ()
155
+ index_dtype = index_dtype .get_mlir_type ()
137
156
index_width = getattr (index_dtype , "width" , 0 )
138
157
levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .compressed )
139
158
ordering = ir .AffineMap .get_permutation ([0 , 1 ])
@@ -175,11 +194,11 @@ def free_tensor(tensor_shaped):
175
194
disassemble .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
176
195
free_tensor .func_op .attributes ["llvm.emit_c_interface" ] = ir .UnitAttr .get ()
177
196
if DEBUG :
178
- (SCRIPT_PATH / "scr_module .mlir" ).write_text (str (module ))
197
+ (CWD / "csr_module .mlir" ).write_text (str (module ))
179
198
pm = mlir .passmanager .PassManager .parse ("builtin.module(sparsifier{create-sparse-deallocs=1})" )
180
199
pm .run (module .operation )
181
200
if DEBUG :
182
- (SCRIPT_PATH / "csr_module_opt.mlir" ).write_text (str (module ))
201
+ (CWD / "csr_module_opt.mlir" ).write_text (str (module ))
183
202
184
203
module = mlir .execution_engine .ExecutionEngine (module , opt_level = 2 , shared_libs = [MLIR_C_RUNNER_UTILS ])
185
204
return (module , csr_shaped )
@@ -189,20 +208,20 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
189
208
out = ctypes .c_void_p ()
190
209
module .invoke (
191
210
"assemble" ,
192
- ctypes .pointer (ctypes .pointer (MemrefIdx_1D . from_numpy (arr .indptr ))),
193
- ctypes .pointer (ctypes .pointer (MemrefIdx_1D . from_numpy (arr .indices ))),
194
- ctypes .pointer (ctypes .pointer (MemrefF64_1D . from_numpy (arr .data ))),
211
+ ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indptr ))),
212
+ ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .indices ))),
213
+ ctypes .pointer (ctypes .pointer (ranked_memref_from_np (arr .data ))),
195
214
ctypes .pointer (out ),
196
215
)
197
216
return out
198
217
199
218
@classmethod
200
- def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p ) -> sps .csr_array :
219
+ def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [ DType ] ) -> sps .csr_array :
201
220
class Csr (ctypes .Structure ):
202
221
_fields_ = [
203
- ("data" , MemrefF64_1D ),
204
- ("pos" , MemrefIdx_1D ),
205
- ("crd" , MemrefIdx_1D ),
222
+ ("data" , make_memref_ctype ( dtype , 1 ) ),
223
+ ("pos" , make_memref_ctype ( Index , 1 ) ),
224
+ ("crd" , make_memref_ctype ( Index , 1 ) ),
206
225
("data_len" , np .ctypeslib .c_intp ),
207
226
("pos_len" , np .ctypeslib .c_intp ),
208
227
("crd_len" , np .ctypeslib .c_intp ),
@@ -214,7 +233,7 @@ def to_sps(self) -> sps.csr_array:
214
233
pos = self .pos .to_numpy ()[: self .pos_len ]
215
234
crd = self .crd .to_numpy ()[: self .crd_len ]
216
235
data = self .data .to_numpy ()[: self .data_len ]
217
- return sps .csr_array ((data . copy () , crd . copy () , pos . copy () ), shape = (self .shape_x , self .shape_y ))
236
+ return sps .csr_array ((data , crd , pos ), shape = (self .shape_x , self .shape_y ))
218
237
219
238
arr = Csr ()
220
239
module .invoke (
@@ -235,23 +254,21 @@ def _is_numpy_obj(x) -> bool:
235
254
236
255
def asarray (obj ) -> Tensor :
237
256
# TODO: discover obj's dtype
238
- values_dtype = Float64
239
- index_dtype = Index
257
+ values_dtype = asdtype (obj .dtype )
240
258
241
259
# TODO: support other scipy formats
242
260
if _is_scipy_sparse_obj (obj ):
243
261
format_class = CSRFormat
262
+ # This can be int32 or int64
263
+ index_dtype = asdtype (obj .indptr .dtype )
244
264
elif _is_numpy_obj (obj ):
245
265
format_class = DenseFormat
266
+ index_dtype = Index
246
267
else :
247
268
raise Exception (f"{ type (obj )} not supported." )
248
269
249
270
# TODO: support proper caching
250
- if hash (obj .shape ) in format_class .modules :
251
- module , tensor_type = format_class .modules [hash (obj .shape )]
252
- else :
253
- module , tensor_type = format_class .get_module (obj .shape , values_dtype , index_dtype )
254
- format_class .modules [hash (obj .shape )] = module , tensor_type
271
+ module , tensor_type = format_class .get_module (obj .shape , values_dtype , index_dtype )
255
272
256
273
assembled_obj = format_class .assemble (module , obj )
257
274
return Tensor (assembled_obj , module , tensor_type , format_class .disassemble , values_dtype , index_dtype )
0 commit comments