diff --git a/loki/analyse/__init__.py b/loki/analyse/__init__.py index 85685fede..eb29caa45 100644 --- a/loki/analyse/__init__.py +++ b/loki/analyse/__init__.py @@ -8,4 +8,4 @@ Advanced analysis utilities, such as dataflow analysis functionalities. """ -from loki.analyse.analyse_dataflow import * # noqa +from loki.analyse.data_flow_analysis import * # noqa diff --git a/loki/analyse/abstract_dfa.py b/loki/analyse/abstract_dfa.py new file mode 100644 index 000000000..10273ee45 --- /dev/null +++ b/loki/analyse/abstract_dfa.py @@ -0,0 +1,50 @@ +# (C) Copyright 2024- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from abc import ABC, abstractmethod +from contextlib import contextmanager + +from loki import Transformer + +__all__ = ['AbstractDataflowAnalysis'] + +class AbstractDataflowAnalysis(Transformer, ABC): + class _Attacher(Transformer): + pass + + class _Detacher(Transformer): + pass + + def get_attacher(self): + return self._Attacher() + + def get_detacher(self): + return self._Detacher() + + @abstractmethod + def attach_dataflow_analysis(self, module_or_routine): + pass + + def detach_dataflow_analysis(self, module_or_routine): + """ + Remove from each IR node the stored dataflow analysis metadata. + + Accessing the relevant attributes afterwards raises :py:class:`RuntimeError`. + """ + + if hasattr(module_or_routine, 'spec'): + self.get_detacher().visit(module_or_routine.spec) + if hasattr(module_or_routine, 'body'): + self.get_detacher().visit(module_or_routine.body) + + @contextmanager + def dataflow_analysis_attached(self, module_or_routine): + self.attach_dataflow_analysis(module_or_routine) + try: + yield module_or_routine + finally: + self.detach_dataflow_analysis(module_or_routine) \ No newline at end of file diff --git a/loki/analyse/analyse_dataflow.py b/loki/analyse/analyse_dataflow.py deleted file mode 100644 index 5082178d9..000000000 --- a/loki/analyse/analyse_dataflow.py +++ /dev/null @@ -1,606 +0,0 @@ -# (C) Copyright 2018- ECMWF. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -""" -Collection of dataflow analysis schema routines. -""" - -from contextlib import contextmanager -from loki.expression import Array, ProcedureSymbol -from loki.tools import as_tuple, flatten -from loki.types import BasicType -from loki.ir import ( - Visitor, Transformer, FindVariables, FindInlineCalls, FindTypedSymbols -) -from loki.subroutine import Subroutine -from loki.tools.util import CaseInsensitiveDict - -__all__ = [ - 'dataflow_analysis_attached', 'read_after_write_vars', - 'loop_carried_dependencies' -] - - -class DataflowAnalysisAttacher(Transformer): - """ - Analyse and attach in-place the definition, use and live status of - symbols. - """ - - # group of functions that only query memory properties and don't read/write variable value - _mem_property_queries = ('size', 'lbound', 'ubound', 'present') - - def __init__(self, **kwargs): - super().__init__(inplace=True, invalidate_source=False, **kwargs) - - # Utility routines - - def _visit_body(self, body, live=None, defines=None, uses=None, **kwargs): - """ - Iterate through the tuple that is a body and update defines and - uses along the way. - """ - if live is None: - live = set() - if defines is None: - defines = set() - if uses is None: - uses = set() - visited = [] - for i in flatten(body): - visited += [self.visit(i, live_symbols=live|defines, **kwargs)] - uses |= visited[-1].uses_symbols.copy() - defines - defines |= visited[-1].defines_symbols.copy() - return as_tuple(visited), defines, uses - - @staticmethod - def _symbols_from_expr(expr, condition=None): - """ - Return set of symbols found in an expression. - """ - if condition is not None: - return {v.clone(dimensions=None) for v in FindVariables().visit(expr) if condition(v)} - return {v.clone(dimensions=None) for v in FindVariables().visit(expr)} - - @classmethod - def _symbols_from_lhs_expr(cls, expr): - """ - Determine symbol use and symbol definition from a left-hand side expression. - - Parameters - ---------- - expr : :any:`Scalar` or :any:`Array` - The left-hand side expression of an assignment. - - Returns - ------- - (defines, uses) : (set, set) - The sets of defined and used symbols (in that order). - """ - defines = {expr.clone(dimensions=None)} - uses = cls._symbols_from_expr(getattr(expr, 'dimensions', ())) - return defines, uses - - # Abstract node (also called from every node type for integration) - - def visit_Node(self, o, **kwargs): - # Live symbols are determined on InternalNode handler levels and - # get passed down to all child nodes - o._update(_live_symbols=kwargs.get('live_symbols', set())) - - # Symbols defined or used by this node are determined by their individual - # handler routines and passed on to visitNode from there - o._update(_defines_symbols=kwargs.get('defines_symbols', set())) - o._update(_uses_symbols=kwargs.get('uses_symbols', set())) - return o - - # Internal nodes - - def visit_Interface(self, o, **kwargs): - # Subroutines/functions calls defined in an explicit interface - defines = set() - for b in o.body: - if isinstance(b, Subroutine): - defines = defines | set(as_tuple(b.procedure_symbol)) - return self.visit_Node(o, defines_symbols=defines, **kwargs) - - def visit_InternalNode(self, o, **kwargs): - # An internal node defines all symbols defined by its body and uses all - # symbols used by its body before they are defined in the body - live = kwargs.pop('live_symbols', set()) - body, defines, uses = self._visit_body(o.body, live=live, **kwargs) - o._update(body=body) - return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_Associate(self, o, **kwargs): - # An associate block defines all symbols defined by its body and uses all - # symbols used by its body before they are defined in the body - live = kwargs.pop('live_symbols', set()) - body, defines, uses = self._visit_body(o.body, live=live, **kwargs) - o._update(body=body) - - # reverse the mapping of names before assinging lives, defines, uses sets for Associate node itself - invert_assoc = CaseInsensitiveDict({v.name: k for k, v in o.associations}) - _live = set(invert_assoc[v.name] if v.name in invert_assoc else v for v in live) - _defines = set(invert_assoc[v.name] if v.name in invert_assoc else v for v in defines) - _uses = set(invert_assoc[v.name] if v.name in invert_assoc else v for v in uses) - - return self.visit_Node(o, live_symbols=_live, defines_symbols=_defines, uses_symbols=_uses, **kwargs) - - def visit_Loop(self, o, **kwargs): - # A loop defines the induction variable for its body before entering it - live = kwargs.pop('live_symbols', set()) - uses = self._symbols_from_expr(o.bounds) - body, defines, uses = self._visit_body(o.body, live=live|{o.variable.clone()}, uses=uses, **kwargs) - o._update(body=body) - # Make sure the induction variable is not considered outside the loop - uses.discard(o.variable) - defines.discard(o.variable) - return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_WhileLoop(self, o, **kwargs): - # A while loop uses variables in its condition - live = kwargs.pop('live_symbols', set()) - uses = self._symbols_from_expr(o.condition) - body, defines, uses = self._visit_body(o.body, live=live, uses=uses, **kwargs) - o._update(body=body) - return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_Conditional(self, o, **kwargs): - live = kwargs.pop('live_symbols', set()) - - # exclude arguments to functions that just check the memory attributes of a variable - mem_call = as_tuple(i for i in FindInlineCalls().visit(o.condition) if i.function in self._mem_property_queries) - query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_call)) - cset = set(v for v in FindVariables().visit(o.condition) if not v in query_args) - - condition = self._symbols_from_expr(as_tuple(cset)) - body, defines, uses = self._visit_body(o.body, live=live, uses=condition, **kwargs) - else_body, else_defines, uses = self._visit_body(o.else_body, live=live, uses=uses, **kwargs) - o._update(body=body, else_body=else_body) - return self.visit_Node(o, live_symbols=live, defines_symbols=defines|else_defines, uses_symbols=uses, **kwargs) - - def visit_MultiConditional(self, o, **kwargs): - live = kwargs.pop('live_symbols', set()) - - # exclude arguments to functions that just check the memory attributes of a variable - mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.expr) if i.function in self._mem_property_queries) - query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls)) - eset = set(v for v in FindVariables().visit(o.expr) if not v in query_args) - - mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.values) if i.function in self._mem_property_queries) - query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls)) - vset = set(v for v in FindVariables().visit(o.values) if not v in query_args) - - uses = self._symbols_from_expr(as_tuple(eset)) | self._symbols_from_expr(as_tuple(vset)) - body = () - defines = set() - for b in o.bodies: - _b, _d, uses = self._visit_body(b, live=live, uses=uses, **kwargs) - body += (as_tuple(_b),) - defines |= _d - else_body, else_defines, uses = self._visit_body(o.else_body, live=live, uses=uses, **kwargs) - o._update(bodies=body, else_body=else_body) - defines = defines | else_defines - return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_MaskedStatement(self, o, **kwargs): - live = kwargs.pop('live_symbols', set()) - conditions = self._symbols_from_expr(o.conditions) - - body = () - defines = set() - uses = set(conditions) - for b in o.bodies: - _b, defines, uses = self._visit_body(b, live=live, uses=uses, defines=defines, **kwargs) - body += (_b,) - - default, default_defs, uses = self._visit_body(o.default, live=live, uses=uses, **kwargs) - o._update(bodies=body, default=default) - return self.visit_Node(o, live_symbols=live, defines_symbols=defines|default_defs, uses_symbols=uses, **kwargs) - - # Leaf nodes - - def visit_Assignment(self, o, **kwargs): - # exclude arguments to functions that just check the memory attributes of a variable - mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.rhs) if i.function in self._mem_property_queries) - query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls)) - rset = set(v for v in FindVariables().visit(o.rhs) if not v in query_args) - - # The left-hand side variable is defined by this statement - defines, uses = self._symbols_from_lhs_expr(o.lhs) - - # Anything on the right-hand side is used before assigning to it - uses |= self._symbols_from_expr(as_tuple(rset)) - return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_ConditionalAssignment(self, o, **kwargs): - # The left-hand side variable is defined by this statement - defines, uses = self._symbols_from_lhs_expr(o.lhs) - # Anything on the right-hand side is used before assigning to it - uses |= self._symbols_from_expr((o.condition, o.rhs, o.else_rhs)) - return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_CallStatement(self, o, **kwargs): - if o.routine is not BasicType.DEFERRED: - # With a call context provided we can determine which arguments - # are potentially defined and which are definitely only used by - # this call - defines, uses = set(), set() - outvals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'out')] - invals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'in')] - - arrays = [v for v in FindVariables().visit(outvals) if isinstance(v, Array)] - dims = set(v for a in arrays for v in self._symbols_from_expr(a.dimensions)) - for val in outvals: - exprs = self._symbols_from_expr(val) - defines |= {e for e in exprs if not e in dims} - uses |= dims - - uses |= {s for val in invals for s in self._symbols_from_expr(val)} - else: - # We don't know the intent of any of these arguments and thus have - # to assume all of them are potentially used or defined by this - # statement - arrays = [v for v in FindVariables().visit(o.arguments) if isinstance(v, Array)] - arrays += [v for arg, val in o.kwarguments for v in FindVariables().visit(val) if isinstance(v, Array)] - - dims = set(v for a in arrays for v in FindVariables().visit(a.dimensions)) - defines = self._symbols_from_expr(o.arguments, condition=lambda x: x not in dims) - for arg, val in o.kwarguments: - defines |= self._symbols_from_expr(val, condition=lambda x: x not in dims) - uses = defines.copy() | dims - - return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_Allocation(self, o, **kwargs): - arrays = [v for v in FindVariables().visit(o.variables) if isinstance(v, Array)] - dims = set(v for a in arrays for v in FindVariables().visit(a.dimensions)) - defines = self._symbols_from_expr(o.variables, condition=lambda x: x not in dims) - uses = self._symbols_from_expr(o.data_source or ()) | dims - return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) - - def visit_Deallocation(self, o, **kwargs): - defines = self._symbols_from_expr(o.variables) - return self.visit_Node(o, defines_symbols=defines, **kwargs) - - visit_Nullify = visit_Deallocation - - def visit_Import(self, o, **kwargs): - defines = set(s.clone(dimensions=None) for s in FindTypedSymbols().visit(o.symbols or ()) - if isinstance(s, ProcedureSymbol)) - return self.visit_Node(o, defines_symbols=defines, **kwargs) - - def visit_VariableDeclaration(self, o, **kwargs): - defines = self._symbols_from_expr(o.symbols, condition=lambda v: v.type.initial is not None) - uses = {v for a in o.symbols if isinstance(a, Array) for v in self._symbols_from_expr(a.dimensions)} - return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) - - -class DataflowAnalysisDetacher(Transformer): - """ - Remove in-place any dataflow analysis properties. - """ - - def __init__(self, **kwargs): - super().__init__(inplace=True, invalidate_source=False, **kwargs) - - def visit_Node(self, o, **kwargs): - o._update(_live_symbols=None, _defines_symbols=None, _uses_symbols=None) - return super().visit_Node(o, **kwargs) - - -def attach_dataflow_analysis(module_or_routine): - """ - Determine and attach to each IR node dataflow analysis metadata. - - This makes for each IR node the following properties available: - - * :attr:`Node.live_symbols`: symbols defined before the node; - * :attr:`Node.defines_symbols`: symbols (potentially) defined by the - node, i.e., live in subsequent nodes; - * :attr:`Node.uses_symbols`: symbols used by the node (that had to be - defined before). - - The IR nodes are updated in-place and thus existing references to IR - nodes remain valid. - """ - live_symbols = set() - if hasattr(module_or_routine, 'arguments'): - live_symbols = DataflowAnalysisAttacher._symbols_from_expr( - module_or_routine.arguments, - condition=lambda a: a.type.intent and a.type.intent.lower() in ('in', 'inout') - ) - - if hasattr(module_or_routine, 'spec'): - DataflowAnalysisAttacher().visit(module_or_routine.spec, live_symbols=live_symbols) - live_symbols |= module_or_routine.spec.defines_symbols - - if hasattr(module_or_routine, 'body'): - DataflowAnalysisAttacher().visit(module_or_routine.body, live_symbols=live_symbols) - - -def detach_dataflow_analysis(module_or_routine): - """ - Remove from each IR node the stored dataflow analysis metadata. - - Accessing the relevant attributes afterwards raises :py:class:`RuntimeError`. - """ - if hasattr(module_or_routine, 'spec'): - DataflowAnalysisDetacher().visit(module_or_routine.spec) - if hasattr(module_or_routine, 'body'): - DataflowAnalysisDetacher().visit(module_or_routine.body) - - -@contextmanager -def dataflow_analysis_attached(module_or_routine): - r""" - Create a context in which information about defined, live and used symbols - is attached to each IR node - - This makes for each IR node the following properties available: - - * :attr:`Node.live_symbols`: symbols defined before the node; - * :attr:`Node.defines_symbols`: symbols (potentially) defined by the - node; - * :attr:`Node.uses_symbols`: symbols used by the node that had to be - defined before. - - This is an in-place update of nodes and thus existing references to IR - nodes remain valid. When leaving the context the information is removed - from IR nodes, while existing references remain valid. - - The analysis is based on a rather crude regions-based analysis, with the - hierarchy implied by (nested) :any:`InternalNode` IR nodes used as regions - in the reducible flow graph (cf. Chapter 9, in particular 9.7 of Aho, Lam, - Sethi, and Ulliman (2007)). Our implementation shares some similarities - with a full reaching definitions dataflow analysis but is not quite as - powerful. - - In reaching definitions dataflow analysis (cf. Chapter 9.2.4 Aho et. al.), - the transfer function of a definition :math:`d` can be expressed as: - - .. math:: f_d(x) = \operatorname{gen}_d \cup (x - \operatorname{kill}_d) - - with the set of definitions generated :math:`\operatorname{gen}_d` and the - set of definitions killed/invalidated :math:`\operatorname{kill}_d`. - - We, however, do not record definitions explicitly and instead operate on - consolidated sets of defined symbols, i.e., effectively evaluate the - chained transfer functions up to the node. This yields a set of active - definitions at this node. The symbols defined by these definitions are - in :any:`Node.live_symbols`, and the symbols defined by the node (i.e., - symbols defined by definitions in :math:`\operatorname{gen}_d`) are in - :any:`Node.defines_symbols`. - - The advantage of this approach is that it avoids the need to introduce - a layer for definitions and dependencies. A downside is that this focus - on symbols instead of definitions precludes, in particular, the ability - to take data space into account, which makes it less useful for arrays. - - .. note:: - The context manager operates only on the module or routine itself - (i.e., its spec and, if applicable, body), not on any contained - subroutines or functions. - - Parameters - ---------- - module_or_routine : :any:`Module` or :any:`Subroutine` - The object for which the IR is to be annotated. - """ - attach_dataflow_analysis(module_or_routine) - try: - yield module_or_routine - finally: - detach_dataflow_analysis(module_or_routine) - - -class FindReads(Visitor): - """ - Look for reads in a specified part of a control flow tree. - - Parameters - ---------- - start : (iterable of) :any:`Node`, optional - Visitor is only active after encountering one of the nodes in - :data:`start` and until encountering a node in :data:`stop`. - stop : (iterable of) :any:`Node`, optional - Visitor is no longer active after encountering one of the nodes in - :data:`stop` until it encounters again a node in :data:`start`. - active : bool, optional - Set the visitor active right from the beginning. - candidate_set : set of :any:`Node`, optional - If given, only reads for symbols in this set are considered. - clear_candidates_on_write : bool, optional - If enabled, writes of a symbol remove it from the :data:`candidate_set`. - """ - - def __init__(self, start=None, stop=None, active=False, - candidate_set=None, clear_candidates_on_write=False, **kwargs): - super().__init__(**kwargs) - self.start = set(as_tuple(start)) - self.stop = set(as_tuple(stop)) - self.active = active - self.candidate_set = candidate_set - self.clear_candidates_on_write = clear_candidates_on_write - self.reads = set() - - @staticmethod - def _symbols_from_expr(expr): - """ - Return set of symbols found in an expression. - """ - return {v.clone(dimensions=None) for v in FindVariables().visit(expr)} - - def _register_reads(self, read_symbols): - if self.active: - if self.candidate_set is None: - self.reads |= read_symbols - else: - self.reads |= read_symbols & self.candidate_set - - def _register_writes(self, write_symbols): - if self.active and self.clear_candidates_on_write and self.candidate_set is not None: - self.candidate_set -= write_symbols - - def visit(self, o, *args, **kwargs): - self.active = (self.active and o not in self.stop) or o in self.start - return super().visit(o, *args, **kwargs) - - def visit_object(self, o, **kwargs): # pylint: disable=unused-argument - pass - - def visit_LeafNode(self, o, **kwargs): # pylint: disable=unused-argument - self._register_reads(o.uses_symbols) - self._register_writes(o.defines_symbols) - - def visit_Conditional(self, o, **kwargs): - self._register_reads(self._symbols_from_expr(o.condition)) - # Visit each branch with the original candidate set and then take the - # union of both afterwards to include all potential read-after-writes - candidate_set = self.candidate_set.copy() if self.candidate_set is not None else None - self.visit(o.body, **kwargs) - self.candidate_set, candidate_set = candidate_set, self.candidate_set - self.visit(o.else_body, **kwargs) - if self.candidate_set is not None: - self.candidate_set |= candidate_set - - def visit_Loop(self, o, **kwargs): - self._register_reads(self._symbols_from_expr(o.bounds)) - active = self.active - if self.active and self.candidate_set is not None: - # remove the loop variable as a variable of interest - self.candidate_set.discard(o.variable) - self.visit(o.children, **kwargs) - if active: - self.reads.discard(o.variable) - - def visit_WhileLoop(self, o, **kwargs): - self._register_reads(self._symbols_from_expr(o.condition)) - self.visit(o.children, **kwargs) - - -class FindWrites(Visitor): - """ - Look for writes in a specified part of a control flow tree. - - Parameters - ---------- - start : (iterable of) :any:`Node`, optional - Visitor is only active after encountering one of the nodes in - :data:`start` and until encountering a node in :data:`stop`. - stop : (iterable of) :any:`Node`, optional - Visitor is no longer active after encountering one of the nodes in - :data:`stop` until it encounters again a node in :data:`start`. - active : bool, optional - Set the visitor active right from the beginning. - candidate_set : set of :any:`Node`, optional - If given, only writes for symbols in this set are considered. - """ - - def __init__(self, start=None, stop=None, active=False, - candidate_set=None, **kwargs): - super().__init__(**kwargs) - self.start = set(as_tuple(start)) - self.stop = set(as_tuple(stop)) - self.active = active - self.candidate_set = candidate_set - self.writes = set() - - @staticmethod - def _symbols_from_expr(expr): - """ - Return set of symbols found in an expression. - """ - return {v.clone(dimensions=None) for v in FindVariables().visit(expr)} - - def _register_writes(self, write_symbols): - if self.candidate_set is None: - self.writes |= write_symbols - else: - self.writes |= write_symbols & self.candidate_set - - def visit(self, o, *args, **kwargs): - self.active = (self.active and o not in self.stop) or o in self.start - return super().visit(o, *args, **kwargs) - - def visit_object(self, o, **kwargs): # pylint: disable=unused-argument - pass - - def visit_LeafNode(self, o, **kwargs): # pylint: disable=unused-argument - if self.active: - self._register_writes(o.defines_symbols) - - def visit_Loop(self, o, **kwargs): - if self.active: - # remove the loop variable as a variable of interest - if self.candidate_set is not None: - self.candidate_set.discard(o.variable) - self.writes.discard(o.variable) - super().visit_Node(o, **kwargs) - - -def read_after_write_vars(ir, inspection_node): - """ - Find variables that are read after being written in the given IR. - - This requires prior application of :meth:`dataflow_analysis_attached` to - the corresponding :any:`Module` or :any:`Subroutine`. - - The result is the set of variables with a data dependency across the - :data:`inspection_node`. - - See the remarks about implementation and limitations in the description of - :meth:`dataflow_analysis_attached`. In particular, this does not take into - account data space and iteration space for arrays. - - Parameters - ---------- - ir : :any:`Node` - The root of the control flow (sub-)tree to inspect. - inspection_node : :any:`Node` - Only variables with a write before and a read at or after this node - are considered. - - Returns - ------- - :any:`set` of :any:`Scalar` or :any:`Array` - The list of read-after-write variables. - """ - write_visitor = FindWrites(stop=inspection_node, active=True) - write_visitor.visit(ir) - read_visitor = FindReads(start=inspection_node, candidate_set=write_visitor.writes, - clear_candidates_on_write=True) - read_visitor.visit(ir) - return read_visitor.reads - - -def loop_carried_dependencies(loop): - """ - Find variables that are potentially loop-carried dependencies. - - This requires prior application of :meth:`dataflow_analysis_attached` to - the corresponding :any:`Module` or :any:`Subroutine`. - - See the remarks about implementation and limitations in the description of - :meth:`dataflow_analysis_attached`. In particular, this does not take into - account data space and iteration space for arrays. For cases with a - linear mapping from iteration to data space and no overlap, this will - falsely report loop-carried dependencies when there are in fact none. - However, the risk of false negatives should be low. - - Parameters - ---------- - loop : :any:`Loop` - The loop node to inspect. - - Returns - ------- - :any:`set` of :any:`Scalar` or :any:`Array` - The list of variables that potentially have a loop-carried dependency. - """ - return loop.uses_symbols & loop.defines_symbols diff --git a/loki/analyse/constant_propagation_analysis.py b/loki/analyse/constant_propagation_analysis.py new file mode 100644 index 000000000..85131b242 --- /dev/null +++ b/loki/analyse/constant_propagation_analysis.py @@ -0,0 +1,435 @@ +# (C) Copyright 2024- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import itertools +import math +import operator +import functools +from copy import deepcopy + +from loki import Transformer, Array, DeferredTypeSymbol, RangeIndex, \ + IntLiteral, get_pyrange, LoopRange, as_tuple, Assignment, FindNodes, LokiIdentityMapper, FindVariables, Loop, \ + is_constant +from loki.expression.symbols import _Literal, FloatLiteral, LogicLiteral, Product, StringLiteral +from loki.analyse.data_flow_analysis import DataFlowAnalysis +from loki.analyse.abstract_dfa import AbstractDataflowAnalysis +from loki.transformations.transform_loop import LoopUnrollTransformer + +__all__ = [ + 'ConstantPropagationAnalysis' +] + +class ConstantPropagationAnalysis(AbstractDataflowAnalysis): + class ConstPropMapper(LokiIdentityMapper): + def __init__(self, fold_floats=True): + self.fold_floats = fold_floats + super().__init__() + + def map_array(self, expr, *args, **kwargs): + constants_map = kwargs.get('constants_map', dict()) + return constants_map.get((expr.basename, getattr(expr, 'dimensions', ())), expr) + + map_scalar = map_array + map_deferred_type_symbol = map_array + + def map_constant(self, expr, *args, **kwargs): + if isinstance(expr, int): + return IntLiteral(expr) + elif isinstance(expr, float): + return FloatLiteral(str(expr)) + elif isinstance(expr, bool): + return LogicLiteral(expr) + + def map_sum(self, expr, *args, **kwargs): + return self.binary_num_op_helper(expr, sum, math.fsum, *args, **kwargs) + + def map_product(self, expr, *args, **kwargs): + mapped_product = self.binary_num_op_helper(expr, math.prod, math.prod, *args, **kwargs) + # Only way to get this here is if loki transformed `-expr` to `-1 * expr`, but couldn't const prop for expr + if getattr(mapped_product, 'children', (False,))[0] == IntLiteral(-1): + mapped_product = Product((-1, mapped_product.children[1])) + return mapped_product + + def map_quotient(self, expr, *args, **kwargs): + return self.binary_num_op_helper(expr, operator.floordiv, operator.truediv, + left_attr='numerator', right_attr='denominator', *args, **kwargs) + + def map_power(self, expr, *args, **kwargs): + return self.binary_num_op_helper(expr, operator.pow, operator.pow, + left_attr='base', right_attr='exponent', *args, **kwargs) + + def binary_num_op_helper(self, expr, int_op, float_op, left_attr=None, right_attr=None, *args, **kwargs): + lr_fields = not (left_attr is None and right_attr is None) + if lr_fields: + children = [getattr(expr, left_attr), getattr(expr, right_attr)] + else: + children = expr.children + + children = self.rec(children, *args, **kwargs) + + literals, non_literals = ConstantPropagationAnalysis._separate_literals(children) + if len(non_literals) == 0: + if any([isinstance(v, FloatLiteral) for v in literals]): + # Strange rounding possibility + if self.fold_floats: + if lr_fields: + return FloatLiteral(str(float_op(float(children[0].value), float(children[1].value)))) + else: + return FloatLiteral(str(float_op([float(c.value) for c in children]))) + else: + if lr_fields: + return IntLiteral(int_op(children[0].value, children[1].value)) + else: + return IntLiteral(int_op([c.value for c in children])) + + if lr_fields: + return expr.__class__(children[0], children[1]) + else: + return expr.__class__(children) + + def map_logical_and(self, expr, *args, **kwargs): + return self.binary_bool_op_helper(expr, lambda x, y: x and y, True, *args, **kwargs) + + def map_logical_or(self, expr, *args, **kwargs): + return self.binary_bool_op_helper(expr, lambda x, y: x or y, False, *args, **kwargs) + + def binary_bool_op_helper(self, expr, bool_op, initial, *args, **kwargs): + # This short-circuiting check is done twice as we might get lucky and not have to rec() + if LogicLiteral(not initial) in expr.children: + return LogicLiteral(not initial) + + children = tuple([self.rec(c, *args, **kwargs) for c in expr.children]) + + # Second short-circuiting check + if LogicLiteral(not initial) in children: + return LogicLiteral(not initial) + + literals, non_literals = ConstantPropagationAnalysis._separate_literals(children) + if len(non_literals) == 0: + return LogicLiteral(functools.reduce(bool_op, [c.value for c in children], initial)) + + return expr.__class__(children) + + def map_logical_not(self, expr, *args, **kwargs): + child = self.rec(expr.child, **kwargs) + + literals, non_literals = ConstantPropagationAnalysis._separate_literals([child]) + if len(non_literals) == 0: + return LogicLiteral(not child.value) + + return expr.__class__(child) + + def map_comparison(self, expr, *args, **kwargs): + left = self.rec(expr.left, *args, **kwargs) + right = self.rec(expr.right, *args, **kwargs) + + literals, non_literals = ConstantPropagationAnalysis._separate_literals([left, right]) + if len(non_literals) == 0: + # TODO: This should be a match statement >=3.10 + operators_map = { + 'lt': operator.lt, + 'le': operator.le, + 'eq': operator.eq, + 'ne': operator.ne, + 'ge': operator.ge, + 'gt': operator.gt, + } + operator_str = expr.operator if expr.operator in operators_map.keys() else expr.operator_to_name[ + expr.operator] + return LogicLiteral(operators_map[operator_str](left.value, right.value)) + + return expr.__class__(left, expr.operator, right) + + def map_loop_range(self, expr, *args, **kwargs): + start = self.rec(expr.start, *args, **kwargs) + stop = self.rec(expr.stop, *args, **kwargs) + step = self.rec(expr.step, *args, **kwargs) + return expr.__class__((start, stop, step)) + + def map_string_concat(self, expr, *args, **kwargs): + children = tuple([self.rec(c, *args, **kwargs) for c in expr.children]) + + literals, non_literals = ConstantPropagationAnalysis._separate_literals(children) + if len(non_literals) == 0: + return StringLiteral(''.join([c.value for c in children])) + + return expr.__class__(children) + + + class _Attacher(Transformer): + + def _pop_array_accesses(self, o, **kwargs): + # Clear out the unknown dimensions + constants_map = kwargs.get('constants_map', dict()) + new_shape = ConstantPropagationAnalysis.ConstPropMapper()(o.lhs.shape, constants_map=constants_map) + + # Create masks for literals and dimensions we can compute form the shape + literal_mask = [is_constant(d) for d in o.lhs.dimensions] + computable_dimension_mask = [is_constant(ns) for ns in new_shape] + + # Build list of indices to pass to _array_indices_to_accesses + masked_indices = [] + # Build mask of indices that are neither literal nor computable, + # and so can be ignored when partially matching + ignore_mask = [] + partial = False + for i, (lm, cdm) in enumerate(zip(literal_mask, computable_dimension_mask)): + if lm: + masked_indices.append(o.lhs.dimensions[i]) + ignore_mask.append(False) + elif cdm: + masked_indices.append(RangeIndex((None, None, None))) + ignore_mask.append(False) + else: + # We now need to take the scenic route of finding partial matches + partial = True + masked_indices.append(-1) + ignore_mask.append(True) + + # Expand the indices into accesses + possible_accesses = ConstantPropagationAnalysis._array_indices_to_accesses(masked_indices, new_shape) + keys = constants_map.keys() + + for access in possible_accesses: + if partial: + # Find partial matches and pop any candidates + for key in keys: + if key[0] == o.lhs.name and all(k == a or ignore for k,a,ignore in [zip(key[1], access, ignore_mask)]): + constants_map.pop(key) + else: + constants_map.pop((o.lhs.basename, access), None) + + def __init__(self, parent, **kwargs): + self.parent = parent + super().__init__(inplace=not self.parent._apply_transform, invalidate_source=self.parent._apply_transform, **kwargs) + + def visit_Assignment(self, o, **kwargs): + constants_map = kwargs.get('constants_map', dict()) + # Create a deep copy of the constants map when we enter this node. This is so that the node + # has the constants as they were before it mutated them, which is probably more useful + constants_map_in = deepcopy(constants_map) + o._update(_constants_map=constants_map_in) + + new_rhs = ConstantPropagationAnalysis.ConstPropMapper(self.parent.fold_floats)(o.rhs, **kwargs) + if self.parent._apply_transform: + o._update(rhs=new_rhs) + # Work with this lhs in case we're not applying transforms & can't modify o.lhs + lhs = o.lhs + + # What if the lhs isn't a scalar shape? + if isinstance(lhs, Array): + new_dimensions = [ConstantPropagationAnalysis.ConstPropMapper(self.parent.fold_floats)(d, **kwargs) for d in lhs.dimensions] + _, new_d_non_literals = ConstantPropagationAnalysis._separate_literals(new_dimensions) + + new_lhs = Array(lhs.name, lhs.scope, lhs.type, as_tuple(new_dimensions)) + if self.parent._apply_transform: + o._update(lhs=new_lhs) + lhs = new_lhs + if len(new_d_non_literals) != 0: + self._pop_array_accesses(o, **kwargs) + return o + + literals, non_literals = ConstantPropagationAnalysis._separate_literals([new_rhs]) + if len(non_literals) == 0: + if isinstance(lhs, Array): + for access in ConstantPropagationAnalysis._array_indices_to_accesses(lhs.dimensions, lhs.shape): + constants_map[(lhs.basename, access)] = new_rhs + else: + constants_map[(lhs.basename, ())] = new_rhs + else: + # TODO: What if it's a pointer + if isinstance(lhs, Array): + for access in ConstantPropagationAnalysis._array_indices_to_accesses(lhs.dimensions, lhs.shape): + constants_map.pop((lhs.basename, access), None) + else: + constants_map.pop((lhs.basename, ()), None) + + return o + + def visit(self, o, *args, **kwargs): + constants_map = kwargs.pop('constants_map', dict()) + return super().visit(o, *args, constants_map=constants_map, **kwargs) + + def visit_Conditional(self, o, **kwargs): + constants_map = kwargs.pop('constants_map', dict()) + constants_map_in = deepcopy(constants_map) + o._update(_constants_map=constants_map_in) + + new_condition = ConstantPropagationAnalysis.ConstPropMapper(self.parent.fold_floats)(o.condition, constants_map=constants_map, **kwargs) + body_constants_map = deepcopy(constants_map) + else_body_constants_map = deepcopy(constants_map) + new_body = self.visit(o.body, constants_map=body_constants_map, **kwargs) + new_else_body = self.visit(o.else_body, constants_map=else_body_constants_map, **kwargs) + + if self.parent._apply_transform: + o._update( + condition=new_condition, + body=new_body, + else_body=new_else_body + ) + + if isinstance(new_condition, LogicLiteral): + if new_condition.value: + constants_map.update(body_constants_map) + else: + constants_map.update(else_body_constants_map) + else: + for key in set(body_constants_map.keys()).union(else_body_constants_map): + if body_constants_map.get(key, None) == else_body_constants_map.get(key, None): + constants_map[key] = body_constants_map[key] + else: + constants_map.pop(key, None) + + return o + + def visit_Loop(self, o, **kwargs): + constants_map = kwargs.pop('constants_map', dict()) + constants_map_in = deepcopy(constants_map) + o._update(_constants_map=constants_map_in) + + constants_map.pop((o.variable.basename, ()), None) + + new_bounds = ConstantPropagationAnalysis.ConstPropMapper(self.parent.fold_floats)(o.bounds, constants_map=constants_map, **kwargs) + if self.parent._apply_transform: + o._update(bounds=new_bounds) + + if self.parent.unroll_loops: + temp_loop = o.clone() + temp_loop._update(bounds=new_bounds) + unrolled = LoopUnrollTransformer(warn_iterations_length=False).visit(temp_loop) + # If we cannot unroll, then we need to fall back to the no unroll analysis + if not isinstance(unrolled, Loop): + if self.parent._apply_transform: + o = self.visit(unrolled, constants_map=constants_map, **kwargs) + # TODO: _update each node in the new body with the const map + return o + + # TODO: could also be mutating subroutine, not just an assign + lhs_vars = {o.variable} + lhs_vars.update([l.variable for l in FindNodes(Loop).visit(o.body)]) + + new_body_constants_map = deepcopy(constants_map) + new_body = self.visit(o.body, constants_map=new_body_constants_map, **kwargs) + popped_keys = constants_map_in.keys() - new_body_constants_map.keys() + for key in popped_keys: + constants_map.pop(key, None) + + # Build a set of invariants + assignments = FindNodes(Assignment).visit(new_body) + for a in assignments: + lhs_vars.add(a.lhs) + + bounds_are_const = (is_constant(new_bounds.start) and is_constant(new_bounds.stop) and (is_constant(new_bounds.step) or new_bounds.step is None)) + bounds_has_steps = bounds_are_const and len(get_pyrange(LoopRange((new_bounds.start, new_bounds.stop, new_bounds.step)))) > 0 + # Then figure out which lhs are generated from only invariants + + # If all bounds are const (i.e. we can check we'll take the loop at least once), + if bounds_are_const: + # if we can guarantee we'll take the loop at least once + if bounds_has_steps: + consts_map = constants_map + # else do the transform, but don't update the consts map + else: + consts_map = deepcopy(constants_map) + + for a in assignments: + # if rhs of a has no lhs vars from loop body (i.e. consists solely of loop invariant var) + if len(set(FindVariables().visit(a.rhs)).intersection(lhs_vars)) == 0: + # Pass to visit_assignment + self.visit_Assignment(a, constants_map=consts_map, **kwargs) + + # If not all bounds are const (i.e. we don't know if we're taking the loop or not) + elif not bounds_are_const: + for a in assignments: + if isinstance(a.lhs, Array): + self._pop_array_accesses(a, constants_map=constants_map, **kwargs) + else: + constants_map.pop((a.lhs.basename, ()), None) + + if self.parent._apply_transform: + o._update(body=new_body) + + return o + + class _Detacher(Transformer): + """ + Remove in-place any dataflow analysis properties. + """ + + def __init__(self, **kwargs): + super().__init__(inplace=True, invalidate_source=False, **kwargs) + + def visit_Node(self, o, **kwargs): + o._update(_constants_map=None) + return super().visit_Node(o, **kwargs) + + def __init__(self, fold_floats, unroll_loops, _apply_transform=False): + self.fold_floats = fold_floats + self.unroll_loops = unroll_loops + self._apply_transform = _apply_transform + super().__init__() + + def get_attacher(self): + return self._Attacher(self) + + def attach_dataflow_analysis(self, module_or_routine): + constants_map = self.generate_declarations_map(module_or_routine) + + # TODO: Implement + # if hasattr(module_or_routine, 'spec'): + # (self.get_attacher().visit(module_or_routine.spec, constants_map=constants_map)) + + if hasattr(module_or_routine, 'body'): + (self.get_attacher().visit(module_or_routine.body, constants_map=constants_map)) + + def generate_declarations_map(self, routine): + def index_initial_elements(i, e): + if len(i) == 1: + return e.elements[i[0].value - 1] + else: + return index_initial_elements(i[1:], e.elements[i[0].value - 1]) + + declarations_map = dict() + # TODO: What if there's a context already? + with DataFlowAnalysis().dataflow_analysis_attached(routine): + for s in routine.symbols: + if isinstance(s, DeferredTypeSymbol) or s.initial is None: + continue + if isinstance(s, Array): + declarations_map.update({(s.basename, i): index_initial_elements(i, s.initial) for i in + self._array_indices_to_accesses( + [RangeIndex((None, None, None))] * len(s.shape), s.shape + )}) + else: + declarations_map[(s.basename, ())] = s.initial + return declarations_map + + @staticmethod + def _array_indices_to_accesses(dimensions, shape): + accesses = functools.partial(itertools.product) + for (count, dimension) in enumerate(dimensions): + if isinstance(dimension, RangeIndex): + start = dimension.start if dimension.start is not None else IntLiteral(1) + # TODO: shape[] might not be as nice as we want + stop = dimension.stop if dimension.stop is not None else shape[count] + accesses = functools.partial(accesses, [IntLiteral(v) for v in + get_pyrange(LoopRange((start, stop, dimension.step)))]) + else: + accesses = functools.partial(accesses, [dimension]) + + return accesses() + + @staticmethod + def _separate_literals(children): + separated = ([], []) + for c in children: + # is_constant only covers int, float, & complex + if isinstance(c, _Literal): + separated[0].append(c) + else: + separated[1].append(c) + return separated diff --git a/loki/analyse/data_flow_analysis.py b/loki/analyse/data_flow_analysis.py new file mode 100644 index 000000000..0495345e0 --- /dev/null +++ b/loki/analyse/data_flow_analysis.py @@ -0,0 +1,580 @@ +# (C) Copyright 2024- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from loki import flatten, as_tuple, Transformer, Subroutine, CaseInsensitiveDict, FindInlineCalls, FindVariables, \ + BasicType, Array, FindTypedSymbols, ProcedureSymbol, Visitor +from loki.analyse.abstract_dfa import AbstractDataflowAnalysis + +__all__ = [ + 'DataFlowAnalysis', 'read_after_write_vars', + 'loop_carried_dependencies' +] + +class DataFlowAnalysis(AbstractDataflowAnalysis): + class _Attacher(Transformer): + """ + Analyse and attach in-place the definition, use and live status of + symbols. + """ + + # group of functions that only query memory properties and don't read/write variable value + _mem_property_queries = ('size', 'lbound', 'ubound', 'present') + + def __init__(self, **kwargs): + super().__init__(inplace=True, invalidate_source=False, **kwargs) + + # Utility routines + + def _visit_body(self, body, live=None, defines=None, uses=None, **kwargs): + """ + Iterate through the tuple that is a body and update defines and + uses along the way. + """ + if live is None: + live = set() + if defines is None: + defines = set() + if uses is None: + uses = set() + visited = [] + for i in flatten(body): + visited += [self.visit(i, live_symbols=live | defines, **kwargs)] + uses |= visited[-1].uses_symbols.copy() - defines + defines |= visited[-1].defines_symbols.copy() + return as_tuple(visited), defines, uses + + @staticmethod + def _symbols_from_lhs_expr(expr): + """ + Determine symbol use and symbol definition from a left-hand side expression. + + Parameters + ---------- + expr : :any:`Scalar` or :any:`Array` + The left-hand side expression of an assignment. + + Returns + ------- + (defines, uses) : (set, set) + The sets of defined and used symbols (in that order). + """ + defines = {expr.clone(dimensions=None)} + uses = DataFlowAnalysis._symbols_from_expr(getattr(expr, 'dimensions', ())) + return defines, uses + + # Abstract node (also called from every node type for integration) + + def visit_Node(self, o, **kwargs): + # Live symbols are determined on InternalNode handler levels and + # get passed down to all child nodes + o._update(_live_symbols=kwargs.get('live_symbols', set())) + + # Symbols defined or used by this node are determined by their individual + # handler routines and passed on to visitNode from there + o._update(_defines_symbols=kwargs.get('defines_symbols', set())) + o._update(_uses_symbols=kwargs.get('uses_symbols', set())) + return o + + # Internal nodes + + def visit_Interface(self, o, **kwargs): + # Subroutines/functions calls defined in an explicit interface + defines = set() + for b in o.body: + if isinstance(b, Subroutine): + defines = defines | set(as_tuple(b.procedure_symbol)) + return self.visit_Node(o, defines_symbols=defines, **kwargs) + + def visit_InternalNode(self, o, **kwargs): + # An internal node defines all symbols defined by its body and uses all + # symbols used by its body before they are defined in the body + live = kwargs.pop('live_symbols', set()) + body, defines, uses = self._visit_body(o.body, live=live, **kwargs) + o._update(body=body) + return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_Associate(self, o, **kwargs): + # An associate block defines all symbols defined by its body and uses all + # symbols used by its body before they are defined in the body + live = kwargs.pop('live_symbols', set()) + body, defines, uses = self._visit_body(o.body, live=live, **kwargs) + o._update(body=body) + + # reverse the mapping of names before assinging lives, defines, uses sets for Associate node itself + invert_assoc = CaseInsensitiveDict({v.name: k for k, v in o.associations}) + _live = set(invert_assoc[v.name] if v.name in invert_assoc else v for v in live) + _defines = set(invert_assoc[v.name] if v.name in invert_assoc else v for v in defines) + _uses = set(invert_assoc[v.name] if v.name in invert_assoc else v for v in uses) + + return self.visit_Node(o, live_symbols=_live, defines_symbols=_defines, uses_symbols=_uses, **kwargs) + + def visit_Loop(self, o, **kwargs): + # A loop defines the induction variable for its body before entering it + live = kwargs.pop('live_symbols', set()) + uses = DataFlowAnalysis._symbols_from_expr(o.bounds) + body, defines, uses = self._visit_body(o.body, live=live | {o.variable.clone()}, uses=uses, **kwargs) + o._update(body=body) + # Make sure the induction variable is not considered outside the loop + uses.discard(o.variable) + defines.discard(o.variable) + return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_WhileLoop(self, o, **kwargs): + # A while loop uses variables in its condition + live = kwargs.pop('live_symbols', set()) + uses = DataFlowAnalysis._symbols_from_expr(o.condition) + body, defines, uses = self._visit_body(o.body, live=live, uses=uses, **kwargs) + o._update(body=body) + return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_Conditional(self, o, **kwargs): + live = kwargs.pop('live_symbols', set()) + + # exclude arguments to functions that just check the memory attributes of a variable + mem_call = as_tuple( + i for i in FindInlineCalls().visit(o.condition) if i.function in self._mem_property_queries) + query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_call)) + cset = set(v for v in FindVariables().visit(o.condition) if not v in query_args) + + condition = DataFlowAnalysis._symbols_from_expr(as_tuple(cset)) + body, defines, uses = self._visit_body(o.body, live=live, uses=condition, **kwargs) + else_body, else_defines, uses = self._visit_body(o.else_body, live=live, uses=uses, **kwargs) + o._update(body=body, else_body=else_body) + return self.visit_Node(o, live_symbols=live, defines_symbols=defines | else_defines, uses_symbols=uses, + **kwargs) + + def visit_MultiConditional(self, o, **kwargs): + live = kwargs.pop('live_symbols', set()) + + # exclude arguments to functions that just check the memory attributes of a variable + mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.expr) if i.function in self._mem_property_queries) + query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls)) + eset = set(v for v in FindVariables().visit(o.expr) if not v in query_args) + + mem_calls = as_tuple( + i for i in FindInlineCalls().visit(o.values) if i.function in self._mem_property_queries) + query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls)) + vset = set(v for v in FindVariables().visit(o.values) if not v in query_args) + + uses = DataFlowAnalysis._symbols_from_expr(as_tuple(eset)) | DataFlowAnalysis._symbols_from_expr(as_tuple(vset)) + body = () + defines = set() + for b in o.bodies: + _b, _d, uses = self._visit_body(b, live=live, uses=uses, **kwargs) + body += (as_tuple(_b),) + defines |= _d + else_body, else_defines, uses = self._visit_body(o.else_body, live=live, uses=uses, **kwargs) + o._update(bodies=body, else_body=else_body) + defines = defines | else_defines + return self.visit_Node(o, live_symbols=live, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_MaskedStatement(self, o, **kwargs): + live = kwargs.pop('live_symbols', set()) + conditions = DataFlowAnalysis._symbols_from_expr(o.conditions) + + body = () + defines = set() + uses = set(conditions) + for b in o.bodies: + _b, defines, uses = self._visit_body(b, live=live, uses=uses, defines=defines, **kwargs) + body += (_b,) + + default, default_defs, uses = self._visit_body(o.default, live=live, uses=uses, **kwargs) + o._update(bodies=body, default=default) + return self.visit_Node(o, live_symbols=live, defines_symbols=defines | default_defs, uses_symbols=uses, + **kwargs) + + # Leaf nodes + + def visit_Assignment(self, o, **kwargs): + # exclude arguments to functions that just check the memory attributes of a variable + mem_calls = as_tuple(i for i in FindInlineCalls().visit(o.rhs) if i.function in self._mem_property_queries) + query_args = as_tuple(flatten(FindVariables().visit(i.parameters) for i in mem_calls)) + rset = set(v for v in FindVariables().visit(o.rhs) if not v in query_args) + + # The left-hand side variable is defined by this statement + defines, uses = self._symbols_from_lhs_expr(o.lhs) + + # Anything on the right-hand side is used before assigning to it + uses |= DataFlowAnalysis._symbols_from_expr(as_tuple(rset)) + return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_ConditionalAssignment(self, o, **kwargs): + # The left-hand side variable is defined by this statement + defines, uses = self._symbols_from_lhs_expr(o.lhs) + # Anything on the right-hand side is used before assigning to it + uses |= DataFlowAnalysis._symbols_from_expr((o.condition, o.rhs, o.else_rhs)) + return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_CallStatement(self, o, **kwargs): + if o.routine is not BasicType.DEFERRED: + # With a call context provided we can determine which arguments + # are potentially defined and which are definitely only used by + # this call + defines, uses = set(), set() + outvals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'out')] + invals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'in')] + + arrays = [v for v in FindVariables().visit(outvals) if isinstance(v, Array)] + dims = set(v for a in arrays for v in DataFlowAnalysis._symbols_from_expr(a.dimensions)) + for val in outvals: + exprs = DataFlowAnalysis._symbols_from_expr(val) + defines |= {e for e in exprs if not e in dims} + uses |= dims + + uses |= {s for val in invals for s in DataFlowAnalysis._symbols_from_expr(val)} + else: + # We don't know the intent of any of these arguments and thus have + # to assume all of them are potentially used or defined by this + # statement + arrays = [v for v in FindVariables().visit(o.arguments) if isinstance(v, Array)] + arrays += [v for arg, val in o.kwarguments for v in FindVariables().visit(val) if isinstance(v, Array)] + + dims = set(v for a in arrays for v in FindVariables().visit(a.dimensions)) + defines = DataFlowAnalysis._symbols_from_expr(o.arguments, condition=lambda x: x not in dims) + for arg, val in o.kwarguments: + defines |= DataFlowAnalysis._symbols_from_expr(val, condition=lambda x: x not in dims) + uses = defines.copy() | dims + + return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_Allocation(self, o, **kwargs): + arrays = [v for v in FindVariables().visit(o.variables) if isinstance(v, Array)] + dims = set(v for a in arrays for v in FindVariables().visit(a.dimensions)) + defines = DataFlowAnalysis._symbols_from_expr(o.variables, condition=lambda x: x not in dims) + uses = DataFlowAnalysis._symbols_from_expr(o.data_source or ()) | dims + return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) + + def visit_Deallocation(self, o, **kwargs): + defines = DataFlowAnalysis._symbols_from_expr(o.variables) + return self.visit_Node(o, defines_symbols=defines, **kwargs) + + visit_Nullify = visit_Deallocation + + def visit_Import(self, o, **kwargs): + defines = set(s.clone(dimensions=None) for s in FindTypedSymbols().visit(o.symbols or ()) + if isinstance(s, ProcedureSymbol)) + return self.visit_Node(o, defines_symbols=defines, **kwargs) + + def visit_VariableDeclaration(self, o, **kwargs): + defines = DataFlowAnalysis._symbols_from_expr(o.symbols, condition=lambda v: v.type.initial is not None) + uses = {v for a in o.symbols if isinstance(a, Array) for v in DataFlowAnalysis._symbols_from_expr(a.dimensions)} + return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs) + + class _Detacher(Transformer): + """ + Remove in-place any dataflow analysis properties. + """ + + def __init__(self, **kwargs): + super().__init__(inplace=True, invalidate_source=False, **kwargs) + + def visit_Node(self, o, **kwargs): + o._update(_live_symbols=None, _defines_symbols=None, _uses_symbols=None) + return super().visit_Node(o, **kwargs) + + @staticmethod + def _symbols_from_expr(expr, condition=None): + """ + Return set of symbols found in an expression. + """ + if condition is not None: + return {v.clone(dimensions=None) for v in FindVariables().visit(expr) if condition(v)} + return {v.clone(dimensions=None) for v in FindVariables().visit(expr)} + + + def attach_dataflow_analysis(self, module_or_routine): + """ + Determine and attach to each IR node dataflow analysis metadata. + + This makes for each IR node the following properties available: + + * :attr:`Node.live_symbols`: symbols defined before the node; + * :attr:`Node.defines_symbols`: symbols (potentially) defined by the + node, i.e., live in subsequent nodes; + * :attr:`Node.uses_symbols`: symbols used by the node (that had to be + defined before). + + The IR nodes are updated in-place and thus existing references to IR + nodes remain valid. + """ + live_symbols = set() + if hasattr(module_or_routine, 'arguments'): + live_symbols = self._symbols_from_expr( + module_or_routine.arguments, + condition=lambda a: a.type.intent and a.type.intent.lower() in ('in', 'inout') + ) + + if hasattr(module_or_routine, 'spec'): + self.get_attacher().visit(module_or_routine.spec, live_symbols=live_symbols) + live_symbols |= module_or_routine.spec.defines_symbols + + if hasattr(module_or_routine, 'body'): + self.get_attacher().visit(module_or_routine.body, live_symbols=live_symbols) + + def dataflow_analysis_attached(self, module_or_routine): + r""" + Create a context in which information about defined, live and used symbols + is attached to each IR node + + This makes for each IR node the following properties available: + + * :attr:`Node.live_symbols`: symbols defined before the node; + * :attr:`Node.defines_symbols`: symbols (potentially) defined by the + node; + * :attr:`Node.uses_symbols`: symbols used by the node that had to be + defined before. + + This is an in-place update of nodes and thus existing references to IR + nodes remain valid. When leaving the context the information is removed + from IR nodes, while existing references remain valid. + + The analysis is based on a rather crude regions-based analysis, with the + hierarchy implied by (nested) :any:`InternalNode` IR nodes used as regions + in the reducible flow graph (cf. Chapter 9, in particular 9.7 of Aho, Lam, + Sethi, and Ulliman (2007)). Our implementation shares some similarities + with a full reaching definitions dataflow analysis but is not quite as + powerful. + + In reaching definitions dataflow analysis (cf. Chapter 9.2.4 Aho et. al.), + the transfer function of a definition :math:`d` can be expressed as: + + .. math:: f_d(x) = \operatorname{gen}_d \cup (x - \operatorname{kill}_d) + + with the set of definitions generated :math:`\operatorname{gen}_d` and the + set of definitions killed/invalidated :math:`\operatorname{kill}_d`. + + We, however, do not record definitions explicitly and instead operate on + consolidated sets of defined symbols, i.e., effectively evaluate the + chained transfer functions up to the node. This yields a set of active + definitions at this node. The symbols defined by these definitions are + in :any:`Node.live_symbols`, and the symbols defined by the node (i.e., + symbols defined by definitions in :math:`\operatorname{gen}_d`) are in + :any:`Node.defines_symbols`. + + The advantage of this approach is that it avoids the need to introduce + a layer for definitions and dependencies. A downside is that this focus + on symbols instead of definitions precludes, in particular, the ability + to take data space into account, which makes it less useful for arrays. + + .. note:: + The context manager operates only on the module or routine itself + (i.e., its spec and, if applicable, body), not on any contained + subroutines or functions. + + Parameters + ---------- + module_or_routine : :any:`Module` or :any:`Subroutine` + The object for which the IR is to be annotated. + """ + return super().dataflow_analysis_attached(module_or_routine) + +class FindReads(Visitor): + """ + Look for reads in a specified part of a control flow tree. + + Parameters + ---------- + start : (iterable of) :any:`Node`, optional + Visitor is only active after encountering one of the nodes in + :data:`start` and until encountering a node in :data:`stop`. + stop : (iterable of) :any:`Node`, optional + Visitor is no longer active after encountering one of the nodes in + :data:`stop` until it encounters again a node in :data:`start`. + active : bool, optional + Set the visitor active right from the beginning. + candidate_set : set of :any:`Node`, optional + If given, only reads for symbols in this set are considered. + clear_candidates_on_write : bool, optional + If enabled, writes of a symbol remove it from the :data:`candidate_set`. + """ + + def __init__(self, start=None, stop=None, active=False, + candidate_set=None, clear_candidates_on_write=False, **kwargs): + super().__init__(**kwargs) + self.start = set(as_tuple(start)) + self.stop = set(as_tuple(stop)) + self.active = active + self.candidate_set = candidate_set + self.clear_candidates_on_write = clear_candidates_on_write + self.reads = set() + + @staticmethod + def _symbols_from_expr(expr): + """ + Return set of symbols found in an expression. + """ + return {v.clone(dimensions=None) for v in FindVariables().visit(expr)} + + def _register_reads(self, read_symbols): + if self.active: + if self.candidate_set is None: + self.reads |= read_symbols + else: + self.reads |= read_symbols & self.candidate_set + + def _register_writes(self, write_symbols): + if self.active and self.clear_candidates_on_write and self.candidate_set is not None: + self.candidate_set -= write_symbols + + def visit(self, o, *args, **kwargs): + self.active = (self.active and o not in self.stop) or o in self.start + return super().visit(o, *args, **kwargs) + + def visit_object(self, o, **kwargs): # pylint: disable=unused-argument + pass + + def visit_LeafNode(self, o, **kwargs): # pylint: disable=unused-argument + self._register_reads(o.uses_symbols) + self._register_writes(o.defines_symbols) + + def visit_Conditional(self, o, **kwargs): + self._register_reads(self._symbols_from_expr(o.condition)) + # Visit each branch with the original candidate set and then take the + # union of both afterwards to include all potential read-after-writes + candidate_set = self.candidate_set.copy() if self.candidate_set is not None else None + self.visit(o.body, **kwargs) + self.candidate_set, candidate_set = candidate_set, self.candidate_set + self.visit(o.else_body, **kwargs) + if self.candidate_set is not None: + self.candidate_set |= candidate_set + + def visit_Loop(self, o, **kwargs): + self._register_reads(self._symbols_from_expr(o.bounds)) + active = self.active + if self.active and self.candidate_set is not None: + # remove the loop variable as a variable of interest + self.candidate_set.discard(o.variable) + self.visit(o.children, **kwargs) + if active: + self.reads.discard(o.variable) + + def visit_WhileLoop(self, o, **kwargs): + self._register_reads(self._symbols_from_expr(o.condition)) + self.visit(o.children, **kwargs) + + +class FindWrites(Visitor): + """ + Look for writes in a specified part of a control flow tree. + + Parameters + ---------- + start : (iterable of) :any:`Node`, optional + Visitor is only active after encountering one of the nodes in + :data:`start` and until encountering a node in :data:`stop`. + stop : (iterable of) :any:`Node`, optional + Visitor is no longer active after encountering one of the nodes in + :data:`stop` until it encounters again a node in :data:`start`. + active : bool, optional + Set the visitor active right from the beginning. + candidate_set : set of :any:`Node`, optional + If given, only writes for symbols in this set are considered. + """ + + def __init__(self, start=None, stop=None, active=False, + candidate_set=None, **kwargs): + super().__init__(**kwargs) + self.start = set(as_tuple(start)) + self.stop = set(as_tuple(stop)) + self.active = active + self.candidate_set = candidate_set + self.writes = set() + + @staticmethod + def _symbols_from_expr(expr): + """ + Return set of symbols found in an expression. + """ + return {v.clone(dimensions=None) for v in FindVariables().visit(expr)} + + def _register_writes(self, write_symbols): + if self.candidate_set is None: + self.writes |= write_symbols + else: + self.writes |= write_symbols & self.candidate_set + + def visit(self, o, *args, **kwargs): + self.active = (self.active and o not in self.stop) or o in self.start + return super().visit(o, *args, **kwargs) + + def visit_object(self, o, **kwargs): # pylint: disable=unused-argument + pass + + def visit_LeafNode(self, o, **kwargs): # pylint: disable=unused-argument + if self.active: + self._register_writes(o.defines_symbols) + + def visit_Loop(self, o, **kwargs): + if self.active: + # remove the loop variable as a variable of interest + if self.candidate_set is not None: + self.candidate_set.discard(o.variable) + self.writes.discard(o.variable) + super().visit_Node(o, **kwargs) + + +def read_after_write_vars(ir, inspection_node): + """ + Find variables that are read after being written in the given IR. + + This requires prior application of :meth:`dataflow_analysis_attached` to + the corresponding :any:`Module` or :any:`Subroutine`. + + The result is the set of variables with a data dependency across the + :data:`inspection_node`. + + See the remarks about implementation and limitations in the description of + :meth:`dataflow_analysis_attached`. In particular, this does not take into + account data space and iteration space for arrays. + + Parameters + ---------- + ir : :any:`Node` + The root of the control flow (sub-)tree to inspect. + inspection_node : :any:`Node` + Only variables with a write before and a read at or after this node + are considered. + + Returns + ------- + :any:`set` of :any:`Scalar` or :any:`Array` + The list of read-after-write variables. + """ + write_visitor = FindWrites(stop=inspection_node, active=True) + write_visitor.visit(ir) + read_visitor = FindReads(start=inspection_node, candidate_set=write_visitor.writes, + clear_candidates_on_write=True) + read_visitor.visit(ir) + return read_visitor.reads + + +def loop_carried_dependencies(loop): + """ + Find variables that are potentially loop-carried dependencies. + + This requires prior application of :meth:`dataflow_analysis_attached` to + the corresponding :any:`Module` or :any:`Subroutine`. + + See the remarks about implementation and limitations in the description of + :meth:`dataflow_analysis_attached`. In particular, this does not take into + account data space and iteration space for arrays. For cases with a + linear mapping from iteration to data space and no overlap, this will + falsely report loop-carried dependencies when there are in fact none. + However, the risk of false negatives should be low. + + Parameters + ---------- + loop : :any:`Loop` + The loop node to inspect. + + Returns + ------- + :any:`set` of :any:`Scalar` or :any:`Array` + The list of variables that potentially have a loop-carried dependency. + """ + return loop.uses_symbols & loop.defines_symbols diff --git a/loki/analyse/tests/test_analyse_dataflow.py b/loki/analyse/tests/test_data_flow_analyse.py similarity index 93% rename from loki/analyse/tests/test_analyse_dataflow.py rename to loki/analyse/tests/test_data_flow_analyse.py index 898baf5ad..f4b849d5e 100644 --- a/loki/analyse/tests/test_analyse_dataflow.py +++ b/loki/analyse/tests/test_data_flow_analyse.py @@ -13,7 +13,7 @@ Associate, Module ) from loki.analyse import ( - dataflow_analysis_attached, read_after_write_vars, loop_carried_dependencies + DataFlowAnalysis, read_after_write_vars, loop_carried_dependencies ) from loki.frontend import available_frontends @@ -55,7 +55,7 @@ def test_analyse_live_symbols(frontend): 'v2': {'tmp', 'a', 'n', 'v1', 'v2', 'v3'} } - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert routine.body for assignment in assignments: @@ -103,7 +103,7 @@ def test_analyse_defines_uses_symbols(frontend): for cond in conditionals: _ = cond.uses_symbols - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert fgen(routine) == ref_fgen assert len(FindNodes(Conditional).visit(routine.body)) == 2 assert len(FindNodes(Loop).visit(routine.body)) == 1 @@ -164,7 +164,7 @@ def test_read_after_write_vars(frontend): pragmas = FindNodes(Pragma).visit(routine.body) assert len(pragmas) == 5 - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for pragma in pragmas: assert read_after_write_vars(routine.body, pragma) == vars_at_inspection_node[pragma.content] @@ -210,7 +210,7 @@ def test_read_after_write_vars_conditionals(frontend): pragmas = FindNodes(Pragma).visit(routine.body) assert len(pragmas) == len(vars_at_inspection_node) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for pragma in pragmas: assert read_after_write_vars(routine.body, pragma) == vars_at_inspection_node[pragma.content] @@ -237,7 +237,7 @@ def test_loop_carried_dependencies(frontend): loops = FindNodes(Loop).visit(routine.body) assert len(loops) == 1 - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert loop_carried_dependencies(loops[0]) == {variable_map['b'], variable_map['c']} @pytest.mark.parametrize('frontend', available_frontends()) @@ -273,7 +273,7 @@ def test_analyse_interface(frontend): source = Sourcefile.from_source(fcode, frontend=frontend) routine = source['test'] - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert len(routine.body.defines_symbols) == 0 assert len(routine.body.uses_symbols) == 0 assert len(routine.spec.uses_symbols) == 0 @@ -311,7 +311,7 @@ def test_analyse_imports(frontend, tmp_path): module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path]) routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module, xmods=[tmp_path]) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert len(routine.spec.defines_symbols) == 1 assert 'random_call' in routine.spec.defines_symbols @@ -346,7 +346,7 @@ def test_analyse_enriched_call(frontend): routine.enrich(source.all_subroutines) call = FindNodes(CallStatement).visit(routine.body)[0] - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert all(i in call.defines_symbols for i in ('v_out', 'v_inout')) assert all(i in call.uses_symbols for i in ('v_in', 'v_inout')) @@ -370,7 +370,7 @@ def test_analyse_unenriched_call(frontend): routine = source['test'] call = FindNodes(CallStatement).visit(routine.body)[0] - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert all(i in call.defines_symbols for i in ('v_out', 'v_inout', 'v_in')) assert all(i in call.uses_symbols for i in ('v_in', 'v_inout', 'v_in')) @@ -394,7 +394,7 @@ def test_analyse_allocate_statement(frontend): """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert all(i not in routine.body.defines_symbols for i in ['m', 'n']) assert all(i in routine.body.uses_symbols for i in ['m', 'n']) assert 'a' in routine.body.defines_symbols @@ -417,7 +417,7 @@ def test_analyse_import_kind(frontend): """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert 'real64' in routine.body.uses_symbols assert 'real64' not in routine.body.defines_symbols assert 'a' in routine.body.defines_symbols @@ -446,7 +446,7 @@ def test_analyse_query_memory_attributes(frontend): """.strip() routine = Subroutine.from_source(fcode, frontend=frontend) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert not 'a' in routine.body.uses_symbols assert 'a' in routine.body.defines_symbols assert not 'b' in routine.body.uses_symbols @@ -483,7 +483,7 @@ def test_analyse_call_args_array_slicing(frontend): calls = FindNodes(CallStatement).visit(routine.body) routine.enrich(source.all_subroutines) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert 'n' in calls[0].uses_symbols assert not 'n' in calls[0].defines_symbols assert 'b' in calls[1].uses_symbols @@ -510,7 +510,7 @@ def test_analyse_multiconditional(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) mcond = FindNodes(MultiConditional).visit(routine.body)[0] - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert len(mcond.bodies) == 2 assert len(mcond.else_body) == 1 for b in mcond.bodies: @@ -549,7 +549,7 @@ def test_analyse_maskedstatement(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) mask = FindNodes(MaskedStatement).visit(routine.body)[0] num_bodies = len(mask.bodies) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert len(mask.uses_symbols) == 1 assert len(mask.defines_symbols) == 1 assert 'mask' in mask.uses_symbols @@ -582,7 +582,7 @@ def test_analyse_whileloop(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) loop = FindNodes(WhileLoop).visit(routine.body)[0] cond = FindNodes(Conditional).visit(routine.body)[0] - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): assert len(cond.uses_symbols) == 1 assert 'flag' in cond.uses_symbols assert len(loop.uses_symbols) == 1 @@ -590,7 +590,7 @@ def test_analyse_whileloop(frontend): assert 'ij' in loop.uses_symbols assert all(v in loop.defines_symbols for v in ('ij', 'a')) - with dataflow_analysis_attached(cond): + with DataFlowAnalysis().dataflow_analysis_attached(cond): assert len(loop.uses_symbols) == 1 assert len(loop.defines_symbols) == 2 assert 'ij' in loop.uses_symbols @@ -621,7 +621,7 @@ def test_analyse_associate(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) associates = FindNodes(Associate).visit(routine.body) assigns = FindNodes(Assignment).visit(routine.body) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): # check that associates use variables names in outer scope assert associates[0].uses_symbols == {'in_var'} assert associates[0].defines_symbols == {'a', 'b', 'c'} diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 448eeb493..2120d285f 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -187,6 +187,12 @@ def ir_graph(self, show_comments=False, show_expressions=False, linewidth=40, sy return ir_graph(self, show_comments, show_expressions,linewidth, symgen) + @property + def constants_map(self): + if self.__dict__['_constants_map'] is None: + raise RuntimeError('Need to run constant propagation analysis on the IR first.') + return self.__dict__['_constants_map'] + @property def live_symbols(self): """ @@ -195,9 +201,9 @@ def live_symbols(self): graph. This property is attached to the Node by - :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or + :py:func:`loki.analyse.LiveVariableAnalysis.attach_dataflow_analysis` or when using the - :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` + :py:func:`loki.analyse.LiveVariableAnalysis.dataflow_analysis_attached` context manager. """ if self.__dict__['_live_symbols'] is None: @@ -210,9 +216,9 @@ def defines_symbols(self): Yield the list of symbols (potentially) defined by this node. This property is attached to the Node by - :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or + :py:func:`loki.analyse.LiveVariableAnalysis.attach_dataflow_analysis` or when using the - :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` + :py:func:`loki.analyse.LiveVariableAnalysis.dataflow_analysis_attached` context manager. """ if self.__dict__['_defines_symbols'] is None: @@ -226,9 +232,9 @@ def uses_symbols(self): Yield the list of symbols used by this node before defining it. This property is attached to the Node by - :py:func:`loki.analyse.analyse_dataflow.attach_dataflow_analysis` or + :py:func:`loki.analyse.LiveVariableAnalysis.attach_dataflow_analysis` or when using the - :py:func:`loki.analyse.analyse_dataflow.dataflow_analysis_attached` + :py:func:`loki.analyse.LiveVariableAnalysis.dataflow_analysis_attached` context manager. """ if self.__dict__['_uses_symbols'] is None: diff --git a/loki/ir/tests/test_ir_graph.py b/loki/ir/tests/test_ir_graph.py index b44bd4a15..ee9e8fd36 100644 --- a/loki/ir/tests/test_ir_graph.py +++ b/loki/ir/tests/test_ir_graph.py @@ -10,7 +10,7 @@ import pytest from loki import Sourcefile, graphviz_present -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.ir import Node, FindNodes, ir_graph, GraphCollector @@ -324,7 +324,7 @@ def test_ir_graph_writes_correct_graphs(testdir, test_file, tmp_path): @pytest.mark.parametrize("test_file", test_files) -def test_ir_graph_dataflow_analysis_attached(testdir, test_file, tmp_path): +def test_ir_graph_live_variable_analysis_attached(testdir, test_file, tmp_path): source = Sourcefile.from_file(testdir / test_file, xmods=[tmp_path]) def find_lives_defines_uses(text): @@ -349,7 +349,7 @@ def apply_conversion(text): ) for routine in source.all_subroutines: - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for node in FindNodes(Node).visit(routine.body): node_info, _ = GraphCollector(show_comments=True).visit(node)[0] lives, defines, uses = find_lives_defines_uses(node_info["label"]) diff --git a/loki/transformations/__init__.py b/loki/transformations/__init__.py index 7d6fe4176..9b32975b2 100644 --- a/loki/transformations/__init__.py +++ b/loki/transformations/__init__.py @@ -16,6 +16,7 @@ from loki.transformations.build_system import * # noqa from loki.transformations.argument_shape import * # noqa from loki.transformations.data_offload import * # noqa +from loki.transformations.constant_propagation import * # noqa from loki.transformations.drhook import * # noqa from loki.transformations.extract import * # noqa from loki.transformations.field_api import * # noqa diff --git a/loki/transformations/array_indexing.py b/loki/transformations/array_indexing.py index b01e666a0..b9c05cbf6 100644 --- a/loki/transformations/array_indexing.py +++ b/loki/transformations/array_indexing.py @@ -15,7 +15,7 @@ from loki.batch import Transformation, ProcedureItem from loki.logging import info -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.expression import symbols as sym, simplify, symbolic_op, is_constant from loki.ir import ( nodes as ir, Assignment, Loop, VariableDeclaration, FindNodes, @@ -305,7 +305,7 @@ def promote_variables(routine, variable_names, pos, index=None, size=None): # Create a copy of the tree and apply promotion in-place routine.body = Transformer().visit(routine.body) - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for node, var_list in FindVariables(unique=False, with_ir_node=True).visit(routine.body): # All the variables marked for promotion that appear in this IR node var_list = [v for v in var_list if v.name.lower() in variable_names] diff --git a/loki/transformations/constant_propagation.py b/loki/transformations/constant_propagation.py new file mode 100644 index 000000000..e8721c3df --- /dev/null +++ b/loki/transformations/constant_propagation.py @@ -0,0 +1,41 @@ +# (C) Copyright 2024- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from loki.analyse.constant_propagation_analysis import ConstantPropagationAnalysis +from loki import Transformer, Subroutine + +__all__ = ['ConstantPropagationTransformer'] + +class ConstantPropagationTransformer(Transformer): + + def __init__(self, fold_floats=True, unroll_loops=True): + self.fold_floats = fold_floats + self.unroll_loops = unroll_loops + super().__init__() + + def visit(self, expr, *args, **kwargs): + const_prop = ConstantPropagationAnalysis(self.fold_floats, self.unroll_loops, True) + constants_map = kwargs.get('constants_map', dict()) + try: + declarations_map = const_prop.generate_declarations_map(expr) + # If a user specifies their own map, they probably want it to override these + declarations_map.update(constants_map) + constants_map = declarations_map + except AttributeError: + pass + + is_routine = isinstance(expr, Subroutine) + target = expr.body if is_routine else expr + + target = const_prop.get_attacher().visit(target, constants_map=constants_map) + target = const_prop.get_detacher().visit(target) + + if is_routine: + expr.body = target + return expr + + return target \ No newline at end of file diff --git a/loki/transformations/data_offload/field_offload.py b/loki/transformations/data_offload/field_offload.py index b895de272..d5f400e9b 100644 --- a/loki/transformations/data_offload/field_offload.py +++ b/loki/transformations/data_offload/field_offload.py @@ -5,7 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.batch import Transformation from loki.expression import Array, symbols as sym from loki.ir import ( @@ -72,7 +72,7 @@ def process_driver(self, driver): remove_field_api_view_updates(driver, self.field_group_types) with pragma_regions_attached(driver): - with dataflow_analysis_attached(driver): + with DataFlowAnalysis().dataflow_analysis_attached(driver): for region in FindNodes(ir.PragmaRegion).visit(driver.body): # Only work on active `!$loki data` regions if not region.pragma or not is_loki_pragma(region.pragma, starts_with='data'): diff --git a/loki/transformations/data_offload/global_var.py b/loki/transformations/data_offload/global_var.py index 2bd4b6de8..428cbd00f 100644 --- a/loki/transformations/data_offload/global_var.py +++ b/loki/transformations/data_offload/global_var.py @@ -8,7 +8,7 @@ from collections import defaultdict from itertools import chain -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.batch import Transformation, ProcedureItem, ModuleItem from loki.expression import Scalar, Array from loki.ir import ( @@ -113,7 +113,7 @@ def transform_subroutine(self, routine, **kwargs): import_map.update(scope.import_map) scope = scope.parent - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): # Gather read and written symbols that have been imported uses_imported_symbols = { var for var in routine.body.uses_symbols diff --git a/loki/transformations/extract/outline.py b/loki/transformations/extract/outline.py index bd97d2155..6889963e1 100644 --- a/loki/transformations/extract/outline.py +++ b/loki/transformations/extract/outline.py @@ -5,7 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.expression import symbols as sym, Variable from loki.ir import ( CallStatement, PragmaRegion, Section, FindNodes, @@ -190,7 +190,7 @@ def outline_pragma_regions(routine): parent_vmap = routine.variable_map mapper = {} with pragma_regions_attached(routine): - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for region in FindNodes(PragmaRegion).visit(routine.body): if not is_loki_pragma(region.pragma, starts_with='outline'): continue diff --git a/loki/transformations/parallel/openmp_region.py b/loki/transformations/parallel/openmp_region.py index d3a50f996..c3b99d7d3 100644 --- a/loki/transformations/parallel/openmp_region.py +++ b/loki/transformations/parallel/openmp_region.py @@ -9,7 +9,7 @@ Sub-package with utilities to remove and manipulate parallel OpenMP regions. """ -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.expression import symbols as sym, parse_expr from loki.ir import ( nodes as ir, FindNodes, FindVariables, Transformer, @@ -132,7 +132,7 @@ def add_openmp_regions( ) with pragma_regions_attached(routine): - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for region in FindNodes(ir.PragmaRegion).visit(routine.body): if not is_loki_pragma(region.pragma, starts_with='parallel'): return diff --git a/loki/transformations/pool_allocator.py b/loki/transformations/pool_allocator.py index f5e2d3208..e92defd37 100644 --- a/loki/transformations/pool_allocator.py +++ b/loki/transformations/pool_allocator.py @@ -9,7 +9,7 @@ from collections import defaultdict from loki.batch import Transformation -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.expression import ( Quotient, IntLiteral, LogicLiteral, Variable, Array, Sum, Literal, Product, InlineCall, Comparison, RangeIndex, Cast, @@ -675,7 +675,7 @@ def apply_pool_allocator_to_temporaries(self, routine, item=None): ] # Filter out unused vars - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): temporary_arrays = [ var for var in temporary_arrays if var.name.lower() in routine.body.defines_symbols diff --git a/loki/transformations/raw_stack_allocator.py b/loki/transformations/raw_stack_allocator.py index e4e68067b..50374ad56 100644 --- a/loki/transformations/raw_stack_allocator.py +++ b/loki/transformations/raw_stack_allocator.py @@ -7,7 +7,7 @@ import re -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.backend.fgen import fgen from loki.batch.item import ProcedureItem from loki.batch.transformation import Transformation @@ -498,7 +498,7 @@ def _filter_temporary_arrays(self, routine): ] # Filter out unused vars - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): temporary_arrays = [ var for var in temporary_arrays if var.name.lower() in routine.body.defines_symbols diff --git a/loki/transformations/single_column/vector.py b/loki/transformations/single_column/vector.py index 0ad4dcc01..bf9304ecc 100644 --- a/loki/transformations/single_column/vector.py +++ b/loki/transformations/single_column/vector.py @@ -9,7 +9,7 @@ from more_itertools import split_at -from loki.analyse import dataflow_analysis_attached +from loki.analyse import DataFlowAnalysis from loki.batch import Transformation from loki.expression import symbols as sym, is_dimension_constant from loki.ir import ( @@ -162,7 +162,7 @@ def get_trimmed_sections(cls, routine, horizontal, sections): """ trimmed_sections = () - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for sec in sections: vec_nodes = [node for node in sec if horizontal.index.lower() in node.uses_symbols] start = sec.index(vec_nodes[0]) diff --git a/loki/transformations/tests/test_constant_propagation.py b/loki/transformations/tests/test_constant_propagation.py new file mode 100644 index 000000000..66fbb86ec --- /dev/null +++ b/loki/transformations/tests/test_constant_propagation.py @@ -0,0 +1,1154 @@ +# (C) Copyright 2024- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from loki import ( + Subroutine, FindNodes, Loop, Assignment, Conditional +) +from loki.build import jit_compile +from loki.frontend import available_frontends + +from loki.transformations.constant_propagation import ConstantPropagationTransformer + + +# Basic Types +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_literals(tmp_path, frontend): + fcode = """ +subroutine const_prop_literals + integer :: a, a1 + real :: b, b1 + character (len = 3) :: c, c1 + logical :: d, d1 + + a1 = 1 + a = a1 + + b1 = 1.5 + b = b1 + + c1 = "foo" + c = c1 + + d1 = .true. + d = d1 + +end subroutine const_prop_literals +""" + + routine = Subroutine.from_source(fcode, frontend=frontend) + + filepath = tmp_path/(f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Proof that it compiles & runs, although no runtime testing here + function() + + # Apply transformation + routine.body = ConstantPropagationTransformer().visit(routine.body) + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert 'Assignment:: a = 1' in assignments + assert 'Assignment:: b = 1.5' in assignments + assert 'Assignment:: c = \'foo\'' in assignments + assert 'Assignment:: d = True' in assignments + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_ops_int(tmp_path, frontend): + fcode = """ +subroutine const_prop_ops_int(a_add, a_sub, a_mul, a_pow, a_div, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt) + integer :: a = {a_val} + integer :: b = {b_val} + integer, intent(out) :: a_add, a_sub, a_mul, a_pow, a_div + logical, intent(out) :: a_lt, a_leq, a_eq, a_neq, a_geq, a_gt + + a_add = a + b + a_sub = a - b + a_mul = a * b + a_pow = a ** b + a_div = a / b + a_lt = a < b + a_leq = a <= b + a_eq = a == b + a_neq = a /= b + a_geq = a >= b + a_gt = a > b + +end subroutine const_prop_ops_int +""" + + a_val = 1 + b_val = 2 + routine = Subroutine.from_source(fcode.format(a_val=a_val, b_val=b_val), frontend=frontend) + + filepath = tmp_path/(f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + a_add, a_sub, a_mul, a_pow, a_div, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt = function() + + assert a_add == a_val + b_val + assert a_sub == a_val - b_val + assert a_mul == a_val * b_val + assert a_pow == a_val ** b_val + # Fortran uses integer division by default + assert a_div == a_val // b_val + assert a_lt == a_val < b_val + assert a_leq == a_val <= b_val + assert a_eq == (a_val == b_val) + assert a_neq == (a_val != b_val) + assert a_geq == (a_val >= b_val) + assert a_gt == (a_val > b_val) + + assert len(FindNodes(Assignment).visit(routine.body)) == 11 + + # Apply transformation + routine = ConstantPropagationTransformer().visit(routine) + assert len(FindNodes(Assignment).visit(routine.body)) == 11 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: a_add = {a_val + b_val}' in assignments + assert f'Assignment:: a_sub = {a_val - b_val}' in assignments + assert f'Assignment:: a_mul = {a_val * b_val}' in assignments + assert f'Assignment:: a_pow = {a_val ** b_val}' in assignments + assert f'Assignment:: a_div = {a_val // b_val}' in assignments + assert f'Assignment:: a_lt = {a_val < b_val}' in assignments + assert f'Assignment:: a_leq = {a_val <= b_val}' in assignments + assert f'Assignment:: a_eq = {a_val == b_val}' in assignments + assert f'Assignment:: a_neq = {a_val != b_val}' in assignments + assert f'Assignment:: a_geq = {a_val >= b_val}' in assignments + assert f'Assignment:: a_gt = {a_val > b_val}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + a_add, a_sub, a_mul, a_pow, a_div, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt = new_function() + + assert a_add == a_val + b_val + assert a_sub == a_val - b_val + assert a_mul == a_val * b_val + assert a_pow == a_val ** b_val + # Fortran uses integer division by default + assert a_div == a_val // b_val + assert a_lt == a_val < b_val + assert a_leq == a_val <= b_val + assert a_eq == (a_val == b_val) + assert a_neq == (a_val != b_val) + assert a_geq == (a_val >= b_val) + assert a_gt == (a_val > b_val) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_ops_float(tmp_path, frontend): + fcode = """ +subroutine const_prop_ops_float(a_add, a_sub, a_mul, a_pow, a_div, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt) + real :: a = {a_val} + real :: b = {b_val} + real, intent(out) :: a_add, a_sub, a_mul, a_pow, a_div + logical, intent(out) :: a_lt, a_leq, a_eq, a_neq, a_geq, a_gt + + a_add = a + b + a_sub = a - b + a_mul = a * b + a_pow = a ** b + a_div = a / b + a_lt = a < b + a_leq = a <= b + a_eq = a == b + a_neq = a /= b + a_geq = a >= b + a_gt = a > b + +end subroutine const_prop_ops_float +""" + + a_val = 1.5 + b_val = 2.5 + routine = Subroutine.from_source(fcode.format(a_val=a_val, b_val=b_val), frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + + a_add, a_sub, a_mul, a_pow, a_div, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt = function() + + assert a_add == a_val + b_val + assert a_sub == a_val - b_val + assert a_mul == a_val * b_val + assert a_pow - a_val ** b_val < 1e-6 + assert a_div - a_val / b_val < 1e-6 + assert bool(a_lt) == (a_val < b_val) + assert bool(a_leq) == (a_val <= b_val) + assert bool(a_eq) == (a_val == b_val) + assert bool(a_neq) == (a_val != b_val) + assert bool(a_geq) == (a_val >= b_val) + assert bool(a_gt) == (a_val > b_val) + + assert len(FindNodes(Assignment).visit(routine.body)) == 11 + + # Apply transformation + body = ConstantPropagationTransformer().visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 11 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: a_add = {a_val + b_val}' in assignments + assert f'Assignment:: a_sub = {a_val - b_val}' in assignments + assert f'Assignment:: a_mul = {a_val * b_val}' in assignments + assert f'Assignment:: a_pow = {a_val ** b_val}' in assignments + assert f'Assignment:: a_div = {a_val / b_val}' in assignments + assert f'Assignment:: a_lt = {a_val < b_val}' in assignments + assert f'Assignment:: a_leq = {a_val <= b_val}' in assignments + assert f'Assignment:: a_eq = {a_val == b_val}' in assignments + assert f'Assignment:: a_neq = {a_val != b_val}' in assignments + assert f'Assignment:: a_geq = {a_val >= b_val}' in assignments + assert f'Assignment:: a_gt = {a_val > b_val}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + a_add, a_sub, a_mul, a_pow, a_div, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt = new_function() + + assert a_add == a_val + b_val + assert a_sub == a_val - b_val + assert a_mul == a_val * b_val + assert a_pow - a_val ** b_val < 1e-6 + assert a_div - a_val / b_val < 1e-6 + assert bool(a_lt) == (a_val < b_val) + assert bool(a_leq) == (a_val <= b_val) + assert bool(a_eq) == (a_val == b_val) + assert bool(a_neq) == (a_val != b_val) + assert bool(a_geq) == (a_val >= b_val) + assert bool(a_gt) == (a_val > b_val) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_ops_string(tmp_path, frontend): + fcode = """ +subroutine const_prop_ops_string(a_concat, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt) + character (len = {a_len}) :: a = '{a_val}' + character (len = {b_len}) :: b = '{b_val}' + character (len = {concat_len}), intent(out) :: a_concat + logical, intent(out) :: a_lt, a_leq, a_eq, a_neq, a_geq, a_gt + + a_concat = a // b + a_lt = a < b + a_leq = a <= b + a_eq = a == b + a_neq = a /= b + a_geq = a >= b + a_gt = a > b + +end subroutine const_prop_ops_string +""" + + a_val = 'foo' + b_val = 'bar' + routine = Subroutine.from_source(fcode.format( + a_val=a_val, a_len=len(a_val), b_val=b_val, b_len=len(b_val), concat_len=len(a_val)+len(b_val)), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + a_concat, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt = function() + + assert a_concat.decode('UTF-8') == a_val + b_val + assert bool(a_lt) == (a_val < b_val) + assert bool(a_leq) == (a_val <= b_val) + assert bool(a_eq) == (a_val == b_val) + assert bool(a_neq) == (a_val != b_val) + assert bool(a_geq) == (a_val >= b_val) + assert bool(a_gt) == (a_val > b_val) + + assert len(FindNodes(Assignment).visit(routine.body)) == 7 + + # Apply transformation + routine = ConstantPropagationTransformer().visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 7 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: a_concat = \'{a_val + b_val}\'' in assignments + assert f'Assignment:: a_lt = {a_val < b_val}' in assignments + assert f'Assignment:: a_leq = {a_val <= b_val}' in assignments + assert f'Assignment:: a_eq = {a_val == b_val}' in assignments + assert f'Assignment:: a_neq = {a_val != b_val}' in assignments + assert f'Assignment:: a_geq = {a_val >= b_val}' in assignments + assert f'Assignment:: a_gt = {a_val > b_val}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + a_concat, a_lt, a_leq, a_eq, a_neq, a_geq, a_gt = new_function() + + assert a_concat.decode('UTF-8') == a_val + b_val + assert bool(a_lt) == (a_val < b_val) + assert bool(a_leq) == (a_val <= b_val) + assert bool(a_eq) == (a_val == b_val) + assert bool(a_neq) == (a_val != b_val) + assert bool(a_geq) == (a_val >= b_val) + assert bool(a_gt) == (a_val > b_val) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_ops_bool(tmp_path, frontend): + fcode = """ +subroutine const_prop_ops_bool(a_and, a_or, a_not, a_eqv, a_neqv) + logical :: a = {a_val} + logical :: b = {b_val} + logical, intent(out) :: a_and, a_or, a_not, a_eqv, a_neqv + + a_and = a .and. b + a_or = a .or. b + a_not = .not. a + a_eqv = a .eqv. b + a_neqv = a .neqv. b + +end subroutine const_prop_ops_bool +""" + + a = True + b = False + a_val = '.True.' if a else '.False.' + b_val = '.True.' if b else '.False.' + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + print(routine.to_fortran()) + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + a_and, a_or, a_not, a_eqv, a_neqv = function() + + assert bool(a_and) == (a and b) + assert bool(a_or) == (a or b) + assert bool(a_not) == (not a) + assert bool(a_eqv) == (a == b) + assert bool(a_neqv) == (a != b) + + assert len(FindNodes(Assignment).visit(routine.body)) == 5 + + # Apply transformation + routine = ConstantPropagationTransformer().visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 5 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: a_and = {a and b}' in assignments + assert f'Assignment:: a_or = {a or b}' in assignments + assert f'Assignment:: a_not = {not a}' in assignments + assert f'Assignment:: a_eqv = {a == b}' in assignments + assert f'Assignment:: a_neqv = {a != b}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + a_and, a_or, a_not, a_eqv, a_neqv = new_function() + + assert bool(a_and) == (a and b) + assert bool(a_or) == (a or b) + assert bool(a_not) == (not a) + assert bool(a_eqv) == (a == b) + assert bool(a_neqv) == (a != b) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_ops_bool_short_circuiting(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_ops_bool_short_circuiting(a_and, a_or) + logical :: a = {a_val} + logical :: b + logical, intent(out) :: a_and, a_or + + integer :: n + integer :: i + real :: r + integer, allocatable :: seed(:) + + call random_seed(size = n) + allocate(seed(n)) + seed(:) = 1 + call random_seed(put=seed) + call random_number(r) + + ! floor(r) will be 0, but this is only known at runtime. Statically, it is unknown + b = floor(r) == 0 + + a_and = .not. a .and. b + a_or = a .or. b + +end subroutine test_transform_region_const_prop_ops_bool_short_circuiting +""" + + a = True + a_val = '.True.' if a else '.False.' + routine = Subroutine.from_source(fcode.format( + a_val=a_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + a_and, a_or = function() + + assert (a_and == 1) == (a and False) + assert (a_or == 1) == (a or True) + + # Apply transformation + routine = ConstantPropagationTransformer().visit(routine) + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert len(assignments) == 4 + assert f'Assignment:: a_and = False' in assignments + assert f'Assignment:: a_or = True' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + a_and, a_or = new_function() + + assert (a_and == 1) == (a and False) + assert (a_or == 1) == (a or True) + + +# For loops +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_basic(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_for_loop_basic(c) + integer :: a = {a_val} + integer :: b = {b_val} + integer :: i + integer, intent(out) :: c + + c = 0 + do i = 1, a + c = c + b + end do + +end subroutine test_transform_region_const_prop_for_loop_basic +""" + + a_val = 5 + b_val = 3 + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == a_val * b_val + + # Apply transformation + routine = ConstantPropagationTransformer().visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == a_val + 1 + assert len(FindNodes(Loop).visit(routine.body)) == 0 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + for i in range(1, a_val+1): + assert f'Assignment:: c = {b_val*i}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == a_val * b_val + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_basic_no_unroll(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_for_loop_basic_no_unroll(c) + integer :: a = {a_val} + integer :: b = {b_val} + integer :: i, d + integer, intent(out) :: c + + c = 0 + d = 0 + do i = 1, a + c = a * b + d = a * i + end do + c = c * 2 + d = d * 2 + +end subroutine test_transform_region_const_prop_for_loop_basic_no_unroll +""" + + a_val = 5 + b_val = 3 + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == a_val * b_val * 2 + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 6 + assert len(FindNodes(Loop).visit(routine.body)) == 1 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: c = {a_val * b_val}' in assignments + assert f'Assignment:: d = {a_val}*i' in assignments + + assert f'Assignment:: c = {a_val * b_val * 2}' in assignments + assert f'Assignment:: d = d*2' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == a_val * b_val * 2 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_neg_range_no_unroll(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_for_loop_neg_range_no_unroll(c) + integer :: a = {a_val} + integer :: b = {b_val} + integer :: i, d + integer, intent(out) :: c + + c = 0 + do i = 1, a + c = b + end do + c = c + +end subroutine test_transform_region_const_prop_for_loop_neg_range_no_unroll +""" + + a_val = -1 + b_val = 3 + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == 0 + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 3 + assert len(FindNodes(Loop).visit(routine.body)) == 1 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: c = {b_val}' in assignments + assert len([a for a in assignments if f'Assignment:: c = 0' == a]) == 2 + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == 0 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_never_taken_no_unroll(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_for_loop_never_taken_no_unroll(c) + integer :: a = {a_val} + integer :: n, b + integer :: i + real :: r + integer, allocatable :: seed(:) + integer, intent(out) :: c + + call random_seed(size = n) + allocate(seed(n)) + seed(:) = 1 + call random_seed(put=seed) + call random_number(r) + + b = 0 + c = 0 + ! floor(r) will be 0, but this is only known at runtime. Statically, it is unknown + do i = 1, floor(r) + b = a + end do + + c = b + +end subroutine test_transform_region_const_prop_for_loop_never_taken_no_unroll +""" + + a_val = 5 + routine = Subroutine.from_source(fcode.format( + a_val=a_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == 0 + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert len(assignments) == 5 + assert f'Assignment:: b = {a_val}' in assignments + + assert len(FindNodes(Loop).visit(routine.body)) == 1 + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == 0 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_never_taken_nested_no_unroll(tmp_path, frontend): + fcode = """ +subroutine test_const_prop_for_loop_never_taken_nested_no_unroll(c) + integer :: a = {a_val} + integer :: n, b + integer :: i, j + real :: r + integer, allocatable :: seed(:) + integer, intent(out) :: c + + call random_seed(size = n) + allocate(seed(n)) + seed(:) = 1 + call random_seed(put=seed) + call random_number(r) + + b = 0 + c = 0 + ! floor(r) will be 0, but this is only known at runtime. Statically, it is unknown + do i = 1, floor(r) + b = a + do j = 1,5 + b = 6 + end do + b = b + end do + + c = b + +end subroutine test_const_prop_for_loop_never_taken_nested_no_unroll +""" + + a_val = 5 + routine = Subroutine.from_source(fcode.format( + a_val=a_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == 0 + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert len(assignments) == 7 + assert f'Assignment:: b = {a_val}' in assignments + assert f'Assignment:: b = {6}' in assignments + + assert len(FindNodes(Loop).visit(routine.body)) == 2 + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == 0 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_double_never_taken_nested_no_unroll(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_for_loop_never_taken_no_unroll(c) + integer :: a = {a_val} + integer :: n, b + integer :: i + real :: r + integer, allocatable :: seed(:) + integer, intent(out) :: c + + call random_seed(size = n) + allocate(seed(n)) + seed(:) = 1 + call random_seed(put=seed) + call random_number(r) + + b = 0 + c = 0 + ! floor(r) will be 0, but this is only known at runtime. Statically, it is unknown + do i = 1, floor(r) + b = a + do j = 1, floor(r) + b = 6 + end do + b = b + end do + + c = b + +end subroutine test_transform_region_const_prop_for_loop_never_taken_no_unroll +""" + + a_val = 5 + routine = Subroutine.from_source(fcode.format( + a_val=a_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == 0 + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert len(assignments) == 7 + assert f'Assignment:: b = {a_val}' in assignments + assert f'Assignment:: b = b' in assignments + + assert len(FindNodes(Loop).visit(routine.body)) == 2 + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == 0 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_nested(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_for_loop_nested(c) + integer :: a = {a_val} + integer :: b = {b_val} + integer :: i, j + integer, intent(out) :: c + + c = 0 + do i = 1, a + do j = 1, b + c = c + i + j + end do + end do + +end subroutine test_transform_region_const_prop_for_loop_nested +""" + + a_val = 5 + b_val = 3 + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == b_val*(a_val*(a_val+1))/2 + a_val*(b_val*(b_val+1))/2 + + # Apply transformation + routine = ConstantPropagationTransformer().visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == a_val * b_val + 1 + assert len(FindNodes(Loop).visit(routine.body)) == 0 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + tmp = 0 + for i in range(1, a_val+1): + for j in range(1, b_val+1): + tmp = tmp + i + j + assert f'Assignment:: c = {tmp}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == b_val*(a_val*(a_val+1))/2 + a_val*(b_val*(b_val+1))/2 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_nested_no_unroll(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_for_loop_nested_no_unroll(c) + integer :: a = {a_val} + integer :: b = {b_val} + integer :: i, j + integer, intent(out) :: c + + c = 0 + do i = 1, a + do j = 1, b + c = b + end do + c = c + end do + + c = c + +end subroutine test_transform_region_const_prop_for_loop_nested_no_unroll +""" + + a_val = 5 + b_val = 3 + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == b_val + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 4 + assert len(FindNodes(Loop).visit(routine.body)) == 2 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert len([a for a in assignments if f'Assignment:: c = {b_val}' == a]) == 3 + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == b_val + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_nested_siblings(tmp_path, frontend): + fcode = """ + subroutine test_transform_region_const_prop_loop_nested_siblings(c) + integer :: a = {a_val} + integer :: b = {b_val} + integer :: i, j ,k + integer, intent(out) :: c + + c = 0 + do i = 1, a + do j = 1, b + c = c + i + j + end do + do k = 1, b + c = c + i + k + end do + end do + + end subroutine test_transform_region_const_prop_loop_nested_siblings + """ + + a_val = 5 + b_val = 3 + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == b_val*(a_val*(a_val+1)) + a_val*(b_val*(b_val+1)) + + # Apply transformation + routine = ConstantPropagationTransformer().visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 2 * b_val * a_val + 1 + assert len(FindNodes(Loop).visit(routine.body)) == 0 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + tmp = 0 + for i in range(1, a_val+1): + for j in range(1, b_val+1): + tmp = tmp + i + j + assert f'Assignment:: c = {tmp}' in assignments + for j in range(1, b_val+1): + tmp = tmp + i + j + assert f'Assignment:: c = {tmp}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == b_val*(a_val*(a_val+1)) + a_val*(b_val*(b_val+1)) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_for_loop_nested_siblings_no_unroll(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_loop_nested_siblings_no_unroll(c) + integer :: a = {a_val} + integer :: b = {b_val} + integer :: i, j ,k + integer, intent(out) :: c + + c = 0 + do i = 1, a + do j = 1, b + c = a + end do + c = c + do k = 1, b + c = b + end do + c = c + end do + c = c +end subroutine test_transform_region_const_prop_loop_nested_siblings_no_unroll +""" + + a_val = 5 + b_val = 3 + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == b_val + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 6 + assert len(FindNodes(Loop).visit(routine.body)) == 3 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: c = {a_val}' in assignments + assert f'Assignment:: c = {b_val}' in assignments + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == b_val + + +# Conditionals +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_conditional_basic(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_conditional_basic(c) + integer :: a = {a_val} + integer :: b = {b_val} + logical :: cond = {cond_val} + integer, intent(out) :: c + + if (cond) then + c = a + else + c = b + endif + + c = c + +end subroutine test_transform_region_const_prop_conditional_basic +""" + + a_val = 5 + b_val = 3 + cond = True + + cond_val = '.True.' if cond else '.False.' + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val, cond_val=cond_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == a_val if cond else b_val + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 3 + assert len(FindNodes(Conditional).visit(routine.body)) == 1 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: c = {a_val}' in assignments + assert f'Assignment:: c = {b_val}' in assignments + assert len([a for a in assignments if a == f'Assignment:: c = {a_val if cond else b_val}']) == 2 + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == a_val if cond else b_val + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_region_const_prop_conditional_dynamic_condition(tmp_path, frontend): + fcode = """ +subroutine test_transform_region_const_prop_conditional_dynamic_condition(c) + integer :: a = {a_val} + integer :: b = {b_val} + logical :: cond + integer, intent(out) :: c + + integer :: n + integer :: i + real :: r + integer, allocatable :: seed(:) + + call random_seed(size = n) + allocate(seed(n)) + seed(:) = 1 + call random_seed(put=seed) + call random_number(r) + + ! floor(r) will be 0, but this is only known at runtime. Statically, it is unknown + cond = floor(r) == 0 + + if (cond) then + c = a + else + c = b + endif + + c = c + +end subroutine test_transform_region_const_prop_conditional_dynamic_condition +""" + + a_val = 5 + b_val = 3 + cond = True + + cond_val = '.True.' if cond else '.False.' + routine = Subroutine.from_source(fcode.format( + a_val=a_val, b_val=b_val, cond_val=cond_val), + frontend=frontend) + + filepath = tmp_path / (f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + # Test the reference solution + c = function() + + assert c == a_val if cond else b_val + + # Apply transformation + routine = ConstantPropagationTransformer(unroll_loops=False).visit(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 5 + assert len(FindNodes(Conditional).visit(routine.body)) == 1 + + assignments = [str(a) for a in FindNodes(Assignment).visit(routine.body)] + assert f'Assignment:: c = {a_val}' in assignments + assert f'Assignment:: c = {b_val}' in assignments + assert f'Assignment:: c = c' in assignments + assert len([a for a in assignments if a == f'Assignment:: c = {a_val if cond else b_val}']) == 1 + + # Test transformation + + new_filepath = tmp_path / f'{routine.name}_proped_{frontend}.f90' + new_function = jit_compile(routine, filepath=new_filepath, objname=routine.name) + + c = new_function() + + assert c == a_val if cond else b_val + +# TODO: Conditionals & Loop interactions +# TODO: Assignments \ No newline at end of file diff --git a/loki/transformations/transform_loop.py b/loki/transformations/transform_loop.py index d94b8dfe9..48970f252 100644 --- a/loki/transformations/transform_loop.py +++ b/loki/transformations/transform_loop.py @@ -15,7 +15,7 @@ import numpy as np from loki.analyse import ( - dataflow_analysis_attached, read_after_write_vars, + DataFlowAnalysis, read_after_write_vars, loop_carried_dependencies ) from loki.expression import ( @@ -566,7 +566,7 @@ def do_loop_fission(routine, promote=True, warn_loop_carries=True): if not pragma_loops: return - with optional(promote or warn_loop_carries, dataflow_analysis_attached, routine): + with optional(promote or warn_loop_carries, DataFlowAnalysis().dataflow_analysis_attached, routine): for pragma in pragma_loops: # Now, sort the loops enclosing each pragma from outside to inside and # keep only the ones relevant for fission @@ -599,7 +599,7 @@ def do_loop_fission(routine, promote=True, warn_loop_carries=True): # Warn about broken loop-carried dependencies if warn_loop_carries: - with dataflow_analysis_attached(routine): + with DataFlowAnalysis().dataflow_analysis_attached(routine): for pragma, loop_carries in loop_carried_vars.items(): loop, *remainder = fission_trafo.split_loops[pragma] if not remainder: @@ -665,19 +665,33 @@ def visit_Loop(self, o, depth=None): if self.warn_iterations_length and len(unroll_range) > 32: warning(f"Unrolling loop over 32 iterations ({len(unroll_range)}), this may take a long time & " f"provide few performance benefits.") + neighbour_loops = len([c for c in o.body if isinstance(c, Loop)]) > 1 + counter_in_bounds = o.variable in [v for l in FindNodes(Loop).visit(o.body) + for v in FindVariables().visit(l.bounds)] + + if not neighbour_loops and not counter_in_bounds: + # Use depth first (faster) + # TODO: test + if depth is None or depth >= 1: + o = Loop( + variable=o.variable, + body=self.visit(o.body, depth=depth), + bounds=o.bounds + ) + acc = functools.reduce(op.add, + [SubstituteExpressions({o.variable: sym.IntLiteral(i)}).visit(o.body) for i in unroll_range], + ()) + return as_tuple(flatten(acc)) + else: + # Use breadth first (slower) + acc = functools.reduce(op.add, + [SubstituteExpressions({o.variable: sym.IntLiteral(i)}).visit(o.body) for i in unroll_range], + ()) - acc = functools.reduce(op.add, - [ - # Create a copy of the loop body for every value of the iterator - SubstituteExpressions({o.variable: sym.IntLiteral(i)}).visit(o.body) - for i in unroll_range - ], - ()) - - if depth is None or depth >= 1: - acc = [self.visit(a, depth=depth) for a in acc] + if depth is None or depth >= 1: + acc = [self.visit(a, depth=depth) for a in acc] - return as_tuple(flatten(acc)) + return as_tuple(flatten(acc)) _pragma = tuple( p for p in o.pragma if not is_loki_pragma(p, starts_with='loop-unroll')