From a61313c9a34d4bdb88d7be491e6e44e0abb7e764 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Fri, 16 Jan 2026 12:16:25 -0500 Subject: [PATCH] [feat] Implemented filter rewriter that executes copies of the rule on model matching the filter function. Added flag in FixedPoint and Filter rewriter for reusing rule or creating copies of the rule (default) --- src/oqd_compiler_infrastructure/__init__.py | 3 +- src/oqd_compiler_infrastructure/rewriter.py | 70 ++++++++++--- src/oqd_compiler_infrastructure/rule.py | 2 - src/oqd_compiler_infrastructure/walk.py | 3 - tests/examples/advanced_rewrite.py | 105 ++++++++++---------- tests/test_walk.py | 11 +- 6 files changed, 122 insertions(+), 72 deletions(-) diff --git a/src/oqd_compiler_infrastructure/__init__.py b/src/oqd_compiler_infrastructure/__init__.py index f45617f..8735fe5 100644 --- a/src/oqd_compiler_infrastructure/__init__.py +++ b/src/oqd_compiler_infrastructure/__init__.py @@ -14,7 +14,7 @@ from .base import PassBase from .interface import TypeReflectBaseModel, VisitableBaseModel -from .rewriter import Chain, FixedPoint, RewriterBase +from .rewriter import Chain, FixedPoint, RewriterBase, Filter from .rule import ConversionRule, PrettyPrint, RewriteRule, RuleBase from .walk import In, Level, Post, Pre, WalkBase @@ -34,4 +34,5 @@ "Chain", "FixedPoint", "RewriterBase", + "Filter", ] diff --git a/src/oqd_compiler_infrastructure/rewriter.py b/src/oqd_compiler_infrastructure/rewriter.py index 3242454..530617d 100644 --- a/src/oqd_compiler_infrastructure/rewriter.py +++ b/src/oqd_compiler_infrastructure/rewriter.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable from oqd_compiler_infrastructure.base import PassBase +import copy ######################################################################################## -__all__ = [ - "RewriterBase", - "Chain", - "FixedPoint", -] +__all__ = ["RewriterBase", "Chain", "FixedPoint", "Filter"] ######################################################################################## @@ -34,8 +32,6 @@ class RewriterBase(PassBase): This code was inspired by [SynbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl), [Liang.jl](https://github.com/Roger-luo/Liang.jl/tree/main/src/rewrite). """ - pass - ######################################################################################## @@ -52,7 +48,6 @@ def __init__(self, *rules): super().__init__() self.rules = list(rules) - pass @property def children(self): @@ -74,16 +69,29 @@ class FixedPoint(RewriterBase): This code was inspired by [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl#L117C8-L117C16), [Liang.jl](https://github.com/Roger-luo/Liang.jl/blob/main/src/rewrite/fixpoint.jl). """ - def __init__(self, rule, *, max_iter=1000): + def __init__(self, rule, *, max_iter=1000, reuse=False): super().__init__() - self.rule = rule + self._rule = rule self.max_iter = max_iter - pass + self.reuse = reuse + + self._rule_copies = [] + + @property + def rule(self): + if self.reuse: + return self._rule + + self._rule_copies.append(copy.deepcopy(self._rule)) + return self._rule_copies[-1] @property def children(self): - return [self.rule] + if self.reuse: + return [self._rule] + + return self._rule_copies def map(self, model): i = 0 @@ -96,3 +104,41 @@ def map(self, model): new_model = _model i += 1 + + +######################################################################################## + + +class Filter(RewriterBase): + def __init__(self, function: Callable[[Any], bool], rule: PassBase, *, reuse=False): + super().__init__() + + self._rule = rule + self.function = function + self.reuse = reuse + + self._rule_copies = [] + + @property + def rule(self): + if self.reuse: + return self._rule + + self._rule_copies.append(copy.deepcopy(self._rule)) + return self._rule_copies[-1] + + @property + def children(self): + if self.reuse: + return [self._rule] + + return self._rule_copies + + def map(self, model): + return self.filter(model) + + def filter(self, model): + if not self.function(model): + return model + + return self.rule(model) diff --git a/src/oqd_compiler_infrastructure/rule.py b/src/oqd_compiler_infrastructure/rule.py index d0ae5af..a67be45 100644 --- a/src/oqd_compiler_infrastructure/rule.py +++ b/src/oqd_compiler_infrastructure/rule.py @@ -60,8 +60,6 @@ def map(self, model): def generic_map(self, model): return model - pass - class ConversionRule(RuleBase): """ diff --git a/src/oqd_compiler_infrastructure/walk.py b/src/oqd_compiler_infrastructure/walk.py index 9b65cef..2d74bec 100644 --- a/src/oqd_compiler_infrastructure/walk.py +++ b/src/oqd_compiler_infrastructure/walk.py @@ -41,7 +41,6 @@ def __init__(self, rule: PassBase, *, reverse: bool = False): self.rule = rule self.reverse = reverse - pass @staticmethod def controlled_reverse(iterable, reverse, *, restore_type=False): @@ -72,8 +71,6 @@ def walk(self, model): def generic_walk(self, model): return self.rule(model) - pass - ######################################################################################## diff --git a/tests/examples/advanced_rewrite.py b/tests/examples/advanced_rewrite.py index e68fc0c..821316a 100644 --- a/tests/examples/advanced_rewrite.py +++ b/tests/examples/advanced_rewrite.py @@ -12,32 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from oqd_compiler_infrastructure.interface import TypeReflectBaseModel from oqd_compiler_infrastructure.rule import RewriteRule from oqd_compiler_infrastructure.walk import Post + # AST data structures (same as before) class Expression(TypeReflectBaseModel): """Base class for arithmetic expressions. This class serves as the foundation for all expression types in the AST. """ - pass + class Number(Expression): """Represents a numeric literal. Attributes: value (float): The numeric value of the literal. """ + value: float + class Variable(Expression): """Represents a variable in an expression. Attributes: name (str): The name of the variable. """ + name: str + class BinaryOp(Expression): """Represents a binary operation. Attributes: @@ -45,10 +49,12 @@ class BinaryOp(Expression): left (Expression): The left operand. right (Expression): The right operand. """ + op: str # '+', '-', '*', '/' left: Expression right: Expression + class AdvancedAlgebraicSimplifier(RewriteRule): """Applies advanced algebraic simplification rules. Rules implemented: @@ -58,6 +64,7 @@ class AdvancedAlgebraicSimplifier(RewriteRule): - (x + y) - y = x - Distributive law: a * (b + c) = (a * b) + (a * c) """ + # Implements additional algebraic identities like subtraction and distribution def map_BinaryOp(self, model): @@ -69,37 +76,43 @@ def map_BinaryOp(self, model): """ # x - x = 0 # Rule: subtracting identical terms yields zero - if model.op == '-' and self._expressions_equal(model.left, model.right): + if model.op == "-" and self._expressions_equal(model.left, model.right): return Number(value=0) # x + (-x) = 0 # Rule: x + (-1 * x) => 0 - if (model.op == '+' and - isinstance(model.right, BinaryOp) and - model.right.op == '*' and - isinstance(model.right.left, Number) and - model.right.left.value == -1 and - self._expressions_equal(model.left, model.right.right)): + if ( + model.op == "+" + and isinstance(model.right, BinaryOp) + and model.right.op == "*" + and isinstance(model.right.left, Number) + and model.right.left.value == -1 + and self._expressions_equal(model.left, model.right.right) + ): return Number(value=0) # Distributive law: a * (b + c) = (a * b) + (a * c) # Rule: a * (b + c) => (a*b) + (a*c) - if (model.op == '*' and - isinstance(model.right, BinaryOp) and - model.right.op in ['+', '-']): + if ( + model.op == "*" + and isinstance(model.right, BinaryOp) + and model.right.op in ["+", "-"] + ): # a * (b + c) -> (a * b) + (a * c) return BinaryOp( op=model.right.op, - left=BinaryOp(op='*', left=model.left, right=model.right.left), - right=BinaryOp(op='*', left=model.left, right=model.right.right) + left=BinaryOp(op="*", left=model.left, right=model.right.left), + right=BinaryOp(op="*", left=model.left, right=model.right.right), ) # (x + y) - y = x # Rule: (x + y) - y => x - if (model.op == '-' and - isinstance(model.left, BinaryOp) and - model.left.op == '+' and - self._expressions_equal(model.left.right, model.right)): + if ( + model.op == "-" + and isinstance(model.left, BinaryOp) + and model.left.op == "+" + and self._expressions_equal(model.left.right, model.right) + ): return model.left.left return model @@ -113,7 +126,7 @@ def _expressions_equal(self, expr1, expr2): bool: True if structurally equal, False otherwise. """ # Compare types and recursively compare sub-expressions - if type(expr1) != type(expr2): + if type(expr1) is not type(expr2): return False if isinstance(expr1, Number): @@ -123,12 +136,15 @@ def _expressions_equal(self, expr1, expr2): return expr1.name == expr2.name if isinstance(expr1, BinaryOp): - return (expr1.op == expr2.op and - self._expressions_equal(expr1.left, expr2.left) and - self._expressions_equal(expr1.right, expr2.right)) + return ( + expr1.op == expr2.op + and self._expressions_equal(expr1.left, expr2.left) + and self._expressions_equal(expr1.right, expr2.right) + ) return False + def print_expr(expr): """Convert an expression into a readable string. Args: @@ -145,50 +161,32 @@ def print_expr(expr): return f"({print_expr(expr.left)} {expr.op} {print_expr(expr.right)})" return str(expr) + def main(): """Main function to demonstrate advanced algebraic simplification.""" # Prepare test cases and run the AdvancedAlgebraicSimplifier # Create test expressions test_cases = [ # x - x = 0 - BinaryOp( - op='-', - left=Variable(name='x'), - right=Variable(name='x') - ), - + BinaryOp(op="-", left=Variable(name="x"), right=Variable(name="x")), # x + (-1 * x) = 0 BinaryOp( - op='+', - left=Variable(name='x'), - right=BinaryOp( - op='*', - left=Number(value=-1), - right=Variable(name='x') - ) + op="+", + left=Variable(name="x"), + right=BinaryOp(op="*", left=Number(value=-1), right=Variable(name="x")), ), - # a * (b + c) -> (a * b) + (a * c) BinaryOp( - op='*', - left=Variable(name='a'), - right=BinaryOp( - op='+', - left=Variable(name='b'), - right=Variable(name='c') - ) + op="*", + left=Variable(name="a"), + right=BinaryOp(op="+", left=Variable(name="b"), right=Variable(name="c")), ), - # (x + y) - y = x BinaryOp( - op='-', - left=BinaryOp( - op='+', - left=Variable(name='x'), - right=Variable(name='y') - ), - right=Variable(name='y') - ) + op="-", + left=BinaryOp(op="+", left=Variable(name="x"), right=Variable(name="y")), + right=Variable(name="y"), + ), ] # Create simplifier with Post traversal @@ -202,5 +200,6 @@ def main(): result = simplifier(expr) print(f"Simplified: {print_expr(result)}") + if __name__ == "__main__": - main() + main() diff --git a/tests/test_walk.py b/tests/test_walk.py index 979106e..79c9df0 100644 --- a/tests/test_walk.py +++ b/tests/test_walk.py @@ -34,7 +34,6 @@ def __init__(self): def generic_map(self, model): self.string += f"\n{self.current_index}: {model}" self.current_index += 1 - pass class X(VisitableBaseModel): @@ -485,6 +484,7 @@ def test_in_list(): printer(inp) assert printer.children[0].string == "\n0: a\n1: ['a', 'b']\n2: b" + def test_in_dict(): "Test of In Walk on a dict" inp = {"a": "a", "b": "b"} @@ -494,6 +494,7 @@ def test_in_dict(): printer(inp) assert printer.children[0].string == "\n0: a\n1: {'a': 'a', 'b': 'b'}\n2: b" + def test_in_VisitableBaseModel(): "Test of In Walk on a VisitableBaseModel" inp = X(a="a", b="b") @@ -503,6 +504,7 @@ def test_in_VisitableBaseModel(): printer(inp) assert printer.children[0].string == "\n0: a\n1: a='a' b='b'\n2: b" + def test_in_nested_list(): "Test of In Walk on a nested list" inp = [["a", ["b", "c"]], ["d", "e", "f"]] @@ -515,6 +517,7 @@ def test_in_nested_list(): == "\n0: a\n1: ['a', ['b', 'c']]\n2: b\n3: ['b', 'c']\n4: c\n5: [['a', ['b', 'c']], ['d', 'e', 'f']]\n6: d\n7: e\n8: ['d', 'e', 'f']\n9: f" ) + def test_reversed_in_list(): "Test of reversed In Walk on a list" inp = ["a", "b"] @@ -524,6 +527,7 @@ def test_reversed_in_list(): printer(inp) assert printer.children[0].string == "\n0: b\n1: ['a', 'b']\n2: a" + def test_reversed_in_dict(): "Test of reversed In Walk on a dict" inp = {"a": "a", "b": "b"} @@ -533,6 +537,7 @@ def test_reversed_in_dict(): printer(inp) assert printer.children[0].string == "\n0: b\n1: {'a': 'a', 'b': 'b'}\n2: a" + def test_reversed_in_VisitableBaseModel(): "Test of reversed In Walk on a VisitableBaseModel" inp = X(a="a", b="b") @@ -542,6 +547,7 @@ def test_reversed_in_VisitableBaseModel(): printer(inp) assert printer.children[0].string == "\n0: b\n1: a='a' b='b'\n2: a" + def test_reversed_in_nested_list(): "Test of reversed In Walk on a nested list" inp = [["a", ["b", "c"]], ["d", "e", "f"]] @@ -554,6 +560,7 @@ def test_reversed_in_nested_list(): == "\n0: f\n1: e\n2: ['d', 'e', 'f']\n3: d\n4: [['a', ['b', 'c']], ['d', 'e', 'f']]\n5: c\n6: ['b', 'c']\n7: b\n8: ['a', ['b', 'c']]\n9: a" ) + def test_in_TypeReflectBaseModel(): "Test of In Walk on a TypeReflectBaseModel" inp = Y(a="a", b="b") @@ -563,6 +570,7 @@ def test_in_TypeReflectBaseModel(): printer(inp) assert printer.children[0].string == "\n0: a\n1: class_='Y' a='a' b='b'\n2: b" + def test_reversed_in_TypeReflectBaseModel(): "Test of reversed In Walk on a TypeReflectBaseModel" inp = Y(a="a", b="b") @@ -572,6 +580,7 @@ def test_reversed_in_TypeReflectBaseModel(): printer(inp) assert printer.children[0].string == "\n0: b\n1: class_='Y' a='a' b='b'\n2: a" + def test_reversed_in_TypeReflectBaseModel_no_attribute(): "Test of reversed In Walk on a TypeReflectBaseModel with no attribute for N" x = X(a="x1", b="x2")