|
1 | 1 | import ctypes
|
| 2 | +from types import EllipsisType |
2 | 3 |
|
3 | 4 | import mlir.execution_engine
|
4 | 5 | import mlir.passmanager
|
@@ -85,12 +86,39 @@ def get_reshape_module(
|
85 | 86 | def reshape(a, shape):
|
86 | 87 | return tensor.reshape(out_tensor_type, a, shape)
|
87 | 88 |
|
88 |
| - reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() |
89 |
| - if DEBUG: |
90 |
| - (CWD / "reshape_module.mlir").write_text(str(module)) |
91 |
| - pm.run(module.operation) |
92 |
| - if DEBUG: |
93 |
| - (CWD / "reshape_module_opt.mlir").write_text(str(module)) |
| 89 | + reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() |
| 90 | + if DEBUG: |
| 91 | + (CWD / "reshape_module.mlir").write_text(str(module)) |
| 92 | + pm.run(module.operation) |
| 93 | + if DEBUG: |
| 94 | + (CWD / "reshape_module_opt.mlir").write_text(str(module)) |
| 95 | + |
| 96 | + return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS]) |
| 97 | + |
| 98 | + |
| 99 | +@fn_cache |
| 100 | +def get_slice_module( |
| 101 | + in_tensor_type: ir.RankedTensorType, |
| 102 | + out_tensor_type: ir.RankedTensorType, |
| 103 | + offsets: tuple[int, ...], |
| 104 | + sizes: tuple[int, ...], |
| 105 | + strides: tuple[int, ...], |
| 106 | +) -> ir.Module: |
| 107 | + with ir.Location.unknown(ctx): |
| 108 | + module = ir.Module.create() |
| 109 | + |
| 110 | + with ir.InsertionPoint(module.body): |
| 111 | + |
| 112 | + @func.FuncOp.from_py_func(in_tensor_type) |
| 113 | + def getitem(a): |
| 114 | + return tensor.extract_slice(out_tensor_type, a, [], [], [], offsets, sizes, strides) |
| 115 | + |
| 116 | + getitem.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() |
| 117 | + if DEBUG: |
| 118 | + (CWD / "getitem_module.mlir").write_text(str(module)) |
| 119 | + pm.run(module.operation) |
| 120 | + if DEBUG: |
| 121 | + (CWD / "getitem_module_opt.mlir").write_text(str(module)) |
94 | 122 |
|
95 | 123 | return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
|
96 | 124 |
|
@@ -195,3 +223,80 @@ def broadcast_to(x: Tensor, /, shape: tuple[int, ...], dimensions: list[int]) ->
|
195 | 223 | )
|
196 | 224 |
|
197 | 225 | return Tensor(ret_obj, shape=shape)
|
| 226 | + |
| 227 | + |
| 228 | +def _add_missing_dims(key: tuple, ndim: int) -> tuple: |
| 229 | + if len(key) < ndim and Ellipsis not in key: |
| 230 | + return key + (...,) |
| 231 | + return key |
| 232 | + |
| 233 | + |
| 234 | +def _expand_ellipsis(key: tuple, ndim: int) -> tuple: |
| 235 | + if Ellipsis in key: |
| 236 | + if len([e for e in key if e is Ellipsis]) > 1: |
| 237 | + raise Exception(f"Ellipsis should be used once: {key}") |
| 238 | + to_expand = ndim - len(key) + 1 |
| 239 | + if to_expand <= 0: |
| 240 | + raise Exception(f"Invalid use of Ellipsis in {key}") |
| 241 | + idx = key.index(Ellipsis) |
| 242 | + return key[:idx] + tuple(slice(None) for _ in range(to_expand)) + key[idx + 1 :] |
| 243 | + return key |
| 244 | + |
| 245 | + |
| 246 | +def _decompose_slices( |
| 247 | + key: tuple, |
| 248 | + shape: tuple[int, ...], |
| 249 | +) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: |
| 250 | + offsets = [] |
| 251 | + sizes = [] |
| 252 | + strides = [] |
| 253 | + |
| 254 | + for key_elem, size in zip(key, shape, strict=False): |
| 255 | + if isinstance(key_elem, slice): |
| 256 | + offset = key_elem.start if key_elem.start is not None else 0 |
| 257 | + size = key_elem.stop - offset if key_elem.stop is not None else size - offset |
| 258 | + stride = key_elem.step if key_elem.step is not None else 1 |
| 259 | + elif isinstance(key_elem, int): |
| 260 | + offset = key_elem |
| 261 | + size = key_elem + 1 |
| 262 | + stride = 1 |
| 263 | + offsets.append(offset) |
| 264 | + sizes.append(size) |
| 265 | + strides.append(stride) |
| 266 | + |
| 267 | + return tuple(offsets), tuple(sizes), tuple(strides) |
| 268 | + |
| 269 | + |
| 270 | +def _get_new_shape(sizes, strides) -> tuple[int, ...]: |
| 271 | + return tuple(size // stride for size, stride in zip(sizes, strides, strict=False)) |
| 272 | + |
| 273 | + |
| 274 | +def getitem( |
| 275 | + x: Tensor, |
| 276 | + key: int | slice | EllipsisType | tuple[int | slice | EllipsisType, ...], |
| 277 | +) -> Tensor: |
| 278 | + if not isinstance(key, tuple): |
| 279 | + key = (key,) |
| 280 | + if None in key: |
| 281 | + raise Exception(f"Lazy indexing isn't supported: {key}") |
| 282 | + |
| 283 | + ret_obj = x._format_class() |
| 284 | + |
| 285 | + key = _add_missing_dims(key, x.ndim) |
| 286 | + key = _expand_ellipsis(key, x.ndim) |
| 287 | + offsets, sizes, strides = _decompose_slices(key, x.shape) |
| 288 | + |
| 289 | + new_shape = _get_new_shape(sizes, strides) |
| 290 | + out_tensor_type = x._obj.get_tensor_definition(new_shape) |
| 291 | + |
| 292 | + slice_module = get_slice_module( |
| 293 | + x._obj.get_tensor_definition(x.shape), |
| 294 | + out_tensor_type, |
| 295 | + offsets, |
| 296 | + sizes, |
| 297 | + strides, |
| 298 | + ) |
| 299 | + |
| 300 | + slice_module.invoke("getitem", ctypes.pointer(ctypes.pointer(ret_obj)), *x._obj.to_module_arg()) |
| 301 | + |
| 302 | + return Tensor(ret_obj, shape=out_tensor_type.shape) |
0 commit comments