diff --git a/loki/analyse/analyse_dataflow.py b/loki/analyse/analyse_dataflow.py index 67fd635f0..b333c9de2 100644 --- a/loki/analyse/analyse_dataflow.py +++ b/loki/analyse/analyse_dataflow.py @@ -327,6 +327,9 @@ def visit_VariableDeclaration(self, o, **kwargs): uses |= {o.symbols[0].type.kind} return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) + # The definition of the function has no effect on data flow + visit_StatementFunction = visit_Node + class DataflowAnalysisDetacher(Transformer): """ diff --git a/loki/analyse/tests/test_analyse_dataflow.py b/loki/analyse/tests/test_analyse_dataflow.py index 0b4722a40..3e236b180 100644 --- a/loki/analyse/tests/test_analyse_dataflow.py +++ b/loki/analyse/tests/test_analyse_dataflow.py @@ -25,11 +25,12 @@ def test_analyse_live_symbols(frontend): integer, intent(in) :: v1 integer, intent(inout) :: v2 integer, intent(out) :: v3 - integer :: i, j, n=10, tmp, a + integer :: i, j, k, n=10, tmp, a, b + b(k) = k + 1 do i=1,n do j=1,n - tmp = j + 1 + tmp = b(j) end do a = v2 + tmp end do diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index b4e455d12..b8bfd4742 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -37,9 +37,7 @@ ) from loki.expression import AttachScopesMapper from loki.logging import debug, detail, info, warning, error -from loki.tools import ( - as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup, dict_override -) +from loki.tools import as_tuple, flatten, CaseInsensitiveDict, dict_override from loki.types import BasicType, DerivedType, ProcedureType, SymbolAttributes, Scope from loki.config import config @@ -610,7 +608,11 @@ def visit_Type_Declaration_Stmt(self, o, **kwargs): * :class:`fparser.two.Fortran2003.Attr_Spec_List` * :class:`fparser.two.Fortran2003.Entity_Decl_List` """ - # First, obtain data type and attributes + source = kwargs.get('source') + label = kwargs.get('label') + scope = kwargs.get('scope') + + # First, obtain data type and basic declaration attributes _type = self.visit(o.children[0], **kwargs) attrs = self.visit(o.children[1], **kwargs) if o.children[1] else () attrs = dict(attrs) @@ -619,7 +621,8 @@ def visit_Type_Declaration_Stmt(self, o, **kwargs): _type = _type.clone(**attrs) # Last, instantiate declared variables - variables = as_tuple(self.visit(o.children[2], **kwargs)) + with dict_override(kwargs, {'type': _type}): + variables = as_tuple(self.visit(o.children[2], **kwargs)) # DIMENSION is called shape for us if _type.dimension: @@ -628,34 +631,20 @@ def visit_Type_Declaration_Stmt(self, o, **kwargs): # representation of variables in declarations variables = as_tuple(v.clone(dimensions=_type.shape) for v in variables) - # Make sure KIND and INITIAL (which can be a name) are in the right scope - scope = kwargs['scope'] - if _type.kind is not None: - kind = AttachScopesMapper()(_type.kind, scope=scope) - _type = _type.clone(kind=kind) - if _type.initial is not None: - initial = AttachScopesMapper()(_type.initial, scope=scope) - _type = _type.clone(initial=initial) - # EXTERNAL attribute means this is actually a function or subroutine # Since every symbol refers to a different function we have to update the # type definition for every symbol individually if _type.external: for var in variables: - type_kwargs = _type.__dict__.copy() return_type = SymbolAttributes(_type.dtype) if _type.dtype is not None else None - external_type = scope.symbol_attrs.lookup(var.name) - if external_type is None: - type_kwargs['dtype'] = ProcedureType( - var.name, is_function=return_type is not None, return_type=return_type - ) - else: - type_kwargs['dtype'] = external_type.dtype - scope.symbol_attrs[var.name] = var.type.clone(**type_kwargs) + proc_type = ProcedureType( + var.name, is_function=return_type is not None, return_type=return_type + ) + scope.update(var.name, dtype=proc_type) variables = tuple(var.rescope(scope=scope) for var in variables) return ir.ProcedureDeclaration( - symbols=variables, external=True, source=kwargs.get('source'), label=kwargs.get('label') + symbols=variables, external=True, source=source, label=label ) # Update symbol table entries and rescope @@ -663,8 +652,7 @@ def visit_Type_Declaration_Stmt(self, o, **kwargs): variables = tuple(var.rescope(scope=scope) for var in variables) return ir.VariableDeclaration( - symbols=variables, dimensions=_type.shape, - source=kwargs.get('source'), label=kwargs.get('label') + symbols=variables, dimensions=_type.shape, source=source, label=label ) def visit_Intrinsic_Type_Spec(self, o, **kwargs): @@ -838,12 +826,13 @@ def visit_Entity_Decl(self, o, **kwargs): * char length (:class:`fparser.two.Fortran2003.Char_Length`) * init (:class:`fparser.two.Fortran2003.Initialization`) """ + _type = kwargs.get('type', {}) + scope = kwargs.get('scope') - # Do not pass scope down, as it might alias with previously - # created symbols. Instead, let the rescope in the Declaration - # assign the right scope, always! - with dict_override(kwargs, {'scope': None}): - var = self.visit(o.children[0], **kwargs) + # Declare basic variable type and create variable symbol + vname = o.children[0].tostr() + scope.declare(vname, **dict(_type.__dict__), fail=False) + var = self.visit(o.children[0], **kwargs) if o.children[1]: dimensions = as_tuple(self.visit(o.children[1], **kwargs)) @@ -1924,6 +1913,15 @@ def visit_Function_Subprogram(self, o, **kwargs): (routine, return_type) = self.visit(function_stmt, **kwargs) kwargs['scope'] = routine + # Define the return type in the local scope before parsing spec. + # If the return type is implicit (function name), we need to + # put a dummy declaration here, so that the spec does not see the + # ProcedureType the parent has for this Function. + if return_type: + routine.symbol_attrs[routine.result_name] = return_type + else: + routine.declare(routine.result_name, dtype=BasicType.DEFERRED, fail=False) + # Extract source object for construct source = self.get_source(function_stmt, end_node=end_function_stmt) @@ -1956,10 +1954,6 @@ def visit_Function_Subprogram(self, o, **kwargs): # symbols in the spec part to make them coherent with the symbol table spec = AttachScopes().visit(spec, scope=routine, recurse_to_declaration_attributes=True) - # If the return type is given, inject it into the symbol table - if return_type: - routine.symbol_attrs[routine.result_name] = return_type - # Now all declarations are well-defined and we can parse the member routines contains = self.visit(get_child(o, Fortran2003.Internal_Subprogram_Part), **kwargs) @@ -3363,26 +3357,18 @@ def visit_Assignment_Stmt(self, o, **kwargs): ) if could_be_a_statement_func: - def _create_stmt_func_type(stmt_func): - name = str(stmt_func.variable) - procedure = LazyNodeLookup( - anchor=kwargs['scope'], - query=lambda x: [ - f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name - ][0] - ) - proc_type = ProcedureType(is_function=True, procedure=procedure, name=name) - return SymbolAttributes(dtype=proc_type, is_stmt_func=True) - + # Create the procedure symbol and statement function IR node f_symbol = sym.ProcedureSymbol(name=lhs.name, scope=kwargs['scope']) stmt_func = ir.StatementFunction( variable=f_symbol, arguments=lhs.dimensions, rhs=rhs, return_type=symbol_attrs[lhs.name], - label=kwargs.get('label'), source=kwargs.get('source') + parent=kwargs['scope'], label=kwargs.get('label'), + source=kwargs.get('source') ) # Update the type in the local scope and return stmt func node - symbol_attrs[str(stmt_func.variable)] = _create_stmt_func_type(stmt_func) + proc_type = ProcedureType(name=lhs.name, procedure=stmt_func, is_function=True) + kwargs['scope'].declare(lhs.name, dtype=proc_type, is_stmt_func=True, fail=False) return stmt_func # Return Assignment node if we don't have to deal with the stupid side of Fortran! diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 9e2444255..8e753eb8e 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -1571,9 +1571,12 @@ class _StatementFunctionBase(): @dataclass_strict(frozen=True) -class StatementFunction(LeafNode, _StatementFunctionBase): +class StatementFunction(ScopedNode, LeafNode, _StatementFunctionBase): """ - Internal representation of Fortran statement function statements + Internal representation of Fortran statement function statements. + + Internally, this is considered a :any:`ScopedNode`, because it may + be the target of a :any:`ProcedureType`. Parameters ---------- @@ -1589,8 +1592,10 @@ class StatementFunction(LeafNode, _StatementFunctionBase): _traversable = ['variable', 'arguments', 'rhs'] - def __post_init__(self): - super().__post_init__() + def __post_init__(self, parent=None): + super(ScopedNode, self).__post_init__(parent=parent) + super(LeafNode, self).__post_init__() + assert isinstance(self.variable, Expression) assert is_iterable(self.arguments) and all(isinstance(a, Expression) for a in self.arguments) assert isinstance(self.return_type, SymbolAttributes) @@ -1599,6 +1604,10 @@ def __post_init__(self): def name(self): return str(self.variable) + @property + def variables(self): + return self.arguments + @property def is_function(self): return True diff --git a/loki/program_unit.py b/loki/program_unit.py index 45696b145..62ab9e959 100644 --- a/loki/program_unit.py +++ b/loki/program_unit.py @@ -722,7 +722,9 @@ def subroutines(self): routine for routine in self.contains.body if isinstance(routine, Subroutine) ]) + # Semantic aliases for convenience routines = subroutines + procedures = subroutines @property def subroutine_map(self): diff --git a/loki/tests/test_internal_procedures.py b/loki/tests/test_internal_procedures.py new file mode 100644 index 000000000..d499e9d50 --- /dev/null +++ b/loki/tests/test_internal_procedures.py @@ -0,0 +1,258 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from loki import Subroutine, fgen +from loki.frontend import available_frontends +from loki.ir import FindVariables, FindInlineCalls +from loki.jit_build import jit_compile, clean_test +from loki.types import INTEGER + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_internal_procedures(tmp_path, frontend): + """ + Test internal subroutine and function + """ + fcode = """ +subroutine routine_internal_procedures(in1, in2, out1, out2) + ! Test internal subroutine and function + implicit none + integer, intent(in) :: in1, in2 + integer, intent(out) :: out1, out2 + integer :: localvar + + localvar = in2 + + call internal_procedure(in1, out1) + out2 = internal_function(out1) +contains + subroutine internal_procedure(in1, out1) + ! This internal procedure shadows some variables and uses + ! a variable from the parent scope + implicit none + integer, intent(in) :: in1 + integer, intent(out) :: out1 + + out1 = 5 * in1 + localvar + internal_function(1) + end subroutine internal_procedure + + ! Below is disabled because f90wrap (wrongly) exhibits that + ! symbol to the public, which causes double defined symbols + ! upon compilation. + + function internal_function(in2) + ! This function is just included to test that functions + ! are also possible + implicit none + integer, intent(in) :: in2 + integer :: internal_function + + internal_function = 3 * in2 + 2 + end function internal_function +end subroutine routine_internal_procedures +""" + # Check that internal procedures are parsed correctly + routine = Subroutine.from_source(fcode, frontend=frontend) + assert len(routine.procedures) == 2 + + assert routine.procedures[0].name == 'internal_procedure' + assert routine.procedures[0].symbol_attrs.lookup('localvar', recursive=False) is None + assert routine.procedures[0].symbol_attrs.lookup('localvar') is not None + assert routine.procedures[0].get_symbol_scope('localvar') is routine + assert routine.procedures[0].symbol_attrs.lookup('in1') is not None + assert routine.symbol_attrs.lookup('in1') is not None + assert routine.procedures[0].get_symbol_scope('in1') is routine.procedures[0] + + # Check that inline function is correctly identified + inline_calls = list(FindInlineCalls().visit(routine.procedures[0].body)) + assert len(inline_calls) == 1 + assert inline_calls[0].function.name == 'internal_function' + assert inline_calls[0].function.type.dtype.procedure == routine.procedures[1] + + assert routine.procedures[1].name == 'internal_function' + assert routine.procedures[1].symbol_attrs.lookup('in2') is not None + assert routine.procedures[1].get_symbol_scope('in2') is routine.procedures[1] + assert routine.symbol_attrs.lookup('in2') is not None + assert routine.get_symbol_scope('in2') is routine + + # Generate code, compile and load + filepath = tmp_path/(f'routine_internal_procedures_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname='routine_internal_procedures') + + # Test results of the generated and compiled code + out1, out2 = function(1, 2) + assert out1 == 12 + assert out2 == 38 + clean_test(filepath) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_internal_routine_clone(frontend): + """ + Test that internal subroutine scopes get cloned correctly. + """ + fcode = """ +subroutine internal_routine_clone(in1, in2, out1, out2) + ! Test internal subroutine and function + implicit none + integer, intent(in) :: in1, in2 + integer, intent(out) :: out1, out2 + integer :: localvar + + localvar = in2 + + call internal_procedure(in1, out1) + out2 = 3 * out1 + 2 + +contains + subroutine internal_procedure(in1, out1) + ! This internal procedure shadows some variables and uses + ! a variable from the parent scope + implicit none + integer, intent(in) :: in1 + integer, intent(out) :: out1 + + out1 = 5 * in1 + localvar + end subroutine internal_procedure +end subroutine +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + new_routine = routine.clone() + + # Ensure we have cloned parent and internal routine + assert routine is not new_routine + assert routine.procedures[0] is not new_routine.procedures[0] + assert fgen(routine) == fgen(new_routine) + assert fgen(routine.procedures[0]) == fgen(new_routine.procedures[0]) + + # Check that the scopes are linked correctly + assert routine.procedures[0].parent is routine + assert new_routine.procedures[0].parent is new_routine + + # Check that variables are in the right scope everywhere + assert all(v.scope is routine for v in FindVariables().visit(routine.ir)) + assert all(v.scope in (routine, routine.procedures[0]) for v in FindVariables().visit(routine.procedures[0].ir)) + assert all(v.scope is new_routine for v in FindVariables().visit(new_routine.ir)) + assert all( + v.scope in (new_routine, new_routine.procedures[0]) + for v in FindVariables().visit(new_routine.procedures[0].ir) + ) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_internal_routine_clone_inplace(frontend): + """ + Test that internal subroutine scopes get cloned correctly. + """ + fcode = """ +subroutine internal_routine_clone(in1, in2, out1, out2) + ! Test internal subroutine and function + implicit none + integer, intent(in) :: in1, in2 + integer, intent(out) :: out1, out2 + integer :: localvar + + localvar = in2 + + call internal_procedure(in1, out1) + out2 = 3 * out1 + 2 + +contains + subroutine internal_procedure(in1, out1) + ! This internal procedure shadows some variables and uses + ! a variable from the parent scope + implicit none + integer, intent(in) :: in1 + integer, intent(out) :: out1 + + out1 = 5 * in1 + localvar + end subroutine internal_procedure + + subroutine other_internal(inout1) + ! Another internal routine that uses a parent symbol + implicit none + integer, intent(inout) :: inout1 + + inout1 = 2 * inout1 + localvar + end subroutine other_internal +end subroutine +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + # Make sure the initial state is as expected + internal = routine['internal_procedure'] + assert internal.parent is routine + assert internal.symbol_attrs.parent is routine.symbol_attrs + other_internal = routine['other_internal'] + assert other_internal.parent is routine + assert other_internal.symbol_attrs.parent is routine.symbol_attrs + + # Put the inherited symbol in the local scope, first with a clean clone... + internal.variables += (routine.variable_map['localvar'].clone(scope=internal),) + internal = internal.clone(parent=None) + # ...and then with a clone that preserves the symbol table + other_internal.variables += (routine.variable_map['localvar'].clone(scope=other_internal),) + other_internal = other_internal.clone(parent=None, symbol_attrs=other_internal.symbol_attrs) + # Ultimately, remove the internal routines + routine = routine.clone(contains=None) + + # Check that variables are in the right scope everywhere + assert all(v.scope is routine for v in FindVariables().visit(routine.ir)) + assert all(v.scope is internal for v in FindVariables().visit(internal.ir)) + + # Check that we aren't looking somewhere above anymore + assert internal.parent is None + assert internal.symbol_attrs.parent is None + assert internal.parent is None + assert internal.symbol_attrs._parent is None + assert other_internal.parent is None + assert other_internal.symbol_attrs.parent is None + assert other_internal.parent is None + assert other_internal.symbol_attrs.parent is None + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_internal_procedures_alias(frontend): + """ Test local variable aliases in internal subroutine and function """ + fcode = """ +subroutine outer_routine(in, out) + implicit none + integer, intent(in) :: in + integer, intent(out) :: out + integer :: a, b(2, 2) + + b(1, 1) = in + + call internal_routine(in, out) +contains + + subroutine internal_routine(in, out) + integer, intent(in) :: in + integer, intent(out) :: out + integer :: a(3, 4) + + a(1, 2) = 3 + out = a(1, 2) + b(1, 1) + in + end subroutine internal_routine +end subroutine outer_routine +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + internal = routine['internal_routine'] + + a_outer = routine.get_type('a') + assert a_outer.dtype == INTEGER + assert not a_outer.shape + + b_outer = routine.get_type('b') + assert b_outer.dtype == INTEGER + assert b_outer.shape == (2, 2) + + a_inner = internal.get_type('a') + assert a_inner.dtype == INTEGER + assert a_inner.shape == (3, 4) diff --git a/loki/tests/test_subroutine.py b/loki/tests/test_subroutine.py index 1ae20262a..07eee50c2 100644 --- a/loki/tests/test_subroutine.py +++ b/loki/tests/test_subroutine.py @@ -15,7 +15,7 @@ from loki.frontend import available_frontends, OMNI, REGEX from loki.ir import ( nodes as ir, FindNodes, FindVariables, FindTypedSymbols, - FindInlineCalls, Transformer + Transformer ) from loki.types import ( BasicType, DerivedType, ProcedureType, SymbolAttributes @@ -930,209 +930,6 @@ def test_empty_spec(frontend): assert len(routine.body.body) == 1 -@pytest.mark.parametrize('frontend', available_frontends()) -def test_member_procedures(tmp_path, frontend): - """ - Test member subroutine and function - """ - fcode = """ -subroutine routine_member_procedures(in1, in2, out1, out2) - ! Test member subroutine and function - implicit none - integer, intent(in) :: in1, in2 - integer, intent(out) :: out1, out2 - integer :: localvar - - localvar = in2 - - call member_procedure(in1, out1) - out2 = member_function(out1) -contains - subroutine member_procedure(in1, out1) - ! This member procedure shadows some variables and uses - ! a variable from the parent scope - implicit none - integer, intent(in) :: in1 - integer, intent(out) :: out1 - - out1 = 5 * in1 + localvar + member_function(1) - end subroutine member_procedure - - ! Below is disabled because f90wrap (wrongly) exhibits that - ! symbol to the public, which causes double defined symbols - ! upon compilation. - - function member_function(in2) - ! This function is just included to test that functions - ! are also possible - implicit none - integer, intent(in) :: in2 - integer :: member_function - - member_function = 3 * in2 + 2 - end function member_function -end subroutine routine_member_procedures -""" - # Check that member procedures are parsed correctly - routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(routine.members) == 2 - - assert routine.members[0].name == 'member_procedure' - assert routine.members[0].symbol_attrs.lookup('localvar', recursive=False) is None - assert routine.members[0].symbol_attrs.lookup('localvar') is not None - assert routine.members[0].get_symbol_scope('localvar') is routine - assert routine.members[0].symbol_attrs.lookup('in1') is not None - assert routine.symbol_attrs.lookup('in1') is not None - assert routine.members[0].get_symbol_scope('in1') is routine.members[0] - - # Check that inline function is correctly identified - inline_calls = list(FindInlineCalls().visit(routine.members[0].body)) - assert len(inline_calls) == 1 - assert inline_calls[0].function.name == 'member_function' - assert inline_calls[0].function.type.dtype.procedure == routine.members[1] - - assert routine.members[1].name == 'member_function' - assert routine.members[1].symbol_attrs.lookup('in2') is not None - assert routine.members[1].get_symbol_scope('in2') is routine.members[1] - assert routine.symbol_attrs.lookup('in2') is not None - assert routine.get_symbol_scope('in2') is routine - - # Generate code, compile and load - filepath = tmp_path/(f'routine_member_procedures_{frontend}.f90') - function = jit_compile(routine, filepath=filepath, objname='routine_member_procedures') - - # Test results of the generated and compiled code - out1, out2 = function(1, 2) - assert out1 == 12 - assert out2 == 38 - clean_test(filepath) - - -@pytest.mark.parametrize('frontend', available_frontends()) -def test_member_routine_clone(frontend): - """ - Test that member subroutine scopes get cloned correctly. - """ - fcode = """ -subroutine member_routine_clone(in1, in2, out1, out2) - ! Test member subroutine and function - implicit none - integer, intent(in) :: in1, in2 - integer, intent(out) :: out1, out2 - integer :: localvar - - localvar = in2 - - call member_procedure(in1, out1) - out2 = 3 * out1 + 2 - -contains - subroutine member_procedure(in1, out1) - ! This member procedure shadows some variables and uses - ! a variable from the parent scope - implicit none - integer, intent(in) :: in1 - integer, intent(out) :: out1 - - out1 = 5 * in1 + localvar - end subroutine member_procedure -end subroutine -""" - routine = Subroutine.from_source(fcode, frontend=frontend) - new_routine = routine.clone() - - # Ensure we have cloned routine and member - assert routine is not new_routine - assert routine.members[0] is not new_routine.members[0] - assert fgen(routine) == fgen(new_routine) - assert fgen(routine.members[0]) == fgen(new_routine.members[0]) - - # Check that the scopes are linked correctly - assert routine.members[0].parent is routine - assert new_routine.members[0].parent is new_routine - - # Check that variables are in the right scope everywhere - assert all(v.scope is routine for v in FindVariables().visit(routine.ir)) - assert all(v.scope in (routine, routine.members[0]) for v in FindVariables().visit(routine.members[0].ir)) - assert all(v.scope is new_routine for v in FindVariables().visit(new_routine.ir)) - assert all( - v.scope in (new_routine, new_routine.members[0]) - for v in FindVariables().visit(new_routine.members[0].ir) - ) - - -@pytest.mark.parametrize('frontend', available_frontends()) -def test_member_routine_clone_inplace(frontend): - """ - Test that member subroutine scopes get cloned correctly. - """ - fcode = """ -subroutine member_routine_clone(in1, in2, out1, out2) - ! Test member subroutine and function - implicit none - integer, intent(in) :: in1, in2 - integer, intent(out) :: out1, out2 - integer :: localvar - - localvar = in2 - - call member_procedure(in1, out1) - out2 = 3 * out1 + 2 - -contains - subroutine member_procedure(in1, out1) - ! This member procedure shadows some variables and uses - ! a variable from the parent scope - implicit none - integer, intent(in) :: in1 - integer, intent(out) :: out1 - - out1 = 5 * in1 + localvar - end subroutine member_procedure - - subroutine other_member(inout1) - ! Another member that uses a parent symbol - implicit none - integer, intent(inout) :: inout1 - - inout1 = 2 * inout1 + localvar - end subroutine other_member -end subroutine -""" - routine = Subroutine.from_source(fcode, frontend=frontend) - - # Make sure the initial state is as expected - member = routine['member_procedure'] - assert member.parent is routine - assert member.symbol_attrs.parent is routine.symbol_attrs - other_member = routine['other_member'] - assert other_member.parent is routine - assert other_member.symbol_attrs.parent is routine.symbol_attrs - - # Put the inherited symbol in the local scope, first with a clean clone... - member.variables += (routine.variable_map['localvar'].clone(scope=member),) - member = member.clone(parent=None) - # ...and then with a clone that preserves the symbol table - other_member.variables += (routine.variable_map['localvar'].clone(scope=other_member),) - other_member = other_member.clone(parent=None, symbol_attrs=other_member.symbol_attrs) - # Ultimately, remove the member routines - routine = routine.clone(contains=None) - - # Check that variables are in the right scope everywhere - assert all(v.scope is routine for v in FindVariables().visit(routine.ir)) - assert all(v.scope is member for v in FindVariables().visit(member.ir)) - - # Check that we aren't looking somewhere above anymore - assert member.parent is None - assert member.symbol_attrs.parent is None - assert member.parent is None - assert member.symbol_attrs._parent is None - assert other_member.parent is None - assert other_member.symbol_attrs.parent is None - assert other_member.parent is None - assert other_member.symbol_attrs.parent is None - - @pytest.mark.parametrize('frontend', available_frontends()) def test_external_stmt(tmp_path, frontend): """ diff --git a/loki/tools/util.py b/loki/tools/util.py index 40c24a797..8e420f1bd 100644 --- a/loki/tools/util.py +++ b/loki/tools/util.py @@ -34,14 +34,12 @@ __all__ = [ 'as_tuple', 'is_iterable', 'is_subset', 'flatten', 'chunks', 'execute', 'CaseInsensitiveDict', 'CaseInsensitiveDefaultDict', - 'strip_inline_comments', - 'binary_insertion_sort', 'cached_func', 'optional', - 'LazyNodeLookup', 'yaml_include_constructor', + 'strip_inline_comments', 'binary_insertion_sort', 'cached_func', + 'optional', 'yaml_include_constructor', 'auto_post_mortem_debugger', 'set_excepthook', 'timeout', 'WeakrefProperty', 'group_by_class', 'replace_windowed', 'dict_override', 'stdchannel_redirected', - 'stdchannel_is_captured', 'graphviz_present', - 'OrderedSet' + 'stdchannel_is_captured', 'graphviz_present', 'OrderedSet' ] @@ -433,54 +431,6 @@ def optional(condition, context_manager, *args, **kwargs): yield -class LazyNodeLookup: - """ - Utility class for indirect, :any:`weakref`-style lookups - - References to IR nodes are usually not stable as the IR may be - rebuilt at any time. This class offers a way to refer to a node - in an IR by encoding how it can be found instead. - - .. note:: - **Example:** - Reference a declaration node that contains variable "a" - - .. code-block:: - - from loki import LazyNodeLookup, FindNodes, Declaration - # Assume this has been initialized before - # routine = ... - - # Create the reference - query = lambda x: [d for d in FindNodes(VariableDeclaration).visit(x.spec) if 'a' in d.symbols][0] - decl_ref = LazyNodeLookup(routine, query) - - # Use the reference (this carries out the query) - decl = decl_ref() - - Parameters - ---------- - anchor : - The "stable" anchor object to which :attr:`query` is applied to find the object. - This is stored internally as a :any:`weakref`. - query : - A function object that accepts a single argument and should return the lookup - result. To perform the lookup, :attr:`query` is called with :attr:`anchor` - as argument. - """ - - def __init__(self, anchor, query): - self._anchor = weakref.ref(anchor) - self.query = query - - @property - def anchor(self): - return self._anchor() - - def __call__(self): - return self.query(self.anchor) - - def yaml_include_constructor(loader, node): """ Add support for ``!include`` tags to YAML load diff --git a/loki/transformations/inline/functions.py b/loki/transformations/inline/functions.py index eeed45fd4..4402b2cc5 100644 --- a/loki/transformations/inline/functions.py +++ b/loki/transformations/inline/functions.py @@ -222,6 +222,10 @@ def inline_statement_functions(routine): # Apply expression-level substitution to routine routine.body = SubstituteExpressions(exprmap).visit(routine.body) + # Rescope the routine body, as the recursive update call does not + # set scopes on the RHS of `exprmap`, and so we might miss a scope. + routine.rescope_symbols() + # remove statement function declarations as well as statement function argument(s) declarations vars_to_remove = {stmt_func.variable.name.lower() for stmt_func in stmt_func_decls} vars_to_remove |= {arg.name.lower() for stmt_func in stmt_func_decls for arg in stmt_func.arguments} diff --git a/loki/transformations/inline/tests/test_functions.py b/loki/transformations/inline/tests/test_functions.py index a9eae7f10..011585f74 100644 --- a/loki/transformations/inline/tests/test_functions.py +++ b/loki/transformations/inline/tests/test_functions.py @@ -290,6 +290,7 @@ def test_inline_statement_functions(frontend, stmt_decls): else: assert FindInlineCalls().visit(routine.body) + @pytest.mark.parametrize('frontend', available_frontends( skip={OMNI: "OMNI automatically inlines Statement Functions"} )) @@ -343,7 +344,7 @@ def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_pa real, parameter :: rtt = 1.0 real :: PTARE real :: FOEDELTA - FOEDELTA ( PTARE ) = PTARE + 1.0 + MYFUNC(PTARE) + FOEDELTA ( PTARE ) = PTARE + MYFUNC(PTARE) + MAX(1.0, 2.0) real :: FOEEW FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE) + MYFUNC(PTARE) {intf} @@ -368,13 +369,15 @@ def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_pa inline_calls = FindInlineCalls(unique=False).visit(routine.spec) if provide_myfunc in ('module', 'interface', 'routine'): # Enough information available that MYFUNC is recognized as a procedure call - assert len(inline_calls) == 3 + assert len(inline_calls) == 4 assert all(isinstance(call.function.type.dtype, ProcedureType) for call in inline_calls) else: # No information available about MYFUNC, so fparser treats it as an ArraySubscript - assert len(inline_calls) == 1 - assert inline_calls[0].function == 'foedelta' - assert isinstance(inline_calls[0].function.type.dtype, ProcedureType) + assert len(inline_calls) == 2 + assert inline_calls[0].function == 'MAX' + assert inline_calls[0].function.type.is_intrinsic + assert inline_calls[1].function == 'foedelta' + assert isinstance(inline_calls[1].function.type.dtype, ProcedureType) # Check the body inline_calls = FindInlineCalls().visit(routine.body) @@ -390,23 +393,27 @@ def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_pa if provide_myfunc in ('import', 'intfb'): # MYFUNC(arr) is misclassified as array subscript - assert len(inline_calls) == 0 + assert len(inline_calls) == 3 elif provide_myfunc in ('module', 'routine'): # MYFUNC(arr) is eliminated due to inlining - assert len(inline_calls) == 0 + assert len(inline_calls) == 3 else: - assert len(inline_calls) == 4 + assert len(inline_calls) == 7 assert assignments[0].lhs == 'ret' assert assignments[1].lhs == 'ret2' if provide_myfunc in ('module', 'routine'): # Fully inlined due to definition of myfunc available - assert assignments[0].rhs == "arr + arr + 1.0 + arr*2.0 + arr*2.0" - assert assignments[1].rhs == "3.0 + 1.0 + 3.0*2.0 + val + 1.0 + val*2.0" + assert assignments[0].rhs == "arr + arr + arr*2.0 + max(1.0, 2.0) + arr*2.0" + assert assignments[1].rhs == "3.0 + 3.0*2.0 + max(1.0, 2.0) + val + val*2.0 + max(1.0, 2.0)" else: # myfunc not inlined - assert assignments[0].rhs == "arr + arr + 1.0 + myfunc(arr) + myfunc(arr)" - assert assignments[1].rhs == "3.0 + 1.0 + myfunc(3.0) + val + 1.0 + myfunc(val)" + assert assignments[0].rhs == "arr + arr + myfunc(arr) + max(1.0, 2.0) + myfunc(arr)" + assert assignments[1].rhs == "3.0 + myfunc(3.0) + max(1.0, 2.0) + val + myfunc(val) + max(1.0, 2.0)" + + # Ensure all copies of the intrinsic call are correctly scoped + assert all(c.function.scope == routine for c in inline_calls) + assert all(isinstance(c.function.type.dtype, ProcedureType) for c in inline_calls) @pytest.mark.parametrize('frontend', available_frontends()) diff --git a/loki/types/datatypes.py b/loki/types/datatypes.py index 0622762cf..56d93872e 100644 --- a/loki/types/datatypes.py +++ b/loki/types/datatypes.py @@ -15,7 +15,10 @@ from loki.tools import flatten -__all__ = ['DataType', 'BasicType'] +__all__ = [ + 'DataType', 'BasicType', + 'DEFERRED', 'LOGICAL', 'INTEGER', 'REAL', 'CHARACTER', 'COMPLEX' +] class DataType: @@ -99,3 +102,11 @@ def from_c99_type(cls, value): type_map.update({t: cls.COMPLEX for t in complex_types}) return type_map[value] + + +DEFERRED = BasicType.DEFERRED +LOGICAL = BasicType.LOGICAL +INTEGER = BasicType.INTEGER +REAL = BasicType.REAL +CHARACTER = BasicType.CHARACTER +COMPLEX = BasicType.COMPLEX diff --git a/loki/types/module_type.py b/loki/types/module_type.py index 46b2d4250..cece9977a 100644 --- a/loki/types/module_type.py +++ b/loki/types/module_type.py @@ -9,7 +9,6 @@ import weakref -from loki.tools import LazyNodeLookup from loki.types.datatypes import BasicType, DataType @@ -25,9 +24,8 @@ class ModuleType(DataType): Parameters ---------- name : str, optional - The name of the module. Can be skipped if :data:`module` - is provided (not in the form of a :any:`LazyNodeLookup`) - module : :any:`Module` :any:`LazyNodeLookup`, optional + The name of the module. Can be skipped if :data:`module` is provided. + module : :any:`Module`, optional The procedure this type represents """ @@ -35,14 +33,9 @@ def __init__(self, name=None, module=None): from loki.module import Module # pylint: disable=import-outside-toplevel,cyclic-import super().__init__() assert name or isinstance(module, Module) - if module is None or isinstance(module, LazyNodeLookup): - self._module = module - self._name = name - else: - self._module = weakref.ref(module) - # Cache all properties for when module link becomes inactive - assert name is None or name.lower() == self.module.name.lower() - self._name = self.module.name + + self.module = module + self._name = module.name if module else name @property def name(self): @@ -68,6 +61,13 @@ def module(self): return BasicType.DEFERRED return self._module() + @module.setter + def module(self, mod): + # pylint: disable=import-outside-toplevel,cyclic-import + from loki.module import Module + assert mod is None or isinstance(mod, Module) + self._module = None if mod is None else weakref.ref(mod) + def __str__(self): return self.name diff --git a/loki/types/procedure_type.py b/loki/types/procedure_type.py index abe4898f4..9f7898f64 100644 --- a/loki/types/procedure_type.py +++ b/loki/types/procedure_type.py @@ -9,7 +9,6 @@ import weakref -from loki.tools import LazyNodeLookup from loki.types.datatypes import BasicType, DataType @@ -22,24 +21,20 @@ class ProcedureType(DataType): This serves also as the cross-link between the use of a procedure (e.g. in a :any:`CallStatement`) to the :any:`Subroutine` object that is the target of - a call. If the corresponding object is not yet available when the - :any:`ProcedureType` object is created, or its definition is transient and - subject to IR rebuilds (e.g. :any:`StatementFunction`), the :any:`LazyNodeLookup` - utility can be used to defer the actual instantiation. In that situation, - :data:`name` should be provided in addition. + a call. Parameters ---------- name : str, optional The name of the function or subroutine. Can be skipped if :data:`procedure` - is provided (not in the form of a :any:`LazyNodeLookup`) + is provided. is_function : bool, optional Indicate that this is a function is_generic : bool, optional Indicate that this is a generic function is_intrinsic : bool, optional Indicate that this is an intrinsic function - procedure : :any:`Subroutine` or :any:`StatementFunction` or :any:`LazyNodeLookup`, optional + procedure : :any:`Subroutine` or :any:`StatementFunction`, optional The procedure this type represents """ @@ -56,23 +51,11 @@ def __init__( assert isinstance(return_type, SymbolAttributes) or procedure or not is_function or is_intrinsic self.is_generic = is_generic self.is_intrinsic = is_intrinsic - if procedure is None or isinstance(procedure, LazyNodeLookup): - self._procedure = procedure - self._name = name - self._is_function = is_function or False - self._return_type = return_type - # NB: not applying an assert on the procedure name for LazyNodeLookup as - # the point of the lazy lookup is that we might not have the the procedure - # definition available at type instantiation time - else: - self._procedure = weakref.ref(procedure) - # Cache all properties for when procedure link becomes inactive - assert name is None or name.lower() == self.procedure.name.lower() - self._name = self.procedure.name - assert is_function is None or is_function == self.procedure.is_function - self._is_function = self.procedure.is_function - # TODO: compare return type once type comparison is more robust - self._return_type = self.procedure.return_type if self.procedure.is_function else None + + self.procedure = procedure + self._name = procedure.name if procedure else name + self._is_function = is_function or False + self._return_type = return_type @property def _canonical(self): @@ -110,6 +93,15 @@ def procedure(self): return BasicType.DEFERRED return self._procedure() + @procedure.setter + def procedure(self, proc): + # pylint: disable=import-outside-toplevel,cyclic-import + from loki.subroutine import Subroutine + from loki.function import Function + from loki.ir import StatementFunction + assert proc is None or isinstance(proc, (Function, Subroutine, StatementFunction)) + self._procedure = None if proc is None else weakref.ref(proc) + @property def is_function(self): """ diff --git a/loki/types/scope.py b/loki/types/scope.py index 111af306c..65319d0b9 100644 --- a/loki/types/scope.py +++ b/loki/types/scope.py @@ -14,7 +14,8 @@ import weakref from loki.tools import WeakrefProperty -from loki.types.symbol_table import SymbolTable +from loki.types.datatypes import DataType +from loki.types.symbol_table import SymbolTable, SymbolAttributes __all__ = ['Scope'] @@ -162,3 +163,138 @@ def _reset_parent(self, parent): if self.parent is not None: self.symbol_attrs.parent = self.parent.symbol_attrs + + def declare(self, name, dtype, fail=True, **kwargs): + """ + Method to add type information, including the data type + :param:`dtype`, for a new variable in a :any:`Scope` and + update the symbol table accordingly. + + This method should be used for the initial type declaration of + a variable and will by default fail if the variable has + already been declared. To update an existing variable, the + :method:`update` should be used. + + To completely re-declare an existing variable, ``fail=True`` + can be passed to this method. This case prior type information + will be removed from the symbol table before the re-declaration. + + Parameters + ---------- + name : str + Name of the variable to declare symbol information for. + dtype : :any:`DataType` or str + Basic data typer of the variable + fail : bool, optional + Flag to override the default failure on attempted re-declaration. + **kwargs : optional + Any additional attributes that should be stored as properties in + the symbol table. + """ + + # Ensure `dtype` defines a known type + assert isinstance(dtype, (DataType, str)) + + if fail and name in self.symbol_attrs: + raise ValueError(f'[Loki::Scope] Trying to re-declare already declared symbol name: {name}') + + self.symbol_attrs[name] = SymbolAttributes(dtype, **kwargs) + + def update(self, name, fail=True, **kwargs): + """ + Method to update the type information of a variable in a the + symbol table of a :any:`Scope`. + + This method should only be used update symbol attributes after + the initial declaration of the variable via the + :meth:`update`. If the symbol has not already been declared, a + :any:`ValueError` is raised, unless the ``fail=True`` + override flag is set. + + Parameters + ---------- + name : str + Name of the variable to update symbol information for. + fail : bool, optional + Flag to override the default failure when variable was not declared. + **kwargs : optional + Any additional attributes that should be stored as properties in + the symbol table. + """ + + # Ensure `dtype` defines a known type + if 'dtype' in kwargs: + assert isinstance(kwargs['dtype'], (DataType, str)) + + if fail and name not in self.symbol_attrs: + raise ValueError(f'[Loki::Scope] Trying to update undeclared symbol name: {name}') + + if name in self.symbol_attrs: + self.symbol_attrs[name] = self.symbol_attrs[name].clone(**kwargs) + else: + self.symbol_attrs[name] = SymbolAttributes(**kwargs) + + def get_type(self, name, recursive=True, fail=True): + """ + Method to retrieve the type information (set of symbol + attributes) from the internal symbol table for a given + variable name. + + This method will be default fail if no type information is + found and may recurse to the parent :any:`Scope`. This default + behaviour can be overriden with the respective flags. + + Parameters + ---------- + name : str + Name of the variable to look up + recursive : bool + Use recursive look-up in parent scopes + fail : bool, optional + Flag to override the default failure when variable was not declared. + + Returns + ------- + :any:`SymbolAttributes` + The collection of attributes associated with this symbol + """ + + _type = self.symbol_attrs.lookup(name, recursive=recursive) + + # Check results and fail hard if requested + if fail and _type is None: + raise KeyError(f'[Loki::Scope] Cannot get type for undeclared symbol name: {name}') + + return _type + + def get_dtype(self, name, recursive=True, fail=True): + """ + Method to retrieve the data type from the internal symbol + table for a given variable name. + + This convenience method is equivalent to using + ``scope.get_type(name).dtype``. + + This method will be default fail if no type information is + found and may recurse to the parent :any:`Scope`. This default + behaviour can be overriden with the respective flags. + + Parameters + ---------- + name : str + Name of the variable to look up + recursive : bool + Use recursive look-up in parent scopes + fail : bool, optional + Flag to override the default failure when variable was not declared. + + Returns + ------- + :any:`SymbolAttributes` + The collection of attributes associated with this symbol + """ + + _type = self.get_type(name, recursive=recursive, fail=fail) + if _type is None: + return None + return _type.dtype diff --git a/loki/types/tests/test_procedure_types.py b/loki/types/tests/test_procedure_types.py index 79b4674cd..83b225712 100644 --- a/loki/types/tests/test_procedure_types.py +++ b/loki/types/tests/test_procedure_types.py @@ -10,12 +10,12 @@ from loki import Function, Module, Subroutine from loki.expression import symbols as sym from loki.frontend import available_frontends, OMNI -from loki.ir import nodes as ir, FindNodes +from loki.ir import nodes as ir, FindNodes, Transformer from loki.types import ProcedureType @pytest.mark.parametrize('frontend', available_frontends()) -def test_procedure_type(tmp_path, frontend): +def test_procedure_type_procedures(tmp_path, frontend): """ Test `ProcedureType` links to the procedure when it is defined. """ fcode_mod = """ @@ -82,3 +82,47 @@ def test_procedure_type(tmp_path, frontend): assert isinstance(sftype, ProcedureType) assert str(sftype) == 'pants' and repr(sftype) == '' assert sftype.procedure == stmtfuncs[0] + + +@pytest.mark.parametrize('frontend', available_frontends( + skip=[(OMNI, 'Statement functions not supported with OMNI')] +)) +def test_procedure_type_stmt_func(frontend): + """ Test `ProcedureType` keeps its link to `StatementFunction` when transformed. """ + fcode_mod = """ +subroutine test_routine(n, a) + implicit none + integer, intent(in) :: n + real(kind=4), intent(inout) :: a(3) + real(kind=4) :: pants, on, fire + pants(on, fire) = on + fire + + a(1) = pants(a(1), a(2)) +end subroutine test_routine +""" + routine = Subroutine.from_source(fcode_mod, frontend=frontend) + + stmtfuncs = FindNodes(ir.StatementFunction).visit(routine.spec) + assert len(stmtfuncs) == 1 + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 1 + + # Check original stmt func is fine and linked + sftype = stmtfuncs[0].variable.type.dtype + assert isinstance(sftype, ProcedureType) + assert sftype.procedure == stmtfuncs[0] + + assert isinstance(assigns[0].rhs, sym.InlineCall) + assert isinstance(assigns[0].rhs.function.type.dtype, ProcedureType) + assert assigns[0].rhs.function.type.dtype.procedure == stmtfuncs[0] + + # Save the original IR node and run a no-op Transformer over the spec + original = stmtfuncs[0] + routine.spec = Transformer().visit(routine.spec) + + # Re-generate the spec via Transformer and check the re-build stmt func + new_stmtfunc = FindNodes(ir.StatementFunction).visit(routine.spec)[0] + assert original is new_stmtfunc + + assert routine.get_dtype('pants').procedure is new_stmtfunc + assert assigns[0].rhs.function.type.dtype.procedure is new_stmtfunc diff --git a/loki/types/tests/test_scope.py b/loki/types/tests/test_scope.py index e45a51237..134e6db1f 100644 --- a/loki/types/tests/test_scope.py +++ b/loki/types/tests/test_scope.py @@ -9,7 +9,9 @@ A collection of tests for :any:`SymbolAttrs`, :any:`SymbolTable` and :any:`Scope`. """ -from loki.types import SymbolAttributes, BasicType +import pytest + +from loki.types import INTEGER, REAL, Scope, SymbolAttributes def test_symbol_attributes(): @@ -18,7 +20,7 @@ def test_symbol_attributes(): :any:`SymbolAttributes` """ _type = SymbolAttributes('integer', a='a', b=True, c=None) - assert _type.dtype == BasicType.INTEGER + assert _type.dtype == INTEGER assert _type.a == 'a' assert _type.b assert _type.c is None @@ -48,3 +50,60 @@ def test_symbol_attributes_compare(): assert someint.compare(another, ignore='b') assert another.compare(someint, ignore=['b']) assert not someint.compare(somereal) + + +def test_scope_setter(): + """ Test basic declaration and update behaviour of :any:`Scope` """ + scope = Scope() + + # Check basic type declaration + scope.declare('a', dtype='integer', kind=4, intent='in') + assert 'a' in scope.symbol_attrs + assert scope.symbol_attrs['a'].dtype == INTEGER + assert scope.symbol_attrs['a'].kind == 4 + assert scope.symbol_attrs['a'].intent == 'in' + + # Test erroneous and intentional re-declaration + with pytest.raises(ValueError): + scope.declare('a', dtype='real', kind=8) + + scope.declare('a', dtype='real', kind=8, fail=False) + assert 'a' in scope.symbol_attrs + assert scope.symbol_attrs['a'].dtype == REAL + assert scope.symbol_attrs['a'].kind == 8 + assert not scope.symbol_attrs['a'].intent # Wiped previous value + + # Check type declaration updates + scope.update('a', dtype='integer', intent='inout') + assert 'a' in scope.symbol_attrs + assert scope.symbol_attrs['a'].dtype == INTEGER + assert scope.symbol_attrs['a'].kind == 8 # Previous not wiped + assert scope.symbol_attrs['a'].intent == 'inout' + + with pytest.raises(ValueError): + scope.update('b', dtype='integer', intent='inout') + + # Override fail-safe, acts as another `declare()` call + scope.update('b', dtype='integer', intent='inout', fail=False) + assert 'b' in scope.symbol_attrs + assert scope.symbol_attrs['b'].dtype == INTEGER + assert scope.symbol_attrs['b'].intent == 'inout' + + +def test_scope_getter(): + """ Test basic :method:`get_type`/:method:`get_dtype` behaviour of :any:`Scope` """ + parent = Scope() + scope = Scope(parent=parent) + + scope.declare('a', dtype='real', kind=8, intent='inout') + parent.declare('b', dtype='integer', kind=4, intent='in') + + assert scope.get_type('a').dtype == REAL + assert scope.get_type('a').kind == 8 + assert scope.get_type('a').intent == 'inout' + + # Non-recursive and recursive lookups through parent + with pytest.raises(KeyError): + scope.get_type('b', recursive=False) + + assert scope.get_type('b', recursive=False, fail=False) is None