Skip to content

Commit db60537

Browse files
authored
ENH: Implement broadcast_to function (#782)
1 parent 1593a0f commit db60537

File tree

3 files changed

+96
-1
lines changed

3 files changed

+96
-1
lines changed

sparse/mlir_backend/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
)
1717
from ._ops import (
1818
add,
19+
broadcast_to,
1920
reshape,
2021
)
2122

2223
__all__ = [
2324
"add",
25+
"broadcast_to",
2426
"asarray",
2527
"asdtype",
2628
"reshape",

sparse/mlir_backend/_ops.py

+43
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,32 @@ def reshape(a, shape):
9595
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
9696

9797

98+
@fn_cache
99+
def get_broadcast_to_module(
100+
in_tensor_type: ir.RankedTensorType,
101+
out_tensor_type: ir.RankedTensorType,
102+
dimensions: tuple[int, ...],
103+
) -> ir.Module:
104+
with ir.Location.unknown(ctx):
105+
module = ir.Module.create()
106+
107+
with ir.InsertionPoint(module.body):
108+
109+
@func.FuncOp.from_py_func(in_tensor_type)
110+
def broadcast_to(in_tensor):
111+
out = tensor.empty(out_tensor_type, [])
112+
return linalg.broadcast(in_tensor, outs=[out], dimensions=dimensions)
113+
114+
broadcast_to.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
115+
if DEBUG:
116+
(CWD / "broadcast_to_module.mlir").write_text(str(module))
117+
pm.run(module.operation)
118+
if DEBUG:
119+
(CWD / "broadcast_to_module_opt.mlir").write_text(str(module))
120+
121+
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])
122+
123+
98124
def add(x1: Tensor, x2: Tensor) -> Tensor:
99125
ret_obj = x1._format_class()
100126
out_tensor_type = x1._obj.get_tensor_definition(x1.shape)
@@ -152,3 +178,20 @@ def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
152178
)
153179

154180
return Tensor(ret_obj, shape=out_tensor_type.shape)
181+
182+
183+
def broadcast_to(x: Tensor, /, shape: tuple[int, ...], dimensions: list[int]) -> Tensor:
184+
x_tensor_type = x._obj.get_tensor_definition(x.shape)
185+
format_class = _infer_format_class(len(shape), x._values_dtype, x._index_dtype)
186+
out_tensor_type = format_class.get_tensor_definition(shape)
187+
ret_obj = format_class()
188+
189+
broadcast_to_module = get_broadcast_to_module(x_tensor_type, out_tensor_type, tuple(dimensions))
190+
191+
broadcast_to_module.invoke(
192+
"broadcast_to",
193+
ctypes.pointer(ctypes.pointer(ret_obj)),
194+
*x._obj.to_module_arg(),
195+
)
196+
197+
return Tensor(ret_obj, shape=shape)

sparse/mlir_backend/tests/test_simple.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -289,5 +289,55 @@ def test_reshape(rng, dtype):
289289
np.testing.assert_array_equal(actual, expected)
290290

291291
# DENSE
292-
# NOTE: dense reshape is probably broken in MLIR
292+
# NOTE: dense reshape is probably broken in MLIR in 19.x branch
293293
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
294+
295+
296+
@parametrize_dtypes
297+
def test_broadcast_to(dtype):
298+
# CSR, CSC, COO
299+
for shape, new_shape, dimensions, input_arr, expected_arrs in [
300+
(
301+
(3, 4),
302+
(2, 3, 4),
303+
[0],
304+
np.array([[0, 1, 0, 3], [0, 0, 4, 5], [6, 7, 0, 0]]),
305+
[
306+
np.array([0, 3, 6]),
307+
np.array([0, 1, 2, 0, 1, 2]),
308+
np.array([0, 2, 4, 6, 8, 10, 12]),
309+
np.array([1, 3, 2, 3, 0, 1, 1, 3, 2, 3, 0, 1]),
310+
np.array([1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0]),
311+
],
312+
),
313+
(
314+
(4, 2),
315+
(4, 2, 2),
316+
[1],
317+
np.array([[0, 1], [0, 0], [2, 3], [4, 0]]),
318+
[
319+
np.array([0, 2, 2, 4, 6]),
320+
np.array([0, 1, 0, 1, 0, 1]),
321+
np.array([0, 1, 2, 4, 6, 7, 8]),
322+
np.array([1, 1, 0, 1, 0, 1, 0, 0]),
323+
np.array([1.0, 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 4.0]),
324+
],
325+
),
326+
]:
327+
for fn_format in [sps.csr_array, sps.csc_array, sps.coo_array]:
328+
arr = fn_format(input_arr, shape=shape, dtype=dtype)
329+
arr.sum_duplicates()
330+
tensor = sparse.asarray(arr)
331+
result = sparse.broadcast_to(tensor, new_shape, dimensions=dimensions).to_scipy_sparse()
332+
333+
for actual, expected in zip(result, expected_arrs, strict=False):
334+
np.testing.assert_allclose(actual, expected)
335+
336+
# DENSE
337+
np_arr = np.array([0, 0, 2, 3, 0, 1])
338+
arr = np.asarray(np_arr, dtype=dtype)
339+
tensor = sparse.asarray(arr)
340+
result = sparse.broadcast_to(tensor, (3, 6), dimensions=[0]).to_scipy_sparse()
341+
342+
assert result.format == "csr"
343+
np.testing.assert_allclose(result.todense(), np.repeat(np_arr[np.newaxis], 3, axis=0))

0 commit comments

Comments
 (0)