-
Notifications
You must be signed in to change notification settings - Fork 55
feat[next-dace]: Use SDFG library node for lowering of broadcast and reduce #2386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
f8180e2
ef9ef92
bd1b766
25abd36
071f512
331bcd3
b90976b
21a79d2
a1f6f1a
6e97232
08484df
da12bdb
5382dce
b04586c
4647c7d
779b164
f250040
68a417f
9d774e4
c72e75a
93b1ffd
0915185
cad3fa0
2947655
ab1f400
007f435
16dc60f
84b821d
9e10019
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Final | ||
|
|
||
| import dace | ||
| from dace import library as dace_library, nodes as dace_nodes | ||
| from dace.transformation import transformation as dace_transform | ||
|
|
||
|
|
||
| _INPUT_NAME: Final[str] = "_input" | ||
| _OUTPUT_NAME: Final[str] = "_output" | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @dace_library.expansion | ||
| class ExpandPure(dace_transform.ExpandTransformation): | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Implements pure expansion of the Fill library node.""" | ||
|
|
||
| environments: Final[list[Any]] = [] | ||
|
|
||
| @staticmethod | ||
| def expansion(node: Fill, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG) -> dace.SDFG: | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| sdfg = dace.SDFG(f"{node.label}_sdfg") | ||
|
|
||
| assert len(parent_state.out_edges(node)) == 1 | ||
| outedge = parent_state.out_edges(node)[0] | ||
| out_desc = parent_sdfg.arrays[outedge.data.data] | ||
| inner_out_desc = out_desc.clone() | ||
| inner_out_desc.transient = False | ||
| out = sdfg.add_datadesc(_OUTPUT_NAME, inner_out_desc) | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| outedge._src_conn = _OUTPUT_NAME | ||
| node.add_out_connector(_OUTPUT_NAME) | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| state = sdfg.add_state(f"{node.label}_state") | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| map_params = [f"__i{i}" for i in range(len(out_desc.shape))] | ||
| map_rng = {i: f"0:{s}" for i, s in zip(map_params, out_desc.shape)} | ||
| out_mem = dace.Memlet(expr=f"{out}[{','.join(map_params)}]") | ||
| outputs = {"_out": out_mem} | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| assert len(parent_state.in_edges(node)) == 1 | ||
| inedge = parent_state.in_edges(node)[0] | ||
| inp_desc = parent_sdfg.arrays[inedge.data.data] | ||
| inner_inp_desc = inp_desc.clone() | ||
| inner_inp_desc.transient = False | ||
| inp = sdfg.add_datadesc(_INPUT_NAME, inner_inp_desc) | ||
| inedge._dst_conn = _INPUT_NAME | ||
| node.add_in_connector(_INPUT_NAME) | ||
| inputs = {"_in": dace.Memlet(data=inp, subset="0")} | ||
| code = "_out = _in" | ||
edopao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| state.add_mapped_tasklet( | ||
| f"{node.label}_tasklet", map_rng, inputs, code, outputs, external_edges=True | ||
| ) | ||
|
|
||
| return sdfg | ||
|
|
||
|
|
||
| @dace_library.node | ||
| class Fill(dace_nodes.LibraryNode): | ||
|
||
| """Implements filling data containers with a single value""" | ||
|
|
||
| implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {"pure": ExpandPure} | ||
| default_implementation: Final[str] = "pure" | ||
|
|
||
| def __init__(self, name: str): | ||
| super().__init__(name) | ||
Uh oh!
There was an error while loading. Please reload this page.