Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/oqd_compiler_infrastructure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

from .base import PassBase
from .interface import TypeReflectBaseModel, VisitableBaseModel
from .rewriter import Chain, FixedPoint, RewriterBase
from .rewriter import Chain, FixedPoint, RewriterBase, Filter
from .rule import ConversionRule, PrettyPrint, RewriteRule, RuleBase
from .walk import In, Level, Post, Pre, WalkBase
from .match import Match

__all__ = [
"VisitableBaseModel",
Expand All @@ -34,4 +35,6 @@
"Chain",
"FixedPoint",
"RewriterBase",
"Filter",
"Match",
]
244 changes: 244 additions & 0 deletions src/oqd_compiler_infrastructure/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import copy
from typing import Any, Dict, Optional

from .base import PassBase
from .rewriter import RewriterBase
from .rule import ConversionRule, RewriteRule
from pydantic import BaseModel

import ast


########################################################################################

__all__ = ["Match"]

########################################################################################


class MatchResult(BaseModel):
state: bool
variables: Optional[Dict[str, Any]]

def __bool__(self):
return self.state

def add(self, other):
if not isinstance(other, MatchResult):
raise TypeError("Unsupported addition of MatchResult and other type.")

state = self.state and other.state
self.state = state

if state:
variables = dict(**self.variables, **other.variables)
self.update(variables)

return MatchResult(state=state, variables=variables if state else None)

def update(self, variables):
self.variables.update(variables)

def __getitem__(self, key):
return self.variables[key]


########################################################################################


def _reduce_pattern(pattern):
if isinstance(pattern, str):
pattern = ast.parse(pattern)

while pattern and not isinstance(pattern, ast.Call):
if isinstance(pattern, ast.Module):
pattern = pattern.body[0]
continue

if isinstance(pattern, ast.Expr):
pattern = pattern.value
continue

raise TypeError("Unsupported AST pattern")

return pattern


class _MatchPattern(ConversionRule):
def __init__(self, pattern: ast.AST | str):
super().__init__()

self.pattern = pattern

@property
def pattern(self):
return self._pattern

@pattern.setter
def pattern(self, value):
self._pattern = _reduce_pattern(value)

def map_dict(self, model, operands):
return self.generic_map(model, operands)

def map_list(self, model, operands):
return self.generic_map(model, operands)

def map_tuple(self, model, operands):
return self.generic_map(model, operands)

def generic_map(self, model, operands):
node_names = (
{x.id for x in self.pattern.func.slice.elts}
if isinstance(self.pattern.func, ast.Subscript)
and isinstance(self.pattern.func.value, ast.Name)
and self.pattern.func.value.id == ("Union")
else {ast.unparse(self.pattern.func)}
)

pattern = self.pattern

result = MatchResult(state=True, variables={})

if node_names.intersection(map(lambda x: x.__name__, model.__class__.__mro__)):
for a in pattern.args:
if isinstance(a, ast.Name):
result.update({a.id: model})
continue
raise TypeError(
f"Unsupported type ({a.__class__.__name__}) when matching args"
)

for k in pattern.keywords:
if isinstance(k.value, ast.Call):
self.pattern = k.value

match model:
case dict():
_result = self(model.get(k.arg))
case _:
_result = self(getattr(model, k.arg))

result.add(_result)
continue

if isinstance(k.value, ast.Constant) and k.value.value == Ellipsis:
continue

if isinstance(k.value, ast.Name):
result.update({k.value.id: getattr(model, k.arg)})
continue

raise TypeError(
f"Unsupported type ({k.value.__class__.__name__}) when matching keywords"
)

return result
else:
return MatchResult(state=False, variables=None)


class _MatchSubstitute(RewriteRule):
def __init__(self, pattern: ast.AST | str, substitutions: Dict[str, Any]):
super().__init__()

self.pattern = pattern
self.substitutions = substitutions

@property
def pattern(self):
return self._pattern

@pattern.setter
def pattern(self, value):
self._pattern = _reduce_pattern(value)

def generic_map(self, model):
node_names = (
{x.id for x in self.pattern.func.slice.elts}
if isinstance(self.pattern.func, ast.Subscript)
and isinstance(self.pattern.func.value, ast.Name)
and self.pattern.func.value.id == ("Union")
else {ast.unparse(self.pattern.func)}
)

pattern = self.pattern

