Skip to content
Draft
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import dace
from dace import subsets as dace_subsets

from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import (
common_pattern_matcher as cpm,
Expand All @@ -29,53 +28,14 @@
gtir_to_sdfg,
gtir_to_sdfg_types,
gtir_to_sdfg_utils,
sdfg_library_nodes,
)
from gt4py.next.type_system import type_specifications as ts


def _make_concat_field_slice(
ctx: gtir_to_sdfg.SubgraphContext,
field: gtir_to_sdfg_types.FieldopData,
field_desc: dace.data.Array,
concat_dim: gtx_common.Dimension,
concat_dim_index: int,
concat_dim_origin: dace.symbolic.SymbolicType,
) -> tuple[gtir_to_sdfg_types.FieldopData, dace.data.Array]:
"""
Helper function called by `_translate_concat_where_impl` to create a slice along
the concat dimension, that is a new array with an extra dimension and a single
level. This allows to concatanate the input fields along the concat dimension.
"""
assert isinstance(field.gt_type, ts.FieldType)
assert concat_dim not in field.gt_type.dims
dims = [
*field.gt_type.dims[:concat_dim_index],
concat_dim,
*field.gt_type.dims[concat_dim_index:],
]
origin = tuple(
[*field.origin[:concat_dim_index], concat_dim_origin, *field.origin[concat_dim_index:]]
)
shape = tuple([*field_desc.shape[:concat_dim_index], 1, *field_desc.shape[concat_dim_index:]])
extended_field_data, extended_field_desc = ctx.sdfg.add_temp_transient(shape, field_desc.dtype)
extended_field_node = ctx.state.add_access(extended_field_data)
ctx.state.add_nedge(
field.dc_node,
extended_field_node,
dace.Memlet(
data=field.dc_node.data,
subset=dace_subsets.Range.from_array(field_desc),
other_subset=dace_subsets.Range.from_array(extended_field_desc),
),
)
extended_field = gtir_to_sdfg_types.FieldopData(
extended_field_node, ts.FieldType(dims=dims, dtype=field.gt_type.dtype), origin
)
return extended_field, extended_field_desc


