diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index fa5bc4889f..f4950ea402 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -113,6 +113,18 @@ class TupleExpr(Expr): elts: list[Expr] +# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals +# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name. +class TupleComprehension(Expr): + """ + tuple(element_expr for target in iterable) + """ + + element_expr: Expr + target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`? + iterable: Expr + + class UnaryOp(Expr): op: dialect_ast_enums.UnaryOperator operand: Expr diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 68bf108a0a..e545f9e002 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +import collections from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast @@ -501,6 +501,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri f"Tuples need to be indexed with literal integers, got '{node.index}'.", ) from ex new_type = types[index] + case ts.VarArgType(element_type=element_type): + new_type = ( + element_type # TODO: we only temporarily allow any index for vararg types + ) case ts.OffsetType(source=source, target=(target1, target2)): if not target2.kind == DimensionKind.LOCAL: raise errors.DSLError( @@ -747,6 +751,26 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx new_type = ts.TupleType(types=[element.type for element in new_elts]) return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location) + def visit_TupleComprehension( + self, node: foast.TupleComprehension, **kwargs: Any + ) -> foast.TupleComprehension: + symtable: collections.ChainMap = kwargs["symtable"] # todo annotation + iterable = self.visit(node.iterable, **kwargs) + target = self.visit(node.target, **kwargs) + assert isinstance(iterable.type, ts.VarArgType) + target.type = iterable.type.element_type + element_expr = self.visit( + node.element_expr, + **{**kwargs, "symtable": symtable.new_child({node.target.id: target})}, + ) + return foast.TupleComprehension( + element_expr=element_expr, + target=target, + iterable=iterable, + location=node.location, + type=ts.VarArgType(element_type=element_expr.type), + ) + def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call: new_func = self.visit(node.func, **kwargs) new_args = self.visit(node.args, **kwargs) diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 8b2e369501..77495d78f7 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -118,6 +118,8 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})") + TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))") + UnaryOp = as_fmt("{op}{operand}") def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 3825072cb7..2e587c346e 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -257,6 +257,15 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) + def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr: + return im.call( + im.call("map_tuple")( + im.lambda_(self.visit(node.target, **kwargs))( + self.visit(node.element_expr, **kwargs) + ) + ) + )(self.visit(node.iterable, **kwargs)) + def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 05b080b70b..c37cba5a78 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -21,7 +21,7 @@ from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef from gt4py.next.iterator import ir as itir from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts +from gt4py.next.type_system import type_specifications as ts @dataclasses.dataclass(frozen=True) @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef: *partial_program_type.definition.kw_only_args.keys(), ] assert isinstance(type_, ts.CallableType) - assert arg_types[-1] == type_info.return_type( - type_, with_args=list(arg_types), with_kwargs=kwarg_types - ) + # assert arg_types[-1] == type_info.return_type( + # type_, with_args=list(arg_types), with_kwargs=kwarg_types + # ) assert args_names[-1] == "out" params_decl: list[past.Symbol] = [ diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ced0ff3905..adefa7ba9e 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -337,7 +337,12 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr: return self.visit(node.value) def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name: - return foast.Name(id=node.id, location=self.get_location(node)) + loc = self.get_location(node) + if isinstance(node.ctx, ast.Store): + return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None)) + else: + assert isinstance(node.ctx, ast.Load) + return foast.Name(id=node.id, location=loc) def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp: return foast.UnaryOp( @@ -469,8 +474,10 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator: return foast.CompareOperator.NOTEQ def _verify_builtin_type_constructor(self, node: ast.Call) -> None: - if len(node.args) > 0: - arg = node.args[0] + (arg,) = ( + node.args + ) # note for review: the change here is unrelated to the actual pr and just a small cleanup + if node.func.id == "tuple": if not ( isinstance(arg, ast.Constant) or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant)) @@ -484,9 +491,25 @@ def _func_name(self, node: ast.Call) -> str: return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call: - # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if isinstance(node.func, ast.Name): func_name = self._func_name(node) + + if func_name == "tuple": + (gen_expr,) = node.args + assert ( + len(gen_expr.generators) == 1 + ) # we don't support (... for ... in ... for ... in ...) + assert ( + gen_expr.generators[0].ifs == [] + ) # we don't support if conditions in comprehensions + return foast.TupleComprehension( + element_expr=self.visit(gen_expr.elt, **kwargs), + target=self.visit(gen_expr.generators[0].target, **kwargs), + iterable=self.visit(gen_expr.generators[0].iter, **kwargs), + location=self.get_location(node), + ) + + # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? if func_name in fbuiltins.TYPE_BUILTIN_NAMES: self._verify_builtin_type_constructor(node) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 9d021ceb51..530d407459 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call: operator_return_type = type_info.return_type( new_func.type, with_args=arg_types, with_kwargs=kwarg_types ) - if operator_return_type != new_kwargs["out"].type: + if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type): raise ValueError( "Expected keyword argument 'out' to be of " f"type '{operator_return_type}', got " diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index e54c6ea3d7..7b24c91884 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -498,7 +498,8 @@ def get_domain_range(*args): "lift", "make_const_list", "make_tuple", - "map_", + "map_tuple", + "map_", # TODO: rename to map_list "named_range", "neighbors", "reduce", diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 08ca9d94e0..4102790129 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -24,6 +24,7 @@ prune_empty_concat_where, remove_broadcast, symbol_ref_utils, + unroll_map_tuple, ) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple @@ -179,6 +180,7 @@ def apply_common_transforms( ) # domain inference does not support dynamic offsets yet ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = infer_domain.infer_program( ir, @@ -293,6 +295,7 @@ def apply_fieldview_transforms( ir = infer_domain_ops.InferDomainOps.apply(ir) ir = concat_where.canonicalize_domain_argument(ir) + ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids) ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program ir = infer_domain.infer_program( diff --git a/src/gt4py/next/iterator/transforms/unroll_map_tuple.py b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py new file mode 100644 index 0000000000..66f96d66fa --- /dev/null +++ b/src/gt4py/next/iterator/transforms/unroll_map_tuple.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses + +from gt4py import eve +from gt4py.next import utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass +class UnrollMapTuple(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + uids: utils.IDGeneratorPool + + @classmethod + def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool): + return cls(uids=uids).visit(program) + + def visit_FunCall(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_call_to(node.fun, "map_tuple"): + # TODO: we have to duplicate the function here since the domain inference can not handle them yet + f = node.fun.args[0] + tup = node.args[0] + itir_inference.reinfer(tup) + assert isinstance(tup.type, ts.TupleType) + tup_ref = next(self.uids["_ump"]) + + result = im.let(tup_ref, tup)( + im.make_tuple( + *(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types))) + ) + ) + itir_inference.reinfer(result) + + return result + return node diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 6d77c70375..4406dd9aa8 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -633,6 +633,19 @@ def applied_map( return applied_map +@_register_builtin_type_synthesizer +def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer: + @type_synthesizer + def applied_map( + arg: ts.TupleType, offset_provider_type: common.OffsetProviderType + ) -> ts.TupleType: + return ts.TupleType( + types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types] + ) + + return applied_map + + @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @type_synthesizer diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index eb70d15947..69fccd33da 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -566,6 +566,14 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: or issubclass(type_class(to_type), symbol_type.constraint) ): return True + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType): + return is_concretizable(symbol_type.element_type, to_type.element_type) + if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType): + if len(to_type.types) == 0 or ( + all(type_ == to_type.types[0] for type_ in to_type.types) + and is_concretizable(symbol_type.element_type, to_type.types[0]) + ): + return True elif is_concrete(symbol_type): return symbol_type == to_type return False diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 59ac40f0f3..409138d593 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -148,6 +148,15 @@ def __len__(self) -> int: return len(self.types) +class VarArgType(DataType): + """Represents a variable number of arguments of the same type.""" + + element_type: DataType # TODO: maybe also support different DataTypes + + def __str__(self) -> str: + return f"VarArg[{self.element_type}]" + + class AnyPythonType: """Marker type representing any Python type which cannot be used for instantiation. diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 0f145e04aa..0ca020625a 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -180,8 +180,12 @@ def from_type_hint( case builtins.tuple: if not args: raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") - if Ellipsis in args: - raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") + if len(args) == 2 and args[1] is Ellipsis: + return ts.VarArgType(element_type=from_type_hint_same_ns(args[0])) + elif Ellipsis in args: + raise ValueError( + f"Vararg tuple annotation '{type_hint}' cannot have more than one argument." + ) tuple_types = [from_type_hint_same_ns(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) return ts.TupleType(types=tuple_types) @@ -321,7 +325,19 @@ def from_value(value: Any) -> ts.TypeSpec: return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) - symbol_type = from_type_hint(type_) + if type_ == type[tuple]: + # TODO: this special casing here is not nice, but infer_type is also called on the annotations where + # we don't want to allow unparameterized tuples (or do we?). + symbol_type = ts.ConstructorType( + definition=ts.FunctionType( + pos_only_args=[ts.DeferredType(constraint=None)], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.DeferredType(constraint=ts.VarArgType), + ) + ) + else: + symbol_type = from_type_hint(type_) if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 78e6c62781..e723c963de 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -603,6 +603,15 @@ def _allocate_from_type( for t in types ) ) + case ts.VarArgType(element_type=element_type): + return tuple( + ( + _allocate_from_type( + case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy + ) + for t in [element_type] * 3 # TODO: revisit + ) + ) case ts.NamedCollectionType(types=types) as named_collection_type_spec: container_constructor = ( named_collections.make_named_collection_constructor_from_type_spec( @@ -648,6 +657,8 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> return sum([get_param_size(t, sizes=sizes) for t in types]) case ts.NamedCollectionType(types=types): return sum([get_param_size(t, sizes=sizes) for t in types]) + case ts.VarArgType(element_type=element_type): + return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit case _: raise TypeError(f"Can not get size for parameter of type '{param_type}'.") diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..14f14b3ffb 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -336,6 +336,36 @@ def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: ) +@pytest.mark.uses_tuple_args +def test_tuple_comprehension(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, ...]: + return tuple(tracer * factor for tracer in tracers) + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t), + ) + + +@pytest.mark.uses_tuple_args +def test_tuple_vararg(cartesian_case): + @gtx.field_operator + def testee( + tracers: tuple[cases.IFloatField, ...], factor: float + ) -> tuple[cases.IFloatField, cases.IFloatField]: + return tracers[0] * factor, tracers[1] * factor + + cases.verify_with_default_data( + cartesian_case, + testee, + ref=lambda t, f: tuple(el * f for el in t[:2]), + ) + + @pytest.mark.uses_tuple_args @pytest.mark.xfail(reason="Iterator of tuple approach in lowering does not allow this.") def test_tuple_arg_with_unpromotable_dims(unstructured_case):