diff --git a/lint_rules/lint_rules/ifs_arpege_coding_standards.py b/lint_rules/lint_rules/ifs_arpege_coding_standards.py index 8257d4b38..64aa0b755 100644 --- a/lint_rules/lint_rules/ifs_arpege_coding_standards.py +++ b/lint_rules/lint_rules/ifs_arpege_coding_standards.py @@ -13,7 +13,6 @@ """ from collections import defaultdict -import re try: from fparser.two.Fortran2003 import Intrinsic_Name @@ -48,15 +47,13 @@ class MissingImplicitNoneRule(GenericRule): ), } - _regex = re.compile(r'implicit\s+none\b', re.I) - @classmethod def check_for_implicit_none(cls, ir_): """ Check for intrinsic nodes that match the regex. """ - for intr in FindNodes(ir.Intrinsic).visit(ir_): - if cls._regex.match(intr.text): + for intr in FindNodes(ir.ImplicitStmt).visit(ir_): + if not intr.text or intr.text.lower() == 'none': break else: return False diff --git a/lint_rules/lint_rules/ifs_coding_standards_2011.py b/lint_rules/lint_rules/ifs_coding_standards_2011.py index b07052bda..836d6bdbc 100644 --- a/lint_rules/lint_rules/ifs_coding_standards_2011.py +++ b/lint_rules/lint_rules/ifs_coding_standards_2011.py @@ -215,7 +215,7 @@ class LimitSubroutineStatementsRule(GenericRule): # Coding standards 2.2 # List of nodes that are considered executable statements exec_nodes = ( - ir.Assignment, ir.MaskedStatement, ir.Intrinsic, ir.Allocation, + ir.Assignment, ir.MaskedStatement, ir.GenericStmt, ir.Allocation, ir.Deallocation, ir.Nullify, ir.CallStatement ) @@ -231,7 +231,7 @@ def check_subroutine(cls, subroutine, rule_report, config, **kwargs): nodes = FindNodes(cls.exec_nodes).visit(subroutine.ir) num_nodes = len(nodes) # Subtract number of non-exec intrinsic nodes - intrinsic_nodes = filter(lambda node: isinstance(node, ir.Intrinsic), nodes) + intrinsic_nodes = filter(lambda node: isinstance(node, ir.GenericStmt), nodes) num_nodes -= sum(1 for _ in filter( lambda node: cls.match_non_exec_intrinsic_node.match(node.text), intrinsic_nodes)) @@ -298,15 +298,13 @@ class ImplicitNoneRule(GenericRule): # Coding standards 4.4 'title': '"IMPLICIT NONE" is mandatory in all routines.', } - _regex = re.compile(r'implicit\s+none\b', re.I) - @staticmethod def check_for_implicit_none(ast): """ Check for intrinsic nodes that match the regex. """ - for intr in FindNodes(ir.Intrinsic).visit(ast): - if ImplicitNoneRule._regex.match(intr.text): + for intr in FindNodes(ir.ImplicitStmt).visit(ast): + if not intr.text or intr.text.lower() == 'none': break else: return False @@ -450,9 +448,9 @@ class BannedStatementsRule(GenericRule): # Coding standards 4.11 @classmethod def check_subroutine(cls, subroutine, rule_report, config, **kwargs): '''Check for banned statements in intrinsic nodes.''' - for intr in FindNodes(ir.Intrinsic).visit(subroutine.ir): + for intr in FindNodes(ir.GenericStmt).visit(subroutine.ir): for keyword in config['banned']: - if keyword.lower() in intr.text.lower(): + if keyword.upper() in intr.text.upper() or keyword.upper() == intr.keyword: rule_report.add(f'Banned keyword "{keyword}"', intr) diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 463addebf..0439278e1 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -313,7 +313,7 @@ def visit_Node(self, o, **kwargs): # Handler for IR nodes - def visit_Intrinsic(self, o, **kwargs): # pylint: disable=unused-argument + def visit_GenericStmt(self, o, **kwargs): # pylint: disable=unused-argument """ Format intrinsic nodes. """ diff --git a/loki/backend/fgen.py b/loki/backend/fgen.py index 8836d8f95..f80b44a87 100644 --- a/loki/backend/fgen.py +++ b/loki/backend/fgen.py @@ -323,11 +323,13 @@ def visit_str(self, o, **kwargs): # Handler for IR nodes - def visit_Intrinsic(self, o, **kwargs): + def visit_GenericStmt(self, o, **kwargs): """ Format intrinsic nodes. """ - return self.format_line(str(o.text).lstrip()) + keyword = f'{o.keyword} ' if o.keyword else '' + text = ', '.join(self.visit_all(as_tuple(o.text), **kwargs)) if o.text else '' + return self.format_line(keyword, str(text).lstrip()) def visit_RawSource(self, o, **kwargs): """ diff --git a/loki/backend/pygen.py b/loki/backend/pygen.py index 032c8008a..f66757f52 100644 --- a/loki/backend/pygen.py +++ b/loki/backend/pygen.py @@ -165,7 +165,7 @@ def (): # Handler for IR nodes - def visit_Intrinsic(self, o, **kwargs): # pylint: disable=unused-argument + def visit_GenericStmt(self, o, **kwargs): # pylint: disable=unused-argument """ Format intrinsic nodes. """ diff --git a/loki/backend/tests/test_fgen.py b/loki/backend/tests/test_fgen.py index 89fd758f0..78e02ec35 100644 --- a/loki/backend/tests/test_fgen.py +++ b/loki/backend/tests/test_fgen.py @@ -69,7 +69,7 @@ def test_fgen_literal_list_linebreak(frontend, tmp_path): # Make sure all lines are continued correctly code = module.to_fortran() code_lines = code.splitlines() - assert len(code_lines) in (35, 36) # OMNI produces an extra line + assert len(code_lines) in (35, 37) # OMNI produces an extra line assert all(line.strip(' &\n') for line in code_lines) assert all(len(line) < 132 for line in code_lines) diff --git a/loki/backend/tests/test_stringifier.py b/loki/backend/tests/test_stringifier.py index 06d5fb985..bd959a384 100644 --- a/loki/backend/tests/test_stringifier.py +++ b/loki/backend/tests/test_stringifier.py @@ -21,6 +21,7 @@ def test_stringifier(frontend, tmp_path): """ fcode = """ MODULE some_mod + IMPLICIT NONE INTEGER :: n !$loki dimension(klon) REAL :: arr(:) @@ -83,13 +84,14 @@ def test_stringifier(frontend, tmp_path): ref_lines = [ "", # l. 1 "#", + "##", "##", "##", "##", "#", "##", "##", - "###", + "###", "###", # l. 10 "###", "###", @@ -112,7 +114,7 @@ def test_stringifier(frontend, tmp_path): "... 1. + 1.>", "#", # l. 30 "##", - "###", + "###", "###", "###", "##", @@ -120,7 +122,7 @@ def test_stringifier(frontend, tmp_path): "#", "##", "##", - "###", # l. 40 + "###", # l. 40 "###", "###", "##", @@ -129,11 +131,11 @@ def test_stringifier(frontend, tmp_path): "####", "#####", "####", - "#####", + "#####", "####", # l. 50 "#####", "####", - "#####", + "#####", "###", "####", "###", @@ -144,12 +146,12 @@ def test_stringifier(frontend, tmp_path): if frontend == OMNI: # Some string inconsistencies - ref_lines[15] = ref_lines[15].replace('1E-8', '1e-8') - ref_lines[35] = ref_lines[35].replace('SQRT', 'sqrt') - ref_lines[48] = ref_lines[48].replace('PRINT', 'print') - ref_lines[52] = ref_lines[52].replace('PRINT', 'print') + ref_lines[16] = ref_lines[16].replace('1E-8', '1e-8') + ref_lines[36] = ref_lines[36].replace('SQRT', 'sqrt') + ref_lines[49] = ref_lines[49].replace('PRINT', 'print') + ref_lines[53] = ref_lines[53].replace('PRINT', 'print') - cont_index = 27 # line number where line continuation is happening + cont_index = 28 # line number where line continuation is happening ref = '\n'.join(ref_lines) module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) diff --git a/loki/expression/tests/test_expression.py b/loki/expression/tests/test_expression.py index e4923fa1b..ec3f8de6d 100644 --- a/loki/expression/tests/test_expression.py +++ b/loki/expression/tests/test_expression.py @@ -602,7 +602,7 @@ def test_output_intrinsics(frontend): ref[1] = ref[1].replace(' * ', '*') ref[1] = ref[1].replace('- 1', '-1') - intrinsics = FindNodes(ir.Intrinsic).visit(routine.body) + intrinsics = FindNodes(ir.GenericStmt).visit(routine.body) assert len(intrinsics) == 2 assert intrinsics[0].text.lower() == ref[0] assert intrinsics[1].text.lower() == ref[1] diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index b4e455d12..ff825e1ab 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -322,13 +322,13 @@ def visit_List(self, o, **kwargs): """ return tuple(self.visit(i, **kwargs) for i in o.children) - def visit_Intrinsic_Stmt(self, o, **kwargs): + def visit_Generic_Stmt(self, o, **kwargs): """ Universal routine to capture nodes as plain string in the IR """ label = kwargs.get('label') label = str(label) if label else label # Ensure srting labels - return ir.Intrinsic(text=o.tostr(), label=label, source=kwargs.get('source')) + return ir.GenericStmt(text=o.tostr(), label=label, source=kwargs.get('source')) # # Base blocks @@ -1526,12 +1526,18 @@ def visit_Final_Binding(self, o, **kwargs): symbols=symbols, final=True, source=kwargs.get('source'), label=kwargs.get('label') ) + def visit_Contains_Stmt(self, o, **kwargs): + return ir.ContainsStmt(source=kwargs.get('source')) + + def visit_Private_Components_Stmt(self, o, **kwargs): + return ir.PrivateStmt(source=kwargs.get('source')) + + def visit_Binding_Private_Stmt(self, o, **kwargs): + return ir.PrivateStmt(source=kwargs.get('source')) + visit_Binding_Name_List = visit_List visit_Final_Subroutine_Name_List = visit_List - visit_Contains_Stmt = visit_Intrinsic_Stmt - visit_Binding_Private_Stmt = visit_Intrinsic_Stmt - visit_Private_Components_Stmt = visit_Intrinsic_Stmt - visit_Sequence_Stmt = visit_Intrinsic_Stmt + visit_Sequence_Stmt = visit_Generic_Stmt # # ASSOCIATE blocks @@ -3245,14 +3251,18 @@ def visit_Include_Stmt(self, o, **kwargs): label=kwargs.get('label')) def visit_Implicit_Stmt(self, o, **kwargs): - return ir.Intrinsic(text=f'IMPLICIT {o.items[0]}', source=kwargs.get('source'), - label=kwargs.get('label')) + if len(o.items) == 1 and isinstance(o.items[0], str): + return ir.ImplicitStmt(text=o.items[0], **kwargs) + content = tuple(i if isinstance(i, str) else self.visit(i, **kwargs) for i in o.items) + return ir.ImplicitStmt(text=content, **kwargs) def visit_Print_Stmt(self, o, **kwargs): # NOTE: fparser returns None for an empty print (`PRINT *`) instead of # the usual `Output_Item_List` entity. - return ir.Intrinsic(text=f'PRINT {", ".join(str(i) for i in o.items if i is not None)}', - source=kwargs.get('source'), label=kwargs.get('label')) + return ir.GenericStmt( + text=f'PRINT {", ".join(str(i) for i in o.items if i is not None)}', + source=kwargs.get('source'), label=kwargs.get('label') + ) # TODO: Deal with line-continuation pragmas! _re_pragma = re.compile(r'^\s*\!\$(?P\w+)\s*(?P.*)', re.IGNORECASE) @@ -3491,29 +3501,43 @@ def visit_Parenthesis(self, o, **kwargs): expression = ParenthesisedPow(expression.base, expression.exponent) return expression - visit_Format_Stmt = visit_Intrinsic_Stmt - visit_Write_Stmt = visit_Intrinsic_Stmt - visit_Goto_Stmt = visit_Intrinsic_Stmt - visit_Return_Stmt = visit_Intrinsic_Stmt - visit_Continue_Stmt = visit_Intrinsic_Stmt - visit_Cycle_Stmt = visit_Intrinsic_Stmt - visit_Exit_Stmt = visit_Intrinsic_Stmt - visit_Save_Stmt = visit_Intrinsic_Stmt - visit_Read_Stmt = visit_Intrinsic_Stmt - visit_Open_Stmt = visit_Intrinsic_Stmt - visit_Close_Stmt = visit_Intrinsic_Stmt - visit_Inquire_Stmt = visit_Intrinsic_Stmt - visit_Namelist_Stmt = visit_Intrinsic_Stmt - visit_Parameter_Stmt = visit_Intrinsic_Stmt - visit_Dimension_Stmt = visit_Intrinsic_Stmt - visit_Equivalence_Stmt = visit_Intrinsic_Stmt - visit_Common_Stmt = visit_Intrinsic_Stmt - visit_Stop_Stmt = visit_Intrinsic_Stmt - visit_Error_Stop_Stmt = visit_Intrinsic_Stmt - visit_Backspace_Stmt = visit_Intrinsic_Stmt - visit_Rewind_Stmt = visit_Intrinsic_Stmt - visit_Entry_Stmt = visit_Intrinsic_Stmt - visit_Cray_Pointer_Stmt = visit_Intrinsic_Stmt + # + # Remaining internal Fortran statements + # + + def visit_Save_Stmt(self, o, **kwargs): + return ir.SaveStmt(source=kwargs.get('source')) + + def visit_Return_Stmt(self, o, **kwargs): + return ir.ReturnStmt(source=kwargs.get('source')) + + def visit_Cycle_Stmt(self, o, **kwargs): + return ir.CycleStmt(source=kwargs.get('source')) + + def visit_Goto_Stmt(self, o, **kwargs): + label = o.items[0].tostr() + return ir.GotoStmt(text=label, source=kwargs.get('source')) + + visit_Intrinsic_Stmt = visit_Generic_Stmt + visit_Format_Stmt = visit_Generic_Stmt + visit_Write_Stmt = visit_Generic_Stmt + visit_Continue_Stmt = visit_Generic_Stmt + visit_Exit_Stmt = visit_Generic_Stmt + visit_Read_Stmt = visit_Generic_Stmt + visit_Open_Stmt = visit_Generic_Stmt + visit_Close_Stmt = visit_Generic_Stmt + visit_Inquire_Stmt = visit_Generic_Stmt + visit_Namelist_Stmt = visit_Generic_Stmt + visit_Parameter_Stmt = visit_Generic_Stmt + visit_Dimension_Stmt = visit_Generic_Stmt + visit_Equivalence_Stmt = visit_Generic_Stmt + visit_Common_Stmt = visit_Generic_Stmt + visit_Stop_Stmt = visit_Generic_Stmt + visit_Error_Stop_Stmt = visit_Generic_Stmt + visit_Backspace_Stmt = visit_Generic_Stmt + visit_Rewind_Stmt = visit_Generic_Stmt + visit_Entry_Stmt = visit_Generic_Stmt + visit_Cray_Pointer_Stmt = visit_Generic_Stmt def visit_Cpp_If_Stmt(self, o, **kwargs): return ir.PreprocessorDirective(text=o.tostr(), source=kwargs.get('source')) diff --git a/loki/frontend/omni.py b/loki/frontend/omni.py index 260455364..534c51aa9 100644 --- a/loki/frontend/omni.py +++ b/loki/frontend/omni.py @@ -412,9 +412,9 @@ def visit_FfunctionDefinition(self, o, **kwargs): # Insert the `implicit none` statement OMNI omits (slightly hacky!) f_imports = [im for im in FindNodes(ir.Import).visit(spec) if not im.c_import] if not f_imports: - spec.prepend(ir.Intrinsic(text='IMPLICIT NONE')) + spec.prepend(ir.ImplicitStmt()) else: - spec.insert(spec.body.index(f_imports[-1])+1, ir.Intrinsic(text='IMPLICIT NONE')) + spec.insert(spec.body.index(f_imports[-1])+1, ir.ImplicitStmt()) # Parse member functions body_ast = o.find('body') @@ -462,7 +462,7 @@ def visit_FfunctionDefinition(self, o, **kwargs): def visit_FcontainsStatement(self, o, **kwargs): body = [self.visit(c, **kwargs) for c in o] body = [c for c in body if c is not None] - body = [ir.Intrinsic('CONTAINS', source=kwargs['source'])] + body + body = [ir.ContainsStmt(source=kwargs['source'])] + body return ir.Section(body=as_tuple(body)) def visit_FmoduleProcedureDecl(self, o, **kwargs): @@ -526,6 +526,13 @@ def visit_FmoduleDefinition(self, o, **kwargs): docstring = as_tuple(docstring) spec = Transformer(spec_map, invalidate_source=False).visit(spec) + # Insert the `implicit none` statement OMNI omits (slightly hacky!) + f_imports = [im for im in FindNodes(ir.Import).visit(spec) if not im.c_import] + if not f_imports: + spec.prepend(ir.ImplicitStmt()) + else: + spec.insert(spec.body.index(f_imports[-1])+1, ir.ImplicitStmt()) + # Parse member functions if contains_ast is not None: contains = self.visit(contains_ast, **kwargs) @@ -669,7 +676,7 @@ def visit_FstructDecl(self, o, **kwargs): # Check if the type is marked as sequence if struct_type.get('is_sequence') == 'true': - body += [ir.Intrinsic('SEQUENCE')] + body += [ir.GenericStmt('SEQUENCE')] # Build the list of derived type members and individual body for each if struct_type.find('symbols') is not None: @@ -687,9 +694,9 @@ def visit_FstructDecl(self, o, **kwargs): if struct_type.find('typeBoundProcedures') is not None: # See if components are marked private - body += [ir.Intrinsic('CONTAINS')] + body += [ir.ContainsStmt()] if struct_type.attrib.get('is_internal_private') == 'true': - body += [ir.Intrinsic('PRIVATE')] + body += [ir.PrivateStmt()] body += self.visit(struct_type.find('typeBoundProcedures'), **kwargs) # Finally: update the typedef with its body @@ -1286,49 +1293,49 @@ def visit_FstructConstructor(self, o, **kwargs): def visit_FcycleStatement(self, o, **kwargs): # TODO: do-construct-name is not preserved - return ir.Intrinsic(text='cycle', source=kwargs['source']) + return ir.CycleStmt(source=kwargs['source']) def visit_continueStatement(self, o, **kwargs): - return ir.Intrinsic(text='continue', source=kwargs['source']) + return ir.GenericStmt(text='continue', source=kwargs['source']) def visit_FexitStatement(self, o, **kwargs): # TODO: do-construct-name is not preserved - return ir.Intrinsic(text='exit', source=kwargs['source']) + return ir.GenericStmt(text='exit', source=kwargs['source']) def visit_FopenStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) - return ir.Intrinsic(text=f'open({nargs})', source=kwargs['source']) + return ir.GenericStmt(text=f'open({nargs})', source=kwargs['source']) def visit_FcloseStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) - return ir.Intrinsic(text=f'close({nargs})', source=kwargs['source']) + return ir.GenericStmt(text=f'close({nargs})', source=kwargs['source']) def visit_FreadStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] values = [self.visit(v, **kwargs) for v in o.find('valueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) args = ', '.join(f'{v}' for v in values) - return ir.Intrinsic(text=f'read({nargs}) {args}', source=kwargs['source']) + return ir.GenericStmt(text=f'read({nargs}) {args}', source=kwargs['source']) def visit_FwriteStatement(self, o, **kwargs): nvalues = [self.visit(nv, **kwargs) for nv in o.find('namedValueList')] values = [self.visit(v, **kwargs) for v in o.find('valueList')] nargs = ', '.join(f'{k}={v}' for k, v in nvalues) args = ', '.join(f'{v}' for v in values) - return ir.Intrinsic(text=f'write({nargs}) {args}', source=kwargs['source']) + return ir.GenericStmt(text=f'write({nargs}) {args}', source=kwargs['source']) def visit_FprintStatement(self, o, **kwargs): values = [self.visit(v, **kwargs) for v in o.find('valueList')] args = ', '.join(f'{v}' for v in values) args = f", {args}" if values else "" fmt = o.attrib['format'] - return ir.Intrinsic(text=f'print {fmt}{args}', source=kwargs['source']) + return ir.GenericStmt(text=f'print {fmt}{args}', source=kwargs['source']) def visit_FformatDecl(self, o, **kwargs): fmt = f'FORMAT{o.attrib["format"]}' - return ir.Intrinsic(text=fmt, source=kwargs['source']) + return ir.GenericStmt(text=fmt, source=kwargs['source']) def visit_namedValue(self, o, **kwargs): name = o.attrib['name'] @@ -1446,15 +1453,15 @@ def visit_FconcatExpr(self, o, **kwargs): return StringConcat(exprs) def visit_gotoStatement(self, o, **kwargs): - label = int(o.attrib['label_name']) - return ir.Intrinsic(text=f'go to {label: d}', source=kwargs['source']) + label = str(int(o.attrib['label_name'])) + return ir.GotoStmt(text=label, source=kwargs['source']) def visit_FstopStatement(self, o, **kwargs): code = o.attrib['code'] - return ir.Intrinsic(text=f'stop {code!s}', source=kwargs['source']) + return ir.GenericStmt(text=f'stop {code!s}', source=kwargs['source']) def visit_statementLabel(self, o, **kwargs): return ir.Comment('__STATEMENT_LABEL__', label=o.attrib['label_name'], source=kwargs['source']) def visit_FreturnStatement(self, o, **kwargs): - return ir.Intrinsic(text='return', source=kwargs['source']) + return ir.ReturnStmt(source=kwargs['source']) diff --git a/loki/frontend/preprocessing.py b/loki/frontend/preprocessing.py index a29f933a0..59bf94daf 100644 --- a/loki/frontend/preprocessing.py +++ b/loki/frontend/preprocessing.py @@ -18,7 +18,7 @@ from loki.logging import debug, detail from loki.config import config from loki.tools import as_tuple, gettempdir, filehash -from loki.ir import Intrinsic, FindNodes +from loki.ir import GenericStmt, FindNodes from loki.frontend.util import OMNI, FP, REGEX @@ -132,7 +132,7 @@ def reinsert_convert_endian(ir, pp_info): into calls to OPEN. """ if pp_info: - for intr in FindNodes(Intrinsic).visit(ir): + for intr in FindNodes(GenericStmt).visit(ir): match = pp_info.get(intr.source.lines[0], [None])[0] if match is not None: source = intr.source @@ -151,7 +151,7 @@ def reinsert_open_newunit(ir, pp_info): Reinsert the NEWUNIT=... arguments into calls to OPEN. """ if pp_info: - for intr in FindNodes(Intrinsic).visit(ir): + for intr in FindNodes(GenericStmt).visit(ir): match = pp_info.get(intr.source.lines[0], [None])[0] if match is not None: source = intr.source diff --git a/loki/frontend/regex.py b/loki/frontend/regex.py index 5fb1178f0..34a81155c 100644 --- a/loki/frontend/regex.py +++ b/loki/frontend/regex.py @@ -441,7 +441,7 @@ def match(self, reader, parser_classes, scope): spec = None if match['contains']: - contains = [ir.Intrinsic(text='CONTAINS')] + contains = [ir.ContainsStmt()] span = match.span('contains') span = (span[0] + 8, span[1]) # Skip the "contains" keyword as it has been added candidates = ['SubroutineFunctionPattern'] @@ -533,7 +533,7 @@ def match(self, reader, parser_classes, scope): spec = None if match['contains']: - contains = [ir.Intrinsic(text='CONTAINS')] + contains = [ir.ContainsStmt()] span = match.span('contains') span = (span[0] + 8, span[1]) # Skip the "contains" keyword as it has been added block_children = ['SubroutineFunctionPattern'] @@ -716,7 +716,7 @@ def match(self, reader, parser_classes, scope): spec = [] if match['contains']: - contains = [ir.Intrinsic(text='CONTAINS')] + contains = [ir.ContainsStmt()] span = match.span('contains') span = (span[0] + 8, span[1]) # Skip the "contains" keyword as it has been added diff --git a/loki/frontend/tests/test_fparser_source.py b/loki/frontend/tests/test_fparser_source.py index 122d96d3e..038a0244f 100644 --- a/loki/frontend/tests/test_fparser_source.py +++ b/loki/frontend/tests/test_fparser_source.py @@ -202,7 +202,7 @@ def test_fparser_source_parsing(store_source): assert not assigns[0].source imprts = FindNodes(ir.Import).visit(module.spec) - intrs = FindNodes(ir.Intrinsic).visit(module.spec) + intrs = FindNodes(ir.GenericStmt).visit(module.spec) tdefs = FindNodes(ir.TypeDef).visit(module.spec) assert len(imprts) == 1 and len(tdefs) == 1 and len(intrs) == 1 tdecls = FindNodes(ir.VariableDeclaration).visit(tdefs[0].body) diff --git a/loki/frontend/tests/test_frontends.py b/loki/frontend/tests/test_frontends.py index 938a1ee32..f1c834134 100644 --- a/loki/frontend/tests/test_frontends.py +++ b/loki/frontend/tests/test_frontends.py @@ -995,7 +995,7 @@ def test_empty_print_statement(frontend): """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) print_stmts = [ - intr for intr in FindNodes(ir.Intrinsic).visit(routine.ir) + intr for intr in FindNodes(ir.GenericStmt).visit(routine.ir) if 'print' in intr.text.lower() ] assert print_stmts[0].text.lower() == "print *" diff --git a/loki/frontend/util.py b/loki/frontend/util.py index 1d8ae2985..aebd53502 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -17,7 +17,7 @@ from loki.ir import ( NestedTransformer, FindNodes, PatternFinder, Transformer, Assignment, Comment, CommentBlock, VariableDeclaration, - ProcedureDeclaration, Loop, Intrinsic, Pragma + ProcedureDeclaration, Loop, GenericStmt, Pragma ) from loki.frontend.source import join_source_list from loki.logging import detail, warning, error @@ -202,7 +202,7 @@ def inline_labels(ir): any connection between both. """ pairs = PatternFinder(pattern=(Comment, Assignment)).visit(ir) - pairs += PatternFinder(pattern=(Comment, Intrinsic)).visit(ir) + pairs += PatternFinder(pattern=(Comment, GenericStmt)).visit(ir) pairs += PatternFinder(pattern=(Comment, Loop)).visit(ir) mapper = {} for pair in pairs: diff --git a/loki/ir/nodes/__init__.py b/loki/ir/nodes/__init__.py new file mode 100644 index 000000000..36433c8d0 --- /dev/null +++ b/loki/ir/nodes/__init__.py @@ -0,0 +1,13 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" The node defininition classes for the Loki IR. """ + +from loki.ir.nodes.abstract_nodes import * # noqaOA +from loki.ir.nodes.internal_nodes import * # noqa +from loki.ir.nodes.leaf_nodes import * # noqa +from loki.ir.nodes.stmt_nodes import * # noqa diff --git a/loki/ir/nodes/abstract_nodes.py b/loki/ir/nodes/abstract_nodes.py new file mode 100644 index 000000000..b2d9e107e --- /dev/null +++ b/loki/ir/nodes/abstract_nodes.py @@ -0,0 +1,359 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" Abstract base classes for node definitions in the Loki IR. """ + +from abc import abstractmethod +from collections import OrderedDict +from typing import Tuple, Union, Optional + +from pydantic import field_validator + +from loki.expression import Variable, parse_expr +from loki.frontend.source import Source +from loki.tools import ( + dataclass_strict, sanitize_tuple, CaseInsensitiveDict +) +from loki.types import Scope + + +__all__ = ['Node', 'InternalNode', 'LeafNode', 'ScopedNode'] + + +@dataclass_strict(frozen=True) +class Node: + """ + Base class for all node types in Loki's internal representation. + + Provides the common functionality shared by all node types; specifically, + this comprises functionality to update or rebuild a node, and source + metadata. + + Attributes + ---------- + traversable : list of str + The traversable fields of the Node; that is, fields walked over by + a :any:`Visitor`. All arguments in :py:meth:`__init__` whose + name appear in this list are treated as traversable fields. + + Parameters + ---------- + source : :any:`Source`, optional + the information about the original source for the Node. + label : str, optional + the label assigned to the statement in the original source + corresponding to the Node. + + """ + + source: Optional[Union[Source, str]] = None + label: Optional[str] = None + + _traversable = [] + + def __post_init__(self): + # Create private placeholders for dataflow analysis fields that + # do not show up in the dataclass field definitions, as these + # are entirely transient. + self._update(_live_symbols=None, _defines_symbols=None, _uses_symbols=None) + + @property + def children(self): + """ + The traversable children of the node. + """ + return tuple(getattr(self, i) for i in self._traversable) + + def _rebuild(self, *args, **kwargs): + """ + Rebuild the node. + + Constructs an identical copy of the node from when it was first + created. Optionally, some or all of the arguments for it can + be overwritten. + + Parameters + ---------- + *args : optional + The traversable arguments used to create the node. By default, + ``args`` are used. + **kwargs : optional + The non-traversable arguments used to create the node, By + default, ``args_frozen`` are used. + """ + handle = self.args + argnames = [i for i in self._traversable if i not in kwargs] + handle.update(OrderedDict(zip(argnames, args))) + handle.update(kwargs) + return type(self)(**handle) + + clone = _rebuild + + def _update(self, *args, **kwargs): + """ + In-place update that modifies (re-initializes) the node + without rebuilding it. Use with care! + + Parameters + ---------- + *args : optional + The traversable arguments used to create the node. By default, + ``args`` are used. + **kwargs : optional + The non-traversable arguments used to create the node, By + default, ``args_frozen`` are used. + + """ + argnames = [i for i in self._traversable if i not in kwargs] + kwargs.update(zip(argnames, args)) + self.__dict__.update(kwargs) + + @property + def args(self): + """ + Arguments used to construct the Node. + """ + return {k: v for k, v in self.__dict__.items() if k in self.__dataclass_fields__.keys()} # pylint: disable=no-member + + @property + def args_frozen(self): + """ + Arguments used to construct the Node that cannot be traversed. + """ + return {k: v for k, v in self.args.items() if k not in self._traversable} + + def __repr__(self): + raise NotImplementedError + + def view(self): + """ + Pretty-print the node hierachy under this node. + """ + # pylint: disable=import-outside-toplevel,cyclic-import + from loki.backend.pprint import pprint + pprint(self) + + def ir_graph(self, show_comments=False, show_expressions=False, linewidth=40, symgen=str): + """ + Get the IR graph to visualize the node hierachy under this node. + """ + # pylint: disable=import-outside-toplevel,cyclic-import + from loki.ir.ir_graph import ir_graph + + return ir_graph(self, show_comments, show_expressions,linewidth, symgen) + + @property + def live_symbols(self): + """ + Yield the list of live symbols at this node, i.e., variables that + have been defined (potentially) prior to this point in the control flow + graph. + + This property is attached to the Node by + :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or + when using the + :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` + context manager. + """ + if self.__dict__['_live_symbols'] is None: + raise RuntimeError('Need to run dataflow analysis on the IR first.') + return self.__dict__['_live_symbols'] + + @property + def defines_symbols(self): + """ + Yield the list of symbols (potentially) defined by this node. + + This property is attached to the Node by + :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or + when using the + :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` + context manager. + """ + if self.__dict__['_defines_symbols'] is None: + raise RuntimeError('Need to run dataflow analysis on the IR first.') + return self.__dict__['_defines_symbols'] + + @property + def uses_symbols(self): + """ + + Yield the list of symbols used by this node before defining it. + + This property is attached to the Node by + :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or + when using the + :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` + context manager. + """ + if self.__dict__['_uses_symbols'] is None: + raise RuntimeError('Need to run dataflow analysis on the IR first.') + return self.__dict__['_uses_symbols'] + + +@dataclass_strict(frozen=True) +class _InternalNode(): + """ Type definitions for :any:`InternalNode` node type. """ + + body: Tuple[Union[Node, Scope], ...] = () + + +@dataclass_strict(frozen=True) +class InternalNode(Node, _InternalNode): + """ + Internal representation of a control flow node that has a traversable + `body` property. + + Parameters + ---------- + body : tuple + The nodes that make up the body. + """ + + _traversable = ['body'] + + @field_validator('body', mode='before') + @classmethod + def ensure_tuple(cls, value): + return sanitize_tuple(value) + + def __repr__(self): + raise NotImplementedError + + +@dataclass_strict(frozen=True) +class LeafNode(Node): + """ + Internal representation of a control flow node without a `body`. + """ + + def __repr__(self): + raise NotImplementedError + + +# Mix-ins + +class ScopedNode(Scope): + """ + Mix-in to attache a scope to an IR :any:`Node` + + Additionally, this specializes the node's :meth:`_update` and + :meth:`_rebuild` methods to make sure that an existing symbol table + is carried over correctly. + """ + + @property + def args(self): + """ + Arguments used to construct the :any:`ScopedNode`, excluding + the symbol table. + """ + keys = tuple(k for k in self.__dataclass_fields__.keys() if k not in ('symbol_attrs', )) # pylint: disable=no-member + return {k: v for k, v in self.__dict__.items() if k in keys} + + def _update(self, *args, **kwargs): + if 'symbol_attrs' not in kwargs: + # Retain the symbol table (unless given explicitly) + kwargs['symbol_attrs'] = self.symbol_attrs + super()._update(*args, **kwargs) # pylint: disable=no-member + + def _rebuild(self, *args, **kwargs): + # Retain the symbol table (unless given explicitly) + symbol_attrs = kwargs.pop('symbol_attrs', self.symbol_attrs) + rescope_symbols = kwargs.pop('rescope_symbols', False) + + # Ensure 'parent' is always explicitly set + kwargs['parent'] = kwargs.get('parent', None) + + new_obj = super()._rebuild(*args, **kwargs) # pylint: disable=no-member + new_obj.symbol_attrs.update(symbol_attrs) + + if rescope_symbols: + new_obj.rescope_symbols() + return new_obj + + def __getstate__(self): + s = self.args + s['symbol_attrs'] = self.symbol_attrs + return s + + def __setstate__(self, s): + symbol_attrs = s.pop('symbol_attrs', None) + self._update(**s, symbol_attrs=symbol_attrs, rescope_symbols=True) + + @property + @abstractmethod + def variables(self): + """ + Return the variables defined in this :any:`ScopedNode`. + """ + + @property + def variable_map(self): + """ + Map of variable names to :any:`Variable` objects + """ + return CaseInsensitiveDict((v.name, v) for v in self.variables) + + def get_symbol(self, name): + """ + Returns the symbol for a given name as defined in its declaration. + + The returned symbol might include dimension symbols if it was + declared as an array. + + Parameters + ---------- + name : str + Base name of the symbol to be retrieved + """ + return self.get_symbol_scope(name).variable_map.get(name) + + def Variable(self, **kwargs): + """ + Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes. + + This invokes the :any:`Variable` with this node as the scope. + + Parameters + ---------- + name : str + The name of the variable. + type : optional + The type of that symbol. Defaults to :any:`BasicType.DEFERRED`. + parent : :any:`Scalar` or :any:`Array`, optional + The derived type variable this variable belongs to. + dimensions : :any:`ArraySubscript`, optional + The array subscript expression. + """ + kwargs['scope'] = self + return Variable(**kwargs) + + def parse_expr(self, expr_str, strict=False, evaluate=False, context=None): + """ + Uses :meth:`parse_expr` to convert expression(s) represented + in a string to Loki expression(s)/IR. + + Parameters + ---------- + expr_str : str + The expression as a string + strict : bool, optional + Whether to raise exception for unknown variables/symbols when + evaluating an expression (default: `False`) + evaluate : bool, optional + Whether to evaluate the expression or not (default: `False`) + context : dict, optional + Symbol context, defining variables/symbols/procedures to help/support + evaluating an expression + + Returns + ------- + :any:`Expression` + The expression tree corresponding to the expression + """ + return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context) diff --git a/loki/ir/nodes/internal_nodes.py b/loki/ir/nodes/internal_nodes.py new file mode 100644 index 000000000..b60bf83df --- /dev/null +++ b/loki/ir/nodes/internal_nodes.py @@ -0,0 +1,487 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" Intermediate node classes for nested node definitions in the Loki IR. """ + +from typing import Any, Tuple, Union, Optional + +from pymbolic.primitives import Expression +from pydantic import field_validator + +from loki.ir.nodes.abstract_nodes import ( + Node, InternalNode, ScopedNode +) +from loki.expression import ( + symbols as sym, AttachScopesMapper, ExpressionDimensionsMapper +) +from loki.tools import ( + as_tuple, dataclass_strict, flatten, sanitize_tuple, CaseInsensitiveDict +) +from loki.types import BasicType, SymbolAttributes + + +__all__ = [ + 'Section', 'Associate', 'Loop', 'WhileLoop', 'Conditional', + 'PragmaRegion', 'Interface', +] + + +@dataclass_strict(frozen=True) +class _SectionBase(): + """ Type definitions for :any:`Section` node type. """ + + +@dataclass_strict(frozen=True) +class Section(InternalNode, _SectionBase): + """ + Internal representation of a single code region. + """ + + def append(self, node): + """ + Append the given node(s) to the section's body. + + Parameters + ---------- + node : :any:`Node` or tuple of :any:`Node` + The node(s) to append to the section. + """ + self._update(body=self.body + as_tuple(node)) + + def insert(self, pos, node): + """ + Insert the given node(s) into the section's body at a specific + position. + + Parameters + ---------- + pos : int + The position at which the node(s) should be inserted. Any existing + nodes at this or after this position are shifted back. + node : :any:`Node` or tuple of :any:`Node` + The node(s) to append to the section. + """ + self._update(body=self.body[:pos] + as_tuple(node) + self.body[pos:]) # pylint: disable=unsubscriptable-object + + def prepend(self, node): + """ + Insert the given node(s) at the beginning of the section's body. + + Parameters + ---------- + node : :any:`Node` or tuple of :any:`Node` + The node(s) to insert into the section. + """ + self._update(body=as_tuple(node) + self.body) + + def __repr__(self): + if self.label is not None: + return f'Section:: {self.label}' + return 'Section::' + + +@dataclass_strict(frozen=True) +class _AssociateBase(): + """ Type definitions for :any:`Associate` node type. """ + + associations: Tuple[Tuple[Expression, Expression], ...] + + +@dataclass_strict(frozen=True) +class Associate(ScopedNode, Section, _AssociateBase): # pylint: disable=too-many-ancestors + """ + Internal representation of a code region in which names are associated + with expressions or variables. + + Parameters + ---------- + body : tuple + The associate's body. + associations : dict or collections.OrderedDict + The mapping of names to expressions or variables valid inside the + associate's body. + parent : :any:`Scope`, optional + The parent scope in which the associate appears + symbol_attrs : :any:`SymbolTable`, optional + An existing symbol table to use + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + _traversable = ['body', 'associations'] + + def __post_init__(self, parent=None): + super(ScopedNode, self).__post_init__(parent=parent) + super(Section, self).__post_init__() + + assert self.associations is None or isinstance(self.associations, tuple) + + @property + def association_map(self): + """ + An :any:`collections.OrderedDict` of associated expressions. + """ + return CaseInsensitiveDict((k, v) for k, v in self.associations) + + @property + def inverse_map(self): + """ + An :any:`collections.OrderedDict` of associated expressions. + """ + return CaseInsensitiveDict((v, k) for k, v in self.associations) + + @property + def variables(self): + return tuple(v for _, v in self.associations) + + def _derive_local_symbol_types(self, parent_scope): + """ Derive the types of locally defined symbols from their associations. """ + + rescoped_associations = () + for expr, name in self.associations: + # Put symbols in associated expression into the right scope + expr = AttachScopesMapper()(expr, scope=parent_scope) + + # Determine type of new names + if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)): + # Use the type of the associated variable + _type = expr.type.clone(parent=None) + if isinstance(expr, sym.Array) and expr.dimensions is not None: + shape = ExpressionDimensionsMapper()(expr) + if shape == (sym.IntLiteral(1),): + # For a scalar expression, we remove the shape + shape = None + _type = _type.clone(shape=shape) + else: + # TODO: Handle data type and shape of complex expressions + shape = ExpressionDimensionsMapper()(expr) + if shape == (sym.IntLiteral(1),): + # For a scalar expression, we remove the shape + shape = None + _type = SymbolAttributes(BasicType.DEFERRED, shape=shape) + name = name.clone(scope=self, type=_type) + rescoped_associations += ((expr, name),) + + self._update(associations=rescoped_associations) + + def __repr__(self): + if self.associations: + associations = ', '.join(f'{str(var)}={str(expr)}' + for var, expr in self.associations) + return f'Associate:: {associations}' + return 'Associate::' + + +@dataclass_strict(frozen=True) +class _LoopBase(): + """ Type definitions for :any:`Loop` node type. """ + + variable: Expression + bounds: Expression + body: Tuple[Node, ...] + pragma: Optional[Tuple[Node, ...]] = None + pragma_post: Optional[Tuple[Node, ...]] = None + loop_label: Optional[Any] = None + name: Optional[str] = None + has_end_do: Optional[bool] = True + + +@dataclass_strict(frozen=True) +class Loop(InternalNode, _LoopBase): + """ + Internal representation of a loop with induction variable and range. + + Parameters + ---------- + variable : :any:`Scalar` + The induction variable of the loop. + bounds : :any:`LoopRange` + The range of the loop, defining the iteration space. + body : tuple + The loop body. + pragma : tuple of :any:`Pragma`, optional + Pragma(s) that appear in front of the loop. By default :any:`Pragma` + nodes appear as standalone nodes in the IR before the :any:`Loop` node. + Only a bespoke context created by :py:func:`pragmas_attached` + attaches them for convenience. + pragma_post : tuple of :any:`Pragma`, optional + Pragma(s) that appear after the loop. The same applies as for `pragma`. + loop_label : str, optional + The Fortran label for that loop. Importantly, this is an intrinsic + Fortran feature and different from the statement label that can be + attached to other nodes. + name : str, optional + The Fortran construct name for that loop. + has_end_do : bool, optional + In Fortran, loop blocks can be closed off by a ``CONTINUE`` statement + (which we retain as an :any:`Intrinsic` node) and therefore ``END DO`` + can be omitted. For string reproducibility this parameter can be set + `False` to indicate that this loop did not have an ``END DO`` + statement in the original source. + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + _traversable = ['variable', 'bounds', 'body'] + + def __post_init__(self): + super().__post_init__() + assert self.variable is not None + + def __repr__(self): + label = ', '.join(l for l in [self.name, self.loop_label] if l is not None) + if label: + label = ' ' + label + control = f'{str(self.variable)}={str(self.bounds)}' + return f'Loop::{label} {control}' + + +@dataclass_strict(frozen=True) +class _WhileLoopBase(): + """ Type definitions for :any:`WhileLoop` node type. """ + + condition: Optional[Expression] + body: Tuple[Node, ...] + pragma: Optional[Node] = None + pragma_post: Optional[Node] = None + loop_label: Optional[Any] = None + name: Optional[str] = None + has_end_do: Optional[bool] = True + + +@dataclass_strict(frozen=True) +class WhileLoop(InternalNode, _WhileLoopBase): + """ + Internal representation of a while loop in source code. + + Importantly, this is different from a ``DO`` (Fortran) or ``for`` (C) loop, + as we do not have a specified induction variable with explicit iteration + range. + + Parameters + ---------- + condition : :any:`pymbolic.primitives.Expression` + The condition evaluated before executing the loop body. + body : tuple + The loop body. + pragma : tuple of :any:`Pragma`, optional + Pragma(s) that appear in front of the loop. By default :any:`Pragma` + nodes appear as standalone nodes in the IR before the :any:`Loop` node. + Only a bespoke context created by :py:func:`pragmas_attached` + attaches them for convenience. + pragma_post : tuple of :any:`Pragma`, optional + Pragma(s) that appear after the loop. The same applies as for `pragma`. + loop_label : str, optional + The Fortran label for that loop. Importantly, this is an intrinsic + Fortran feature and different from the statement label that can be + attached to other nodes. + name : str, optional + The Fortran construct name for that loop. + has_end_do : bool, optional + In Fortran, loop blocks can be closed off by a ``CONTINUE`` statement + (which we retain as an :any:`Intrinsic` node) and therefore ``END DO`` + can be omitted. For string reproducibility this parameter can be set + `False` to indicate that this loop did not have an ``END DO`` + statement in the original source. + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + _traversable = ['condition', 'body'] + + def __repr__(self): + label = ', '.join(l for l in [self.name, self.loop_label] if l is not None) + if label: + label = ' ' + label + control = str(self.condition) if self.condition else '' + return f'WhileLoop::{label} {control}' + + +@dataclass_strict(frozen=True) +class _ConditionalBase(): + """ Type definitions for :any:`Conditional` node type. """ + + condition: Expression + body: Tuple[Node, ...] + else_body: Optional[Tuple[Node, ...]] = () + inline: bool = False + has_elseif: bool = False + name: Optional[str] = None + + +@dataclass_strict(frozen=True) +class Conditional(InternalNode, _ConditionalBase): + """ + Internal representation of a conditional branching construct. + + Parameters + ---------- + condition : :any:`pymbolic.primitives.Expression` + The condition evaluated before executing the body. + body : tuple + The conditional's body. + else_body : tuple + The body of the else branch. Can be empty. + inline : bool, optional + Flag that marks this conditional as inline, i.e., it s body consists + only of a single statement that appeared immediately after the + ``IF`` statement and it does not have an ``else_body``. + has_elseif : bool, optional + Flag that indicates that this conditional has an ``ELSE IF`` branch + in the original source. In Loki's IR these are represented as a chain + of :any:`Conditional` but for string reproducibility this flag can be + provided to enable backends to reproduce the original appearance. + name : str, optional + The Fortran construct name for that conditional. + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + _traversable = ['condition', 'body', 'else_body'] + + @field_validator('body', 'else_body', mode='before') + @classmethod + def ensure_tuple(cls, value): + return sanitize_tuple(value) + + def __post_init__(self): + super().__post_init__() + assert self.condition is not None + + if self.has_elseif: + assert len(self.else_body) == 1 + assert isinstance(self.else_body[0], Conditional) # pylint: disable=unsubscriptable-object + + def __repr__(self): + if self.name: + return f'Conditional:: {self.name}' + return 'Conditional::' + + @property + def else_bodies(self): + """ + Return all nested node tuples in the ``ELSEIF``/``ELSE`` part + of the conditional chain. + """ + if self.has_elseif: + return (self.else_body[0].body,) + self.else_body[0].else_bodies + return (self.else_body,) if self.else_body else () + + +@dataclass_strict(frozen=True) +class _PragmaRegionBase(): + """ Type definitions for :any:`PragmaRegion` node type. """ + + body: Tuple[Node, ...] + pragma: Node = None + pragma_post: Node = None + + +@dataclass_strict(frozen=True) +class PragmaRegion(InternalNode, _PragmaRegionBase): + """ + Internal representation of a block of code defined by two matching pragmas. + + Generally, the pair of pragmas are assumed to be of the form + ``!$ `` and ``!$ end ``. + + This node type is injected into the IR within a context created by + :py:func:`pragma_regions_attached`. + + Parameters + ---------- + body : tuple + The statements appearing between opening and closing pragma. + pragma : :any:`Pragma` + The opening pragma declaring that region. + pragma_post : :any:`Pragma` + The closing pragma for that region. + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + _traversable = ['body'] + + def append(self, node): + self._update(body=self.body + as_tuple(node)) + + def insert(self, pos, node): + '''Insert at given position''' + self._update(body=self.body[:pos] + as_tuple(node) + self.body[pos:]) # pylint: disable=unsubscriptable-object + + def prepend(self, node): + self._update(body=as_tuple(node) + self.body) + + def __repr__(self): + return 'PragmaRegion::' + + +@dataclass_strict(frozen=True) +class _InterfaceBase(): + """ Type definitions for :any:`Interface` node type. """ + + body: Tuple[Any, ...] + abstract: bool = False + spec: Optional[Union[Expression, str]] = None + + +@dataclass_strict(frozen=True) +class Interface(InternalNode, _InterfaceBase): + """ + Internal representation of a Fortran interface block. + + Parameters + ---------- + body : tuple + The body of the interface block, containing function and subroutine + specifications or procedure statements + abstract : bool, optional + Flag to indicate that this is an abstract interface + spec : str, optional + A generic name, operator, assignment, or I/O specification + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + _traversable = ['body'] + + def __post_init__(self): + super().__post_init__() + assert not (self.abstract and self.spec) + + @property + def symbols(self): + """ + The list of symbol names declared by this interface + """ + symbols = as_tuple(flatten( + getattr(node, 'procedure_symbol', getattr(node, 'symbols', ())) + for node in self.body # pylint: disable=not-an-iterable + )) + if self.spec: + return (self.spec,) + symbols + return symbols + + @property + def symbol_map(self): + """ + Map symbol name to symbol declared by this interface + """ + return CaseInsensitiveDict( + (s.name.lower(), s) for s in self.symbols + ) + + def __contains__(self, name): + return name in self.symbol_map + + def __repr__(self): + symbols = ', '.join(str(var) for var in self.symbols) + if self.abstract: + return f'Abstract Interface:: {symbols}' + if self.spec: + return f'Interface {self.spec}:: {symbols}' + return f'Interface:: {symbols}' diff --git a/loki/ir/nodes.py b/loki/ir/nodes/leaf_nodes.py similarity index 58% rename from loki/ir/nodes.py rename to loki/ir/nodes/leaf_nodes.py index 9e2444255..056a2e1c8 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes/leaf_nodes.py @@ -5,860 +5,36 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# pylint: disable=too-many-lines """ Control flow node classes for :ref:`internal_representation:Control flow tree` """ -from abc import abstractmethod -from collections import OrderedDict -from dataclasses import dataclass -from functools import partial from itertools import chain from typing import Any, Tuple, Union, Optional from pymbolic.primitives import Expression - -from pydantic.dataclasses import dataclass as dataclass_validated from pydantic import field_validator -from loki.expression import ( - symbols as sym, Variable, parse_expr, AttachScopesMapper, - ExpressionDimensionsMapper +from loki.ir.nodes.abstract_nodes import ( + Node, InternalNode, LeafNode, ScopedNode +) +from loki.tools import ( + as_tuple, dataclass_strict, flatten, is_iterable, sanitize_tuple, + truncate_string, CaseInsensitiveDict ) -from loki.frontend.source import Source -from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict -from loki.types import DataType, BasicType, DerivedType, SymbolAttributes, Scope +from loki.types import DataType, BasicType, DerivedType, SymbolAttributes __all__ = [ - # Abstract base classes - 'Node', 'InternalNode', 'LeafNode', 'ScopedNode', - # Internal node classes - 'Section', 'Associate', 'Loop', 'WhileLoop', 'Conditional', - 'PragmaRegion', 'Interface', # Leaf node classes 'Assignment', 'ConditionalAssignment', 'CallStatement', 'Allocation', 'Deallocation', 'Nullify', 'Comment', 'CommentBlock', 'Pragma', 'PreprocessorDirective', 'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration', 'StatementFunction', 'TypeDef', 'MultiConditional', 'TypeConditional', - 'Forall', 'MaskedStatement', - 'Intrinsic', 'Enumeration', 'RawSource', + 'Forall', 'MaskedStatement', 'Enumeration', 'RawSource', ] -# Configuration for validation mechanism via pydantic -dataclass_validation_config = { - 'arbitrary_types_allowed': True, -} - -# Using this decorator, we can force strict validation -dataclass_strict = partial(dataclass_validated, config=dataclass_validation_config) - - -def _sanitize_tuple(t): - """ - Small helper method to ensure non-nested tuples without ``None``. - """ - return tuple(n for n in flatten(as_tuple(t)) if n is not None) - - -# Abstract base classes - -@dataclass_strict(frozen=True) -class Node: - """ - Base class for all node types in Loki's internal representation. - - Provides the common functionality shared by all node types; specifically, - this comprises functionality to update or rebuild a node, and source - metadata. - - Attributes - ---------- - traversable : list of str - The traversable fields of the Node; that is, fields walked over by - a :any:`Visitor`. All arguments in :py:meth:`__init__` whose - name appear in this list are treated as traversable fields. - - Parameters - ---------- - source : :any:`Source`, optional - the information about the original source for the Node. - label : str, optional - the label assigned to the statement in the original source - corresponding to the Node. - - """ - - source: Optional[Union[Source, str]] = None - label: Optional[str] = None - - _traversable = [] - - def __post_init__(self): - # Create private placeholders for dataflow analysis fields that - # do not show up in the dataclass field definitions, as these - # are entirely transient. - self._update(_live_symbols=None, _defines_symbols=None, _uses_symbols=None) - - @property - def children(self): - """ - The traversable children of the node. - """ - return tuple(getattr(self, i) for i in self._traversable) - - def _rebuild(self, *args, **kwargs): - """ - Rebuild the node. - - Constructs an identical copy of the node from when it was first - created. Optionally, some or all of the arguments for it can - be overwritten. - - Parameters - ---------- - *args : optional - The traversable arguments used to create the node. By default, - ``args`` are used. - **kwargs : optional - The non-traversable arguments used to create the node, By - default, ``args_frozen`` are used. - """ - handle = self.args - argnames = [i for i in self._traversable if i not in kwargs] - handle.update(OrderedDict(zip(argnames, args))) - handle.update(kwargs) - return type(self)(**handle) - - clone = _rebuild - - def _update(self, *args, **kwargs): - """ - In-place update that modifies (re-initializes) the node - without rebuilding it. Use with care! - - Parameters - ---------- - *args : optional - The traversable arguments used to create the node. By default, - ``args`` are used. - **kwargs : optional - The non-traversable arguments used to create the node, By - default, ``args_frozen`` are used. - - """ - argnames = [i for i in self._traversable if i not in kwargs] - kwargs.update(zip(argnames, args)) - self.__dict__.update(kwargs) - - @property - def args(self): - """ - Arguments used to construct the Node. - """ - return {k: v for k, v in self.__dict__.items() if k in self.__dataclass_fields__.keys()} # pylint: disable=no-member - - @property - def args_frozen(self): - """ - Arguments used to construct the Node that cannot be traversed. - """ - return {k: v for k, v in self.args.items() if k not in self._traversable} - - def __repr__(self): - raise NotImplementedError - - def view(self): - """ - Pretty-print the node hierachy under this node. - """ - # pylint: disable=import-outside-toplevel,cyclic-import - from loki.backend.pprint import pprint - pprint(self) - - def ir_graph(self, show_comments=False, show_expressions=False, linewidth=40, symgen=str): - """ - Get the IR graph to visualize the node hierachy under this node. - """ - # pylint: disable=import-outside-toplevel,cyclic-import - from loki.ir.ir_graph import ir_graph - - return ir_graph(self, show_comments, show_expressions,linewidth, symgen) - - @property - def live_symbols(self): - """ - Yield the list of live symbols at this node, i.e., variables that - have been defined (potentially) prior to this point in the control flow - graph. - - This property is attached to the Node by - :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or - when using the - :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` - context manager. - """ - if self.__dict__['_live_symbols'] is None: - raise RuntimeError('Need to run dataflow analysis on the IR first.') - return self.__dict__['_live_symbols'] - - @property - def defines_symbols(self): - """ - Yield the list of symbols (potentially) defined by this node. - - This property is attached to the Node by - :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or - when using the - :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` - context manager. - """ - if self.__dict__['_defines_symbols'] is None: - raise RuntimeError('Need to run dataflow analysis on the IR first.') - return self.__dict__['_defines_symbols'] - - @property - def uses_symbols(self): - """ - - Yield the list of symbols used by this node before defining it. - - This property is attached to the Node by - :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or - when using the - :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` - context manager. - """ - if self.__dict__['_uses_symbols'] is None: - raise RuntimeError('Need to run dataflow analysis on the IR first.') - return self.__dict__['_uses_symbols'] - - -@dataclass_strict(frozen=True) -class _InternalNode(): - """ Type definitions for :any:`InternalNode` node type. """ - - body: Tuple[Union[Node, Scope], ...] = () - - -@dataclass_strict(frozen=True) -class InternalNode(Node, _InternalNode): - """ - Internal representation of a control flow node that has a traversable - `body` property. - - Parameters - ---------- - body : tuple - The nodes that make up the body. - """ - - _traversable = ['body'] - - @field_validator('body', mode='before') - @classmethod - def ensure_tuple(cls, value): - return _sanitize_tuple(value) - - def __repr__(self): - raise NotImplementedError - - -@dataclass_strict(frozen=True) -class LeafNode(Node): - """ - Internal representation of a control flow node without a `body`. - """ - - def __repr__(self): - raise NotImplementedError - - -# Mix-ins - -class ScopedNode(Scope): - """ - Mix-in to attache a scope to an IR :any:`Node` - - Additionally, this specializes the node's :meth:`_update` and - :meth:`_rebuild` methods to make sure that an existing symbol table - is carried over correctly. - """ - - @property - def args(self): - """ - Arguments used to construct the :any:`ScopedNode`, excluding - the symbol table. - """ - keys = tuple(k for k in self.__dataclass_fields__.keys() if k not in ('symbol_attrs', )) # pylint: disable=no-member - return {k: v for k, v in self.__dict__.items() if k in keys} - - def _update(self, *args, **kwargs): - if 'symbol_attrs' not in kwargs: - # Retain the symbol table (unless given explicitly) - kwargs['symbol_attrs'] = self.symbol_attrs - super()._update(*args, **kwargs) # pylint: disable=no-member - - def _rebuild(self, *args, **kwargs): - # Retain the symbol table (unless given explicitly) - symbol_attrs = kwargs.pop('symbol_attrs', self.symbol_attrs) - rescope_symbols = kwargs.pop('rescope_symbols', False) - - # Ensure 'parent' is always explicitly set - kwargs['parent'] = kwargs.get('parent', None) - - new_obj = super()._rebuild(*args, **kwargs) # pylint: disable=no-member - new_obj.symbol_attrs.update(symbol_attrs) - - if rescope_symbols: - new_obj.rescope_symbols() - return new_obj - - def __getstate__(self): - s = self.args - s['symbol_attrs'] = self.symbol_attrs - return s - - def __setstate__(self, s): - symbol_attrs = s.pop('symbol_attrs', None) - self._update(**s, symbol_attrs=symbol_attrs, rescope_symbols=True) - - @property - @abstractmethod - def variables(self): - """ - Return the variables defined in this :any:`ScopedNode`. - """ - - @property - def variable_map(self): - """ - Map of variable names to :any:`Variable` objects - """ - return CaseInsensitiveDict((v.name, v) for v in self.variables) - - def get_symbol(self, name): - """ - Returns the symbol for a given name as defined in its declaration. - - The returned symbol might include dimension symbols if it was - declared as an array. - - Parameters - ---------- - name : str - Base name of the symbol to be retrieved - """ - return self.get_symbol_scope(name).variable_map.get(name) - - def Variable(self, **kwargs): - """ - Factory method for :any:`TypedSymbol` or :any:`MetaSymbol` classes. - - This invokes the :any:`Variable` with this node as the scope. - - Parameters - ---------- - name : str - The name of the variable. - type : optional - The type of that symbol. Defaults to :any:`BasicType.DEFERRED`. - parent : :any:`Scalar` or :any:`Array`, optional - The derived type variable this variable belongs to. - dimensions : :any:`ArraySubscript`, optional - The array subscript expression. - """ - kwargs['scope'] = self - return Variable(**kwargs) - - def parse_expr(self, expr_str, strict=False, evaluate=False, context=None): - """ - Uses :meth:`parse_expr` to convert expression(s) represented - in a string to Loki expression(s)/IR. - - Parameters - ---------- - expr_str : str - The expression as a string - strict : bool, optional - Whether to raise exception for unknown variables/symbols when - evaluating an expression (default: `False`) - evaluate : bool, optional - Whether to evaluate the expression or not (default: `False`) - context : dict, optional - Symbol context, defining variables/symbols/procedures to help/support - evaluating an expression - - Returns - ------- - :any:`Expression` - The expression tree corresponding to the expression - """ - return parse_expr(expr_str, scope=self, strict=strict, evaluate=evaluate, context=context) - - -# Intermediate node types - - -@dataclass_strict(frozen=True) -class _SectionBase(): - """ Type definitions for :any:`Section` node type. """ - - -@dataclass_strict(frozen=True) -class Section(InternalNode, _SectionBase): - """ - Internal representation of a single code region. - """ - - def append(self, node): - """ - Append the given node(s) to the section's body. - - Parameters - ---------- - node : :any:`Node` or tuple of :any:`Node` - The node(s) to append to the section. - """ - self._update(body=self.body + as_tuple(node)) - - def insert(self, pos, node): - """ - Insert the given node(s) into the section's body at a specific - position. - - Parameters - ---------- - pos : int - The position at which the node(s) should be inserted. Any existing - nodes at this or after this position are shifted back. - node : :any:`Node` or tuple of :any:`Node` - The node(s) to append to the section. - """ - self._update(body=self.body[:pos] + as_tuple(node) + self.body[pos:]) # pylint: disable=unsubscriptable-object - - def prepend(self, node): - """ - Insert the given node(s) at the beginning of the section's body. - - Parameters - ---------- - node : :any:`Node` or tuple of :any:`Node` - The node(s) to insert into the section. - """ - self._update(body=as_tuple(node) + self.body) - - def __repr__(self): - if self.label is not None: - return f'Section:: {self.label}' - return 'Section::' - - -@dataclass_strict(frozen=True) -class _AssociateBase(): - """ Type definitions for :any:`Associate` node type. """ - - associations: Tuple[Tuple[Expression, Expression], ...] - - -@dataclass_strict(frozen=True) -class Associate(ScopedNode, Section, _AssociateBase): # pylint: disable=too-many-ancestors - """ - Internal representation of a code region in which names are associated - with expressions or variables. - - Parameters - ---------- - body : tuple - The associate's body. - associations : dict or collections.OrderedDict - The mapping of names to expressions or variables valid inside the - associate's body. - parent : :any:`Scope`, optional - The parent scope in which the associate appears - symbol_attrs : :any:`SymbolTable`, optional - An existing symbol table to use - **kwargs : optional - Other parameters that are passed on to the parent class constructor. - """ - - _traversable = ['body', 'associations'] - - def __post_init__(self, parent=None): - super(ScopedNode, self).__post_init__(parent=parent) - super(Section, self).__post_init__() - - assert self.associations is None or isinstance(self.associations, tuple) - - @property - def association_map(self): - """ - An :any:`collections.OrderedDict` of associated expressions. - """ - return CaseInsensitiveDict((k, v) for k, v in self.associations) - - @property - def inverse_map(self): - """ - An :any:`collections.OrderedDict` of associated expressions. - """ - return CaseInsensitiveDict((v, k) for k, v in self.associations) - - @property - def variables(self): - return tuple(v for _, v in self.associations) - - def _derive_local_symbol_types(self, parent_scope): - """ Derive the types of locally defined symbols from their associations. """ - - rescoped_associations = () - for expr, name in self.associations: - # Put symbols in associated expression into the right scope - expr = AttachScopesMapper()(expr, scope=parent_scope) - - # Determine type of new names - if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)): - # Use the type of the associated variable - _type = expr.type.clone(parent=None) - if isinstance(expr, sym.Array) and expr.dimensions is not None: - shape = ExpressionDimensionsMapper()(expr) - if shape == (sym.IntLiteral(1),): - # For a scalar expression, we remove the shape - shape = None - _type = _type.clone(shape=shape) - else: - # TODO: Handle data type and shape of complex expressions - shape = ExpressionDimensionsMapper()(expr) - if shape == (sym.IntLiteral(1),): - # For a scalar expression, we remove the shape - shape = None - _type = SymbolAttributes(BasicType.DEFERRED, shape=shape) - name = name.clone(scope=self, type=_type) - rescoped_associations += ((expr, name),) - - self._update(associations=rescoped_associations) - - def __repr__(self): - if self.associations: - associations = ', '.join(f'{str(var)}={str(expr)}' - for var, expr in self.associations) - return f'Associate:: {associations}' - return 'Associate::' - - -@dataclass_strict(frozen=True) -class _LoopBase(): - """ Type definitions for :any:`Loop` node type. """ - - variable: Expression - bounds: Expression - body: Tuple[Node, ...] - pragma: Optional[Tuple[Node, ...]] = None - pragma_post: Optional[Tuple[Node, ...]] = None - loop_label: Optional[Any] = None - name: Optional[str] = None - has_end_do: Optional[bool] = True - - -@dataclass_strict(frozen=True) -class Loop(InternalNode, _LoopBase): - """ - Internal representation of a loop with induction variable and range. - - Parameters - ---------- - variable : :any:`Scalar` - The induction variable of the loop. - bounds : :any:`LoopRange` - The range of the loop, defining the iteration space. - body : tuple - The loop body. - pragma : tuple of :any:`Pragma`, optional - Pragma(s) that appear in front of the loop. By default :any:`Pragma` - nodes appear as standalone nodes in the IR before the :any:`Loop` node. - Only a bespoke context created by :py:func:`pragmas_attached` - attaches them for convenience. - pragma_post : tuple of :any:`Pragma`, optional - Pragma(s) that appear after the loop. The same applies as for `pragma`. - loop_label : str, optional - The Fortran label for that loop. Importantly, this is an intrinsic - Fortran feature and different from the statement label that can be - attached to other nodes. - name : str, optional - The Fortran construct name for that loop. - has_end_do : bool, optional - In Fortran, loop blocks can be closed off by a ``CONTINUE`` statement - (which we retain as an :any:`Intrinsic` node) and therefore ``END DO`` - can be omitted. For string reproducibility this parameter can be set - `False` to indicate that this loop did not have an ``END DO`` - statement in the original source. - **kwargs : optional - Other parameters that are passed on to the parent class constructor. - """ - - _traversable = ['variable', 'bounds', 'body'] - - def __post_init__(self): - super().__post_init__() - assert self.variable is not None - - def __repr__(self): - label = ', '.join(l for l in [self.name, self.loop_label] if l is not None) - if label: - label = ' ' + label - control = f'{str(self.variable)}={str(self.bounds)}' - return f'Loop::{label} {control}' - - -@dataclass_strict(frozen=True) -class _WhileLoopBase(): - """ Type definitions for :any:`WhileLoop` node type. """ - - condition: Optional[Expression] - body: Tuple[Node, ...] - pragma: Optional[Node] = None - pragma_post: Optional[Node] = None - loop_label: Optional[Any] = None - name: Optional[str] = None - has_end_do: Optional[bool] = True - - -@dataclass_strict(frozen=True) -class WhileLoop(InternalNode, _WhileLoopBase): - """ - Internal representation of a while loop in source code. - - Importantly, this is different from a ``DO`` (Fortran) or ``for`` (C) loop, - as we do not have a specified induction variable with explicit iteration - range. - - Parameters - ---------- - condition : :any:`pymbolic.primitives.Expression` - The condition evaluated before executing the loop body. - body : tuple - The loop body. - pragma : tuple of :any:`Pragma`, optional - Pragma(s) that appear in front of the loop. By default :any:`Pragma` - nodes appear as standalone nodes in the IR before the :any:`Loop` node. - Only a bespoke context created by :py:func:`pragmas_attached` - attaches them for convenience. - pragma_post : tuple of :any:`Pragma`, optional - Pragma(s) that appear after the loop. The same applies as for `pragma`. - loop_label : str, optional - The Fortran label for that loop. Importantly, this is an intrinsic - Fortran feature and different from the statement label that can be - attached to other nodes. - name : str, optional - The Fortran construct name for that loop. - has_end_do : bool, optional - In Fortran, loop blocks can be closed off by a ``CONTINUE`` statement - (which we retain as an :any:`Intrinsic` node) and therefore ``END DO`` - can be omitted. For string reproducibility this parameter can be set - `False` to indicate that this loop did not have an ``END DO`` - statement in the original source. - **kwargs : optional - Other parameters that are passed on to the parent class constructor. - """ - - _traversable = ['condition', 'body'] - - def __repr__(self): - label = ', '.join(l for l in [self.name, self.loop_label] if l is not None) - if label: - label = ' ' + label - control = str(self.condition) if self.condition else '' - return f'WhileLoop::{label} {control}' - - -@dataclass_strict(frozen=True) -class _ConditionalBase(): - """ Type definitions for :any:`Conditional` node type. """ - - condition: Expression - body: Tuple[Node, ...] - else_body: Optional[Tuple[Node, ...]] = () - inline: bool = False - has_elseif: bool = False - name: Optional[str] = None - - -@dataclass_strict(frozen=True) -class Conditional(InternalNode, _ConditionalBase): - """ - Internal representation of a conditional branching construct. - - Parameters - ---------- - condition : :any:`pymbolic.primitives.Expression` - The condition evaluated before executing the body. - body : tuple - The conditional's body. - else_body : tuple - The body of the else branch. Can be empty. - inline : bool, optional - Flag that marks this conditional as inline, i.e., it s body consists - only of a single statement that appeared immediately after the - ``IF`` statement and it does not have an ``else_body``. - has_elseif : bool, optional - Flag that indicates that this conditional has an ``ELSE IF`` branch - in the original source. In Loki's IR these are represented as a chain - of :any:`Conditional` but for string reproducibility this flag can be - provided to enable backends to reproduce the original appearance. - name : str, optional - The Fortran construct name for that conditional. - **kwargs : optional - Other parameters that are passed on to the parent class constructor. - """ - - _traversable = ['condition', 'body', 'else_body'] - - @field_validator('body', 'else_body', mode='before') - @classmethod - def ensure_tuple(cls, value): - return _sanitize_tuple(value) - - def __post_init__(self): - super().__post_init__() - assert self.condition is not None - - if self.has_elseif: - assert len(self.else_body) == 1 - assert isinstance(self.else_body[0], Conditional) # pylint: disable=unsubscriptable-object - - def __repr__(self): - if self.name: - return f'Conditional:: {self.name}' - return 'Conditional::' - - @property - def else_bodies(self): - """ - Return all nested node tuples in the ``ELSEIF``/``ELSE`` part - of the conditional chain. - """ - if self.has_elseif: - return (self.else_body[0].body,) + self.else_body[0].else_bodies - return (self.else_body,) if self.else_body else () - - -@dataclass_strict(frozen=True) -class _PragmaRegionBase(): - """ Type definitions for :any:`PragmaRegion` node type. """ - - body: Tuple[Node, ...] - pragma: Node = None - pragma_post: Node = None - - -@dataclass_strict(frozen=True) -class PragmaRegion(InternalNode, _PragmaRegionBase): - """ - Internal representation of a block of code defined by two matching pragmas. - - Generally, the pair of pragmas are assumed to be of the form - ``!$ `` and ``!$ end ``. - - This node type is injected into the IR within a context created by - :py:func:`pragma_regions_attached`. - - Parameters - ---------- - body : tuple - The statements appearing between opening and closing pragma. - pragma : :any:`Pragma` - The opening pragma declaring that region. - pragma_post : :any:`Pragma` - The closing pragma for that region. - **kwargs : optional - Other parameters that are passed on to the parent class constructor. - """ - - _traversable = ['body'] - - def append(self, node): - self._update(body=self.body + as_tuple(node)) - - def insert(self, pos, node): - '''Insert at given position''' - self._update(body=self.body[:pos] + as_tuple(node) + self.body[pos:]) # pylint: disable=unsubscriptable-object - - def prepend(self, node): - self._update(body=as_tuple(node) + self.body) - - def __repr__(self): - return 'PragmaRegion::' - - -@dataclass_strict(frozen=True) -class _InterfaceBase(): - """ Type definitions for :any:`Interface` node type. """ - - body: Tuple[Any, ...] - abstract: bool = False - spec: Optional[Union[Expression, str]] = None - - -@dataclass_strict(frozen=True) -class Interface(InternalNode, _InterfaceBase): - """ - Internal representation of a Fortran interface block. - - Parameters - ---------- - body : tuple - The body of the interface block, containing function and subroutine - specifications or procedure statements - abstract : bool, optional - Flag to indicate that this is an abstract interface - spec : str, optional - A generic name, operator, assignment, or I/O specification - **kwargs : optional - Other parameters that are passed on to the parent class constructor. - """ - - _traversable = ['body'] - - def __post_init__(self): - super().__post_init__() - assert not (self.abstract and self.spec) - - @property - def symbols(self): - """ - The list of symbol names declared by this interface - """ - symbols = as_tuple(flatten( - getattr(node, 'procedure_symbol', getattr(node, 'symbols', ())) - for node in self.body # pylint: disable=not-an-iterable - )) - if self.spec: - return (self.spec,) + symbols - return symbols - - @property - def symbol_map(self): - """ - Map symbol name to symbol declared by this interface - """ - return CaseInsensitiveDict( - (s.name.lower(), s) for s in self.symbols - ) - - def __contains__(self, name): - return name in self.symbol_map - - def __repr__(self): - symbols = ', '.join(str(var) for var in self.symbols) - if self.abstract: - return f'Abstract Interface:: {symbols}' - if self.spec: - return f'Interface {self.spec}:: {symbols}' - return f'Interface:: {symbols}' - # Leaf node types @dataclass_strict(frozen=True) @@ -991,12 +167,12 @@ class CallStatement(LeafNode, _CallStatementBase): @field_validator('arguments', mode='before') @classmethod def ensure_tuple(cls, value): - return _sanitize_tuple(value) + return sanitize_tuple(value) @field_validator('kwarguments', mode='before') @classmethod def ensure_nested_tuple(cls, value): - return tuple(_sanitize_tuple(pair) for pair in as_tuple(value)) + return tuple(sanitize_tuple(pair) for pair in as_tuple(value)) def __post_init__(self): super().__post_init__() @@ -1789,12 +965,12 @@ class MultiConditional(LeafNode, _MultiConditionalBase): @field_validator('else_body', mode='before') @classmethod def ensure_tuple(cls, value): - return _sanitize_tuple(value) + return sanitize_tuple(value) @field_validator('values', 'bodies', mode='before') @classmethod def ensure_nested_tuple(cls, value): - return tuple(_sanitize_tuple(pair) for pair in as_tuple(value)) + return tuple(sanitize_tuple(pair) for pair in as_tuple(value)) def __post_init__(self): super().__post_init__() @@ -1954,40 +1130,6 @@ def __repr__(self): return f'MaskedStatement:: {str(self.conditions[0])}' -@dataclass(frozen=True) -class _IntrinsicBase(): - """ Type definitions for :any:`Intrinsic` node type. """ - - text: str - - -@dataclass_strict(frozen=True) -class Intrinsic(LeafNode, _IntrinsicBase): - """ - Catch-all generic node for corner-cases. - - This is provided as a fallback for any statements that do not have - an appropriate representation in the IR. These can either be language - features for which support was not yet added, or statements that are not - relevant in Loki's scope of applications. This node retains the text of - the statement in the original source as-is. - - Parameters - ---------- - text : str - The statement as a string. - **kwargs : optional - Other parameters that are passed on to the parent class constructor. - """ - - def __post_init__(self): - super().__post_init__() - assert isinstance(self.text, str) - - def __repr__(self): - return f'Intrinsic:: {truncate_string(self.text)}' - - @dataclass_strict(frozen=True) class _EnumerationBase(): """ Type definitions for :any:`Enumeration` node type. """ diff --git a/loki/ir/nodes/stmt_nodes.py b/loki/ir/nodes/stmt_nodes.py new file mode 100644 index 000000000..f1636229a --- /dev/null +++ b/loki/ir/nodes/stmt_nodes.py @@ -0,0 +1,252 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" Generic and specific statement node type definitions. """ + +from typing import Optional, Union, Tuple + +from pymbolic.primitives import Expression +from pydantic import field_validator, ValidationError + +from loki.ir.nodes.abstract_nodes import LeafNode +from loki.tools import dataclass_strict, truncate_string, sanitize_tuple + + +__all__ = [ + 'GenericStmt', 'ImplicitStmt', 'SaveStmt', 'PublicStmt', + 'PrivateStmt', 'ContainsStmt', 'ReturnStmt', 'CycleStmt', + 'GotoStmt' +] + + +@dataclass_strict(frozen=True) +class _GenericStmtBase(): + """ Type definitions for :any:`GenericStmt` node type. """ + + text: str + + +@dataclass_strict(frozen=True) +class GenericStmt(LeafNode, _GenericStmtBase): + """ + Catch-all generic node for corner-cases. + + This is provided as a fallback for any statements that do not have + an appropriate representation in the IR. These can either be language + features for which support was not yet added, or statements that are not + relevant in Loki's scope of applications. This node retains the text of + the statement in the original source as-is. + + Parameters + ---------- + text : str + The statement as a string. + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = None + + def __repr__(self): + return f'GenericStmt:: {truncate_string(self.text)}' + + +@dataclass_strict(frozen=True) +class ImplicitStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``IMPLICIT`` statement. + + Parameters + ---------- + text : str or :any:`Expression`, optional + Either a tuple of variable specifiers or a string; default: ``NONE`` + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'IMPLICIT' + + text: Optional[Union[str, Tuple[Expression, ...]]] = 'NONE' + + @field_validator('text', mode='before') + @classmethod + def ensure_str_or_tuple(cls, value): + if isinstance(value, str): + return value + return sanitize_tuple(value) + + def __repr__(self): + return f'Implicit:: {truncate_string(self.text)}' + + +@dataclass_strict(frozen=True) +class SaveStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``SAVE`` statement. + + Parameters + ---------- + text : str or :any:`Expression`, optional + Either a tuple of variable specifiers or a string; default: ``NONE`` + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'SAVE' + + text: Optional[Tuple[Expression, ...]] = () + + @field_validator('text', mode='before') + @classmethod + def ensure_tuple(cls, value): + return sanitize_tuple(value) + + def __repr__(self): + return f'Save:: {truncate_string(self.text)}' + + +@dataclass_strict(frozen=True) +class PublicStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``PUBLIC`` specifier. + + Parameters + ---------- + text : str or tuple of :any:`Expression`, optional + Either a tuple of variable specifiers or a string; default: ``NONE`` + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'PUBLIC' + + text: Optional[Tuple[Expression, ...]] = () + + @field_validator('text', mode='before') + @classmethod + def ensure_tuple(cls, value): + return sanitize_tuple(value) + + def __repr__(self): + return f'Public:: {truncate_string(self.text)}' + + +@dataclass_strict(frozen=True) +class PrivateStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``PRIVATE`` specifier. + + Parameters + ---------- + text : str or tuple of :any:`Expression`, optional + Either a tuple of variable specifiers or a string; default: ``NONE`` + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'PRIVATE' + + text: Optional[Tuple[Expression, ...]] = () + + @field_validator('text', mode='before') + @classmethod + def ensure_tuple(cls, value): + return sanitize_tuple(value) + + def __repr__(self): + return f'Private:: {truncate_string(self.text)}' + + +@dataclass_strict(frozen=True) +class ContainsStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``CONTAINS`` specifier. + + Parameters + ---------- + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'CONTAINS' + + text: Optional[None] = None + + def __post_init__(self): + super().__post_init__() + if not self.text is None: + raise ValidationError('[Loki] ContainsStmt takes no constructor arguments') + + def __repr__(self): + return 'Contains::' + + +@dataclass_strict(frozen=True) +class ReturnStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``RETURN`` specifier. + + Parameters + ---------- + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'RETURN' + + text: Optional[None] = None + + def __post_init__(self): + super().__post_init__() + if not self.text is None: + raise ValidationError('[Loki] ReturnStmt takes no constructor arguments') + + def __repr__(self): + return 'Return::' + + +@dataclass_strict(frozen=True) +class CycleStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``CYCLE`` specifier. + + Parameters + ---------- + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'CYCLE' + + text: Optional[None] = None + + def __post_init__(self): + super().__post_init__() + if not self.text is None: + raise ValidationError('[Loki] CycleStmt takes no constructor arguments') + + def __repr__(self): + return 'Cycle::' + + +@dataclass_strict(frozen=True) +class GotoStmt(GenericStmt): + """ + :any:`GenericStmt` node that represents the ``GO TO`` specifier. + + Parameters + ---------- + **kwargs : optional + Other parameters that are passed on to the parent class constructor. + """ + + keyword = 'GO TO' + + text: str + + def __repr__(self): + return f'Goto:: {truncate_string(self.text)}' diff --git a/loki/ir/tests/test_ir_nodes.py b/loki/ir/nodes/tests/test_nodes.py similarity index 91% rename from loki/ir/tests/test_ir_nodes.py rename to loki/ir/nodes/tests/test_nodes.py index 93c0e5a79..311404fe0 100644 --- a/loki/ir/tests/test_ir_nodes.py +++ b/loki/ir/nodes/tests/test_nodes.py @@ -11,6 +11,7 @@ from pymbolic.primitives import Expression from pydantic import ValidationError +from loki.backend import fgen from loki.expression import symbols as sym, parse_expr from loki.function import Function from loki.ir import nodes as ir @@ -413,3 +414,48 @@ def test_multiconditional(scope, a_i, i): expr=i, values=(()), bodies=(()), else_body=((assign3,), assign2) ) assert multicond.else_body == (assign3, assign2) + + +def test_stmt_nodes(scope, n, i): + """ + Test constructors and scoping behaviour of various :any:`GenericStmt` nodes. + """ + + assert fgen(ir.GenericStmt('Hello World')) == 'Hello World' + with pytest.raises(ValidationError): + ir.GenericStmt(n) + + # Potential symbol quantifiers + assert fgen(ir.ImplicitStmt()) == 'IMPLICIT NONE' + assert fgen(ir.ImplicitStmt('NONE')) == 'IMPLICIT NONE' + assert fgen(ir.ImplicitStmt(i)) == 'IMPLICIT i' + assert fgen(ir.ImplicitStmt((n, i))) == 'IMPLICIT n, i' + + assert fgen(ir.SaveStmt()) == 'SAVE' + assert fgen(ir.SaveStmt(i)) == 'SAVE i' + assert fgen(ir.SaveStmt((n, i))) == 'SAVE n, i' + + assert fgen(ir.PublicStmt()) == 'PUBLIC' + assert fgen(ir.PublicStmt(i)) == 'PUBLIC i' + assert fgen(ir.PublicStmt((n, i))) == 'PUBLIC n, i' + + assert fgen(ir.PrivateStmt()) == 'PRIVATE' + assert fgen(ir.PrivateStmt(i)) == 'PRIVATE i' + assert fgen(ir.PrivateStmt((n, i))) == 'PRIVATE n, i' + + # Control flow statements + assert fgen(ir.ContainsStmt()) == 'CONTAINS' + with pytest.raises(ValidationError): + ir.ContainsStmt(n) + + assert fgen(ir.ReturnStmt()) == 'RETURN' + with pytest.raises(ValidationError): + ir.ReturnStmt(n) + + assert fgen(ir.CycleStmt()) == 'CYCLE' + with pytest.raises(ValidationError): + ir.CycleStmt(n) + + assert fgen(ir.GotoStmt('2345')) == 'GO TO 2345' + with pytest.raises(ValidationError): + assert ir.GotoStmt() diff --git a/loki/ir/tests/test_control_flow.py b/loki/ir/tests/test_control_flow.py index 295571acc..598354715 100644 --- a/loki/ir/tests/test_control_flow.py +++ b/loki/ir/tests/test_control_flow.py @@ -8,7 +8,7 @@ import pytest import numpy as np -from loki import Subroutine +from loki import Module, Subroutine from loki.backend import fgen from loki.jit_build import jit_compile, clean_test from loki.frontend import available_frontends, OMNI @@ -282,28 +282,6 @@ def test_multi_body_conditionals(tmp_path, frontend): clean_test(filepath) -@pytest.mark.parametrize('frontend', available_frontends()) -def test_goto_stmt(tmp_path, frontend): - fcode = """ -subroutine goto_stmt(var) - implicit none - integer, intent(out) :: var - var = 3 - go to 1234 - var = 5 - 1234 return - var = 7 -end subroutine goto_stmt -""" - filepath = tmp_path/(f'control_flow_goto_stmt_{frontend}.f90') - routine = Subroutine.from_source(fcode, frontend=frontend) - function = jit_compile(routine, filepath=filepath, objname='goto_stmt') - - result = function() - assert result == 3 - clean_test(filepath) - - @pytest.mark.parametrize('frontend', available_frontends()) def test_select_case(tmp_path, frontend): fcode = """ @@ -384,30 +362,6 @@ def test_select_case_nested(tmp_path, frontend): assert routine.to_fortran().count('! comment') == 7 -@pytest.mark.parametrize('frontend', available_frontends()) -def test_cycle_stmt(tmp_path, frontend): - fcode = """ -subroutine cycle_stmt(var) - implicit none - integer, intent(out) :: var - integer :: i - - var = 0 - do i=1,10 - if (var > 5) cycle - var = var + 1 - end do -end subroutine cycle_stmt -""" - filepath = tmp_path/(f'control_flow_cycle_stmt_{frontend}.f90') - routine = Subroutine.from_source(fcode, frontend=frontend) - function = jit_compile(routine, filepath=filepath, objname='cycle_stmt') - - result = function() - assert result == 6 - clean_test(filepath) - - @pytest.mark.parametrize('frontend', available_frontends()) def test_conditional_bodies(frontend): """Verify that conditional bodies and else-bodies are tuples of :class:`Node`""" @@ -487,13 +441,95 @@ def test_conditional_else_body_return(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) conditionals = FindNodes(ir.Conditional).visit(routine.body) assert len(conditionals) == 2 - assert isinstance(conditionals[0].body[-1], ir.Intrinsic) - assert conditionals[0].body[-1].text.upper() == 'RETURN' + assert isinstance(conditionals[0].body[-1], ir.ReturnStmt) assert conditionals[0].else_body == (conditionals[1],) - assert isinstance(conditionals[1].body[-1], ir.Intrinsic) - assert conditionals[1].body[-1].text.upper() == 'RETURN' - assert isinstance(conditionals[1].else_body[-1], ir.Intrinsic) - assert conditionals[1].else_body[-1].text.upper() == 'RETURN' + assert isinstance(conditionals[1].body[-1], ir.ReturnStmt) + assert isinstance(conditionals[1].else_body[-1], ir.ReturnStmt) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_fortran_statements(tmp_path, frontend): + fcode = """ +module my_mod + implicit none + real, dimension(6, 42) :: array + save +contains + + subroutine test_fortran_stmts(m, n, var) + integer, intent(in) :: m, n + integer, intent(inout) :: var + integer :: i + + var = 3 + if (m > 3) then + go to 1234 + end if + + var = 0 + do i=1, 10 + if (var > 5) cycle + var = var + 1 + end do + + 1234 return + var = 7 + return + end subroutine test_fortran_stmts +end module my_mod +""" + module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + routine = module['test_fortran_stmts'] + + # Module spec statements + m_spec_stmt = FindNodes(ir.GenericStmt).visit(module.spec) + if frontend == OMNI: + assert len(m_spec_stmt) == 1 + assert isinstance(m_spec_stmt[0], ir.ImplicitStmt) and m_spec_stmt[0].text == 'NONE' + else: + assert len(m_spec_stmt) == 2 + assert isinstance(m_spec_stmt[0], ir.ImplicitStmt) and m_spec_stmt[0].text == 'NONE' + assert isinstance(m_spec_stmt[1], ir.SaveStmt) + assert 'IMPLICIT NONE' in module.to_fortran() + + # Module contains statements + m_cont_stmt = FindNodes(ir.GenericStmt).visit(module.contains) + assert len(m_cont_stmt) == 1 + assert isinstance(m_cont_stmt[0], ir.ContainsStmt) + assert 'CONTAINS' in module.to_fortran() + + # Subroutine body statements + r_body_stmt = FindNodes(ir.GenericStmt).visit(routine.body) + assert len(r_body_stmt) == 4 + assert isinstance(r_body_stmt[0], ir.GotoStmt) and r_body_stmt[0].text == '1234' + assert isinstance(r_body_stmt[1], ir.CycleStmt) + assert isinstance(r_body_stmt[2], ir.ReturnStmt) + assert isinstance(r_body_stmt[3], ir.ReturnStmt) + assert 'GO TO 1234' in routine.to_fortran() + assert 'CYCLE' in routine.to_fortran() + assert routine.to_fortran().count('RETURN') == 2 + + +@pytest.mark.parametrize('frontend', available_frontends( + xfail=[(OMNI, 'No support for Cray Pointers')] +)) +def test_cray_pointers(frontend): + fcode = """ +SUBROUTINE SUBROUTINE_WITH_CRAY_POINTER (KLON,KLEV,POOL) + IMPLICIT NONE + INTEGER, INTENT(IN) :: KLON, KLEV + REAL, INTENT(INOUT) :: POOL(:) + REAL, DIMENSION(KLON,KLEV) :: ZQ + POINTER(IP_ZQ, ZQ) + IP_ZQ = LOC(POOL) +END SUBROUTINE + """.strip() + routine = Subroutine.from_source(fcode, frontend=frontend) + stmts = FindNodes(ir.GenericStmt).visit(routine.spec) + assert len(stmts) == 2 + assert isinstance(stmts[0], ir.ImplicitStmt) + assert 'POINTER(IP_ZQ, ZQ)' in stmts[1].text + assert 'POINTER(IP_ZQ, ZQ)' in routine.to_fortran() @pytest.mark.parametrize('frontend', available_frontends( @@ -675,25 +711,3 @@ def test_multi_line_forall_construct(tmp_path, frontend): assert regenerated_code[8].strip() == "c(i, j) = c(i, j + 2) + c(i, j - 2) + c(i + 2, j) + c(i - 2, j)" assert regenerated_code[9].strip() == "d(i, j) = c(i, j)" assert regenerated_code[10].strip() == "END FORALL" - - -@pytest.mark.parametrize('frontend', available_frontends( - xfail=[(OMNI, 'No support for Cray Pointers')] -)) -def test_cray_pointers(frontend): - fcode = """ -SUBROUTINE SUBROUTINE_WITH_CRAY_POINTER (KLON,KLEV,POOL) -IMPLICIT NONE -INTEGER, INTENT(IN) :: KLON, KLEV -REAL, INTENT(INOUT) :: POOL(:) -REAL, DIMENSION(KLON,KLEV) :: ZQ -POINTER(IP_ZQ, ZQ) -IP_ZQ = LOC(POOL) -END SUBROUTINE - """.strip() - routine = Subroutine.from_source(fcode, frontend=frontend) - intrinsics = FindNodes(ir.Intrinsic).visit(routine.spec) - assert len(intrinsics) == 2 - assert 'IMPLICIT NONE' in intrinsics[0].text - assert 'POINTER(IP_ZQ, ZQ)' in intrinsics[1].text - assert 'POINTER(IP_ZQ, ZQ)' in routine.to_fortran() diff --git a/loki/ir/tests/test_ir_graph.py b/loki/ir/tests/test_ir_graph.py index b44bd4a15..62a37e552 100644 --- a/loki/ir/tests/test_ir_graph.py +++ b/loki/ir/tests/test_ir_graph.py @@ -42,12 +42,12 @@ def fixture_testdir(here): "3": "", "4": "", "5": "", - "6": "", - "7": "", - "8": "", - "9": "", - "10": "", - "11": "", + "6": "", + "7": "", + "8": "", + "9": "", + "10": "", + "11": "", }, "connectivity_list": { "0": ["1"], @@ -67,8 +67,8 @@ def fixture_testdir(here): "3": "", "4": "", "5": "x > 0.0", - "6": "", - "7": "", + "6": "", + "7": "", }, "connectivity_list": { "0": ["1"], @@ -85,9 +85,9 @@ def fixture_testdir(here): "0": "", "1": "", "2": "", - "3": "", + "3": "", "4": "", - "5": "", + "5": "", "6": "", "7": "", "8": "", @@ -134,11 +134,11 @@ def fixture_testdir(here): "4": "", "5": "x > 0", "6": "y > 0", - "7": "", - "8": "", + "7": "", + "8": "", "9": "y > 0", - "10": "", - "11": "", + "10": "", + "11": "", }, "connectivity_list": { "0": ["1"], diff --git a/loki/ir/tests/test_visitor.py b/loki/ir/tests/test_visitor.py index 85cfb7ca5..5254a5f69 100644 --- a/loki/ir/tests/test_visitor.py +++ b/loki/ir/tests/test_visitor.py @@ -69,10 +69,10 @@ def test_find_scopes(frontend): """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) - intrinsics = FindNodes(ir.Intrinsic).visit(routine.body) - assert len(intrinsics) == 2 - inner = [i for i in intrinsics if 'Inner' in i.text][0] - outer = [i for i in intrinsics if 'Outer' in i.text][0] + stmts = FindNodes(ir.GenericStmt).visit(routine.body) + assert len(stmts) == 2 + inner = [i for i in stmts if 'Inner' in i.text][0] + outer = [i for i in stmts if 'Outer' in i.text][0] conditionals = FindNodes(ir.Conditional).visit(routine.body) assert len(conditionals) == 2 diff --git a/loki/lint/tests/test_reporter.py b/loki/lint/tests/test_reporter.py index aaf59bb72..ded93d181 100644 --- a/loki/lint/tests/test_reporter.py +++ b/loki/lint/tests/test_reporter.py @@ -16,7 +16,7 @@ except ImportError: HAVE_YAML = False -from loki.ir import Intrinsic +from loki.ir import GenericStmt from loki.lint.linter import lint_files from loki.lint.reporter import ( ProblemReport, RuleReport, FileReport, @@ -55,8 +55,8 @@ def dummy_file_fixture(here): def fixture_dummy_file_report(dummy_file): file_report = FileReport(str(dummy_file)) rule_report = RuleReport(GenericRule) - rule_report.add('Some message', Intrinsic('foobar')) - rule_report.add('Other message', Intrinsic('baz')) + rule_report.add('Some message', GenericStmt('foobar')) + rule_report.add('Other message', GenericStmt('baz')) file_report.add(rule_report) return file_report @@ -79,8 +79,8 @@ class SomeRule(GenericRule): rule_report = RuleReport(SomeRule) assert not rule_report.problem_reports and rule_report.problem_reports is not None - rule_report.add('Some message', Intrinsic('foobar')) - rule_report.add('Other message', Intrinsic('baz')) + rule_report.add('Some message', GenericStmt('foobar')) + rule_report.add('Other message', GenericStmt('baz')) assert len(rule_report.problem_reports) == 2 assert isinstance(rule_report.problem_reports[0], ProblemReport) assert rule_report.problem_reports[0].msg == 'Some message' diff --git a/loki/program_unit.py b/loki/program_unit.py index 45696b145..5a9aad732 100644 --- a/loki/program_unit.py +++ b/loki/program_unit.py @@ -79,10 +79,10 @@ def __initialize__(self, name, docstring=None, spec=None, contains=None, if not isinstance(contains, ir.Section): contains = ir.Section(body=as_tuple(contains)) for node in contains.body: - if isinstance(node, ir.Intrinsic) and 'contains' in node.text.lower(): # pylint: disable=no-member + if isinstance(node, ir.ContainsStmt): break if isinstance(node, ProgramUnit): - contains.prepend(ir.Intrinsic(text='CONTAINS')) + contains.prepend(ir.ContainsStmt()) break # Primary IR components @@ -759,8 +759,7 @@ def spec_parts(self): if not self.spec: return ((),(),()) - intrinsic_nodes = FindNodes(ir.Intrinsic).visit(self.spec) - implicit_nodes = [node for node in intrinsic_nodes if node.text.lstrip().lower().startswith('implicit')] + implicit_nodes = FindNodes(ir.ImplicitStmt).visit(self.spec) if implicit_nodes: # Use 'IMPLICIT' statements as divider diff --git a/loki/tests/test_modules.py b/loki/tests/test_modules.py index 98107291b..26f553dd6 100644 --- a/loki/tests/test_modules.py +++ b/loki/tests/test_modules.py @@ -422,6 +422,7 @@ def test_module_variables_add_remove(frontend, tmp_path): if frontend == OMNI: # OMNI frontend inserts a few peculiarities assert fgen(module.spec).lower() == """ +implicit none integer, parameter :: jprb = selected_real_kind(13, 300) integer :: x integer :: y @@ -1080,8 +1081,7 @@ def test_module_spec_parts(frontend, spec, part_lengths, tmp_path): assert all(isinstance(p, tuple) for p in module.spec_parts) if frontend == OMNI: - # OMNI removes any 'IMPLICIT' statements so the middle part is always empty - part_lengths = (part_lengths[0], 0, part_lengths[2]) + part_lengths = (part_lengths[0], 1, part_lengths[2]) else: # OMNI _conveniently_ puts any use statements _before_ the docstring for # absolutely zero sensible reasons, so it would be purely based on good luck @@ -1195,13 +1195,11 @@ def test_module_contains_auto_insert(frontend, tmp_path): routine1 = routine1.clone(contains=routine2) assert isinstance(routine1.contains, ir.Section) - assert isinstance(routine1.contains.body[0], ir.Intrinsic) - assert routine1.contains.body[0].text == 'CONTAINS' + assert isinstance(routine1.contains.body[0], ir.ContainsStmt) module = module.clone(contains=routine1) assert isinstance(module.contains, ir.Section) - assert isinstance(module.contains.body[0], ir.Intrinsic) - assert module.contains.body[0].text == 'CONTAINS' + assert isinstance(module.contains.body[0], ir.ContainsStmt) @pytest.mark.parametrize('frontend', available_frontends()) diff --git a/loki/tests/test_source_identity.py b/loki/tests/test_source_identity.py index 07a2ae841..977840c89 100644 --- a/loki/tests/test_source_identity.py +++ b/loki/tests/test_source_identity.py @@ -56,7 +56,7 @@ def test_raw_source_loop(tmp_path, frontend): # Check the intrinsics intrinsic_lines = (9, 11) - for node in FindNodes(ir.Intrinsic).visit(routine.body): + for node in FindNodes(ir.GenericStmt).visit(routine.body): # Verify that source string is subset of the relevant line in the original source assert node.source is not None assert node.source.lines in ((l, l) for l in intrinsic_lines) @@ -120,7 +120,7 @@ def test_raw_source_conditional(tmp_path, frontend): # Check the intrinsics intrinsic_lines = (5, 7, 9, 11) - for node in FindNodes(ir.Intrinsic).visit(routine.body): + for node in FindNodes(ir.GenericStmt).visit(routine.body): # Verify that source string is subset of the relevant line in the original source assert node.source is not None assert node.source.lines in ((l, l) for l in intrinsic_lines) @@ -177,7 +177,7 @@ def test_raw_source_multicond(tmp_path, frontend): # Check the intrinsics intrinsic_lines = (6, 8, 10) - for node in FindNodes(ir.Intrinsic).visit(routine.body): + for node in FindNodes(ir.GenericStmt).visit(routine.body): # Verify that source string is subset of the relevant line in the original source assert node.source is not None assert node.source.lines in ((l, l) for l in intrinsic_lines) diff --git a/loki/tests/test_sourcefile.py b/loki/tests/test_sourcefile.py index 8d890ee53..13a91ed68 100644 --- a/loki/tests/test_sourcefile.py +++ b/loki/tests/test_sourcefile.py @@ -11,10 +11,9 @@ import numpy as np from loki import ( - Sourcefile, FindNodes, PreprocessorDirective, Intrinsic, - Assignment, Import, fgen, ProcedureType, ProcedureSymbol, - StatementFunction, Comment, CommentBlock, RawSource, Scalar + Sourcefile, fgen, ProcedureType, ProcedureSymbol, Scalar ) +from loki.ir import nodes as ir, FindNodes from loki.jit_build import jit_compile, clean_test from loki.frontend import available_frontends, OMNI, FP, REGEX @@ -105,7 +104,7 @@ def test_sourcefile_from_source(frontend, tmp_path): assert 'contained_c' not in [routine.name.lower() for routine in source.subroutines] assert 'contained_c' not in [routine.name.lower() for routine in source.all_subroutines] - comments = FindNodes((Comment, CommentBlock)).visit(source.ir) + comments = FindNodes((ir.Comment, ir.CommentBlock)).visit(source.ir) assert len(comments) == 4 assert all(comment.text.strip() in ['! Some comment', '! Other comment'] for comment in comments) @@ -114,7 +113,7 @@ def test_sourcefile_from_source(frontend, tmp_path): def test_sourcefile_pp_macros(here, frontend): filepath = here/'sources/sourcefile_pp_macros.F90' routine = Sourcefile.from_file(filepath, frontend=frontend)['routine_pp_macros'] - directives = FindNodes(PreprocessorDirective).visit(routine.ir) + directives = FindNodes(ir.PreprocessorDirective).visit(routine.ir) assert len(directives) == 8 assert all(node.text.startswith('#') for node in directives) @@ -128,14 +127,14 @@ def test_sourcefile_pp_directives(here, frontend): # Note: these checks are rather loose as we currently do not restore the original version but # simply replace the PP constants by strings - directives = FindNodes(PreprocessorDirective).visit(routine.body) + directives = FindNodes(ir.PreprocessorDirective).visit(routine.body) assert len(directives) == 1 assert directives[0].text == '#define __FILENAME__ __FILE__' - intrinsics = FindNodes(Intrinsic).visit(routine.body) + intrinsics = FindNodes(ir.GenericStmt).visit(routine.body) assert '__FILENAME__' in intrinsics[0].text and '__DATE__' in intrinsics[0].text assert '__FILE__' in intrinsics[1].text and '__VERSION__' in intrinsics[1].text - statements = FindNodes(Assignment).visit(routine.body) + statements = FindNodes(ir.Assignment).visit(routine.body) assert len(statements) == 1 assert fgen(statements[0]) == 'y = 0*5 + 0' @@ -146,7 +145,7 @@ def test_sourcefile_pp_include(here, frontend): sourcefile = Sourcefile.from_file(filepath, frontend=frontend, includes=[here/'include']) routine = sourcefile['routine_pp_include'] - statements = FindNodes(Assignment).visit(routine.body) + statements = FindNodes(ir.Assignment).visit(routine.body) assert len(statements) == 1 if frontend == OMNI: # OMNI resolves that statement function! @@ -156,7 +155,7 @@ def test_sourcefile_pp_include(here, frontend): if frontend is not OMNI: # OMNI resolves the import in the frontend - imports = FindNodes(Import).visit([routine.spec, routine.body]) + imports = FindNodes(ir.Import).visit([routine.spec, routine.body]) assert len(imports) == 1 assert imports[0].c_import assert imports[0].module == 'some_header.h' @@ -171,11 +170,11 @@ def test_sourcefile_cpp_preprocessing(here, frontend): source = Sourcefile.from_file(filepath, preprocess=True, frontend=frontend) routine = source['sourcefile_external_preprocessing'] - directives = FindNodes(PreprocessorDirective).visit(routine.ir) + directives = FindNodes(ir.PreprocessorDirective).visit(routine.ir) if frontend is not OMNI: # OMNI skips the import in the frontend - imports = FindNodes(Import).visit([routine.spec, routine.body]) + imports = FindNodes(ir.Import).visit([routine.spec, routine.body]) assert len(imports) == 1 assert imports[0].c_import assert imports[0].module == 'some_header.h' @@ -187,7 +186,7 @@ def test_sourcefile_cpp_preprocessing(here, frontend): source = Sourcefile.from_file(filepath, preprocess=True, defines='FLAG_SMALL', frontend=frontend) routine = source['sourcefile_external_preprocessing'] - directives = FindNodes(PreprocessorDirective).visit(routine.ir) + directives = FindNodes(ir.PreprocessorDirective).visit(routine.ir) assert len(directives) == 0 assert 'b = 6' in fgen(routine) @@ -209,7 +208,7 @@ def test_sourcefile_cpp_stmt_func(here, frontend, tmp_path): # OMNI inlines statement functions, so we can't check the representation if frontend != OMNI: routine = source['cpp_stmt_func'] - stmt_func_decls = FindNodes(StatementFunction).visit(routine.spec) + stmt_func_decls = FindNodes(ir.StatementFunction).visit(routine.spec) assert len(stmt_func_decls) == 4 for decl in stmt_func_decls: @@ -286,8 +285,8 @@ def test_sourcefile_lazy_construction(frontend, tmp_path): # Make sure we have an incomplete parse tree until now assert source._incomplete - assert len(FindNodes(RawSource).visit(source.ir)) == 5 - assert len(FindNodes(RawSource).visit(source['routine_a'].ir)) == 1 + assert len(FindNodes(ir.RawSource).visit(source.ir)) == 5 + assert len(FindNodes(ir.RawSource).visit(source['routine_a'].ir)) == 1 # Trigger the full parse try: @@ -299,19 +298,19 @@ def test_sourcefile_lazy_construction(frontend, tmp_path): assert not source._incomplete # Make sure no RawSource nodes are left - assert not FindNodes(RawSource).visit(source.ir) + assert not FindNodes(ir.RawSource).visit(source.ir) if frontend == FP: # Some newlines are also treated as comments - assert len(FindNodes(Comment).visit(source.ir)) == 2 + assert len(FindNodes(ir.Comment).visit(source.ir)) == 2 else: - assert len(FindNodes(Comment).visit(source.ir)) == 1 + assert len(FindNodes(ir.Comment).visit(source.ir)) == 1 if frontend == OMNI: - assert not FindNodes(PreprocessorDirective).visit(source.ir) + assert not FindNodes(ir.PreprocessorDirective).visit(source.ir) else: - assert len(FindNodes(PreprocessorDirective).visit(source.ir)) == 2 + assert len(FindNodes(ir.PreprocessorDirective).visit(source.ir)) == 2 for routine in source.all_subroutines: - assert not FindNodes(RawSource).visit(routine.ir) - assert len(FindNodes(Assignment).visit(routine.ir)) == 1 + assert not FindNodes(ir.RawSource).visit(routine.ir) + assert len(FindNodes(ir.Assignment).visit(routine.ir)) == 1 # The previously generated ProgramUnit objects should be the same as before assert routine_b is source['routine_b'] @@ -336,20 +335,20 @@ def test_sourcefile_lazy_comments(frontend): """.strip() source = Sourcefile.from_source(fcode, frontend=REGEX) - assert isinstance(source.ir.body[0], RawSource) - assert isinstance(source.ir.body[2], RawSource) + assert isinstance(source.ir.body[0], ir.RawSource) + assert isinstance(source.ir.body[2], ir.RawSource) myroutine = source['myroutine'] - assert isinstance(myroutine.spec.body[0], RawSource) + assert isinstance(myroutine.spec.body[0], ir.RawSource) source.make_complete(frontend=frontend) - assert isinstance(source.ir.body[0], Comment) - assert isinstance(source.ir.body[2], Comment) + assert isinstance(source.ir.body[0], ir.Comment) + assert isinstance(source.ir.body[2], ir.Comment) if frontend == OMNI: - assert isinstance(myroutine.body.body[0], Comment) + assert isinstance(myroutine.body.body[0], ir.Comment) else: - assert isinstance(myroutine.docstring[0], Comment) + assert isinstance(myroutine.docstring[0], ir.Comment) code = source.to_fortran() assert '! Comment outside' in code @@ -409,7 +408,7 @@ def test_sourcefile_clone(frontend, tmp_path): assert 'new_mod_routine' in new_new_source['my_mod'] if not source._incomplete: - assert isinstance(source.ir.body[0], Comment) + assert isinstance(source.ir.body[0], ir.Comment) comment_text = source.ir.body[0].text new_comment_text = comment_text + ' some more text' source.ir.body[0]._update(text=new_comment_text) diff --git a/loki/tests/test_subroutine.py b/loki/tests/test_subroutine.py index 1ae20262a..c6afe7e0f 100644 --- a/loki/tests/test_subroutine.py +++ b/loki/tests/test_subroutine.py @@ -84,7 +84,7 @@ def test_routine_simple(frontend): assert isinstance(routine.body, ir.Section) if frontend == OMNI: assert len(routine.spec.body) == 9 - assert isinstance(routine.spec.body[0], ir.Intrinsic) + assert isinstance(routine.spec.body[0], ir.GenericStmt) assert isinstance(routine.spec.body[1], ir.Pragma) assert all(isinstance(n, ir.VariableDeclaration) for n in routine.spec.body[2:]) assert routine.spec.body[2].symbols == ('jprb',) diff --git a/loki/tools/__init__.py b/loki/tools/__init__.py index 001cb5c2d..975dc7ce5 100644 --- a/loki/tools/__init__.py +++ b/loki/tools/__init__.py @@ -8,6 +8,7 @@ Collection of tools and utility methods used throughout Loki. """ -from loki.tools.util import * # noqa +from loki.tools.dataclass import * # noqa from loki.tools.files import * # noqa from loki.tools.strings import * # noqa +from loki.tools.util import * # noqa diff --git a/loki/tools/dataclass.py b/loki/tools/dataclass.py new file mode 100644 index 000000000..115ca4033 --- /dev/null +++ b/loki/tools/dataclass.py @@ -0,0 +1,22 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from functools import partial + +from pydantic.dataclasses import dataclass as dataclass_validated + + +__all__ = ['dataclass_strict'] + + +# Configuration for validation mechanism via pydantic +dataclass_validation_config = { + 'arbitrary_types_allowed': True, +} + +# Using this decorator, we can force strict validation +dataclass_strict = partial(dataclass_validated, config=dataclass_validation_config) diff --git a/loki/tools/util.py b/loki/tools/util.py index 40c24a797..7b648e70c 100644 --- a/loki/tools/util.py +++ b/loki/tools/util.py @@ -33,15 +33,14 @@ __all__ = [ 'as_tuple', 'is_iterable', 'is_subset', 'flatten', 'chunks', - 'execute', 'CaseInsensitiveDict', 'CaseInsensitiveDefaultDict', - 'strip_inline_comments', + 'sanitize_tuple', 'execute', 'CaseInsensitiveDict', + 'CaseInsensitiveDefaultDict', 'strip_inline_comments', 'binary_insertion_sort', 'cached_func', 'optional', 'LazyNodeLookup', 'yaml_include_constructor', 'auto_post_mortem_debugger', 'set_excepthook', 'timeout', 'WeakrefProperty', 'group_by_class', 'replace_windowed', 'dict_override', 'stdchannel_redirected', - 'stdchannel_is_captured', 'graphviz_present', - 'OrderedSet' + 'stdchannel_is_captured', 'graphviz_present', 'OrderedSet' ] @@ -184,6 +183,13 @@ def chunks(l, n): yield l[i:i + n] +def sanitize_tuple(t): + """ + Small helper method to ensure non-nested tuples without ``None``. + """ + return tuple(n for n in flatten(as_tuple(t)) if n is not None) + + def execute(command, silent=True, **kwargs): """ Execute a single command within a given directory or environment diff --git a/loki/transformations/build_system/tests/test_dependency.py b/loki/transformations/build_system/tests/test_dependency.py index 5440f9b51..134d19a18 100644 --- a/loki/transformations/build_system/tests/test_dependency.py +++ b/loki/transformations/build_system/tests/test_dependency.py @@ -11,9 +11,7 @@ from loki import Sourcefile from loki.batch import Scheduler, SchedulerConfig from loki.frontend import available_frontends, OMNI -from loki.ir import ( - FindNodes, CallStatement, Import, Interface, Intrinsic, FindInlineCalls -) +from loki.ir import nodes as ir, FindNodes, FindInlineCalls from loki.transformations import ( DependencyTransformation, ModuleWrapTransformation @@ -107,17 +105,17 @@ def test_dependency_transformation_globalvar_imports(frontend, use_scheduler, tm assert kernel.modules[0].variables[0].name == 'some_const' # Check that calls and matching import have been diverted to the re-generated routine - calls = FindNodes(CallStatement).visit(driver['driver'].body) + calls = FindNodes(ir.CallStatement).visit(driver['driver'].body) assert len(calls) == 1 assert calls[0].name == 'kernel_test' - imports = FindNodes(Import).visit(driver['driver'].spec) + imports = FindNodes(ir.Import).visit(driver['driver'].spec) assert len(imports) == 2 - assert isinstance(imports[0], Import) + assert isinstance(imports[0], ir.Import) assert driver['driver'].spec.body[0].module == 'kernel_test_mod' assert 'kernel_test' in [str(s) for s in driver['driver'].spec.body[0].symbols] # Check that global variable import remains unchanged - assert isinstance(imports[1], Import) + assert isinstance(imports[1], ir.Import) assert driver['driver'].spec.body[1].module == 'kernel_mod' assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols] @@ -229,17 +227,17 @@ def test_dependency_transformation_access_spec_names(frontend, use_scheduler, tm assert kernel.modules[0].typedefs[1].name == 't_type_2' # Check that calls and matching import have been diverted to the re-generated routine - calls = FindNodes(CallStatement).visit(driver['driver'].body) + calls = FindNodes(ir.CallStatement).visit(driver['driver'].body) assert len(calls) == 1 assert calls[0].name == 'kernel_test' - imports = FindNodes(Import).visit(driver['driver'].spec) + imports = FindNodes(ir.Import).visit(driver['driver'].spec) assert len(imports) == 3 - assert isinstance(imports[0], Import) + assert isinstance(imports[0], ir.Import) assert driver['driver'].spec.body[0].module == 'kernel_access_spec_test_mod' assert 'kernel_test' in [str(s) for s in driver['driver'].spec.body[0].symbols] # Check that global variable import remains unchanged - assert isinstance(imports[1], Import) + assert isinstance(imports[1], ir.Import) assert driver['driver'].spec.body[1].module == 'kernel_access_spec_mod' assert 'another_const' in [str(s) for s in driver['driver'].spec.body[1].symbols] assert 'some_const' in [str(s) for s in driver['driver'].spec.body[2].symbols] @@ -311,17 +309,17 @@ def test_dependency_transformation_globalvar_imports_driver_mod(frontend, use_sc assert kernel.modules[0].variables[0].name == 'some_const' # Check that calls and matching import have been diverted to the re-generated routine - calls = FindNodes(CallStatement).visit(driver['driver'].body) + calls = FindNodes(ir.CallStatement).visit(driver['driver'].body) assert len(calls) == 1 assert calls[0].name == 'kernel_test' - imports = FindNodes(Import).visit(driver['driver_mod'].spec) + imports = FindNodes(ir.Import).visit(driver['driver_mod'].spec) assert len(imports) == 2 - assert isinstance(imports[0], Import) + assert isinstance(imports[0], ir.Import) assert driver['driver_mod'].spec.body[0].module == 'kernel_test_mod' assert 'kernel_test' in [str(s) for s in driver['driver_mod'].spec.body[0].symbols] # Check that global variable import remains unchanged - assert isinstance(imports[1], Import) + assert isinstance(imports[1], ir.Import) assert driver['driver_mod'].spec.body[1].module == 'kernel_mod' assert 'some_const' in [str(s) for s in driver['driver_mod'].spec.body[1].symbols] @@ -499,9 +497,9 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, replace_ assert driver.subroutines[0].name == 'driver' # Check that calls and imports have been diverted to the re-generated routine - calls = FindNodes(CallStatement).visit(driver['driver'].body) + calls = FindNodes(ir.CallStatement).visit(driver['driver'].body) assert len(calls) == 2 - imports = FindNodes(Import).visit(driver['driver'].ir) + imports = FindNodes(ir.Import).visit(driver['driver'].ir) assert len(imports) == 3 _imported_symbols = driver['driver'].imported_symbols @@ -606,26 +604,26 @@ def test_dependency_transformation_replace_interface(frontend, use_scheduler, mo assert driver.subroutines[0].name == 'driver' # Check that calls have been diverted to the re-generated routine - calls = FindNodes(CallStatement).visit(driver['driver'].body) + calls = FindNodes(ir.CallStatement).visit(driver['driver'].body) assert len(calls) == 1 assert calls[0].name == 'kernel_test' if module_wrap: # Check that imports have been generated - imports = FindNodes(Import).visit(driver['driver'].spec) + imports = FindNodes(ir.Import).visit(driver['driver'].spec) assert len(imports) == 1 assert imports[0].module.lower() == 'kernel_test_mod' assert 'kernel_test' in imports[0].symbols # Check that the newly generated USE statement appears before IMPLICIT NONE - nodes = FindNodes((Intrinsic, Import)).visit(driver['driver'].spec) + nodes = FindNodes((ir.ImplicitStmt, ir.Import)).visit(driver['driver'].spec) assert len(nodes) == 2 - assert isinstance(nodes[1], Intrinsic) - assert nodes[1].text.lower() == 'implicit none' + assert isinstance(nodes[1], ir.ImplicitStmt) + assert nodes[1].text.lower() == 'none' else: # Check that the interface has been updated - intfs = FindNodes(Interface).visit(driver['driver'].spec) + intfs = FindNodes(ir.Interface).visit(driver['driver'].spec) assert len(intfs) == 1 assert intfs[0].symbols == ('kernel_test',) @@ -697,7 +695,7 @@ def test_dependency_transformation_inline_call(frontend): calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body)) assert len(calls) == 3 assert calls[0].name == 'kernel_test' - imports = FindNodes(Import).visit(driver['driver'].spec) + imports = FindNodes(ir.Import).visit(driver['driver'].spec) assert len(imports) == 1 assert imports[0].module == 'kernel_test_mod' assert 'kernel_test' in [str(s) for s in imports[0].symbols] @@ -768,7 +766,7 @@ def test_dependency_transformation_inline_call_result_var(frontend): calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body)) assert len(calls) == 3 assert calls[0].name == 'kernel_test' - imports = FindNodes(Import).visit(driver['driver'].spec) + imports = FindNodes(ir.Import).visit(driver['driver'].spec) assert len(imports) == 1 assert imports[0].module == 'kernel_test_mod' assert 'kernel_test' in [str(s) for s in imports[0].symbols] @@ -839,10 +837,10 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tmp driver['driver'].apply(transformation, role='driver', targets=('kernel', 'kernel_mod')) # Check that calls and matching import have been diverted to the re-generated routine - calls = FindNodes(CallStatement).visit(driver['driver'].body) + calls = FindNodes(ir.CallStatement).visit(driver['driver'].body) assert len(calls) == 1 assert calls[0].name == 'kernel_test' - imports = FindNodes(Import).visit(driver['driver'].spec) + imports = FindNodes(ir.Import).visit(driver['driver'].spec) assert len(imports) == 1 assert imports[0].module.lower() == 'kernel_test_mod' assert imports[0].symbols == ('kernel_test',) @@ -856,7 +854,7 @@ def test_dependency_transformation_contained_member(frontend, use_scheduler, tmp assert kernel['kernel_test'].subroutines[1].name.lower() == 'get_b' # Check if kernel calls have been renamed - calls = FindNodes(CallStatement).visit(kernel['kernel_test'].body) + calls = FindNodes(ir.CallStatement).visit(kernel['kernel_test'].body) assert len(calls) == 1 assert calls[0].name == 'set_a' @@ -956,7 +954,7 @@ def test_dependency_transformation_item_filter(frontend, tmp_path, config): calls = tuple(FindInlineCalls(unique=False).visit(driver['driver'].body)) assert len(calls) == 3 assert all(call.name == 'kernel_test' for call in calls) - imports = FindNodes(Import).visit(driver['driver'].spec) + imports = FindNodes(ir.Import).visit(driver['driver'].spec) imports = driver['driver'].import_map assert len(imports) == 2 assert 'header_var' in imports and imports['header_var'].module.lower() == 'header_mod' diff --git a/loki/transformations/extract/tests/test_outline.py b/loki/transformations/extract/tests/test_outline.py index a74d497c3..64aed9ba3 100644 --- a/loki/transformations/extract/tests/test_outline.py +++ b/loki/transformations/extract/tests/test_outline.py @@ -11,7 +11,7 @@ from loki import Module, Subroutine from loki.jit_build import jit_compile, jit_compile_lib, Builder, Obj from loki.frontend import available_frontends -from loki.ir import FindNodes, Section, Assignment, CallStatement, Intrinsic +from loki.ir import nodes as ir, FindNodes from loki.tools import as_tuple from loki.types import BasicType @@ -51,19 +51,19 @@ def test_outline_pragma_regions(tmp_path, frontend): a, b, c = function() assert a == 1 and b == 1 and c == 2 - assert len(FindNodes(Assignment).visit(routine.body)) == 4 - assert len(FindNodes(CallStatement).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 4 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 0 # Apply transformation routines = outline_pragma_regions(routine) assert len(routines) == 1 and routines[0].name == f'{routine.name}_outlined_0' - assert len(FindNodes(Assignment).visit(routine.body)) == 3 - assert len(FindNodes(Assignment).visit(routines[0].body)) == 1 - assert len(FindNodes(CallStatement).visit(routine.body)) == 1 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3 + assert len(FindNodes(ir.Assignment).visit(routines[0].body)) == 1 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 # Test transformation - contains = Section(body=as_tuple([Intrinsic('CONTAINS'), *routines, routine])) + contains = ir.Section(body=as_tuple([ir.ContainsStmt(), *routines, routine])) module = Module(name=f'{routine.name}_mod', spec=None, contains=contains) mod_filepath = tmp_path/(f'{module.name}_converted_{frontend}.f90') mod = jit_compile(module, filepath=mod_filepath, objname=module.name) @@ -107,8 +107,8 @@ def test_outline_pragma_regions_multiple(tmp_path, frontend): a, b, c = function() assert a == 5 and b == 5 and c == 10 - assert len(FindNodes(Assignment).visit(routine.body)) == 7 - assert len(FindNodes(CallStatement).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 7 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 0 # Apply transformation routines = outline_pragma_regions(routine) @@ -116,12 +116,12 @@ def test_outline_pragma_regions_multiple(tmp_path, frontend): assert routines[0].name == 'oiwjfklsf' assert all(routines[i].name == f'{routine.name}_outlined_{i}' for i in (1,2)) - assert len(FindNodes(Assignment).visit(routine.body)) == 4 - assert all(len(FindNodes(Assignment).visit(r.body)) == 1 for r in routines) - assert len(FindNodes(CallStatement).visit(routine.body)) == 3 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 4 + assert all(len(FindNodes(ir.Assignment).visit(r.body)) == 1 for r in routines) + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 3 # Test transformation - contains = Section(body=as_tuple([Intrinsic('CONTAINS'), *routines, routine])) + contains = ir.Section(body=as_tuple([ir.ContainsStmt(), *routines, routine])) module = Module(name=f'{routine.name}_mod', spec=None, contains=contains) mod_filepath = tmp_path/(f'{module.name}_converted_{frontend}.f90') mod = jit_compile(module, filepath=mod_filepath, objname=module.name) @@ -167,8 +167,8 @@ def test_outline_pragma_regions_arguments(tmp_path, frontend): a, b, c = function() assert a == 5 and b == 5 and c == 10 - assert len(FindNodes(Assignment).visit(routine.body)) == 7 - assert len(FindNodes(CallStatement).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 7 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 0 # Apply transformation routines = outline_pragma_regions(routine) @@ -187,12 +187,12 @@ def test_outline_pragma_regions_arguments(tmp_path, frontend): assert routines[2].variable_map['b'].type.intent == 'inout' assert routines[2].variable_map['c'].type.intent == 'out' - assert len(FindNodes(Assignment).visit(routine.body)) == 4 - assert all(len(FindNodes(Assignment).visit(r.body)) == 1 for r in routines) - assert len(FindNodes(CallStatement).visit(routine.body)) == 3 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 4 + assert all(len(FindNodes(ir.Assignment).visit(r.body)) == 1 for r in routines) + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 3 # Test transformation - contains = Section(body=as_tuple([Intrinsic('CONTAINS'), *routines, routine])) + contains = ir.Section(body=as_tuple([ir.ContainsStmt(), *routines, routine])) module = Module(name=f'{routine.name}_mod', spec=None, contains=contains) mod_filepath = tmp_path/(f'{module.name}_converted_{frontend}.f90') mod = jit_compile(module, filepath=mod_filepath, objname=module.name) @@ -246,14 +246,14 @@ def test_outline_pragma_regions_arrays(tmp_path, frontend): assert np.all(a == range(1,n+1)) assert np.all(b == [1] * n) - assert len(FindNodes(Assignment).visit(routine.body)) == 4 - assert len(FindNodes(CallStatement).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 4 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 0 # Apply transformation routines = outline_pragma_regions(routine) - assert len(FindNodes(Assignment).visit(routine.body)) == 0 - assert len(FindNodes(CallStatement).visit(routine.body)) == 3 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 0 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 3 assert len(routines) == 3 @@ -263,7 +263,7 @@ def test_outline_pragma_regions_arrays(tmp_path, frontend): assert routines[0].variable_map['a'].dimensions[0].scope is routines[0] # Test transformation - contains = Section(body=as_tuple([Intrinsic('CONTAINS'), *routines, routine])) + contains = ir.Section(body=as_tuple([ir.ContainsStmt(), *routines, routine])) module = Module(name=f'{routine.name}_mod', spec=None, contains=contains) mod_filepath = tmp_path/(f'{module.name}_converted_{frontend}.f90') mod = jit_compile(module, filepath=mod_filepath, objname=module.name) @@ -335,14 +335,14 @@ def test_outline_pragma_regions_imports(tmp_path, builder, frontend): assert np.all(b == range(1,11)) (tmp_path/f'{module.name}.f90').unlink() - assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 4 - assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 0 + assert len(FindNodes(ir.Assignment).visit(module.subroutines[0].body)) == 4 + assert len(FindNodes(ir.CallStatement).visit(module.subroutines[0].body)) == 0 # Apply transformation routines = outline_pragma_regions(module.subroutines[0]) - assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 1 - assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 3 + assert len(FindNodes(ir.Assignment).visit(module.subroutines[0].body)) == 1 + assert len(FindNodes(ir.CallStatement).visit(module.subroutines[0].body)) == 3 assert len(routines) == 3 @@ -413,14 +413,14 @@ def test_outline_pragma_regions_derived_args(tmp_path, builder, frontend): assert np.all(b == 42) (tmp_path/f'{module.name}.f90').unlink() - assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 6 - assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 0 + assert len(FindNodes(ir.Assignment).visit(module.subroutines[0].body)) == 6 + assert len(FindNodes(ir.CallStatement).visit(module.subroutines[0].body)) == 0 # Apply transformation routines = outline_pragma_regions(module.subroutines[0]) - assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 4 - assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 1 + assert len(FindNodes(ir.Assignment).visit(module.subroutines[0].body)) == 4 + assert len(FindNodes(ir.CallStatement).visit(module.subroutines[0].body)) == 1 # Check for a single derived-type argument assert len(routines) == 1 @@ -496,14 +496,14 @@ def test_outline_pragma_regions_associates(tmp_path, builder, frontend): assert np.all(b == 42) (tmp_path/f'{module.name}.f90').unlink() - assert len(FindNodes(Assignment).visit(routine.body)) == 6 - assert len(FindNodes(CallStatement).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 6 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 0 # Apply transformation outlined = outline_pragma_regions(routine) - assert len(FindNodes(Assignment).visit(routine.body)) == 4 - calls = FindNodes(CallStatement).visit(routine.body) + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 4 + calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments == ('c', 'd') diff --git a/loki/transformations/parametrise.py b/loki/transformations/parametrise.py index 0f247caa8..cc93f07ca 100644 --- a/loki/transformations/parametrise.py +++ b/loki/transformations/parametrise.py @@ -175,7 +175,7 @@ class ParametriseTransformation(Transformation): def error_stop(**kwargs): msg = kwargs.get("msg") - return ir.Intrinsic(text=f'error stop "{msg}"'), + return ir.GenericStmt(text=f'error stop "{msg}"'), dic2p = {'a': 12, 'b': 11} @@ -286,8 +286,8 @@ def transform_subroutine(self, routine, **kwargs): parametrised_var = routine.variable_map[f'parametrised_{key}'] # use default abort mechanism if self.abort_callback is None: - abort = (ir.Intrinsic(text=f'PRINT *, "{error_msg}: ", {parametrised_var.name}'), - ir.Intrinsic(text="STOP 1")) + abort = (ir.GenericStmt(text=f'PRINT *, "{error_msg}: ", {parametrised_var.name}'), + ir.GenericStmt(text="STOP 1")) # use user define abort/warn mechanism else: kwargs = {"msg": error_msg, "routine": routine.name, "var": parametrised_var, diff --git a/loki/transformations/remove_code.py b/loki/transformations/remove_code.py index f0f147702..d813bd411 100644 --- a/loki/transformations/remove_code.py +++ b/loki/transformations/remove_code.py @@ -72,7 +72,7 @@ class RemoveCodeTransformation(Transformation): List of subroutine names against which to match :any:`CallStatement` nodes during :meth:`remove_calls`. intrinsic_names : list of str - List of module names against which to match :any:`Intrinsic` + List of module names against which to match :any:`GenericStmt` nodes during :meth:`remove_calls`. remove_imports : boolean Flag indicating whether to remove symbols from :any:`Import` @@ -489,7 +489,7 @@ def do_remove_calls( List of subroutine names against which to match :any:`CallStatement` nodes. intrinsic_names : list of str - List of module names against which to match :any:`Intrinsic` + List of module names against which to match :any:`GenericStmt` nodes. remove_imports : boolean Flag indicating whether to remove the respective procedure @@ -514,7 +514,7 @@ class RemoveCallsTransformer(Transformer): (flag) call named_procedure()``. This :any:`Transformer` will also attempt to match and remove - :any:`Intrinsic` nodes against a given list of name strings. This + :any:`GenericStmt` nodes against a given list of name strings. This allows removing intrinsic calls like ``write (*,*) "..."``. In addition, this :any:`Transformer` can also attempt to match and @@ -528,7 +528,7 @@ class RemoveCallsTransformer(Transformer): List of subroutine names against which to match :any:`CallStatement` nodes. intrinsic_names : list of str - List of module names against which to match :any:`Intrinsic` + List of module names against which to match :any:`GenericStmt` nodes. remove_imports : boolean Flag indicating whether to remove the respective procedure @@ -565,8 +565,8 @@ def visit_Conditional(self, o, **kwargs): return self._rebuild(o, (cond, body, else_body)) - def visit_Intrinsic(self, o, **kwargs): - """ Match and remove :any:`Intrinsic` nodes against name patterns """ + def visit_GenericStmt(self, o, **kwargs): + """ Match and remove :any:`GenericStmt` nodes against name patterns """ if self.intrinsic_names: if any(str(c).lower() in o.text.lower() for c in self.intrinsic_names): return None diff --git a/loki/transformations/temporaries/pool_allocator.py b/loki/transformations/temporaries/pool_allocator.py index 2604516ba..30d9568f5 100644 --- a/loki/transformations/temporaries/pool_allocator.py +++ b/loki/transformations/temporaries/pool_allocator.py @@ -16,7 +16,7 @@ DetachScopesMapper ) from loki.ir import ( - FindNodes, FindVariables, FindInlineCalls, Transformer, Intrinsic, + FindNodes, FindVariables, FindInlineCalls, Transformer, GenericStmt, Assignment, Conditional, CallStatement, Import, Allocation, Deallocation, Loop, Pragma, Interface, get_pragma_parameters, SubstituteExpressions @@ -660,7 +660,7 @@ def _create_stack_allocation(self, stack_ptr, stack_end, ptr_var, arr, stack_siz if self.check_bounds: stack_size_check = Conditional( condition=Comparison(stack_ptr, '>', stack_end), inline=True, - body=(Intrinsic('STOP'),), else_body=None + body=(GenericStmt('STOP'),), else_body=None ) return ([ptr_assignment, ptr_increment, stack_size_check], stack_size) return ([ptr_assignment, ptr_increment], stack_size) @@ -752,7 +752,7 @@ def apply_pool_allocator_to_temporaries(self, routine, item=None): stack_end = self._get_stack_end(routine) for arr in temporary_arrays: ptr_var = Variable(name=self.local_ptr_var_name_pattern.format(name=arr.name), scope=routine) - declarations += [Intrinsic(f'POINTER({ptr_var.name}, {arr.name})')] # pylint: disable=no-member + declarations += [GenericStmt(f'POINTER({ptr_var.name}, {arr.name})')] # pylint: disable=no-member allocation, stack_size = self._create_stack_allocation(stack_ptr, stack_end, ptr_var, arr, stack_size, stack_storage) allocations += allocation diff --git a/loki/transformations/temporaries/tests/test_pool_allocator.py b/loki/transformations/temporaries/tests/test_pool_allocator.py index 1cc71fc27..ebaba6722 100644 --- a/loki/transformations/temporaries/tests/test_pool_allocator.py +++ b/loki/transformations/temporaries/tests/test_pool_allocator.py @@ -16,9 +16,8 @@ ) from loki.frontend import available_frontends, OMNI, FP from loki.ir import ( - FindNodes, CallStatement, Assignment, Allocation, Deallocation, - Loop, Pragma, get_pragma_parameters, FindVariables, FindInlineCalls, - Intrinsic + nodes as ir, FindNodes, FindVariables, FindInlineCalls, + get_pragma_parameters ) from loki.transformations.pragma_model import PragmaModelTransformation @@ -58,21 +57,21 @@ def check_stack_created_in_driver( assert 'ylstack_l' in driver.variables # Is there an allocation and deallocation for the stack storage? - allocations = FindNodes(Allocation).visit(driver.body) + allocations = FindNodes(ir.Allocation).visit(driver.body) assert len(allocations) == 1 and 'zstack(istsz,nb)' in allocations[0].variables - deallocations = FindNodes(Deallocation).visit(driver.body) + deallocations = FindNodes(ir.Deallocation).visit(driver.body) assert len(deallocations) == 1 and 'zstack' in deallocations[0].variables # # Check the stack size - assignments = FindNodes(Assignment).visit(driver.body) + assignments = FindNodes(ir.Assignment).visit(driver.body) for assignment in assignments: if assignment.lhs == 'istsz': assert simplify(assignment.rhs) == simplify(stack_size) # # Check for stack assignment inside loop - loops = FindNodes(Loop).visit(driver.body) + loops = FindNodes(ir.Loop).visit(driver.body) assert len(loops) == num_block_loops - assignments = FindNodes(Assignment).visit(loops[0].body) + assignments = FindNodes(ir.Assignment).visit(loops[0].body) assert assignments[0].lhs == 'ylstack_l' if cray_ptr_loc_rhs: assert assignments[0].rhs == '1' @@ -272,7 +271,7 @@ def test_pool_allocator_temporaries(tmp_path, frontend, generate_driver_stack, b driver = scheduler['#driver'].ir check_c_sizeof_import(driver) check_real64_import(driver) - calls = FindNodes(CallStatement).visit(driver.body) + calls = FindNodes(ir.CallStatement).visit(driver.body) assert len(calls) == 1 if trafo == TemporariesPoolAllocatorTransformation else 2 if nclv_param: expected_args = ('1', 'nlon', 'nlon', 'nz', 'field1(:,b)', 'field2(:,:,b)') @@ -336,7 +335,7 @@ def test_pool_allocator_temporaries(tmp_path, frontend, generate_driver_stack, b else: tmp_indices = (1, 2, 3, 5) assign_idx = {} - for idx, assign in enumerate(FindNodes(Assignment).visit(kernel.body)): + for idx, assign in enumerate(FindNodes(ir.Assignment).visit(kernel.body)): if assign.lhs == 'ylstack_l' and assign.rhs == 'ydstack_l': # Local copy of stack status assign_idx['stack_assign'] = idx @@ -478,9 +477,11 @@ def test_pool_allocator_unused_temporaries(tmp_path, frontend, horizontal, block # check that the correct variables end up on the stack #  look for 'POINTER(IP_tmp<...>, tmp<...>)' Intrinsics - pointers = [intrinsic.text.split(',')[1].replace(')', '').replace(' ', '') for intrinsic - in FindNodes(Intrinsic).visit(kernel_item.ir.spec) - if 'pointer' in intrinsic.text.lower()] + pointers = [ + intrinsic.text.split(',')[1].replace(')', '').replace(' ', '') + for intrinsic in FindNodes(ir.GenericStmt).visit(kernel_item.ir.spec) + if 'pointer' in intrinsic.text.lower() + ] assert 'tmp1' in pointers assert 'tmp2' in pointers assert 'tmp6' in pointers @@ -682,7 +683,7 @@ def test_pool_allocator_temporaries_kernel_sequence(tmp_path, frontend, block_di # driver = scheduler['#driver'].ir - stack_order = FindNodes(Assignment).visit(driver.body) + stack_order = FindNodes(ir.Assignment).visit(driver.body) if stack_insert_pragma: assert stack_order[0].lhs == "a" else: @@ -696,7 +697,7 @@ def test_pool_allocator_temporaries_kernel_sequence(tmp_path, frontend, block_di assert 'jprb' not in driver.import_map['jpim'].symbols # Has the stack been added to the call statements? - calls = FindNodes(CallStatement).visit(driver.body) + calls = FindNodes(ir.CallStatement).visit(driver.body) expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_U')) if cray_ptr_loc_rhs: expected_kwarguments += (('ZSTACK', 'zstack(:,b)'),) @@ -725,7 +726,7 @@ def test_pool_allocator_temporaries_kernel_sequence(tmp_path, frontend, block_di if directive in ['openmp', 'openacc']: keyword = {'openmp': 'omp', 'openacc': 'acc'}[directive] pragmas = [ - p for p in FindNodes(Pragma).visit(driver.body) + p for p in FindNodes(ir.Pragma).visit(driver.body) if p.keyword.lower() == keyword and p.content.startswith('parallel') ] assert len(pragmas) == 2 @@ -737,7 +738,7 @@ def test_pool_allocator_temporaries_kernel_sequence(tmp_path, frontend, block_di # Are there data regions for the stack? if directive == ['openacc']: pragmas = [ - p for p in FindNodes(Pragma).visit(driver.body) + p for p in FindNodes(ir.Pragma).visit(driver.body) if p.keyword.lower() == 'acc' and 'data' in p.content ] assert len(pragmas) == 2 @@ -780,7 +781,7 @@ def test_pool_allocator_temporaries_kernel_sequence(tmp_path, frontend, block_di # Let's check for the relevant "allocations" happening in the right order assign_idx = {} - for idx, ass in enumerate(FindNodes(Assignment).visit(kernel.body)): + for idx, ass in enumerate(FindNodes(ir.Assignment).visit(kernel.body)): if ass.lhs == 'ylstack_l' and ass.rhs == 'ydstack_l': # Local copy of stack status @@ -996,7 +997,7 @@ def test_pool_allocator_temporaries_kernel_nested(tmp_path, frontend, block_dim, assert driver.import_map['jpim'] == driver.import_map['jplm'] # Has the stack been added to the call statements? - calls = FindNodes(CallStatement).visit(driver.body) + calls = FindNodes(ir.CallStatement).visit(driver.body) expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_u')) if cray_ptr_loc_rhs: expected_kwarguments += (('ZSTACK', 'zstack(:,b)'),) @@ -1025,7 +1026,7 @@ def test_pool_allocator_temporaries_kernel_nested(tmp_path, frontend, block_dim, if directive in ['openmp', 'openacc']: keyword = {'openmp': 'omp', 'openacc': 'acc'}[directive] pragmas = [ - p for p in FindNodes(Pragma).visit(driver.body) + p for p in FindNodes(ir.Pragma).visit(driver.body) if p.keyword.lower() == keyword and p.content.startswith('parallel') ] assert len(pragmas) == 1 @@ -1038,7 +1039,7 @@ def test_pool_allocator_temporaries_kernel_nested(tmp_path, frontend, block_dim, # Are there data regions for the stack? if directive == ['openacc']: pragmas = [ - p for p in FindNodes(Pragma).visit(driver.body) + p for p in FindNodes(ir.Pragma).visit(driver.body) if p.keyword.lower() == 'acc' and 'data' in p.content ] assert len(pragmas) == 2 @@ -1048,7 +1049,7 @@ def test_pool_allocator_temporaries_kernel_nested(tmp_path, frontend, block_dim, # # A few checks on the kernels # - calls = FindNodes(CallStatement).visit(kernel_item.ir.body) + calls = FindNodes(ir.CallStatement).visit(kernel_item.ir.body) expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_u')) if cray_ptr_loc_rhs: expected_kwarguments += (('ZSTACK', 'zstack'),) @@ -1085,7 +1086,7 @@ def test_pool_allocator_temporaries_kernel_nested(tmp_path, frontend, block_dim, # Let's check for the relevant "allocations" happening in the right order assign_idx = {} - for idx, ass in enumerate(FindNodes(Assignment).visit(kernel.body)): + for idx, ass in enumerate(FindNodes(ir.Assignment).visit(kernel.body)): if ass.lhs == 'ylstack_l' and ass.rhs == 'ydstack_l': # Local copy of stack status @@ -1213,7 +1214,7 @@ def test_pool_allocator_more_call_checks(tmp_path, frontend, block_dim, caplog, assert 'ylstack_u' in kernel.variables # Has the stack been added to the call statement at the correct location? - calls = FindNodes(CallStatement).visit(kernel.body) + calls = FindNodes(ir.CallStatement).visit(kernel.body) expected_kwarguments = (('YDSTACK_L', 'ylstack_l'), ('YDSTACK_U', 'ylstack_u')) if cray_ptr_loc_rhs: expected_kwarguments += (('ZSTACK', 'zstack'),) @@ -1387,7 +1388,7 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p assert 'ydstack_l' in kernel2.arguments assert 'ydstack_u' in kernel2.arguments - calls = FindNodes(CallStatement).visit(driver.body) + calls = FindNodes(ir.CallStatement).visit(driver.body) additional_kwargs = (('ZSTACK', 'zstack(:,b)'),) if cray_ptr_loc_rhs else () assert calls[0].arguments == () assert calls[0].kwarguments == ( @@ -1424,7 +1425,7 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p ) + additional_kwargs # check stack size allocation - allocations = FindNodes(Allocation).visit(driver.body) + allocations = FindNodes(ir.Allocation).visit(driver.body) assert len(allocations) == 1 and 'zstack(istsz,geom%blk_dim%nb)' in allocations[0].variables # check that array size was imported to the driver diff --git a/loki/transformations/tests/test_parametrise.py b/loki/transformations/tests/test_parametrise.py index 3fcbdae09..26c9e1775 100644 --- a/loki/transformations/tests/test_parametrise.py +++ b/loki/transformations/tests/test_parametrise.py @@ -346,7 +346,7 @@ def test_parametrise_modified_callback(tmp_path, testdir, frontend, config): def error_stop(**kwargs): msg = kwargs.get("msg") - abort = (ir.Intrinsic(text=f'error stop "{msg}"'),) + abort = (ir.GenericStmt(text=f'error stop "{msg}"'),) return abort def stop_execution(**kwargs): @@ -383,7 +383,7 @@ def test_parametrise_modified_callback_wrong_input(tmp_path, testdir, frontend, def only_warn(**kwargs): msg = kwargs.get("msg") - abort = (ir.Intrinsic(text=f'print *, "This is just a warning: {msg}"'),) + abort = (ir.GenericStmt(text=f'print *, "This is just a warning: {msg}"'),) return abort scheduler = Scheduler( diff --git a/loki/transformations/tests/test_remove_code.py b/loki/transformations/tests/test_remove_code.py index 4af6288df..81669ff49 100644 --- a/loki/transformations/tests/test_remove_code.py +++ b/loki/transformations/tests/test_remove_code.py @@ -570,7 +570,7 @@ def test_transform_remove_calls(frontend, remove_imports, tmp_path): assert len(conditionals) == (4 if frontend == OMNI else 0) # Check that all intrinsic calls to WRITE have been removed - intrinsics = FindNodes(ir.Intrinsic).visit(routine.body) + intrinsics = FindNodes(ir.GenericStmt).visit(routine.body) assert len(intrinsics) == 1 assert 'never gonna let you down' in intrinsics[0].text diff --git a/loki/transformations/tests/test_split_read_write.py b/loki/transformations/tests/test_split_read_write.py index 95610cd08..41b25d0c3 100644 --- a/loki/transformations/tests/test_split_read_write.py +++ b/loki/transformations/tests/test_split_read_write.py @@ -125,7 +125,7 @@ def test_split_read_write(frontend, horizontal, vertical): assert not 'var2(jl,jk)' in FindVariables().visit(outer_loops[1]) # check print statement is only present in first copy of region - assert len(FindNodes(ir.Intrinsic).visit(region)) == 1 + assert len(FindNodes(ir.GenericStmt).visit(region)) == 1 # check correctness of split reads assigns = FindNodes(ir.Assignment).visit(outer_loops[0].body) diff --git a/loki/transformations/transpile/fortran_c.py b/loki/transformations/transpile/fortran_c.py index eb17c0e64..18b389499 100644 --- a/loki/transformations/transpile/fortran_c.py +++ b/loki/transformations/transpile/fortran_c.py @@ -15,8 +15,7 @@ SubstituteExpressionsMapper ) from loki.ir import ( - Import, Intrinsic, Interface, CallStatement, Assignment, - Transformer, FindNodes, Comment, SubstituteExpressions, + nodes as ir, Transformer, FindNodes, SubstituteExpressions, FindInlineCalls ) from loki.logging import debug @@ -163,7 +162,7 @@ def transform_subroutine(self, routine, **kwargs): c_kernel = self.generate_c_kernel(routine, targets=targets) for successor in successors: - c_kernel.spec.prepend(Import(module=f'{successor.ir.name.lower()}_c.h', c_import=True)) + c_kernel.spec.prepend(ir.Import(module=f'{successor.ir.name.lower()}_c.h', c_import=True)) if depth == 1: if self.language != 'c': @@ -172,14 +171,14 @@ def transform_subroutine(self, routine, **kwargs): c_path = (path/c_kernel_launch.name.lower()).with_suffix('.h') Sourcefile.to_file(source=self.codegen(c_kernel_launch, extern=True), path=c_path) - assignments = FindNodes(Assignment).visit(c_kernel.body) + assignments = FindNodes(ir.Assignment).visit(c_kernel.body) assignments2remove = ['griddim', 'blockdim'] assignment_map = {assignment: None for assignment in assignments if assignment.lhs.name.lower() in assignments2remove} c_kernel.body = Transformer(assignment_map).visit(c_kernel.body) if depth > 1: - c_kernel.spec.prepend(Import(module=f'{c_kernel.name.lower()}.h', c_import=True)) + c_kernel.spec.prepend(ir.Import(module=f'{c_kernel.name.lower()}.h', c_import=True)) c_path = (path/c_kernel.name.lower()).with_suffix(self.file_suffix()) Sourcefile.to_file(source=self.codegen(c_kernel, extern=self.language=='cpp'), path=c_path) header_path = (path/c_kernel.name.lower()).with_suffix('.h') @@ -187,7 +186,7 @@ def transform_subroutine(self, routine, **kwargs): def convert_kwargs_to_args(self, routine, targets): # calls (to subroutines) - for call in as_tuple(FindNodes(CallStatement).visit(routine.body)): + for call in as_tuple(FindNodes(ir.CallStatement).visit(routine.body)): if str(call.name).lower() in as_tuple(targets): call.convert_kwargs_to_args() # inline calls (to functions) @@ -202,10 +201,10 @@ def interface_to_import(self, routine, targets): """ Convert interface to import. """ - for call in FindNodes(CallStatement).visit(routine.body): + for call in FindNodes(ir.CallStatement).visit(routine.body): if str(call.name).lower() in as_tuple(targets): call.convert_kwargs_to_args() - intfs = FindNodes(Interface).visit(routine.spec) + intfs = FindNodes(ir.Interface).visit(routine.spec) removal_map = {} for i in intfs: for s in i.symbols: @@ -213,7 +212,7 @@ def interface_to_import(self, routine, targets): # Create a new module import with explicitly qualified symbol new_symbol = s.clone(name=f'{s.name}_FC', scope=routine) modname = f'{new_symbol.name}_MOD' - new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) + new_import = ir.Import(module=modname, c_import=False, symbols=(new_symbol,)) routine.spec.prepend(new_import) # Mark current import for removal removal_map[i] = None @@ -279,7 +278,7 @@ def generate_c_kernel(self, routine, targets, **kwargs): for module, variables in module_variables.items(): for var in variables: getter = f'{module}__get__{var.name.lower()}' - vget = Assignment(lhs=var, rhs=InlineCall(ProcedureSymbol(getter, scope=var.scope))) + vget = ir.Assignment(lhs=var, rhs=InlineCall(ProcedureSymbol(getter, scope=var.scope))) getter_calls += [vget] kernel.body.prepend(getter_calls) @@ -300,8 +299,7 @@ def generate_c_kernel(self, routine, targets, **kwargs): kernel.spec = Transformer(import_map).visit(kernel.spec) # Remove intrinsics from spec (eg. implicit none) - intrinsic_map = {i: None for i in FindNodes(Intrinsic).visit(kernel.spec) - if 'implicit' in i.text.lower()} + intrinsic_map = {i: None for i in FindNodes(ir.ImplicitStmt).visit(kernel.spec)} kernel.spec = Transformer(intrinsic_map).visit(kernel.spec) # Resolve implicit struct mappings through "associates" @@ -339,7 +337,7 @@ def generate_c_kernel(self, routine, targets, **kwargs): def convert_call_names(self, routine, targets): # calls (to subroutines) - calls = FindNodes(CallStatement).visit(routine.body) + calls = FindNodes(ir.CallStatement).visit(routine.body) for call in calls: if call.name not in as_tuple(targets): continue @@ -353,7 +351,7 @@ def convert_call_names(self, routine, targets): def generate_c_kernel_launch(self, kernel_launch, kernel, **kwargs): import_map = {} - for im in FindNodes(Import).visit(kernel_launch.spec): + for im in FindNodes(ir.Import).visit(kernel_launch.spec): import_map[im] = None kernel_launch.spec = Transformer(import_map).visit(kernel_launch.spec) @@ -368,7 +366,7 @@ def generate_c_kernel_launch(self, kernel_launch, kernel, **kwargs): griddim = kernel_launch.variable_map['griddim'] if 'blockdim' in kernel_launch.variable_map: blockdim = kernel_launch.variable_map['blockdim'] - assignments = FindNodes(Assignment).visit(kernel_launch.body) + assignments = FindNodes(ir.Assignment).visit(kernel_launch.body) griddim_assignment = None blockdim_assignment = None for assignment in assignments: @@ -376,7 +374,7 @@ def generate_c_kernel_launch(self, kernel_launch, kernel, **kwargs): griddim_assignment = assignment.clone() if assignment.lhs == blockdim: blockdim_assignment = assignment.clone() - kernel_launch.body = (Comment(text="! here should be the launcher ...."), - griddim_assignment, blockdim_assignment, CallStatement(name=Variable(name=kernel.name), + kernel_launch.body = (ir.Comment(text="! here should be the launcher ...."), + griddim_assignment, blockdim_assignment, ir.CallStatement(name=Variable(name=kernel.name), arguments=call_arguments, chevron=(sym.Variable(name="griddim"), sym.Variable(name="blockdim")))) diff --git a/loki/transformations/transpile/fortran_iso_c_wrapper.py b/loki/transformations/transpile/fortran_iso_c_wrapper.py index 4759b97f9..83d8cbcdb 100644 --- a/loki/transformations/transpile/fortran_iso_c_wrapper.py +++ b/loki/transformations/transpile/fortran_iso_c_wrapper.py @@ -116,7 +116,7 @@ def transform_subroutine(self, routine, **kwargs): routine, c_structs, bind_name=bind_name, use_c_ptr=self.use_c_ptr, language=self.language ) - contains = ir.Section(body=(ir.Intrinsic('CONTAINS'), wrapper)) + contains = ir.Section(body=(ir.ContainsStmt(), wrapper)) wrapperpath = (path/wrapper.name.lower()).with_suffix('.F90') module = Module(name=f'{wrapper.name.upper()}_MOD', contains=contains) module.spec = ir.Section(body=(ir.Import(module='iso_c_binding'),)) @@ -252,7 +252,7 @@ def generate_iso_c_interface(routine, bind_name, c_structs, scope, use_c_ptr=Fal if not im.c_import: im_symbols = tuple(s.clone(scope=intf_routine) for s in im.symbols) intf_spec.append(im.clone(symbols=im_symbols)) - intf_spec.append(ir.Intrinsic(text='implicit none')) + intf_spec.append(ir.ImplicitStmt()) intf_spec.append(c_structs.values()) intf_routine.spec = intf_spec @@ -450,7 +450,7 @@ def generate_iso_c_wrapper_module(module, use_c_ptr=False, language='c'): ) getter.variables = as_tuple(sym.Variable(name=gettername, type=isoctype, scope=getter)) wrappers += [getter] - wrapper_module.contains = ir.Section(body=(ir.Intrinsic('CONTAINS'), *wrappers)) + wrapper_module.contains = ir.Section(body=(ir.ContainsStmt(), *wrappers)) # Remove any unused imports sanitise_imports(wrapper_module) @@ -482,7 +482,7 @@ def generate_c_header(module): continue ctype = c_intrinsic_kind(v.type, scope=module) tmpl_function = f'{ctype} {module.name.lower()}__get__{v.name.lower()}();' - spec += [ir.Intrinsic(text=tmpl_function)] + spec += [ir.GenericStmt(text=tmpl_function)] # Re-create type definitions with range indices (``:``) replaced by pointers for td in FindNodes(ir.TypeDef).visit(module.spec): diff --git a/loki/transformations/transpile/fortran_python.py b/loki/transformations/transpile/fortran_python.py index 8d14902e4..7ddcbf505 100644 --- a/loki/transformations/transpile/fortran_python.py +++ b/loki/transformations/transpile/fortran_python.py @@ -61,8 +61,7 @@ def transform_subroutine(self, routine, **kwargs): # Remove all "IMPLICT" intrinsic statements mapper = { - i: None for i in FindNodes(ir.Intrinsic).visit(routine.spec) - if 'implicit' in i.text.lower() + i: None for i in FindNodes(ir.ImplicitStmt).visit(routine.spec) } routine.spec = Transformer(mapper).visit(routine.spec) diff --git a/loki/types/tests/test_derived_types.py b/loki/types/tests/test_derived_types.py index 23095748a..bc7742e71 100644 --- a/loki/types/tests/test_derived_types.py +++ b/loki/types/tests/test_derived_types.py @@ -506,10 +506,10 @@ def test_derived_type_private_comp(frontend, tmp_path): some_private_comp_type = module.typedef_map['some_private_comp_type'] type_bound_proc_type = module.typedef_map['type_bound_proc_type'] - intrinsic_nodes = FindNodes(ir.Intrinsic).visit(type_bound_proc_type.body) + intrinsic_nodes = FindNodes(ir.GenericStmt).visit(type_bound_proc_type.body) assert len(intrinsic_nodes) == 2 - assert intrinsic_nodes[0].text.lower() == 'contains' - assert intrinsic_nodes[1].text.lower() == 'private' + assert isinstance(intrinsic_nodes[0], ir.ContainsStmt) + assert isinstance(intrinsic_nodes[1], ir.PrivateStmt) assert re.search( r'^\s+contains$\s+private', fgen(type_bound_proc_type), re.I | re.MULTILINE @@ -517,10 +517,10 @@ def test_derived_type_private_comp(frontend, tmp_path): # OMNI gets the below wrong as it doesn't retain the private statement for components if frontend != OMNI: - intrinsic_nodes = FindNodes(ir.Intrinsic).visit(some_private_comp_type.body) + intrinsic_nodes = FindNodes(ir.GenericStmt).visit(some_private_comp_type.body) assert len(intrinsic_nodes) == 2 - assert intrinsic_nodes[0].text.lower() == 'private' - assert intrinsic_nodes[1].text.lower() == 'contains' + assert isinstance(intrinsic_nodes[0], ir.PrivateStmt) + assert isinstance(intrinsic_nodes[1], ir.ContainsStmt) assert re.search( r'^\s+private*$(\s.*?){2}\s+contains', fgen(some_private_comp_type), re.I | re.MULTILINE