Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/oqd_compiler_infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .base import PassBase
from .interface import TypeReflectBaseModel, VisitableBaseModel
from .rewriter import Chain, FixedPoint, RewriterBase
from .rewriter import Chain, FixedPoint, RewriterBase, Filter
from .rule import ConversionRule, PrettyPrint, RewriteRule, RuleBase
from .walk import In, Level, Post, Pre, WalkBase

Expand All @@ -34,4 +34,5 @@
"Chain",
"FixedPoint",
"RewriterBase",
"Filter",
]
70 changes: 58 additions & 12 deletions src/oqd_compiler_infrastructure/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable
from oqd_compiler_infrastructure.base import PassBase
import copy

########################################################################################

__all__ = [
"RewriterBase",
"Chain",
"FixedPoint",
]
__all__ = ["RewriterBase", "Chain", "FixedPoint", "Filter"]

########################################################################################

Expand All @@ -34,8 +32,6 @@ class RewriterBase(PassBase):
This code was inspired by [SynbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl), [Liang.jl](https://github.com/Roger-luo/Liang.jl/tree/main/src/rewrite).
"""

pass


########################################################################################

Expand All @@ -52,7 +48,6 @@ def __init__(self, *rules):
super().__init__()

self.rules = list(rules)
pass

@property
def children(self):
Expand All @@ -74,16 +69,29 @@ class FixedPoint(RewriterBase):
This code was inspired by [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl#L117C8-L117C16), [Liang.jl](https://github.com/Roger-luo/Liang.jl/blob/main/src/rewrite/fixpoint.jl).
"""

def __init__(self, rule, *, max_iter=1000):
def __init__(self, rule, *, max_iter=1000, reuse=False):
super().__init__()

self.rule = rule
self._rule = rule
self.max_iter = max_iter
pass
self.reuse = reuse

self._rule_copies = []

@property
def rule(self):
if self.reuse:
return self._rule

self._rule_copies.append(copy.deepcopy(self._rule))
return self._rule_copies[-1]

@property
def children(self):
return [self.rule]
if self.reuse:
return [self._rule]

return self._rule_copies

def map(self, model):
i = 0
Expand All @@ -96,3 +104,41 @@ def map(self, model):

new_model = _model
i += 1


########################################################################################


class Filter(RewriterBase):
def __init__(self, function: Callable[[Any], bool], rule: PassBase, *, reuse=False):
super().__init__()

self._rule = rule
self.function = function
self.reuse = reuse

self._rule_copies = []

@property
def rule(self):
if self.reuse:
return self._rule

self._rule_copies.append(copy.deepcopy(self._rule))
return self._rule_copies[-1]

@property
def children(self):
if self.reuse:
return [self._rule]

return self._rule_copies

def map(self, model):
return self.filter(model)

def filter(self, model):
if not self.function(model):
return model

return self.rule(model)
2 changes: 0 additions & 2 deletions src/oqd_compiler_infrastructure/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def map(self, model):
def generic_map(self, model):
return model

pass


class ConversionRule(RuleBase):
"""
Expand Down
3 changes: 0 additions & 3 deletions src/oqd_compiler_infrastructure/walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, rule: PassBase, *, reverse: bool = False):

self.rule = rule
self.reverse = reverse
pass

@staticmethod
def controlled_reverse(iterable, reverse, *, restore_type=False):
Expand Down Expand Up @@ -72,8 +71,6 @@ def walk(self, model):
def generic_walk(self, model):
return self.rule(model)

pass


########################################################################################

Expand Down
105 changes: 52 additions & 53 deletions tests/examples/advanced_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from oqd_compiler_infrastructure.interface import TypeReflectBaseModel
from oqd_compiler_infrastructure.rule import RewriteRule
from oqd_compiler_infrastructure.walk import Post


# AST data structures (same as before)
class Expression(TypeReflectBaseModel):
"""Base class for arithmetic expressions.
This class serves as the foundation for all expression types in the AST.
"""
pass


class Number(Expression):
"""Represents a numeric literal.
Attributes:
value (float): The numeric value of the literal.
"""

value: float


class Variable(Expression):
"""Represents a variable in an expression.
Attributes:
name (str): The name of the variable.
"""

name: str


class BinaryOp(Expression):
"""Represents a binary operation.
Attributes:
op (str): The operator (e.g., '+', '-', '*', '/').
left (Expression): The left operand.
right (Expression): The right operand.
"""

op: str # '+', '-', '*', '/'
left: Expression
right: Expression


class AdvancedAlgebraicSimplifier(RewriteRule):
"""Applies advanced algebraic simplification rules.
Rules implemented:
Expand All @@ -58,6 +64,7 @@ class AdvancedAlgebraicSimplifier(RewriteRule):
- (x + y) - y = x
- Distributive law: a * (b + c) = (a * b) + (a * c)
"""

# Implements additional algebraic identities like subtraction and distribution

def map_BinaryOp(self, model):
Expand All @@ -69,37 +76,43 @@ def map_BinaryOp(self, model):
"""
# x - x = 0
# Rule: subtracting identical terms yields zero
if model.op == '-' and self._expressions_equal(model.left, model.right):
if model.op == "-" and self._expressions_equal(model.left, model.right):
return Number(value=0)

# x + (-x) = 0
# Rule: x + (-1 * x) => 0
if (model.op == '+' and
isinstance(model.right, BinaryOp) and
model.right.op == '*' and
isinstance(model.right.left, Number) and
model.right.left.value == -1 and
self._expressions_equal(model.left, model.right.right)):
if (
model.op == "+"
and isinstance(model.right, BinaryOp)
and model.right.op == "*"
and isinstance(model.right.left, Number)
and model.right.left.value == -1
and self._expressions_equal(model.left, model.right.right)
):
return Number(value=0)

# Distributive law: a * (b + c) = (a * b) + (a * c)
# Rule: a * (b + c) => (a*b) + (a*c)
if (model.op == '*' and
isinstance(model.right, BinaryOp) and
model.right.op in ['+', '-']):
if (
model.op == "*"
and isinstance(model.right, BinaryOp)
and model.right.op in ["+", "-"]
):
# a * (b + c) -> (a * b) + (a * c)
return BinaryOp(
op=model.right.op,
left=BinaryOp(op='*', left=model.left, right=model.right.left),
right=BinaryOp(op='*', left=model.left, right=model.right.right)
left=BinaryOp(op="*", left=model.left, right=model.right.left),
right=BinaryOp(op="*", left=model.left, right=model.right.right),
)

# (x + y) - y = x
# Rule: (x + y) - y => x
if (model.op == '-' and
isinstance(model.left, BinaryOp) and
model.left.op == '+' and
self._expressions_equal(model.left.right, model.right)):
if (
model.op == "-"
and isinstance(model.left, BinaryOp)
and model.left.op == "+"
and self._expressions_equal(model.left.right, model.right)
):
return model.left.left

return model
Expand All @@ -113,7 +126,7 @@ def _expressions_equal(self, expr1, expr2):
bool: True if structurally equal, False otherwise.
"""
# Compare types and recursively compare sub-expressions
if type(expr1) != type(expr2):
if type(expr1) is not type(expr2):
return False

if isinstance(expr1, Number):
Expand All @@ -123,12 +136,15 @@ def _expressions_equal(self, expr1, expr2):
return expr1.name == expr2.name

if isinstance(expr1, BinaryOp):
return (expr1.op == expr2.op and
self._expressions_equal(expr1.left, expr2.left) and
self._expressions_equal(expr1.right, expr2.right))
return (
expr1.op == expr2.op
and self._expressions_equal(expr1.left, expr2.left)
and self._expressions_equal(expr1.right, expr2.right)
)

return False


def print_expr(expr):
"""Convert an expression into a readable string.
Args:
Expand All @@ -145,50 +161,32 @@ def print_expr(expr):
return f"({print_expr(expr.left)} {expr.op} {print_expr(expr.right)})"
return str(expr)


def main():
"""Main function to demonstrate advanced algebraic simplification."""
# Prepare test cases and run the AdvancedAlgebraicSimplifier
# Create test expressions
test_cases = [
# x - x = 0
BinaryOp(
op='-',
left=Variable(name='x'),
right=Variable(name='x')
),

BinaryOp(op="-", left=Variable(name="x"), right=Variable(name="x")),
# x + (-1 * x) = 0
BinaryOp(
op='+',
left=Variable(name='x'),
right=BinaryOp(
op='*',
left=Number(value=-1),
right=Variable(name='x')
)
op="+",
left=Variable(name="x"),
right=BinaryOp(op="*", left=Number(value=-1), right=Variable(name="x")),
),

# a * (b + c) -> (a * b) + (a * c)
BinaryOp(
op='*',
left=Variable(name='a'),
right=BinaryOp(
op='+',
left=Variable(name='b'),
right=Variable(name='c')
)
op="*",
left=Variable(name="a"),
right=BinaryOp(op="+", left=Variable(name="b"), right=Variable(name="c")),
),

# (x + y) - y = x
BinaryOp(
op='-',
left=BinaryOp(
op='+',
left=Variable(name='x'),
right=Variable(name='y')
),
right=Variable(name='y')
)
op="-",
left=BinaryOp(op="+", left=Variable(name="x"), right=Variable(name="y")),
right=Variable(name="y"),
),
]

# Create simplifier with Post traversal
Expand All @@ -202,5 +200,6 @@ def main():
result = simplifier(expr)
print(f"Simplified: {print_expr(result)}")


if __name__ == "__main__":
main()
main()
Loading
Loading