Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/gt4py/next/ffront/field_operator_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ class TupleExpr(Expr):
elts: list[Expr]


# TODO: give a good error for tuple(... for el in iter if ...) so that users understand that and why we don't support conditionals
# TODO: should this have SymbolTableTrait since target declares a new symbol. Write test that has two comprehensions using the same target name.
class TupleComprehension(Expr):
"""
tuple(element_expr for target in iterable)
"""

element_expr: Expr
target: DataSymbol # TODO: how about `tuple(el1+el2 for el1, el2 in var_arg)`?
iterable: Expr


class UnaryOp(Expr):
op: dialect_ast_enums.UnaryOperator
operand: Expr
Expand Down
26 changes: 25 additions & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import collections
from typing import Any, Optional, TypeAlias, TypeVar, cast

import gt4py.next.ffront.field_operator_ast as foast
Expand Down Expand Up @@ -501,6 +501,10 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> foast.Subscri
f"Tuples need to be indexed with literal integers, got '{node.index}'.",
) from ex
new_type = types[index]
case ts.VarArgType(element_type=element_type):
new_type = (
element_type # TODO: we only temporarily allow any index for vararg types
)
case ts.OffsetType(source=source, target=(target1, target2)):
if not target2.kind == DimensionKind.LOCAL:
raise errors.DSLError(
Expand Down Expand Up @@ -747,6 +751,26 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> foast.TupleEx
new_type = ts.TupleType(types=[element.type for element in new_elts])
return foast.TupleExpr(elts=new_elts, type=new_type, location=node.location)

def visit_TupleComprehension(
self, node: foast.TupleComprehension, **kwargs: Any
) -> foast.TupleComprehension:
symtable: collections.ChainMap = kwargs["symtable"] # todo annotation
iterable = self.visit(node.iterable, **kwargs)
target = self.visit(node.target, **kwargs)
assert isinstance(iterable.type, ts.VarArgType)
target.type = iterable.type.element_type
element_expr = self.visit(
node.element_expr,
**{**kwargs, "symtable": symtable.new_child({node.target.id: target})},
)
return foast.TupleComprehension(
element_expr=element_expr,
target=target,
iterable=iterable,
location=node.location,
type=ts.VarArgType(element_type=element_expr.type),
)

def visit_Call(self, node: foast.Call, **kwargs: Any) -> foast.Call:
new_func = self.visit(node.func, **kwargs)
new_args = self.visit(node.args, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/ffront/foast_pretty_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def apply(cls, node: foast.LocatedNode, **kwargs: Any) -> str: # type: ignore[o

TupleExpr = as_fmt("({', '.join(elts)}{',' if len(elts)==1 else ''})")

TupleComprehension = as_fmt("tuple(({element_expr} for {target} in {iterable}))")

UnaryOp = as_fmt("{op}{operand}")

def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> str:
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,15 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr:
def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr:
return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts])

def visit_TupleComprehension(self, node: foast.TupleComprehension, **kwargs: Any) -> itir.Expr:
return im.call(
im.call("map_tuple")(
im.lambda_(self.visit(node.target, **kwargs))(
self.visit(node.element_expr, **kwargs)
)
)
)(self.visit(node.iterable, **kwargs))

def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
# TODO(tehrengruber): extend iterator ir to support unary operators
dtype = type_info.extract_dtype(node.type)
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gt4py.next.ffront.stages import ConcreteFOASTOperatorDef, ConcretePASTProgramDef
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import toolchain, workflow
from gt4py.next.type_system import type_info, type_specifications as ts
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -113,9 +113,9 @@ def __call__(self, inp: ConcreteFOASTOperatorDef) -> ConcretePASTProgramDef:
*partial_program_type.definition.kw_only_args.keys(),
]
assert isinstance(type_, ts.CallableType)
assert arg_types[-1] == type_info.return_type(
type_, with_args=list(arg_types), with_kwargs=kwarg_types
)
# assert arg_types[-1] == type_info.return_type(
# type_, with_args=list(arg_types), with_kwargs=kwarg_types
# )
assert args_names[-1] == "out"