if node_names.intersection(map(lambda x: x.__name__, model.__class__.__mro__)):
if pattern.args and isinstance(pattern.args[0], ast.Name):
return self.substitutions[pattern.args[0].id]

if pattern.args:
raise TypeError(
f"Unsupported type ({pattern.args.__class__.__name__}) when matching args"
)

new_model = copy.deepcopy(model)
for k in pattern.keywords:
if isinstance(k.value, (ast.Call, ast.Name)):
self.pattern = k.value

match new_model:
case dict():
new_model[k.arg] = self(model.get(k.arg))
case _:
setattr(new_model, k.arg, self(getattr(model, k.arg)))

continue

if isinstance(k.value, ast.Constant) and k.value.value == Ellipsis:
continue

raise TypeError(
f"Unsupported type ({k.value.__class__.__name__}) when matching keywords"
)

return new_model
else:
raise ValueError("Pattern does not match model")


########################################################################################


class Match(RewriterBase):
def __init__(self, pattern: str, rule: PassBase, *, reuse=False):
super().__init__()

self._rule = rule
self.pattern = pattern
self.reuse = reuse

self._rule_copies = []

@property
def rule(self):
if self.reuse:
return self._rule

self._rule_copies.append(copy.deepcopy(self._rule))
return self._rule_copies[-1]

@property
def children(self):
if self.reuse:
return [self._rule]

return self._rule_copies

def map(self, model):
return self.match(model)

def match(self, model):
self.match_result = _MatchPattern(pattern=self.pattern)(model)

if not self.match_result:
return model

substitutions = {
k: self.rule(v) for k, v in self.match_result.variables.items()
}

return _MatchSubstitute(pattern=self.pattern, substitutions=substitutions)(
model
)
70 changes: 58 additions & 12 deletions src/oqd_compiler_infrastructure/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable
from oqd_compiler_infrastructure.base import PassBase
import copy

########################################################################################

__all__ = [
"RewriterBase",
"Chain",
"FixedPoint",
]
__all__ = ["RewriterBase", "Chain", "FixedPoint", "Filter"]

########################################################################################

Expand All @@ -34,8 +32,6 @@ class RewriterBase(PassBase):
This code was inspired by [SynbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl), [Liang.jl](https://github.com/Roger-luo/Liang.jl/tree/main/src/rewrite).
"""

pass


########################################################################################

Expand All @@ -52,7 +48,6 @@ def __init__(self, *rules):
super().__init__()

self.rules = list(rules)
pass

@property
def children(self):
Expand All @@ -74,16 +69,29 @@ class FixedPoint(RewriterBase):
This code was inspired by [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl/blob/master/src/rewriters.jl#L117C8-L117C16), [Liang.jl](https://github.com/Roger-luo/Liang.jl/blob/main/src/rewrite/fixpoint.jl).
"""

def __init__(self, rule, *, max_iter=1000):
def __init__(self, rule, *, max_iter=1000, reuse=False):
super().__init__()

self.rule = rule
self._rule = rule
self.max_iter = max_iter
pass
self.reuse = reuse

self._rule_copies = []

@property
def rule(self):
if self.reuse:
return self._rule

self._rule_copies.append(copy.deepcopy(self._rule))
return self._rule_copies[-1]

@property
def children(self):
return [self.rule]
if self.reuse:
return [self._rule]

return self._rule_copies

def map(self, model):
i = 0
Expand All @@ -96,3 +104,41 @@ def map(self, model):

new_model = _model
i += 1


########################################################################################


class Filter(RewriterBase):
def __init__(self, function: Callable[[Any], bool], rule: PassBase, *, reuse=False):
super().__init__()

self._rule = rule
self.function = function
self.reuse = reuse

self._rule_copies = []

@property
def rule(self):
if self.reuse:
return self._rule

self._rule_copies.append(copy.deepcopy(self._rule))
return self._rule_copies[-1]

@property
def children(self):
if self.reuse:
return [self._rule]

return self._rule_copies

def map(self, model):
return self.filter(model)

def filter(self, model):
if not self.function(model):
return model

return self.rule(model)
2 changes: 0 additions & 2 deletions src/oqd_compiler_infrastructure/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ def map(self, model):
def generic_map(self, model):
return model

pass


class ConversionRule(RuleBase):
"""
Expand Down
Loading
Loading