1
1
import ctypes
2
+ from typing import Any
2
3
3
4
import mlir .runtime as rt
4
5
from mlir import ir
@@ -48,18 +49,23 @@ def free_memref(obj: ctypes.Structure) -> None:
48
49
49
50
50
51
@fn_cache
51
- def get_csr_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
52
- class Csr (ctypes .Structure ):
52
+ def get_csx_class (
53
+ values_dtype : type [DType ],
54
+ index_dtype : type [DType ],
55
+ order : str ,
56
+ ) -> type [ctypes .Structure ]:
57
+ class Csx (ctypes .Structure ):
53
58
_fields_ = [
54
59
("indptr" , get_nd_memref_descr (1 , index_dtype )),
55
60
("indices" , get_nd_memref_descr (1 , index_dtype )),
56
61
("data" , get_nd_memref_descr (1 , values_dtype )),
57
62
]
58
63
dtype = values_dtype
59
64
_index_dtype = index_dtype
65
+ _order = order
60
66
61
67
@classmethod
62
- def from_sps (cls , arr : sps .csr_array ) -> "Csr " :
68
+ def from_sps (cls , arr : sps .csr_array | sps . csc_array ) -> "Csx " :
63
69
indptr = numpy_to_ranked_memref (arr .indptr )
64
70
indices = numpy_to_ranked_memref (arr .indices )
65
71
data = numpy_to_ranked_memref (arr .data )
@@ -69,11 +75,11 @@ def from_sps(cls, arr: sps.csr_array) -> "Csr":
69
75
70
76
return csr_instance
71
77
72
- def to_sps (self , shape : tuple [int , ...]) -> sps .csr_array :
78
+ def to_sps (self , shape : tuple [int , ...]) -> sps .csr_array | sps . csc_array :
73
79
pos = ranked_memref_to_numpy (self .indptr )
74
80
crd = ranked_memref_to_numpy (self .indices )
75
81
data = ranked_memref_to_numpy (self .data )
76
- return sps . csr_array ((data , crd , pos ), shape = shape )
82
+ return get_csx_scipy_class ( self . _order ) ((data , crd , pos ), shape = shape )
77
83
78
84
def to_module_arg (self ) -> list :
79
85
return [
@@ -93,15 +99,15 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
93
99
index_dtype = cls ._index_dtype .get_mlir_type ()
94
100
index_width = getattr (index_dtype , "width" , 0 )
95
101
levels = (sparse_tensor .LevelFormat .dense , sparse_tensor .LevelFormat .compressed )
96
- ordering = ir .AffineMap .get_permutation ([ 0 , 1 ] )
102
+ ordering = ir .AffineMap .get_permutation (get_order_tuple ( cls . _order ) )
97
103
encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
98
104
return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
99
105
100
- return Csr
106
+ return Csx
101
107
102
108
103
109
@fn_cache
104
- def get_coo_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
110
+ def get_coo_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type [ ctypes . Structure ] :
105
111
class Coo (ctypes .Structure ):
106
112
_fields_ = [
107
113
("pos" , get_nd_memref_descr (1 , index_dtype )),
@@ -162,12 +168,61 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
162
168
163
169
164
170
@fn_cache
165
- def get_csf_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
166
- raise NotImplementedError
171
+ def get_csf_class (
172
+ values_dtype : type [DType ],
173
+ index_dtype : type [DType ],
174
+ ) -> type [ctypes .Structure ]:
175
+ class Csf (ctypes .Structure ):
176
+ _fields_ = [
177
+ ("indptr_1" , get_nd_memref_descr (1 , index_dtype )),
178
+ ("indices_1" , get_nd_memref_descr (1 , index_dtype )),
179
+ ("indptr_2" , get_nd_memref_descr (1 , index_dtype )),
180
+ ("indices_2" , get_nd_memref_descr (1 , index_dtype )),
181
+ ("data" , get_nd_memref_descr (1 , values_dtype )),
182
+ ]
183
+ dtype = values_dtype
184
+ _index_dtype = index_dtype
185
+
186
+ @classmethod
187
+ def from_sps (cls , arrs : list [np .ndarray ]) -> "Csf" :
188
+ csf_instance = cls (* [numpy_to_ranked_memref (arr ) for arr in arrs ])
189
+ for arr in arrs :
190
+ _take_owneship (csf_instance , arr )
191
+ return csf_instance
192
+
193
+ def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
194
+ class List (list ):
195
+ pass
196
+
197
+ return List (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
198
+
199
+ def to_module_arg (self ) -> list :
200
+ return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
201
+
202
+ def get__fields_ (self ) -> list :
203
+ return [self .indptr_1 , self .indices_1 , self .indptr_2 , self .indices_2 , self .data ]
204
+
205
+ @classmethod
206
+ @fn_cache
207
+ def get_tensor_definition (cls , shape : tuple [int , ...]) -> ir .RankedTensorType :
208
+ with ir .Location .unknown (ctx ):
209
+ values_dtype = cls .dtype .get_mlir_type ()
210
+ index_dtype = cls ._index_dtype .get_mlir_type ()
211
+ index_width = getattr (index_dtype , "width" , 0 )
212
+ levels = (
213
+ sparse_tensor .LevelFormat .dense ,
214
+ sparse_tensor .LevelFormat .compressed ,
215
+ sparse_tensor .LevelFormat .compressed ,
216
+ )
217
+ ordering = ir .AffineMap .get_permutation ([0 , 1 , 2 ])
218
+ encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
219
+ return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
220
+
221
+ return Csf
167
222
168
223
169
224
@fn_cache
170
- def get_dense_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
225
+ def get_dense_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type [ ctypes . Structure ] :
171
226
class Dense (ctypes .Structure ):
172
227
_fields_ = [
173
228
("data" , get_nd_memref_descr (1 , values_dtype )),
@@ -221,22 +276,42 @@ def _is_mlir_obj(x) -> bool:
221
276
return isinstance (x , ctypes .Structure )
222
277
223
278
279
+ def get_order_tuple (order : str ) -> tuple [int , int ]:
280
+ if order in ("r" , "c" ):
281
+ return (0 , 1 ) if order == "r" else (1 , 0 )
282
+ raise Exception (f"Invalid order: { order } " )
283
+
284
+
285
+ def get_csx_scipy_class (order : str ) -> type [sps .sparray ]:
286
+ if order in ("r" , "c" ):
287
+ return sps .csr_array if order == "r" else sps .csc_array
288
+ raise Exception (f"Invalid order: { order } " )
289
+
290
+
224
291
################
225
292
# Tensor class #
226
293
################
227
294
228
295
229
296
class Tensor :
230
- def __init__ (self , obj , shape = None ) -> None :
297
+ def __init__ (
298
+ self ,
299
+ obj : Any ,
300
+ shape : tuple [int , ...] | None = None ,
301
+ dtype : type [DType ] | None = None ,
302
+ format : str | None = None ,
303
+ ) -> None :
231
304
self .shape = shape if shape is not None else obj .shape
232
- self ._values_dtype = asdtype (obj .dtype )
305
+ self .ndim = len (self .shape )
306
+ self ._values_dtype = dtype if dtype is not None else asdtype (obj .dtype )
233
307
234
308
if _is_scipy_sparse_obj (obj ):
235
309
self ._owns_memory = False
236
310
237
- if obj .format == "csr" :
311
+ if obj .format in ("csr" , "csc" ):
312
+ order = "r" if obj .format == "csr" else "c"
238
313
index_dtype = asdtype (obj .indptr .dtype )
239
- self ._format_class = get_csr_class (self ._values_dtype , index_dtype )
314
+ self ._format_class = get_csx_class (self ._values_dtype , index_dtype , order )
240
315
self ._obj = self ._format_class .from_sps (obj )
241
316
elif obj .format == "coo" :
242
317
index_dtype = asdtype (obj .coords [0 ].dtype )
@@ -256,6 +331,15 @@ def __init__(self, obj, shape=None) -> None:
256
331
self ._format_class = type (obj )
257
332
self ._obj = obj
258
333
334
+ elif format is not None :
335
+ if format == "csf" :
336
+ self ._owns_memory = False
337
+ index_dtype = asdtype (np .intp )
338
+ self ._format_class = get_csf_class (self ._values_dtype , index_dtype )
339
+ self ._obj = self ._format_class .from_sps (obj )
340
+ else :
341
+ raise Exception (f"Format { format } not supported." )
342
+
259
343
else :
260
344
raise Exception (f"{ type (obj )} not supported." )
261
345
@@ -269,5 +353,5 @@ def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
269
353
return self ._obj .to_sps (self .shape )
270
354
271
355
272
- def asarray (obj ) -> Tensor :
273
- return Tensor (obj )
356
+ def asarray (obj , shape = None , dtype = None , format = None ) -> Tensor :
357
+ return Tensor (obj , shape , dtype , format )
0 commit comments