Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
gtir_to_sdfg,
gtir_to_sdfg_types,
gtir_to_sdfg_utils,
sdfg_library_nodes,
)
from gt4py.next.type_system import type_specifications as ts

Expand Down Expand Up @@ -99,24 +100,38 @@ def _make_concat_scalar_broadcast(
out_name, out_desc = ctx.sdfg.add_temp_transient(out_shape, inp_desc.dtype)
out_node = ctx.state.add_access(out_name)

map_variables = [gtir_to_sdfg_utils.get_map_variable(dim) for dim in out_dims]
inp_index = (
"0"
if isinstance(inp.dc_node.desc(ctx.sdfg), dace.data.Scalar)
else (
if isinstance(inp.dc_node.desc(ctx.sdfg), dace.data.Scalar):
# Use a 'Fill' library node to write the scalar value to the result field.
fill_node = sdfg_library_nodes.Fill("fill")
ctx.state.add_node(fill_node)
ctx.state.add_nedge(
inp.dc_node,
fill_node,
dace.Memlet(data=inp.dc_node.data, subset="0"),
)
ctx.state.add_nedge(
fill_node,
out_node,
dace.Memlet(data=out_name, subset=dace_subsets.Range.from_array(out_desc)),
)
else:
# Create a map to copy one level on the field domain.
map_variables = [gtir_to_sdfg_utils.get_map_variable(dim) for dim in out_dims]
inp_index = (
f"({map_variables[concat_dim_index]} + {out_origin[concat_dim_index] - inp.origin[0]})"
)
)
ctx.state.add_mapped_tasklet(
"broadcast",
map_ranges=dict(zip(map_variables, dace_subsets.Range.from_array(out_desc), strict=True)),
code="__out = __inp",
inputs={"__inp": dace.Memlet(data=inp.dc_node.data, subset=inp_index)},
outputs={"__out": dace.Memlet(data=out_name, subset=",".join(map_variables))},
input_nodes={inp.dc_node},
output_nodes={out_node},
external_edges=True,
)
ctx.state.add_mapped_tasklet(
"broadcast",
map_ranges=dict(
zip(map_variables, dace_subsets.Range.from_array(out_desc), strict=True)
),
code="__out = __inp",
inputs={"__inp": dace.Memlet(data=inp.dc_node.data, subset=inp_index)},
outputs={"__out": dace.Memlet(data=out_name, subset=",".join(map_variables))},
input_nodes={inp.dc_node},
output_nodes={out_node},
external_edges=True,
)

out_type = ts.FieldType(dims=out_dims, dtype=inp.gt_type.dtype)
out_field = gtir_to_sdfg_types.FieldopData(out_node, out_type, tuple(out_origin))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import dace
from dace import subsets as dace_subsets

from gt4py.eve.extended_typing import MaybeNestedInTuple
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import (
Expand All @@ -29,6 +28,7 @@
gtir_to_sdfg,
gtir_to_sdfg_types,
gtir_to_sdfg_utils,
sdfg_library_nodes,
utils as gtx_dace_utils,
)
from gt4py.next.program_processors.runners.dace.gtir_to_sdfg_concat_where import (
Expand Down Expand Up @@ -73,7 +73,7 @@ def _parse_fieldop_arg(
ctx: gtir_to_sdfg.SubgraphContext,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
domain: gtir_domain.FieldopDomain,
) -> MaybeNestedInTuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr]:
) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr:
"""
Helper method to visit an expression passed as argument to a field operator
and create the local view for the field argument.
Expand Down Expand Up @@ -251,11 +251,16 @@ def translate_as_fieldop(
raise NotImplementedError("Unexpected 'as_filedop' with tuple output in SDFG lowering.")

if cpm.is_ref_to(fieldop_expr, "deref"):
# Special usage of 'deref' as argument to fieldop expression, to pass a scalar
# value to 'as_fieldop' function. It results in broadcasting the scalar value
# over the field domain.
stencil_expr = im.lambda_("a")(im.deref("a"))
stencil_expr.expr.type = node.type.dtype
if isinstance(node.args[0].type, ts.ScalarType):
# Special usage of 'deref' as argument to fieldop expression, to broadcast
# a scalar value on the field domain.
return translate_broadcast(node, ctx, sdfg_builder)
else:
# Special usage of 'deref' with field argument, to return a subset of
# the full field domain.
# TODO(edopao): Lower this case to a memlet edge, planned for next PR.
stencil_expr = im.lambda_("a")(im.deref("a"))
stencil_expr.expr.type = node.type.dtype
elif isinstance(fieldop_expr, gtir.Lambda):
# Default case, handled below: the argument expression is a lambda function
# representing the stencil operation to be computed over the field domain.
Expand Down Expand Up @@ -285,6 +290,66 @@ def translate_as_fieldop(
)


def translate_broadcast(
node: gtir.Node,
ctx: gtir_to_sdfg.SubgraphContext,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
) -> gtir_to_sdfg_types.FieldopData:
"""Translates a broadcast expression which writes a scalar value on the field domain."""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")

if not isinstance(node.type, ts.FieldType):
raise NotImplementedError("Unexpected 'as_filedop' with tuple output in SDFG lowering.")

assert isinstance(node.type.dtype, ts.ScalarType)
field_dtype = gtx_dace_utils.as_dace_type(node.type.dtype)

assert len(node.args) == 1
assert isinstance(node.args[0].type, ts.ScalarType)
scalar_arg = node.args[0]

fun_node = node.fun
assert len(fun_node.args) == 2
fieldop_expr, fieldop_domain_expr = fun_node.args
assert cpm.is_ref_to(fieldop_expr, "deref")

# Parse the domain of the field operator.
assert isinstance(fieldop_domain_expr.type, ts.DomainType)
field_domain = gtir_domain.get_field_domain(
domain_utils.SymbolicDomain.from_expr(fieldop_domain_expr)
)

# The memory layout of the output field follows the field operator compute domain.
field_dims, field_origin, field_shape = gtir_domain.get_field_layout(field_domain)
assert field_dims == node.type.dims
field_name, field_desc = sdfg_builder.add_temp_array(ctx.sdfg, field_shape, field_dtype)
field_node = ctx.state.add_access(field_name)

# Retrieve the scalar argument, which could be either a literal value or the
# result of a scalar expression.
arg = _parse_fieldop_arg(scalar_arg, ctx, sdfg_builder, field_domain)
assert isinstance(arg, gtir_dataflow.MemletExpr)
assert arg.subset.num_elements() == 1

# Use a 'Fill' library node to write the scalar value to the result field.
fill_node = sdfg_library_nodes.Fill("fill")
ctx.state.add_node(fill_node)
ctx.state.add_nedge(
arg.dc_node,
fill_node,
dace.Memlet(data=arg.dc_node.data, subset=arg.subset),
)

ctx.state.add_nedge(
fill_node,
field_node,
dace.Memlet(data=field_name, subset=dace_subsets.Range.from_array(field_desc)),
)

return gtir_to_sdfg_types.FieldopData(field_node, node.type, tuple(field_origin))


def _construct_if_branch_output(
ctx: gtir_to_sdfg.SubgraphContext,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
Expand Down Expand Up @@ -715,6 +780,7 @@ def translate_symbol_ref(
# Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol
__primitive_translators: list[PrimitiveTranslator] = [
translate_as_fieldop,
translate_broadcast,
translate_concat_where,
translate_if,
translate_index,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from typing import Any, Final

import dace
from dace import library as dace_library, nodes as dace_nodes
from dace.transformation import transformation as dace_transform


_INPUT_NAME: Final[str] = "_input"
_OUTPUT_NAME: Final[str] = "_output"


@dace_library.expansion
class ExpandPure(dace_transform.ExpandTransformation):
"""Implements pure expansion of the Fill library node."""

environments: Final[list[Any]] = []

@staticmethod
def expansion(node: Fill, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG) -> dace.SDFG:
sdfg = dace.SDFG(f"{node.label}_sdfg")

assert len(parent_state.out_edges(node)) == 1
outedge = parent_state.out_edges(node)[0]
out_desc = parent_sdfg.arrays[outedge.data.data]
inner_out_desc = out_desc.clone()
inner_out_desc.transient = False
out = sdfg.add_datadesc(_OUTPUT_NAME, inner_out_desc)
outedge._src_conn = _OUTPUT_NAME
node.add_out_connector(_OUTPUT_NAME)

state = sdfg.add_state(f"{node.label}_state")
map_params = [f"__i{i}" for i in range(len(out_desc.shape))]
map_rng = {i: f"0:{s}" for i, s in zip(map_params, out_desc.shape)}
out_mem = dace.Memlet(expr=f"{out}[{','.join(map_params)}]")
outputs = {"_out": out_mem}

assert len(parent_state.in_edges(node)) == 1
inedge = parent_state.in_edges(node)[0]
inp_desc = parent_sdfg.arrays[inedge.data.data]
inner_inp_desc = inp_desc.clone()
inner_inp_desc.transient = False
inp = sdfg.add_datadesc(_INPUT_NAME, inner_inp_desc)
inedge._dst_conn = _INPUT_NAME
node.add_in_connector(_INPUT_NAME)
inputs = {"_in": dace.Memlet(data=inp, subset="0")}
code = "_out = _in"

state.add_mapped_tasklet(
f"{node.label}_tasklet", map_rng, inputs, code, outputs, external_edges=True
)

return sdfg


@dace_library.node
class Fill(dace_nodes.LibraryNode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add some more semantic, i.e. an input connector, that collects the value that should be broadcasted and an output connector for the output.

I am also wondering if it would make sense to have two different library nodes.
One where the value that is broadcast is a literal, like 0.0 and one, which is probably the current one, where the value is read from another data descriptor (might be hard to integrate into the lowering).

"""Implements filling data containers with a single value"""

implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {"pure": ExpandPure}
default_implementation: Final[str] = "pure"

def __init__(self, name: str):
super().__init__(name)