def _make_concat_scalar_broadcast(
ctx: gtir_to_sdfg.SubgraphContext,
sdfg_builder: gtir_to_sdfg.SDFGBuilder,
inp: gtir_to_sdfg_types.FieldopData,
inp_desc: dace.data.Array,
out_domain: domain_utils.SymbolicDomain,
Expand All @@ -87,38 +47,45 @@ def _make_concat_scalar_broadcast(
The scalar value can come from either a scalar node or from a 1D-array (assuming
the array represents a field in the concat dimension).
"""
assert isinstance(inp.gt_type, ts.FieldType)
assert len(inp.gt_type.dims) == 1
concat_dim = inp.gt_type.dims[0]

out_dims, out_origin, out_shape = gtir_domain.get_field_layout(
gtir_domain.get_field_domain(out_domain)
)
concat_dim_index = out_dims.index(concat_dim)

out_name, out_desc = ctx.sdfg.add_temp_transient(out_shape, inp_desc.dtype)
out_node = ctx.state.add_access(out_name)

map_variables = [gtir_to_sdfg_utils.get_map_variable(dim) for dim in out_dims]
inp_index = (
"0"
if isinstance(inp.dc_node.desc(ctx.sdfg), dace.data.Scalar)
else (
f"({map_variables[concat_dim_index]} + {out_origin[concat_dim_index] - inp.origin[0]})"
)
inp_desc = inp.dc_node.desc(ctx.sdfg)

if isinstance(inp.gt_type, ts.FieldType):
assert isinstance(inp.gt_type.dtype, ts.ScalarType)
inp_axes = [out_dims.index(dim) for dim in inp.gt_type.dims]
inp_origin = inp.origin
dtype = inp.gt_type.dtype
else:
inp_axes = None
inp_origin = None
dtype = inp.gt_type

# Use a 'Broadcast' library node to write the scalar value to the result field.
name = sdfg_builder.unique_tasklet_name("broadcast")
bcast_node = sdfg_library_nodes.Broadcast(name, inp_axes, inp_origin, out_origin)
ctx.state.add_node(bcast_node)
ctx.state.add_edge(
inp.dc_node,
None,
bcast_node,
"_inp",
dace.Memlet(data=inp.dc_node.data, subset=dace_subsets.Range.from_array(inp_desc)),
)
ctx.state.add_mapped_tasklet(
"broadcast",
map_ranges=dict(zip(map_variables, dace_subsets.Range.from_array(out_desc), strict=True)),
code="__out = __inp",
inputs={"__inp": dace.Memlet(data=inp.dc_node.data, subset=inp_index)},
outputs={"__out": dace.Memlet(data=out_name, subset=",".join(map_variables))},
input_nodes={inp.dc_node},
output_nodes={out_node},
external_edges=True,
ctx.state.add_edge(
bcast_node,
"_outp",
out_node,
None,
dace.Memlet(data=out_name, subset=dace_subsets.Range.from_array(out_desc)),
)

out_type = ts.FieldType(dims=out_dims, dtype=inp.gt_type.dtype)
out_type = ts.FieldType(dims=out_dims, dtype=dtype)
out_field = gtir_to_sdfg_types.FieldopData(out_node, out_type, tuple(out_origin))
return out_field, out_desc

Expand Down Expand Up @@ -184,165 +151,85 @@ def _translate_concat_where_impl(
else:
raise ValueError(f"Unexpected concat mask {mask_domain} with finite domain.")

# We use the concat domain, stored in the annex, as the domain of output field.
output_domain = gtir_domain.get_field_domain(node_domain)
output_dims, output_origin, output_shape = gtir_domain.get_field_layout(output_domain)
concat_dim_index = output_dims.index(concat_dim)

"""
In case one of the arguments is a scalar value, for example:
We broadcast the argument field on dimensions on the output field, in case the
argument is a scalar value, for example:
```python
@gtx.field_operator
def testee(a: np.int32, b: IJKField) -> IJKField:
return concat_where(KDim == 0, a, b)
```

Similarly, we broadcast a field defined as a slice of the output domain, i.e.
a field with a smaller number of dimensions than the out field.
Consider for example the following IR, where the the 'a' IJ-field is used for
the vertical boundary (`KDim == 0`):
```python
@gtx.field_operator
def testee(a: np.int32, b: cases.IJKField) -> cases.IJKField:
return concat_where(KDim < 1, a, b)
def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
return concat_where(KDim == 0, boundary, interior)
```
we convert it to a single-element 1D field with the dimension of the concat expression.
"""
if isinstance(lower.gt_type, ts.ScalarType):
assert len(lower_domain.ranges) == 0
assert isinstance(upper.gt_type, ts.FieldType)
lower_origin_expr = node_domain.ranges[concat_dim].start
lower = gtir_to_sdfg_types.FieldopData(
lower.dc_node,
ts.FieldType(dims=[concat_dim], dtype=lower.gt_type),
origin=(gtir_to_sdfg_utils.get_symbolic(lower_origin_expr),),
)
lower_domain.ranges[concat_dim] = domain_utils.SymbolicRange(
start=lower_origin_expr, stop=concat_dim_bound_expr
)
elif isinstance(upper.gt_type, ts.ScalarType):
assert len(upper_domain.ranges) == 0
assert isinstance(lower.gt_type, ts.FieldType)
upper_origin_expr = concat_dim_bound_expr
upper = gtir_to_sdfg_types.FieldopData(
upper.dc_node,
ts.FieldType(dims=[concat_dim], dtype=upper.gt_type),
origin=(gtir_to_sdfg_utils.get_symbolic(upper_origin_expr),),
)
upper_domain.ranges[concat_dim] = domain_utils.SymbolicRange(
start=upper_origin_expr,
stop=node_domain.ranges[concat_dim].stop,
)

# we use the concat domain, stored in the annex, as the domain of output field
output_domain = gtir_domain.get_field_domain(node_domain)
output_dims, output_origin, output_shape = gtir_domain.get_field_layout(output_domain)
concat_dim_index = output_dims.index(concat_dim)

if concat_dim not in lower.gt_type.dims: # type: ignore[union-attr]
"""
The field on the lower domain is to be treated as a slice to add as one
level in the concat dimension, on the lower bound.
Consider for example the following IR, where a horizontal field is added
as level zero in K-dimension:
```python
@gtx.field_operator
def testee(interior: cases.IJKField, boundary: cases.IJField) -> cases.IJKField:
return concat_where(KDim == 0, boundary, interior)
```
"""
assert (
lower.gt_type.dims # type: ignore[union-attr]
== [
*upper.gt_type.dims[0:concat_dim_index], # type: ignore[union-attr]
*upper.gt_type.dims[concat_dim_index + 1 :], # type: ignore[union-attr]
]
)
lower_origin_expr = node_domain.ranges[concat_dim].start
lower, lower_desc = _make_concat_field_slice(
ctx=ctx,
field=lower,
field_desc=lower_desc,
concat_dim=concat_dim,
concat_dim_index=concat_dim_index,
concat_dim_origin=gtir_to_sdfg_utils.get_symbolic(lower_origin_expr),
)
lower_domain.ranges[concat_dim] = domain_utils.SymbolicRange(
start=lower_origin_expr,
stop=concat_dim_bound_expr,
)
elif concat_dim not in upper.gt_type.dims: # type: ignore[union-attr]
# Same as previous case, but the field slice is added on the upper bound.
assert (
upper.gt_type.dims # type: ignore[union-attr]
== [
*lower.gt_type.dims[0:concat_dim_index], # type: ignore[union-attr]
*lower.gt_type.dims[concat_dim_index + 1 :], # type: ignore[union-attr]
]
)
upper_origin_expr = concat_dim_bound_expr
upper, upper_desc = _make_concat_field_slice(
ctx=ctx,
field=upper,
field_desc=upper_desc,
concat_dim=concat_dim,
concat_dim_index=concat_dim_index,
concat_dim_origin=gtir_to_sdfg_utils.get_symbolic(upper_origin_expr),
)
upper_domain.ranges[concat_dim] = domain_utils.SymbolicRange(
start=upper_origin_expr,
stop=node_domain.ranges[concat_dim].stop,
)
elif isinstance(lower_desc, dace.data.Scalar) or (
len(lower.gt_type.dims) == 1 and len(node_domain.ranges) > 1 # type: ignore[union-attr]
):
"""
The input on the lower domain is either a scalar or a 1d field, representing
the value(s) to be added as one level in the concat dimension below the upper domain.
Consider for example the following IR, where the scalar value is one level
(`KDim == 0`) taken from lower input 'a':
```python
@gtx.field_operator
def testee(a: cases.KField, b: cases.IJKField) -> cases.IJKField:
return concat_where(KDim == 0, a, b)
```
"""
assert lower_domain.ranges.keys() == {concat_dim}
if isinstance(lower.gt_type, ts.ScalarType) or len(lower.gt_type.dims) < len(output_dims):
if concat_dim not in lower_domain.ranges:
lower_domain.ranges[concat_dim] = domain_utils.SymbolicRange(
start=node_domain.ranges[concat_dim].start, stop=concat_dim_bound_expr
)
lower_domain = domain_utils.promote_domain(lower_domain, node_domain.ranges.keys())
lower_domain = domain_utils.domain_intersection(lower_domain, node_domain)
lower, lower_desc = _make_concat_scalar_broadcast(
ctx=ctx,
sdfg_builder=sdfg_builder,
inp=lower,
inp_desc=lower_desc,
out_domain=node_domain,
out_domain=lower_domain,
)
elif isinstance(upper_desc, dace.data.Scalar) or (
len(upper.gt_type.dims) == 1 and len(node_domain.ranges) > 1 # type: ignore[union-attr]
):
# Same as previous case, but the scalar value is taken from `upper` input.
assert upper_domain.ranges.keys() == {concat_dim}

elif isinstance(upper.gt_type, ts.ScalarType) or len(upper.gt_type.dims) < len(output_dims):
if concat_dim not in upper_domain.ranges:
upper_domain.ranges[concat_dim] = domain_utils.SymbolicRange(
start=concat_dim_bound_expr,
stop=node_domain.ranges[concat_dim].stop,
)
upper_domain = domain_utils.promote_domain(upper_domain, node_domain.ranges.keys())
upper_domain = domain_utils.domain_intersection(upper_domain, node_domain)
upper, upper_desc = _make_concat_scalar_broadcast(
ctx=ctx,
sdfg_builder=sdfg_builder,
inp=upper,
inp_desc=upper_desc,
out_domain=node_domain,
out_domain=upper_domain,
)

else:
"""
Handle here the _regular_ case, that is concat_where applied to two fields
with same domain:
```python
@gtx.field_operator
def testee(a: cases.IJKField, b: cases.IJKField) -> cases.IJKField:
return concat_where(KDim <=10 , a, b)
return concat_where(KDim < 10, a, b)
```
"""
assert isinstance(lower.gt_type, ts.FieldType)
assert isinstance(lower_desc, dace.data.Array)
assert isinstance(upper.gt_type, ts.FieldType)
assert isinstance(upper_desc, dace.data.Array)
if lower.gt_type.dims != upper.gt_type.dims:
raise NotImplementedError(
raise ValueError(
"Lowering concat_where on fields with different domain is not supported."
)

# ensure that the arguments have the same domain as the concat result
assert all(ftype.dims == output_dims for ftype in (lower.gt_type, upper.gt_type)) # type: ignore[union-attr]

lower_domain = domain_utils.domain_intersection(lower_domain, node_domain)
lower_domain_range = lower_domain.ranges[concat_dim]
lower_range_0 = gtir_to_sdfg_utils.get_symbolic(lower_domain_range.start)
lower_range_1 = gtir_to_sdfg_utils.get_symbolic(
im.maximum(lower_domain_range.start, lower_domain_range.stop)
)
lower_range_size = lower_range_1 - lower_range_0

upper_domain = domain_utils.domain_intersection(upper_domain, node_domain)
upper_domain_range = upper_domain.ranges[concat_dim]
upper_range_0 = gtir_to_sdfg_utils.get_symbolic(upper_domain_range.start)
upper_range_1 = gtir_to_sdfg_utils.get_symbolic(
Expand Down
Loading