diff --git a/src/oqd_compiler_infrastructure/__init__.py b/src/oqd_compiler_infrastructure/__init__.py index f45617f..8735fe5 100644 --- a/src/oqd_compiler_infrastructure/__init__.py +++ b/src/oqd_compiler_infrastructure/__init__.py @@ -14,7 +14,7 @@ from .base import PassBase from .interface import TypeReflectBaseModel, VisitableBaseModel -from .rewriter import Chain, FixedPoint, RewriterBase +from .rewriter import Chain, FixedPoint, RewriterBase, Filter from .rule import ConversionRule, PrettyPrint, RewriteRule, RuleBase from .walk import In, Level, Post, Pre, WalkBase @@ -34,4 +34,5 @@ "Chain", "FixedPoint", "RewriterBase", + "Filter", ] diff --git a/src/oqd_compiler_infrastructure/rewriter.py b/src/oqd_compiler_infrastructure/rewriter.py index 3242454..530617d 100644 --- a/src/oqd_compiler_infrastructure/rewriter.py +++ b/src/oqd_compiler_infrastructure/rewriter.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable from oqd_compiler_infrastructure.base import PassBase +import copy ######################################################################################## -__all__ = [ - "RewriterBase", - "Chain", - "FixedPoint", -] +__all__ = ["RewriterBase", "Chain", "FixedPoint", "Filter"] ######################################################################################## @@ -34,8 +32,6 @@ class RewriterBase(PassBase): This code was inspired by [SynbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl), [Liang.jl](https://github.com/Roger-luo/Liang.jl/tree/main/src/rewrite). """ - pass - ######################################################################################## @@ -52,7 +48,6 @@ def __init__(self, *rules): super().__init__() self.rules = list(rules) - pass @property def children(self): @@ -74,16 +69,29 @@ class FixedPoint(RewriterBase): This code was inspired by [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl#L117C8-L117C16), [Liang.jl](https://github.com/Roger-luo/Liang.jl/blob/main/src/rewrite/fixpoint.jl). """ - def __init__(self, rule, *, max_iter=1000): + def __init__(self, rule, *, max_iter=1000, reuse=False): super().__init__() - self.rule = rule + self._rule = rule self.max_iter = max_iter - pass + self.reuse = reuse + + self._rule_copies = [] + + @property + def rule(self): + if self.reuse: + return self._rule + + self._rule_copies.append(copy.deepcopy(self._rule)) + return self._rule_copies[-1] @property def children(self): - return [self.rule] + if self.reuse: + return [self._rule] + + return self._rule_copies def map(self, model): i = 0 @@ -96,3 +104,41 @@ def map(self, model): new_model = _model i += 1 + + +######################################################################################## + + +class Filter(RewriterBase): + def __init__(self, function: Callable[[Any], bool], rule: PassBase, *, reuse=False): + super().__init__() + + self._rule = rule + self.function = function + self.reuse = reuse + + self._rule_copies = [] + + @property + def rule(self): + if self.reuse: + return self._rule + + self._rule_copies.append(copy.deepcopy(self._rule)) + return self._rule_copies[-1] + + @property + def children(self): + if self.reuse: + return [self._rule] + + return self._rule_copies + + def map(self, model): + return self.filter(model) + + def filter(self, model): + if not self.function(model): + return model + + return self.rule(model) diff --git a/src/oqd_compiler_infrastructure/rule.py b/src/oqd_compiler_infrastructure/rule.py index d0ae5af..a67be45 100644 --- a/src/oqd_compiler_infrastructure/rule.py +++ b/src/oqd_compiler_infrastructure/rule.py @@ -60,8 +60,6 @@ def map(self, model): def generic_map(self, model): return model - pass - class ConversionRule(RuleBase): """ diff --git a/src/oqd_compiler_infrastructure/walk.py b/src/oqd_compiler_infrastructure/walk.py index 9b65cef..2d74bec 100644 --- a/src/oqd_compiler_infrastructure/walk.py +++ b/src/oqd_compiler_infrastructure/walk.py @@ -41,7 +41,6 @@ def __init__(self, rule: PassBase, *, reverse: bool = False): self.rule = rule self.reverse = reverse - pass @staticmethod def controlled_reverse(iterable, reverse, *, restore_type=False): @@ -72,8 +71,6 @@ def walk(self, model): def generic_walk(self, model): return self.rule(model) - pass - ######################################################################################## diff --git a/tests/examples/advanced_rewrite.py b/tests/examples/advanced_rewrite.py index e68fc0c..821316a 100644 --- a/tests/examples/advanced_rewrite.py +++ b/tests/examples/advanced_rewrite.py @@ -12,32 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from oqd_compiler_infrastructure.interface import TypeReflectBaseModel from oqd_compiler_infrastructure.rule import RewriteRule from oqd_compiler_infrastructure.walk import Post + # AST data structures (same as before) class Expression(TypeReflectBaseModel): """Base class for arithmetic expressions. This class serves as the foundation for all expression types in the AST. """ - pass + class Number(Expression): """Represents a numeric literal. Attributes: value (float): The numeric value of the literal. """ + value: float + class Variable(Expression): """Represents a variable in an expression. Attributes: name (str): The name of the variable. """ + name: str + class BinaryOp(Expression): """Represents a binary operation. Attributes: @@ -45,10 +49,12 @@ class BinaryOp(Expression): left (Expression): The left operand. right (Expression): The right operand. """ + op: str # '+', '-', '*', '/' left: Expression right: Expression + class AdvancedAlgebraicSimplifier(RewriteRule): """Applies advanced algebraic simplification rules. Rules implemented: @@ -58,6 +64,7 @@ class AdvancedAlgebraicSimplifier(RewriteRule): - (x + y) - y = x - Distributive law: a * (b + c) = (a * b) + (a * c) """ + # Implements additional algebraic identities like subtraction and distribution def map_BinaryOp(self, model): @@ -69,37 +76,43 @@ def map_BinaryOp(self, model): """ # x - x = 0 # Rule: subtracting identical terms yields zero - if model.op == '-' and self._expressions_equal(model.left, model.right): + if model.op == "-" and self._expressions_equal(model.left, model.right): return Number(value=0) # x + (-x) = 0 # Rule: x + (-1 * x) => 0 - if (model.op == '+' and - isinstance(model.right, BinaryOp) and - model.right.op == '*' and - isinstance(model.right.left, Number) and - model.right.left.value == -1 and - self._expressions_equal(model.left, model.right.right)): + if ( + model.op == "+" + and isinstance(model.right, BinaryOp) + and model.right.op == "*" + and isinstance(model.right.left, Number) + and model.right.left.value == -1 + and self._expressions_equal(model.left, model.right.right) + ): return Number(value=0) # Distributive law: a * (b + c) = (a * b) + (a * c) # Rule: a * (b + c) => (a*b) + (a*c) - if (model.op == '*' and - isinstance(model.right, BinaryOp) and - model.right.op in ['+', '-']): + if ( + model.op == "*" + and isinstance(model.right, BinaryOp) + and model.right.op in ["+", "-"] + ): # a * (b + c) -> (a * b) + (a * c) return BinaryOp( op=model.right.op, - left=BinaryOp(op='*', left=model.left, right=model.right.left), - right=BinaryOp(op='*', left=model.left, right=model.right.right) + left=BinaryOp(op="*", left=model.left, right=model.right.left), + right=BinaryOp(op="*", left=model.left, right=model.right.right), ) # (x + y) - y = x # Rule: (x + y) - y => x - if (model.op == '-' and - isinstance(model.left, BinaryOp) and - model.left.op == '+' and - self._expressions_equal(model.left.right, model.right)): + if ( + model.op == "-" + and isinstance(model.left, BinaryOp) + and model.left.op == "+" + and self._expressions_equal(model.left.right, model.right) + ): return model.left.left return model @@ -113,7 +126,7 @@ def _expressions_equal(self, expr1, expr2): bool: True if structurally equal, False otherwise. """ # Compare types and recursively compare sub-expressions - if type(expr1) != type(expr2): + if type(expr1) is not type(expr2): return False if isinstance(expr1, Number): @@ -123,12 +136,15 @@ def _expressions_equal(self, expr1, expr2): return expr1.name == expr2.name if isinstance(expr1, BinaryOp): - return (expr1.op == expr2.op and - self._expressions_equal(expr1.left, expr2.left) and - self._expressions_equal(expr1.right, expr2.right)) + return ( + expr1.op == expr2.op + and self._expressions_equal(expr1.left, expr2.left) + and self._expressions_equal(expr1.right, expr2.right) + ) return False + def print_expr(expr): """Convert an expression into a readable string. Args: @@ -145,50 +161,32 @@ def print_expr(expr): return f"({print_expr(expr.left)} {expr.op} {print_expr(expr.right)})" return str(expr) + def main(): """Main function to demonstrate advanced algebraic simplification.""" # Prepare test cases and run the AdvancedAlgebraicSimplifier # Create test expressions test_cases = [ # x - x = 0 - BinaryOp( - op='-', - left=Variable(name='x'), - right=Variable(name='x') - ), - + BinaryOp(op="-", left=Variable(name="x"), right=Variable(name="x")), # x + (-1 * x) = 0 BinaryOp( - op='+', - left=Variable(name='x'), - right=BinaryOp( - op='*', - left=Number(value=-1), - right=Variable(name='x') - ) + op="+", + left=Variable(name="x"), + right=BinaryOp(op="*", left=Number(value=-1), right=Variable(name="x")), ), - # a * (b + c) -> (a * b) + (a * c) BinaryOp( - op='*', - left=Variable(name='a'), - right=BinaryOp( - op='+', - left=Variable(name='b'), - right=Variable(name='c') - ) + op="*", + left=Variable(name="a"), + right=BinaryOp(op="+", left=Variable(name="b"), right=Variable(name="c")), ), - # (x + y) - y = x BinaryOp( - op='-', - left=BinaryOp( - op='+', - left=Variable(name='x'), - right=Variable(name='y') - ), - right=Variable(name='y') - ) + op="-", + left=BinaryOp(op="+", left=Variable(name="x"), right=Variable(name="y")), + right=Variable(name="y"), + ), ] # Create simplifier with Post traversal @@ -202,5 +200,6 @@ def main(): result = simplifier(expr) print(f"Simplified: {print_expr(result)}") + if __name__ == "__main__": - main() + main() diff --git a/tests/test_walk.py b/tests/test_walk.py index 979106e..79c9df0 100644 --- a/tests/test_walk.py +++ b/tests/test_walk.py @@ -34,7 +34,6 @@ def __init__(self): def generic_map(self, model): self.string += f"\n{self.current_index}: {model}" self.current_index += 1 - pass class X(VisitableBaseModel): @@ -485,6 +484,7 @@ def test_in_list(): printer(inp) assert printer.children[0].string == "\n0: a\n1: ['a', 'b']\n2: b" + def test_in_dict(): "Test of In Walk on a dict" inp = {"a": "a", "b": "b"} @@ -494,6 +494,7 @@ def test_in_dict(): printer(inp) assert printer.children[0].string == "\n0: a\n1: {'a': 'a', 'b': 'b'}\n2: b" + def test_in_VisitableBaseModel(): "Test of In Walk on a VisitableBaseModel" inp = X(a="a", b="b") @@ -503,6 +504,7 @@ def test_in_VisitableBaseModel(): printer(inp) assert printer.children[0].string == "\n0: a\n1: a='a' b='b'\n2: b" + def test_in_nested_list(): "Test of In Walk on a nested list" inp = [["a", ["b", "c"]], ["d", "e", "f"]] @@ -515,6 +517,7 @@ def test_in_nested_list(): == "\n0: a\n1: ['a', ['b', 'c']]\n2: b\n3: ['b', 'c']\n4: c\n5: [['a', ['b', 'c']], ['d', 'e', 'f']]\n6: d\n7: e\n8: ['d', 'e', 'f']\n9: f" ) + def test_reversed_in_list(): "Test of reversed In Walk on a list" inp = ["a", "b"] @@ -524,6 +527,7 @@ def test_reversed_in_list(): printer(inp) assert printer.children[0].string == "\n0: b\n1: ['a', 'b']\n2: a" + def test_reversed_in_dict(): "Test of reversed In Walk on a dict" inp = {"a": "a", "b": "b"} @@ -533,6 +537,7 @@ def test_reversed_in_dict(): printer(inp) assert printer.children[0].string == "\n0: b\n1: {'a': 'a', 'b': 'b'}\n2: a" + def test_reversed_in_VisitableBaseModel(): "Test of reversed In Walk on a VisitableBaseModel" inp = X(a="a", b="b") @@ -542,6 +547,7 @@ def test_reversed_in_VisitableBaseModel(): printer(inp) assert printer.children[0].string == "\n0: b\n1: a='a' b='b'\n2: a" + def test_reversed_in_nested_list(): "Test of reversed In Walk on a nested list" inp = [["a", ["b", "c"]], ["d", "e", "f"]] @@ -554,6 +560,7 @@ def test_reversed_in_nested_list(): == "\n0: f\n1: e\n2: ['d', 'e', 'f']\n3: d\n4: [['a', ['b', 'c']], ['d', 'e', 'f']]\n5: c\n6: ['b', 'c']\n7: b\n8: ['a', ['b', 'c']]\n9: a" ) + def test_in_TypeReflectBaseModel(): "Test of In Walk on a TypeReflectBaseModel" inp = Y(a="a", b="b") @@ -563,6 +570,7 @@ def test_in_TypeReflectBaseModel(): printer(inp) assert printer.children[0].string == "\n0: a\n1: class_='Y' a='a' b='b'\n2: b" + def test_reversed_in_TypeReflectBaseModel(): "Test of reversed In Walk on a TypeReflectBaseModel" inp = Y(a="a", b="b") @@ -572,6 +580,7 @@ def test_reversed_in_TypeReflectBaseModel(): printer(inp) assert printer.children[0].string == "\n0: b\n1: class_='Y' a='a' b='b'\n2: a" + def test_reversed_in_TypeReflectBaseModel_no_attribute(): "Test of reversed In Walk on a TypeReflectBaseModel with no attribute for N" x = X(a="x1", b="x2")