Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
transformation as dace_transformation,
)
from dace.codegen.targets import cpp as dace_cpp
from dace.sdfg import nodes as dace_nodes
from dace.sdfg import memlet_utils as dace_mutils, nodes as dace_nodes

from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations

Expand Down Expand Up @@ -262,7 +262,7 @@ def _gt_expand_non_standard_memlets_sdfg(
new_maps: set[dace_nodes.MapEntry] = set()
# The implementation is based on DaCe's code generator, see `dace/codegen/targets/cuda.py`
# in the function `preprocess()`
# NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/1976
# NOTE: This implementation needs a DaCe version that includes https://github.com/spcl/dace/pull/2033
for state in sdfg.states():
for e in state.edges():
# We are only interested in edges that connects two access nodes of GPU memory.
Expand All @@ -289,6 +289,20 @@ def _gt_expand_non_standard_memlets_sdfg(
is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1
if is_c_order or is_fortran_order:
continue

# NOTE: Special case of continuous copy
# Example: dcol[0:I, 0:J, k] -> datacol[0:I, 0:J]
# with copy shape [I, J] and strides [J*K, K], [J, 1]
if src_strides[-1] != 1 or dst_strides[-1] != 1:
try:
is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1]
is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1]
except (TypeError, ValueError):
is_src_cont = False
is_dst_cont = False
if is_src_cont and is_dst_cont:
continue

elif dims > 2:
if not (src_strides[-1] != 1 or dst_strides[-1] != 1):
continue
Expand All @@ -298,27 +312,27 @@ def _gt_expand_non_standard_memlets_sdfg(
edge.dst for edge in state.out_edges(a)
]

# Turn unsupported copy to a map
try:
dace_transformation.dataflow.CopyToMap.apply_to(
sdfg,
save=False,
annotate=False,
a=a,
b=b,
options={
"ignore_strides": True
}, # apply 'CopyToMap' even if src/dst strides are different
)
except ValueError: # If transformation doesn't match, continue normally
continue
if not dace_mutils.can_memlet_be_turned_into_a_map(
edge=e, state=state, sdfg=sdfg, ignore_strides=True
):
# NOTE: In DaCe, they simply ignore that case and continue to the
# code generator. In GT4Py we generate an error.
raise RuntimeError(f"Unable to turn the not supported edge '{e}' into a copy Map.")

# Turn the not supported Memlet into a copy Map. We have to do it here,
# such that we can then set their iteration order correctly.
dace_mutils.memlet_to_map(
edge=e,
state=state,
sdfg=sdfg,
ignore_strides=True,
)

# We find the new map by comparing the new neighborhood of `a` with the old one.
new_nodes: set[dace_nodes.MapEntry] = {
edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a
}
assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes)
assert len(new_nodes) == 1
new_maps.update(new_nodes)
return new_maps

Expand Down
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.