Skip to content

Commit 9db27d4

Browse files
committed
ENH: Implemented __getitem__ logic
1 parent df50a8d commit 9db27d4

File tree

3 files changed

+153
-6
lines changed

3 files changed

+153
-6
lines changed

sparse/mlir_backend/_constructors.py

+6
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ def __del__(self):
364364
for field in self._obj.get__fields_():
365365
free_memref(field)
366366

367+
def __getitem__(self, key) -> "Tensor":
368+
# imported lazily to avoid cyclic dependency
369+
from ._ops import getitem
370+
371+
return getitem(self, key)
372+
367373
@_hold_self_ref_in_ret
368374
def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
369375
return self._obj.to_sps(self.shape)

sparse/mlir_backend/_ops.py

+111-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
from types import EllipsisType
23

34
import mlir.execution_engine
45
import mlir.passmanager
@@ -85,12 +86,39 @@ def get_reshape_module(
8586
def reshape(a, shape):
8687
return tensor.reshape(out_tensor_type, a, shape)
8788

88-
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
89-
if DEBUG:
90-
(CWD / "reshape_module.mlir").write_text(str(module))
91-
pm.run(module.operation)
92-
if DEBUG:
93-
(CWD / "reshape_module_opt.mlir").write_text(str(module))
89+
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
90+
if DEBUG:
91+
(CWD / "reshape_module.mlir").write_text(str(module))
92+
pm.run(module.operation)
93+
if DEBUG:
94+
(CWD / "reshape_module_opt.mlir").write_text(str(module))
95+
96+
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
97+
98+
99+
@fn_cache
100+
def get_slice_module(
101+
in_tensor_type: ir.RankedTensorType,
102+
out_tensor_type: ir.RankedTensorType,
103+
offsets: tuple[int, ...],
104+
sizes: tuple[int, ...],
105+
strides: tuple[int, ...],
106+
) -> ir.Module:
107+
with ir.Location.unknown(ctx):
108+
module = ir.Module.create()
109+
110+
with ir.InsertionPoint(module.body):
111+
112+
@func.FuncOp.from_py_func(in_tensor_type)
113+
def getitem(a):
114+
return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides)
115+
116+
getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
117+
if DEBUG:
118+
(CWD / "getitem_module.mlir").write_text(str(module))
119+
pm.run(module.operation)
120+
if DEBUG:
121+
(CWD / "getitem_module_opt.mlir").write_text(str(module))
94122

95123
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
96124

@@ -152,3 +180,80 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
152180
)
153181

154182
return Tensor(ret_obj, shape=out_tensor_type.shape)
183+
184+
185+
def _add_missing_dims(key: tuple, ndim: int) -> tuple:
186+
if len(key) < ndim and Ellipsis not in key:
187+
return key + (...,)
188+
return key
189+
190+
191+
def _expand_ellipsis(key: tuple, ndim: int) -> tuple:
192+
if Ellipsis in key:
193+
if len([e for e in key if e is Ellipsis]) > 1:
194+
raise Exception(f"Ellipsis should be used once: {key}")
195+
to_expand = ndim - len(key) + 1
196+
if to_expand <= 0:
197+
raise Exception(f"Invalid use of Ellipsis in {key}")
198+
idx = key.index(Ellipsis)
199+
return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :]
200+
return key
201+
202+
203+
def _decompose_slices(
204+
key: tuple,
205+
shape: tuple[int, ...],
206+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
207+
offsets = []
208+
sizes = []
209+
strides = []
210+
211+
for key_elem, size in zip(key, shape, strict=False):
212+
if isinstance(key_elem, slice):
213+
offset = key_elem.start if key_elem.start is not None else 0
214+
size = key_elem.stop - offset if key_elem.stop is not None else size - offset
215+
stride = key_elem.step if key_elem.step is not None else 1
216+
elif isinstance(key_elem, int):
217+
offset = key_elem
218+
size = key_elem + 1
219+
stride = 1
220+
offsets.append(offset)
221+
sizes.append(size)
222+
strides.append(stride)
223+
224+
return tuple(offsets), tuple(sizes), tuple(strides)
225+
226+
227+
def _get_new_shape(sizes, strides) -> tuple[int, ...]:
228+
return tuple(size // stride for size, stride in zip(sizes, strides, strict=False))
229+
230+
231+
def getitem(
232+
x: Tensor,
233+
key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...],
234+
) -> Tensor:
235+
if not isinstance(key, tuple):
236+
key = (key,)
237+
if None in key:
238+
raise Exception(f"Lazy indexing isn't supported: {key}")
239+
240+
ret_obj = x._format_class()
241+
242+
key = _add_missing_dims(key, x.ndim)
243+
key = _expand_ellipsis(key, x.ndim)
244+
offsets, sizes, strides = _decompose_slices(key, x.shape)
245+
246+
new_shape = _get_new_shape(sizes, strides)
247+
out_tensor_type = x._obj.get_tensor_definition(new_shape)
248+
249+
slice_module = get_slice_module(
250+
x._obj.get_tensor_definition(x.shape),
251+
out_tensor_type,
252+
offsets,
253+
sizes,
254+
strides,
255+
)
256+
257+
slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg())
258+
259+
return Tensor(ret_obj, shape=out_tensor_type.shape)

sparse/mlir_backend/tests/test_simple.py

+36
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,39 @@ def test_reshape(rng, dtype):
291291
# DENSE
292292
# NOTE: dense reshape is probably broken in MLIR
293293
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
294+
295+
296+
@pytest.mark.skip(reason="https://discourse.llvm.org/t/illegal-operation-when-slicing-csr-csc-coo-tensor/81404")
297+
@parametrize_dtypes
298+
@pytest.mark.parametrize(
299+
"index",
300+
[
301+
0,
302+
(2,),
303+
(2, 3),
304+
(..., slice(0, 4, 2)),
305+
(1, slice(1, None, 1)),
306+
# TODO: For below cases we need an update to ownership mechanism.
307+
# `tensor[:, :]` returns the same memref that was passed.
308+
# The mechanism sees the result as MLIR-allocated and frees
309+
# it, while it still can be owned by SciPy/NumPy causing a
310+
# segfault when it frees SciPy/NumPy managed memory.
311+
# ...,
312+
# slice(None),
313+
# (slice(None), slice(None)),
314+
],
315+
)
316+
def test_indexing_2d(rng, dtype, index):
317+
SHAPE = (20, 30)
318+
DENSITY = 0.5
319+
320+
for format in ["csr", "csc", "coo"]:
321+
arr = sps.random_array(SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng)
322+
arr.sum_duplicates()
323+
324+
tensor = sparse.asarray(arr)
325+
326+
actual = tensor[index].to_scipy_sparse()
327+
expected = arr.todense()[index]
328+
329+
np.testing.assert_array_equal(actual.todense(), expected)

0 commit comments

Comments
 (0)