Skip to content

Commit fbd4586

Browse files
committed
ENH: Implemented __getitem__ logic
1 parent db60537 commit fbd4586

File tree

3 files changed

+153
-6
lines changed

3 files changed

+153
-6
lines changed

sparse/mlir_backend/_constructors.py

+6
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ def __del__(self):
364364
for field in self._obj.get__fields_():
365365
free_memref(field)
366366

367+
def __getitem__(self, key) -> "Tensor":
368+
# imported lazily to avoid cyclic dependency
369+
from ._ops import getitem
370+
371+
return getitem(self, key)
372+
367373
@_hold_self_ref_in_ret
368374
def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
369375
return self._obj.to_sps(self.shape)

sparse/mlir_backend/_ops.py

+111-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import ctypes
2+
from types import EllipsisType
23

34
import mlir.execution_engine
45
import mlir.passmanager
@@ -85,12 +86,39 @@ def get_reshape_module(
8586
def reshape(a, shape):
8687
return tensor.reshape(out_tensor_type, a, shape)
8788

88-
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
89-
if DEBUG:
90-
(CWD / "reshape_module.mlir").write_text(str(module))
91-
pm.run(module.operation)
92-
if DEBUG:
93-
(CWD / "reshape_module_opt.mlir").write_text(str(module))
89+
reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
90+
if DEBUG:
91+
(CWD / "reshape_module.mlir").write_text(str(module))
92+
pm.run(module.operation)
93+
if DEBUG:
94+
(CWD / "reshape_module_opt.mlir").write_text(str(module))
95+
96+
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
97+
98+
99+
@fn_cache
100+
def get_slice_module(
101+
in_tensor_type: ir.RankedTensorType,
102+
out_tensor_type: ir.RankedTensorType,
103+
offsets: tuple[int, ...],
104+
sizes: tuple[int, ...],
105+
strides: tuple[int, ...],
106+
) -> ir.Module:
107+
with ir.Location.unknown(ctx):
108+
module = ir.Module.create()
109+
110+
with ir.InsertionPoint(module.body):
111+
112+
@func.FuncOp.from_py_func(in_tensor_type)
113+
def getitem(a):
114+
return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides)
115+
116+
getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
117+
if DEBUG:
118+
(CWD / "getitem_module.mlir").write_text(str(module))
119+
pm.run(module.operation)
120+
if DEBUG:
121+
(CWD / "getitem_module_opt.mlir").write_text(str(module))
94122

95123
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
96124

@@ -195,3 +223,80 @@ def broadcast_to(x: Tensor, /, shape: tuple[int, ...], dimensions: list[int]) ->
195223
)
196224

197225
return Tensor(ret_obj, shape=shape)
226+
227+
228+
def _add_missing_dims(key: tuple, ndim: int) -> tuple:
229+
if len(key) < ndim and Ellipsis not in key:
230+
return key + (...,)
231+
return key
232+
233+
234+
def _expand_ellipsis(key: tuple, ndim: int) -> tuple:
235+
if Ellipsis in key:
236+
if len([e for e in key if e is Ellipsis]) > 1:
237+
raise Exception(f"Ellipsis should be used once: {key}")
238+
to_expand = ndim - len(key) + 1
239+
if to_expand <= 0:
240+
raise Exception(f"Invalid use of Ellipsis in {key}")
241+
idx = key.index(Ellipsis)
242+
return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :]
243+
return key
244+
245+
246+
def _decompose_slices(
247+
key: tuple,
248+
shape: tuple[int, ...],
249+
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
250+
offsets = []
251+
sizes = []
252+
strides = []
253+
254+
for key_elem, size in zip(key, shape, strict=False):
255+
if isinstance(key_elem, slice):
256+
offset = key_elem.start if key_elem.start is not None else 0
257+
size = key_elem.stop - offset if key_elem.stop is not None else size - offset
258+
stride = key_elem.step if key_elem.step is not None else 1
259+
elif isinstance(key_elem, int):
260+
offset = key_elem
261+
size = key_elem + 1
262+
stride = 1
263+
offsets.append(offset)
264+
sizes.append(size)
265+
strides.append(stride)
266+
267+
return tuple(offsets), tuple(sizes), tuple(strides)
268+
269+
270+
def _get_new_shape(sizes, strides) -> tuple[int, ...]:
271+
return tuple(size // stride for size, stride in zip(sizes, strides, strict=False))
272+
273+
274+
def getitem(
275+
x: Tensor,
276+
key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...],
277+
) -> Tensor:
278+
if not isinstance(key, tuple):
279+
key = (key,)
280+
if None in key:
281+
raise Exception(f"Lazy indexing isn't supported: {key}")
282+
283+
ret_obj = x._format_class()
284+
285+
key = _add_missing_dims(key, x.ndim)
286+
key = _expand_ellipsis(key, x.ndim)
287+
offsets, sizes, strides = _decompose_slices(key, x.shape)
288+
289+
new_shape = _get_new_shape(sizes, strides)
290+
out_tensor_type = x._obj.get_tensor_definition(new_shape)
291+
292+
slice_module = get_slice_module(
293+
x._obj.get_tensor_definition(x.shape),
294+
out_tensor_type,
295+
offsets,
296+
sizes,
297+
strides,
298+
)
299+
300+
slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg())
301+
302+
return Tensor(ret_obj, shape=out_tensor_type.shape)

sparse/mlir_backend/tests/test_simple.py

+36
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,39 @@ def test_broadcast_to(dtype):
341341

342342
assert result.format == "csr"
343343
np.testing.assert_allclose(result.todense(), np.repeat(np_arr[np.newaxis], 3, axis=0))
344+
345+
346+
@pytest.mark.skip(reason="https://discourse.llvm.org/t/illegal-operation-when-slicing-csr-csc-coo-tensor/81404")
347+
@parametrize_dtypes
348+
@pytest.mark.parametrize(
349+
"index",
350+
[
351+
0,
352+
(2,),
353+
(2, 3),
354+
(..., slice(0, 4, 2)),
355+
(1, slice(1, None, 1)),
356+
# TODO: For below cases we need an update to ownership mechanism.
357+
# `tensor[:, :]` returns the same memref that was passed.
358+
# The mechanism sees the result as MLIR-allocated and frees
359+
# it, while it still can be owned by SciPy/NumPy causing a
360+
# segfault when it frees SciPy/NumPy managed memory.
361+
# ...,
362+
# slice(None),
363+
# (slice(None), slice(None)),
364+
],
365+
)
366+
def test_indexing_2d(rng, dtype, index):
367+
SHAPE = (20, 30)
368+
DENSITY = 0.5
369+
370+
for format in ["csr", "csc", "coo"]:
371+
arr = sps.random_array(SHAPE, density=DENSITY, format=format, dtype=dtype, random_state=rng)
372+
arr.sum_duplicates()
373+
374+
tensor = sparse.asarray(arr)
375+
376+
actual = tensor[index].to_scipy_sparse()
377+
expected = arr.todense()[index]
378+
379+
np.testing.assert_array_equal(actual.todense(), expected)

0 commit comments

Comments
 (0)