params_decl: list[past.Symbol] = [
Expand Down
31 changes: 27 additions & 4 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,12 @@ def visit_Expr(self, node: ast.Expr) -> foast.Expr:
return self.visit(node.value)

def visit_Name(self, node: ast.Name, **kwargs: Any) -> foast.Name:
return foast.Name(id=node.id, location=self.get_location(node))
loc = self.get_location(node)
if isinstance(node.ctx, ast.Store):
return foast.DataSymbol(id=node.id, location=loc, type=ts.DeferredType(constraint=None))
else:
assert isinstance(node.ctx, ast.Load)
return foast.Name(id=node.id, location=loc)

def visit_UnaryOp(self, node: ast.UnaryOp, **kwargs: Any) -> foast.UnaryOp:
return foast.UnaryOp(
Expand Down Expand Up @@ -469,8 +474,10 @@ def visit_NotEq(self, node: ast.NotEq, **kwargs: Any) -> foast.CompareOperator:
return foast.CompareOperator.NOTEQ

def _verify_builtin_type_constructor(self, node: ast.Call) -> None:
if len(node.args) > 0:
arg = node.args[0]
(arg,) = (
node.args
) # note for review: the change here is unrelated to the actual pr and just a small cleanup
if node.func.id == "tuple":
if not (
isinstance(arg, ast.Constant)
or (isinstance(arg, ast.UnaryOp) and isinstance(arg.operand, ast.Constant))
Expand All @@ -484,9 +491,25 @@ def _func_name(self, node: ast.Call) -> str:
return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly.

def visit_Call(self, node: ast.Call, **kwargs: Any) -> foast.Call:
# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
if isinstance(node.func, ast.Name):
func_name = self._func_name(node)

if func_name == "tuple":
(gen_expr,) = node.args
assert (
len(gen_expr.generators) == 1
) # we don't support (... for ... in ... for ... in ...)
assert (
gen_expr.generators[0].ifs == []
) # we don't support if conditions in comprehensions
return foast.TupleComprehension(
element_expr=self.visit(gen_expr.elt, **kwargs),
target=self.visit(gen_expr.generators[0].target, **kwargs),
iterable=self.visit(gen_expr.generators[0].iter, **kwargs),
location=self.get_location(node),
)

# TODO(tehrengruber): is this still needed or redundant with the checks in type deduction?
if func_name in fbuiltins.TYPE_BUILTIN_NAMES:
self._verify_builtin_type_constructor(node)

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def visit_Call(self, node: past.Call, **kwargs: Any) -> past.Call:
operator_return_type = type_info.return_type(
new_func.type, with_args=arg_types, with_kwargs=kwarg_types
)
if operator_return_type != new_kwargs["out"].type:
if not type_info.is_compatible_type(operator_return_type, new_kwargs["out"].type):
raise ValueError(
"Expected keyword argument 'out' to be of "
f"type '{operator_return_type}', got "
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ def get_domain_range(*args):
"lift",
"make_const_list",
"make_tuple",
"map_",
"map_tuple",
"map_", # TODO: rename to map_list
"named_range",
"neighbors",
"reduce",
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
prune_empty_concat_where,
remove_broadcast,
symbol_ref_utils,
unroll_map_tuple,
)
from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet
from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple
Expand Down Expand Up @@ -179,6 +180,7 @@ def apply_common_transforms(
) # domain inference does not support dynamic offsets yet
ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids)

ir = infer_domain.infer_program(
ir,
Expand Down Expand Up @@ -293,6 +295,7 @@ def apply_fieldview_transforms(

ir = infer_domain_ops.InferDomainOps.apply(ir)
ir = concat_where.canonicalize_domain_argument(ir)
ir = unroll_map_tuple.UnrollMapTuple.apply(ir, uids=uids)
ir = ConstantFolding.apply(ir) # type: ignore[assignment] # always an itir.Program

ir = infer_domain.infer_program(
Expand Down
47 changes: 47 additions & 0 deletions src/gt4py/next/iterator/transforms/unroll_map_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import dataclasses

from gt4py import eve
from gt4py.next import utils
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im
from gt4py.next.iterator.type_system import inference as itir_inference
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass
class UnrollMapTuple(eve.NodeTranslator):
PRESERVED_ANNEX_ATTRS = ("domain",)

uids: utils.IDGeneratorPool

@classmethod
def apply(cls, program: itir.Program, *, uids: utils.IDGeneratorPool):
return cls(uids=uids).visit(program)

def visit_FunCall(self, node: itir.Expr):
node = self.generic_visit(node)

if cpm.is_call_to(node.fun, "map_tuple"):
# TODO: we have to duplicate the function here since the domain inference can not handle them yet
f = node.fun.args[0]
tup = node.args[0]
itir_inference.reinfer(tup)
assert isinstance(tup.type, ts.TupleType)
tup_ref = next(self.uids["_ump"])

result = im.let(tup_ref, tup)(
im.make_tuple(
*(im.call(f)(im.tuple_get(i, tup_ref)) for i in range(len(tup.type.types)))
)
)
itir_inference.reinfer(result)

return result
return node
13 changes: 13 additions & 0 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,19 @@ def applied_map(
return applied_map


@_register_builtin_type_synthesizer
def map_tuple(op: TypeSynthesizer) -> TypeSynthesizer:
@type_synthesizer
def applied_map(
arg: ts.TupleType, offset_provider_type: common.OffsetProviderType
) -> ts.TupleType:
return ts.TupleType(
types=[op(arg_, offset_provider_type=offset_provider_type) for arg_ in arg.types]
)

return applied_map


@_register_builtin_type_synthesizer
def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer:
@type_synthesizer
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/type_system/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,14 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool:
or issubclass(type_class(to_type), symbol_type.constraint)
):
return True
if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.VarArgType):
return is_concretizable(symbol_type.element_type, to_type.element_type)
if isinstance(symbol_type, ts.VarArgType) and isinstance(to_type, ts.TupleType):
if len(to_type.types) == 0 or (
all(type_ == to_type.types[0] for type_ in to_type.types)
and is_concretizable(symbol_type.element_type, to_type.types[0])
):
return True
elif is_concrete(symbol_type):
return symbol_type == to_type
return False
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ def __len__(self) -> int:
return len(self.types)


class VarArgType(DataType):
"""Represents a variable number of arguments of the same type."""

element_type: DataType # TODO: maybe also support different DataTypes

def __str__(self) -> str:
return f"VarArg[{self.element_type}]"


class AnyPythonType:
"""Marker type representing any Python type which cannot be used for instantiation.

Expand Down
22 changes: 19 additions & 3 deletions src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,12 @@ def from_type_hint(
case builtins.tuple:
if not args:
raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.")
if Ellipsis in args:
raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.")
if len(args) == 2 and args[1] is Ellipsis:
return ts.VarArgType(element_type=from_type_hint_same_ns(args[0]))
elif Ellipsis in args:
raise ValueError(
f"Vararg tuple annotation '{type_hint}' cannot have more than one argument."
)
tuple_types = [from_type_hint_same_ns(arg) for arg in args]
assert all(isinstance(elem, ts.DataType) for elem in tuple_types)
return ts.TupleType(types=tuple_types)
Expand Down Expand Up @@ -321,7 +325,19 @@ def from_value(value: Any) -> ts.TypeSpec:
return UnknownPythonObject(value)
else:
type_ = xtyping.infer_type(value, annotate_callable_kwargs=True)
symbol_type = from_type_hint(type_)
if type_ == type[tuple]:
# TODO: this special casing here is not nice, but infer_type is also called on the annotations where
# we don't want to allow unparameterized tuples (or do we?).
symbol_type = ts.ConstructorType(
definition=ts.FunctionType(
pos_only_args=[ts.DeferredType(constraint=None)],
pos_or_kw_args={},
kw_only_args={},
returns=ts.DeferredType(constraint=ts.VarArgType),
)
)
else:
symbol_type = from_type_hint(type_)

if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)):
return symbol_type
Expand Down
11 changes: 11 additions & 0 deletions tests/next_tests/integration_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,15 @@ def _allocate_from_type(
for t in types
)
)
case ts.VarArgType(element_type=element_type):
return tuple(
(
_allocate_from_type(
case=case, arg_type=t, domain=domain, dtype=dtype, strategy=strategy
)
for t in [element_type] * 3 # TODO: revisit
)
)
case ts.NamedCollectionType(types=types) as named_collection_type_spec:
container_constructor = (
named_collections.make_named_collection_constructor_from_type_spec(
Expand Down Expand Up @@ -648,6 +657,8 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) ->
return sum([get_param_size(t, sizes=sizes) for t in types])
case ts.NamedCollectionType(types=types):
return sum([get_param_size(t, sizes=sizes) for t in types])
case ts.VarArgType(element_type=element_type):
return get_param_size(ts.TupleType(types=[element_type] * 3), sizes) # TODO: revisit
case _:
raise TypeError(f"Can not get size for parameter of type '{param_type}'.")

Expand Down
Loading
Loading