diff --git a/docs/development/ADRs/cartesian/backend-cuda-feature-freeze.md b/docs/development/ADRs/cartesian/backend-cuda-feature-freeze.md new file mode 100644 index 0000000000..0568cd9e07 --- /dev/null +++ b/docs/development/ADRs/cartesian/backend-cuda-feature-freeze.md @@ -0,0 +1,17 @@ +# Cuda backend: Feature freeze + +In the context of (backend) feature development, facing maintainability/duplication concerns, we decided to put a feature freeze on the `cuda` backend and focus on the `dace:gpu` backends instead to keep the number of backends manageable. + +## Context + +The introduction of the [`dace:*`](./backend-dace.md) backends brought up the question of backend redundancy. In particular, it seems that `cuda` and `dace:gpu` backends serve similar purposes. + +`dace:gpu` backends not only generate code for different graphics cards, they also share substantial code paths with the `dace:cpu` backend. This simplifies (backend) feature development. + +## Decision + +We decided to put a feature freeze on the `cuda` backend, focusing on the `dace:*` backends instead. While we don't drop the backend, new DSL features won't be available in the `cuda` backend. New features will error out cleanly and suggest to use the `dace:gpu` backend instead. + +## Consequences + +While the `cuda` backend only targets NVIDIA cards, the `dace:*` backends allow to generate code for NVIDIA and AMD graphics cards. Furthermore, `dace:cpu` and `dace:gpu` backends share large parts of the transpilation layers because code generation is deferred to DaCe and only depending on the SDFG. This allows us to develop many (backend) features for the `dace:*` backends in one place. diff --git a/docs/development/ADRs/cartesian/backend-dace-schedule-tree.md b/docs/development/ADRs/cartesian/backend-dace-schedule-tree.md new file mode 100644 index 0000000000..3fa36389ee --- /dev/null +++ b/docs/development/ADRs/cartesian/backend-dace-schedule-tree.md @@ -0,0 +1,73 @@ +# DaCe backends: Schedule tree + +In the context of [DaCe backends](./backend-dace.md), facing tech-debt, a lack of understanding of the current stack, and under performing map- & state fusion, we decided to rewrite substantial parts of the DaCe backends with so called "Schedule Trees" to achieve hardware dependent macro-level optimizations (e.g. loop merging and loop re-ordering) at a new IR level before going down to SDFGs. We considered writing custom SDFG fusion passes and accept that we have to contribute a conversion from Schedule Tree to SDFG in DaCe. + +## Context + +Basically, three forces were driving this drastic change: + +1. We were unhappy with the performance of the DaCe backends, especially on CPU. +2. We had little understanding of the previous GT4Py-DaCe bridge. +3. The previous GT4Py-DaCe bridge accumulated a lot of tech debt, making it clumsy to work with and hard to inject major changes. + +## Decision + +We chose to directly translate GT4Py's optimization IR (OIR) to DaCe's schedule tree (and from there to SDFG and code generation) because this allows to separate macro-level and data-specific optimizations. DaCe's schedule tree is ideally suited for schedule-level optimizations like loop re-ordering or loop merges with over-computation. The (simplified) pipeline looks like this: + +```mermaid +flowchart LR +oir[" + OIR + (GT4Py) +"] +treeir[" + Tree IR + (GT4Py) +"] +stree[" + Schedule tree + (DaCe) +"] +sdfg[" + SDFG + (DaCe) +"] +codegen[" + Code generation + (per target) +"] + +oir --> treeir --> stree --> sdfg --> codegen +``` + +OIR to Tree IR conversion has two visitors in separate files: + +1. `dace/oir_to_treeir.py` transpiles control flow elements +2. `dace/oir_to_tasklet.py` transpiles computations (i.e. bodies of control flow elements) + +While this incurs a bit of code duplications (e.g. resolving index accesses), it allows for separation of concerns: Everything that is related to the schedule is handled in `oir_to_treeir.py`. Note, for example, that we keep the distinction between horizontal mask and general `if` statements. This distinction is kept because horizontal regions might influence scheduling decisions, while general `if` statements do not. + +The subsequent conversion from Tree IR to schedule tree is a straight forward visitor located in `dace/treeir_to_stree.py`. Notice the simplicity of that visitor. + +## Consequences + +The schedule tree introduces a transpilation layer ideally suited for macro-level optimizations, which are targeting the program's execution schedule. This is particularly interesting for the DaCe backends because we use the same backend pipeline to generate code for CPU and GPU targets. + +In particular, the schedule tree allows to easily re-order/modify/change the loop structure. This not only allows us to generate hardware-specific loop order and tile-sizes, but also gives us fine grained control over loop merges and/or which loops to generate in the first place. For example, going directly from OIR to Tree IR allows us to translate horizontal regions to either `if` statements inside a bigger horizontal loop (for small regions) or break them out into separate loops (for bigger regions) if that makes sense for the target architecture. + +## Alternatives considered + +### OIR -> SDFG -> schedule tree -> SDFG + +- Seems smart because it allows to keep the current OIR -> SDFG bridge, i.e. no need to write and OIR -> schedule tree bridge, +- but the first SDFG is unnecessary and translation times are a real problem +- and we were unhappy with the OIR -> SDFG bridge anyway +- and ,in addition, we loose some context between OIR and schedule tree (e.g. horizontal regions). + +### Improve the existing SDFG map fusion + +GT4Py next has gone this route and an improved version is merged in the mainline version of DaCe. We think we'll need a custom map fusion pass which lets us decide low-level things like under which circumstances over-computation is desirable. A general map fusion pass will never be able to allow this. + +### Write custom map fusion based on SDFG syntax + +Possible, but a lot more cumbersome than writing the same transformation based on the schedule tree syntax. diff --git a/docs/development/ADRs/cartesian/backend-dace-version.md b/docs/development/ADRs/cartesian/backend-dace-version.md new file mode 100644 index 0000000000..231bbf3e48 --- /dev/null +++ b/docs/development/ADRs/cartesian/backend-dace-version.md @@ -0,0 +1,26 @@ +# DaCe backends: DaCe version + +In the context of the [DaCe backend](./backend-dace.md) and the [schedule tree](./backend-dace-schedule-tree.md), facing time pressure, we decided to stay at the `v1.x` branch of DaCe to minimize up-front cost and deliver CPU performance as fast as possible. We considered updating to the mainline version of DaCe and accept follow-up cost of partial rewrites once DaCe `v2` releases. + +## Context + +The currently released version of DaCe is on the `v1.x` branch. However, the mainline branch moved on (with breaking changes) to what is supposed to be DaCe `v2`. All feature development is supposed to be merged against mainline. Only bug fixes are allowed on the `v1.x` branch. + +The [schedule tree](./backend-dace-schedule-tree.md) feature will need changes in DaCe, in particular to translate schedule trees into SDFG. We are unfamiliar with the breaking changes in DaCe. + +## Decision + +We decided to build a first version of the schedule tree feature against the `v1.x` version of DaCe. The temporary branch will live on the [GridTools fork of DaCe](https://github.com/GridTools/dace) until the Schedule Tree feature is available in DaCe `v2` and we can update back to mainline again. + +## Consequences + +- We'll be able to code against familiar API (e.g. same as the previous GT4Py-DaCe bridge). +- In DaCe, we won't be able to merge changes into `v1.x`. We'll work on a branch and later refactor the schedule tree -> SDFG transformation to code flow regions in DaCe `v2`. + +## Alternatives considered + +### Update to DaCe mainline first + +- Good because mainline DaCe is accepting new features while `v1.x` is closed for new feature development. +- Bad because it incurs an up-front cost, which we are trying to minimize to get results fast. +- Bad because we aren't trained to use the new control flow regions. diff --git a/docs/development/ADRs/cartesian/backend-dace.md b/docs/development/ADRs/cartesian/backend-dace.md new file mode 100644 index 0000000000..0c7bca95a3 --- /dev/null +++ b/docs/development/ADRs/cartesian/backend-dace.md @@ -0,0 +1,21 @@ +# DaCe backends + +In the context of performance optimization, facing the fragmentedness of Numerical Weather Prediction (NWP) codes, we decided to implement a backend based on DaCe to unlock full-program optimization. We accept the downside of having to maintain that (additional) performance backend. + +## Context + +NWP codes aren't like your typical optimization problem homework where 80% of runtime is spent within a single stencil which you can then optimize to oblivion. Instead, computations in NWP codes are fragmented and scattered all over the place with parts in-between that move memory around. Stencil-only optimizations don't cut through this. DaCe allows us to do (data-flow) optimization on the full program, not only inside stencils. As a nice side-effect, DaCe offers code generation to CPU and GPU targets. + +## Decision + +We chose to add DaCe backends,`dace:cpu` and `dace:gpu`, for CPU and GPU targets because we need full-program optimization to get the best possible performance. + +## Consequences + +We will need to maintain the `dace:*` backends. If we keep adding more and more backends, maintainability will be a question down the road. We thus decided to put a [feature freeze](./backend-cuda-feature-freeze.md) on the `cuda` backend, focussing on `dace:*` backends instead. + +Compared to the [`cuda` backend](./backend-cuda-feature-freeze.md), which only targets NVIDIA cards, we get support for both, NVIDIA and AMD cards, with the `dace:gpu` backends. + +## References + +[DaCe Promo Website](http://dace.is/fast) | [DaCe GitHub](https://github.com/spcl/dace) | [DaCe Documentation](https://spcldace.readthedocs.io/en/latest/) diff --git a/pyproject.toml b/pyproject.toml index 8dd9324020..b748d232cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,9 @@ requires = ['cython>=3.0.0', 'setuptools>=70.0.0', 'versioningit>=3.1.1', 'wheel # -- Dependency groups -- [dependency-groups] build = ['cython>=3.0.0', 'pip>=22.1.1', 'setuptools>=70.0.0', 'wheel>=0.33.6'] -dace-cartesian = ['dace>=1.0.2,<2'] +dace-cartesian = [ + 'dace>=1.0.2,<2' # renfined in [tool.uv.sources] +] dace-next = [ 'dace>=1.0.0' # refined in [tool.uv.sources] ] @@ -133,8 +135,6 @@ requires-python = '>=3.10, <3.14' cartesian = ['gt4py[jax,standard,testing]'] cuda11 = ['cupy-cuda11x>=12.0'] cuda12 = ['cupy-cuda12x>=12.0'] -# 'dace' defined as dependency-group until supported versions are pushed to PyPI -# dace = ['dace>=1.0.2,<1.1.0'] jax = ['jax>=0.4.26'] jax-cuda12 = ['jax[cuda12_local]>=0.4.26', 'gt4py[cuda12]'] next = ['gt4py[jax,standard,testing]'] @@ -448,6 +448,7 @@ url = 'https://test.pypi.org/simple' [tool.uv.sources] atlas4py = {index = "test.pypi"} dace = [ + {git = "https://github.com/GridTools/dace", branch = "romanc/stree-to-sdfg", group = "dace-cartesian"}, {git = "https://github.com/GridTools/dace", tag = "__gt4py-next-integration_2025_07_25", group = "dace-next"} ] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 3f6e3f3072..58f1faba2c 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -16,6 +16,7 @@ from dace import SDFG, Memlet, SDFGState, config, data, dtypes, nodes, subsets, symbolic from dace.codegen import codeobject +from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.utils import inline_sdfgs from gt4py._core import definitions as core_defs @@ -31,18 +32,13 @@ pybuffer_to_sid, ) from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir -from gt4py.cartesian.gtc import common, gtir -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.dace.nodes import StencilComputation -from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder -from gt4py.cartesian.gtc.dace.transformations import ( - NoEmptyEdgeTrivialMapElimination, - nest_sequential_map_scopes, -) +from gt4py.cartesian.gtc import gtir +from gt4py.cartesian.gtc.dace.oir_to_treeir import OIRToTreeIR +from gt4py.cartesian.gtc.dace.treeir_to_stree import TreeIRToScheduleTree from gt4py.cartesian.gtc.dace.utils import array_dimensions, replace_strides from gt4py.cartesian.gtc.gtir_to_oir import GTIRToOIR from gt4py.cartesian.gtc.passes.gtir_k_boundary import compute_k_boundary -from gt4py.cartesian.gtc.passes.gtir_pipeline import GtirPipeline +from gt4py.cartesian.gtc.passes.oir_optimizations import caches from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_fields_extents from gt4py.cartesian.gtc.passes.oir_pipeline import DefaultPipeline from gt4py.cartesian.utils import shash @@ -56,132 +52,36 @@ from gt4py.cartesian.stencil_object import StencilObject -def _specialize_transient_strides(sdfg: SDFG, layout_info: layout.LayoutInfo) -> None: - replacement_dictionary = replace_strides( - [array for array in sdfg.arrays.values() if array.transient], layout_info["layout_map"] +def _specialize_transient_strides( + sdfg: SDFG, layout_info: layout.LayoutInfo, replacement_dictionary: dict[str, str] | None = None +) -> None: + # Find transients in this SDFG to specialize. + stride_replacements = replace_strides( + [ + array + for array in sdfg.arrays.values() + if isinstance(array, data.Array) and array.transient + ], + layout_info["layout_map"], ) + + # In case of nested SDFGs (see below), merge with replacement dict that was passed down. + # Dev note: We shouldn't use mutable data structures as argument defaults. + replacement_dictionary = {} if replacement_dictionary is None else replacement_dictionary + replacement_dictionary.update(stride_replacements) + + # Replace in this SDFG sdfg.replace_dict(replacement_dictionary) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): - for k, v in replacement_dictionary.items(): - if k in node.symbol_mapping: - node.symbol_mapping[k] = v + # Recursively replace strides in nested SDFGs + _specialize_transient_strides(node.sdfg, layout_info, replacement_dictionary) for k in replacement_dictionary.keys(): if k in sdfg.symbols: sdfg.remove_symbol(k) -def _get_expansion_priority_cpu(node: StencilComputation): - expansion_priority = [] - if node.has_splittable_regions(): - expansion_priority.append(["Sections", "Stages", "I", "J", "K"]) - expansion_priority.extend( - [ - ["TileJ", "TileI", "IMap", "JMap", "Sections", "K", "Stages"], - ["TileJ", "TileI", "IMap", "JMap", "Sections", "Stages", "K"], - ["TileJ", "TileI", "Sections", "Stages", "IMap", "JMap", "K"], - ["TileJ", "TileI", "Sections", "K", "Stages", "JMap", "IMap"], - ] - ) - return expansion_priority - - -def _get_expansion_priority_gpu(node: StencilComputation): - expansion_priority = [] - if node.has_splittable_regions(): - expansion_priority.append(["Sections", "Stages", "J", "I", "K"]) - if node.oir_node.loop_order == common.LoopOrder.PARALLEL: - expansion_priority.append(["Sections", "Stages", "K", "J", "I"]) - else: - expansion_priority.append(["J", "I", "Sections", "Stages", "K"]) - expansion_priority.append(["TileJ", "TileI", "Sections", "K", "Stages", "JMap", "IMap"]) - return expansion_priority - - -def _set_expansion_orders(sdfg: SDFG): - for node, _ in filter( - lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive() - ): - if node.device == dtypes.DeviceType.GPU: - expansion_priority = _get_expansion_priority_gpu(node) - else: - expansion_priority = _get_expansion_priority_cpu(node) - is_set = False - for exp in expansion_priority: - try: - node.expansion_specification = exp - is_set = True - except ValueError: - continue - else: - break - if not is_set: - raise ValueError("No expansion compatible") - - -def _set_tile_sizes(sdfg: SDFG): - for node, _ in filter( - lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive() - ): - if node.device == dtypes.DeviceType.GPU: - node.tile_sizes = {dcir.Axis.I: 64, dcir.Axis.J: 8, dcir.Axis.K: 8} - node.tile_sizes_interpretation = "shape" - else: - node.tile_sizes = {dcir.Axis.I: 8, dcir.Axis.J: 8, dcir.Axis.K: 8} - node.tile_sizes_interpretation = "strides" - - -def _to_device(sdfg: SDFG, device: str) -> None: - """Update sdfg in place.""" - if device == "gpu": - for array in sdfg.arrays.values(): - array.storage = dtypes.StorageType.GPU_Global - for node, _ in sdfg.all_nodes_recursive(): - if isinstance(node, StencilComputation): - node.device = dtypes.DeviceType.GPU - - -def _pre_expand_transformations( - gtir_pipeline: GtirPipeline, sdfg: SDFG, layout_info: layout.LayoutInfo -): - args_data = make_args_data_from_gtir(gtir_pipeline) - - # stencils without effect - if all(info is None for info in args_data.field_info.values()): - sdfg = SDFG(gtir_pipeline.gtir.name) - sdfg.add_state(gtir_pipeline.gtir.name) - return sdfg - - sdfg.simplify(validate=False) - - _set_expansion_orders(sdfg) - _set_tile_sizes(sdfg) - _specialize_transient_strides(sdfg, layout_info) - return sdfg - - -def _post_expand_transformations(sdfg: SDFG): - # DaCe "standard" clean-up transformations - sdfg.simplify(validate=False) - - sdfg.apply_transformations_repeated(NoEmptyEdgeTrivialMapElimination, validate=False) - - # Control the `#pragma omp parallel` statements: Fully collapse parallel loops, - # but set 1D maps to be sequential. (Typical domains are too small to benefit from parallelism) - for node, _ in filter(lambda n: isinstance(n[0], nodes.MapEntry), sdfg.all_nodes_recursive()): - node.collapse = len(node.range) - if node.schedule == dtypes.ScheduleType.CPU_Multicore and len(node.range) <= 1: - node.schedule = dtypes.ScheduleType.Sequential - - # To be re-evaluated with https://github.com/GridTools/gt4py/issues/1896 - # sdfg.apply_transformations_repeated(InlineThreadLocalTransients, validate=False) # noqa: ERA001 - sdfg.simplify(validate=False) - nest_sequential_map_scopes(sdfg) - for sd in sdfg.all_sdfgs_recursive(): - sd.openmp_sections = False - - def _sdfg_add_arrays_and_edges( field_info: dict[str, definitions.FieldInfo], wrapper_sdfg: SDFG, @@ -193,7 +93,10 @@ def _sdfg_add_arrays_and_edges( origins, ) -> None: for name, array in inner_sdfg.arrays.items(): - if isinstance(array, data.Array) and not array.transient: + if array.transient: + continue + + if isinstance(array, data.Array): axes = field_info[name].axes shape = [f"__{name}_{axis}_size" for axis in axes] + [ @@ -201,7 +104,11 @@ def _sdfg_add_arrays_and_edges( ] wrapper_sdfg.add_array( - name, dtype=array.dtype, strides=array.strides, shape=shape, storage=array.storage + name, + dtype=array.dtype, + strides=array.strides, + shape=shape, + storage=array.storage, ) if isinstance(origins, tuple): origin = [o for a, o in zip("IJK", origins) if a in axes] @@ -210,13 +117,22 @@ def _sdfg_add_arrays_and_edges( if len(origin) == 3: origin = [o for a, o in zip("IJK", origin) if a in axes] - ranges = [ - (o - max(0, e), o - max(0, e) + s - 1, 1) - for o, e, s in zip( - origin, field_info[name].boundary.lower_indices, inner_sdfg.arrays[name].shape - ) - ] + # Read boundaries for axis-bound fields + if axes != (): + ranges = [ + (o - max(0, e), o - max(0, e) + s - 1, 1) + for o, e, s in zip( + origin, + field_info[name].boundary.lower_indices, + inner_sdfg.arrays[name].shape, + ) + ] + else: + ranges = [] + + # Add data dimensions to the range ranges += [(0, d, 1) for d in field_info[name].data_dims] + if name in inputs: state.add_edge( state.add_read(name), @@ -233,6 +149,29 @@ def _sdfg_add_arrays_and_edges( None, Memlet(name, subset=subsets.Range(ranges)), ) + elif isinstance(array, data.Scalar): + wrapper_sdfg.add_scalar( + name, + dtype=array.dtype, + storage=array.storage, + lifetime=array.lifetime, + ) + if name in inputs: + state.add_edge( + state.add_read(name), + None, + nsdfg, + name, + Memlet(name), + ) + if name in outputs: + state.add_edge( + nsdfg, + name, + state.add_write(name), + None, + Memlet(name), + ) def _sdfg_specialize_symbols(wrapper_sdfg: SDFG, domain: tuple[int, ...]) -> None: @@ -272,13 +211,37 @@ def _sdfg_specialize_symbols(wrapper_sdfg: SDFG, domain: tuple[int, ...]) -> Non def freeze_origin_domain_sdfg( - inner_sdfg: SDFG, + inner_sdfg_unfrozen: SDFG, arg_names: list[str], field_info: dict[str, definitions.FieldInfo], *, + layout_info: layout.LayoutInfo, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...], ) -> SDFG: + """Create a new SDFG by wrapping a _copy_ of the original SDFG and freezing it's + origin and domain. + + This wrapping is required because we do not expect any of the inner_sdfg bounds to + have been specialized, e.g. we expect "__I/J/K" symbols to still be present. We wrap + the call and specialize at top level, which will then be passed as a symbol to the + inner sdfg. + + Once we move specialization of array & maps bounds upstream, this will become moot + and can be removed, see https://github.com/GridTools/gt4py/issues/2082. + + Dev note: we need to wrap a copy to make sure we can use caching with no side effects + in other parts of the SDFG making pipeline. + + Args: + inner_sdfg_unfrozen: SDFG with cartesian bounds as symbols + arg_names: names of arguments to freeze + field_info: full info stack on arguments + origin: tuple of offset into the memory + domain: tuple of size for the memory written by the stencil + """ + inner_sdfg = copy.deepcopy(inner_sdfg_unfrozen) + wrapper_sdfg = SDFG("frozen_" + inner_sdfg.name) state = wrapper_sdfg.add_state("frozen_" + inner_sdfg.name + "_state") @@ -296,7 +259,7 @@ def freeze_origin_domain_sdfg( nsdfg = state.add_nested_sdfg(inner_sdfg, None, inputs, outputs) _sdfg_add_arrays_and_edges( - field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origins=origin + field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origin ) # in special case of empty domain, remove entire SDFG. @@ -315,6 +278,7 @@ def freeze_origin_domain_sdfg( inline_sdfgs(wrapper_sdfg) _sdfg_specialize_symbols(wrapper_sdfg, domain) + _specialize_transient_strides(wrapper_sdfg, layout_info) for _, _, array in wrapper_sdfg.arrays_recursive(): if array.transient: @@ -326,11 +290,53 @@ def freeze_origin_domain_sdfg( class SDFGManager: - # Cache loaded SDFGs across all instances + # Cache loaded SDFGs across all instances (unless caching strategy is "nocaching") _loaded_sdfgs: ClassVar[dict[str | pathlib.Path, SDFG]] = dict() - def __init__(self, builder: StencilBuilder) -> None: + def __init__(self, builder: StencilBuilder, debug_stree: bool = False) -> None: + """ + Initializes the SDFGManager. + + Args: + builder: The StencilBuilder instance, used for build options and caching strategy. + debug_stree: If true, saves a string representation of the schedule tree next to the cached SDFG. + """ self.builder = builder + self.debug_stree = debug_stree + + def schedule_tree(self) -> tn.ScheduleTreeRoot: + """ + Schedule tree representation of the gtir (taken from the builder). + + This function is a three-step process: + + oir = gtir_to_oir(self.builder.gtir) + tree_ir = oir_to_tree_ir(oir) + schedule_tree = tree_ir_to_schedule_tree(tree_ir) + """ + + oir = GTIRToOIR().visit(self.builder.gtir) + + # Deactivate caches. We need to extend the skip list in case users have + # specified skip as well AND we need to copy in order to not trash the + # cache hash! + oir_pipeline: DefaultPipeline = self.builder.options.backend_opts.get( + "oir_pipeline", DefaultPipeline() + ) + oir_pipeline = copy.deepcopy(oir_pipeline) + oir_pipeline.skip.extend( + [ + caches.IJCacheDetection, + caches.KCacheDetection, + caches.PruneKCacheFills, + caches.PruneKCacheFlushes, + ] + ) + oir = oir_pipeline.run(oir) + + tir = OIRToTreeIR(self.builder).visit(oir) + + return TreeIRToScheduleTree().visit(tir) @staticmethod def _strip_history(sdfg: SDFG) -> None: @@ -340,71 +346,71 @@ def _strip_history(sdfg: SDFG) -> None: tmp_sdfg.orig_sdfg = None @staticmethod - def _save_sdfg(sdfg: SDFG, path: str) -> None: + def _save_sdfg(sdfg: SDFG, path: str, validate: bool = False) -> None: + if validate: + sdfg.validate() SDFGManager._strip_history(sdfg) sdfg.save(path) - def _unexpanded_sdfg(self): - filename = self.builder.module_name + ".sdfg" + def sdfg_via_schedule_tree(self) -> SDFG: + """Lower OIR into an SDFG via Schedule Tree transpile first. + + Cache the SDFG into the manager for re-use, unless the builder has a no-caching policy. + """ + filename = f"{self.builder.module_name}.sdfg" path = ( pathlib.Path(os.path.relpath(self.builder.module_path.parent, pathlib.Path.cwd())) / filename ) - if path not in SDFGManager._loaded_sdfgs: - try: - sdfg = SDFG.from_file(path) - except FileNotFoundError: - base_oir = GTIRToOIR().visit(self.builder.gtir) - oir_pipeline = self.builder.options.backend_opts.get( - "oir_pipeline", DefaultPipeline() - ) - oir_node = oir_pipeline.run(base_oir) - sdfg = OirSDFGBuilder().visit(oir_node) + do_cache = self.builder.caching.name != "nocaching" + if do_cache and path in SDFGManager._loaded_sdfgs: + return SDFGManager._loaded_sdfgs[path] - _to_device(sdfg, self.builder.backend.storage_info["device"]) - _pre_expand_transformations( - self.builder.gtir_pipeline, sdfg, self.builder.backend.storage_info - ) - self._save_sdfg(sdfg, path) - SDFGManager._loaded_sdfgs[path] = sdfg + # Create SDFG + stree = self.schedule_tree() + sdfg = stree.as_sdfg( + validate=True, + simplify=True, + skip={"ScalarToSymbolPromotion"}, + ) - return SDFGManager._loaded_sdfgs[path] + if do_cache: + self._save_sdfg(sdfg, str(path)) + SDFGManager._loaded_sdfgs[path] = sdfg - def unexpanded_sdfg(self): - return copy.deepcopy(self._unexpanded_sdfg()) + if self.debug_stree: + stree_path = path.with_suffix(".stree.txt") + with open(stree_path, "w+") as file: + file.write(stree.as_string(-1)) - def _expanded_sdfg(self): - sdfg = self._unexpanded_sdfg() - sdfg.expand_library_nodes() - _post_expand_transformations(sdfg) return sdfg - def expanded_sdfg(self): - return copy.deepcopy(self._expanded_sdfg()) - def _frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]) -> SDFG: basename = self.builder.module_path.with_suffix("") path = f"{basename}_{shash(origin, domain)}.sdfg" - # check if same sdfg already cached on disk - if path in SDFGManager._loaded_sdfgs: + # check if the same sdfg is already loaded + do_cache = self.builder.caching.name != "nocache" + if do_cache and path in SDFGManager._loaded_sdfgs: return SDFGManager._loaded_sdfgs[path] - # otherwise, wrap and save sdfg from scratch - inner_sdfg = self.unexpanded_sdfg() - - sdfg = freeze_origin_domain_sdfg( - inner_sdfg, + # Otherwise, wrap and save sdfg from scratch + sdfg = self.sdfg_via_schedule_tree() + frozen_sdfg = freeze_origin_domain_sdfg( + sdfg, arg_names=[arg.name for arg in self.builder.gtir.api_signature], field_info=make_args_data_from_gtir(self.builder.gtir_pipeline).field_info, + layout_info=self.builder.backend.storage_info, origin=origin, domain=domain, ) - SDFGManager._loaded_sdfgs[path] = sdfg - self._save_sdfg(sdfg, path) - return SDFGManager._loaded_sdfgs[path] + if do_cache: + SDFGManager._loaded_sdfgs[path] = frozen_sdfg + self._save_sdfg(frozen_sdfg, path) + + return frozen_sdfg def frozen_sdfg(self, *, origin: dict[str, tuple[int, ...]], domain: tuple[int, ...]) -> SDFG: return copy.deepcopy(self._frozen_sdfg(origin=origin, domain=domain)) @@ -418,13 +424,21 @@ def __init__(self, class_name: str, module_name: str, backend: BaseDaceBackend) def __call__(self) -> dict[str, dict[str, str]]: manager = SDFGManager(self.backend.builder) - sdfg = manager.expanded_sdfg() + + sdfg = manager.sdfg_via_schedule_tree() + _specialize_transient_strides( + sdfg, + self.backend.storage_info, + ) + + # NOTE + # The glue code in DaCeComputationCodegen.apply() (just below) will define all the + # symbols. Our job creating the sdfg/stree is to make sure we use the same symbols + # and to be sure that these symbols are added as dace symbols. implementation = DaCeComputationCodegen.apply(self.backend.builder, sdfg) - bindings = DaCeBindingsCodegen.apply( - sdfg, module_name=self.module_name, backend=self.backend - ) + bindings = DaCeBindingsCodegen.apply(sdfg, self.module_name, backend=self.backend) bindings_ext = "cu" if self.backend.storage_info["device"] == "gpu" else "cpp" return { @@ -494,7 +508,7 @@ def generate_tmp_allocs(self, sdfg: SDFG) -> list[str]: return res @staticmethod - def _postprocess_dace_code(code_objects: codeobject.CodeObject, is_gpu: bool) -> str: + def _postprocess_dace_code(code_objects: list[codeobject.CodeObject], is_gpu: bool) -> str: lines = code_objects[[co.title for co in code_objects].index("Frame")].clean_code.split( "\n" ) @@ -598,6 +612,10 @@ def generate_dace_args(self, stencil_ir: gtir.Stencil, sdfg: SDFG) -> list[str]: if array.transient: continue + if isinstance(array, data.Scalar): + # will be passed by name (as variable) by the catch all below + continue + dims = [dim for dim, select in zip("IJK", array_dimensions(array)) if select] data_ndim = len(array.shape) - len(dims) @@ -653,14 +671,20 @@ def generate_functor_args(self, sdfg: SDFG) -> list[str]: for name, array in sdfg.arrays.items(): if array.transient: continue - arguments.append(f"auto && __{name}_sid") + if isinstance(array, data.Scalar): + arguments.append(f"auto {name}") + continue + if isinstance(array, data.Array): + arguments.append(f"auto && __{name}_sid") + continue + raise NotImplementedError(f"generate_functor_args(): unexpected type {type(array)}") for name, dtype in ((n, d) for n, d in sdfg.symbols.items() if not n.startswith("__")): arguments.append(dtype.as_arg(name)) return arguments class DaCeBindingsCodegen: - def __init__(self, backend: BaseDaceBackend): + def __init__(self, backend: BaseDaceBackend) -> None: self.backend = backend self._unique_index: int = 0 @@ -676,16 +700,24 @@ def generate_entry_params(self, sdfg: SDFG) -> list[str]: for name in sdfg.signature_arglist(with_types=False, for_call=True): if name in sdfg.arrays: container = sdfg.arrays[name] - assert isinstance(container, data.Array) - res[name] = ( - "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type=( - "object" if self.backend.storage_info["device"] == "gpu" else "buffer" - ), - name=name, - ndim=len(container.shape), + if isinstance(container, data.Scalar): + res[name] = f"{container.ctype} {name}" + elif isinstance(container, data.Array): + res[name] = ( + "py::{pybind_type} {name}, std::array {name}_origin".format( + pybind_type=( + "object" + if self.backend.storage_info["device"] == "gpu" + else "buffer" + ), + name=name, + ndim=len(container.shape), + ) + ) + else: + raise NotImplementedError( + f"generate_entry_params(): unexpected type {type(container)}" ) - ) elif name in sdfg.symbols and not name.startswith("__"): res[name] = f"{sdfg.symbols[name].ctype} {name}" return list(res[node.name] for node in self.backend.builder.gtir.params if node.name in res) @@ -697,10 +729,16 @@ def generate_sid_params(self, sdfg: SDFG) -> list[str]: if array.transient: continue + if isinstance(array, data.Scalar): + res.append(name) + continue + + if not isinstance(array, data.Array): + raise NotImplementedError(f"generate_sid_params(): unexpected type {type(array)}") + domain_dim_flags = tuple(array_dimensions(array)) if len(domain_dim_flags) != 3: raise RuntimeError("Expected 3 cartesian array dimensions. Codegen error.") - data_ndim = len(array.shape) - sum(domain_dim_flags) sid_def = pybuffer_to_sid( name=name, @@ -717,7 +755,7 @@ def generate_sid_params(self, sdfg: SDFG) -> list[str]: res.append(name) return res - def generate_sdfg_bindings(self, sdfg, module_name) -> str: + def generate_sdfg_bindings(self, sdfg: SDFG, module_name: str) -> str: return self.mako_template.render_values( name=sdfg.name, module_name=module_name, @@ -727,7 +765,7 @@ def generate_sdfg_bindings(self, sdfg, module_name) -> str: @classmethod def apply(cls, sdfg: SDFG, module_name: str, *, backend: BaseDaceBackend) -> str: - generated_code = cls(backend).generate_sdfg_bindings(sdfg, module_name=module_name) + generated_code = cls(backend).generate_sdfg_bindings(sdfg, module_name) if backend.builder.options.format_source: generated_code = codegen.format_source("cpp", generated_code, style="LLVM") return generated_code @@ -784,8 +822,8 @@ class DaceCPUBackend(BaseDaceBackend): storage_info: ClassVar[layout.LayoutInfo] = { "alignment": 1, "device": "cpu", - "layout_map": layout.layout_maker_factory((0, 1, 2)), - "is_optimal_layout": layout.layout_checker_factory(layout.layout_maker_factory((0, 1, 2))), + "layout_map": layout.layout_maker_factory((1, 0, 2)), + "is_optimal_layout": layout.layout_checker_factory(layout.layout_maker_factory((1, 0, 2))), } MODULE_GENERATOR_CLASS = DaCePyExtModuleGenerator diff --git a/src/gt4py/cartesian/backend/dace_lazy_stencil.py b/src/gt4py/cartesian/backend/dace_lazy_stencil.py index 6a09258889..38a41ac466 100644 --- a/src/gt4py/cartesian/backend/dace_lazy_stencil.py +++ b/src/gt4py/cartesian/backend/dace_lazy_stencil.py @@ -51,7 +51,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.SDFG: args_data = make_args_data_from_gtir(self.builder.gtir_pipeline) arg_names = [arg.name for arg in self.builder.gtir.api_signature] assert args_data.domain_info is not None - norm_kwargs = DaCeStencilObject.normalize_args( + norm_kwargs = DaCeStencilObject.normalize_arg_fields( *args, backend=self.backend.name, arg_names=arg_names, diff --git a/src/gt4py/cartesian/backend/dace_stencil_object.py b/src/gt4py/cartesian/backend/dace_stencil_object.py index d727b1d901..600295fc24 100644 --- a/src/gt4py/cartesian/backend/dace_stencil_object.py +++ b/src/gt4py/cartesian/backend/dace_stencil_object.py @@ -39,22 +39,26 @@ def _extract_array_infos(field_args, device) -> Dict[str, Optional[ArgsInfo]]: def add_optional_fields( - sdfg: dace.SDFG, field_info: Dict[str, Any], parameter_info: Dict[str, Any], **kwargs: Any + sdfg: dace.SDFG, + field_info: Dict[str, Any], + parameter_info: Dict[str, Any], + **kwargs: Any, ) -> dace.SDFG: sdfg = copy.deepcopy(sdfg) for name, info in field_info.items(): if info.access == AccessKind.NONE and name in kwargs and name not in sdfg.arrays: outer_array = kwargs[name] sdfg.add_array( - name, shape=outer_array.shape, dtype=outer_array.dtype, strides=outer_array.strides + name, + shape=outer_array.shape, + dtype=outer_array.dtype, + strides=outer_array.strides, ) for name, info in parameter_info.items(): if info.access == AccessKind.NONE and name in kwargs and name not in sdfg.symbols: if isinstance(kwargs[name], dace.data.Scalar): sdfg.add_scalar(name, dtype=kwargs[name].dtype) - else: - sdfg.add_symbol(name, stype=dace.typeclass(type(kwargs[name]))) return sdfg @@ -67,7 +71,10 @@ class DaCeFrozenStencil(FrozenStencil, SDFGConvertible): def __sdfg__(self, **kwargs): return add_optional_fields( - self.sdfg, self.stencil_object.field_info, self.stencil_object.parameter_info, **kwargs + self.sdfg, + self.stencil_object.field_info, + self.stencil_object.parameter_info, + **kwargs, ) def __sdfg_signature__(self): @@ -97,7 +104,10 @@ def _get_domain_origin_key(domain, origin): return domain, origins_tuple def freeze( - self: DaCeStencilObject, *, origin: Dict[str, Tuple[int, ...]], domain: Tuple[int, ...] + self: DaCeStencilObject, + *, + origin: Dict[str, Tuple[int, ...]], + domain: Tuple[int, ...], ) -> DaCeFrozenStencil: key = DaCeStencilObject._get_domain_origin_key(domain, origin) @@ -106,10 +116,12 @@ def freeze( return self._frozen_cache[key] # otherwise, wrap and save sdfg from scratch + backend_class = gt_backend.from_name(self.backend) frozen_sdfg = freeze_origin_domain_sdfg( self.sdfg(), arg_names=list(self.__sdfg_signature__()[0]), field_info=self.field_info, + layout_info=backend_class.storage_info, origin=origin, domain=domain, ) @@ -138,7 +150,7 @@ def closure_resolver( def __sdfg__(self, *args, **kwargs) -> dace.SDFG: arg_names, _ = self.__sdfg_signature__() - norm_kwargs = DaCeStencilObject.normalize_args( + norm_kwargs = DaCeStencilObject.normalize_arg_fields( *args, backend=self.backend, arg_names=arg_names, @@ -165,7 +177,7 @@ def __sdfg_signature__(self) -> Tuple[Sequence[str], Sequence[str]]: return (args, []) @staticmethod - def normalize_args( + def normalize_arg_fields( *args, backend: str, arg_names: Iterable[str], @@ -175,11 +187,19 @@ def normalize_args( origin: Optional[Dict[str, Tuple[int, ...]]] = None, **kwargs, ): + """Normalize Fields in argument list to the proper domain/origin""" backend_cls = gt_backend.from_name(backend) args_iter = iter(args) - args_as_kwargs = { - name: (kwargs[name] if name in kwargs else next(args_iter)) for name in arg_names - } + + args_as_kwargs = {} + for name in arg_names: + if name not in field_info.keys(): + continue + if name in kwargs.keys(): + args_as_kwargs[name] = kwargs[name] + else: + args_as_kwargs[name] = next(args_iter) + arg_infos = _extract_array_infos( field_args=args_as_kwargs, device=backend_cls.storage_info["device"] ) diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py deleted file mode 100644 index 8b42e1f319..0000000000 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ /dev/null @@ -1,1006 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from typing import Any, Dict, Generator, List, Optional, Sequence, Set, Tuple, Union - -import dace -import sympy - -from gt4py import eve -from gt4py.cartesian.gtc import common, definitions, oir -from gt4py.cartesian.gtc.common import LocNode -from gt4py.cartesian.gtc.dace import prefix -from gt4py.cartesian.gtc.dace.symbol_utils import ( - get_axis_bound_dace_symbol, - get_axis_bound_diff_str, - get_axis_bound_str, - get_dace_symbol, -) -from gt4py.eve import datamodels - - -@eve.utils.noninstantiable -class Expr(common.Expr): - dtype: common.DataType - - -@eve.utils.noninstantiable -class Stmt(common.Stmt): - pass - - -class Axis(eve.StrEnum): - I = "I" # noqa: E741 [ambiguous-variable-name] - J = "J" - K = "K" - - def domain_symbol(self) -> eve.SymbolRef: - return eve.SymbolRef("__" + self.upper()) - - def iteration_symbol(self) -> eve.SymbolRef: - return eve.SymbolRef("__" + self.lower()) - - def tile_symbol(self) -> eve.SymbolRef: - return eve.SymbolRef("__tile_" + self.lower()) - - @staticmethod - def dims_3d() -> Generator[Axis, None, None]: - yield from [Axis.I, Axis.J, Axis.K] - - @staticmethod - def dims_horizontal() -> Generator[Axis, None, None]: - yield from [Axis.I, Axis.J] - - def to_idx(self) -> int: - return [Axis.I, Axis.J, Axis.K].index(self) - - def domain_dace_symbol(self): - return get_dace_symbol(self.domain_symbol()) - - def iteration_dace_symbol(self): - return get_dace_symbol(self.iteration_symbol()) - - def tile_dace_symbol(self): - return get_dace_symbol(self.tile_symbol()) - - -class MapSchedule(eve.IntEnum): - Default = 0 - Sequential = 1 - - CPU_Multicore = 2 - - GPU_Device = 3 - GPU_ThreadBlock = 4 - - def to_dace_schedule(self): - return { - MapSchedule.Default: dace.ScheduleType.Default, - MapSchedule.Sequential: dace.ScheduleType.Sequential, - MapSchedule.CPU_Multicore: dace.ScheduleType.CPU_Multicore, - MapSchedule.GPU_Device: dace.ScheduleType.GPU_Device, - MapSchedule.GPU_ThreadBlock: dace.ScheduleType.GPU_ThreadBlock, - }[self] - - @classmethod - def from_dace_schedule(cls, schedule): - return { - dace.ScheduleType.Default: MapSchedule.Default, - dace.ScheduleType.Sequential: MapSchedule.Sequential, - dace.ScheduleType.CPU_Multicore: MapSchedule.CPU_Multicore, - dace.ScheduleType.GPU_Default: MapSchedule.GPU_Device, - dace.ScheduleType.GPU_Device: MapSchedule.GPU_Device, - dace.ScheduleType.GPU_ThreadBlock: MapSchedule.GPU_ThreadBlock, - }[schedule] - - -class StorageType(eve.IntEnum): - Default = 0 - - CPU_Heap = 1 - - GPU_Global = 3 - GPU_Shared = 4 - - Register = 5 - - def to_dace_storage(self): - return { - StorageType.Default: dace.StorageType.Default, - StorageType.CPU_Heap: dace.StorageType.CPU_Heap, - StorageType.GPU_Global: dace.StorageType.GPU_Global, - StorageType.GPU_Shared: dace.StorageType.GPU_Shared, - StorageType.Register: dace.StorageType.Register, - }[self] - - @classmethod - def from_dace_storage(cls, schedule): - return { - dace.StorageType.Default: StorageType.Default, - dace.StorageType.CPU_Heap: StorageType.CPU_Heap, - dace.StorageType.GPU_Global: StorageType.GPU_Global, - dace.StorageType.GPU_Shared: StorageType.GPU_Shared, - dace.StorageType.Register: StorageType.Register, - }[schedule] - - -class AxisBound(common.AxisBound): - axis: Axis - - def __str__(self) -> str: - return get_axis_bound_str(self, self.axis.domain_symbol()) - - @classmethod - def from_common(cls, axis, node): - return cls(axis=axis, level=node.level, offset=node.offset) - - def to_dace_symbolic(self): - return get_axis_bound_dace_symbol(self) - - -class IndexWithExtent(eve.Node): - axis: Axis - value: Union[AxisBound, int, str] - extent: Tuple[int, int] - - @property - def free_symbols(self) -> Set[eve.SymbolRef]: - if isinstance(self.value, AxisBound) and self.value.level == common.LevelMarker.END: - return {self.axis.domain_symbol()} - elif isinstance(self.value, str): - return {self.axis.iteration_symbol()} - return set() - - @classmethod - def from_axis(cls, axis: Axis, extent=(0, 0)): - return cls(axis=axis, value=axis.iteration_symbol(), extent=extent) - - @property - def size(self): - return self.extent[1] - self.extent[0] + 1 - - @property - def overapproximated_size(self): - return self.size - - def union(self, other: IndexWithExtent): - assert self.axis == other.axis - if isinstance(self.value, int) or (isinstance(self.value, str) and self.value.isdigit()): - value = other.value - elif isinstance(other.value, int) or ( - isinstance(other.value, str) and other.value.isdigit() - ): - value = self.value - elif ( - self.value == self.axis.iteration_symbol() - or other.value == self.axis.iteration_symbol() - ): - value = self.axis.iteration_symbol() - else: - assert other.value == self.value - value = self.value - return IndexWithExtent( - axis=self.axis, - value=value, - extent=(min(self.extent[0], other.extent[0]), max(self.extent[1], other.extent[1])), - ) - - @property - def idx_range(self): - return (f"{self.value}{self.extent[0]:+d}", f"{self.value}{self.extent[1] + 1:+d}") - - def to_dace_symbolic(self): - if isinstance(self.value, AxisBound): - symbolic_value = get_axis_bound_dace_symbol(self.value) - elif isinstance(self.value, str): - symbolic_value = next( - axis for axis in Axis.dims_3d() if axis.iteration_symbol() == self.value - ).iteration_dace_symbol() - else: - symbolic_value = self.value - return symbolic_value + self.extent[0], symbolic_value + self.extent[1] + 1 - - def shifted(self, offset): - extent = self.extent[0] + offset, self.extent[1] + offset - return IndexWithExtent(axis=self.axis, value=self.value, extent=extent) - - -class DomainInterval(eve.Node): - start: AxisBound - end: AxisBound - - def __init__(self, start: AxisBound, end: AxisBound): - super().__init__() - - if start.axis != end.axis: - raise ValueError( - f"Axis need to match for start and end bounds. Got {start.axis} and {end.axis}." - ) - - self.start = start - self.end = end - - @property - def free_symbols(self) -> Set[eve.SymbolRef]: - res = set() - if self.start.level == common.LevelMarker.END: - res.add(self.start.axis.domain_symbol()) - if self.end.level == common.LevelMarker.END: - res.add(self.end.axis.domain_symbol()) - return res - - @property - def size(self): - return get_axis_bound_diff_str( - self.end, self.start, var_name=self.start.axis.domain_symbol() - ) - - @property - def overapproximated_size(self): - return self.size - - @classmethod - def union(cls, first, second): - return cls(start=min(first.start, second.start), end=max(first.end, second.end)) - - @classmethod - def intersection(cls, axis, first, second): - first_start = first.start if first.start is not None else second.start - first_end = first.end if first.end is not None else second.end - second_start = second.start if second.start is not None else first.start - second_end = second.end if second.end is not None else first_end.end - - if hasattr(first_start, "axis") and first_start.axis != axis: - raise ValueError(f"Axis need to match: {first_start.axis} and {axis} are different.") - - if hasattr(second_start, "axis") and second_start.axis != axis: - raise ValueError(f"Axis need to match: {second_start.axis} and {axis} are different.") - - # overlapping intervals - # or first contained in second - # or second contained in first - if not ( - (first_start <= second_end and second_start <= first_end) - or (second_start <= first_start and first_end <= second_end) - or (first_start <= second_start and second_end <= first_end) - ): - raise ValueError(f"No intersection found for intervals {first} and {second}") - - start = max(first_start, second_start) - start = AxisBound(axis=axis, level=start.level, offset=start.offset) - end = min(first_end, second_end) - end = AxisBound(axis=axis, level=end.level, offset=end.offset) - return cls(start=start, end=end) - - @property - def idx_range(self): - return str(self.start), str(self.end) - - def to_dace_symbolic(self): - return self.start.to_dace_symbolic(), self.end.to_dace_symbolic() - - def shifted(self, offset: int): - return DomainInterval( - start=AxisBound( - axis=self.start.axis, level=self.start.level, offset=self.start.offset + offset - ), - end=AxisBound( - axis=self.end.axis, level=self.end.level, offset=self.end.offset + offset - ), - ) - - def is_subset_of(self, other: DomainInterval) -> bool: - return self.start >= other.start and self.end <= other.end - - -class TileInterval(eve.Node): - axis: Axis - start_offset: int - end_offset: int - tile_size: int - domain_limit: AxisBound - - @property - def free_symbols(self) -> Set[eve.SymbolRef]: - res = {self.axis.tile_symbol()} - if self.domain_limit.level == common.LevelMarker.END: - res.add(self.axis.domain_symbol()) - return res - - @property - def size(self): - return "min({tile_size}, {domain_limit} - {tile_symbol}){halo_size:+d}".format( - tile_size=self.tile_size, - domain_limit=self.domain_limit, - tile_symbol=self.axis.tile_symbol(), - halo_size=self.end_offset - self.start_offset, - ) - - @property - def overapproximated_size(self): - return "{tile_size}{halo_size:+d}".format( - tile_size=self.tile_size, halo_size=self.end_offset - self.start_offset - ) - - @classmethod - def union(cls, first, second): - assert first.axis == second.axis - assert first.tile_size == second.tile_size - assert first.domain_limit == second.domain_limit - return cls( - axis=first.axis, - start_offset=min(first.start_offset, second.start_offset), - end_offset=max(first.end_offset, second.end_offset), - tile_size=first.tile_size, - domain_limit=first.domain_limit, - ) - - @property - def idx_range(self): - start = f"{self.axis.tile_symbol()}{self.start_offset:+d}" - end = f"{start}+({self.size})" - return start, end - - def dace_symbolic_size(self): - return ( - sympy.Min( - self.tile_size, self.domain_limit.to_dace_symbolic() - self.axis.tile_dace_symbol() - ) - + self.end_offset - - self.start_offset - ) - - def to_dace_symbolic(self): - start = self.axis.tile_dace_symbol() + self.start_offset - end = start + self.dace_symbolic_size() - return start, end - - -class Range(eve.Node): - var: eve.SymbolRef - interval: Union[DomainInterval, TileInterval] - stride: int - - @classmethod - def from_axis_and_interval( - cls, axis: Axis, interval: Union[DomainInterval, TileInterval], stride=1 - ): - return cls(var=axis.iteration_symbol(), interval=interval, stride=stride) - - @property - def free_symbols(self) -> Set[eve.SymbolRef]: - return {self.var, *self.interval.free_symbols} - - -class GridSubset(eve.Node): - intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] - - def __iter__(self): - for axis in Axis.dims_3d(): - if axis in self.intervals: - yield self.intervals[axis] - - def items(self): - for axis in Axis.dims_3d(): - if axis in self.intervals: - yield axis, self.intervals[axis] - - @property - def free_symbols(self) -> Set[eve.SymbolRef]: - return set().union(*(interval.free_symbols for interval in self.intervals.values())) - - @classmethod - def single_gridpoint(cls, offset=(0, 0, 0)): - return cls( - intervals={ - axis: IndexWithExtent.from_axis(axis, extent=(offset[i], offset[i])) - for i, axis in enumerate(Axis.dims_3d()) - } - ) - - @property - def shape(self): - return tuple(interval.size for _, interval in self.items()) - - @property - def overapproximated_shape(self): - return tuple(interval.overapproximated_size for _, interval in self.items()) - - def restricted_to_index(self, axis: Axis, extent=(0, 0)) -> GridSubset: - intervals = dict(self.intervals) - intervals[axis] = IndexWithExtent.from_axis(axis, extent=extent) - return GridSubset(intervals=intervals) - - def set_interval( - self, - axis: Axis, - interval: Union[DomainInterval, IndexWithExtent, TileInterval, oir.Interval], - ) -> GridSubset: - if isinstance(interval, oir.Interval): - interval = DomainInterval( - start=AxisBound( - level=interval.start.level, offset=interval.start.offset, axis=Axis.K - ), - end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K), - ) - elif isinstance(interval, DomainInterval): - assert interval.start.axis == axis - intervals = dict(self.intervals) - intervals[axis] = interval - return GridSubset(intervals=intervals) - - @classmethod - def from_gt4py_extent(cls, extent: definitions.Extent): - i_interval = DomainInterval( - start=AxisBound(level=common.LevelMarker.START, offset=extent[0][0], axis=Axis.I), - end=AxisBound(level=common.LevelMarker.END, offset=extent[0][1], axis=Axis.I), - ) - j_interval = DomainInterval( - start=AxisBound(level=common.LevelMarker.START, offset=extent[1][0], axis=Axis.J), - end=AxisBound(level=common.LevelMarker.END, offset=extent[1][1], axis=Axis.J), - ) - - return cls(intervals={Axis.I: i_interval, Axis.J: j_interval}) - - @classmethod - def from_interval( - cls, - interval: Union[DomainInterval, IndexWithExtent, oir.Interval, TileInterval], - axis: Axis, - ): - res_interval: Union[DomainInterval, IndexWithExtent, TileInterval] - if isinstance(interval, (DomainInterval, oir.Interval)): - res_interval = DomainInterval( - start=AxisBound( - level=interval.start.level, offset=interval.start.offset, axis=Axis.K - ), - end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K), - ) - else: - assert isinstance(interval, (IndexWithExtent, TileInterval)) - res_interval = interval - - return cls(intervals={axis: res_interval}) - - def axes(self): - for axis in Axis.dims_3d(): - if axis in self.intervals: - yield axis - - @classmethod - def full_domain(cls, axes=None): - if axes is None: - axes = Axis.dims_3d() - res_subsets = dict() - for axis in axes: - res_subsets[axis] = DomainInterval( - start=AxisBound(axis=axis, level=common.LevelMarker.START, offset=0), - end=AxisBound(axis=axis, level=common.LevelMarker.END, offset=0), - ) - return GridSubset(intervals=res_subsets) - - def tile(self, tile_sizes: Dict[Axis, int]): - res_intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] = {} - for axis, interval in self.intervals.items(): - if isinstance(interval, DomainInterval) and axis in tile_sizes: - if axis == Axis.K: - res_intervals[axis] = TileInterval( - axis=axis, - tile_size=tile_sizes[axis], - start_offset=0, - end_offset=0, - domain_limit=interval.end, - ) - else: - assert ( - interval.start.level == common.LevelMarker.START - and interval.end.level == common.LevelMarker.END - ) - res_intervals[axis] = TileInterval( - axis=axis, - tile_size=tile_sizes[axis], - start_offset=interval.start.offset, - end_offset=interval.end.offset, - domain_limit=AxisBound(axis=axis, level=common.LevelMarker.END, offset=0), - ) - else: - res_intervals[axis] = interval - return GridSubset(intervals=res_intervals) - - def union(self, other): - assert list(self.axes()) == list(other.axes()) - intervals = dict() - for axis in self.axes(): - interval1 = self.intervals[axis] - interval2 = other.intervals[axis] - if isinstance(interval1, DomainInterval) and isinstance(interval2, DomainInterval): - intervals[axis] = DomainInterval.union(interval1, interval2) - elif isinstance(interval1, TileInterval) and isinstance(interval2, TileInterval): - intervals[axis] = TileInterval.union(interval1, interval2) - elif isinstance(interval1, IndexWithExtent) and isinstance(interval2, IndexWithExtent): - intervals[axis] = interval1.union(interval2) - else: - assert ( - isinstance(interval2, (DomainInterval, TileInterval)) - and isinstance(interval1, (DomainInterval, IndexWithExtent)) - ) or ( - isinstance(interval1, (DomainInterval, TileInterval)) - and isinstance(interval2, IndexWithExtent) - ) - intervals[axis] = ( - interval1 - if isinstance(interval1, (DomainInterval, TileInterval)) - else interval2 - ) - return GridSubset(intervals=intervals) - - -class FieldAccessInfo(eve.Node): - grid_subset: GridSubset - global_grid_subset: GridSubset - dynamic_access: bool = False - variable_offset_axes: List[Axis] = eve.field(default_factory=list) - - def axes(self): - yield from self.grid_subset.axes() - - @property - def shape(self): - return self.grid_subset.shape - - @property - def overapproximated_shape(self): - return self.grid_subset.overapproximated_shape - - def apply_iteration(self, grid_subset: GridSubset): - res_intervals = dict(self.grid_subset.intervals) - for axis, field_interval in self.grid_subset.intervals.items(): - if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval): - grid_interval = grid_subset.intervals[axis] - assert isinstance(field_interval, IndexWithExtent) - extent = field_interval.extent - if isinstance(grid_interval, DomainInterval): - if axis in self.global_grid_subset.intervals: - res_intervals[axis] = self.global_grid_subset.intervals[axis] - else: - res_intervals[axis] = DomainInterval( - start=AxisBound( - axis=axis, - level=grid_interval.start.level, - offset=grid_interval.start.offset + extent[0], - ), - end=AxisBound( - axis=axis, - level=grid_interval.end.level, - offset=grid_interval.end.offset + extent[1], - ), - ) - elif isinstance(grid_interval, TileInterval): - res_intervals[axis] = TileInterval( - axis=axis, - tile_size=grid_interval.tile_size, - start_offset=grid_interval.start_offset + extent[0], - end_offset=grid_interval.end_offset + extent[1], - domain_limit=grid_interval.domain_limit, - ) - else: - assert field_interval.value == grid_interval.value - extent = ( - min(extent) + grid_interval.extent[0], - max(extent) + grid_interval.extent[1], - ) - res_intervals[axis] = IndexWithExtent( - axis=axis, value=field_interval.value, extent=extent - ) - return FieldAccessInfo( - grid_subset=GridSubset(intervals=res_intervals), - dynamic_access=self.dynamic_access, - variable_offset_axes=self.variable_offset_axes, - global_grid_subset=self.global_grid_subset, - ) - - def union(self, other: FieldAccessInfo): - grid_subset = self.grid_subset.union(other.grid_subset) - global_subset = self.global_grid_subset.union(other.global_grid_subset) - variable_offset_axes = [ - axis - for axis in Axis.dims_3d() - if axis in self.variable_offset_axes or axis in other.variable_offset_axes - ] - return FieldAccessInfo( - grid_subset=grid_subset, - dynamic_access=self.dynamic_access or other.dynamic_access, - variable_offset_axes=variable_offset_axes, - global_grid_subset=global_subset, - ) - - def clamp_full_axis(self, axis): - grid_subset = GridSubset(intervals=self.grid_subset.intervals) - interval = self.grid_subset.intervals[axis] - full_interval = DomainInterval( - start=AxisBound(level=common.LevelMarker.START, offset=0, axis=axis), - end=AxisBound(level=common.LevelMarker.END, offset=0, axis=axis), - ) - res_interval = DomainInterval.union( - full_interval, self.global_grid_subset.intervals.get(axis, full_interval) - ) - if isinstance(interval, DomainInterval): - interval_union = DomainInterval.union(interval, res_interval) - grid_subset.intervals[axis] = interval_union - else: - grid_subset.intervals[axis] = res_interval - grid_subset = grid_subset.set_interval(axis, res_interval) - return FieldAccessInfo( - grid_subset=grid_subset, - dynamic_access=self.dynamic_access, - variable_offset_axes=self.variable_offset_axes, - global_grid_subset=self.global_grid_subset, - ) - - def untile(self, tile_axes: Sequence[Axis]) -> FieldAccessInfo: - res_intervals = {} - for axis, interval in self.grid_subset.intervals.items(): - if isinstance(interval, TileInterval) and axis in tile_axes: - res_intervals[axis] = self.global_grid_subset.intervals[axis] - else: - res_intervals[axis] = interval - return FieldAccessInfo( - grid_subset=GridSubset(intervals=res_intervals), - global_grid_subset=self.global_grid_subset, - dynamic_access=self.dynamic_access, - variable_offset_axes=self.variable_offset_axes, - ) - - def restricted_to_index(self, axis: Axis, extent: Tuple[int, int] = (0, 0)): - return FieldAccessInfo( - grid_subset=self.grid_subset.restricted_to_index(axis=axis, extent=extent), - global_grid_subset=self.global_grid_subset, - dynamic_access=self.dynamic_access, - variable_offset_axes=self.variable_offset_axes, - ) - - -class Memlet(eve.Node): - field: eve.Coerced[eve.SymbolRef] - access_info: FieldAccessInfo - connector: eve.Coerced[eve.SymbolRef] - is_read: bool - is_write: bool - - def union(self, other): - assert self.field == other.field - return Memlet( - field=self.field, - access_info=self.access_info.union(other.access_info), - connector=self.field, - is_read=self.is_read or other.is_read, - is_write=self.is_write or other.is_write, - ) - - def remove_read(self): - return Memlet( - field=self.field, - access_info=self.access_info, - connector=self.connector, - is_read=False, - is_write=self.is_write, - ) - - def remove_write(self): - return Memlet( - field=self.field, - access_info=self.access_info, - connector=self.connector, - is_read=self.is_read, - is_write=False, - ) - - -class Decl(LocNode): - name: eve.Coerced[eve.SymbolName] - dtype: common.DataType - - def __init__(self, *args: Any, **kwargs: Any) -> None: - if type(self) is Decl: - raise TypeError("Trying to instantiate `Decl` abstract class.") - super().__init__(*args, **kwargs) - - -class FieldDecl(Decl): - strides: Tuple[Union[int, str], ...] - data_dims: Tuple[int, ...] = eve.field(default_factory=tuple) - access_info: FieldAccessInfo - storage: StorageType - - @property - def shape(self): - access_info = self.access_info - for axis in self.access_info.variable_offset_axes: - access_info = access_info.clamp_full_axis(axis) - ijk_shape = access_info.grid_subset.shape - return ijk_shape + tuple(self.data_dims) - - def axes(self): - yield from self.access_info.grid_subset.axes() - - @property - def is_dynamic(self) -> bool: - return self.access_info.dynamic_access - - def with_set_access_info(self, access_info: FieldAccessInfo) -> FieldDecl: - return FieldDecl( - name=self.name, - dtype=self.dtype, - strides=self.strides, - data_dims=self.data_dims, - access_info=access_info, - ) - - -class Literal(common.Literal, Expr): - pass - - -class ScalarAccess(common.ScalarAccess, Expr): - is_target: bool - original_name: Optional[str] = None - - -class VariableKOffset(common.VariableKOffset[Expr]): - @datamodels.validator("k") - def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Expr) -> None: - for part in expression.walk_values(): - if isinstance(part, Cast): - raise ValueError( - "DaCe backends are currently missing support for casts in variable k offsets. See issue https://github.com/GridTools/gt4py/issues/1881." - ) - - -class IndexAccess(common.FieldAccess, Expr): - # ScalarAccess used for indirect addressing - offset: Optional[common.CartesianOffset | Literal | ScalarAccess | VariableKOffset] - is_target: bool - - explicit_indices: Optional[list[Literal | ScalarAccess | VariableKOffset]] = None - """Used to access as a full field with explicit indices""" - - -class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt): - _dtype_validation = common.assign_stmt_dtype_validation(strict=True) - - -class MaskStmt(Stmt): - mask: Expr - body: List[Stmt] - - @datamodels.validator("mask") - def mask_is_boolean_field_expr(self, attribute: datamodels.Attribute, v: Expr) -> None: - if v.dtype != common.DataType.BOOL: - raise ValueError("Mask must be a boolean expression.") - - -class HorizontalRestriction(common.HorizontalRestriction[Stmt], Stmt): - pass - - -class UnaryOp(common.UnaryOp[Expr], Expr): - pass - - -class BinaryOp(common.BinaryOp[Expr], Expr): - _dtype_propagation = common.binary_op_dtype_propagation(strict=True) - - -class TernaryOp(common.TernaryOp[Expr], Expr): - _dtype_propagation = common.ternary_op_dtype_propagation(strict=True) - - -class Cast(common.Cast[Expr], Expr): - pass - - -class NativeFuncCall(common.NativeFuncCall[Expr], Expr): - _dtype_propagation = common.native_func_call_dtype_propagation(strict=True) - - -class While(common.While[Stmt, Expr], Stmt): - pass - - -class ScalarDecl(Decl): - pass - - -class LocalScalarDecl(ScalarDecl): - pass - - -class SymbolDecl(ScalarDecl): - def to_dace_symbol(self): - return get_dace_symbol(self.name, self.dtype) - - -class Temporary(FieldDecl): - pass - - -class ComputationNode(LocNode): - # mapping connector names to tuple of field name and access info - read_memlets: List[Memlet] - write_memlets: List[Memlet] - - @datamodels.validator("read_memlets") - @datamodels.validator("write_memlets") - def _validator(self, attribute: datamodels.Attribute, value: List[Memlet]) -> None: - conns: Dict[eve.SymbolRef, Set[eve.SymbolRef]] = {} - for memlet in value: - conns.setdefault(memlet.field, set()) - if memlet.connector in conns[memlet.field]: - raise ValueError(f"Found multiple Memlets for connector '{memlet.connector}'") - conns[memlet.field].add(memlet.connector) - - @property - def input_connectors(self): - return set(ml.connector for ml in self.read_memlets) - - @property - def output_connectors(self): - return set(ml.connector for ml in self.write_memlets) - - -class IterationNode(eve.Node): - grid_subset: GridSubset - - -class Condition(eve.Node): - condition: Tasklet - true_states: list[ComputationState | Condition | WhileLoop] - - # Currently unused due to how `if` statements are parsed in `gtir_to_oir`, see - # https://github.com/GridTools/gt4py/issues/1898 - false_states: list[ComputationState | Condition | WhileLoop] = eve.field(default_factory=list) - - @datamodels.validator("condition") - def condition_has_boolean_expression( - self, attribute: datamodels.Attribute, tasklet: Tasklet - ) -> None: - assert isinstance(tasklet, Tasklet) - assert len(tasklet.stmts) == 1 - assert isinstance(tasklet.stmts[0], AssignStmt) - assert isinstance(tasklet.stmts[0].left, ScalarAccess) - if tasklet.stmts[0].left.original_name is None: - raise ValueError( - f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error." - ) - assert isinstance(tasklet.stmts[0].right, Expr) - if tasklet.stmts[0].right.dtype != common.DataType.BOOL: - raise ValueError("Condition must be a boolean expression.") - - -class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait): - label: str - stmts: List[Stmt] - grid_subset: GridSubset = GridSubset.single_gridpoint() - - @datamodels.validator("stmts") - def non_empty_list(self, attribute: datamodels.Attribute, v: list[Stmt]) -> None: - if len(v) < 1: - raise ValueError("Tasklet must contain at least one statement.") - - @datamodels.validator("stmts") - def read_after_write(self, attribute: datamodels.Attribute, statements: list[Stmt]) -> None: - def _remove_prefix(name: eve.SymbolRef) -> str: - return name.removeprefix(prefix.TASKLET_OUT).removeprefix(prefix.TASKLET_IN) - - class ReadAfterWriteChecker(eve.NodeVisitor): - def visit_IndexAccess(self, node: IndexAccess, writes: set[str]) -> None: - if node.is_target: - # Keep track of writes - writes.add(_remove_prefix(node.name)) - return - - # Check reads - if ( - node.name.startswith(prefix.TASKLET_OUT) - and _remove_prefix(node.name) not in writes - ): - raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.") - - if _remove_prefix(node.name) in writes and not node.name.startswith( - prefix.TASKLET_OUT - ): - raise ValueError( - f"Read after write of '{node.name}' not connected to out connector. DaCe IR error." - ) - - def visit_ScalarAccess(self, node: ScalarAccess, writes: set[str]) -> None: - # Handle stencil parameters differently because they are always available - if not node.name.startswith(prefix.TASKLET_IN) and not node.name.startswith( - prefix.TASKLET_OUT - ): - return - - # Keep track of writes - if node.is_target: - writes.add(_remove_prefix(node.name)) - return - - # Make sure we don't read uninitialized memory - if ( - node.name.startswith(prefix.TASKLET_OUT) - and _remove_prefix(node.name) not in writes - ): - raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.") - - if _remove_prefix(node.name) in writes and not node.name.startswith( - prefix.TASKLET_OUT - ): - raise ValueError( - f"Read after write of '{node.name}' not connected to out connector. DaCe IR error." - ) - - def visit_AssignStmt(self, node: AssignStmt, writes: Set[eve.SymbolRef]) -> None: - # Visiting order matters because `writes` must not contain the symbols from the left visit - self.visit(node.right, writes=writes) - self.visit(node.left, writes=writes) - - writes: set[str] = set() - checker = ReadAfterWriteChecker() - for statement in statements: - checker.visit(statement, writes=writes) - - -class DomainMap(ComputationNode, IterationNode): - index_ranges: List[Range] - schedule: MapSchedule - computations: List[Union[Tasklet, DomainMap, NestedSDFG]] - - -class ComputationState(IterationNode): - computations: List[Union[Tasklet, DomainMap]] - - -class DomainLoop(ComputationNode, IterationNode): - axis: Axis - index_range: Range - loop_states: list[ComputationState | Condition | DomainLoop | WhileLoop] - - -class WhileLoop(eve.Node): - condition: Tasklet - body: list[ComputationState | Condition | WhileLoop] - - @datamodels.validator("condition") - def condition_has_boolean_expression( - self, attribute: datamodels.Attribute, tasklet: Tasklet - ) -> None: - assert isinstance(tasklet, Tasklet) - assert len(tasklet.stmts) == 1 - assert isinstance(tasklet.stmts[0], AssignStmt) - assert isinstance(tasklet.stmts[0].left, ScalarAccess) - if tasklet.stmts[0].left.original_name is None: - raise ValueError( - f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error." - ) - assert isinstance(tasklet.stmts[0].right, Expr) - if tasklet.stmts[0].right.dtype != common.DataType.BOOL: - raise ValueError("Condition must be a boolean expression.") - - -class NestedSDFG(ComputationNode, eve.SymbolTableTrait): - label: eve.Coerced[eve.SymbolRef] - field_decls: List[FieldDecl] - symbol_decls: List[SymbolDecl] - states: list[ComputationState | Condition | DomainLoop | WhileLoop] - - -# There are circular type references with string placeholders. These statements let datamodels resolve those. -DomainMap.update_forward_refs() # type: ignore[attr-defined] -DomainLoop.update_forward_refs() # type: ignore[attr-defined] diff --git a/src/gt4py/cartesian/gtc/dace/expansion/__init__.py b/src/gt4py/cartesian/gtc/dace/expansion/__init__.py deleted file mode 100644 index abf4c3e24c..0000000000 --- a/src/gt4py/cartesian/gtc/dace/expansion/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py deleted file mode 100644 index 892909b210..0000000000 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ /dev/null @@ -1,1264 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -from dataclasses import dataclass -from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union, cast - -import dace -import dace.data - -from gt4py import eve -from gt4py.cartesian import utils as gt_utils -from gt4py.cartesian.gtc import common, oir -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.dace.expansion.utils import remove_horizontal_region -from gt4py.cartesian.gtc.dace.expansion_specification import Loop, Map, Sections, Stages -from gt4py.cartesian.gtc.dace.utils import ( - compute_tasklet_access_infos, - flatten_list, - get_tasklet_symbol, - make_dace_subset, - union_inout_memlets, - union_node_grid_subsets, - untile_memlets, -) -from gt4py.cartesian.gtc.definitions import Extent - - -if TYPE_CHECKING: - from gt4py.cartesian.gtc.dace.nodes import StencilComputation - - -class AccessType(Enum): - READ = 0 - WRITE = 1 - - -def _field_access_iterator( - code_block: oir.CodeBlock | oir.MaskStmt | oir.While, access_type: AccessType -): - if access_type == AccessType.WRITE: - return ( - code_block.walk_values() - .if_isinstance(oir.AssignStmt) - .getattr("left") - .if_isinstance(oir.FieldAccess) - ) - - def read_access_iterator(): - for node in code_block.walk_values(): - if isinstance(node, oir.AssignStmt): - yield from node.right.walk_values().if_isinstance(oir.FieldAccess) - elif isinstance(node, oir.While): - yield from node.cond.walk_values().if_isinstance(oir.FieldAccess) - elif isinstance(node, oir.MaskStmt): - yield from node.mask.walk_values().if_isinstance(oir.FieldAccess) - - return read_access_iterator() - - -def _mapped_access_iterator( - node: oir.CodeBlock | oir.MaskStmt | oir.While, access_type: AccessType -): - iterator = _field_access_iterator(node, access_type) - write_access = access_type == AccessType.WRITE - - yield from ( - eve.utils.xiter(iterator).map( - lambda acc: ( - acc.name, - acc.offset, - get_tasklet_symbol(acc.name, offset=acc.offset, is_target=write_access), - ) - ) - ).unique(key=lambda x: x[2]) - - -def _get_tasklet_inout_memlets( - node: oir.CodeBlock | oir.MaskStmt | oir.While, - access_type: AccessType, - *, - global_ctx: DaCeIRBuilder.GlobalContext, - horizontal_extent, - k_interval, - grid_subset: dcir.GridSubset, - dcir_statements: list[dcir.Stmt], -) -> list[dcir.Memlet]: - access_infos = compute_tasklet_access_infos( - node, - collect_read=access_type == AccessType.READ, - collect_write=access_type == AccessType.WRITE, - declarations=global_ctx.library_node.declarations, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - grid_subset=grid_subset, - ) - - names = [ - access.name - for statement in dcir_statements - for access in statement.walk_values().if_isinstance(dcir.ScalarAccess, dcir.IndexAccess) - ] - - memlets: list[dcir.Memlet] = [] - for name, offset, tasklet_symbol in _mapped_access_iterator(node, access_type): - # Avoid adding extra inputs/outputs to the tasklet - if name not in access_infos: - continue - - # Find `tasklet_symbol` in dcir_statements because we can't know (from the oir statements) - # where the tasklet boundaries will be. Consider - # - # with computation(PARALLEL), interval(...): - # statement1 - # if condition: - # statement2 - # statement3 - # - # statements 1 and 3 will end up in the same CodeBlock but aren't in the same tasklet. - if tasklet_symbol not in names: - continue - - access_info = access_infos[name] - if not access_info.variable_offset_axes: - offset_dict = offset.to_dict() - for axis in access_info.axes(): - access_info = access_info.restricted_to_index( - axis, extent=(offset_dict[axis.lower()], offset_dict[axis.lower()]) - ) - - memlets.append( - dcir.Memlet( - field=name, - connector=tasklet_symbol, - access_info=access_info, - is_read=access_type == AccessType.READ, - is_write=access_type == AccessType.WRITE, - ) - ) - return memlets - - -def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval: Any) -> bool: - def all_statements_in_region(scope_nodes: List[eve.Node]) -> bool: - return all( - isinstance(stmt, dcir.HorizontalRestriction) - for tasklet in eve.walk_values(scope_nodes).if_isinstance(dcir.Tasklet) - for stmt in tasklet.stmts - ) - - def all_regions_same(scope_nodes: List[eve.Node]) -> bool: - return ( - len( - set( - ( - ( - None - if mask.intervals[axis.to_idx()].start is None - else mask.intervals[axis.to_idx()].start.level - ), - ( - None - if mask.intervals[axis.to_idx()].start is None - else mask.intervals[axis.to_idx()].start.offset - ), - ( - None - if mask.intervals[axis.to_idx()].end is None - else mask.intervals[axis.to_idx()].end.level - ), - ( - None - if mask.intervals[axis.to_idx()].end is None - else mask.intervals[axis.to_idx()].end.offset - ), - ) - for mask in eve.walk_values(scope_nodes).if_isinstance(common.HorizontalMask) - ) - ) - == 1 - ) - - return ( - axis in dcir.Axis.dims_horizontal() - and isinstance(interval, dcir.DomainInterval) - and all_statements_in_region(scope_nodes) - and all_regions_same(scope_nodes) - ) - - -class DaCeIRBuilder(eve.NodeTranslator): - @dataclass - class GlobalContext: - library_node: StencilComputation - arrays: Dict[str, dace.data.Data] - - def get_dcir_decls( - self, - access_infos: Dict[eve.SymbolRef, dcir.FieldAccessInfo], - symbol_collector: DaCeIRBuilder.SymbolCollector, - ) -> List[dcir.FieldDecl]: - return [ - self._get_dcir_decl(field, access_info, symbol_collector=symbol_collector) - for field, access_info in access_infos.items() - ] - - def _get_dcir_decl( - self, - field: eve.SymbolRef, - access_info: dcir.FieldAccessInfo, - symbol_collector: DaCeIRBuilder.SymbolCollector, - ) -> dcir.FieldDecl: - oir_decl: oir.Decl = self.library_node.declarations[field] - assert isinstance(oir_decl, oir.FieldDecl) - dace_array = self.arrays[field] - for stride in dace_array.strides: - for symbol in dace.symbolic.symlist(stride).values(): - symbol_collector.add_symbol(str(symbol)) - for symbol in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(symbol) - - return dcir.FieldDecl( - name=field, - dtype=oir_decl.dtype, - strides=tuple(str(s) for s in dace_array.strides), - data_dims=oir_decl.data_dims, - access_info=access_info, - storage=dcir.StorageType.from_dace_storage(dace.StorageType.Default), - ) - - @dataclass - class IterationContext: - grid_subset: dcir.GridSubset - parent: Optional[DaCeIRBuilder.IterationContext] = None - - def push_axes_extents(self, axes_extents) -> DaCeIRBuilder.IterationContext: - res = self.grid_subset - for axis, extent in axes_extents.items(): - axis_interval = res.intervals[axis] - if isinstance(axis_interval, dcir.DomainInterval): - res__interval = dcir.DomainInterval( - start=dcir.AxisBound( - level=common.LevelMarker.START, offset=extent[0], axis=axis - ), - end=dcir.AxisBound( - level=common.LevelMarker.END, offset=extent[1], axis=axis - ), - ) - res = res.set_interval(axis, res__interval) - elif isinstance(axis_interval, dcir.TileInterval): - tile_interval = dcir.TileInterval( - axis=axis, - start_offset=extent[0], - end_offset=extent[1], - tile_size=axis_interval.tile_size, - domain_limit=axis_interval.domain_limit, - ) - res = res.set_interval(axis, tile_interval) - # if is IndexWithExtent, do nothing. - return DaCeIRBuilder.IterationContext(grid_subset=res, parent=self) - - def push_interval( - self, axis: dcir.Axis, interval: Union[dcir.DomainInterval, oir.Interval] - ) -> DaCeIRBuilder.IterationContext: - return DaCeIRBuilder.IterationContext( - grid_subset=self.grid_subset.set_interval(axis, interval), parent=self - ) - - def push_expansion_item(self, item: Union[Map, Loop]) -> DaCeIRBuilder.IterationContext: - if not isinstance(item, (Map, Loop)): - raise ValueError - - iterations = item.iterations if isinstance(item, Map) else [item] - grid_subset = self.grid_subset - for it in iterations: - axis = it.axis - if it.kind == "tiling": - assert it.stride is not None - grid_subset = grid_subset.tile(tile_sizes={axis: it.stride}) - else: - grid_subset = grid_subset.restricted_to_index(axis) - return DaCeIRBuilder.IterationContext(grid_subset=grid_subset, parent=self) - - def push_expansion_items( - self, items: Iterable[Union[Map, Loop]] - ) -> DaCeIRBuilder.IterationContext: - res = self - for item in items: - res = res.push_expansion_item(item) - return res - - def pop(self) -> DaCeIRBuilder.IterationContext: - assert self.parent is not None - return self.parent - - @dataclass - class SymbolCollector: - symbol_decls: Dict[str, dcir.SymbolDecl] = dataclasses.field(default_factory=dict) - - def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32) -> None: - if name not in self.symbol_decls: - self.symbol_decls[name] = dcir.SymbolDecl(name=name, dtype=dtype) - else: - assert self.symbol_decls[name].dtype == dtype - - def remove_symbol(self, name: eve.SymbolRef) -> None: - if name in self.symbol_decls: - del self.symbol_decls[name] - - def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> dcir.Literal: - return dcir.Literal(value=node.value, dtype=node.dtype) - - def visit_UnaryOp(self, node: oir.UnaryOp, **kwargs: Any) -> dcir.UnaryOp: - return dcir.UnaryOp(op=node.op, expr=self.visit(node.expr, **kwargs), dtype=node.dtype) - - def visit_BinaryOp(self, node: oir.BinaryOp, **kwargs: Any) -> dcir.BinaryOp: - return dcir.BinaryOp( - op=node.op, - left=self.visit(node.left, **kwargs), - right=self.visit(node.right, **kwargs), - dtype=node.dtype, - ) - - def visit_HorizontalRestriction( - self, - node: oir.HorizontalRestriction, - *, - symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs: Any, - ) -> dcir.HorizontalRestriction: - for axis, interval in zip(dcir.Axis.dims_horizontal(), node.mask.intervals): - for bound in (interval.start, interval.end): - if bound is not None: - symbol_collector.add_symbol(axis.iteration_symbol()) - if bound.level == common.LevelMarker.END: - symbol_collector.add_symbol(axis.domain_symbol()) - - return dcir.HorizontalRestriction( - mask=node.mask, - body=self.visit( - node.body, - symbol_collector=symbol_collector, - inside_horizontal_region=True, - **kwargs, - ), - ) - - def visit_VariableKOffset( - self, node: oir.VariableKOffset, **kwargs: Any - ) -> dcir.VariableKOffset: - return dcir.VariableKOffset(k=self.visit(node.k, **kwargs)) - - def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> dcir.LocalScalarDecl: - return dcir.LocalScalarDecl(name=node.name, dtype=node.dtype) - - def visit_FieldAccess( - self, - node: oir.FieldAccess, - *, - is_target: bool, - targets: list[oir.FieldAccess | oir.ScalarAccess], - var_offset_fields: set[eve.SymbolRef], - K_write_with_offset: set[eve.SymbolRef], - **kwargs: Any, - ) -> dcir.IndexAccess | dcir.ScalarAccess: - """Generate the relevant accessor to match the memlet that was previously setup. - - Args: - is_target (bool): true if we write to this FieldAccess - """ - - # Distinguish between writing to a variable and reading a previously written variable. - # In the latter case (read after write), we need to read from the "gtOUT__" symbol. - is_write = is_target - is_target = is_target or ( - # read after write (within a code block) - any( - isinstance(t, oir.FieldAccess) and t.name == node.name and t.offset == node.offset - for t in targets - ) - ) - name = get_tasklet_symbol(node.name, offset=node.offset, is_target=is_target) - - access_node: dcir.IndexAccess | dcir.ScalarAccess - if node.name in var_offset_fields.union(K_write_with_offset): - access_node = dcir.IndexAccess( - name=name, - is_target=is_target, - offset=self.visit( - node.offset, - is_target=False, - targets=targets, - var_offset_fields=var_offset_fields, - K_write_with_offset=K_write_with_offset, - **kwargs, - ), - data_index=self.visit( - node.data_index, - is_target=False, - targets=targets, - var_offset_fields=var_offset_fields, - K_write_with_offset=K_write_with_offset, - **kwargs, - ), - dtype=node.dtype, - ) - elif node.data_index: - access_node = dcir.IndexAccess( - name=name, - offset=None, - is_target=is_target, - data_index=self.visit( - node.data_index, - is_target=False, - targets=targets, - var_offset_fields=var_offset_fields, - K_write_with_offset=K_write_with_offset, - **kwargs, - ), - dtype=node.dtype, - ) - else: - access_node = dcir.ScalarAccess(name=name, dtype=node.dtype, is_target=is_write) - - if is_write and not any( - isinstance(t, oir.FieldAccess) and t.name == node.name and t.offset == node.offset - for t in targets - ): - targets.append(node) - return access_node - - def visit_ScalarAccess( - self, - node: oir.ScalarAccess, - *, - is_target: bool, - targets: list[oir.FieldAccess | oir.ScalarAccess], - global_ctx: DaCeIRBuilder.GlobalContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - **_: Any, - ) -> dcir.ScalarAccess: - if node.name in global_ctx.library_node.declarations: - # Handle stencil parameters differently because they are always available - symbol_collector.add_symbol(node.name, dtype=node.dtype) - return dcir.ScalarAccess(name=node.name, dtype=node.dtype, is_target=is_target) - - # Distinguish between writing to a variable and reading a previously written variable. - # In the latter case (read after write), we need to read from the "gtOUT__" symbol. - is_write = is_target - is_target = is_target or ( - # read after write (within a code block) - any(isinstance(t, oir.ScalarAccess) and t.name == node.name for t in targets) - ) - - if is_write and not any( - isinstance(t, oir.ScalarAccess) and t.name == node.name for t in targets - ): - targets.append(node) - - # Rename local scalars inside tasklets such that we can pass them from one state - # to another (same as we do for index access). - tasklet_name = get_tasklet_symbol(node.name, is_target=is_target) - return dcir.ScalarAccess( - name=tasklet_name, original_name=node.name, dtype=node.dtype, is_target=is_write - ) - - def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs: Any) -> dcir.AssignStmt: - # Visiting order matters because targets must not contain the target symbols from the left visit - right = self.visit(node.right, is_target=False, **kwargs) - left = self.visit(node.left, is_target=True, **kwargs) - return dcir.AssignStmt(left=left, right=right, loc=node.loc) - - def _condition_tasklet( - self, - node: oir.MaskStmt | oir.While, - *, - global_ctx: DaCeIRBuilder.GlobalContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - horizontal_extent, - k_interval, - iteration_ctx: DaCeIRBuilder.IterationContext, - targets: list[oir.FieldAccess | oir.ScalarAccess], - **kwargs: Any, - ) -> dcir.Tasklet: - condition_expression = node.mask if isinstance(node, oir.MaskStmt) else node.cond - prefix = "if" if isinstance(node, oir.MaskStmt) else "while" - tmp_name = f"{prefix}_expression_{id(node)}" - - # Reset the set of targets (used for detecting read after write inside a tasklet) - targets.clear() - - statement = dcir.AssignStmt( - right=self.visit( - condition_expression, - is_target=False, - targets=targets, - global_ctx=global_ctx, - symbol_collector=symbol_collector, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - iteration_ctx=iteration_ctx, - **kwargs, - ), - left=dcir.ScalarAccess( - name=get_tasklet_symbol(tmp_name, is_target=True), - original_name=tmp_name, - dtype=common.DataType.BOOL, - loc=node.loc, - is_target=True, - ), - loc=node.loc, - ) - - read_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( - node, - AccessType.READ, - global_ctx=global_ctx, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - grid_subset=iteration_ctx.grid_subset, - dcir_statements=[statement], - ) - - tasklet = dcir.Tasklet( - label=f"eval_{prefix}_{id(node)}", - stmts=[statement], - read_memlets=read_memlets, - write_memlets=[], - ) - # See notes inside the function - self._fix_memlet_array_access( - tasklet=tasklet, - memlets=read_memlets, - global_context=global_ctx, - symbol_collector=symbol_collector, - targets=targets, - **kwargs, - ) - - return tasklet - - def visit_MaskStmt( - self, - node: oir.MaskStmt, - global_ctx: DaCeIRBuilder.GlobalContext, - iteration_ctx: DaCeIRBuilder.IterationContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - horizontal_extent, - k_interval, - targets: list[oir.FieldAccess | oir.ScalarAccess], - inside_horizontal_region: bool = False, - **kwargs: Any, - ) -> dcir.MaskStmt | dcir.Condition: - if inside_horizontal_region: - # inside horizontal regions, we use old-style mask statements that - # might translate to if statements inside the tasklet - return dcir.MaskStmt( - mask=self.visit( - node.mask, - is_target=False, - global_ctx=global_ctx, - iteration_ctx=iteration_ctx, - symbol_collector=symbol_collector, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - inside_horizontal_region=inside_horizontal_region, - targets=targets, - **kwargs, - ), - body=self.visit( - node.body, - global_ctx=global_ctx, - iteration_ctx=iteration_ctx, - symbol_collector=symbol_collector, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - inside_horizontal_region=inside_horizontal_region, - targets=targets, - **kwargs, - ), - ) - - tasklet = self._condition_tasklet( - node, - global_ctx=global_ctx, - symbol_collector=symbol_collector, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - iteration_ctx=iteration_ctx, - targets=targets, - **kwargs, - ) - code_block = self.visit( - oir.CodeBlock(body=node.body, loc=node.loc, label=f"condition_{id(node)}"), - global_ctx=global_ctx, - symbol_collector=symbol_collector, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - iteration_ctx=iteration_ctx, - targets=targets, - **kwargs, - ) - targets.clear() - return dcir.Condition(condition=tasklet, true_states=gt_utils.listify(code_block)) - - def visit_While( - self, - node: oir.While, - global_ctx: DaCeIRBuilder.GlobalContext, - iteration_ctx: DaCeIRBuilder.IterationContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - horizontal_extent, - k_interval, - targets: list[oir.FieldAccess | oir.ScalarAccess], - inside_horizontal_region: bool = False, - **kwargs: Any, - ) -> dcir.While | dcir.WhileLoop: - if inside_horizontal_region: - # inside horizontal regions, we use old-style while statements that - # might translate to while statements inside the tasklet - return dcir.While( - cond=self.visit( - node.cond, - is_target=False, - global_ctx=global_ctx, - iteration_ctx=iteration_ctx, - symbol_collector=symbol_collector, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - inside_horizontal_region=inside_horizontal_region, - targets=targets, - **kwargs, - ), - body=self.visit( - node.body, - global_ctx=global_ctx, - iteration_ctx=iteration_ctx, - symbol_collector=symbol_collector, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - inside_horizontal_region=inside_horizontal_region, - targets=targets, - **kwargs, - ), - ) - - tasklet = self._condition_tasklet( - node, - global_ctx=global_ctx, - symbol_collector=symbol_collector, - iteration_ctx=iteration_ctx, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - targets=targets, - **kwargs, - ) - code_block = self.visit( - oir.CodeBlock(body=node.body, loc=node.loc, label=f"while_{id(node)}"), - global_ctx=global_ctx, - symbol_collector=symbol_collector, - iteration_ctx=iteration_ctx, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - targets=targets, - **kwargs, - ) - targets.clear() - return dcir.WhileLoop(condition=tasklet, body=code_block) - - def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> dcir.Cast: - return dcir.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs)) - - def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> dcir.NativeFuncCall: - return dcir.NativeFuncCall( - func=node.func, args=self.visit(node.args, **kwargs), dtype=node.dtype - ) - - def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> dcir.TernaryOp: - return dcir.TernaryOp( - cond=self.visit(node.cond, **kwargs), - true_expr=self.visit(node.true_expr, **kwargs), - false_expr=self.visit(node.false_expr, **kwargs), - dtype=node.dtype, - ) - - def _fix_memlet_array_access( - self, - *, - tasklet: dcir.Tasklet, - memlets: list[dcir.Memlet], - global_context: DaCeIRBuilder.GlobalContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs: Any, - ) -> None: - for memlet in memlets: - """ - This loop handles the special case of a tasklet performing array access. - The memlet should pass the full array shape (no tiling) and - the tasklet expression for array access should use all explicit indexes. - """ - array_ndims = len(global_context.arrays[memlet.field].shape) - field_decl = global_context.library_node.field_decls[memlet.field] - # calculate array subset on original memlet - memlet_subset = make_dace_subset( - global_context.library_node.access_infos[memlet.field], - memlet.access_info, - field_decl.data_dims, - ) - # select index values for single-point grid access - memlet_data_index = [ - dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32) - for dim_range, dim_size in zip(memlet_subset, memlet_subset.size()) - if dim_size == 1 - ] - if len(memlet_data_index) < array_ndims: - reshape_memlet = False - for access_node in tasklet.walk_values().if_isinstance(dcir.IndexAccess): - if access_node.data_index and access_node.name == memlet.connector: - # Order matters! - # Resolve first the cartesian dimensions packed in memlet_data_index - access_node.explicit_indices = [] - for data_index in memlet_data_index: - access_node.explicit_indices.append( - self.visit( - data_index, - symbol_collector=symbol_collector, - global_ctx=global_context, - **kwargs, - ) - ) - # Separate between case where K is offset or absolute and - # where it's a regular offset (should be dealt with the above memlet_data_index) - if access_node.offset: - access_node.explicit_indices.append(access_node.offset) - # Add any remaining data dimensions indexing - for data_index in access_node.data_index: - access_node.explicit_indices.append( - self.visit( - data_index, - symbol_collector=symbol_collector, - global_ctx=global_context, - is_target=False, - **kwargs, - ) - ) - assert len(access_node.explicit_indices) == array_ndims - reshape_memlet = True - if reshape_memlet: - # ensure that memlet symbols used for array indexing are defined in context - for sym in memlet.access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) - # set full shape on memlet - memlet.access_info = global_context.library_node.access_infos[memlet.field] - - def visit_CodeBlock( - self, - node: oir.CodeBlock, - *, - global_ctx: DaCeIRBuilder.GlobalContext, - iteration_ctx: DaCeIRBuilder.IterationContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - horizontal_extent, - k_interval, - targets: list[oir.FieldAccess | oir.ScalarAccess], - **kwargs: Any, - ): - # Reset the set of targets (used for detecting read after write inside a tasklet) - targets.clear() - statements = [ - self.visit( - statement, - targets=targets, - global_ctx=global_ctx, - symbol_collector=symbol_collector, - iteration_ctx=iteration_ctx, - k_interval=k_interval, - horizontal_extent=horizontal_extent, - **kwargs, - ) - for statement in node.body - ] - - # Gather all statements that aren't control flow (e.g. everything except Condition and WhileLoop), - # put them in a tasklet, and call "to_state" on it. - # Then, return a new list with types that are either ComputationState, Condition, or WhileLoop. - dace_nodes: list[dcir.ComputationState | dcir.Condition | dcir.WhileLoop] = [] - current_block: list[dcir.Stmt] = [] - for index, statement in enumerate(statements): - is_control_flow = isinstance(statement, (dcir.Condition, dcir.WhileLoop)) - if not is_control_flow: - current_block.append(statement) - - last_statement = index == len(statements) - 1 - if (is_control_flow or last_statement) and len(current_block) > 0: - read_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( - node, - AccessType.READ, - global_ctx=global_ctx, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - grid_subset=iteration_ctx.grid_subset, - dcir_statements=current_block, - ) - write_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( - node, - AccessType.WRITE, - global_ctx=global_ctx, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - grid_subset=iteration_ctx.grid_subset, - dcir_statements=current_block, - ) - tasklet = dcir.Tasklet( - label=node.label, - stmts=current_block, - read_memlets=read_memlets, - write_memlets=write_memlets, - ) - # See notes inside the function - self._fix_memlet_array_access( - tasklet=tasklet, - memlets=[*read_memlets, *write_memlets], - global_context=global_ctx, - symbol_collector=symbol_collector, - targets=targets, - **kwargs, - ) - - dace_nodes.append(*self.to_state(tasklet, grid_subset=iteration_ctx.grid_subset)) - - # reset block scope - current_block = [] - - # append control flow statement after new tasklet (if applicable) - if is_control_flow: - dace_nodes.append(statement) - - return dace_nodes - - def visit_HorizontalExecution( - self, - node: oir.HorizontalExecution, - *, - global_ctx: DaCeIRBuilder.GlobalContext, - iteration_ctx: DaCeIRBuilder.IterationContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - k_interval, - **kwargs: Any, - ): - extent = global_ctx.library_node.get_extents(node) - - stages_idx = next( - idx - for idx, item in enumerate(global_ctx.library_node.expansion_specification) - if isinstance(item, Stages) - ) - expansion_items = global_ctx.library_node.expansion_specification[stages_idx + 1 :] - - iteration_ctx = iteration_ctx.push_axes_extents( - {k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)} - ) - iteration_ctx = iteration_ctx.push_expansion_items(expansion_items) - assert iteration_ctx.grid_subset == dcir.GridSubset.single_gridpoint() - - code_block = oir.CodeBlock(body=node.body, loc=node.loc, label=f"he_{id(node)}") - targets: list[oir.FieldAccess | oir.ScalarAccess] = [] - dcir_nodes = self.visit( - code_block, - global_ctx=global_ctx, - iteration_ctx=iteration_ctx, - symbol_collector=symbol_collector, - horizontal_extent=global_ctx.library_node.get_extents(node), - k_interval=k_interval, - targets=targets, - **kwargs, - ) - - for item in reversed(expansion_items): - iteration_ctx = iteration_ctx.pop() - dcir_nodes = self._process_iteration_item( - dcir_nodes, - item, - global_ctx=global_ctx, - iteration_ctx=iteration_ctx, - symbol_collector=symbol_collector, - **kwargs, - ) - # pop stages context (pushed with push_grid_subset) - iteration_ctx.pop() - - return dcir_nodes - - def visit_VerticalLoopSection( - self, - node: oir.VerticalLoopSection, - *, - iteration_ctx: DaCeIRBuilder.IterationContext, - global_ctx: DaCeIRBuilder.GlobalContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs, - ): - sections_idx, stages_idx = [ - idx - for idx, item in enumerate(global_ctx.library_node.expansion_specification) - if isinstance(item, (Sections, Stages)) - ] - expansion_items = global_ctx.library_node.expansion_specification[ - sections_idx + 1 : stages_idx - ] - - iteration_ctx = iteration_ctx.push_interval( - dcir.Axis.K, node.interval - ).push_expansion_items(expansion_items) - - dcir_nodes = self.generic_visit( - node.horizontal_executions, - iteration_ctx=iteration_ctx, - global_ctx=global_ctx, - symbol_collector=symbol_collector, - k_interval=node.interval, - **kwargs, - ) - - # if multiple horizontal executions, enforce their order by means of a state machine - if len(dcir_nodes) > 1: - dcir_nodes = [ - self.to_state([node], grid_subset=node.grid_subset) - for node in flatten_list(dcir_nodes) - ] - - for item in reversed(expansion_items): - iteration_ctx = iteration_ctx.pop() - dcir_nodes = self._process_iteration_item( - scope=dcir_nodes, - item=item, - iteration_ctx=iteration_ctx, - global_ctx=global_ctx, - symbol_collector=symbol_collector, - ) - # pop off interval - iteration_ctx.pop() - return dcir_nodes - - def to_dataflow( - self, - nodes, - *, - global_ctx: DaCeIRBuilder.GlobalContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - ): - nodes = flatten_list(nodes) - if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): - return nodes - if not all( - isinstance(n, (dcir.ComputationState, dcir.Condition, dcir.DomainLoop, dcir.WhileLoop)) - for n in nodes - ): - raise ValueError("Can't mix dataflow and state nodes on same level.") - - read_memlets, write_memlets, field_memlets = union_inout_memlets(nodes) - - field_decls = global_ctx.get_dcir_decls( - {memlet.field: memlet.access_info for memlet in field_memlets}, - symbol_collector=symbol_collector, - ) - read_fields = {memlet.field for memlet in read_memlets} - write_fields = {memlet.field for memlet in write_memlets} - read_memlets = [ - memlet.remove_write() for memlet in field_memlets if memlet.field in read_fields - ] - write_memlets = [ - memlet.remove_read() for memlet in field_memlets if memlet.field in write_fields - ] - - return [ - dcir.NestedSDFG( - label=global_ctx.library_node.label, - field_decls=field_decls, - # NestedSDFG must have same shape on input and output, matching corresponding - # nsdfg.sdfg's array shape - read_memlets=read_memlets, - write_memlets=write_memlets, - states=nodes, - symbol_decls=list(symbol_collector.symbol_decls.values()), - ) - ] - - def to_state(self, nodes, *, grid_subset: dcir.GridSubset): - nodes = flatten_list(nodes) - if all( - isinstance(n, (dcir.ComputationState, dcir.Condition, dcir.DomainLoop, dcir.WhileLoop)) - for n in nodes - ): - return nodes - if all(isinstance(n, (dcir.DomainMap, dcir.NestedSDFG, dcir.Tasklet)) for n in nodes): - return [dcir.ComputationState(computations=nodes, grid_subset=grid_subset)] - - raise ValueError("Can't mix dataflow and state nodes on same level.") - - def _process_map_item( - self, - scope_nodes, - item: Map, - *, - global_ctx: DaCeIRBuilder.GlobalContext, - iteration_ctx: DaCeIRBuilder.IterationContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs: Any, - ) -> List[dcir.DomainMap]: - grid_subset = iteration_ctx.grid_subset - read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) - scope_nodes = self.to_dataflow( - scope_nodes, global_ctx=global_ctx, symbol_collector=symbol_collector - ) - - ranges = [] - for iteration in item.iterations: - axis = iteration.axis - interval = iteration_ctx.grid_subset.intervals[axis] - grid_subset = grid_subset.set_interval(axis, interval) - if iteration.kind == "tiling": - read_memlets = untile_memlets(read_memlets, axes=[axis]) - write_memlets = untile_memlets(write_memlets, axes=[axis]) - if not axis == dcir.Axis.K: - interval = dcir.DomainInterval( - start=dcir.AxisBound.from_common(axis, oir.AxisBound.start()), - end=dcir.AxisBound.from_common(axis, oir.AxisBound.end()), - ) - symbol_collector.remove_symbol(axis.tile_symbol()) - ranges.append( - dcir.Range(var=axis.tile_symbol(), interval=interval, stride=iteration.stride) - ) - else: - if _all_stmts_same_region(scope_nodes, axis, interval): - masks = cast( - List[common.HorizontalMask], - eve.walk_values(scope_nodes).if_isinstance(common.HorizontalMask).to_list(), - ) - horizontal_mask_interval = next( - iter((mask.intervals[axis.to_idx()] for mask in masks)) - ) - interval = dcir.DomainInterval.intersection( - axis, horizontal_mask_interval, interval - ) - scope_nodes = remove_horizontal_region(scope_nodes, axis) - assert iteration.kind == "contiguous" - res_read_memlets = [] - res_write_memlets = [] - for memlet in read_memlets: - access_info = memlet.access_info.apply_iteration( - dcir.GridSubset.from_interval(interval, axis) - ) - for sym in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) - res_read_memlets.append( - dcir.Memlet( - field=memlet.field, - connector=memlet.connector, - access_info=access_info, - is_read=True, - is_write=False, - ) - ) - for memlet in write_memlets: - access_info = memlet.access_info.apply_iteration( - dcir.GridSubset.from_interval(interval, axis) - ) - for sym in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) - res_write_memlets.append( - dcir.Memlet( - field=memlet.field, - connector=memlet.connector, - access_info=access_info, - is_read=False, - is_write=True, - ) - ) - read_memlets = res_read_memlets - write_memlets = res_write_memlets - - assert not isinstance(interval, dcir.IndexWithExtent) - index_range = dcir.Range.from_axis_and_interval(axis, interval) - symbol_collector.remove_symbol(index_range.var) - ranges.append(index_range) - - return [ - dcir.DomainMap( - computations=scope_nodes, - index_ranges=ranges, - schedule=dcir.MapSchedule.from_dace_schedule(item.schedule), - read_memlets=read_memlets, - write_memlets=write_memlets, - grid_subset=grid_subset, - ) - ] - - def _process_loop_item( - self, - scope_nodes, - item: Loop, - *, - iteration_ctx: DaCeIRBuilder.IterationContext, - symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs: Any, - ) -> List[dcir.DomainLoop]: - grid_subset = union_node_grid_subsets(list(scope_nodes)) - read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) - scope_nodes = self.to_state(scope_nodes, grid_subset=grid_subset) - - axis = item.axis - interval = iteration_ctx.grid_subset.intervals[axis] - grid_subset = grid_subset.set_interval(axis, interval) - if item.kind == "tiling": - raise NotImplementedError("Tiling as a state machine not implemented.") - - assert item.kind == "contiguous" - res_read_memlets = [] - res_write_memlets = [] - for memlet in read_memlets: - access_info = memlet.access_info.apply_iteration( - dcir.GridSubset.from_interval(interval, axis) - ) - for sym in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) - res_read_memlets.append( - dcir.Memlet( - field=memlet.field, - connector=memlet.connector, - access_info=access_info, - is_read=True, - is_write=False, - ) - ) - for memlet in write_memlets: - access_info = memlet.access_info.apply_iteration( - dcir.GridSubset.from_interval(interval, axis) - ) - for sym in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) - res_write_memlets.append( - dcir.Memlet( - field=memlet.field, - connector=memlet.connector, - access_info=access_info, - is_read=False, - is_write=True, - ) - ) - read_memlets = res_read_memlets - write_memlets = res_write_memlets - - assert not isinstance(interval, dcir.IndexWithExtent) - index_range = dcir.Range.from_axis_and_interval(axis, interval, stride=item.stride) - for sym in index_range.free_symbols: - symbol_collector.add_symbol(sym, common.DataType.INT32) - symbol_collector.remove_symbol(index_range.var) - return [ - dcir.DomainLoop( - axis=axis, - loop_states=scope_nodes, - index_range=index_range, - read_memlets=read_memlets, - write_memlets=write_memlets, - grid_subset=grid_subset, - ) - ] - - def _process_iteration_item(self, scope, item, **kwargs): - if isinstance(item, Map): - return self._process_map_item(scope, item, **kwargs) - if isinstance(item, Loop): - return self._process_loop_item(scope, item, **kwargs) - - raise ValueError("Invalid expansion specification set.") - - def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs: Any - ) -> dcir.NestedSDFG: - overall_extent = Extent.zeros(2) - for he in node.walk_values().if_isinstance(oir.HorizontalExecution): - overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he)) - - iteration_ctx = DaCeIRBuilder.IterationContext( - grid_subset=dcir.GridSubset.from_gt4py_extent(overall_extent).set_interval( - axis=dcir.Axis.K, interval=node.sections[0].interval - ) - ) - - # Variable offsets - var_offset_fields = { - acc.name - for acc in node.walk_values().if_isinstance(oir.FieldAccess) - if isinstance(acc.offset, oir.VariableKOffset) - } - - # We book keep - all write offset to K - K_write_with_offset = set() - for assign_node in node.walk_values().if_isinstance(oir.AssignStmt): - if isinstance(assign_node.left, oir.FieldAccess): - if ( - isinstance(assign_node.left.offset, common.CartesianOffset) - and assign_node.left.offset.k != 0 - ): - K_write_with_offset.add(assign_node.left.name) - - sections_idx = next( - idx - for idx, item in enumerate(global_ctx.library_node.expansion_specification) - if isinstance(item, Sections) - ) - expansion_items = global_ctx.library_node.expansion_specification[:sections_idx] - iteration_ctx = iteration_ctx.push_expansion_items(expansion_items) - - symbol_collector = DaCeIRBuilder.SymbolCollector() - sections = flatten_list( - self.generic_visit( - node.sections, - global_ctx=global_ctx, - iteration_ctx=iteration_ctx, - symbol_collector=symbol_collector, - var_offset_fields=var_offset_fields, - K_write_with_offset=K_write_with_offset, - **kwargs, - ) - ) - if node.loop_order != common.LoopOrder.PARALLEL: - sections = [self.to_state(s, grid_subset=iteration_ctx.grid_subset) for s in sections] - computations = sections - for item in reversed(expansion_items): - iteration_ctx = iteration_ctx.pop() - computations = self._process_iteration_item( - scope=computations, - item=item, - iteration_ctx=iteration_ctx, - global_ctx=global_ctx, - symbol_collector=symbol_collector, - ) - - read_memlets, write_memlets, field_memlets = union_inout_memlets(computations) - - field_decls = global_ctx.get_dcir_decls( - global_ctx.library_node.access_infos, symbol_collector=symbol_collector - ) - - read_fields = set(memlet.field for memlet in read_memlets) - write_fields = set(memlet.field for memlet in write_memlets) - - return dcir.NestedSDFG( - label=global_ctx.library_node.label, - states=self.to_state(computations, grid_subset=iteration_ctx.grid_subset), - field_decls=field_decls, - read_memlets=[memlet for memlet in field_memlets if memlet.field in read_fields], - write_memlets=[memlet for memlet in field_memlets if memlet.field in write_fields], - symbol_decls=list(symbol_collector.symbol_decls.values()), - ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py deleted file mode 100644 index e1acee2111..0000000000 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ /dev/null @@ -1,159 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import copy -from typing import TYPE_CHECKING, ClassVar, Dict, List - -import dace -import dace.data -import dace.library -import dace.subsets -import sympy - -from gt4py.cartesian.gtc.dace import daceir as dcir, prefix -from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder -from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder -from gt4py.cartesian.gtc.dace.expansion.utils import split_horizontal_executions_regions - - -if TYPE_CHECKING: - from gt4py.cartesian.gtc.dace.nodes import StencilComputation - - -class StencilComputationExpansion(dace.library.ExpandTransformation): - environments: ClassVar[List] = [] - - @staticmethod - def _solve_for_domain(field_decls: Dict[str, dcir.FieldDecl], outer_subsets): - equations = [] - symbols = set() - - # Collect equations and symbols from arguments and shapes - for field, decl in field_decls.items(): - inner_shape = [dace.symbolic.pystr_to_symbolic(s) for s in decl.shape] - outer_shape = [ - dace.symbolic.pystr_to_symbolic(s) for s in outer_subsets[field].bounding_box_size() - ] - - for inner_dim, outer_dim in zip(inner_shape, outer_shape): - repldict = {} - for sym in dace.symbolic.symlist(inner_dim).values(): - newsym = dace.symbolic.symbol("__SOLVE_" + str(sym)) - symbols.add(newsym) - repldict[sym] = newsym - - # Replace symbols with __SOLVE_ symbols so as to allow - # the same symbol in the called SDFG - if repldict: - inner_dim = inner_dim.subs(repldict) - - equations.append(inner_dim - outer_dim) - if len(symbols) == 0: - return {} - - # Solve for all at once - results = sympy.solve(equations, *symbols, dict=True) - result = results[0] - result = {str(k)[len("__SOLVE_") :]: v for k, v in result.items()} - return result - - @staticmethod - def _fix_context( - nsdfg, node: StencilComputation, parent_state: dace.SDFGState, daceir: dcir.NestedSDFG - ): - """Apply changes to StencilComputation and the SDFG it is embedded in to satisfy post-expansion constraints. - - * change connector names to match inner array name (before expansion prefixed to satisfy uniqueness) - * change in- and out-edges' subsets so that they have the same shape as the corresponding array inside - * determine the domain size based on edges to StencilComputation - """ - # change connector names - for in_edge in parent_state.in_edges(node): - assert in_edge.dst_conn.startswith(prefix.CONNECTOR_IN) - in_edge.dst_conn = in_edge.dst_conn.removeprefix(prefix.CONNECTOR_IN) - for out_edge in parent_state.out_edges(node): - assert out_edge.src_conn.startswith(prefix.CONNECTOR_OUT) - out_edge.src_conn = out_edge.src_conn.removeprefix(prefix.CONNECTOR_OUT) - - # union input and output subsets - subsets = {} - for edge in parent_state.in_edges(node): - subsets[edge.dst_conn] = edge.data.subset - for edge in parent_state.out_edges(node): - subsets[edge.src_conn] = dace.subsets.union( - edge.data.subset, subsets.get(edge.src_conn, edge.data.subset) - ) - # ensure single-use of input and output subset instances - for edge in parent_state.in_edges(node): - edge.data.subset = copy.deepcopy(subsets[edge.dst_conn]) - for edge in parent_state.out_edges(node): - edge.data.subset = copy.deepcopy(subsets[edge.src_conn]) - - # determine "__I", "__J" and "__K" values based on memlets to StencilComputation's shape - symbol_mapping = StencilComputationExpansion._solve_for_domain( - { - decl.name: decl - for decl in daceir.field_decls - if decl.name - in set(memlet.field for memlet in daceir.read_memlets + daceir.write_memlets) - }, - subsets, - ) - nsdfg.symbol_mapping.update({**symbol_mapping, **node.symbol_mapping}) - - # remove unused symbols from symbol_mapping - delkeys = set() - for sym in node.symbol_mapping.keys(): - if str(sym) not in nsdfg.sdfg.free_symbols: - delkeys.add(str(sym)) - for key in delkeys: - del node.symbol_mapping[key] - if key in nsdfg.symbol_mapping: - del nsdfg.symbol_mapping[key] - - for edge in parent_state.in_edges(node): - if edge.dst_conn not in nsdfg.in_connectors: - # Drop connection if connector is not found in the expansion of the library node - parent_state.remove_edge(edge) - if parent_state.in_degree(edge.src) + parent_state.out_degree(edge.src) == 0: - # Remove node if it is now isolated - parent_state.remove_node(edge.src) - - @staticmethod - def _get_parent_arrays( - node: StencilComputation, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG - ) -> Dict[str, dace.data.Data]: - parent_arrays: Dict[str, dace.data.Data] = {} - for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None): - parent_arrays[edge.dst_conn.removeprefix(prefix.CONNECTOR_IN)] = parent_sdfg.arrays[ - edge.data.data - ] - for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None): - parent_arrays[edge.src_conn.removeprefix(prefix.CONNECTOR_OUT)] = parent_sdfg.arrays[ - edge.data.data - ] - return parent_arrays - - @staticmethod - def expansion( - node: StencilComputation, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG - ) -> dace.nodes.NestedSDFG: - """Expand the coarse SDFG in parent_sdfg to a NestedSDFG with all the states.""" - split_horizontal_executions_regions(node) - arrays = StencilComputationExpansion._get_parent_arrays(node, parent_state, parent_sdfg) - - daceir: dcir.NestedSDFG = DaCeIRBuilder().visit( - node.oir_node, global_ctx=DaCeIRBuilder.GlobalContext(library_node=node, arrays=arrays) - ) - - nsdfg: dace.nodes.NestedSDFG = StencilComputationSDFGBuilder().visit(daceir) - - StencilComputationExpansion._fix_context(nsdfg, node, parent_state, daceir) - return nsdfg diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py deleted file mode 100644 index cdae090b30..0000000000 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ /dev/null @@ -1,660 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -from dataclasses import dataclass -from typing import Any, ChainMap, Dict, List, Optional, Set, Tuple - -import dace -import dace.subsets - -from gt4py import eve -from gt4py.cartesian.gtc.dace import daceir as dcir, prefix -from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen -from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset - - -def _node_name_from_connector(connector: str) -> str: - if not connector.startswith(prefix.TASKLET_IN) and not connector.startswith(prefix.TASKLET_OUT): - raise ValueError( - f"Connector {connector} doesn't follow the in ({prefix.TASKLET_IN}) or out ({prefix.TASKLET_OUT}) prefix rule" - ) - return connector.removeprefix(prefix.TASKLET_OUT).removeprefix(prefix.TASKLET_IN) - - -def _add_empty_edges( - entry_node: dace.nodes.Node, - exit_node: dace.nodes.Node, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - node_ctx: StencilComputationSDFGBuilder.NodeContext, -) -> None: - if not sdfg_ctx.state.in_degree(entry_node) and None in node_ctx.input_node_and_conns: - sdfg_ctx.state.add_edge( - *node_ctx.input_node_and_conns[None], entry_node, None, dace.Memlet() - ) - if not sdfg_ctx.state.out_degree(exit_node) and None in node_ctx.output_node_and_conns: - sdfg_ctx.state.add_edge( - exit_node, None, *node_ctx.output_node_and_conns[None], dace.Memlet() - ) - - -class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait): - @dataclass - class NodeContext: - input_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] - output_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] - - @dataclass - class SDFGContext: - sdfg: dace.SDFG - state: dace.SDFGState - state_stack: List[dace.SDFGState] = dataclasses.field(default_factory=list) - - def add_state(self, label: Optional[str] = None) -> None: - new_state = self.sdfg.add_state(label=label) - for edge in self.sdfg.out_edges(self.state): - self.sdfg.remove_edge(edge) - self.sdfg.add_edge(new_state, edge.dst, edge.data) - self.sdfg.add_edge(self.state, new_state, dace.InterstateEdge()) - self.state = new_state - - def add_loop(self, index_range: dcir.Range) -> None: - loop_state = self.sdfg.add_state("loop_state") - after_state = self.sdfg.add_state("loop_after") - for edge in self.sdfg.out_edges(self.state): - self.sdfg.remove_edge(edge) - self.sdfg.add_edge(after_state, edge.dst, edge.data) - - assert isinstance(index_range.interval, dcir.DomainInterval) - if index_range.stride < 0: - initialize_expr = f"{index_range.interval.end} - 1" - end_expr = f"{index_range.interval.start} - 1" - else: - initialize_expr = str(index_range.interval.start) - end_expr = str(index_range.interval.end) - comparison_op = "<" if index_range.stride > 0 else ">" - condition_expr = f"{index_range.var} {comparison_op} {end_expr}" - _, _, after_state = self.sdfg.add_loop( - before_state=self.state, - loop_state=loop_state, - after_state=after_state, - loop_var=index_range.var, - initialize_expr=initialize_expr, - condition_expr=condition_expr, - increment_expr=f"{index_range.var}+({index_range.stride})", - ) - if index_range.var not in self.sdfg.symbols: - self.sdfg.add_symbol(index_range.var, stype=dace.int32) - - self.state_stack.append(after_state) - self.state = loop_state - - def pop_loop(self) -> None: - self._pop_last("loop_after") - - def add_condition(self, node: dcir.Condition) -> None: - """Inserts a condition after the current self.state. - - The condition consists of an initial state connected to a guard state, which branches - to a true_state and a false_state based on the given condition. Both states then merge - into a merge_state. - - self.state is set to init_state and the other states are pushed on the stack to be - popped with `pop_condition_*()` methods. - """ - # Data model validators enforce this to exist - assert isinstance(node.condition.stmts[0], dcir.AssignStmt) - assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess) - condition_name = node.condition.stmts[0].left.original_name - - merge_state = self.sdfg.add_state("condition_after") - for edge in self.sdfg.out_edges(self.state): - self.sdfg.remove_edge(edge) - self.sdfg.add_edge(merge_state, edge.dst, edge.data) - - # Evaluate node condition - init_state = self.sdfg.add_state("condition_init") - self.sdfg.add_edge(self.state, init_state, dace.InterstateEdge()) - - # Promote condition (from init_state) to symbol - condition_state = self.sdfg.add_state("condition_guard") - self.sdfg.add_edge( - init_state, - condition_state, - dace.InterstateEdge(assignments=dict(if_condition=condition_name)), - ) - - true_state = self.sdfg.add_state("condition_true") - self.sdfg.add_edge( - condition_state, true_state, dace.InterstateEdge(condition="if_condition") - ) - self.sdfg.add_edge(true_state, merge_state, dace.InterstateEdge()) - - false_state = self.sdfg.add_state("condition_false") - self.sdfg.add_edge( - condition_state, false_state, dace.InterstateEdge(condition="not if_condition") - ) - self.sdfg.add_edge(false_state, merge_state, dace.InterstateEdge()) - - self.state_stack.append(merge_state) - self.state_stack.append(false_state) - self.state_stack.append(true_state) - self.state_stack.append(condition_state) - self.state = init_state - - def pop_condition_guard(self) -> None: - self._pop_last("condition_guard") - - def pop_condition_true(self) -> None: - self._pop_last("condition_true") - - def pop_condition_false(self) -> None: - self._pop_last("condition_false") - - def pop_condition_after(self) -> None: - self._pop_last("condition_after") - - def add_while(self, node: dcir.WhileLoop) -> None: - """Inserts a while loop after the current state.""" - # Data model validators enforce this to exist - assert isinstance(node.condition.stmts[0], dcir.AssignStmt) - assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess) - condition_name = node.condition.stmts[0].left.original_name - - after_state = self.sdfg.add_state("while_after") - for edge in self.sdfg.out_edges(self.state): - self.sdfg.remove_edge(edge) - self.sdfg.add_edge(after_state, edge.dst, edge.data) - - # Evaluate loop condition - init_state = self.sdfg.add_state("while_init") - self.sdfg.add_edge(self.state, init_state, dace.InterstateEdge()) - - # Promote condition (from init_state) to symbol - guard_state = self.sdfg.add_state("while_guard") - self.sdfg.add_edge( - init_state, - guard_state, - dace.InterstateEdge(assignments=dict(loop_condition=condition_name)), - ) - - loop_state = self.sdfg.add_state("while_loop") - self.sdfg.add_edge( - guard_state, loop_state, dace.InterstateEdge(condition="loop_condition") - ) - # Loop back to init_state to re-evaluate the loop condition - self.sdfg.add_edge(loop_state, init_state, dace.InterstateEdge()) - - # Exit the loop - self.sdfg.add_edge( - guard_state, after_state, dace.InterstateEdge(condition="not loop_condition") - ) - - self.state_stack.append(after_state) - self.state_stack.append(loop_state) - self.state_stack.append(guard_state) - self.state = init_state - - def pop_while_guard(self) -> None: - self._pop_last("while_guard") - - def pop_while_loop(self) -> None: - self._pop_last("while_loop") - - def pop_while_after(self) -> None: - self._pop_last("while_after") - - def _pop_last(self, node_label: str | None = None) -> None: - if node_label is not None: - assert self.state_stack[-1].label.startswith(node_label) - - self.state = self.state_stack[-1] - del self.state_stack[-1] - - def visit_Memlet( - self, - node: dcir.Memlet, - *, - scope_node: dcir.ComputationNode, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - node_ctx: StencilComputationSDFGBuilder.NodeContext, - connector_prefix: str = "", - symtable: ChainMap[eve.SymbolRef, dcir.Decl], - ) -> None: - field_decl = symtable[node.field] - assert isinstance(field_decl, dcir.FieldDecl) - memlet = dace.Memlet( - node.field, - subset=make_dace_subset(field_decl.access_info, node.access_info, field_decl.data_dims), - dynamic=field_decl.is_dynamic, - ) - if node.is_read: - sdfg_ctx.state.add_edge( - *node_ctx.input_node_and_conns[memlet.data], - scope_node, - connector_prefix + node.connector, - memlet, - ) - if node.is_write: - sdfg_ctx.state.add_edge( - scope_node, - connector_prefix + node.connector, - *node_ctx.output_node_and_conns[memlet.data], - memlet, - ) - - def visit_WhileLoop( - self, - node: dcir.WhileLoop, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - node_ctx: StencilComputationSDFGBuilder.NodeContext, - **kwargs: Any, - ) -> None: - sdfg_ctx.add_while(node) - assert sdfg_ctx.state.label.startswith("while_init") - - read_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} - write_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} - for memlet in node.condition.read_memlets: - if memlet.field not in read_acc_and_conn: - read_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), - None, - ) - for memlet in node.condition.write_memlets: - if memlet.field not in write_acc_and_conn: - write_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), - None, - ) - eval_node_ctx = StencilComputationSDFGBuilder.NodeContext( - input_node_and_conns=read_acc_and_conn, output_node_and_conns=write_acc_and_conn - ) - self.visit(node.condition, sdfg_ctx=sdfg_ctx, node_ctx=eval_node_ctx, **kwargs) - - sdfg_ctx.pop_while_guard() - sdfg_ctx.pop_while_loop() - - for state in node.body: - self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, **kwargs) - - sdfg_ctx.pop_while_after() - - def visit_Condition( - self, - node: dcir.Condition, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - node_ctx: StencilComputationSDFGBuilder.NodeContext, - **kwargs: Any, - ) -> None: - sdfg_ctx.add_condition(node) - assert sdfg_ctx.state.label.startswith("condition_init") - - read_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} - write_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} - for memlet in node.condition.read_memlets: - if memlet.field not in read_acc_and_conn: - read_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), - None, - ) - for memlet in node.condition.write_memlets: - if memlet.field not in write_acc_and_conn: - write_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), - None, - ) - eval_node_ctx = StencilComputationSDFGBuilder.NodeContext( - input_node_and_conns=read_acc_and_conn, output_node_and_conns=write_acc_and_conn - ) - self.visit(node.condition, sdfg_ctx=sdfg_ctx, node_ctx=eval_node_ctx, **kwargs) - - sdfg_ctx.pop_condition_guard() - sdfg_ctx.pop_condition_true() - for state in node.true_states: - self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, **kwargs) - - sdfg_ctx.pop_condition_false() - for state in node.false_states: - self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, **kwargs) - - sdfg_ctx.pop_condition_after() - - def visit_Tasklet( - self, - node: dcir.Tasklet, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs: Any, - ) -> None: - code = TaskletCodegen.apply_codegen( - node, - read_memlets=node.read_memlets, - write_memlets=node.write_memlets, - symtable=symtable, - sdfg=sdfg_ctx.sdfg, - ) - - # We are breaking up vertical loops inside stencils in multiple Tasklets - # It might thus happen that we write a "local" scalar in one Tasklet and - # read it in another Tasklet (downstream). - # We thus create output connectors for all writes to scalar variables - # inside Tasklets. And input connectors for all scalar reads unless - # previously written in the same Tasklet. DaCe's simplify pipeline will get - # rid of any dead dataflow introduced with this general approach. - scalar_inputs: set[str] = set() - scalar_outputs: set[str] = set() - - # Gather scalar writes in this Tasklet - for access_node in node.walk_values().if_isinstance(dcir.AssignStmt): - target_name = access_node.left.name - - field_access = ( - len( - set( - memlet.connector - for memlet in [*node.write_memlets] - if memlet.connector == target_name - ) - ) - > 0 - ) - if field_access or target_name in scalar_outputs: - continue - - assert isinstance(access_node.left, dcir.ScalarAccess) - assert access_node.left.original_name is not None, ( - "Original name not found for '{access_nodes.left.name}'. DaCeIR error." - ) - - original_name = access_node.left.original_name - scalar_outputs.add(target_name) - if original_name not in sdfg_ctx.sdfg.arrays: - sdfg_ctx.sdfg.add_scalar( - original_name, - dtype=data_type_to_dace_typeclass(access_node.left.dtype), - transient=True, - ) - - # Gather scalar reads in this Tasklet - for access_node in node.walk_values().if_isinstance(dcir.ScalarAccess): - read_name = access_node.name - field_access = ( - len( - set( - memlet.connector - for memlet in [*node.read_memlets, *node.write_memlets] - if memlet.connector == read_name - ) - ) - > 0 - ) - defined_symbol = any(read_name in symbol_map for symbol_map in symtable.maps) - - if ( - not field_access - and not defined_symbol - and not access_node.is_target - and read_name.startswith(prefix.TASKLET_IN) - and read_name not in scalar_inputs - ): - scalar_inputs.add(read_name) - - inputs = set(memlet.connector for memlet in node.read_memlets).union(scalar_inputs) - outputs = set(memlet.connector for memlet in node.write_memlets).union(scalar_outputs) - - tasklet = sdfg_ctx.state.add_tasklet( - name=node.label, - code=code, - inputs=inputs, - outputs=outputs, - debuginfo=get_dace_debuginfo(node), - ) - - # Add memlets for scalars access (read/write) - for connector in scalar_outputs: - original_name = _node_name_from_connector(connector) - access_node = sdfg_ctx.state.add_write(original_name) - sdfg_ctx.state.add_memlet_path( - tasklet, access_node, src_conn=connector, memlet=dace.Memlet(data=original_name) - ) - for connector in scalar_inputs: - original_name = _node_name_from_connector(connector) - access_node = sdfg_ctx.state.add_read(original_name) - sdfg_ctx.state.add_memlet_path( - access_node, tasklet, dst_conn=connector, memlet=dace.Memlet(data=original_name) - ) - - # Add memlets for field access (read/write) - self.visit( - node.read_memlets, - scope_node=tasklet, - sdfg_ctx=sdfg_ctx, - symtable=symtable, - **kwargs, - ) - self.visit( - node.write_memlets, - scope_node=tasklet, - sdfg_ctx=sdfg_ctx, - symtable=symtable, - **kwargs, - ) - - def visit_Range(self, node: dcir.Range, **kwargs: Any) -> Dict[str, str]: - start, end = node.interval.to_dace_symbolic() - return {node.var: str(dace.subsets.Range([(start, end - 1, node.stride)]))} - - def visit_DomainMap( - self, - node: dcir.DomainMap, - *, - node_ctx: StencilComputationSDFGBuilder.NodeContext, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs: Any, - ) -> None: - ndranges = { - k: v - for index_range in node.index_ranges - for k, v in self.visit(index_range, **kwargs).items() - } - name = sdfg_ctx.sdfg.label + "".join(ndranges.keys()) + "_map" - map_entry, map_exit = sdfg_ctx.state.add_map( - name=name, - ndrange=ndranges, - schedule=node.schedule.to_dace_schedule(), - debuginfo=get_dace_debuginfo(node), - ) - - for scope_node in node.computations: - input_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} - output_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} - for field in set(memlet.field for memlet in scope_node.read_memlets): - map_entry.add_in_connector(f"{prefix.PASSTHROUGH_IN}{field}") - map_entry.add_out_connector(f"{prefix.PASSTHROUGH_OUT}{field}") - input_node_and_conns[field] = (map_entry, f"{prefix.PASSTHROUGH_OUT}{field}") - for field in set(memlet.field for memlet in scope_node.write_memlets): - map_exit.add_in_connector(f"{prefix.PASSTHROUGH_IN}{field}") - map_exit.add_out_connector(f"{prefix.PASSTHROUGH_OUT}{field}") - output_node_and_conns[field] = (map_exit, f"{prefix.PASSTHROUGH_IN}{field}") - if not input_node_and_conns: - input_node_and_conns[None] = (map_entry, None) - if not output_node_and_conns: - output_node_and_conns[None] = (map_exit, None) - inner_node_ctx = StencilComputationSDFGBuilder.NodeContext( - input_node_and_conns=input_node_and_conns, - output_node_and_conns=output_node_and_conns, - ) - self.visit(scope_node, sdfg_ctx=sdfg_ctx, node_ctx=inner_node_ctx, **kwargs) - - self.visit( - node.read_memlets, - scope_node=map_entry, - sdfg_ctx=sdfg_ctx, - node_ctx=node_ctx, - connector_prefix=prefix.PASSTHROUGH_IN, - **kwargs, - ) - self.visit( - node.write_memlets, - scope_node=map_exit, - sdfg_ctx=sdfg_ctx, - node_ctx=node_ctx, - connector_prefix=prefix.PASSTHROUGH_OUT, - **kwargs, - ) - _add_empty_edges(map_entry, map_exit, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx) - - def visit_DomainLoop( - self, - node: dcir.DomainLoop, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs: Any, - ) -> None: - sdfg_ctx.add_loop(node.index_range) - self.visit(node.loop_states, sdfg_ctx=sdfg_ctx, **kwargs) - sdfg_ctx.pop_loop() - - def visit_ComputationState( - self, - node: dcir.ComputationState, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs: Any, - ) -> None: - sdfg_ctx.add_state() - - # node_ctx is used to keep track of memlets per ComputationState. Conditions and WhileLoops - # will (recursively) introduce more than one compute state per vertical loop. We thus drop - # any node_ctx that is potentially passed down and instead create a new one for each - # ComputationState that we encounter. - kwargs.pop("node_ctx", None) - - read_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} - write_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} - for computation in node.computations: - assert isinstance(computation, dcir.ComputationNode) - for memlet in computation.read_memlets: - if memlet.field not in read_acc_and_conn: - read_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field), - None, - ) - for memlet in computation.write_memlets: - if memlet.field not in write_acc_and_conn: - write_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field), - None, - ) - node_ctx = StencilComputationSDFGBuilder.NodeContext( - input_node_and_conns=read_acc_and_conn, output_node_and_conns=write_acc_and_conn - ) - self.visit(computation, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, **kwargs) - - def visit_FieldDecl( - self, - node: dcir.FieldDecl, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - non_transients: Set[eve.SymbolRef], - **kwargs: Any, - ) -> None: - assert len(node.strides) == len(node.shape) - sdfg_ctx.sdfg.add_array( - node.name, - shape=node.shape, - strides=[dace.symbolic.pystr_to_symbolic(s) for s in node.strides], - dtype=data_type_to_dace_typeclass(node.dtype), - storage=node.storage.to_dace_storage(), - transient=node.name not in non_transients, - debuginfo=get_dace_debuginfo(node), - ) - - def visit_SymbolDecl( - self, - node: dcir.SymbolDecl, - *, - sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs: Any, - ) -> None: - if node.name not in sdfg_ctx.sdfg.symbols: - sdfg_ctx.sdfg.add_symbol(node.name, stype=data_type_to_dace_typeclass(node.dtype)) - - def visit_NestedSDFG( - self, - node: dcir.NestedSDFG, - *, - sdfg_ctx: Optional[StencilComputationSDFGBuilder.SDFGContext] = None, - node_ctx: Optional[StencilComputationSDFGBuilder.NodeContext] = None, - symtable: ChainMap[eve.SymbolRef, Any], - **kwargs: Any, - ) -> dace.nodes.NestedSDFG: - sdfg = dace.SDFG(node.label) - inner_sdfg_ctx = StencilComputationSDFGBuilder.SDFGContext( - sdfg=sdfg, state=sdfg.add_state(is_start_block=True) - ) - self.visit( - node.field_decls, - sdfg_ctx=inner_sdfg_ctx, - non_transients={memlet.connector for memlet in node.read_memlets + node.write_memlets}, - **kwargs, - ) - self.visit(node.symbol_decls, sdfg_ctx=inner_sdfg_ctx, **kwargs) - symbol_mapping = {decl.name: decl.to_dace_symbol() for decl in node.symbol_decls} - - for computation_state in node.states: - self.visit( - computation_state, - sdfg_ctx=inner_sdfg_ctx, - node_ctx=node_ctx, - symtable=symtable, - **kwargs, - ) - - if sdfg_ctx is not None and node_ctx is not None: - nsdfg = sdfg_ctx.state.add_nested_sdfg( - sdfg=sdfg, - parent=None, - inputs=node.input_connectors, - outputs=node.output_connectors, - symbol_mapping=symbol_mapping, - ) - self.visit( - node.read_memlets, - scope_node=nsdfg, - sdfg_ctx=sdfg_ctx, - node_ctx=node_ctx, - symtable=symtable.parents, - **kwargs, - ) - self.visit( - node.write_memlets, - scope_node=nsdfg, - sdfg_ctx=sdfg_ctx, - node_ctx=node_ctx, - symtable=symtable.parents, - **kwargs, - ) - _add_empty_edges(nsdfg, nsdfg, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx) - return nsdfg - - return dace.nodes.NestedSDFG( - label=sdfg.label, - sdfg=sdfg, - inputs={memlet.connector for memlet in node.read_memlets}, - outputs={memlet.connector for memlet in node.write_memlets}, - symbol_mapping=symbol_mapping, - ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py deleted file mode 100644 index cfd6c98832..0000000000 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ /dev/null @@ -1,302 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import copy -from typing import Any, ChainMap, List, Optional, Union - -import dace -import dace.subsets - -from gt4py import eve -from gt4py.cartesian.gtc import common -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.dace.symbol_utils import get_axis_bound_str -from gt4py.cartesian.gtc.dace.utils import make_dace_subset -from gt4py.eve.codegen import FormatTemplate as as_fmt - - -class TaskletCodegen(eve.codegen.TemplatedGenerator, eve.VisitorWithSymbolTableTrait): - ScalarAccess = as_fmt("{name}") - - def _visit_offset( - self, - node: Union[dcir.VariableKOffset, common.CartesianOffset], - *, - access_info: dcir.FieldAccessInfo, - **kwargs: Any, - ) -> str: - int_sizes: List[Optional[int]] = [] - for i, axis in enumerate(access_info.axes()): - memlet_shape = access_info.shape - if str(memlet_shape[i]).isnumeric() and axis not in access_info.variable_offset_axes: - int_sizes.append(int(memlet_shape[i])) - else: - int_sizes.append(None) - sym_offsets = [ - dace.symbolic.pystr_to_symbolic(self.visit(off, access_info=access_info, **kwargs)) - for off in (node.to_dict()["i"], node.to_dict()["j"], node.k) - ] - for axis in access_info.variable_offset_axes: - access_info = access_info.restricted_to_index(axis) - context_info = copy.deepcopy(access_info) - context_info.variable_offset_axes = [] - ranges = make_dace_subset( - access_info, - context_info, - data_dims=(), # data_index added in visit_IndexAccess - ) - ranges.offset(sym_offsets, negative=False) - res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1]) - return str(res) - - def _explicit_indexing( - self, node: common.CartesianOffset | dcir.VariableKOffset, **kwargs: Any - ) -> str: - """If called from the explicit pass we need to be add manually the relative indexing""" - return f"__k+{self.visit(node.k, **kwargs)}" - - def visit_CartesianOffset( - self, node: common.CartesianOffset, explicit=False, **kwargs: Any - ) -> str: - if explicit: - return self._explicit_indexing(node, **kwargs) - - return self._visit_offset(node, **kwargs) - - def visit_VariableKOffset( - self, node: dcir.VariableKOffset, explicit=False, **kwargs: Any - ) -> str: - if explicit: - return self._explicit_indexing(node, **kwargs) - - return self._visit_offset(node, **kwargs) - - def visit_IndexAccess( - self, - node: dcir.IndexAccess, - *, - is_target: bool, - sdfg: dace.SDFG, - symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs: Any, - ) -> str: - if is_target: - memlets = kwargs["write_memlets"] - else: - # if this node is not a target, it will still use the symbol of the write memlet if the - # field was previously written in the same memlet. - memlets = kwargs["read_memlets"] + kwargs["write_memlets"] - - try: - memlet = next(mem for mem in memlets if mem.connector == node.name) - except StopIteration: - raise ValueError( - "Memlet connector and tasklet variable mismatch, DaCe IR error." - ) from None - - index_strs: list[str] = [] - if node.explicit_indices: - # Full array access with every dimensions accessed in full. - # Everything was packed in `explicit_indices` in `DaCeIRBuilder._fix_memlet_array_access()` - # along the `reshape_memlet=True` code path. - assert len(node.explicit_indices) == len(sdfg.arrays[memlet.field].shape) - for idx in node.explicit_indices: - index_strs.append( - self.visit( - idx, - symtable=symtable, - in_idx=True, - explicit=True, - **kwargs, - ) - ) - else: - # Grid-point access, I & J are unitary, K can be offsetted with variable - # Resolve K offset (also resolves I & J) - if node.offset is not None: - index_strs.append( - self.visit( - node.offset, - access_info=memlet.access_info, - symtable=symtable, - in_idx=True, - **kwargs, - ) - ) - # Add any data dimensions - index_strs.extend( - self.visit(idx, symtable=symtable, in_idx=True, **kwargs) for idx in node.data_index - ) - # Filter empty strings - non_empty_indices = list(filter(None, index_strs)) - return ( - f"{node.name}[{','.join(non_empty_indices)}]" - if len(non_empty_indices) > 0 - else node.name - ) - - def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs: Any) -> str: - # Visiting order matters because targets must not contain the target symbols from the left visit - right = self.visit(node.right, is_target=False, **kwargs) - left = self.visit(node.left, is_target=True, **kwargs) - return f"{left} = {right}" - - BinaryOp = as_fmt("({left} {op} {right})") - - UnaryOp = as_fmt("({op}{expr})") - - TernaryOp = as_fmt("({true_expr} if {cond} else {false_expr})") - - def visit_BuiltInLiteral(self, builtin: common.BuiltInLiteral, **kwargs: Any) -> str: - if builtin == common.BuiltInLiteral.TRUE: - return "True" - if builtin == common.BuiltInLiteral.FALSE: - return "False" - raise NotImplementedError("Not implemented BuiltInLiteral encountered.") - - def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs: Any) -> str: - value = self.visit(literal.value, in_idx=in_idx, **kwargs) - if in_idx: - return str(value) - - return "{dtype}({value})".format( - dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value - ) - - Cast = as_fmt("{dtype}({expr})") - - def visit_NativeFunction(self, func: common.NativeFunction, **kwargs: Any) -> str: - try: - return { - common.NativeFunction.ABS: "abs", - common.NativeFunction.MIN: "min", - common.NativeFunction.MAX: "max", - common.NativeFunction.MOD: "fmod", - common.NativeFunction.SIN: "dace.math.sin", - common.NativeFunction.COS: "dace.math.cos", - common.NativeFunction.TAN: "dace.math.tan", - common.NativeFunction.ARCSIN: "asin", - common.NativeFunction.ARCCOS: "acos", - common.NativeFunction.ARCTAN: "atan", - common.NativeFunction.SINH: "dace.math.sinh", - common.NativeFunction.COSH: "dace.math.cosh", - common.NativeFunction.TANH: "dace.math.tanh", - common.NativeFunction.ARCSINH: "asinh", - common.NativeFunction.ARCCOSH: "acosh", - common.NativeFunction.ARCTANH: "atanh", - common.NativeFunction.SQRT: "dace.math.sqrt", - common.NativeFunction.POW: "dace.math.pow", - common.NativeFunction.EXP: "dace.math.exp", - common.NativeFunction.LOG: "dace.math.log", - common.NativeFunction.LOG10: "log10", - common.NativeFunction.GAMMA: "tgamma", - common.NativeFunction.CBRT: "cbrt", - common.NativeFunction.ISFINITE: "isfinite", - common.NativeFunction.ISINF: "isinf", - common.NativeFunction.ISNAN: "isnan", - common.NativeFunction.FLOOR: "dace.math.ifloor", - common.NativeFunction.CEIL: "ceil", - common.NativeFunction.TRUNC: "trunc", - }[func] - except KeyError as error: - raise NotImplementedError("Not implemented NativeFunction encountered.") from error - - def visit_NativeFuncCall(self, call: common.NativeFuncCall, **kwargs: Any) -> str: - # TODO: Unroll integer POW - return f"{self.visit(call.func, **kwargs)}({','.join([self.visit(a, **kwargs) for a in call.args])})" - - def visit_DataType(self, dtype: common.DataType, **kwargs: Any) -> str: - if dtype == common.DataType.BOOL: - return "dace.bool_" - if dtype == common.DataType.INT8: - return "dace.int8" - if dtype == common.DataType.INT16: - return "dace.int16" - if dtype == common.DataType.INT32: - return "dace.int32" - if dtype == common.DataType.INT64: - return "dace.int64" - if dtype == common.DataType.FLOAT32: - return "dace.float32" - if dtype == common.DataType.FLOAT64: - return "dace.float64" - raise NotImplementedError("Not implemented DataType encountered.") - - def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: - if op == common.UnaryOperator.NOT: - return " not " - if op == common.UnaryOperator.NEG: - return "-" - if op == common.UnaryOperator.POS: - return "+" - raise NotImplementedError("Not implemented UnaryOperator encountered.") - - Arg = as_fmt("{name}") - - Param = as_fmt("{name}") - - def visit_Tasklet(self, node: dcir.Tasklet, **kwargs: Any) -> str: - return "\n".join(self.visit(node.stmts, **kwargs)) - - def _visit_conditional( - self, - cond: Optional[Union[dcir.Expr, common.HorizontalMask]], - body: List[dcir.Stmt], - keyword: str, - **kwargs: Any, - ) -> str: - mask_str = "" - indent = "" - if cond is not None and (cond_str := self.visit(cond, is_target=False, **kwargs)): - mask_str = f"{keyword} {cond_str}:" - indent = " " * 4 - body_code = [line for block in self.visit(body, **kwargs) for line in block.split("\n")] - body_code = [indent + b for b in body_code] - return "\n".join([mask_str, *body_code]) - - def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs: Any) -> str: - return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - - def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs: Any) -> str: - return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - - def visit_While(self, node: dcir.While, **kwargs: Any) -> Any: - return self._visit_conditional(cond=node.cond, body=node.body, keyword="while", **kwargs) - - def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs: Any) -> str: - clauses: List[str] = [] - - for axis, interval in zip(dcir.Axis.dims_horizontal(), node.intervals): - it_sym, dom_sym = axis.iteration_symbol(), axis.domain_symbol() - - min_val = get_axis_bound_str(interval.start, dom_sym) - max_val = get_axis_bound_str(interval.end, dom_sym) - if ( - min_val - and max_val - and interval.start is not None - and interval.end is not None - and interval.start.level == interval.end.level - and interval.start.offset + 1 == interval.end.offset - ): - clauses.append(f"{it_sym} == {min_val}") - else: - if min_val: - clauses.append(f"{it_sym} >= {min_val}") - if max_val: - clauses.append(f"{it_sym} < {max_val}") - - return " and ".join(clauses) - - @classmethod - def apply_codegen(cls, node: dcir.Tasklet, **kwargs: Any) -> str: - # NOTE This is not named 'apply' b/c the base class has a method with - # that name and a different type signature. - if not isinstance(node, dcir.Tasklet): - raise ValueError("apply() requires dcir.Tasklet node") - return super().apply(node, **kwargs) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py deleted file mode 100644 index 637b348a03..0000000000 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ /dev/null @@ -1,163 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from typing import TYPE_CHECKING, List - -from gt4py import eve -from gt4py.cartesian.gtc import common, oir -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.definitions import Extent - - -if TYPE_CHECKING: - from gt4py.cartesian.gtc.dace.nodes import StencilComputation - - -class HorizontalIntervalRemover(eve.NodeTranslator): - def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis): - mask_attrs = dict(i=node.i, j=node.j) - mask_attrs[axis.lower()] = self.visit(getattr(node, axis.lower())) - return common.HorizontalMask(**mask_attrs) - - def visit_HorizontalInterval(self, node: common.HorizontalInterval): - return common.HorizontalInterval(start=None, end=None) - - -class HorizontalMaskRemover(eve.NodeTranslator): - def visit_Tasklet(self, node: dcir.Tasklet): - res_body = [] - for stmt in node.stmts: - newstmt = self.visit(stmt) - if isinstance(newstmt, list): - res_body.extend(newstmt) - else: - res_body.append(newstmt) - return dcir.Tasklet( - label=f"he_remover_{id(node)}", - stmts=res_body, - read_memlets=node.read_memlets, - write_memlets=node.write_memlets, - ) - - def visit_MaskStmt(self, node: oir.MaskStmt): - if isinstance(node.mask, common.HorizontalMask): - if ( - node.mask.i.start is None - and node.mask.j.start is None - and node.mask.i.end is None - and node.mask.j.end is None - ): - return self.generic_visit(node.body) - return self.generic_visit(node) - - -def remove_horizontal_region(node, axis): - intervals_removed = HorizontalIntervalRemover().visit(node, axis=axis) - return HorizontalMaskRemover().visit(intervals_removed) - - -def mask_includes_inner_domain(mask: common.HorizontalMask): - for interval in mask.intervals: - if interval.start is None and interval.end is None: - return True - elif ( - interval.start is None - and interval.end is not None - and interval.end.level == common.LevelMarker.END - ): - return True - elif ( - interval.end is None - and interval.start is not None - and interval.start.level == common.LevelMarker.START - ): - return True - elif ( - interval.start is not None - and interval.end is not None - and interval.start.level != interval.end.level - ): - return True - return False - - -class HorizontalExecutionSplitter(eve.NodeTranslator): - @staticmethod - def is_horizontal_execution_splittable(he: oir.HorizontalExecution): - for stmt in he.body: - if isinstance(stmt, oir.HorizontalRestriction) and not mask_includes_inner_domain( - stmt.mask - ): - continue - elif isinstance(stmt, oir.AssignStmt) and isinstance(stmt.left, oir.ScalarAccess): - continue - return False - - # If the regions are not disjoint, then the horizontal executions are not splittable. - regions: List[common.HorizontalMask] = [] - for stmt in he.walk_values().if_isinstance(oir.HorizontalRestriction): - assert isinstance(stmt, oir.HorizontalRestriction) - for region in regions: - if region.i.overlaps(stmt.mask.i) and region.j.overlaps(stmt.mask.j): - return False - regions.append(stmt.mask) - - return True - - def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *, extents, library_node): - if not HorizontalExecutionSplitter.is_horizontal_execution_splittable(node): - extents.append(library_node.get_extents(node)) - return node - - res_he_stmts = [] - scalar_writes = [] - for stmt in node.body: - if isinstance(stmt, oir.AssignStmt): - scalar_writes.append(stmt) - else: - assert isinstance(stmt, oir.HorizontalRestriction) - new_he = oir.HorizontalRestriction( - mask=stmt.mask, body=[*scalar_writes, *stmt.body] - ) - res_he_stmts.append([new_he]) - - res_hes = [] - for stmts in res_he_stmts: - accessed_scalars = ( - eve.walk_values(stmts).if_isinstance(oir.ScalarAccess).getattr("name").to_set() - ) - declarations = [decl for decl in node.declarations if decl.name in accessed_scalars] - res_he = oir.HorizontalExecution(declarations=declarations, body=stmts) - res_hes.append(res_he) - extents.append(library_node.get_extents(node)) - return res_hes - - def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection, **kwargs): - res_hes = [] - for he in node.horizontal_executions: - new_he = self.visit(he, **kwargs) - if isinstance(new_he, list): - res_hes.extend(new_he) - else: - res_hes.append(new_he) - return oir.VerticalLoopSection(interval=node.interval, horizontal_executions=res_hes) - - -def split_horizontal_executions_regions(node: StencilComputation): - extents: List[Extent] = [] - - node.oir_node = HorizontalExecutionSplitter().visit( - node.oir_node, library_node=node, extents=extents - ) - ctr = 0 - for i, section in enumerate(node.oir_node.sections): - for j, _ in enumerate(section.horizontal_executions): - node.extents[j * len(node.oir_node.sections) + i] = extents[ctr] - ctr += 1 diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py deleted file mode 100644 index 9b20507ceb..0000000000 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ /dev/null @@ -1,616 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import copy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, List, Optional, Set, Union - -import dace - -from gt4py.cartesian.gtc import common, oir -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.definitions import Extent - - -if TYPE_CHECKING: - from gt4py.cartesian.gtc.dace.nodes import StencilComputation - -_EXPANSION_VALIDITY_CHECKS: List[Callable] = [] - - -def _register_validity_check(x): - _EXPANSION_VALIDITY_CHECKS.append(x) - return x - - -@dataclass -class ExpansionItem: - pass - - -@dataclass -class Iteration: - axis: dcir.Axis - kind: str # tiling, contiguous - # if stride is not specified, it is chosen based on backend (tiling) and loop order (K) - stride: Optional[int] = None - - @property - def iterations(self) -> List["Iteration"]: - return [self] - - -@dataclass -class Map(ExpansionItem): - iterations: List[Iteration] - schedule: Optional[dace.ScheduleType] = None - - -@dataclass -class Loop(Iteration, ExpansionItem): - kind: str = "contiguous" - storage: dace.StorageType = None - - @property - def iterations(self) -> List[Iteration]: - return [self] - - -@dataclass -class Stages(ExpansionItem): - pass - - -@dataclass -class Sections(ExpansionItem): - pass - - -def _get_axis_from_pattern(item, fmt): - for axis in dcir.Axis.dims_3d(): - if fmt.format(axis=axis) == item: - return axis - return "" - - -def _is_domain_loop(item): - return _get_axis_from_pattern(item, fmt="{axis}Loop") - - -def _is_domain_map(item): - return _get_axis_from_pattern(item, fmt="{axis}Map") - - -def _is_tiling(item): - return _get_axis_from_pattern(item, fmt="Tile{axis}") - - -def get_expansion_order_axis(item): - if axis := ( - _is_domain_map(item) - or _is_domain_loop(item) - or _is_tiling(item) - or _get_axis_from_pattern(item, fmt="{axis}") - ): - return dcir.Axis(axis) - raise ValueError(f"Can't get axis for item '{item}'.") - - -def get_expansion_order_index(expansion_order, axis): - for idx, item in enumerate(expansion_order): - if isinstance(item, Iteration) and item.axis == axis: - return idx - - if isinstance(item, Map): - for it in item.iterations: - if it.kind == "contiguous" and it.axis == axis: - return idx - - -def _is_expansion_order_implemented(expansion_specification): - for item in expansion_specification: - if isinstance(item, Sections): - break - if isinstance(item, Iteration) and item.axis == dcir.Axis.K: - return False - if isinstance(item, Map) and any(it.axis == dcir.Axis.K for it in item.iterations): - return False - - return True - - -def _choose_loop_or_map(node, eo): - if any(eo == axis for axis in dcir.Axis.dims_horizontal()): - return f"{eo}Map" - if eo == dcir.Axis.K: - if node.oir_node.loop_order == common.LoopOrder.PARALLEL: - return f"{eo}Map" - else: - return f"{eo}Loop" - return eo - - -def _order_as_spec( - computation_node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]] -) -> List[ExpansionItem]: - expansion_order = list(_choose_loop_or_map(computation_node, eo) for eo in expansion_order) - expansion_specification = [] - for item in expansion_order: - if isinstance(item, ExpansionItem): - expansion_specification.append(item) - elif axis := _is_tiling(item): - expansion_specification.append( - Map(iterations=[Iteration(axis=axis, kind="tiling", stride=None)]) - ) - elif axis := _is_domain_map(item): - expansion_specification.append( - Map(iterations=[Iteration(axis=axis, kind="contiguous", stride=1)]) - ) - elif axis := _is_domain_loop(item): - expansion_specification.append( - Loop( - axis=axis, - stride=( - -1 - if computation_node.oir_node.loop_order == common.LoopOrder.BACKWARD - else 1 - ), - ) - ) - elif item == "Sections": - expansion_specification.append(Sections()) - else: - assert item == "Stages", item - expansion_specification.append(Stages()) - - return expansion_specification - - -def _populate_strides(node: StencilComputation, expansion_specification: List[ExpansionItem]): - """Fill in `stride` attribute of `Iteration` and `Loop` dataclasses. - - For loops, stride is set to either -1 or 1, based on iteration order. - For tiling maps, the stride is chosen such that the resulting tile size - is that of the tile_size attribute. - Other maps get stride 1. - """ - assert all(isinstance(es, ExpansionItem) for es in expansion_specification) - - iterations = [it for item in expansion_specification for it in getattr(item, "iterations", [])] - - for it in iterations: - if isinstance(it, Loop): - if it.stride is None: - it.stride = -1 if node.oir_node.loop_order == common.LoopOrder.BACKWARD else 1 - else: - if it.stride is None: - if it.kind == "tiling": - if node.extents is not None and it.axis.to_idx() < 2: - extent = Extent.zeros(2) - for he_extent in node.extents.values(): - extent = extent.union(he_extent) - extent = extent[it.axis.to_idx()] - else: - extent = (0, 0) - it.stride = node.tile_strides.get(it.axis, 8) - else: - it.stride = 1 - - -def _populate_storages(expansion_specification: List[ExpansionItem]): - assert all(isinstance(es, ExpansionItem) for es in expansion_specification) - innermost_axes = set(dcir.Axis.dims_3d()) - tiled_axes = set() - for item in expansion_specification: - if isinstance(item, Map): - for it in item.iterations: - if it.kind == "tiling": - tiled_axes.add(it.axis) - for es in reversed(expansion_specification): - if isinstance(es, Map): - for it in es.iterations: - if it.axis in innermost_axes: - innermost_axes.remove(it.axis) - if it.kind == "tiling": - tiled_axes.remove(it.axis) - - -def _populate_cpu_schedules(expansion_specification: List[ExpansionItem]): - is_outermost = True - for es in expansion_specification: - if isinstance(es, Map): - if es.schedule is None: - if is_outermost: - es.schedule = dace.ScheduleType.CPU_Multicore - is_outermost = False - else: - es.schedule = dace.ScheduleType.Default - - -def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]): - # On GPU if any dimension is tiled and has a contiguous map in the same axis further in - # pick those two maps as Device/ThreadBlock maps. If not, Make just device map with - # default blocksizes - is_outermost = True - tiled = False - for i, item in enumerate(expansion_specification): - if isinstance(item, Map): - for it in item.iterations: - if not tiled and it.kind == "tiling": - for inner_item in expansion_specification[i + 1 :]: - if isinstance(inner_item, Map) and any( - inner_it.kind == "contiguous" and inner_it.axis == it.axis - for inner_it in inner_item.iterations - ): - item.schedule = dace.ScheduleType.GPU_Device - inner_item.schedule = dace.ScheduleType.GPU_ThreadBlock - tiled = True - break - if not tiled: - assert any(isinstance(item, Map) for item in expansion_specification), ( - "needs at least one map to avoid dereferencing on CPU" - ) - for es in expansion_specification: - if isinstance(es, Map): - if es.schedule is None: - if is_outermost: - es.schedule = dace.ScheduleType.GPU_Device - is_outermost = False - else: - es.schedule = dace.ScheduleType.Default - - -def _populate_schedules(node: StencilComputation, expansion_specification: List[ExpansionItem]): - assert all(isinstance(es, ExpansionItem) for es in expansion_specification) - assert hasattr(node, "_device") - if node.device == dace.DeviceType.GPU: - _populate_gpu_schedules(expansion_specification) - else: - _populate_cpu_schedules(expansion_specification) - - -def _collapse_maps_gpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: - def _union_map_items(last_item, next_item): - if last_item.schedule == next_item.schedule: - return ( - Map( - iterations=last_item.iterations + next_item.iterations, - schedule=last_item.schedule, - ), - ) - - if next_item.schedule is None or next_item.schedule == dace.ScheduleType.Default: - specified_item = last_item - else: - specified_item = next_item - - if specified_item.schedule is not None and not specified_item == dace.ScheduleType.Default: - return ( - Map( - iterations=last_item.iterations + next_item.iterations, - schedule=specified_item.schedule, - ), - ) - - # one is default and the other None - return ( - Map( - iterations=last_item.iterations + next_item.iterations, - schedule=dace.ScheduleType.Default, - ), - ) - - res_items: List[ExpansionItem] = [] - for item in expansion_specification: - if isinstance(item, Map): - if not res_items or not isinstance(res_items[-1], Map): - res_items.append(item) - else: - res_items[-1:] = _union_map_items(last_item=res_items[-1], next_item=item) - else: - res_items.append(item) - for item in res_items: - if isinstance(item, Map) and ( - item.schedule is None or item.schedule == dace.ScheduleType.Default - ): - item.schedule = dace.ScheduleType.Sequential - return res_items - - -def _collapse_maps_cpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: - res_items: List[ExpansionItem] = [] - for item in expansion_specification: - if isinstance(item, Map): - if ( - not res_items - or not isinstance(res_items[-1], Map) - or any( - it.axis in set(outer_it.axis for outer_it in res_items[-1].iterations) - for it in item.iterations - ) - ): - res_items.append(item) - elif item.schedule == res_items[-1].schedule: - res_items[-1].iterations.extend(item.iterations) - elif item.schedule is None or item.schedule == dace.ScheduleType.Default: - if res_items[-1].schedule == dace.ScheduleType.CPU_Multicore: - res_items[-1].iterations.extend(item.iterations) - else: - res_items.append(item) - elif ( - res_items[-1].schedule is None - or res_items[-1].schedule == dace.ScheduleType.Default - ): - if item.schedule == dace.ScheduleType.CPU_Multicore: - res_items[-1].iterations.extend(item.iterations) - res_items[-1].schedule = dace.ScheduleType.CPU_Multicore - else: - res_items.append(item) - else: - res_items.append(item) - else: - res_items.append(item) - return res_items - - -def _collapse_maps(node: StencilComputation, expansion_specification: List[ExpansionItem]): - assert hasattr(node, "_device") - if node.device == dace.DeviceType.GPU: - res_items = _collapse_maps_gpu(expansion_specification) - else: - res_items = _collapse_maps_cpu(expansion_specification) - expansion_specification.clear() - expansion_specification.extend(res_items) - - -def make_expansion_order( - node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]] -) -> List[ExpansionItem]: - if expansion_order is None: - return None - expansion_order = copy.deepcopy(expansion_order) - expansion_specification = _order_as_spec(node, expansion_order) - - if not _is_expansion_order_implemented(expansion_specification): - raise ValueError("Provided StencilComputation.expansion_order is not supported.") - if node.oir_node is not None: - if not is_expansion_order_valid(node, expansion_specification): - raise ValueError("Provided StencilComputation.expansion_order is invalid.") - - _populate_strides(node, expansion_specification) - _populate_schedules(node, expansion_specification) - _collapse_maps(node, expansion_specification) - _populate_storages(expansion_specification) - return expansion_specification - - -def _k_inside_dims(node: StencilComputation): - # Putting K inside of i or j is valid if - # * K parallel or - # * All reads with k-offset to values modified in same HorizontalExecution are not - # to fields that are also accessed horizontally (in I or J, respectively) - # (else, race condition in other column) - - if node.oir_node.loop_order == common.LoopOrder.PARALLEL: - return {dcir.Axis.I, dcir.Axis.J} - - res = {dcir.Axis.I, dcir.Axis.J} - for section in node.oir_node.sections: - for he in section.horizontal_executions: - i_offset_fields = set( - ( - acc.name - for acc in he.walk_values().if_isinstance(oir.FieldAccess) - if acc.offset.to_dict()["i"] != 0 - ) - ) - j_offset_fields = set( - ( - acc.name - for acc in he.walk_values().if_isinstance(oir.FieldAccess) - if acc.offset.to_dict()["j"] != 0 - ) - ) - k_offset_fields = set( - ( - acc.name - for acc in he.walk_values().if_isinstance(oir.FieldAccess) - if isinstance(acc.offset, oir.VariableKOffset) or acc.offset.to_dict()["k"] != 0 - ) - ) - modified_fields: Set[str] = ( - he.walk_values() - .if_isinstance(oir.AssignStmt) - .getattr("left") - .if_isinstance(oir.FieldAccess) - .getattr("name") - .to_set() - ) - for name in modified_fields: - if name in k_offset_fields and name in i_offset_fields: - res.remove(dcir.Axis.I) - if name in k_offset_fields and name in j_offset_fields: - res.remove(dcir.Axis.J) - return res - - -def _k_inside_stages(node: StencilComputation): - # Putting K inside of stages is valid if - # * K parallel - # * not "ahead" in order of iteration to fields that are modified in previous - # HorizontalExecutions (else, reading updated values that should be old) - - if node.oir_node.loop_order == common.LoopOrder.PARALLEL: - return True - - for section in node.oir_node.sections: - modified_fields: Set[str] = set() - for he in section.horizontal_executions: - if modified_fields: - ahead_acc = list() - for acc in he.walk_values().if_isinstance(oir.FieldAccess): - if ( - isinstance(acc.offset, oir.VariableKOffset) - or ( - node.oir_node.loop_order == common.LoopOrder.FORWARD - and acc.offset.k > 0 - ) - or ( - node.oir_node.loop_order == common.LoopOrder.BACKWARD - and acc.offset.k < 0 - ) - ): - ahead_acc.append(acc) - if any(acc.name in modified_fields for acc in ahead_acc): - return False - - modified_fields.update( - he.walk_values() - .if_isinstance(oir.AssignStmt) - .getattr("left") - .if_isinstance(oir.FieldAccess) - .getattr("name") - .to_set() - ) - - return True - - -@_register_validity_check -def _sequential_as_loops( - node: StencilComputation, expansion_specification: List[ExpansionItem] -) -> bool: - # K can't be Map if not parallel - if node.oir_node.loop_order != common.LoopOrder.PARALLEL and any( - (isinstance(item, Map) and any(it.axis == dcir.Axis.K for it in item.iterations)) - for item in expansion_specification - ): - return False - return True - - -@_register_validity_check -def _stages_inside_sections(expansion_specification: List[ExpansionItem], **kwargs) -> bool: - # Oir defines that HorizontalExecutions have to be applied per VerticalLoopSection. A meaningful inversion of this - # is not possible. - sections_idx = next( - idx for idx, item in enumerate(expansion_specification) if isinstance(item, Sections) - ) - stages_idx = next( - idx for idx, item in enumerate(expansion_specification) if isinstance(item, Stages) - ) - if stages_idx < sections_idx: - return False - return True - - -@_register_validity_check -def _k_inside_ij_valid( - node: StencilComputation, expansion_specification: List[ExpansionItem] -) -> bool: - # OIR defines that horizontal maps go inside vertical K loop (i.e. all grid points are updated in a - # HorizontalExecution before the computation of the next one is executed.). Under certain conditions the semantics - # remain unchanged even if a single horizontal map is executing all contained HorizontalExecution nodes. - # Note: Opposed to e.g. Fusions in OIR, this can here be done on a per-dimension basis. See `_k_inside_dims` for - # details. - for axis in dcir.Axis.dims_horizontal(): - if get_expansion_order_index(expansion_specification, axis) < get_expansion_order_index( - expansion_specification, dcir.Axis.K - ) and axis not in _k_inside_dims(node): - return False - return True - - -@_register_validity_check -def _k_inside_stages_valid( - node: StencilComputation, expansion_specification: List[ExpansionItem] -) -> bool: - # OIR defines that all horizontal executions of a VerticalLoopSection are run per level. Under certain conditions - # the semantics remain unchanged even if the k loop is run per horizontal execution. See `_k_inside_stages` for - # details - stages_idx = next( - idx for idx, item in enumerate(expansion_specification) if isinstance(item, Stages) - ) - if stages_idx < get_expansion_order_index( - expansion_specification, dcir.Axis.K - ) and not _k_inside_stages(node): - return False - return True - - -@_register_validity_check -def _ij_outside_sections_valid( - node: StencilComputation, expansion_specification: List[ExpansionItem] -) -> bool: - # If there are multiple horizontal executions in any section, IJ iteration must go inside sections. - # TODO: do mergeability checks on a per-axis basis. - for item in expansion_specification: - if isinstance(item, Sections): - break - if isinstance(item, (Map, Loop, Iteration)): - for it in item.iterations: - if it.axis in dcir.Axis.dims_horizontal() and it.kind == "contiguous": - if any( - len(section.horizontal_executions) > 1 for section in node.oir_node.sections - ): - return False - - # if there are horizontal executions with different iteration ranges in an axis across sections, - # that iteration must be per section - # TODO less conservative: allow different domains if all outputs smaller than bounding box are temporaries - # TODO implement/allow this with predicates implicit regions / predicates - for item in expansion_specification: - if isinstance(item, Sections): - break - for it in getattr(item, "iterations", []): - if it.axis in dcir.Axis.dims_horizontal() and it.kind == "contiguous": - xiter = iter(node.oir_node.walk_values().if_isinstance(oir.HorizontalExecution)) - extent = node.get_extents(next(xiter)) - for he in xiter: - if node.get_extents(he)[it.axis.to_idx()] != extent[it.axis.to_idx()]: - return False - return True - - -@_register_validity_check -def _iterates_domain(expansion_specification: List[ExpansionItem], **kwargs) -> bool: - # There must be exactly one iteration per dimension, except for tiled dimensions, where a Tiling has to go outside - # and the corresponding contiguous iteration inside. - tiled_axes = set() - contiguous_axes = set() - for item in expansion_specification: - if isinstance(item, (Map, Loop, Iteration)): - for it in item.iterations: - if it.kind == "tiling": - if it.axis in tiled_axes or it.axis in contiguous_axes: - return False - tiled_axes.add(it.axis) - else: - if it.axis in contiguous_axes: - return False - contiguous_axes.add(it.axis) - if not all(axis in contiguous_axes for axis in dcir.Axis.dims_3d()): - return False - return True - - -def is_expansion_order_valid(node: StencilComputation, expansion_order) -> bool: - """Check if a given expansion specification valid. - - That is, it is semantically valid for the StencilComputation node that is to be configured and currently - implemented. - """ - expansion_specification = list(_choose_loop_or_map(node, eo) for eo in expansion_order) - - for check in _EXPANSION_VALIDITY_CHECKS: - if not check(node=node, expansion_specification=expansion_specification): - return False - - return _is_expansion_order_implemented(expansion_specification) diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py deleted file mode 100644 index a21ee20dcd..0000000000 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ /dev/null @@ -1,223 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import base64 -import pickle -import typing -from typing import Dict, Final, List, Optional, Set, Union - -import dace.data -import dace.dtypes -import dace.properties -import dace.subsets -import numpy as np -from dace import library - -from gt4py.cartesian.gtc import common, oir -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion -from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter -from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order -from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo -from gt4py.cartesian.gtc.definitions import Extent -from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection - - -def _set_expansion_order( - node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]] -): - res = make_expansion_order(node, expansion_order) - node._expansion_specification = res - - -def _set_tile_sizes_interpretation(node: StencilComputation, tile_sizes_interpretation: str): - valid_values = {"shape", "strides"} - if tile_sizes_interpretation not in valid_values: - raise ValueError(f"tile_sizes_interpretation must be one in {valid_values}.") - node._tile_sizes_interpretation = tile_sizes_interpretation - - -class PickledProperty: - def to_json(self, obj): - protocol = pickle.DEFAULT_PROTOCOL - pbytes = pickle.dumps(obj, protocol=protocol) - jsonobj = dict(pickle=base64.b64encode(pbytes).decode("utf-8")) - return jsonobj - - @classmethod - def from_json(cls, d, sdfg=None): - # DaCe won't serialize attr with default values by default - # which would lead the deserializer to push a default in the - # wrong format (non pickle). - # Best mitigation is to give back the object plain if it does - # not contain any pickling information - if isinstance(d, dict) and "pickle" in d.keys(): - b64string = d["pickle"] - byte_repr = base64.b64decode(b64string) - return pickle.loads(byte_repr) - - return d - - -class PickledDataclassProperty(PickledProperty, dace.properties.DataclassProperty): - pass - - -class PickledListProperty(PickledProperty, dace.properties.ListProperty): - pass - - -class PickledDictProperty(PickledProperty, dace.properties.DictProperty): - pass - - -@library.node -class StencilComputation(library.LibraryNode): - implementations: Final[Dict[str, dace.library.ExpandTransformation]] = { - "default": StencilComputationExpansion - } - default_implementation = "default" - - oir_node = PickledDataclassProperty(dtype=VerticalLoop, allow_none=True) - - declarations = PickledDictProperty(key_type=str, value_type=Decl, allow_none=True) - extents = PickledDictProperty(key_type=int, value_type=Extent, allow_none=False) - access_infos = PickledDictProperty( - key_type=str, value_type=dcir.FieldAccessInfo, allow_none=True - ) - - device = dace.properties.EnumProperty( - dtype=dace.DeviceType, default=dace.DeviceType.CPU, allow_none=True - ) - expansion_specification = PickledListProperty( - element_type=ExpansionItem, allow_none=True, setter=_set_expansion_order - ) - tile_sizes = PickledDictProperty( - key_type=dcir.Axis, value_type=int, default={dcir.Axis.I: 8, dcir.Axis.J: 8, dcir.Axis.K: 8} - ) - - tile_sizes_interpretation = dace.properties.Property( - setter=_set_tile_sizes_interpretation, dtype=str, default="strides" - ) - - symbol_mapping = dace.properties.DictProperty( - key_type=str, value_type=dace.symbolic.pystr_to_symbolic, default=None, allow_none=True - ) - _dace_library_name = "StencilComputation" - - def __init__( - self, - name="unnamed_vloop", - oir_node: Optional[VerticalLoop] = None, - extents: Optional[Dict[int, Extent]] = None, - declarations: Optional[Dict[str, Decl]] = None, - expansion_order=None, - *args, - **kwargs, - ): - super().__init__(*args, name=name, **kwargs) - - from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos - - if oir_node is not None: - assert extents is not None - assert declarations is not None - extents_dict = dict() - for i, section in enumerate(oir_node.sections): - for j, he in enumerate(section.horizontal_executions): - extents_dict[j * len(oir_node.sections) + i] = extents[id(he)] - - self.oir_node = typing.cast(PickledDataclassProperty, oir_node) - self.extents = extents_dict # type: ignore - self.declarations = declarations # type: ignore - self.symbol_mapping = { - decl.name: dace.symbol( - decl.name, - dtype=dace.typeclass(np.dtype(common.data_type_to_typestr(decl.dtype)).type), - ) - for decl in declarations.values() - if isinstance(decl, oir.ScalarDecl) - } - self.symbol_mapping.update( - { - axis.domain_symbol(): dace.symbol(axis.domain_symbol(), dtype=dace.int32) - for axis in dcir.Axis.dims_horizontal() - } - ) - self.access_infos = compute_dcir_access_infos( - oir_node, - oir_decls=declarations, - block_extents=self.get_extents, - collect_read=True, - collect_write=True, - ) - if any( - interval.start.level == common.LevelMarker.END - or interval.end.level == common.LevelMarker.END - for interval in oir_node.walk_values() - .if_isinstance(VerticalLoopSection) - .getattr("interval") - ) or any( - decl.dimensions[dcir.Axis.K.to_idx()] - for decl in self.declarations.values() - if isinstance(decl, oir.FieldDecl) - ): - self.symbol_mapping[dcir.Axis.K.domain_symbol()] = dace.symbol( - dcir.Axis.K.domain_symbol(), dtype=dace.int32 - ) - - self.debuginfo = get_dace_debuginfo(oir_node) - - if expansion_order is None: - expansion_order = [ - "TileI", - "TileJ", - "Sections", - "K", # Expands to either Loop or Map - "Stages", - "I", - "J", - ] - _set_expansion_order(self, expansion_order) - - def get_extents(self, he): - for i, section in enumerate(self.oir_node.sections): - for j, cand_he in enumerate(section.horizontal_executions): - if he is cand_he: - return self.extents[j * len(self.oir_node.sections) + i] - - @property - def field_decls(self) -> Dict[str, FieldDecl]: - return { - name: decl for name, decl in self.declarations.items() if isinstance(decl, FieldDecl) - } - - @property - def free_symbols(self) -> Set[str]: - result: Set[str] = set() - for v in self.symbol_mapping.values(): - result.update(map(str, v.free_symbols)) - return result - - def has_splittable_regions(self): - for he in self.oir_node.walk_values().if_isinstance(oir.HorizontalExecution): - if not HorizontalExecutionSplitter.is_horizontal_execution_splittable(he): - return False - return True - - @property - def tile_strides(self): - if self.tile_sizes_interpretation == "strides": - return self.tile_sizes - - overall_extent: Extent = next(iter(self.extents.values())) - for extent in self.extents.values(): - overall_extent |= extent - return {key: value + overall_extent[key.to_idx()] for key, value in self.tile_sizes.items()} diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py deleted file mode 100644 index ea0a9b1c18..0000000000 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ /dev/null @@ -1,182 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Dict - -import dace - -from gt4py import eve -from gt4py.cartesian.gtc import oir -from gt4py.cartesian.gtc.dace import daceir as dcir, prefix -from gt4py.cartesian.gtc.dace.nodes import StencilComputation -from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import ( - compute_dcir_access_infos, - get_dace_debuginfo, - make_dace_subset, -) -from gt4py.cartesian.gtc.definitions import Extent -from gt4py.cartesian.gtc.passes.oir_optimizations.utils import ( - AccessCollector, - compute_horizontal_block_extents, -) - - -class OirSDFGBuilder(eve.NodeVisitor): - @dataclass - class SDFGContext: - sdfg: dace.SDFG - current_state: dace.SDFGState - decls: Dict[str, oir.Decl] - block_extents: Dict[int, Extent] - access_infos: Dict[str, dcir.FieldAccessInfo] - loop_counter: int = 0 - - def __init__(self, stencil: oir.Stencil): - self.sdfg = dace.SDFG(stencil.name) - self.current_state = self.sdfg.add_state(is_start_block=True) - self.decls = {decl.name: decl for decl in stencil.params + stencil.declarations} - self.block_extents = compute_horizontal_block_extents(stencil) - - self.access_infos = compute_dcir_access_infos( - stencil, - oir_decls=self.decls, - block_extents=lambda he: self.block_extents[id(he)], - collect_read=True, - collect_write=True, - include_full_domain=True, - ) - - def make_shape(self, field): - if field not in self.access_infos: - return [ - axis.domain_dace_symbol() - for axis in dcir.Axis.dims_3d() - if self.decls[field].dimensions[axis.to_idx()] - ] + [d for d in self.decls[field].data_dims] - return self.access_infos[field].shape + self.decls[field].data_dims - - def make_input_dace_subset(self, node, field): - local_access_info = compute_dcir_access_infos( - node, - collect_read=True, - collect_write=False, - block_extents=lambda he: self.block_extents[id(he)], - oir_decls=self.decls, - )[field] - for axis in local_access_info.variable_offset_axes: - local_access_info = local_access_info.clamp_full_axis(axis) - - return self._make_dace_subset(local_access_info, field) - - def make_output_dace_subset(self, node, field): - local_access_info = compute_dcir_access_infos( - node, - collect_read=False, - collect_write=True, - block_extents=lambda he: self.block_extents[id(he)], - oir_decls=self.decls, - )[field] - for axis in local_access_info.variable_offset_axes: - local_access_info = local_access_info.clamp_full_axis(axis) - - return self._make_dace_subset(local_access_info, field) - - def _make_dace_subset(self, local_access_info, field): - global_access_info = self.access_infos[field] - return make_dace_subset( - global_access_info, local_access_info, self.decls[field].data_dims - ) - - def _vloop_name(self, node: oir.VerticalLoop, ctx: OirSDFGBuilder.SDFGContext) -> str: - sdfg_name = ctx.sdfg.name - counter = ctx.loop_counter - ctx.loop_counter += 1 - - return f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}" - - def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext): - declarations = { - acc.name: ctx.decls[acc.name] - for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess) - if acc.name in ctx.decls - } - library_node = StencilComputation( - name=self._vloop_name(node, ctx), - extents=ctx.block_extents, - declarations=declarations, - oir_node=node, - ) - - state = ctx.sdfg.add_state_after(ctx.current_state) - ctx.current_state = state - state.add_node(library_node) - - access_collection = AccessCollector.apply(node) - - for field in access_collection.read_fields(): - access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) - connector_name = f"{prefix.CONNECTOR_IN}{field}" - library_node.add_in_connector(connector_name) - subset = ctx.make_input_dace_subset(node, field) - state.add_edge( - access_node, None, library_node, connector_name, dace.Memlet(field, subset=subset) - ) - - for field in access_collection.write_fields(): - access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) - connector_name = f"{prefix.CONNECTOR_OUT}{field}" - library_node.add_out_connector(connector_name) - subset = ctx.make_output_dace_subset(node, field) - state.add_edge( - library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset) - ) - - def visit_Stencil(self, node: oir.Stencil): - ctx = OirSDFGBuilder.SDFGContext(stencil=node) - for param in node.params: - if isinstance(param, oir.FieldDecl): - dim_strs = [d for i, d in enumerate("IJK") if param.dimensions[i]] + [ - f"d{d}" for d in range(len(param.data_dims)) - ] - ctx.sdfg.add_array( - param.name, - shape=ctx.make_shape(param.name), - strides=[ - dace.symbolic.pystr_to_symbolic(f"__{param.name}_{dim}_stride") - for dim in dim_strs - ], - dtype=data_type_to_dace_typeclass(param.dtype), - transient=False, - debuginfo=get_dace_debuginfo(param), - ) - else: - ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype)) - - for decl in node.declarations: - dim_strs = [d for i, d in enumerate("IJK") if decl.dimensions[i]] + [ - f"d{d}" for d in range(len(decl.data_dims)) - ] - ctx.sdfg.add_array( - decl.name, - shape=ctx.make_shape(decl.name), - strides=[ - dace.symbolic.pystr_to_symbolic(f"__{decl.name}_{dim}_stride") - for dim in dim_strs - ], - dtype=data_type_to_dace_typeclass(decl.dtype), - transient=True, - lifetime=dace.AllocationLifetime.Persistent, - debuginfo=get_dace_debuginfo(decl), - ) - self.visit(node.vertical_loops, ctx=ctx) - ctx.sdfg.validate() - return ctx.sdfg diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py new file mode 100644 index 0000000000..8bfb979757 --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py @@ -0,0 +1,372 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import operator +from dataclasses import dataclass +from functools import reduce +from typing import Any, Final + +from dace import Memlet, nodes, subsets + +from gt4py import eve +from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc.dace import treeir as tir, utils + + +# Tasklet in/out connector prefixes +TASKLET_IN: Final[str] = "gtIN_" +TASKLET_OUT: Final[str] = "gtOUT_" + + +@dataclass +class Context: + code: list[str] + """Tasklet code, line by line.""" + + targets: set[str] + """Names of fields / scalars that we've already written to. Used for read-after-write analysis.""" + + inputs: dict[str, Memlet] + """Mapping connector names to memlets flowing into the Tasklet.""" + + outputs: dict[str, Memlet] + """Mapping connector names to memlets flowing out of the Tasklet.""" + + tree: tir.TreeRoot + """Schedule tree in which this Tasklet will be inserted.""" + + +class OIRToTasklet(eve.NodeVisitor): + """ + Translate the numerical code from OIR to DaCe Tasklets. + + This visitor should neither attempt transformations nor do any control flow + work. Control flow is the responsibility of OIRToTreeIR. + """ + + def visit_CodeBlock( + self, node: oir.CodeBlock, root: tir.TreeRoot + ) -> tuple[nodes.Tasklet, dict[str, Memlet], dict[str, Memlet]]: + """Entry point to gather all code, inputs and outputs.""" + ctx = Context(code=[], targets=set(), inputs={}, outputs={}, tree=root) + + self.visit(node.body, ctx=ctx) + + tasklet = nodes.Tasklet( + label=node.label, + code="\n".join(ctx.code), + inputs=ctx.inputs.keys(), + outputs=ctx.outputs.keys(), + ) + + return (tasklet, ctx.inputs, ctx.outputs) + + def visit_ScalarAccess(self, node: oir.ScalarAccess, ctx: Context, is_target: bool) -> str: + target = is_target or node.name in ctx.targets + tasklet_name = _tasklet_name(node, target) + + if ( + node.name in ctx.targets # (read or write) after write + or tasklet_name in ctx.inputs # read after read + ): + return tasklet_name + + memlet = Memlet(data=node.name, subset=subsets.Range([(0, 0, 1)])) + if is_target: + # Note: it doesn't matter if we use is_target or target here because if they + # were different, we had a read-after-write situation, which was already + # handled above. + ctx.targets.add(node.name) + ctx.outputs[tasklet_name] = memlet + else: + ctx.inputs[tasklet_name] = memlet + + return tasklet_name + + def visit_FieldAccess(self, node: oir.FieldAccess, ctx: Context, is_target: bool) -> str: + # Derive tasklet name of this access + postfix = _field_offset_postfix(node) + key = f"{node.name}_{postfix}" + target = is_target or key in ctx.targets + tasklet_name = _tasklet_name(node, target, postfix) + + # Gather all parts of the variable name in this list + name_parts = [tasklet_name] + + # Variable K offset subscript + if isinstance(node.offset, oir.VariableKOffset): + symbol = tir.Axis.K.iteration_dace_symbol() + shift = ctx.tree.shift[node.name][tir.Axis.K] + offset = self.visit(node.offset.k, ctx=ctx, is_target=False) + name_parts.append(f"[({symbol}) + ({shift}) + ({offset})]") + + # Data dimension subscript + data_indices: list[str] = [] + for index in node.data_index: + data_indices.append(self.visit(index, ctx=ctx, is_target=False)) + + if data_indices: + name_parts.append(f"[{', '.join(data_indices)}]") + + # In case this is the second access (inside the same tasklet), we can just return the + # name and don't have to build a Memlet anymore. + if ( + key in ctx.targets # (read or write) after write + or tasklet_name in ctx.inputs # read after read + ): + return "".join(filter(None, name_parts)) + + # Build Memlet and add it to inputs/outputs + data_domains: list[int] = ( + ctx.tree.containers[node.name].shape[-len(node.data_index) :] if node.data_index else [] + ) + memlet = Memlet( + data=node.name, + subset=_memlet_subset(node, data_domains, ctx), + volume=reduce(operator.mul, data_domains, 1), # correct volume for VariableK offsets + ) + if is_target: + # Note: it doesn't matter if we use is_target or target here because if they + # were different, we had a read-after-write situation, which was already + # handled above. + ctx.targets.add(key) + ctx.outputs[tasklet_name] = memlet + else: + ctx.inputs[tasklet_name] = memlet + + return "".join(filter(None, name_parts)) + + def visit_AssignStmt(self, node: oir.AssignStmt, ctx: Context) -> None: + # Order matters: always evaluate the right side of an assignment first + right = self.visit(node.right, ctx=ctx, is_target=False) + left = self.visit(node.left, ctx=ctx, is_target=True) + + ctx.code.append(f"{left} = {right}") + + def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> str: + condition = self.visit(node.cond, **kwargs) + if_code = self.visit(node.true_expr, **kwargs) + else_code = self.visit(node.false_expr, **kwargs) + + return f"({if_code} if {condition} else {else_code})" + + def visit_BinaryOp(self, node: oir.BinaryOp, **kwargs: Any) -> str: + left = self.visit(node.left, **kwargs) + right = self.visit(node.right, **kwargs) + + return f"({left} {node.op.value} {right})" + + def visit_UnaryOp(self, node: oir.UnaryOp, **kwargs: Any) -> str: + expr = self.visit(node.expr, **kwargs) + + return f"{node.op.value}({expr})" + + def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> str: + dtype = utils.data_type_to_dace_typeclass(node.dtype) + expression = self.visit(node.expr, **kwargs) + + return f"{dtype}({expression})" + + def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> str: + if type(node.value) is str: + # Note: isinstance(node.value, str) also matches the string enum `BuiltInLiteral` + # which we don't want to match because it returns lower-case `true`, which isn't + # defined in (python) tasklet code. + return node.value + + return self.visit(node.value, **kwargs) + + def visit_BuiltInLiteral(self, node: common.BuiltInLiteral, **_kwargs: Any) -> str: + if node == common.BuiltInLiteral.TRUE: + return "True" + + if node == common.BuiltInLiteral.FALSE: + return "False" + + raise NotImplementedError(f"BuiltInLiteral '{node}' not (yet) implemented.") + + def visit_NativeFunction(self, node: common.NativeFunction, **_kwargs: Any) -> str: + native_functions = { + common.NativeFunction.ABS: "abs", + common.NativeFunction.MIN: "min", + common.NativeFunction.MAX: "max", + common.NativeFunction.MOD: "fmod", + common.NativeFunction.SIN: "dace.math.sin", + common.NativeFunction.COS: "dace.math.cos", + common.NativeFunction.TAN: "dace.math.tan", + common.NativeFunction.ARCSIN: "asin", + common.NativeFunction.ARCCOS: "acos", + common.NativeFunction.ARCTAN: "atan", + common.NativeFunction.SINH: "dace.math.sinh", + common.NativeFunction.COSH: "dace.math.cosh", + common.NativeFunction.TANH: "dace.math.tanh", + common.NativeFunction.ARCSINH: "asinh", + common.NativeFunction.ARCCOSH: "acosh", + common.NativeFunction.ARCTANH: "atanh", + common.NativeFunction.SQRT: "dace.math.sqrt", + common.NativeFunction.POW: "dace.math.pow", + common.NativeFunction.EXP: "dace.math.exp", + common.NativeFunction.LOG: "dace.math.log", + common.NativeFunction.LOG10: "log10", + common.NativeFunction.GAMMA: "tgamma", + common.NativeFunction.CBRT: "cbrt", + common.NativeFunction.ISFINITE: "isfinite", + common.NativeFunction.ISINF: "isinf", + common.NativeFunction.ISNAN: "isnan", + common.NativeFunction.FLOOR: "dace.math.ifloor", + common.NativeFunction.CEIL: "ceil", + common.NativeFunction.TRUNC: "trunc", + } + if node not in native_functions: + raise NotImplementedError(f"NativeFunction '{node}' not (yet) implemented.") + + return native_functions[node] + + def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> str: + function_name = self.visit(node.func, **kwargs) + arguments = ",".join([self.visit(a, **kwargs) for a in node.args]) + + return f"{function_name}({arguments})" + + # Not (yet) supported section + def visit_CacheDesc(self, node: oir.CacheDesc, **kwargs: Any) -> None: + raise NotImplementedError("To be implemented: Caches") + + def visit_IJCache(self, node: oir.IJCache, **kwargs: Any) -> None: + raise NotImplementedError("To be implemented: Caches") + + def visit_KCache(self, node: oir.KCache, **kwargs: Any) -> None: + raise NotImplementedError("To be implemented: Caches") + + # Should _not_ be called + def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs: Any) -> None: + raise RuntimeError("Cartesian Offset should be dealt in Access IRs.") + + def visit_VariableKOffset(self, node: oir.VariableKOffset, **kwargs: Any) -> None: + raise RuntimeError("Variable K Offset should be dealt in Access IRs.") + + def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> None: + raise RuntimeError("visit_MaskStmt should not be called") + + def visit_While(self, node: oir.While, **kwargs: Any) -> None: + raise RuntimeError("visit_While should not be called") + + def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, **kwargs: Any) -> None: + raise RuntimeError("visit_HorizontalRestriction: should be dealt in TreeIR") + + def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> None: + raise RuntimeError("visit_LocalScalar should not be called") + + def visit_Temporary(self, node: oir.Temporary, **kwargs: Any) -> None: + raise RuntimeError("visit_LocalScalar should not be called") + + def visit_Stencil(self, node: oir.Stencil, **kwargs: Any) -> None: + raise RuntimeError("visit_Stencil should not be called") + + def visit_Decl(self, node: oir.Decl, **kwargs: Any) -> None: + raise RuntimeError("visit_Decl should not be called") + + def visit_FieldDecl(self, node: oir.FieldDecl, **kwargs: Any) -> None: + raise RuntimeError("visit_FieldDecl should not be called") + + def visit_ScalarDecl(self, node: oir.ScalarDecl, **kwargs: Any) -> None: + raise RuntimeError("visit_ScalarDecl should not be called") + + def visit_Interval(self, node: oir.Interval, **kwargs: Any) -> None: + raise RuntimeError("visit_Interval should not be called") + + def visit_UnboundedInterval(self, node: oir.UnboundedInterval, **kwargs: Any) -> None: + raise RuntimeError("visit_UnboundedInterval should not be called") + + def visit_HorizontalExecution(self, node: oir.HorizontalExecution, **kwargs: Any) -> None: + raise RuntimeError("visit_HorizontalExecution should not be called") + + def visit_VerticalLoop(self, node: oir.VerticalLoop, **kwargs: Any) -> None: + raise RuntimeError("visit_VerticalLoop should not be called") + + def visit_VerticalLoopSection(self, node: oir.VerticalLoopSection, **kwargs: Any) -> None: + raise RuntimeError("visit_VerticalLoopSection should not be called") + + +def _tasklet_name( + node: oir.FieldAccess | oir.ScalarAccess, is_target: bool, postfix: str = "" +) -> str: + name_prefix = TASKLET_OUT if is_target else TASKLET_IN + return "_".join(filter(None, [name_prefix, node.name, postfix])) + + +def _field_offset_postfix(node: oir.FieldAccess) -> str: + if isinstance(node.offset, oir.VariableKOffset): + return "var_k" + + offset_indicators = [ + f"{k}{'p' if v > 0 else 'm'}{abs(v)}" for k, v in node.offset.to_dict().items() if v != 0 + ] + return "_".join(offset_indicators) + + +def _memlet_subset(node: oir.FieldAccess, data_domains: list[int], ctx: Context) -> subsets.Subset: + if isinstance(node.offset, common.CartesianOffset): + return _memlet_subset_cartesian(node, data_domains, ctx) + + if isinstance(node.offset, oir.VariableKOffset): + return _memlet_subset_variable_offset(node, data_domains, ctx) + + raise NotImplementedError(f"_memlet_subset(): unknown offset type {type(node.offset)}") + + +def _memlet_subset_cartesian( + node: oir.FieldAccess, data_domains: list[int], ctx: Context +) -> subsets.Subset: + """ + Generates the memlet subset for a field access with a cartesian offset. + + Note that we pass data dimensions as a full array into the Tasklet. For cases with data dimensions + we thus need a Range subset. We could use the more narrow Indices subset for cases without data + dimensions. For the sake of simplicity, we choose to always return a Range. + """ + offset_dict = node.offset.to_dict() + dimensions = ctx.tree.dimensions[node.name] + shift = ctx.tree.shift[node.name] + + ranges: list[tuple[str | int, str | int, int]] = [] + # Handle cartesian indices + for index, axis in enumerate(tir.Axis.dims_3d()): + if dimensions[index]: + i = f"({axis.iteration_dace_symbol()}) + ({shift[axis]}) + ({offset_dict[axis.lower()]})" + ranges.append((i, i, 1)) + + # Append data dimensions + for domain_size in data_domains: + ranges.append((0, domain_size - 1, 1)) # ranges are inclusive + + return subsets.Range(ranges) + + +def _memlet_subset_variable_offset( + node: oir.FieldAccess, data_domains: list[int], ctx: Context +) -> subsets.Subset: + """ + Generates the memlet subset for a field access with a variable K offset. + + While we know that we are reading at one specific i/j/k access, the K-access point is only + determined at runtime. We thus pass the K-axis as array into the Tasklet. + """ + # Handle cartesian indices + shift = ctx.tree.shift[node.name] + offset_dict = node.offset.to_dict() + i = f"({tir.Axis.I.iteration_symbol()}) + ({shift[tir.Axis.I]}) + ({offset_dict[tir.Axis.I.lower()]})" + j = f"({tir.Axis.J.iteration_symbol()}) + ({shift[tir.Axis.J]}) + ({offset_dict[tir.Axis.J.lower()]})" + K = f"({tir.Axis.K.domain_symbol()}) + ({shift[tir.Axis.K]}) - 1" # ranges are inclusive + ranges: list[tuple[str | int, str | int, int]] = [(i, i, 1), (j, j, 1), (0, K, 1)] + + # Append data dimensions + for domain_size in data_domains: + ranges.append((0, domain_size - 1, 1)) # ranges are inclusive + + return subsets.Range(ranges) diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py b/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py new file mode 100644 index 0000000000..e1cbbef5ce --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/oir_to_treeir.py @@ -0,0 +1,531 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, List, TypeAlias + +from dace import data, dtypes, symbolic + +from gt4py import eve +from gt4py.cartesian.gtc import common, definitions, oir +from gt4py.cartesian.gtc.dace import oir_to_tasklet, treeir as tir, utils +from gt4py.cartesian.gtc.passes.gtir_k_boundary import compute_k_boundary +from gt4py.cartesian.gtc.passes.oir_optimizations import utils as oir_utils +from gt4py.cartesian.stencil_builder import StencilBuilder + + +ControlFlow: TypeAlias = ( + oir.HorizontalExecution | oir.While | oir.MaskStmt | oir.HorizontalRestriction +) +"""All control flow OIR nodes""" + +DEFAULT_STORAGE_TYPE = { + dtypes.DeviceType.CPU: dtypes.StorageType.Default, + dtypes.DeviceType.GPU: dtypes.StorageType.GPU_Global, +} +"""Default dace residency types per device type.""" + +DEFAULT_MAP_SCHEDULE = { + dtypes.DeviceType.CPU: dtypes.ScheduleType.Default, + dtypes.DeviceType.GPU: dtypes.ScheduleType.GPU_Device, +} +"""Default kernel target per device type.""" + + +class OIRToTreeIR(eve.NodeVisitor): + """ + Translate the GT4Py OIR into a Dace-centric TreeIR. + + TreeIR is built to be a minimum representation of DaCe's Schedule + Tree. No transformation is done on TreeIR, though should be done + once the TreeIR has been properly turned into a Schedule Tree. + + This class _does not_ deal with Tasklet representation, it defers that + work to the OIRToTasklet visitor. + """ + + def __init__(self, builder: StencilBuilder) -> None: + device_type_translate = { + "CPU": dtypes.DeviceType.CPU, + "GPU": dtypes.DeviceType.GPU, + } + + device_type = builder.backend.storage_info["device"] + if device_type.upper() not in device_type_translate: + raise ValueError(f"Unknown device type {device_type}.") + + self._device_type = device_type_translate[device_type.upper()] + self._api_signature = builder.gtir.api_signature + self._k_bounds = compute_k_boundary(builder.gtir) + + def visit_CodeBlock(self, node: oir.CodeBlock, ctx: tir.Context) -> None: + dace_tasklet, inputs, outputs = oir_to_tasklet.OIRToTasklet().visit_CodeBlock( + node, root=ctx.root + ) + + tasklet = tir.Tasklet( + tasklet=dace_tasklet, + inputs=inputs, + outputs=outputs, + parent=ctx.current_scope, + ) + ctx.current_scope.children.append(tasklet) + + def _group_statements(self, node: ControlFlow) -> list[oir.CodeBlock | ControlFlow]: + """ + Group the body of a control flow node into CodeBlocks and other ControlFlow. + + This function only groups statements. The job of visiting the groups statements is + left to the caller. + """ + statements: List[ControlFlow | oir.CodeBlock | common.Stmt] = [] + groups: List[ControlFlow | oir.CodeBlock] = [] + + for statement in node.body: + if isinstance(statement, ControlFlow): + if statements != []: + groups.append( + oir.CodeBlock(label=f"he_{id(node)}_{len(groups)}", body=statements) + ) + groups.append(statement) + statements = [] + else: + statements.append(statement) + + if statements != []: + groups.append(oir.CodeBlock(label=f"he_{id(node)}_{len(groups)}", body=statements)) + + return groups + + def _insert_evaluation_tasklet( + self, node: oir.MaskStmt | oir.While, ctx: tir.Context + ) -> tuple[str, oir.AssignStmt]: + """Evaluate condition in a separate tasklet to avoid sympy problems down the line.""" + + prefix = "while" if isinstance(node, oir.While) else "if" + condition_name = f"{prefix}_condition_{id(node)}" + + ctx.root.containers[condition_name] = data.Scalar( + utils.data_type_to_dace_typeclass(common.DataType.BOOL), + transient=True, + storage=dtypes.StorageType.Register, + debuginfo=utils.get_dace_debuginfo(node), + ) + + assignment = oir.AssignStmt( + left=oir.ScalarAccess(name=condition_name), + right=node.cond if isinstance(node, oir.While) else node.mask, + ) + + code_block = oir.CodeBlock(label=f"masklet_{id(node)}", body=[assignment]) + self.visit(code_block, ctx=ctx) + + return (condition_name, assignment) + + def visit_HorizontalExecution(self, node: oir.HorizontalExecution, ctx: tir.Context) -> None: + block_extent = ctx.block_extents[id(node)] + + axis_start_i = f"{block_extent[0][0]}" + axis_start_j = f"{block_extent[1][0]}" + axis_end_i = f"({tir.Axis.I.domain_dace_symbol()}) + ({block_extent[0][1]})" + axis_end_j = f"({tir.Axis.J.domain_dace_symbol()}) + ({block_extent[1][1]})" + + loop = tir.HorizontalLoop( + bounds_i=tir.Bounds(start=axis_start_i, end=axis_end_i), + bounds_j=tir.Bounds(start=axis_start_j, end=axis_end_j), + schedule=DEFAULT_MAP_SCHEDULE[self._device_type], + children=[], + parent=ctx.current_scope, + ) + + with loop.scope(ctx): + # Push local scalars to the tree repository + for local_scalar in node.declarations: + ctx.root.containers[local_scalar.name] = data.Scalar( + dtype=utils.data_type_to_dace_typeclass(local_scalar.dtype), + transient=True, + storage=dtypes.StorageType.Register, + debuginfo=utils.get_dace_debuginfo(local_scalar), + ) + + groups = self._group_statements(node) + self.visit(groups, ctx=ctx) + + def visit_MaskStmt(self, node: oir.MaskStmt, ctx: tir.Context) -> None: + condition_name, _ = self._insert_evaluation_tasklet(node, ctx) + + if_else = tir.IfElse( + if_condition_code=condition_name, children=[], parent=ctx.current_scope + ) + + with if_else.scope(ctx): + groups = self._group_statements(node) + self.visit(groups, ctx=ctx) + + def visit_HorizontalRestriction( + self, node: oir.HorizontalRestriction, ctx: tir.Context + ) -> None: + """Translate `region` concept into If control flow in TreeIR.""" + condition_code = self.visit(node.mask, ctx=ctx) + if_else = tir.IfElse( + if_condition_code=condition_code, children=[], parent=ctx.current_scope + ) + + with if_else.scope(ctx): + groups = self._group_statements(node) + self.visit(groups, ctx=ctx) + + def visit_HorizontalMask(self, node: common.HorizontalMask, ctx: tir.Context) -> str: + loop_i = tir.Axis.I.iteration_symbol() + loop_j = tir.Axis.J.iteration_symbol() + + axis_start_i = "0" + axis_end_i = tir.Axis.I.domain_symbol() + axis_start_j = "0" + axis_end_j = tir.Axis.J.domain_symbol() + + conditions: list[str] = [] + if node.i.start is not None: + conditions.append( + f"{loop_i} >= {self.visit(node.i.start, axis_start=axis_start_i, axis_end=axis_end_i)}" + ) + if node.i.end is not None: + conditions.append( + f"{loop_i} < {self.visit(node.i.end, axis_start=axis_start_i, axis_end=axis_end_i)}" + ) + if node.j.start is not None: + conditions.append( + f"{loop_j} >= {self.visit(node.j.start, axis_start=axis_start_j, axis_end=axis_end_j)}" + ) + if node.j.end is not None: + conditions.append( + f"{loop_j} < {self.visit(node.j.end, axis_start=axis_start_j, axis_end=axis_end_j)}" + ) + + return " and ".join(conditions) + + def visit_While(self, node: oir.While, ctx: tir.Context) -> None: + condition_name, assignment = self._insert_evaluation_tasklet(node, ctx) + + # Re-evaluate the condition as last step of the while loop + node.body.append(assignment) + + # Use the mask created for conditional check + while_ = tir.While( + condition_code=condition_name, + children=[], + parent=ctx.current_scope, + ) + + with while_.scope(ctx): + groups = self._group_statements(node) + self.visit(groups, ctx=ctx) + + def visit_AxisBound(self, node: oir.AxisBound, axis_start: str, axis_end: str) -> str: + if node.level == common.LevelMarker.START: + return f"({axis_start}) + ({node.offset})" + + return f"({axis_end}) + ({node.offset})" + + def visit_Interval( + self, node: oir.Interval, loop_order: common.LoopOrder, axis_start: str, axis_end: str + ) -> tir.Bounds: + start = self.visit(node.start, axis_start=axis_start, axis_end=axis_end) + end = self.visit(node.end, axis_start=axis_start, axis_end=axis_end) + + if loop_order == common.LoopOrder.BACKWARD: + return tir.Bounds(start=f"{end} - 1", end=start) + + return tir.Bounds(start=start, end=end) + + def _vertical_loop_schedule(self) -> dtypes.ScheduleType: + """ + Defines the vertical loop schedule. + + Current strategy is to + - keep the vertical loop on the host for both, CPU and GPU targets + - and run it in parallel on CPU and sequential on GPU. + """ + if self._device_type == dtypes.DeviceType.GPU: + return dtypes.ScheduleType.Sequential + + return DEFAULT_MAP_SCHEDULE[self._device_type] + + def visit_VerticalLoopSection( + self, node: oir.VerticalLoopSection, ctx: tir.Context, loop_order: common.LoopOrder + ) -> None: + bounds = self.visit( + node.interval, + loop_order=loop_order, + axis_start="0", + axis_end=tir.Axis.K.domain_dace_symbol(), + ) + + loop = tir.VerticalLoop( + loop_order=loop_order, + bounds_k=bounds, + schedule=self._vertical_loop_schedule(), + children=[], + parent=ctx.current_scope, + ) + + with loop.scope(ctx): + self.visit(node.horizontal_executions, ctx=ctx) + + def visit_VerticalLoop(self, node: oir.VerticalLoop, ctx: tir.Context) -> None: + if node.caches: + raise NotImplementedError("Caches are not supported in this prototype.") + + self.visit(node.sections, ctx=ctx, loop_order=node.loop_order) + + def visit_Stencil(self, node: oir.Stencil) -> tir.TreeRoot: + # setup the descriptor repository + containers: dict[str, data.Data] = {} + dimensions: dict[str, tuple[bool, bool, bool]] = {} + symbols: tir.SymbolDict = {} + shift: dict[str, dict[tir.Axis, int]] = {} # dict of field_name -> (dict of axis -> shift) + + # this is ij blocks = horizontal execution + field_extents, block_extents = oir_utils.compute_extents( + node, + centered_extent=True, + ) + # When determining the shape of the array, we have to look at the field extents at large. + # GT4Py tries to give a precise measure by looking at the horizontal restriction and reduce + # the extent to the only grid points inside the mask. DaCe requires the real size of the + # data, hence the call with ignore_horizontal_mask=True. + field_without_mask_extents = oir_utils.compute_fields_extents( + node, + centered_extent=True, + ignore_horizontal_mask=True, + ) + + missing_api_parameters: list[str] = [p.name for p in self._api_signature] + for param in node.params: + missing_api_parameters.remove(param.name) + if isinstance(param, oir.ScalarDecl): + containers[param.name] = data.Scalar( + dtype=utils.data_type_to_dace_typeclass(param.dtype), + debuginfo=utils.get_dace_debuginfo(param), + ) + continue + + if isinstance(param, oir.FieldDecl): + field_extent = field_extents[param.name] + k_bound = self._k_bounds[param.name] + shift[param.name] = { + tir.Axis.I: -field_extent[0][0], + tir.Axis.J: -field_extent[1][0], + tir.Axis.K: max(k_bound[0], 0), + } + containers[param.name] = data.Array( + dtype=utils.data_type_to_dace_typeclass(param.dtype), + shape=get_dace_shape( + param, + field_without_mask_extents[param.name], + k_bound, + symbols, + ), + strides=get_dace_strides(param, symbols), + storage=DEFAULT_STORAGE_TYPE[self._device_type], + debuginfo=utils.get_dace_debuginfo(param), + ) + dimensions[param.name] = param.dimensions + continue + + raise ValueError(f"Unexpected parameter type {type(param)}.") + + for field in node.declarations: + field_extent = field_extents[field.name] + k_bound = self._k_bounds[field.name] + shift[field.name] = { + tir.Axis.I: -field_extent[0][0], + tir.Axis.J: -field_extent[1][0], + tir.Axis.K: max(k_bound[0], 0), + } + # TODO / Dev Note: Persistent memory is an overkill here - we should scope + # the temporary as close to the tasklets as we can, but any lifetime lower + # than persistent will yield issues with memory leaks. + containers[field.name] = data.Array( + dtype=utils.data_type_to_dace_typeclass(field.dtype), + shape=get_dace_shape(field, field_extent, k_bound, symbols), + strides=get_dace_strides(field, symbols), + transient=True, + lifetime=dtypes.AllocationLifetime.Persistent, + storage=DEFAULT_STORAGE_TYPE[self._device_type], + debuginfo=utils.get_dace_debuginfo(field), + ) + dimensions[field.name] = field.dimensions + + tree = tir.TreeRoot( + name=node.name, + containers=containers, + dimensions=dimensions, + shift=shift, + symbols=symbols, + children=[], + parent=None, + ) + + ctx = tir.Context( + root=tree, + current_scope=tree, + field_extents=field_extents, + block_extents=block_extents, + ) + + self.visit(node.vertical_loops, ctx=ctx) + + return ctx.root + + # Visit expressions for condition code in ControlFlow + def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> str: + dtype = utils.data_type_to_dace_typeclass(node.dtype) + expression = self.visit(node.expr, **kwargs) + + return f"{dtype}({expression})" + + def visit_CartesianOffset( + self, node: common.CartesianOffset, field: oir.FieldAccess, ctx: tir.Context, **_kwargs: Any + ) -> str: + shift = ctx.root.shift[field.name] + indices: list[str] = [] + + offset_dict = node.to_dict() + for index, axis in enumerate(tir.Axis.dims_3d()): + if ctx.root.dimensions[field.name][index]: + shift_str = f" + {shift[axis]}" if shift[axis] != 0 else "" + indices.append( + f"{axis.iteration_symbol()}{shift_str} + {offset_dict[axis.lower()]}" + ) + + return ", ".join(indices) + + def visit_VariableKOffset( + self, node: oir.VariableKOffset, field: oir.FieldAccess, ctx: tir.Context, **kwargs: Any + ) -> str: + shift = ctx.root.shift[field.name] + i_shift = f" + {shift[tir.Axis.I]}" if shift[tir.Axis.I] != 0 else "" + j_shift = f" + {shift[tir.Axis.J]}" if shift[tir.Axis.J] != 0 else "" + k_shift = f" + {shift[tir.Axis.K]}" if shift[tir.Axis.K] != 0 else "" + + return ( + f"{tir.Axis.I.iteration_symbol()}{i_shift}, " + f"{tir.Axis.J.iteration_symbol()}{j_shift}, " + f"{tir.Axis.K.iteration_symbol()}{k_shift} + {self.visit(node.k, ctx=ctx, **kwargs)}" + ) + + def visit_ScalarAccess(self, node: oir.ScalarAccess, **kwargs: Any) -> str: + return f"{node.name}" + + def visit_FieldAccess(self, node: oir.FieldAccess, **kwargs: Any) -> str: + if node.data_index: + raise NotImplementedError("Data dimensions aren't supported yet.") + + if "field" in kwargs: + kwargs.pop("field") + + field_name = node.name + offsets = self.visit(node.offset, field=node, **kwargs) + + return f"{field_name}[{offsets}]" + + def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> str: + if type(node.value) is str: + # Note: isinstance(node.value, str) also matches the string enum `BuiltInLiteral` + # which we don't want to match because it returns lower-case `true`, which isn't + # defined in (python) tasklet code. + return node.value + + return self.visit(node.value, **kwargs) + + def visit_BuiltInLiteral(self, node: common.BuiltInLiteral, **kwargs: Any) -> str: + if node == common.BuiltInLiteral.TRUE: + return "True" + + if node == common.BuiltInLiteral.FALSE: + return "False" + + raise NotImplementedError(f"Not implemented BuiltInLiteral '{node}' encountered.") + + def visit_UnaryOp(self, node: oir.UnaryOp, **kwargs: Any) -> str: + expression = self.visit(node.expr, **kwargs) + + return f"{node.op}({expression})" + + def visit_BinaryOp(self, node: oir.BinaryOp, **kwargs: Any) -> str: + left = self.visit(node.left, **kwargs) + right = self.visit(node.right, **kwargs) + + return f"({left} {node.op.value} {right})" + + def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> str: + condition = self.visit(node.cond, **kwargs) + if_code = self.visit(node.true_expr, **kwargs) + else_code = self.visit(node.false_expr, **kwargs) + + return f"({if_code} if {condition} else {else_code})" + + # visitors that should _not_ be called + + def visit_Decl(self, node: oir.Decl, **kwargs: Any) -> None: + raise RuntimeError("visit_Decl should not be called") + + def visit_FieldDecl(self, node: oir.FieldDecl, **kwargs: Any) -> None: + raise RuntimeError("visit_FieldDecl should not be called") + + def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> None: + raise RuntimeError("visit_LocalScalar should not be called") + + +def get_dace_shape( + field: oir.FieldDecl, + extent: definitions.Extent, + k_bound: tuple[int, int], + symbols: tir.SymbolDict, +) -> list[symbolic.symbol]: + shape = [] + for index, axis in enumerate(tir.Axis.dims_3d()): + if field.dimensions[index]: + symbol = axis.domain_dace_symbol() + symbols[axis.domain_symbol()] = dtypes.int32 + + if axis == tir.Axis.I: + i_padding = extent[0][1] - extent[0][0] + if i_padding != 0: + shape.append(symbol + i_padding) + continue + + if axis == tir.Axis.J: + j_padding = extent[1][1] - extent[1][0] + if j_padding != 0: + shape.append(symbol + j_padding) + continue + + if axis == tir.Axis.K: + k_padding = max(k_bound[0], 0) + max(k_bound[1], 0) + if k_padding != 0: + shape.append(symbol + k_padding) + continue + + shape.append(symbol) + + shape.extend([d for d in field.data_dims]) + return shape + + +def get_dace_strides(field: oir.FieldDecl, symbols: tir.SymbolDict) -> list[symbolic.symbol]: + dimension_strings = [d for i, d in enumerate("IJK") if field.dimensions[i]] + data_dimension_strings = [f"d{ddim}" for ddim in range(len(field.data_dims))] + + strides = [] + for dim in dimension_strings + data_dimension_strings: + stride = f"__{field.name}_{dim}_stride" + symbol = symbolic.pystr_to_symbolic(stride) + symbols[stride] = dtypes.int32 + strides.append(symbol) + return strides diff --git a/src/gt4py/cartesian/gtc/dace/prefix.py b/src/gt4py/cartesian/gtc/dace/prefix.py deleted file mode 100644 index 1da9eb95f3..0000000000 --- a/src/gt4py/cartesian/gtc/dace/prefix.py +++ /dev/null @@ -1,23 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - - -from typing import Final - - -# DaCe passthrough prefixes -PASSTHROUGH_IN: Final[str] = "IN_" -PASSTHROUGH_OUT: Final[str] = "OUT_" - -# StencilComputation in/out connector prefixes -CONNECTOR_IN: Final[str] = "__in_" -CONNECTOR_OUT: Final[str] = "__out_" - -# Tasklet in/out connector prefixes -TASKLET_IN: Final[str] = "gtIN__" -TASKLET_OUT: Final[str] = "gtOUT__" diff --git a/src/gt4py/cartesian/gtc/dace/symbol_utils.py b/src/gt4py/cartesian/gtc/dace/symbol_utils.py deleted file mode 100644 index 1e620a3dea..0000000000 --- a/src/gt4py/cartesian/gtc/dace/symbol_utils.py +++ /dev/null @@ -1,71 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from functools import lru_cache -from typing import TYPE_CHECKING - -import numpy as np -from dace import dtypes, symbolic - -from gt4py import eve -from gt4py.cartesian.gtc import common - - -if TYPE_CHECKING: - from gt4py.cartesian.gtc.dace import daceir as dcir - - -def data_type_to_dace_typeclass(data_type: common.DataType) -> dtypes.typeclass: - dtype = np.dtype(common.data_type_to_typestr(data_type)) - return dtypes.typeclass(dtype.type) - - -def get_axis_bound_str(axis_bound, var_name): - from gt4py.cartesian.gtc.common import LevelMarker - - if axis_bound is None: - return "" - elif axis_bound.level == LevelMarker.END: - return f"{var_name}{axis_bound.offset:+d}" - else: - return f"{axis_bound.offset}" - - -def get_axis_bound_dace_symbol(axis_bound: dcir.AxisBound): - from gt4py.cartesian.gtc.common import LevelMarker - - if axis_bound is None: - return - - elif axis_bound.level == LevelMarker.END: - return axis_bound.axis.domain_dace_symbol() + axis_bound.offset - else: - return axis_bound.offset - - -def get_axis_bound_diff_str(axis_bound1, axis_bound2, var_name: str): - if axis_bound1 <= axis_bound2: - axis_bound1, axis_bound2 = axis_bound2, axis_bound1 - sign = "-" - else: - sign = "" - - if axis_bound1.level != axis_bound2.level: - var = var_name - else: - var = "" - return f"{sign}({var}{axis_bound1.offset - axis_bound2.offset:+d})" - - -@lru_cache(maxsize=None) -def get_dace_symbol( - name: eve.SymbolRef, dtype: common.DataType = common.DataType.INT32 -) -> symbolic.symbol: - return symbolic.symbol(name, dtype=data_type_to_dace_typeclass(dtype)) diff --git a/src/gt4py/cartesian/gtc/dace/transformations.py b/src/gt4py/cartesian/gtc/dace/transformations.py deleted file mode 100644 index eccf6f97d1..0000000000 --- a/src/gt4py/cartesian/gtc/dace/transformations.py +++ /dev/null @@ -1,113 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dace -from dace.transformation.dataflow import TrivialMapElimination -from dace.transformation.helpers import nest_state_subgraph -from dace.transformation.interstate import InlineTransients - - -class NoEmptyEdgeTrivialMapElimination(TrivialMapElimination): - """Eliminate trivial maps like TrivialMapElimination, with additional conditions in can_be_applied.""" - - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - if not super().can_be_applied(graph, expr_index, sdfg, permissive=permissive): - return False - - map_entry = self.map_entry - map_exit = graph.exit_node(map_entry) - if map_entry.map.schedule not in { - dace.ScheduleType.Sequential, - dace.ScheduleType.CPU_Multicore, - }: - return False - if any( - edge.data.is_empty() for edge in (graph.in_edges(map_entry) + graph.out_edges(map_exit)) - ): - return False - return True - - -class InlineThreadLocalTransients(dace.transformation.SingleStateTransformation): - """ - Inline and tile thread-local transients. - - Inlines transients like `dace.transformations.interstate.InlineTransients`, however only applies to OpenMP map - scopes but also makes the resulting local arrays persistent and thread-local. This reproduces `cpu_kfirst`-style - transient tiling. - """ - - map_entry = dace.transformation.transformation.PatternNode(dace.nodes.MapEntry) - - @classmethod - def expressions(cls): - return [dace.sdfg.utils.node_path_graph(cls.map_entry)] - - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - map_entry = self.map_entry - - if not map_entry.schedule == dace.ScheduleType.CPU_Multicore: - return False - - scope_subgraph = graph.scope_subgraph(map_entry, include_entry=False, include_exit=False) - if len(scope_subgraph) > 1 or not isinstance( - scope_subgraph.nodes()[0], dace.nodes.NestedSDFG - ): - return False - - candidates = InlineTransients._candidates(sdfg, graph, scope_subgraph.nodes()[0]) - return len(candidates) > 0 - - def apply(self, graph, sdfg): - map_entry = self.map_entry - - scope_subgraph = graph.scope_subgraph(map_entry, include_entry=False, include_exit=False) - nsdfg_node = scope_subgraph.nodes()[0] - candidates = InlineTransients._candidates(sdfg, graph, nsdfg_node) - InlineTransients.apply_to(sdfg, nsdfg=nsdfg_node, save=False) - for name in candidates: - if name in sdfg.arrays: - continue - array: dace.data.Array = nsdfg_node.sdfg.arrays[name] - shape = [dace.symbolic.overapproximate(s) for s in array.shape] - strides = [1] - total_size = shape[0] - for s in reversed(shape[1:]): - strides = [s * strides[0], *strides] - total_size *= s - array.shape = shape - array.strides = strides - array.total_size = total_size - array.storage = dace.StorageType.CPU_ThreadLocal - array.lifetime = dace.AllocationLifetime.Persistent - - -def nest_sequential_map_scopes(sdfg: dace.SDFG): - """Nest map scopes of sequential maps. - - Nest scope subgraphs of sequential maps in NestedSDFG's to force eagerly offsetting pointers on each iteration, to - avoid more complex pointer arithmetic on each Tasklet's invocation. - This is performed in an inner-map-first order to avoid revisiting the graph after changes. - """ - - def _process_map(sdfg: dace.SDFG, state: dace.SDFGState, map_entry: dace.nodes.MapEntry): - for node in state.scope_children()[map_entry]: - if isinstance(node, dace.nodes.NestedSDFG): - nest_sequential_map_scopes(node.sdfg) - elif isinstance(node, dace.nodes.MapEntry): - _process_map(sdfg, state, node) - if map_entry.schedule == dace.ScheduleType.Sequential: - subgraph = state.scope_subgraph(map_entry, include_entry=False, include_exit=False) - nest_state_subgraph(sdfg, state, subgraph) - - state: dace.SDFGState - for state in sdfg.nodes(): - for map_entry in filter( - lambda n: isinstance(n, dace.nodes.MapEntry), state.scope_children()[None] - ): - _process_map(sdfg, state, map_entry) diff --git a/src/gt4py/cartesian/gtc/dace/treeir.py b/src/gt4py/cartesian/gtc/dace/treeir.py new file mode 100644 index 0000000000..390ef328f1 --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/treeir.py @@ -0,0 +1,152 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from dataclasses import dataclass +from types import TracebackType +from typing import Generator, TypeAlias + +from dace import Memlet, data, dtypes, nodes + +from gt4py import eve +from gt4py.cartesian.gtc import common, definitions +from gt4py.cartesian.gtc.dace import utils + + +SymbolDict: TypeAlias = dict[str, dtypes.typeclass] + + +@dataclass +class Context: + root: TreeRoot + current_scope: TreeScope + + field_extents: dict[str, definitions.Extent] # field_name -> Extent + block_extents: dict[int, definitions.Extent] # id(horizontal execution) -> Extent + + +class ContextPushPop: + """Append the node to the scope, then push/pop the scope.""" + + def __init__(self, ctx: Context, node: TreeScope) -> None: + self._ctx = ctx + self._parent_scope = ctx.current_scope + self._node = node + + def __enter__(self) -> None: + self._node.parent = self._parent_scope + self._parent_scope.children.append(self._node) + self._ctx.current_scope = self._node + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self._ctx.current_scope = self._parent_scope + + +class Axis(eve.StrEnum): + I = "I" # noqa: E741 [ambiguous-variable-name] + J = "J" + K = "K" + + def domain_symbol(self) -> eve.SymbolRef: + return eve.SymbolRef(f"__{self.upper()}") + + def iteration_symbol(self) -> eve.SymbolRef: + return eve.SymbolRef(f"__{self.lower()}") + + @staticmethod + def dims_3d() -> Generator[Axis, None, None]: + yield from [Axis.I, Axis.J, Axis.K] + + @staticmethod + def dims_horizontal() -> Generator[Axis, None, None]: + yield from [Axis.I, Axis.J] + + def to_idx(self) -> int: + return [Axis.I, Axis.J, Axis.K].index(self) + + def domain_dace_symbol(self): + return utils.get_dace_symbol(self.domain_symbol()) + + def iteration_dace_symbol(self): + return utils.get_dace_symbol(self.iteration_symbol()) + + def tile_dace_symbol(self): + return utils.get_dace_symbol(self.tile_symbol()) + + +class Bounds(eve.Node): + start: str + end: str + + +class TreeNode(eve.Node): + parent: TreeScope | None + + +class TreeScope(TreeNode): + children: list[TreeScope | TreeNode] + + def scope(self, ctx: Context) -> ContextPushPop: + return ContextPushPop(ctx, self) + + +class Tasklet(TreeNode): + tasklet: nodes.Tasklet + + inputs: dict[str, Memlet] + """Mapping tasklet.in_connectors to Memlets""" + outputs: dict[str, Memlet] + """Mapping tasklet.out_connectors to Memlets""" + + +class IfElse(TreeScope): + # This should become an if/else, someday, so I am naming it if/else in hope + # to see it before my bodily demise + if_condition_code: str + """Condition as ScheduleTree worthy code""" + + +class While(TreeScope): + condition_code: str + """Condition as ScheduleTree worthy code""" + + +class HorizontalLoop(TreeScope): + bounds_i: Bounds + bounds_j: Bounds + + schedule: dtypes.ScheduleType + + +class VerticalLoop(TreeScope): + loop_order: common.LoopOrder + bounds_k: Bounds + + schedule: dtypes.ScheduleType + + +class TreeRoot(TreeScope): + name: str + + containers: dict[str, data.Data] + """Mapping field/scalar names to data descriptors.""" + + dimensions: dict[str, tuple[bool, bool, bool]] + """Mapping field names to shape-axis.""" + + shift: dict[str, dict[Axis, int]] + """Mapping field names to dict[axis] -> shift.""" + + symbols: SymbolDict + """Mapping between type and symbol name.""" diff --git a/src/gt4py/cartesian/gtc/dace/treeir_to_stree.py b/src/gt4py/cartesian/gtc/dace/treeir_to_stree.py new file mode 100644 index 0000000000..bbf6106f0d --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/treeir_to_stree.py @@ -0,0 +1,229 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from dataclasses import dataclass +from types import TracebackType + +from dace import __version__ as dace_version, dtypes, nodes, sdfg, subsets +from dace.codegen import control_flow as dcf +from dace.properties import CodeBlock +from dace.sdfg.analysis.schedule_tree import treenodes as tn + +from gt4py import eve +from gt4py.cartesian.gtc import common +from gt4py.cartesian.gtc.dace import treeir as tir + + +@dataclass +class Context: + tree: tn.ScheduleTreeRoot + """A reference to the tree root.""" + current_scope: tn.ScheduleTreeScope + """A reference to the current scope node.""" + + +class ContextPushPop: + """Append the node to the scope, then push/pop the scope""" + + def __init__(self, ctx: Context, node: tn.ScheduleTreeScope) -> None: + self._ctx = ctx + self._parent_scope = ctx.current_scope + self._node = node + + def __enter__(self) -> None: + self._node.parent = self._parent_scope + self._parent_scope.children.append(self._node) + self._ctx.current_scope = self._node + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self._ctx.current_scope = self._parent_scope + + +class TreeIRToScheduleTree(eve.NodeVisitor): + """Translate TreeIR temporary IR to DaCe's Schedule Tree. + + TreeIR should have undone most of the DSL specificity when translating + from OIR. This should be a rather direct translation. No transformation + should happen here, they should all be done on the resulting Schedule Tree. + """ + + def visit_Tasklet(self, node: tir.Tasklet, ctx: Context) -> None: + tasklet = tn.TaskletNode( + node=node.tasklet, in_memlets=node.inputs, out_memlets=node.outputs + ) + tasklet.parent = ctx.current_scope + ctx.current_scope.children.append(tasklet) + + def visit_HorizontalLoop(self, node: tir.HorizontalLoop, ctx: Context) -> None: + # Define axis iteration symbols + for axis in tir.Axis.dims_horizontal(): + ctx.tree.symbols[axis.iteration_symbol()] = dtypes.int32 + + dace_map = nodes.Map( + label=f"horizontal_loop_{id(node)}", + params=[axis.iteration_symbol() for axis in tir.Axis.dims_horizontal()], + ndrange=subsets.Range( + [ + # -1 because range bounds are inclusive + (node.bounds_i.start, f"{node.bounds_i.end} - 1", 1), + (node.bounds_j.start, f"{node.bounds_j.end} - 1", 1), + ] + ), + schedule=node.schedule, + ) + map_scope = tn.MapScope(node=nodes.MapEntry(dace_map), children=[]) + + with ContextPushPop(ctx, map_scope): + self.visit(node.children, ctx=ctx) + + def visit_VerticalLoop(self, node: tir.VerticalLoop, ctx: Context) -> None: + # In any case, define the iteration symbol + ctx.tree.symbols[tir.Axis.K.iteration_symbol()] = dtypes.int32 + + # For serial loops, create a ForScope and add it to the tree + if node.loop_order != common.LoopOrder.PARALLEL: + for_scope = tn.ForScope(header=_for_scope_header(node), children=[]) + + with ContextPushPop(ctx, for_scope): + self.visit(node.children, ctx=ctx) + + return + + # For parallel loops, create a map and add it to the tree + dace_map = nodes.Map( + label=f"vertical_loop_{id(node)}", + params=[tir.Axis.K.iteration_symbol()], + ndrange=subsets.Range( + # -1 because range bounds are inclusive + [(node.bounds_k.start, f"{node.bounds_k.end} - 1", 1)] + ), + schedule=node.schedule, + ) + map_scope = tn.MapScope(node=nodes.MapEntry(dace_map), children=[]) + + with ContextPushPop(ctx, map_scope): + self.visit(node.children, ctx=ctx) + + def visit_IfElse(self, node: tir.IfElse, ctx: Context) -> None: + if_scope = tn.IfScope( + condition=tn.CodeBlock(node.if_condition_code), + children=[], + ) + + with ContextPushPop(ctx, if_scope): + self.visit(node.children, ctx=ctx) + + def visit_While(self, node: tir.While, ctx: Context) -> None: + while_scope = tn.WhileScope(children=[], header=_while_scope_header(node)) + + with ContextPushPop(ctx, while_scope): + self.visit(node.children, ctx=ctx) + + def visit_TreeRoot(self, node: tir.TreeRoot) -> tn.ScheduleTreeRoot: + """Construct a schedule tree from TreeIR.""" + tree = tn.ScheduleTreeRoot( + name=node.name, + containers=node.containers, + symbols=node.symbols, + constants={}, + children=[], + ) + ctx = Context(tree=tree, current_scope=tree) + + self.visit(node.children, ctx=ctx) + + return ctx.tree + + +def _for_scope_header(node: tir.VerticalLoop) -> dcf.ForScope: + """Header for the tn.ForScope re-using DaCe codegen ForScope. + + Only setup the required data, default or mock the rest. + + TODO: In DaCe 2.x this will be replaced by an SDFG concept which should + be closer and required less mockup. + """ + if not dace_version.startswith("1."): + raise NotImplementedError("DaCe 2.x detected - please fix below code") + if node.loop_order == common.LoopOrder.PARALLEL: + raise ValueError("Parallel vertical loops should be translated to maps instead.") + + plus_minus = "+" if node.loop_order == common.LoopOrder.FORWARD else "-" + comparison = "<" if node.loop_order == common.LoopOrder.FORWARD else ">=" + iteration_var = tir.Axis.K.iteration_symbol() + + for_scope = dcf.ForScope( + condition=CodeBlock( + code=f"{iteration_var} {comparison} {node.bounds_k.end}", + language=dtypes.Language.Python, + ), + itervar=iteration_var, + init=node.bounds_k.start, + update=f"{iteration_var} {plus_minus} 1", + # Unused + parent=None, # not Tree parent, CF parent + dispatch_state=lambda _state: "", + last_block=False, + guard=sdfg.SDFGState(), + body=dcf.GeneralBlock( + lambda _state: "", + None, + True, + None, + [], + [], + [], + [], + [], + False, + ), + init_edges=[], + ) + # Kill the loop_range test for memlet propagation check going in + dcf.ForScope.loop_range = lambda self: None + return for_scope + + +def _while_scope_header(node: tir.While) -> dcf.WhileScope: + """Header for the tn.WhileScope re-using DaCe codegen WhileScope. + + Only setup the required data, default or mock the rest. + + TODO: In DaCe 2.x this will be replaced by an SDFG concept which should + be closer and required less mockup. + """ + if not dace_version.startswith("1."): + raise NotImplementedError("DaCe 2.x detected - please fix below code") + + return dcf.WhileScope( + test=CodeBlock(node.condition_code), + # Unused + guard=sdfg.SDFGState(), + dispatch_state=lambda _state: "", + parent=None, + body=dcf.GeneralBlock( + lambda _state: "", + None, + True, + None, + [], + [], + [], + [], + [], + False, + ), + last_block=False, + ) diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index c552c9dfa1..2187d41b21 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -6,20 +6,14 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from __future__ import annotations - import re -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from functools import lru_cache import numpy as np -from dace import data, dtypes, properties, subsets, symbolic +from dace import data, dtypes, symbolic from gt4py import eve -from gt4py.cartesian.gtc import common, oir -from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset -from gt4py.cartesian.gtc.dace import daceir as dcir, prefix -from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents +from gt4py.cartesian.gtc import common def get_dace_debuginfo(node: common.LocNode) -> dtypes.DebugInfo: @@ -64,543 +58,13 @@ def replace_strides(arrays: list[data.Array], get_layout_map) -> dict[str, str]: return symbol_mapping -def get_tasklet_symbol( - name: str, - *, - offset: Optional[CartesianOffset | VariableKOffset] = None, - is_target: bool, -): - access_name = f"{prefix.TASKLET_OUT}{name}" if is_target else f"{prefix.TASKLET_IN}{name}" - if offset is None: - return access_name - - # add (per axis) offset markers, e.g. gtIN__A_km1 for A[0, 0, -1] - offset_strings = [] - for axis in dcir.Axis.dims_3d(): - axis_offset = offset.to_dict()[axis.lower()] - if axis_offset is not None and axis_offset != 0: - offset_strings.append( - axis.lower() + ("m" if axis_offset < 0 else "p") + f"{abs(axis_offset):d}" - ) - - return access_name + "_".join(offset_strings) - - -def axes_list_from_flags(flags): - return [ax for f, ax in zip(flags, dcir.Axis.dims_3d()) if f] - - -class AccessInfoCollector(eve.NodeVisitor): - def __init__(self, collect_read: bool, collect_write: bool, include_full_domain: bool = False): - self.collect_read: bool = collect_read - self.collect_write: bool = collect_write - self.include_full_domain: bool = include_full_domain - - @dataclass - class Context: - axes: Dict[str, List[dcir.Axis]] - access_infos: Dict[str, dcir.FieldAccessInfo] = field(default_factory=dict) - - def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, block_extents, ctx, **kwargs: Any - ) -> Dict[str, dcir.FieldAccessInfo]: - for section in reversed(node.sections): - self.visit(section, block_extents=block_extents, ctx=ctx, **kwargs) - return ctx.access_infos - - def visit_VerticalLoopSection( - self, node: oir.VerticalLoopSection, *, block_extents, ctx, grid_subset=None, **kwargs: Any - ) -> Dict[str, dcir.FieldAccessInfo]: - inner_ctx = self.Context(axes=ctx.axes) - - if grid_subset is None: - grid_subset = dcir.GridSubset.from_interval(node.interval, dcir.Axis.K) - elif dcir.Axis.K not in grid_subset.intervals: - intervals = dict(dcir.GridSubset.from_interval(node.interval, dcir.Axis.K).intervals) - intervals.update(grid_subset.intervals) - grid_subset = dcir.GridSubset(intervals=intervals) - self.visit( - node.horizontal_executions, - block_extents=block_extents, - ctx=inner_ctx, - grid_subset=grid_subset, - k_interval=node.interval, - **kwargs, - ) - inner_infos = inner_ctx.access_infos - - k_grid = dcir.GridSubset.from_interval(grid_subset.intervals[dcir.Axis.K], dcir.Axis.K) - inner_infos = {name: info.apply_iteration(k_grid) for name, info in inner_infos.items()} - - ctx.access_infos.update( - { - name: info.union(ctx.access_infos.get(name, info)) - for name, info in inner_infos.items() - } - ) - - return ctx.access_infos - - def visit_HorizontalExecution( - self, - node: oir.HorizontalExecution, - *, - block_extents, - ctx: Context, - k_interval, - grid_subset=None, - **kwargs, - ) -> Dict[str, dcir.FieldAccessInfo]: - horizontal_extent = block_extents(node) - - inner_ctx = self.Context(axes=ctx.axes) - inner_infos = inner_ctx.access_infos - ij_grid = dcir.GridSubset.from_gt4py_extent(horizontal_extent) - he_grid = ij_grid.set_interval(dcir.Axis.K, k_interval) - self.visit( - node.body, - horizontal_extent=horizontal_extent, - ctx=inner_ctx, - he_grid=he_grid, - grid_subset=grid_subset, - **kwargs, - ) - - if grid_subset is not None: - for axis in ij_grid.axes(): - if axis in grid_subset.intervals: - ij_grid = ij_grid.set_interval(axis, grid_subset.intervals[axis]) - - inner_infos = {name: info.apply_iteration(ij_grid) for name, info in inner_infos.items()} - - ctx.access_infos.update( - { - name: info.union(ctx.access_infos.get(name, info)) - for name, info in inner_infos.items() - } - ) - - return ctx.access_infos - - def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs): - self.visit(node.right, is_write=False, **kwargs) - self.visit(node.left, is_write=True, **kwargs) - - def visit_HorizontalRestriction( - self, node: oir.HorizontalRestriction, *, is_conditional=False, **kwargs - ): - self.visit(node.mask, is_conditional=is_conditional, **kwargs) - self.visit(node.body, is_conditional=True, region=node.mask, **kwargs) - - def visit_MaskStmt(self, node: oir.MaskStmt, *, is_conditional=False, **kwargs): - self.visit(node.mask, is_conditional=is_conditional, **kwargs) - self.visit(node.body, is_conditional=True, **kwargs) - - def visit_While(self, node: oir.While, *, is_conditional=False, **kwargs): - self.visit(node.cond, is_conditional=is_conditional, **kwargs) - self.visit(node.body, is_conditional=True, **kwargs) - - @staticmethod - def _global_grid_subset( - region: common.HorizontalMask, he_grid: dcir.GridSubset, offset: List[Optional[int]] - ): - res: Dict[ - dcir.Axis, Union[dcir.DomainInterval, dcir.TileInterval, dcir.IndexWithExtent] - ] = {} - if region is not None: - for axis, oir_interval in zip(dcir.Axis.dims_horizontal(), region.intervals): - he_grid_interval = he_grid.intervals[axis] - assert isinstance(he_grid_interval, dcir.DomainInterval) - start = ( - oir_interval.start if oir_interval.start is not None else he_grid_interval.start - ) - end = oir_interval.end if oir_interval.end is not None else he_grid_interval.end - dcir_interval = dcir.DomainInterval( - start=dcir.AxisBound.from_common(axis, start), - end=dcir.AxisBound.from_common(axis, end), - ) - res[axis] = dcir.DomainInterval.union(dcir_interval, res.get(axis, dcir_interval)) - if dcir.Axis.K in he_grid.intervals: - off = offset[dcir.Axis.K.to_idx()] or 0 - he_grid_k_interval = he_grid.intervals[dcir.Axis.K] - assert not isinstance(he_grid_k_interval, dcir.TileInterval) - res[dcir.Axis.K] = he_grid_k_interval.shifted(off) - for axis in dcir.Axis.dims_horizontal(): - iteration_interval = he_grid.intervals[axis] - mask_interval = res.get(axis, iteration_interval) - res[axis] = dcir.DomainInterval.intersection( - axis, iteration_interval, mask_interval - ).shifted(offset[axis.to_idx()]) - return dcir.GridSubset(intervals=res) - - def _make_access_info( - self, - offset_node: Union[CartesianOffset, oir.VariableKOffset], - axes, - is_conditional, - region, - he_grid, - grid_subset, - is_write, - ) -> dcir.FieldAccessInfo: - # Check we have expression offsets in K - offset = [offset_node.to_dict()[k] for k in "ijk"] - variable_offset_axes = [dcir.Axis.K] if isinstance(offset_node, oir.VariableKOffset) else [] - - global_subset = self._global_grid_subset(region, he_grid, offset) - intervals = {} - for axis in axes: - if axis in variable_offset_axes: - intervals[axis] = dcir.IndexWithExtent( - axis=axis, value=axis.iteration_symbol(), extent=(0, 0) - ) - else: - intervals[axis] = dcir.IndexWithExtent( - axis=axis, - value=axis.iteration_symbol(), - extent=(offset[axis.to_idx()], offset[axis.to_idx()]), - ) - grid_subset = dcir.GridSubset(intervals=intervals) - return dcir.FieldAccessInfo( - grid_subset=grid_subset, - global_grid_subset=global_subset, - variable_offset_axes=variable_offset_axes, - ) - - def visit_FieldAccess( - self, - node: oir.FieldAccess, - *, - he_grid, - grid_subset, - is_write: bool = False, - is_conditional: bool = False, - region=None, - ctx: AccessInfoCollector.Context, - **kwargs, - ): - self.visit( - node.offset, - is_conditional=is_conditional, - ctx=ctx, - is_write=False, - region=region, - he_grid=he_grid, - grid_subset=grid_subset, - **kwargs, - ) - - if (is_write and not self.collect_write) or (not is_write and not self.collect_read): - return - - access_info = self._make_access_info( - node.offset, - axes=ctx.axes[node.name], - is_conditional=is_conditional, - region=region, - he_grid=he_grid, - grid_subset=grid_subset, - is_write=is_write, - ) - ctx.access_infos[node.name] = access_info.union( - ctx.access_infos.get(node.name, access_info) - ) - - -def compute_dcir_access_infos( - oir_node, - *, - oir_decls=None, - block_extents=None, - collect_read=True, - collect_write=True, - include_full_domain=False, - **kwargs, -) -> properties.DictProperty: - if block_extents is None: - assert isinstance(oir_node, oir.Stencil) - block_extents = compute_horizontal_block_extents(oir_node) - - axes = { - name: axes_list_from_flags(decl.dimensions) - for name, decl in oir_decls.items() - if isinstance(decl, oir.FieldDecl) - } - ctx = AccessInfoCollector.Context(axes=axes, access_infos=dict()) - AccessInfoCollector(collect_read=collect_read, collect_write=collect_write).visit( - oir_node, block_extents=block_extents, ctx=ctx, **kwargs - ) - if include_full_domain: - res = dict() - for name, access_info in ctx.access_infos.items(): - res[name] = access_info.union( - dcir.FieldAccessInfo( - grid_subset=dcir.GridSubset.full_domain(axes=access_info.axes()), - global_grid_subset=access_info.global_grid_subset, - ) - ) - return res - - return ctx.access_infos - - -class TaskletAccessInfoCollector(eve.NodeVisitor): - @dataclass - class Context: - axes: dict[str, list[dcir.Axis]] - access_infos: dict[str, dcir.FieldAccessInfo] = field(default_factory=dict) - - def __init__( - self, collect_read: bool, collect_write: bool, *, horizontal_extent, k_interval, grid_subset - ): - self.collect_read: bool = collect_read - self.collect_write: bool = collect_write - - self.ij_grid = dcir.GridSubset.from_gt4py_extent(horizontal_extent) - self.he_grid = self.ij_grid.set_interval(dcir.Axis.K, k_interval) - self.grid_subset = grid_subset - - def visit_CodeBlock(self, _node: oir.CodeBlock, **_kwargs): - raise RuntimeError("We shouldn't reach code blocks anymore") - - def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs): - self.visit(node.right, is_write=False, **kwargs) - self.visit(node.left, is_write=True, **kwargs) - - def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs): - self.visit(node.mask, is_write=False, **kwargs) - self.visit(node.body, **kwargs) - - def visit_While(self, node: oir.While, **kwargs): - self.visit(node.cond, is_write=False, **kwargs) - self.visit(node.body, **kwargs) - - def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, **kwargs): - self.visit(node.mask, is_write=False, **kwargs) - self.visit(node.body, region=node.mask, **kwargs) - - def _global_grid_subset( - self, - region: Optional[common.HorizontalMask], - offset: list[Optional[int]], - ): - res: dict[dcir.Axis, dcir.DomainInterval | dcir.IndexWithExtent | dcir.TileInterval] = {} - if region is not None: - for axis, oir_interval in zip(dcir.Axis.dims_horizontal(), region.intervals): - he_grid_interval = self.he_grid.intervals[axis] - assert isinstance(he_grid_interval, dcir.DomainInterval) - start = ( - oir_interval.start if oir_interval.start is not None else he_grid_interval.start - ) - end = oir_interval.end if oir_interval.end is not None else he_grid_interval.end - dcir_interval = dcir.DomainInterval( - start=dcir.AxisBound.from_common(axis, start), - end=dcir.AxisBound.from_common(axis, end), - ) - res[axis] = dcir.DomainInterval.union(dcir_interval, res.get(axis, dcir_interval)) - if dcir.Axis.K in self.he_grid.intervals: - off = offset[dcir.Axis.K.to_idx()] or 0 - he_grid_k_interval = self.he_grid.intervals[dcir.Axis.K] - assert not isinstance(he_grid_k_interval, dcir.TileInterval) - res[dcir.Axis.K] = he_grid_k_interval.shifted(off) - for axis in dcir.Axis.dims_horizontal(): - iteration_interval = self.he_grid.intervals[axis] - mask_interval = res.get(axis, iteration_interval) - res[axis] = dcir.DomainInterval.intersection( - axis, iteration_interval, mask_interval - ).shifted(offset[axis.to_idx()]) - return dcir.GridSubset(intervals=res) - - def _make_access_info( - self, - offset_node: CartesianOffset | VariableKOffset, - axes, - region: Optional[common.HorizontalMask], - ) -> dcir.FieldAccessInfo: - # Check we have expression offsets in K - offset = [offset_node.to_dict()[k] for k in "ijk"] - variable_offset_axes = [dcir.Axis.K] if isinstance(offset_node, VariableKOffset) else [] - - global_subset = self._global_grid_subset(region, offset) - intervals = {} - for axis in axes: - extent = ( - (0, 0) - if axis in variable_offset_axes - else (offset[axis.to_idx()], offset[axis.to_idx()]) - ) - intervals[axis] = dcir.IndexWithExtent( - axis=axis, value=axis.iteration_symbol(), extent=extent - ) - - return dcir.FieldAccessInfo( - grid_subset=dcir.GridSubset(intervals=intervals), - global_grid_subset=global_subset, - # Field access inside horizontal regions might or might not happen - dynamic_access=region is not None, - variable_offset_axes=variable_offset_axes, - ) - - def visit_FieldAccess( - self, - node: oir.FieldAccess, - *, - is_write: bool, - region: Optional[common.HorizontalMask] = None, - ctx: TaskletAccessInfoCollector.Context, - **kwargs, - ): - self.visit(node.offset, ctx=ctx, is_write=False, region=region, **kwargs) - - if (is_write and not self.collect_write) or (not is_write and not self.collect_read): - return - - access_info = self._make_access_info( - node.offset, - axes=ctx.axes[node.name], - region=region, - ) - ctx.access_infos[node.name] = access_info.union( - ctx.access_infos.get(node.name, access_info) - ) - - -def compute_tasklet_access_infos( - node: oir.CodeBlock | oir.MaskStmt | oir.While, - *, - collect_read: bool = True, - collect_write: bool = True, - declarations: dict[str, oir.Decl], - horizontal_extent, - k_interval, - grid_subset, -): - """ - Compute access information needed to build Memlets for the Tasklet - associated with the given `node`. - """ - axes = { - name: axes_list_from_flags(declaration.dimensions) - for name, declaration in declarations.items() - if isinstance(declaration, oir.FieldDecl) - } - ctx = TaskletAccessInfoCollector.Context(axes=axes, access_infos=dict()) - collector = TaskletAccessInfoCollector( - collect_read=collect_read, - collect_write=collect_write, - horizontal_extent=horizontal_extent, - k_interval=k_interval, - grid_subset=grid_subset, - ) - if isinstance(node, oir.CodeBlock): - collector.visit(node.body, ctx=ctx) - elif isinstance(node, oir.MaskStmt): - # node.mask is a simple expression. - # Pass `is_write` explicitly since we don't automatically set it in `visit_AssignStmt()` - collector.visit(node.mask, ctx=ctx, is_write=False) - elif isinstance(node, oir.While): - # node.cond is a simple expression. - # Pass `is_write` explicitly since we don't automatically set it in `visit_AssignStmt()` - collector.visit(node.cond, ctx=ctx, is_write=False) - else: - raise ValueError("Unexpected node type.") - - return ctx.access_infos - - -def make_dace_subset( - context_info: dcir.FieldAccessInfo, - access_info: dcir.FieldAccessInfo, - data_dims: Tuple[int, ...], -) -> subsets.Range: - clamped_access_info = access_info - clamped_context_info = context_info - for axis in access_info.axes(): - if axis in access_info.variable_offset_axes: - clamped_access_info = clamped_access_info.clamp_full_axis(axis) - if axis in context_info.variable_offset_axes: - clamped_context_info = clamped_context_info.clamp_full_axis(axis) - res_ranges = [] - - for axis in clamped_access_info.axes(): - context_start, _ = clamped_context_info.grid_subset.intervals[axis].to_dace_symbolic() - subset_start, subset_end = clamped_access_info.grid_subset.intervals[ - axis - ].to_dace_symbolic() - res_ranges.append((subset_start - context_start, subset_end - context_start - 1, 1)) - res_ranges.extend((0, dim - 1, 1) for dim in data_dims) - return subsets.Range(res_ranges) - - -def untile_memlets(memlets: Sequence[dcir.Memlet], axes: Sequence[dcir.Axis]) -> List[dcir.Memlet]: - res_memlets: List[dcir.Memlet] = [] - for memlet in memlets: - res_memlets.append( - dcir.Memlet( - field=memlet.field, - access_info=memlet.access_info.untile(axes), - connector=memlet.connector, - is_read=memlet.is_read, - is_write=memlet.is_write, - ) - ) - return res_memlets - - -def union_node_grid_subsets(nodes: List[eve.Node]): - grid_subset = None - - for node in collect_toplevel_iteration_nodes(nodes): - if grid_subset is None: - grid_subset = node.grid_subset - grid_subset = grid_subset.union(node.grid_subset) - - return grid_subset - - -def _union_memlets(*memlets: dcir.Memlet) -> List[dcir.Memlet]: - res: Dict[str, dcir.Memlet] = {} - for memlet in memlets: - res[memlet.field] = memlet.union(res.get(memlet.field, memlet)) - return list(res.values()) - - -def union_inout_memlets(nodes: List[eve.Node]): - read_memlets: List[dcir.Memlet] = [] - write_memlets: List[dcir.Memlet] = [] - for node in collect_toplevel_computation_nodes(nodes): - read_memlets = _union_memlets(*read_memlets, *node.read_memlets) - write_memlets = _union_memlets(*write_memlets, *node.write_memlets) - - return (read_memlets, write_memlets, _union_memlets(*read_memlets, *write_memlets)) - - -def flatten_list(list_or_node: Union[List[Any], eve.Node]): - list_or_node = [list_or_node] - while not all(isinstance(ref, eve.Node) for ref in list_or_node): - list_or_node = [r for li in list_or_node for r in li] - return list_or_node - - -def collect_toplevel_computation_nodes( - list_or_node: Union[List[Any], eve.Node], -) -> List[dcir.ComputationNode]: - class ComputationNodeCollector(eve.NodeVisitor): - def visit_ComputationNode(self, node: dcir.ComputationNode, *, collection: List): - collection.append(node) - - collection: List[dcir.ComputationNode] = [] - ComputationNodeCollector().visit(list_or_node, collection=collection) - return collection - +def data_type_to_dace_typeclass(data_type: common.DataType) -> dtypes.typeclass: + dtype = np.dtype(common.data_type_to_typestr(data_type)) + return dtypes.typeclass(dtype.type) -def collect_toplevel_iteration_nodes( - list_or_node: Union[List[Any], eve.Node], -) -> List[dcir.IterationNode]: - class IterationNodeCollector(eve.NodeVisitor): - def visit_IterationNode(self, node: dcir.IterationNode, *, collection: List): - collection.append(node) - collection: List[dcir.IterationNode] = [] - IterationNodeCollector().visit(list_or_node, collection=collection) - return collection +@lru_cache(maxsize=None) +def get_dace_symbol( + name: eve.SymbolRef, dtype: common.DataType = common.DataType.INT32 +) -> symbolic.symbol: + return symbolic.symbol(name, dtype=data_type_to_dace_typeclass(dtype)) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py index 96faf9211a..0c78c00c17 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/utils.py @@ -15,7 +15,7 @@ from gt4py import eve from gt4py.cartesian.gtc import common, oir -from gt4py.cartesian.gtc.definitions import Extent +from gt4py.cartesian.gtc.definitions import CenteredExtent, Extent from gt4py.cartesian.gtc.passes.horizontal_masks import mask_overlap_with_extent @@ -39,21 +39,31 @@ class GenericAccess(Generic[OffsetT]): def is_read(self) -> bool: return not self.is_write - def to_extent(self, horizontal_extent: Extent) -> Optional[Extent]: + def to_extent( + self, + horizontal_extent: Extent, + centered: bool = False, + ignore_horizontal_mask: bool = False, + ) -> Optional[Extent]: """ Convert the access to an extent provided a horizontal extent for the access. - This returns None if no overlap exists between the horizontal mask and interval. + This returns None if no overlap exists between the horizontal mask and interval if + `ignore_horizontal_mask` is not set. """ - offset_as_extent = Extent.from_offset(cast(Tuple[int, int, int], self.offset)[:2]) + if centered: + offset_as_extent = CenteredExtent.from_offset( + cast(Tuple[int, int, int], self.offset)[:2] + ) + else: + offset_as_extent = Extent.from_offset(cast(Tuple[int, int, int], self.offset)[:2]) zeros = Extent.zeros(ndims=2) - if self.horizontal_mask: + if self.horizontal_mask and not ignore_horizontal_mask: if dist_from_edge := mask_overlap_with_extent(self.horizontal_mask, horizontal_extent): return ((horizontal_extent - dist_from_edge) + offset_as_extent) | zeros - else: - return None - else: - return horizontal_extent + offset_as_extent + return None + + return horizontal_extent + offset_as_extent class CartesianAccess(GenericAccess[Tuple[int, int, int]]): @@ -220,13 +230,29 @@ def collect_symbol_names(node: eve.RootNode) -> Set[str]: class StencilExtentComputer(eve.NodeVisitor): + """Compute extent for fields and horizontal blocks. + + Args: + add_k: Add an extent for the K axis. Defaults to `False`. + centered_extent: Center the extent on 0 (negative left, positive right). Defaults to `False`. + ignore_horizontal_mask: When computing extent, do not restrict it by reading the masks of + horizontal regions. Defaults to `False`. + """ + @dataclass class Context: fields: Dict[str, Extent] = dataclasses.field(default_factory=dict) blocks: Dict[int, Extent] = dataclasses.field(default_factory=dict) - def __init__(self, add_k: bool = False): + def __init__( + self, + add_k: bool = False, + centered_extent: bool = False, + ignore_horizontal_mask: bool = False, + ): self.add_k = add_k + self.centered_extent = centered_extent + self.ignore_horizontal_mask = ignore_horizontal_mask self.zero_extent = Extent.zeros(ndims=2) def visit_Stencil(self, node: oir.Stencil) -> Context: @@ -255,7 +281,11 @@ def visit_HorizontalExecution(self, node: oir.HorizontalExecution, *, ctx: Conte ctx.blocks[id(node)] = horizontal_extent for access in results.ordered_accesses(): - extent = access.to_extent(horizontal_extent) + extent = access.to_extent( + horizontal_extent, + centered=self.centered_extent, + ignore_horizontal_mask=self.ignore_horizontal_mask, + ) if extent is None: continue diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 8dc541f049..6429ac2b11 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -191,7 +191,7 @@ def stencil( field_3d = gt_storage.zeros( full_shape, dtype, backend=backend, aligned_index=aligned_index, dimensions=None ) - assert field_3d.shape == full_shape[:] + assert field_3d.shape == full_shape[:], "field_3d shape" field_2d = gt_storage.zeros( full_shape[:-1], @@ -200,7 +200,7 @@ def stencil( aligned_index=aligned_index[:-1], dimensions="IJ", ) - assert field_2d.shape == full_shape[:-1] + assert field_2d.shape == full_shape[:-1], "field_2d shape" field_1d = gt_storage.ones( full_shape[-1:], @@ -209,13 +209,12 @@ def stencil( aligned_index=(aligned_index[-1],), dimensions="K", ) - assert list(field_1d.shape) == [full_shape[-1]] + assert list(field_1d.shape) == [full_shape[-1]], "field_1d shape" stencil(field_3d, field_2d, field_1d, origin=(1, 1, 0), domain=(4, 3, 6)) res_field_3d = storage_utils.cpu_copy(field_3d) - np.testing.assert_allclose(res_field_3d[1:-1, 1:-2, :1], 2) - np.testing.assert_allclose(res_field_3d[1:-1, 1:-2, 1:], 1) - + np.testing.assert_allclose(res_field_3d[1:-1, 1:-2, :1], 2, err_msg="expected 2 from K=0") + np.testing.assert_allclose(res_field_3d[1:-1, 1:-2, 1:], 1, err_msg="expected 1 from K>=1") stencil(field_3d, field_2d, field_1d, origin=(1, 1, 0)) @@ -1117,21 +1116,7 @@ def test( assert (out_arr[:, :, :] == 388.0).all() -def _xfail_dace_backends(param): - if param.values[0].startswith("dace:"): - marks = [ - *param.marks, - pytest.mark.xfail( - raises=ValueError, - reason="Missing support in DaCe backends, see https://github.com/GridTools/gt4py/issues/1881.", - ), - ] - # make a copy because otherwise we are operating in-place - return pytest.param(*param.values, marks=marks) - return param - - -@pytest.mark.parametrize("backend", map(_xfail_dace_backends, ALL_BACKENDS)) +@pytest.mark.parametrize("backend", ALL_BACKENDS) def test_cast_in_index(backend): @gtscript.stencil(backend) def cast_in_index( diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir.py deleted file mode 100644 index dc98392536..0000000000 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir.py +++ /dev/null @@ -1,111 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import pytest - -from gt4py.cartesian.gtc import common -from gt4py.cartesian.gtc.dace import daceir as dcir - -# Because "dace tests" filter by `requires_dace`, we still need to add the marker. -# This global variable add the marker to all test functions in this module. -pytestmark = pytest.mark.requires_dace - - -def test_DomainInterval() -> None: - I_start = dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START) - I_end = dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.END) - interval = dcir.DomainInterval(start=I_start, end=I_end) - - assert interval.start == I_start - assert interval.end == I_end - - with pytest.raises(ValueError, match=r"^Axis need to match for start and end bounds. Got *"): - dcir.DomainInterval( - start=I_start, - end=dcir.AxisBound(axis=dcir.Axis.J, level=common.LevelMarker.END), - ) - - -def test_DomainInterval_intersection() -> None: - I_0_4 = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=4), - ) - I_2_10 = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=2), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=10), - ) - I_2_5 = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=2), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=5), - ) - I_8_15 = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=8), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=15), - ) - I_full = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.END), - ) - I_end_m3 = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.END, offset=-3), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.END), - ) - - # expected results - I_2_4 = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=2), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=4), - ) - I_8_10 = dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=8), - end=dcir.AxisBound(axis=dcir.Axis.I, level=common.LevelMarker.START, offset=10), - ) - - assert dcir.DomainInterval.intersection(dcir.Axis.I, I_0_4, I_2_10) == I_2_4, ( - "intersection left" - ) - assert dcir.DomainInterval.intersection(dcir.Axis.I, I_2_10, I_8_15) == I_8_10, ( - "intersection right" - ) - - assert dcir.DomainInterval.intersection(dcir.Axis.I, I_2_5, I_2_10) == I_2_5, ( - "first contained in second" - ) - assert dcir.DomainInterval.intersection(dcir.Axis.I, I_2_10, I_2_5) == I_2_5, ( - "second contained in first" - ) - assert dcir.DomainInterval.intersection(dcir.Axis.I, I_8_15, I_full) == I_8_15, ( - "full interval overlaps with start level" - ) - assert dcir.DomainInterval.intersection(dcir.Axis.I, I_end_m3, I_full) == I_end_m3, ( - "full interval overlaps with end level" - ) - - with pytest.raises(ValueError, match=r"^No intersection found for intervals *"): - dcir.DomainInterval.intersection(dcir.Axis.I, I_0_4, I_8_15) - - with pytest.raises(ValueError, match=r"^Axis need to match: *"): - dcir.DomainInterval.intersection( - dcir.Axis.I, - dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.J, level=common.LevelMarker.START), - end=dcir.AxisBound(axis=dcir.Axis.J, level=common.LevelMarker.END), - ), - I_full, - ) - - with pytest.raises(ValueError, match=r"^Axis need to match: *"): - dcir.DomainInterval.intersection( - dcir.Axis.I, - I_full, - dcir.DomainInterval( - start=dcir.AxisBound(axis=dcir.Axis.J, level=common.LevelMarker.START), - end=dcir.AxisBound(axis=dcir.Axis.J, level=common.LevelMarker.END), - ), - ) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py deleted file mode 100644 index af23d7056a..0000000000 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py +++ /dev/null @@ -1,109 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import pytest - -from gt4py.cartesian.gtc.dace import daceir as dcir - -from cartesian_tests.unit_tests.test_gtc.dace import utils -from cartesian_tests.unit_tests.test_gtc.oir_utils import ( - AssignStmtFactory, - BinaryOpFactory, - HorizontalExecutionFactory, - LiteralFactory, - LocalScalarFactory, - MaskStmtFactory, - ScalarAccessFactory, - StencilFactory, - WhileFactory, -) - - -# Because "dace tests" filter by `requires_dace`, we still need to add the marker. -# This global variable add the marker to all test functions in this module. -pytestmark = pytest.mark.requires_dace - - -def test_dcir_code_structure_condition() -> None: - """Tests the following code structure: - - ComputationState - Condition - true_states: [ComputationState] - false_states: [] - ComputationState - """ - stencil = StencilFactory( - vertical_loops__0__sections__0__horizontal_executions=[ - HorizontalExecutionFactory( - body=[ - AssignStmtFactory( - left=ScalarAccessFactory(name="tmp"), - right=BinaryOpFactory( - left=LiteralFactory(value="0"), right=LiteralFactory(value="2") - ), - ), - MaskStmtFactory(), - AssignStmtFactory( - left=ScalarAccessFactory(name="other"), - right=ScalarAccessFactory(name="tmp"), - ), - ], - declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], - ), - ] - ) - expansions = utils.library_node_expansions(stencil) - assert len(expansions) == 1, "expect one vertical loop to be expanded" - - nested_SDFG = utils.nested_SDFG_inside_triple_loop(expansions[0]) - assert isinstance(nested_SDFG.states[0], dcir.ComputationState) - assert isinstance(nested_SDFG.states[1], dcir.Condition) - assert nested_SDFG.states[1].true_states - assert isinstance(nested_SDFG.states[1].true_states[0], dcir.ComputationState) - assert not nested_SDFG.states[1].false_states - assert isinstance(nested_SDFG.states[2], dcir.ComputationState) - - -def test_dcir_code_structure_while() -> None: - """Tests the following code structure - - ComputationState - WhileLoop - body: [ComputationState] - ComputationState - """ - stencil = StencilFactory( - vertical_loops__0__sections__0__horizontal_executions=[ - HorizontalExecutionFactory( - body=[ - AssignStmtFactory( - left=ScalarAccessFactory(name="tmp"), - right=BinaryOpFactory( - left=LiteralFactory(value="0"), right=LiteralFactory(value="2") - ), - ), - WhileFactory(), - AssignStmtFactory( - left=ScalarAccessFactory(name="other"), - right=ScalarAccessFactory(name="tmp"), - ), - ], - declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], - ), - ] - ) - expansions = utils.library_node_expansions(stencil) - assert len(expansions) == 1, "expect one vertical loop to be expanded" - - nested_SDFG = utils.nested_SDFG_inside_triple_loop(expansions[0]) - assert isinstance(nested_SDFG.states[0], dcir.ComputationState) - assert isinstance(nested_SDFG.states[1], dcir.WhileLoop) - assert nested_SDFG.states[1].body - assert isinstance(nested_SDFG.states[1].body[0], dcir.ComputationState) - assert isinstance(nested_SDFG.states[2], dcir.ComputationState) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py new file mode 100644 index 0000000000..e65d0c7f74 --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_oir_to_tasklet.py @@ -0,0 +1,53 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py.cartesian.gtc.dace import oir_to_tasklet +from gt4py.cartesian.gtc import oir, common + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable adds the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +@pytest.mark.parametrize( + "node,expected", + [ + ( + oir.FieldAccess( + name="A", + offset=oir.VariableKOffset(k=oir.Literal(value="1", dtype=common.DataType.AUTO)), + ), + "var_k", + ), + (oir.FieldAccess(name="A", offset=common.CartesianOffset(i=1, j=-1, k=0)), "ip1_jm1"), + ], +) +def test__field_offset_postfix(node: oir.FieldAccess, expected: str) -> None: + assert oir_to_tasklet._field_offset_postfix(node) == expected + + +@pytest.mark.parametrize( + "node,is_target,postfix,expected", + [ + (oir.ScalarAccess(name="A"), False, "", "gtIN__A"), + (oir.ScalarAccess(name="A"), True, "", "gtOUT__A"), + (oir.ScalarAccess(name="A"), False, "im1", "gtIN__A_im1"), + ( + oir.FieldAccess(name="A", offset=common.CartesianOffset(i=1, j=-1, k=0)), + True, + "", + "gtOUT__A", + ), + ], +) +def test__tasklet_name( + node: oir.FieldAccess | oir.ScalarAccess, is_target: bool, postfix: str, expected: str +) -> None: + assert oir_to_tasklet._tasklet_name(node, is_target, postfix) == expected diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py deleted file mode 100644 index 69329d96b2..0000000000 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py +++ /dev/null @@ -1,144 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dace -import pytest - -from gt4py.cartesian.gtc.common import BuiltInLiteral, DataType -from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder - -from cartesian_tests.unit_tests.test_gtc.dace import utils -from cartesian_tests.unit_tests.test_gtc.oir_utils import ( - AssignStmtFactory, - BinaryOpFactory, - HorizontalExecutionFactory, - LiteralFactory, - LocalScalarFactory, - MaskStmtFactory, - ScalarAccessFactory, - StencilFactory, -) - - -# Because "dace tests" filter by `requires_dace`, we still need to add the marker. -# This global variable add the marker to all test functions in this module. -pytestmark = pytest.mark.requires_dace - - -def test_scalar_access_multiple_tasklets() -> None: - """Test scalar access if an oir.CodeBlock is split over multiple Tasklets. - - We are breaking up vertical loops inside stencils in multiple Tasklets. It might thus happen that - we write a "local" scalar in one Tasklet and read it in another Tasklet (downstream). - We thus create output connectors for all writes to scalar variables inside Tasklets. And input - connectors for all scalar reads unless previously written in the same Tasklet. DaCe's simplify - pipeline will get rid of any dead dataflow introduced with this general approach. - """ - stencil = StencilFactory( - vertical_loops__0__sections__0__horizontal_executions=[ - HorizontalExecutionFactory( - body=[ - AssignStmtFactory( - left=ScalarAccessFactory(name="tmp"), - right=BinaryOpFactory( - left=LiteralFactory(value="0"), right=LiteralFactory(value="2") - ), - ), - MaskStmtFactory( - mask=LiteralFactory(value=BuiltInLiteral.TRUE, dtype=DataType.BOOL), body=[] - ), - AssignStmtFactory( - left=ScalarAccessFactory(name="other"), - right=ScalarAccessFactory(name="tmp"), - ), - ], - declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], - ), - ] - ) - expansions = utils.library_node_expansions(stencil) - nsdfg = StencilComputationSDFGBuilder().visit(expansions[0]) - assert isinstance(nsdfg.sdfg, dace.SDFG) - - for node in nsdfg.sdfg.nodes()[1].nodes(): - if not isinstance(node, dace.nodes.NestedSDFG): - continue - - nested = node.sdfg - for state in nested.states(): - if state.name == "block_0": - nodes = state.nodes() - assert ( - len(list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))) == 1 - ) - assert ( - len( - list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) - and node.data == "tmp", - nodes, - ) - ) - ) - == 1 - ), "one AccessNode of tmp" - - edges = state.edges() - tasklet = list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))[0] - write_access = list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "tmp", - nodes, - ) - )[0] - assert len(edges) == 1, "one edge expected" - assert edges[0].src == tasklet and edges[0].dst == write_access, ( - "write access of 'tmp'" - ) - - if state.name == "block_1": - nodes = state.nodes() - assert ( - len(list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))) == 1 - ) - assert ( - len( - list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) - and node.data == "tmp", - nodes, - ) - ) - ) - == 1 - ), "one AccessNode of tmp" - - edges = state.edges() - tasklet = list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))[0] - read_access = list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "tmp", - nodes, - ) - )[0] - write_access = list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) - and node.data == "other", - nodes, - ) - )[0] - assert len(edges) == 2, "two edges expected" - assert edges[0].src == tasklet and edges[0].dst == write_access, ( - "write access of 'other'" - ) - assert edges[1].src == read_access and edges[1].dst == tasklet, ( - "read access of 'tmp'" - ) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py deleted file mode 100644 index ab501d722e..0000000000 --- a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import pytest - -from typing import Optional - -from gt4py.cartesian.gtc.common import DataType, CartesianOffset -from gt4py.cartesian.gtc.dace import daceir as dcir -from gt4py.cartesian.gtc.dace import prefix -from gt4py.cartesian.gtc.dace import utils - -# Because "dace tests" filter by `requires_dace`, we still need to add the marker. -# This global variable add the marker to all test functions in this module. -pytestmark = pytest.mark.requires_dace - - -@pytest.mark.parametrize( - "name,is_target,offset,expected", - [ - ("A", False, None, f"{prefix.TASKLET_IN}A"), - ("A", True, None, f"{prefix.TASKLET_OUT}A"), - ("A", True, CartesianOffset(i=0, j=0, k=-1), f"{prefix.TASKLET_OUT}Akm1"), - ("A", False, CartesianOffset(i=1, j=-2, k=3), f"{prefix.TASKLET_IN}Aip1_jm2_kp3"), - ( - "A", - True, - dcir.VariableKOffset(k=dcir.Literal(value="3", dtype=DataType.INT32)), - f"{prefix.TASKLET_OUT}A", - ), - ], -) -def test_get_tasklet_symbol( - name: str, - is_target: bool, - offset: Optional[CartesianOffset | dcir.VariableKOffset], - expected: str, -) -> None: - assert utils.get_tasklet_symbol(name, is_target=is_target, offset=offset) == expected diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py deleted file mode 100644 index 3a3238f018..0000000000 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py +++ /dev/null @@ -1,159 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import pytest -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - import dace -else: - dace = pytest.importorskip("dace") - -from gt4py.cartesian.gtc import oir -from gt4py.cartesian.gtc.common import DataType -from gt4py.cartesian.gtc.dace.nodes import StencilComputation -from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder - -from cartesian_tests.unit_tests.test_gtc.oir_utils import ( - AssignStmtFactory, - FieldAccessFactory, - FieldDeclFactory, - ScalarAccessFactory, - StencilFactory, -) - -# Because "dace tests" filter by `requires_dace`, we still need to add the marker. -# This global variable add the marker to all test functions in this module. -pytestmark = pytest.mark.requires_dace - - -def test_oir_sdfg_builder_copy_stencil() -> None: - stencil_name = "copy" - stencil = StencilFactory( - name=stencil_name, - params=[ - FieldDeclFactory(name="A", dtype=DataType.FLOAT32), - FieldDeclFactory(name="B", dtype=DataType.FLOAT32), - ], - vertical_loops__0__sections__0__horizontal_executions__0__body=[ - AssignStmtFactory(left=FieldAccessFactory(name="B"), right=FieldAccessFactory(name="A")) - ], - ) - sdfg = OirSDFGBuilder().visit(stencil) - - assert isinstance(sdfg, dace.SDFG), "DaCe SDFG expected" - assert sdfg.name == stencil_name, "Stencil name is preserved" - assert len(sdfg.arrays) == 2, "two arrays expected (A and B)" - - a_array = sdfg.arrays.get("A") - assert a_array is not None, "Array A expected to be defined" - assert a_array.ctype == "float", "A is of type `float`" - assert a_array.offset == (0, 0, 0), "CartesianOffset.zero() expected" - - b_array = sdfg.arrays.get("B") - assert b_array is not None, "Array B expected to be defined" - assert b_array.ctype == "float", "B is of type `float`" - assert b_array.offset == (0, 0, 0), "CartesianOffset.zero() expected" - - states = sdfg.nodes() - assert len(states) >= 1, "at least one state expected" - - # expect StencilComputation, AccessNode(A), and AccessNode(B) in the last block - last_block = states[len(states) - 1] - nodes = last_block.nodes() - assert len(list(filter(lambda node: isinstance(node, StencilComputation), nodes))) == 1, ( - "one StencilComputation library node" - ) - assert ( - len( - list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes - ) - ) - ) - == 1 - ), "one AccessNode of A" - assert ( - len( - list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "B", nodes - ) - ) - ) - == 1 - ), "one AccessNode of B" - - edges = last_block.edges() - assert len(edges) == 2, "read and write memlet path expected" - - library_node = list(filter(lambda node: isinstance(node, StencilComputation), nodes))[0] - read_access = list( - filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes) - )[0] - write_access = list( - filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "B", nodes) - )[0] - - assert edges[0].src == read_access and edges[0].dst == library_node, "read access expected" - assert edges[1].src == library_node and edges[1].dst == write_access, "write access expected" - - -def test_oir_sdfg_builder_assign_scalar_param() -> None: - stencil_name = "scalar_assign" - stencil = StencilFactory( - name=stencil_name, - params=[ - FieldDeclFactory(name="A", dtype=DataType.FLOAT64), - oir.ScalarDecl(name="b", dtype=DataType.INT32), - ], - vertical_loops__0__sections__0__horizontal_executions__0__body=[ - AssignStmtFactory( - left=FieldAccessFactory(name="A"), right=ScalarAccessFactory(name="b") - ) - ], - ) - sdfg = OirSDFGBuilder().visit(stencil) - - assert isinstance(sdfg, dace.SDFG), "DaCe SDFG expected" - assert sdfg.name == stencil_name, "Stencil name is preserved" - assert len(sdfg.arrays) == 1, "one array expected (A)" - - a_array = sdfg.arrays.get("A") - assert a_array is not None, "Array A expected to be defined" - assert a_array.ctype == "double", "Array A is of type `double`" - assert a_array.offset == (0, 0, 0), "CartesianOffset.zeros() expected" - assert "b" in sdfg.symbols.keys(), "expected `b` as scalar parameter" - - states = sdfg.nodes() - assert len(states) >= 1, "at least one state expected" - - last_block = states[len(states) - 1] - nodes = last_block.nodes() - assert len(list(filter(lambda node: isinstance(node, StencilComputation), nodes))) == 1, ( - "one StencilComputation library node" - ) - assert ( - len( - list( - filter( - lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes - ) - ) - ) - == 1 - ), "one AccessNode of A" - - edges = last_block.edges() - library_node = list(filter(lambda node: isinstance(node, StencilComputation), nodes))[0] - write_access = list( - filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes) - )[0] - assert len(edges) == 1, "write memlet path expected" - assert edges[0].src == library_node and edges[0].dst == write_access, "write access expected" diff --git a/uv.lock b/uv.lock index d06b44007c..416896c2eb 100644 --- a/uv.lock +++ b/uv.lock @@ -850,7 +850,6 @@ resolution-markers = [ dependencies = [ { name = "aenum" }, { name = "astunparse" }, - { name = "cmake" }, { name = "dill" }, { name = "fparser" }, { name = "networkx" }, @@ -859,14 +858,13 @@ dependencies = [ { name = "ply" }, { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-rocm5-0' and extra == 'extra-5-gt4py-rocm6-0')" }, { name = "pyyaml" }, - { name = "scikit-build" }, { name = "sympy" }, ] [[package]] name = "dace" version = "1.0.2" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-to-sdfg#82541a9401dcadca43edc33cf1db61a0fe21d0e5" } resolution-markers = [ "python_full_version >= '3.13'", "python_full_version == '3.12.*'", @@ -876,7 +874,6 @@ resolution-markers = [ dependencies = [ { name = "aenum" }, { name = "astunparse" }, - { name = "cmake" }, { name = "dill" }, { name = "fparser" }, { name = "networkx" }, @@ -885,10 +882,8 @@ dependencies = [ { name = "ply" }, { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-rocm5-0' and extra == 'extra-5-gt4py-rocm6-0')" }, { name = "pyyaml" }, - { name = "scikit-build" }, { name = "sympy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/02/1a2ece00b229710a4db8f301bba6097eacfbc2a9af84d8746089242d1cf5/dace-1.0.2.tar.gz", hash = "sha256:6728f4bcf584b9f5bbb9c9a393fbdd87364af0c6ad9120da0302b8b470f4f71c", size = 5801789, upload-time = "2025-03-20T15:17:14.034Z" } [[package]] name = "debugpy" @@ -1003,15 +998,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973, upload-time = "2024-10-09T18:35:44.272Z" }, ] -[[package]] -name = "distro" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722, upload-time = "2023-12-24T09:54:32.31Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, -] - [[package]] name = "docutils" version = "0.21.2" @@ -1348,7 +1334,7 @@ build = [ { name = "wheel" }, ] dace-cartesian = [ - { name = "dace", version = "1.0.2", source = { registry = "https://pypi.org/simple" } }, + { name = "dace", version = "1.0.2", source = { git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-to-sdfg#82541a9401dcadca43edc33cf1db61a0fe21d0e5" } }, ] dace-next = [ { name = "dace", version = "1.0.0", source = { git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_07_25#46a8a9f26fe433a5bc7dd0d11b284c6fe03a2a9c" } }, @@ -1490,7 +1476,7 @@ build = [ { name = "setuptools", specifier = ">=70.0.0" }, { name = "wheel", specifier = ">=0.33.6" }, ] -dace-cartesian = [{ name = "dace", specifier = ">=1.0.2,<2" }] +dace-cartesian = [{ name = "dace", git = "https://github.com/GridTools/dace?branch=romanc%2Fstree-to-sdfg" }] dace-next = [{ name = "dace", git = "https://github.com/GridTools/dace?tag=__gt4py-next-integration_2025_07_25" }] dev = [ { name = "atlas4py", specifier = ">=0.41", index = "https://test.pypi.org/simple" }, @@ -3410,22 +3396,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/d9/fde7610abd53c0c76b6af72fc679cb377b27c617ba704e25da834e0a0608/ruff-0.9.5-py3-none-win_arm64.whl", hash = "sha256:18a29f1a005bddb229e580795627d297dfa99f16b30c7039e73278cf6b5f9fa9", size = 10064595, upload-time = "2025-02-06T19:47:12.071Z" }, ] -[[package]] -name = "scikit-build" -version = "0.18.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "distro" }, - { name = "packaging" }, - { name = "setuptools" }, - { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm6-0') or (extra == 'extra-5-gt4py-rocm5-0' and extra == 'extra-5-gt4py-rocm6-0')" }, - { name = "wheel" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/56/54/2beb41f3fcddb4ea238634c6c23fe93115090d8799a45f626a83e6934c16/scikit_build-0.18.1.tar.gz", hash = "sha256:a4152ac5a084d499c28a7797be0628d8366c336e2fb0e1a063eb32e55efcb8e7", size = 274171, upload-time = "2024-08-28T18:18:13.457Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/a3/21b519f58de90d684056c52ec4e45f744cfda7483f082dcc4dd18cc74a93/scikit_build-0.18.1-py3-none-any.whl", hash = "sha256:a6860e300f6807e76f21854163bdb9db16afc74eadf34bd6a9947d3fdfcd725a", size = 85568, upload-time = "2024-08-28T18:18:12.247Z" }, -] - [[package]] name = "scipy" version = "1.15.1"