From 8d34ab9bffd13023af546e35c4d5d569a263756f Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 13:57:31 -0400 Subject: [PATCH 01/14] + updating attribute Traits --- src/kirin/dialects/py/indexing.py | 2 +- src/kirin/ir/__init__.py | 2 +- src/kirin/ir/attrs/abc.py | 22 ++++++++++++++++++++-- src/kirin/ir/nodes/stmt.py | 8 ++++---- src/kirin/ir/traits/__init__.py | 2 +- src/kirin/ir/traits/abc.py | 11 +++++++---- src/kirin/ir/traits/basic.py | 16 ++++++++-------- src/kirin/ir/traits/callable.py | 6 +++--- src/kirin/ir/traits/symbol.py | 6 +++--- 9 files changed, 48 insertions(+), 27 deletions(-) diff --git a/src/kirin/dialects/py/indexing.py b/src/kirin/dialects/py/indexing.py index c4c65bcc7..77970f307 100644 --- a/src/kirin/dialects/py/indexing.py +++ b/src/kirin/dialects/py/indexing.py @@ -30,7 +30,7 @@ @dataclass(frozen=True, eq=False) -class GetItemLike(ir.StmtTrait, Generic[GetItemLikeStmt]): +class GetItemLike(ir.Trait, Generic[GetItemLikeStmt]): @abstractmethod def get_object(self, stmt: GetItemLikeStmt) -> ir.SSAValue: ... diff --git a/src/kirin/ir/__init__.py b/src/kirin/ir/__init__.py index c43b107d5..55fdb43cc 100644 --- a/src/kirin/ir/__init__.py +++ b/src/kirin/ir/__init__.py @@ -23,9 +23,9 @@ from kirin.ir.method import Method as Method from kirin.ir.traits import ( Pure as Pure, + Trait as Trait, HasParent as HasParent, MaybePure as MaybePure, - StmtTrait as StmtTrait, RegionTrait as RegionTrait, SymbolTable as SymbolTable, ConstantLike as ConstantLike, diff --git a/src/kirin/ir/attrs/abc.py b/src/kirin/ir/attrs/abc.py index 53cc75654..72ce15aef 100644 --- a/src/kirin/ir/attrs/abc.py +++ b/src/kirin/ir/attrs/abc.py @@ -1,8 +1,9 @@ from abc import ABC, ABCMeta, abstractmethod -from typing import TYPE_CHECKING, ClassVar, Optional +from typing import TYPE_CHECKING, TypeVar, ClassVar, Optional from dataclasses import field, dataclass from kirin.print import Printable +from kirin.ir.traits import Trait from kirin.lattice.abc import LatticeMeta, SingletonMeta if TYPE_CHECKING: @@ -41,10 +42,27 @@ class Attribute(ABC, Printable, metaclass=AttributeMeta): """Dialect of the attribute. (default: None)""" name: ClassVar[str] = field(init=False, repr=False) """Name of the attribute in printing and other text format.""" - traits: ClassVar[frozenset[str]] = field( + traits: ClassVar[frozenset[Trait["Attribute"]]] = field( default=frozenset(), init=False, repr=False ) """Set of Attribute traits.""" @abstractmethod def __hash__(self) -> int: ... + + TraitType = TypeVar("TraitType", bound=Trait["Attribute"]) + + def get_trait(self, trait: type[TraitType]) -> Optional[TraitType]: + """Get the trait of the attribute. + + Args: + trait (type[Trait]): the trait to get + + Returns: + Optional[Trait]: the trait if found, None otherwise + """ + for t in self.traits: + if isinstance(t, trait): + return t + + return None diff --git a/src/kirin/ir/nodes/stmt.py b/src/kirin/ir/nodes/stmt.py index 5d90293a0..c10f35f11 100644 --- a/src/kirin/ir/nodes/stmt.py +++ b/src/kirin/ir/nodes/stmt.py @@ -8,7 +8,7 @@ from kirin.print import Printer, Printable from kirin.ir.ssa import SSAValue, ResultValue from kirin.ir.use import Use -from kirin.ir.traits import StmtTrait +from kirin.ir.traits import Trait from kirin.ir.attrs.abc import Attribute from kirin.ir.nodes.base import IRNode from kirin.ir.nodes.view import MutableSequenceView @@ -132,7 +132,7 @@ class Statement(IRNode["Block"]): name: ClassVar[str] dialect: ClassVar[Dialect | None] = field(default=None, init=False, repr=False) - traits: ClassVar[frozenset[StmtTrait]] + traits: ClassVar[frozenset[Trait["Statement"]]] = frozenset() _arg_groups: ClassVar[frozenset[str]] = frozenset() _args: tuple[SSAValue, ...] = field(init=False) @@ -665,7 +665,7 @@ def get_attr_or_prop(self, key: str) -> Attribute | None: return self.attributes.get(key) @classmethod - def has_trait(cls, trait_type: type[StmtTrait]) -> bool: + def has_trait(cls, trait_type: type[Trait["Statement"]]) -> bool: """Check if the Statement has a specific trait. Args: @@ -679,7 +679,7 @@ def has_trait(cls, trait_type: type[StmtTrait]) -> bool: return True return False - TraitType = TypeVar("TraitType", bound=StmtTrait) + TraitType = TypeVar("TraitType", bound=Trait["Statement"]) @classmethod def get_trait(cls, trait: type[TraitType]) -> TraitType | None: diff --git a/src/kirin/ir/traits/__init__.py b/src/kirin/ir/traits/__init__.py index 9aca37d15..2232522c9 100644 --- a/src/kirin/ir/traits/__init__.py +++ b/src/kirin/ir/traits/__init__.py @@ -10,7 +10,7 @@ """ from .abc import ( - StmtTrait as StmtTrait, + Trait as Trait, RegionTrait as RegionTrait, PythonLoweringTrait as PythonLoweringTrait, ) diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index b1ce2fbe4..6dea840cf 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -9,11 +9,14 @@ from kirin.graph import Graph +IRNodeType = TypeVar("IRNodeType") + + @dataclass(frozen=True) -class StmtTrait(ABC): +class Trait(ABC, Generic[IRNodeType]): """Base class for all statement traits.""" - def verify(self, stmt: "Statement"): + def verify(self, node: IRNodeType): pass @@ -21,7 +24,7 @@ def verify(self, stmt: "Statement"): @dataclass(frozen=True) -class RegionTrait(StmtTrait, Generic[GraphType]): +class RegionTrait(Trait["Region"], Generic[GraphType]): """A trait that indicates the properties of the statement's region.""" @abstractmethod @@ -33,7 +36,7 @@ def get_graph(self, region: "Region") -> GraphType: ... @dataclass(frozen=True) -class PythonLoweringTrait(StmtTrait, Generic[StatementType, ASTNode]): +class PythonLoweringTrait(Trait[StatementType], Generic[StatementType, ASTNode]): """A trait that indicates that a statement can be lowered from Python AST.""" @abstractmethod diff --git a/src/kirin/ir/traits/basic.py b/src/kirin/ir/traits/basic.py index 89a8594e1..c97a87d17 100644 --- a/src/kirin/ir/traits/basic.py +++ b/src/kirin/ir/traits/basic.py @@ -1,14 +1,14 @@ from typing import TYPE_CHECKING from dataclasses import dataclass -from .abc import StmtTrait +from .abc import Trait if TYPE_CHECKING: from kirin.ir import Statement @dataclass(frozen=True) -class Pure(StmtTrait): +class Pure(Trait): """A trait that indicates that a statement is pure, i.e., it has no side effects. """ @@ -17,7 +17,7 @@ class Pure(StmtTrait): @dataclass(frozen=True) -class MaybePure(StmtTrait): +class MaybePure(Trait): """A trait that indicates the statement may be pure, i.e., a call statement can be pure if the callee is pure. """ @@ -40,7 +40,7 @@ def set_pure(cls, stmt: "Statement") -> None: @dataclass(frozen=True) -class ConstantLike(StmtTrait): +class ConstantLike(Trait): """A trait that indicates that a statement is constant-like, i.e., it represents a constant value. """ @@ -49,7 +49,7 @@ class ConstantLike(StmtTrait): @dataclass(frozen=True) -class IsTerminator(StmtTrait): +class IsTerminator(Trait): """A trait that indicates that a statement is a terminator, i.e., it terminates a block. """ @@ -58,19 +58,19 @@ class IsTerminator(StmtTrait): @dataclass(frozen=True) -class NoTerminator(StmtTrait): +class NoTerminator(Trait): """A trait that indicates that the region of a statement has no terminator.""" pass @dataclass(frozen=True) -class IsolatedFromAbove(StmtTrait): +class IsolatedFromAbove(Trait): pass @dataclass(frozen=True) -class HasParent(StmtTrait): +class HasParent(Trait): """A trait that indicates that a statement has a parent statement. """ diff --git a/src/kirin/ir/traits/callable.py b/src/kirin/ir/traits/callable.py index 6358cb918..12c32a1dc 100644 --- a/src/kirin/ir/traits/callable.py +++ b/src/kirin/ir/traits/callable.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar from dataclasses import dataclass -from kirin.ir.traits.abc import StmtTrait +from kirin.ir.traits.abc import Trait if TYPE_CHECKING: from kirin.ir import Region, Statement @@ -12,7 +12,7 @@ @dataclass(frozen=True) -class CallableStmtInterface(StmtTrait, Generic[StmtType]): +class CallableStmtInterface(Trait, Generic[StmtType]): """A trait that indicates that a statement is a callable statement. A callable statement is a statement that can be called as a function. @@ -26,7 +26,7 @@ def get_callable_region(cls, stmt: "StmtType") -> "Region": @dataclass(frozen=True) -class HasSignature(StmtTrait, ABC): +class HasSignature(Trait, ABC): """A trait that indicates that a statement has a function signature attribute. """ diff --git a/src/kirin/ir/traits/symbol.py b/src/kirin/ir/traits/symbol.py index 3d7d2bb8e..a652389c7 100644 --- a/src/kirin/ir/traits/symbol.py +++ b/src/kirin/ir/traits/symbol.py @@ -3,14 +3,14 @@ from kirin.exceptions import VerificationError from kirin.ir.attrs.py import PyAttr -from kirin.ir.traits.abc import StmtTrait +from kirin.ir.traits.abc import Trait if TYPE_CHECKING: from kirin.ir import Statement @dataclass(frozen=True) -class SymbolOpInterface(StmtTrait): +class SymbolOpInterface(Trait): """A trait that indicates that a statement is a symbol operation. A symbol operation is a statement that has a symbol name attribute. @@ -32,7 +32,7 @@ def verify(self, stmt: "Statement"): @dataclass(frozen=True) -class SymbolTable(StmtTrait): +class SymbolTable(Trait): """ Statement with SymbolTable trait can only have one region with one block. """ From 560a5373790b8330e0401d9675d0c65439997abd Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 14:04:25 -0400 Subject: [PATCH 02/14] fixing type hints --- src/kirin/dialects/py/indexing.py | 2 +- src/kirin/ir/traits/basic.py | 14 +++++++------- src/kirin/ir/traits/callable.py | 12 ++++++------ src/kirin/ir/traits/symbol.py | 16 +++++++++------- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/src/kirin/dialects/py/indexing.py b/src/kirin/dialects/py/indexing.py index 77970f307..0bf11c21e 100644 --- a/src/kirin/dialects/py/indexing.py +++ b/src/kirin/dialects/py/indexing.py @@ -30,7 +30,7 @@ @dataclass(frozen=True, eq=False) -class GetItemLike(ir.Trait, Generic[GetItemLikeStmt]): +class GetItemLike(ir.Trait[ir.Statement], Generic[GetItemLikeStmt]): @abstractmethod def get_object(self, stmt: GetItemLikeStmt) -> ir.SSAValue: ... diff --git a/src/kirin/ir/traits/basic.py b/src/kirin/ir/traits/basic.py index c97a87d17..55ae38cb9 100644 --- a/src/kirin/ir/traits/basic.py +++ b/src/kirin/ir/traits/basic.py @@ -8,7 +8,7 @@ @dataclass(frozen=True) -class Pure(Trait): +class Pure(Trait["Statement"]): """A trait that indicates that a statement is pure, i.e., it has no side effects. """ @@ -17,7 +17,7 @@ class Pure(Trait): @dataclass(frozen=True) -class MaybePure(Trait): +class MaybePure(Trait["Statement"]): """A trait that indicates the statement may be pure, i.e., a call statement can be pure if the callee is pure. """ @@ -40,7 +40,7 @@ def set_pure(cls, stmt: "Statement") -> None: @dataclass(frozen=True) -class ConstantLike(Trait): +class ConstantLike(Trait["Statement"]): """A trait that indicates that a statement is constant-like, i.e., it represents a constant value. """ @@ -49,7 +49,7 @@ class ConstantLike(Trait): @dataclass(frozen=True) -class IsTerminator(Trait): +class IsTerminator(Trait["Statement"]): """A trait that indicates that a statement is a terminator, i.e., it terminates a block. """ @@ -58,19 +58,19 @@ class IsTerminator(Trait): @dataclass(frozen=True) -class NoTerminator(Trait): +class NoTerminator(Trait["Statement"]): """A trait that indicates that the region of a statement has no terminator.""" pass @dataclass(frozen=True) -class IsolatedFromAbove(Trait): +class IsolatedFromAbove(Trait["Statement"]): pass @dataclass(frozen=True) -class HasParent(Trait): +class HasParent(Trait["Statement"]): """A trait that indicates that a statement has a parent statement. """ diff --git a/src/kirin/ir/traits/callable.py b/src/kirin/ir/traits/callable.py index 12c32a1dc..920bc49d3 100644 --- a/src/kirin/ir/traits/callable.py +++ b/src/kirin/ir/traits/callable.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, TypeVar from dataclasses import dataclass from kirin.ir.traits.abc import Trait @@ -12,7 +12,7 @@ @dataclass(frozen=True) -class CallableStmtInterface(Trait, Generic[StmtType]): +class CallableStmtInterface(Trait[StmtType]): """A trait that indicates that a statement is a callable statement. A callable statement is a statement that can be called as a function. @@ -26,13 +26,13 @@ def get_callable_region(cls, stmt: "StmtType") -> "Region": @dataclass(frozen=True) -class HasSignature(Trait, ABC): +class HasSignature(Trait[StmtType], ABC): """A trait that indicates that a statement has a function signature attribute. """ @classmethod - def get_signature(cls, stmt: "Statement"): + def get_signature(cls, stmt: StmtType): signature: Signature | None = stmt.attributes.get("signature") # type: ignore if signature is None: raise ValueError(f"Statement {stmt.name} does not have a function type") @@ -40,10 +40,10 @@ def get_signature(cls, stmt: "Statement"): return signature @classmethod - def set_signature(cls, stmt: "Statement", signature: "Signature"): + def set_signature(cls, stmt: StmtType, signature: "Signature"): stmt.attributes["signature"] = signature - def verify(self, stmt: "Statement"): + def verify(self, stmt: StmtType): from kirin.dialects.func.attrs import Signature signature = self.get_signature(stmt) diff --git a/src/kirin/ir/traits/symbol.py b/src/kirin/ir/traits/symbol.py index a652389c7..2a8ded8ce 100644 --- a/src/kirin/ir/traits/symbol.py +++ b/src/kirin/ir/traits/symbol.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar from dataclasses import dataclass from kirin.exceptions import VerificationError @@ -8,22 +8,24 @@ if TYPE_CHECKING: from kirin.ir import Statement +StmtType = TypeVar("StmtType", bound="Statement") + @dataclass(frozen=True) -class SymbolOpInterface(Trait): +class SymbolOpInterface(Trait[StmtType]): """A trait that indicates that a statement is a symbol operation. A symbol operation is a statement that has a symbol name attribute. """ - def get_sym_name(self, stmt: "Statement") -> "PyAttr[str]": + def get_sym_name(self, stmt: StmtType) -> "PyAttr[str]": sym_name: PyAttr[str] | None = stmt.get_attr_or_prop("sym_name") # type: ignore # NOTE: unlike MLIR or xDSL we do not allow empty symbol names if sym_name is None: raise ValueError(f"Statement {stmt.name} does not have a symbol name") return sym_name - def verify(self, stmt: "Statement"): + def verify(self, stmt: StmtType): from kirin.types import String sym_name = self.get_sym_name(stmt) @@ -32,16 +34,16 @@ def verify(self, stmt: "Statement"): @dataclass(frozen=True) -class SymbolTable(Trait): +class SymbolTable(Trait[StmtType]): """ Statement with SymbolTable trait can only have one region with one block. """ @staticmethod - def walk(stmt: "Statement"): + def walk(stmt: StmtType): return stmt.regions[0].blocks[0].stmts - def verify(self, stmt: "Statement"): + def verify(self, stmt: StmtType): if len(stmt.regions) != 1: raise VerificationError( stmt, From 502ef1adce6dd86630dfceaa6fc76fdc97925e00 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 14:06:10 -0400 Subject: [PATCH 03/14] renaming TypeVar --- src/kirin/ir/traits/abc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index 6dea840cf..044f74088 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -32,14 +32,14 @@ def get_graph(self, region: "Region") -> GraphType: ... ASTNode = TypeVar("ASTNode", bound=ast.AST) -StatementType = TypeVar("StatementType", bound="Statement") +StmtType = TypeVar("StmtType", bound="Statement") @dataclass(frozen=True) -class PythonLoweringTrait(Trait[StatementType], Generic[StatementType, ASTNode]): +class PythonLoweringTrait(Trait[StmtType], Generic[StmtType, ASTNode]): """A trait that indicates that a statement can be lowered from Python AST.""" @abstractmethod def lower( - self, stmt: type[StatementType], state: "lowering.LoweringState", node: ASTNode + self, stmt: type[StmtType], state: "lowering.LoweringState", node: ASTNode ) -> "lowering.Result": ... From c84cc3d245f3637ba124ff2d868c47ce1201cf0a Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 14:21:08 -0400 Subject: [PATCH 04/14] migrating work from another branch --- src/kirin/ir/traits/abc.py | 19 ++++ src/kirin/parse/__init__.py | 0 src/kirin/parse/grammer.py | 172 ++++++++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+) create mode 100644 src/kirin/parse/__init__.py create mode 100644 src/kirin/parse/grammer.py diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index 044f74088..4f7efe00e 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -3,10 +3,13 @@ from typing import TYPE_CHECKING, Generic, TypeVar from dataclasses import dataclass +import lark + if TYPE_CHECKING: from kirin import lowering from kirin.ir import Block, Region, Statement from kirin.graph import Graph + from kirin.parse.grammer import Grammer, LarkParser IRNodeType = TypeVar("IRNodeType") @@ -43,3 +46,19 @@ class PythonLoweringTrait(Trait[StmtType], Generic[StmtType, ASTNode]): def lower( self, stmt: type[StmtType], state: "lowering.LoweringState", node: ASTNode ) -> "lowering.Result": ... + + +@dataclass(frozen=True) +class LarkLoweringTrait(Trait[IRNodeType]): + + @abstractmethod + def lark_rule(self, rules: "Grammer", node: IRNodeType) -> str: ... + + @abstractmethod + def lower( + self, + parser: "LarkParser", + state: "lowering.LoweringState", + node: type[IRNodeType], + tree: lark.Tree, + ) -> "lowering.Result": ... diff --git a/src/kirin/parse/__init__.py b/src/kirin/parse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/kirin/parse/grammer.py b/src/kirin/parse/grammer.py new file mode 100644 index 000000000..1f17a1a36 --- /dev/null +++ b/src/kirin/parse/grammer.py @@ -0,0 +1,172 @@ +import textwrap +from typing import ClassVar +from dataclasses import field, dataclass + +import lark + +from kirin import ir, types +from kirin.idtable import IdTable +from kirin.ir.traits import LarkLoweringTrait +from kirin.exceptions import LarkLoweringError +from kirin.lowering.state import LoweringState +from kirin.lowering.result import Result + +SSA_IDENTIFIER: str = "ssa_identifier" +BLOCK_IDENTIFIER: str = "block_identifier" +BLOCK: str = "block" +REGION: str = "region" +SIGNATURE: str = "signature" +TYPE: str = "type" +DIALECT: str = "dialect" +ATTR: str = "attr" + + +@dataclass +class Grammer: + rule_ids: IdTable[type[ir.Statement | ir.Attribute] | types.PyClass] = field( + default_factory=IdTable, init=False + ) + stmt_ids: list[str] = field(default_factory=list, init=False) + attr_ids: list[str] = field(default_factory=list, init=False) + rules: list[str] = field(default_factory=list, init=False) + stmt_traits: dict[str, LarkLoweringTrait[ir.Statement]] = field( + default_factory=dict, init=False + ) + attr_traits: dict[str, LarkLoweringTrait[ir.Attribute] | types.PyClass] = field( + default_factory=dict, init=False + ) + + header: ClassVar[str] = textwrap.dedent( + """ + %import common.NEWLINE + %import common.CNAME -> IDENTIFIER + %import common.INT + %import common.FLOAT + %import common.ESCAPED_STRING -> STRING + %import common.WS + %ignore WS + %ignore "│" + + region: "{{" newline (newline block)* "}}" newline* + block: block_identifier block_args newline (stmt newline)* + + stmt = {stmt_rule} + attr = {attr_rule} + + block_identifier: "^" INT + block_args: '(' ssa_identifier (',' ssa_identifier)* ')' + ssa_identifier: '%' (IDENTIFIER | INT) | '%' (IDENTIFIER | INT) ":" type + newline: NEWLINE | "//" NEWLINE | "//" /.+/ NEWLINE + """ + ) + + def add_attr(self, node: type[ir.Attribute]) -> str: + trait: LarkLoweringTrait[ir.Attribute] = node.get_trait(LarkLoweringTrait) + + if trait is None: + raise LarkLoweringError( + f"Attribute {node} does not have a LarkLoweringTrait" + ) + + self.attr_ids(rule_id := self.rule_ids[node]) + self.rules.append(f"{rule_id}: {trait.lark_rule(self, node)}") + return rule_id, trait + + def add_stmt(self, node: type[ir.Statement]) -> str: + trait: LarkLoweringTrait[ir.Statement] = node.get_trait(LarkLoweringTrait) + + if trait is None: + raise LarkLoweringError( + f"Statement {node} does not have a LarkLoweringTrait" + ) + + self.stmt_ids(rule_id := self.rule_ids[node]) + self.rules.append(f"{rule_id}: {trait.lark_rule(self, node)}") + return rule_id, trait + + def add_pyclass(self, node: types.PyClass) -> str: + rule = f'"{node.prefix}.{node.display_name}"' + self.attr_ids(rule_id := self.rule_ids[node]) + self.rules.append(f"{rule_id}: {rule}") + return rule_id + + def emit(self) -> str: + stmt = " | ".join(self.stmt_ids) + attr = " | ".join(self.attr_ids) + return self.header.format(stmt_rule=stmt, attr_rule=attr) + "\n".join( + self.rules + ) + + +@dataclass(init=False) +class LarkParser: + dialects: ir.DialectGroup + lark_parser: lark.Lark + stmt_traits: dict[str, LarkLoweringTrait[ir.Statement]] + attr_traits: dict[str, LarkLoweringTrait[ir.Attribute] | types.PyClass] + state: LoweringState | None = None + + def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): + self.dialects = dialects + + start = None + grammer = Grammer() + + for dialect in dialects.data: + for attr in dialect.attrs: + rule_id, trait = grammer.add_attr(attr) + self.attr_traits[rule_id] = trait + + for type_binding in dialect.python_types.keys(): + rule_id = grammer.add_pyclass(type_binding) + self.attr_traits[rule_id] = type_binding + + for stmt in dialect.stmts: + rule_id, trait = grammer.add_attr(attr) + self.stmt_traits[rule_id] = trait + + if stmt is start_node: + start = rule_id + + if start is None: + raise LarkLoweringError(f"Start node {start_node} is not in the dialects") + + self.lark_parser = lark.Lark(grammer.emit(), start=start) + + def lower(self, tree: lark.Tree) -> Result: + node_type = tree.data + + if node_type == "newline": + return None + elif node_type == "region": + return self.lower_region(tree) + elif node_type == "block": + return self.lower_block(tree) + elif node_type == "stmt": + return self.lower_stmt(tree) + elif node_type == "attr": + return self.lower_attr(tree) + elif node_type == "type": + return self.lower_type(tree) + else: + raise LarkLoweringError(f"Unknown node type {node_type}") + + def lower_region(self, tree: lark.Tree) -> ir.Region: + + for child in tree.children: + self.lower(child) + + return Result() + + def lower_block(self, tree: lark.Tree) -> ir.Block: + block = self.state.current_frame.curr_block + + block_args = tree.children[1] + assert block_args.data == "block_args" + for arg in block_args.children: + block.args.append(self.lower(arg).expect_one()) + + for stmt in tree.children[2:]: + self.lower(stmt) + + self.state.current_frame.curr_block From 019b586f5a3a593305fbde85ff6cd8a8220d33e6 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 14:37:47 -0400 Subject: [PATCH 05/14] adding lark trait for typical statement --- src/kirin/ir/traits/lark/stmt.py | 60 ++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/kirin/ir/traits/lark/stmt.py diff --git a/src/kirin/ir/traits/lark/stmt.py b/src/kirin/ir/traits/lark/stmt.py new file mode 100644 index 000000000..29651d868 --- /dev/null +++ b/src/kirin/ir/traits/lark/stmt.py @@ -0,0 +1,60 @@ +from kirin import ir +from kirin.decl import fields +from kirin.exceptions import LarkLoweringError +from kirin.ir.nodes.stmt import Statement +from kirin.parse.grammer import LarkParser +from kirin.lowering.state import LoweringState +from kirin.lowering.result import Result as Result + +from ..abc import LarkLoweringTrait + + +class FromLark(LarkLoweringTrait): + def lark_rule( + self, + grammer_rules: dict[ir.IRNode | ir.Attribute, str], + stmt_type: type[Statement], + ) -> str: + assert ( + stmt_type.dialect is not None + ), f"Statement {stmt_type} must have a dialect" + + stmt_fields = fields(stmt_type) + + if len(stmt_fields.regions) > 0: + raise LarkLoweringError( + f"Statement {stmt_type} has regions, which are not supported by FromLark trait. create a custom trait for this statement" + ) + + if len(stmt_fields.blocks) > 0: + raise LarkLoweringError( + f"Statement {stmt_type} has blocks, which are not supported by FromLark trait. create a custom trait for this statement" + ) + + num_results = len(stmt_fields.results) + + stmt_body = f'"{stmt_type.dialect.name}.{stmt_type.name}" ' + return_match = ", ".join("ssa_identifier" for _ in range(num_results)) + type_match = ", ".join(' "!" attr' for _ in range(num_results)) + stmt_args = ", ".join( + f'"{arg.name}" "=" ssa_identifier' for arg in stmt_fields.args + ) + attr_args = ", ".join( + f'"{name}" "=" {grammer_rules[attr.type]}' + for name, attr in stmt_fields.attributes.items() + ) + + stmt_rule = f'{stmt_body} "(" {stmt_args} ")"' + + if len(attr_args) > 0: + stmt_rule = f'{stmt_rule} "{{" {attr_args} "}}"' + + if len(return_match) > 0: + stmt_rule = f'"{return_match} "=" {stmt_rule} ":" {type_match}' + + return stmt_rule + + def lower( + self, parser: LarkParser, state: LoweringState, stmt: Statement + ) -> Result: + pass From e33f343cbb6f47fe09240641f74c7a7071778e0d Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 14:38:27 -0400 Subject: [PATCH 06/14] adding LarkLoweringError --- src/kirin/exceptions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/kirin/exceptions.py b/src/kirin/exceptions.py index 76582990d..a452d1316 100644 --- a/src/kirin/exceptions.py +++ b/src/kirin/exceptions.py @@ -57,3 +57,7 @@ def __init__(self, node: "IRNode", *messages: str) -> None: class DuplicatedDefinitionError(Exception): pass + + +class LarkLoweringError(Exception): + pass From a46c18a793e2098cf7bebc31bef574c4549e8034 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 14:54:49 -0400 Subject: [PATCH 07/14] fixing spelling error --- src/kirin/ir/traits/abc.py | 4 ++-- src/kirin/ir/traits/lark/stmt.py | 8 ++++---- src/kirin/parse/{grammer.py => grammar.py} | 7 +++++-- 3 files changed, 11 insertions(+), 8 deletions(-) rename src/kirin/parse/{grammer.py => grammar.py} (97%) diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index 5249e74f9..be885bdf9 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -9,7 +9,7 @@ from kirin import lowering from kirin.ir import Block, Region, Statement from kirin.graph import Graph - from kirin.parse.grammer import Grammer, LarkParser + from kirin.parse.grammar import Grammar, LarkParser IRNodeType = TypeVar("IRNodeType") @@ -55,7 +55,7 @@ def lower( class LarkLoweringTrait(Trait[IRNodeType]): @abstractmethod - def lark_rule(self, rules: "Grammer", node: IRNodeType) -> str: ... + def lark_rule(self, grammar: "Grammar", node: IRNodeType) -> str: ... @abstractmethod def lower( diff --git a/src/kirin/ir/traits/lark/stmt.py b/src/kirin/ir/traits/lark/stmt.py index 29651d868..ddc7d4a63 100644 --- a/src/kirin/ir/traits/lark/stmt.py +++ b/src/kirin/ir/traits/lark/stmt.py @@ -1,8 +1,7 @@ -from kirin import ir from kirin.decl import fields from kirin.exceptions import LarkLoweringError from kirin.ir.nodes.stmt import Statement -from kirin.parse.grammer import LarkParser +from kirin.parse.grammar import Grammar, LarkParser from kirin.lowering.state import LoweringState from kirin.lowering.result import Result as Result @@ -12,7 +11,7 @@ class FromLark(LarkLoweringTrait): def lark_rule( self, - grammer_rules: dict[ir.IRNode | ir.Attribute, str], + grammar: Grammar, stmt_type: type[Statement], ) -> str: assert ( @@ -31,6 +30,7 @@ def lark_rule( f"Statement {stmt_type} has blocks, which are not supported by FromLark trait. create a custom trait for this statement" ) + # TODO: replace global rules like: ssa_identifier, attr, etc with module constants: kirin.parse.grammar.SSA_IDENTIFIER, kirin.parse.grammar.ATTR, etc num_results = len(stmt_fields.results) stmt_body = f'"{stmt_type.dialect.name}.{stmt_type.name}" ' @@ -40,7 +40,7 @@ def lark_rule( f'"{arg.name}" "=" ssa_identifier' for arg in stmt_fields.args ) attr_args = ", ".join( - f'"{name}" "=" {grammer_rules[attr.type]}' + f'"{name}" "=" {grammar.attr_rules[type(attr.type)]}' for name, attr in stmt_fields.attributes.items() ) diff --git a/src/kirin/parse/grammer.py b/src/kirin/parse/grammar.py similarity index 97% rename from src/kirin/parse/grammer.py rename to src/kirin/parse/grammar.py index 1f17a1a36..071fed046 100644 --- a/src/kirin/parse/grammer.py +++ b/src/kirin/parse/grammar.py @@ -22,12 +22,15 @@ @dataclass -class Grammer: +class Grammar: rule_ids: IdTable[type[ir.Statement | ir.Attribute] | types.PyClass] = field( default_factory=IdTable, init=False ) stmt_ids: list[str] = field(default_factory=list, init=False) attr_ids: list[str] = field(default_factory=list, init=False) + attr_rules: dict[type[ir.Attribute] | types.PyClass, str] = field( + default_factory=dict, init=False + ) rules: list[str] = field(default_factory=list, init=False) stmt_traits: dict[str, LarkLoweringTrait[ir.Statement]] = field( default_factory=dict, init=False @@ -110,7 +113,7 @@ def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): self.dialects = dialects start = None - grammer = Grammer() + grammer = Grammar() for dialect in dialects.data: for attr in dialect.attrs: From d1474b77ff90fe6dfbf7405d7868b6e8530b0bc4 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 15:01:24 -0400 Subject: [PATCH 08/14] rename variables --- src/kirin/ir/traits/lark/stmt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/kirin/ir/traits/lark/stmt.py b/src/kirin/ir/traits/lark/stmt.py index ddc7d4a63..3967991d0 100644 --- a/src/kirin/ir/traits/lark/stmt.py +++ b/src/kirin/ir/traits/lark/stmt.py @@ -36,18 +36,18 @@ def lark_rule( stmt_body = f'"{stmt_type.dialect.name}.{stmt_type.name}" ' return_match = ", ".join("ssa_identifier" for _ in range(num_results)) type_match = ", ".join(' "!" attr' for _ in range(num_results)) - stmt_args = ", ".join( + stmt_args_rule = ", ".join( f'"{arg.name}" "=" ssa_identifier' for arg in stmt_fields.args ) - attr_args = ", ".join( + attr_args_rule = ", ".join( f'"{name}" "=" {grammar.attr_rules[type(attr.type)]}' for name, attr in stmt_fields.attributes.items() ) - stmt_rule = f'{stmt_body} "(" {stmt_args} ")"' + stmt_rule = f'{stmt_body} "(" {stmt_args_rule} ")"' - if len(attr_args) > 0: - stmt_rule = f'{stmt_rule} "{{" {attr_args} "}}"' + if len(attr_args_rule) > 0: + stmt_rule = f'{stmt_rule} "{{" {attr_args_rule} "}}"' if len(return_match) > 0: stmt_rule = f'"{return_match} "=" {stmt_rule} ":" {type_match}' From 95442ccde47e265be0910d31015adfd40dd27ff5 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 15:02:23 -0400 Subject: [PATCH 09/14] renaming class --- src/kirin/ir/traits/abc.py | 4 ++-- src/kirin/parse/grammar.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index be885bdf9..f886a4b97 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -9,7 +9,7 @@ from kirin import lowering from kirin.ir import Block, Region, Statement from kirin.graph import Graph - from kirin.parse.grammar import Grammar, LarkParser + from kirin.parse.grammar import Grammar, DialectGroupParser IRNodeType = TypeVar("IRNodeType") @@ -60,7 +60,7 @@ def lark_rule(self, grammar: "Grammar", node: IRNodeType) -> str: ... @abstractmethod def lower( self, - parser: "LarkParser", + parser: "DialectGroupParser", state: "lowering.LoweringState", node: type[IRNodeType], tree: lark.Tree, diff --git a/src/kirin/parse/grammar.py b/src/kirin/parse/grammar.py index 071fed046..56685f645 100644 --- a/src/kirin/parse/grammar.py +++ b/src/kirin/parse/grammar.py @@ -102,7 +102,7 @@ def emit(self) -> str: @dataclass(init=False) -class LarkParser: +class DialectGroupParser: dialects: ir.DialectGroup lark_parser: lark.Lark stmt_traits: dict[str, LarkLoweringTrait[ir.Statement]] From 4414aff113a8ee2d1ef2ce7e7c6ebdd7bd527399 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Thu, 13 Mar 2025 15:36:27 -0400 Subject: [PATCH 10/14] finish visitor --- src/kirin/ir/traits/__init__.py | 1 + src/kirin/ir/traits/abc.py | 4 +- src/kirin/parse/grammar.py | 116 +++++++++++++++++++++++--------- 3 files changed, 86 insertions(+), 35 deletions(-) diff --git a/src/kirin/ir/traits/__init__.py b/src/kirin/ir/traits/__init__.py index 2232522c9..f21494507 100644 --- a/src/kirin/ir/traits/__init__.py +++ b/src/kirin/ir/traits/__init__.py @@ -12,6 +12,7 @@ from .abc import ( Trait as Trait, RegionTrait as RegionTrait, + LarkLoweringTrait as LarkLoweringTrait, PythonLoweringTrait as PythonLoweringTrait, ) from .basic import ( diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index f886a4b97..a21df2050 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -9,7 +9,7 @@ from kirin import lowering from kirin.ir import Block, Region, Statement from kirin.graph import Graph - from kirin.parse.grammar import Grammar, DialectGroupParser + from kirin.parse.grammar import Grammar, LarkLowerResult, DialectGroupParser IRNodeType = TypeVar("IRNodeType") @@ -64,4 +64,4 @@ def lower( state: "lowering.LoweringState", node: type[IRNodeType], tree: lark.Tree, - ) -> "lowering.Result": ... + ) -> "LarkLowerResult[IRNodeType]": ... diff --git a/src/kirin/parse/grammar.py b/src/kirin/parse/grammar.py index 56685f645..29d0ac470 100644 --- a/src/kirin/parse/grammar.py +++ b/src/kirin/parse/grammar.py @@ -1,5 +1,5 @@ import textwrap -from typing import ClassVar +from typing import Generic, TypeVar, ClassVar from dataclasses import field, dataclass import lark @@ -9,7 +9,6 @@ from kirin.ir.traits import LarkLoweringTrait from kirin.exceptions import LarkLoweringError from kirin.lowering.state import LoweringState -from kirin.lowering.result import Result SSA_IDENTIFIER: str = "ssa_identifier" BLOCK_IDENTIFIER: str = "block_identifier" @@ -20,6 +19,40 @@ DIALECT: str = "dialect" ATTR: str = "attr" +NodeType = TypeVar("NodeType", bound=ir.Statement | ir.Attribute | None) + + +@dataclass +class LarkLowerResult(Generic[NodeType]): + result: NodeType = None + + def expect_none(self): + if self.result is not None: + raise LarkLoweringError(f"Expected None, got {self.result}") + + def expect_stmt(self) -> ir.Statement: + if not isinstance(self.result, ir.Statement): + raise LarkLoweringError(f"Expected statement, got {self.result}") + + return self.result + + def expect_attr(self) -> ir.Attribute: + if not isinstance(self.result, ir.Attribute): + raise LarkLoweringError(f"Expected attribute, got {self.result}") + + return self.result + + +@dataclass +class LarkTraitWrapper(Generic[NodeType]): + node_type: type[NodeType] + trait: LarkLoweringTrait[NodeType] + + def lower( + self, parser: "DialectGroupParser", state: LoweringState, tree: lark.Tree + ): + return self.trait.lower(parser, state, self.node_type, tree) + @dataclass class Grammar: @@ -28,14 +61,12 @@ class Grammar: ) stmt_ids: list[str] = field(default_factory=list, init=False) attr_ids: list[str] = field(default_factory=list, init=False) - attr_rules: dict[type[ir.Attribute] | types.PyClass, str] = field( - default_factory=dict, init=False - ) + rules: list[str] = field(default_factory=list, init=False) - stmt_traits: dict[str, LarkLoweringTrait[ir.Statement]] = field( + stmt_traits: dict[str, LarkTraitWrapper[ir.Statement]] = field( default_factory=dict, init=False ) - attr_traits: dict[str, LarkLoweringTrait[ir.Attribute] | types.PyClass] = field( + attr_traits: dict[str, LarkTraitWrapper[ir.Attribute] | types.PyClass] = field( default_factory=dict, init=False ) @@ -63,35 +94,35 @@ class Grammar: """ ) - def add_attr(self, node: type[ir.Attribute]) -> str: - trait: LarkLoweringTrait[ir.Attribute] = node.get_trait(LarkLoweringTrait) + def add_attr(self, node_type: type[ir.Attribute]) -> str: + trait: LarkLoweringTrait[ir.Attribute] = node_type.get_trait(LarkLoweringTrait) if trait is None: raise LarkLoweringError( - f"Attribute {node} does not have a LarkLoweringTrait" + f"Attribute {node_type} does not have a LarkLoweringTrait" ) - self.attr_ids(rule_id := self.rule_ids[node]) - self.rules.append(f"{rule_id}: {trait.lark_rule(self, node)}") - return rule_id, trait + self.attr_ids(rule_id := self.rule_ids[node_type]) + self.rules.append(f"{rule_id}: {trait.lark_rule(self, node_type)}") + return rule_id, LarkTraitWrapper(node_type, trait) - def add_stmt(self, node: type[ir.Statement]) -> str: - trait: LarkLoweringTrait[ir.Statement] = node.get_trait(LarkLoweringTrait) + def add_stmt(self, node_type: type[ir.Statement]) -> str: + trait: LarkLoweringTrait[ir.Statement] = node_type.get_trait(LarkLoweringTrait) if trait is None: raise LarkLoweringError( - f"Statement {node} does not have a LarkLoweringTrait" + f"Statement {node_type} does not have a LarkLoweringTrait" ) - self.stmt_ids(rule_id := self.rule_ids[node]) - self.rules.append(f"{rule_id}: {trait.lark_rule(self, node)}") - return rule_id, trait + self.stmt_ids(rule_id := self.rule_ids[node_type]) + self.rules.append(f"{rule_id}: {trait.lark_rule(self, node_type)}") + return rule_id, LarkTraitWrapper(node_type, trait) def add_pyclass(self, node: types.PyClass) -> str: rule = f'"{node.prefix}.{node.display_name}"' self.attr_ids(rule_id := self.rule_ids[node]) self.rules.append(f"{rule_id}: {rule}") - return rule_id + return rule_id, node def emit(self) -> str: stmt = " | ".join(self.stmt_ids) @@ -105,8 +136,8 @@ def emit(self) -> str: class DialectGroupParser: dialects: ir.DialectGroup lark_parser: lark.Lark - stmt_traits: dict[str, LarkLoweringTrait[ir.Statement]] - attr_traits: dict[str, LarkLoweringTrait[ir.Attribute] | types.PyClass] + stmt_registry: dict[str, LarkTraitWrapper[ir.Statement]] + attr_registry: dict[str, LarkTraitWrapper[ir.Attribute] | types.PyClass] state: LoweringState | None = None def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): @@ -126,7 +157,7 @@ def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): for stmt in dialect.stmts: rule_id, trait = grammer.add_attr(attr) - self.stmt_traits[rule_id] = trait + self.stmt_registry[rule_id] = trait if stmt is start_node: start = rule_id @@ -136,11 +167,11 @@ def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): self.lark_parser = lark.Lark(grammer.emit(), start=start) - def lower(self, tree: lark.Tree) -> Result: + def lower(self, tree: lark.Tree): node_type = tree.data if node_type == "newline": - return None + return LarkLowerResult() elif node_type == "region": return self.lower_region(tree) elif node_type == "block": @@ -149,19 +180,15 @@ def lower(self, tree: lark.Tree) -> Result: return self.lower_stmt(tree) elif node_type == "attr": return self.lower_attr(tree) - elif node_type == "type": - return self.lower_type(tree) else: raise LarkLoweringError(f"Unknown node type {node_type}") - def lower_region(self, tree: lark.Tree) -> ir.Region: - + def lower_region(self, tree: lark.Tree): for child in tree.children: self.lower(child) + return LarkLowerResult() - return Result() - - def lower_block(self, tree: lark.Tree) -> ir.Block: + def lower_block(self, tree: lark.Tree): block = self.state.current_frame.curr_block block_args = tree.children[1] @@ -172,4 +199,27 @@ def lower_block(self, tree: lark.Tree) -> ir.Block: for stmt in tree.children[2:]: self.lower(stmt) - self.state.current_frame.curr_block + self.state.current_frame.append_block() + return LarkLowerResult() + + def lower_stmt(self, tree: lark.Tree): + if tree.data not in self.stmt_registry: + raise LarkLoweringError(f"Unknown statement type {tree.data}") + + stmt = self.stmt_registry[tree.data].lower(self, self.state, tree).expect_stmt() + self.state.current_frame.append_stmt(stmt) + + return LarkLowerResult() + + def lower_attr(self, tree: lark.Tree): + if tree.data not in self.attr_registry: + raise LarkLoweringError(f"Unknown statement type {tree.data}") + + reg_result = self.attr_registry[tree.data] + if isinstance(reg_result, types.PyClass): + return LarkLowerResult(reg_result) + else: + return reg_result.lower(self, self.state, tree) + + def run(self, body: str, entry: type[NodeType]) -> NodeType: + raise NotImplementedError("TODO: implement run method") From 5b37f45bef1d66aef52cc07d93ffc33f62ce61a9 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Fri, 14 Mar 2025 15:23:22 -0400 Subject: [PATCH 11/14] WIP:refactor visitor --- src/kirin/parse/grammar.py | 66 ++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/src/kirin/parse/grammar.py b/src/kirin/parse/grammar.py index 29d0ac470..9092065a6 100644 --- a/src/kirin/parse/grammar.py +++ b/src/kirin/parse/grammar.py @@ -1,5 +1,5 @@ import textwrap -from typing import Generic, TypeVar, ClassVar +from typing import Any, Generic, TypeVar, ClassVar from dataclasses import field, dataclass import lark @@ -9,6 +9,7 @@ from kirin.ir.traits import LarkLoweringTrait from kirin.exceptions import LarkLoweringError from kirin.lowering.state import LoweringState +from kirin.lowering.result import Result SSA_IDENTIFIER: str = "ssa_identifier" BLOCK_IDENTIFIER: str = "block_identifier" @@ -23,8 +24,8 @@ @dataclass -class LarkLowerResult(Generic[NodeType]): - result: NodeType = None +class LarkLowerResult: + result: Any = None def expect_none(self): if self.result is not None: @@ -42,6 +43,12 @@ def expect_attr(self) -> ir.Attribute: return self.result + def expect_ssa(self) -> ir.SSAValue: + if not isinstance(self.result, ir.SSAValue): + raise LarkLoweringError(f"Expected SSA, got {self.result}") + + return self.result + @dataclass class LarkTraitWrapper(Generic[NodeType]): @@ -83,13 +90,17 @@ class Grammar: region: "{{" newline (newline block)* "}}" newline* block: block_identifier block_args newline (stmt newline)* + stmt_ssa_args: "(" ssa_assign ("," ssa_assign)* ")" | "(" ")" + stmt_attr_args: "{" attr_assign (",", attr_assign)* "}" stmt = {stmt_rule} attr = {attr_rule} block_identifier: "^" INT + ssa_assign: IDENTIFIER "=" ssa_identifier + attr_assign: IDENTIFIER "=" attr block_args: '(' ssa_identifier (',' ssa_identifier)* ')' - ssa_identifier: '%' (IDENTIFIER | INT) | '%' (IDENTIFIER | INT) ":" type + ssa_identifier: '%' (IDENTIFIER | INT) newline: NEWLINE | "//" NEWLINE | "//" /.+/ NEWLINE """ ) @@ -167,51 +178,43 @@ def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): self.lark_parser = lark.Lark(grammer.emit(), start=start) - def lower(self, tree: lark.Tree): + def visit(self, tree: lark.Tree) -> Result: node_type = tree.data + visitor = getattr(self, f"visit_{node_type}", self.default_visit) + return visitor(tree) - if node_type == "newline": - return LarkLowerResult() - elif node_type == "region": - return self.lower_region(tree) - elif node_type == "block": - return self.lower_block(tree) - elif node_type == "stmt": - return self.lower_stmt(tree) - elif node_type == "attr": - return self.lower_attr(tree) - else: - raise LarkLoweringError(f"Unknown node type {node_type}") + def default_visit(self, tree: lark.Tree): + raise LarkLoweringError(f"Unknown node type {tree.data}") - def lower_region(self, tree: lark.Tree): + def visit_region(self, tree: lark.Tree): for child in tree.children: - self.lower(child) - return LarkLowerResult() + self.visit(child) + return Result() - def lower_block(self, tree: lark.Tree): + def visit_block(self, tree: lark.Tree): block = self.state.current_frame.curr_block block_args = tree.children[1] assert block_args.data == "block_args" for arg in block_args.children: - block.args.append(self.lower(arg).expect_one()) + block.args.append(self.visit(arg).expect_one()) for stmt in tree.children[2:]: - self.lower(stmt) + self.visit(stmt) self.state.current_frame.append_block() - return LarkLowerResult() + return Result() - def lower_stmt(self, tree: lark.Tree): + def visit_stmt(self, tree: lark.Tree): if tree.data not in self.stmt_registry: raise LarkLoweringError(f"Unknown statement type {tree.data}") stmt = self.stmt_registry[tree.data].lower(self, self.state, tree).expect_stmt() self.state.current_frame.append_stmt(stmt) - return LarkLowerResult() + return Result() - def lower_attr(self, tree: lark.Tree): + def visit_attr(self, tree: lark.Tree): if tree.data not in self.attr_registry: raise LarkLoweringError(f"Unknown statement type {tree.data}") @@ -221,5 +224,14 @@ def lower_attr(self, tree: lark.Tree): else: return reg_result.lower(self, self.state, tree) + def visit_stmt_ssa_args(self, tree: lark.Tree): + return Result([self.visit(child).expect_one() for child in tree.children]) + + def visit_ssa_assign(self, tree: lark.Tree): + return Result([self.visit(tree.children[1]).expect_one()]) + + def visit_ssa_identifier(self, tree: lark.Tree): + + def run(self, body: str, entry: type[NodeType]) -> NodeType: raise NotImplementedError("TODO: implement run method") From 0816428ffeecdae0287cbd44f3ccb3e3d54d38c5 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 17 Mar 2025 20:24:07 -0400 Subject: [PATCH 12/14] WIP: sketch func + general trait --- src/kirin/dialects/lowering/cf.py | 8 +- src/kirin/dialects/lowering/func.py | 2 +- src/kirin/dialects/scf/lowering.py | 8 +- src/kirin/ir/traits/abc.py | 8 +- src/kirin/ir/traits/lark/func.py | 39 +++ src/kirin/ir/traits/lark/stmt.py | 63 +++-- src/kirin/ir/traits/lowering/context.py | 2 +- src/kirin/lowering/__init__.py | 5 +- src/kirin/lowering/frame.py | 51 +++- src/kirin/lowering/state.py | 2 +- src/kirin/parse/grammar.py | 340 ++++++++++++++++-------- src/kirin/source.py | 20 ++ 12 files changed, 393 insertions(+), 155 deletions(-) create mode 100644 src/kirin/ir/traits/lark/func.py diff --git a/src/kirin/dialects/lowering/cf.py b/src/kirin/dialects/lowering/cf.py index 28377eb86..03e018de8 100644 --- a/src/kirin/dialects/lowering/cf.py +++ b/src/kirin/dialects/lowering/cf.py @@ -34,7 +34,7 @@ def new_block_arg_if_inside_loop(frame: Frame, capture: ir.SSAValue): none_stmt = frame.append_stmt(py.Constant(None)) body_frame = state.push_frame( - Frame.from_stmts( + Frame.from_ast( node.body, state, region=state.current_frame.curr_region, @@ -87,7 +87,7 @@ def lower_If(self, state: LoweringState, node: ast.If) -> Result: frame = state.current_frame before_block = frame.curr_block if_frame = state.push_frame( - Frame.from_stmts( + Frame.from_ast( node.body, state, region=frame.curr_region, @@ -102,7 +102,7 @@ def lower_If(self, state: LoweringState, node: ast.If) -> Result: state.pop_frame() else_frame = state.push_frame( - Frame.from_stmts( + Frame.from_ast( node.orelse, state, region=frame.curr_region, @@ -117,7 +117,7 @@ def lower_If(self, state: LoweringState, node: ast.If) -> Result: state.pop_frame() after_frame = state.push_frame( - Frame.from_stmts( + Frame.from_ast( frame.stream.split(), state, region=frame.curr_region, diff --git a/src/kirin/dialects/lowering/func.py b/src/kirin/dialects/lowering/func.py index 20474fb21..e2a28d577 100644 --- a/src/kirin/dialects/lowering/func.py +++ b/src/kirin/dialects/lowering/func.py @@ -60,7 +60,7 @@ def callback(frame: lowering.Frame, value: ir.SSAValue): return stmt.result func_frame = state.push_frame( - lowering.Frame.from_stmts( + lowering.Frame.from_ast( node.body, state, entr_block=entr_block, diff --git a/src/kirin/dialects/scf/lowering.py b/src/kirin/dialects/scf/lowering.py index 06d196900..434d50db9 100644 --- a/src/kirin/dialects/scf/lowering.py +++ b/src/kirin/dialects/scf/lowering.py @@ -14,7 +14,7 @@ class Lowering(lowering.FromPythonAST): def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Result: cond = state.visit(node.test).expect_one() frame = state.current_frame - body_frame = lowering.Frame.from_stmts(node.body, state, globals=frame.globals) + body_frame = lowering.Frame.from_ast(node.body, state, globals=frame.globals) then_cond = body_frame.curr_block.args.append_from(types.Bool, cond.name) if cond.name: body_frame.defs[cond.name] = then_cond @@ -22,9 +22,7 @@ def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Resu state.exhaust(body_frame) state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks - else_frame = lowering.Frame.from_stmts( - node.orelse, state, globals=frame.globals - ) + else_frame = lowering.Frame.from_ast(node.orelse, state, globals=frame.globals) else_cond = else_frame.curr_block.args.append_from(types.Bool, cond.name) if cond.name: else_frame.defs[cond.name] = else_cond @@ -91,7 +89,7 @@ def new_block_arg_if_inside_loop(frame: lowering.Frame, capture: ir.SSAValue): return frame.curr_block.args.append_from(capture.type, capture.name) body_frame = state.push_frame( - lowering.Frame.from_stmts( + lowering.Frame.from_ast( node.body, state, globals=state.current_frame.globals, diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index a21df2050..fe68bdd7c 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -9,7 +9,7 @@ from kirin import lowering from kirin.ir import Block, Region, Statement from kirin.graph import Graph - from kirin.parse.grammar import Grammar, LarkLowerResult, DialectGroupParser + from kirin.parse.grammar import Grammar, LarkLowerResult IRNodeType = TypeVar("IRNodeType") @@ -59,9 +59,5 @@ def lark_rule(self, grammar: "Grammar", node: IRNodeType) -> str: ... @abstractmethod def lower( - self, - parser: "DialectGroupParser", - state: "lowering.LoweringState", - node: type[IRNodeType], - tree: lark.Tree, + self, state: "lowering.LoweringState", stmt: type[IRNodeType], tree: "lark.Tree" ) -> "LarkLowerResult[IRNodeType]": ... diff --git a/src/kirin/ir/traits/lark/func.py b/src/kirin/ir/traits/lark/func.py new file mode 100644 index 000000000..b1a22c94b --- /dev/null +++ b/src/kirin/ir/traits/lark/func.py @@ -0,0 +1,39 @@ +from typing import TYPE_CHECKING + +import lark + +from kirin.lowering import Frame +from kirin.parse.grammar import LarkLoweringState + +from ..abc import LarkLoweringTrait + +if TYPE_CHECKING: + from kirin.dialects import func + + +class FunctionLowerTrait(LarkLoweringTrait["func.Function"]): + + def lark_rule(self, _, __): + return '"func.func" IDENTIFIER signature region' + + def lower( + self, + state: LarkLoweringState, + func_type: type["func.Function"], + tree: lark.Tree, + ): + from kirin.dialects.func import Signature + + _, sym_name_tree, signature_tree, region_tree = tree.children + + sym_name = state.visit(sym_name_tree).expect(str) + signature = state.visit(signature_tree).expect(Signature) + state.push_frame(Frame.from_lark(state)) + + state.visit(region_tree) + + return func_type( + sym_name=sym_name, + signature=signature, + body=state.pop_frame(finalize_next=False), + ) diff --git a/src/kirin/ir/traits/lark/stmt.py b/src/kirin/ir/traits/lark/stmt.py index 3967991d0..320dfd52f 100644 --- a/src/kirin/ir/traits/lark/stmt.py +++ b/src/kirin/ir/traits/lark/stmt.py @@ -1,8 +1,9 @@ +from lark import Tree, Token + from kirin.decl import fields from kirin.exceptions import LarkLoweringError from kirin.ir.nodes.stmt import Statement -from kirin.parse.grammar import Grammar, LarkParser -from kirin.lowering.state import LoweringState +from kirin.parse.grammar import Grammar, LarkLowerResult, LarkLoweringState from kirin.lowering.result import Result as Result from ..abc import LarkLoweringTrait @@ -30,31 +31,41 @@ def lark_rule( f"Statement {stmt_type} has blocks, which are not supported by FromLark trait. create a custom trait for this statement" ) - # TODO: replace global rules like: ssa_identifier, attr, etc with module constants: kirin.parse.grammar.SSA_IDENTIFIER, kirin.parse.grammar.ATTR, etc - num_results = len(stmt_fields.results) - - stmt_body = f'"{stmt_type.dialect.name}.{stmt_type.name}" ' - return_match = ", ".join("ssa_identifier" for _ in range(num_results)) - type_match = ", ".join(' "!" attr' for _ in range(num_results)) - stmt_args_rule = ", ".join( - f'"{arg.name}" "=" ssa_identifier' for arg in stmt_fields.args - ) - attr_args_rule = ", ".join( - f'"{name}" "=" {grammar.attr_rules[type(attr.type)]}' - for name, attr in stmt_fields.attributes.items() - ) - - stmt_rule = f'{stmt_body} "(" {stmt_args_rule} ")"' + results = 'stmt_return_args "=" ' if len(stmt_fields.results) > 0 else "" + attrs = "stmt_attr_args" if len(stmt_fields.attrs) > 0 else "" - if len(attr_args_rule) > 0: - stmt_rule = f'{stmt_rule} "{{" {attr_args_rule} "}}"' + return f'{results} "{stmt_type.dialect.name}.{stmt_type.name}" stmt_ssa_args {attrs}' - if len(return_match) > 0: - stmt_rule = f'"{return_match} "=" {stmt_rule} ":" {type_match}' + def lower( + self, state: LarkLoweringState, stmt_type: type[Statement], tree: Tree + ) -> LarkLowerResult: + results = [] + attrs = {} + match tree.children: + case [ + Tree() as results_tree, + Token(), + Token(), + Tree() as ssa_args_tree, + Tree() as attrs_tree, + ]: + results = state.visit(results_tree).expect(list) + ssa_args = state.visit(ssa_args_tree).expect(dict) + attrs = state.visit(attrs_tree).expect(dict) + case [Tree() as results_tree, Token(), Token(), Tree() as ssa_args_tree]: + results = state.visit(results_tree).expect(list) + ssa_args = state.visit(ssa_args_tree).expect(dict) + case [Token(), Tree() as ssa_args_tree, Tree() as attrs_tree]: + ssa_args = state.visit(ssa_args_tree).expect(dict) + attrs = state.visit(attrs_tree).expect(dict) + case [Token(), Tree() as ssa_args_tree]: + ssa_args = state.visit(ssa_args_tree).expect(dict) + case _: + raise ValueError(f"Unexpected tree shape: {tree}") - return stmt_rule + stmt = state.append_stmt(stmt_type(**ssa_args, **attrs)) + state.current_frame.defs.update( + {result: ssa for result, ssa in zip(results, stmt.results)} + ) - def lower( - self, parser: LarkParser, state: LoweringState, stmt: Statement - ) -> Result: - pass + return LarkLowerResult(stmt) diff --git a/src/kirin/ir/traits/lowering/context.py b/src/kirin/ir/traits/lowering/context.py index 1c849df9b..4e79655b1 100644 --- a/src/kirin/ir/traits/lowering/context.py +++ b/src/kirin/ir/traits/lowering/context.py @@ -71,7 +71,7 @@ def lower( f"Expected context expression to be a call for with {stmt.name}" ) - body_frame = lowering.Frame.from_stmts(body, state, parent=state.current_frame) + body_frame = lowering.Frame.from_ast(body, state, parent=state.current_frame) state.push_frame(body_frame) state.exhaust() region_name, region_info = next(iter(fs.regions.items())) diff --git a/src/kirin/lowering/__init__.py b/src/kirin/lowering/__init__.py index 78358b065..c8ea07c04 100644 --- a/src/kirin/lowering/__init__.py +++ b/src/kirin/lowering/__init__.py @@ -1,6 +1,9 @@ from kirin.lowering.core import Lowering as Lowering from kirin.lowering.frame import Frame as Frame -from kirin.lowering.state import LoweringState as LoweringState +from kirin.lowering.state import ( + SourceInfo as SourceInfo, + LoweringState as LoweringState, +) from kirin.lowering.result import Result as Result from kirin.lowering.stream import StmtStream as StmtStream from kirin.lowering.binding import wraps as wraps diff --git a/src/kirin/lowering/frame.py b/src/kirin/lowering/frame.py index 71cb95f1d..ed7603e5b 100644 --- a/src/kirin/lowering/frame.py +++ b/src/kirin/lowering/frame.py @@ -1,5 +1,5 @@ import ast -from typing import TYPE_CHECKING, Any, TypeVar, Callable, Optional, Sequence +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Optional, Sequence from dataclasses import field, dataclass from kirin.ir import Block, Region, SSAValue, Statement @@ -7,19 +7,22 @@ from kirin.lowering.stream import StmtStream if TYPE_CHECKING: + from kirin.parse.grammar import LarkLoweringState from kirin.lowering.state import LoweringState CallbackFn = Callable[["Frame", SSAValue], SSAValue] +NodeAST = TypeVar("NodeAST") + @dataclass -class Frame: +class Frame(Generic[NodeAST]): state: "LoweringState" """lowering state""" parent: Optional["Frame"] """parent frame, if any""" - stream: StmtStream[ast.stmt] + stream: StmtStream[NodeAST] """stream of statements to be lowered""" curr_region: Region @@ -42,7 +45,47 @@ class Frame: """callback function that creates a local SSAValue value when an captured value was used.""" @classmethod - def from_stmts( + def from_lark( + cls, + state: "LarkLoweringState", + parent: Optional["Frame"] = None, + region: Optional[Region] = None, + entr_block: Optional[Block] = None, + next_block: Optional[Block] = None, + ): + """Create a new frame from a lark lowering state. + + - `state`: lark lowering state. + - `region`: `Region` to append the new block to, `None` to create a new one, default `None`. + - `entr_block`: `Block` to append the new statements to, + `None` to create a new one and attached to the region, default `None`. + - `next_block`: `Block` to use if branching to a new block, if `None` to create + a new one without attaching to the region. (note: this should not attach to + the region at frame construction) + """ + if region is None: + entr_block = entr_block or Block() + region = Region(entr_block) + + if entr_block is None: + entr_block = region.blocks[0] + + assert ( + region.blocks[0] is entr_block + ), "entr_block must be the first block in the region" + + return cls( + state=state, + parent=parent, + stream=StmtStream([]), + curr_region=region, + entr_block=entr_block, + curr_block=entr_block, + next_block=next_block or Block(), + ) + + @classmethod + def from_ast( cls, stmts: Sequence[ast.stmt] | StmtStream[ast.stmt], state: "LoweringState", diff --git a/src/kirin/lowering/state.py b/src/kirin/lowering/state.py index eae9451f2..2b73047d3 100644 --- a/src/kirin/lowering/state.py +++ b/src/kirin/lowering/state.py @@ -62,7 +62,7 @@ def from_stmt( max_lines=max_lines, ) - frame = Frame.from_stmts([stmt], state, globals=globals) + frame = Frame.from_ast([stmt], state, globals=globals) state.push_frame(frame) return state diff --git a/src/kirin/parse/grammar.py b/src/kirin/parse/grammar.py index 9092065a6..2a312021e 100644 --- a/src/kirin/parse/grammar.py +++ b/src/kirin/parse/grammar.py @@ -1,15 +1,14 @@ import textwrap -from typing import Any, Generic, TypeVar, ClassVar +from typing import Generic, TypeVar, ClassVar, Optional from dataclasses import field, dataclass import lark -from kirin import ir, types +from kirin import ir, types, lowering from kirin.idtable import IdTable +from kirin.dialects import func from kirin.ir.traits import LarkLoweringTrait from kirin.exceptions import LarkLoweringError -from kirin.lowering.state import LoweringState -from kirin.lowering.result import Result SSA_IDENTIFIER: str = "ssa_identifier" BLOCK_IDENTIFIER: str = "block_identifier" @@ -22,58 +21,34 @@ NodeType = TypeVar("NodeType", bound=ir.Statement | ir.Attribute | None) +T = TypeVar("T") -@dataclass -class LarkLowerResult: - result: Any = None - - def expect_none(self): - if self.result is not None: - raise LarkLoweringError(f"Expected None, got {self.result}") - def expect_stmt(self) -> ir.Statement: - if not isinstance(self.result, ir.Statement): - raise LarkLoweringError(f"Expected statement, got {self.result}") - - return self.result - - def expect_attr(self) -> ir.Attribute: - if not isinstance(self.result, ir.Attribute): - raise LarkLoweringError(f"Expected attribute, got {self.result}") - - return self.result +@dataclass +class LarkLowerResult(Generic[T]): + result: T = field(default=None) - def expect_ssa(self) -> ir.SSAValue: - if not isinstance(self.result, ir.SSAValue): - raise LarkLoweringError(f"Expected SSA, got {self.result}") + def expect[T](self, typ: type[T]) -> T: + if not isinstance(self.result, typ): + raise ValueError(f"Expected {typ}, got {self.result}") return self.result -@dataclass -class LarkTraitWrapper(Generic[NodeType]): - node_type: type[NodeType] - trait: LarkLoweringTrait[NodeType] - - def lower( - self, parser: "DialectGroupParser", state: LoweringState, tree: lark.Tree - ): - return self.trait.lower(parser, state, self.node_type, tree) - - @dataclass class Grammar: - rule_ids: IdTable[type[ir.Statement | ir.Attribute] | types.PyClass] = field( - default_factory=IdTable, init=False - ) + rule_ids: IdTable = field(default_factory=IdTable, init=False) stmt_ids: list[str] = field(default_factory=list, init=False) attr_ids: list[str] = field(default_factory=list, init=False) rules: list[str] = field(default_factory=list, init=False) - stmt_traits: dict[str, LarkTraitWrapper[ir.Statement]] = field( + stmt_traits: dict[str, LarkLoweringTrait[ir.Statement]] = field( + default_factory=dict, init=False + ) + attr_traits: dict[str, LarkLoweringTrait[ir.Attribute] | types.PyClass] = field( default_factory=dict, init=False ) - attr_traits: dict[str, LarkTraitWrapper[ir.Attribute] | types.PyClass] = field( + type_map: dict[str, type[ir.Statement] | type[ir.Attribute]] = field( default_factory=dict, init=False ) @@ -89,19 +64,25 @@ class Grammar: %ignore "│" region: "{{" newline (newline block)* "}}" newline* - block: block_identifier block_args newline (stmt newline)* - stmt_ssa_args: "(" ssa_assign ("," ssa_assign)* ")" | "(" ")" - stmt_attr_args: "{" attr_assign (",", attr_assign)* "}" - - stmt = {stmt_rule} - attr = {attr_rule} + block: block_identifier "(" block_args ")" ":" newline (stmt newline)* + signature: ( "(" ")" | "(" attr ("," attr)* ")" ) "->" attr + stmt_ssa_args: "(" kwarg_ssa ("," kwarg_ssa)* ")" | "(" ")" + stmt_attr_args: "{" kwarg_attr (",", kwarg_attr)* "}" + stmt_return_args: ssa_value ("," ssa_value)* + block_args: block_argument ("," block_argument)* block_identifier: "^" INT - ssa_assign: IDENTIFIER "=" ssa_identifier - attr_assign: IDENTIFIER "=" attr - block_args: '(' ssa_identifier (',' ssa_identifier)* ')' + kwarg_ssa: IDENTIFIER "=" ssa_value + kwarg_attr: IDENTIFIER "=" attr + block_argument: ssa_identifier | annotated_ssa_identifier + annotated_ssa_identifier: ssa_identifier ":" attr + ssa_identifier: '%' (IDENTIFIER | INT) + ssa_value: '%' (IDENTIFIER | INT) newline: NEWLINE | "//" NEWLINE | "//" /.+/ NEWLINE + + stmt = {stmt_rule} + attr = {attr_rule} """ ) @@ -115,7 +96,8 @@ def add_attr(self, node_type: type[ir.Attribute]) -> str: self.attr_ids(rule_id := self.rule_ids[node_type]) self.rules.append(f"{rule_id}: {trait.lark_rule(self, node_type)}") - return rule_id, LarkTraitWrapper(node_type, trait) + self.type_map[rule_id] = node_type + return rule_id, trait def add_stmt(self, node_type: type[ir.Statement]) -> str: trait: LarkLoweringTrait[ir.Statement] = node_type.get_trait(LarkLoweringTrait) @@ -127,7 +109,8 @@ def add_stmt(self, node_type: type[ir.Statement]) -> str: self.stmt_ids(rule_id := self.rule_ids[node_type]) self.rules.append(f"{rule_id}: {trait.lark_rule(self, node_type)}") - return rule_id, LarkTraitWrapper(node_type, trait) + self.type_map[rule_id] = node_type + return rule_id, node_type def add_pyclass(self, node: types.PyClass) -> str: rule = f'"{node.prefix}.{node.display_name}"' @@ -143,95 +126,240 @@ def emit(self) -> str: ) -@dataclass(init=False) -class DialectGroupParser: +@dataclass +class LarkLoweringState: dialects: ir.DialectGroup - lark_parser: lark.Lark - stmt_registry: dict[str, LarkTraitWrapper[ir.Statement]] - attr_registry: dict[str, LarkTraitWrapper[ir.Attribute] | types.PyClass] - state: LoweringState | None = None + source_info: lowering.SourceInfo + registry: dict[ + str, + LarkLoweringTrait[ir.Statement] + | LarkLoweringTrait[ir.Attribute] + | types.PyClass, + ] + type_map: dict[str, type[ir.Statement] | type[ir.Attribute]] + + _current_frame: Optional[lowering.Frame[lark.Tree]] = field( + default=None, init=False + ) - def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): - self.dialects = dialects + @classmethod + def from_stmt( + cls, + stmt: lark.Tree, + dialect_group_parser: "DialectGroupParser", + ): + return cls( + dialect_group_parser.dialects, + lowering.SourceInfo.from_lark_tree(stmt), + registry=dialect_group_parser.registry, + type_map=dialect_group_parser.type_map, + ) - start = None - grammer = Grammar() + @property + def current_frame(self) -> lowering.Frame[lark.Tree]: + if self._current_frame is None: + raise ValueError("No frame") + return self._current_frame - for dialect in dialects.data: - for attr in dialect.attrs: - rule_id, trait = grammer.add_attr(attr) - self.attr_traits[rule_id] = trait + @property + def code(self): + stmt = self.current_frame.curr_region.blocks[0].first_stmt + if stmt: + return stmt + raise ValueError("No code generated") - for type_binding in dialect.python_types.keys(): - rule_id = grammer.add_pyclass(type_binding) - self.attr_traits[rule_id] = type_binding + StmtType = TypeVar("StmtType", bound=ir.Statement) - for stmt in dialect.stmts: - rule_id, trait = grammer.add_attr(attr) - self.stmt_registry[rule_id] = trait + def append_stmt(self, stmt: StmtType) -> StmtType: + """Shorthand for appending a statement to the current block of current frame.""" + return self.current_frame.append_stmt(stmt) - if stmt is start_node: - start = rule_id + def push_frame(self, frame: lowering.Frame): + frame.parent = self._current_frame + self._current_frame = frame + return frame - if start is None: - raise LarkLoweringError(f"Start node {start_node} is not in the dialects") + def pop_frame(self, finalize_next: bool = True): + """Pop the current frame and return it. - self.lark_parser = lark.Lark(grammer.emit(), start=start) + Args: + finalize_next(bool): If True, append the next block of the current frame. - def visit(self, tree: lark.Tree) -> Result: - node_type = tree.data - visitor = getattr(self, f"visit_{node_type}", self.default_visit) - return visitor(tree) + Returns: + Frame: The popped frame. + """ + if self._current_frame is None: + raise ValueError("No frame to pop") + frame = self._current_frame + + if finalize_next and frame.next_block.parent is None: + frame.append_block(frame.next_block) + self._current_frame = frame.parent + return frame + + def update_lineno(self, node: lark.Tree): + self.source = lowering.SourceInfo.from_lark_tree(node) + + def visit(self, node: lark.Tree | lark.Token) -> LarkLowerResult: + if isinstance(node, lark.Tree): + self.source_info = lowering.SourceInfo.from_lark_tree(node) + return getattr(self, f"visit_{node.data}", self.default)(node) + elif isinstance(node, lark.Token): + self.source_info = lowering.SourceInfo.from_lark_token(node) + return LarkLowerResult(node.value) + else: + raise ValueError(f"Unknown node type {node}") - def default_visit(self, tree: lark.Tree): + def default(self, tree: lark.Tree): raise LarkLoweringError(f"Unknown node type {tree.data}") def visit_region(self, tree: lark.Tree): for child in tree.children: self.visit(child) - return Result() + + return LarkLowerResult(self.current_frame.curr_region) def visit_block(self, tree: lark.Tree): - block = self.state.current_frame.curr_block + self.current_frame.append_block() - block_args = tree.children[1] - assert block_args.data == "block_args" - for arg in block_args.children: - block.args.append(self.visit(arg).expect_one()) + for child in tree.children: + self.visit(child) - for stmt in tree.children[2:]: - self.visit(stmt) + return LarkLowerResult(self.current_frame.curr_block) - self.state.current_frame.append_block() - return Result() + def visit_signature(self, tree: lark.Tree) -> func.Signature: + *inputs, ret = [ + self.visit(child).expect(types.TypeAttribute) + for child in tree.children + if isinstance(child, lark.Tree) + ] + return LarkLowerResult(func.Signature(inputs, ret=ret)) def visit_stmt(self, tree: lark.Tree): - if tree.data not in self.stmt_registry: + if tree.data not in self.registry: raise LarkLoweringError(f"Unknown statement type {tree.data}") - stmt = self.stmt_registry[tree.data].lower(self, self.state, tree).expect_stmt() - self.state.current_frame.append_stmt(stmt) - - return Result() + return self.registry[tree.data].lower(self, self.type_map[tree.data], tree) def visit_attr(self, tree: lark.Tree): - if tree.data not in self.attr_registry: + if tree.data not in self.registry: raise LarkLoweringError(f"Unknown statement type {tree.data}") - reg_result = self.attr_registry[tree.data] + reg_result = self.registry[tree.data] if isinstance(reg_result, types.PyClass): return LarkLowerResult(reg_result) else: - return reg_result.lower(self, self.state, tree) + return reg_result.lower(self, self.type_map[tree.data], tree) + + def visit_ssa_stmt_args(self, tree: lark.Tree): + return LarkLowerResult( + dict( + self.visit(child).expect(tuple) + for child in tree.children + if isinstance(child, lark.Tree) + ) + ) - def visit_stmt_ssa_args(self, tree: lark.Tree): - return Result([self.visit(child).expect_one() for child in tree.children]) + def visit_stmt_attr_args(self, tree: lark.Tree): + return LarkLowerResult( + dict( + self.visit(child).expect(tuple) + for child in tree.children + if isinstance(child, lark.Tree) + ) + ) - def visit_ssa_assign(self, tree: lark.Tree): - return Result([self.visit(tree.children[1]).expect_one()]) + def visit_kwarg_ssa(self, tree: lark.Tree): + name = self.visit(tree.children[0]).expect(str) + value = self.visit(tree.children[2]).expect(ir.SSAValue) + return LarkLowerResult((name, value)) + + def visit_kwarg_attr(self, tree: lark.Tree): + name = self.visit(tree.children[0]).expect(str) + value = self.visit(tree.children[2]).expect(ir.Attribute) + return LarkLowerResult((name, value)) def visit_ssa_identifier(self, tree: lark.Tree): + return LarkLowerResult( + "".join(str(self.visit(child).result) for child in tree.children) + ) + + def visit_block_identifier(self, tree: lark.Tree): + return LarkLowerResult(self.visit(tree.children[1]).expect(int)) + + def visit_block_argument(self, tree: lark.Tree): + results = list(map(self.visit, tree.children)) + + if len(results) == 2: + ident = results[0].expect(str) + attr = types.Any + elif len(results) == 3: + ident = results[0].expect(str) + attr = results[2].expect(ir.Attribute) + else: + raise ValueError(f"Expected 2 or 3 results, got {len(results)}") + assert ident.startswith("%") - def run(self, body: str, entry: type[NodeType]) -> NodeType: - raise NotImplementedError("TODO: implement run method") + self.current_frame.defs[ident] = ( + block_arg := self.current_frame.curr_block.args.append_from( + attr, (None if ident[1:].isnumeric() else ident[1:]) + ) + ) + + return LarkLowerResult(block_arg) + + def visit_stmt_return_args(self, tree: lark.Tree): + return LarkLowerResult( + [ + self.visit(child).expect(str) + for child in tree.children + if isinstance(child, lark.Tree) + ] + ) + + def visit_ssa_value(self, tree: lark.Tree): + ident = self.visit_ssa_identifier(tree).expect(str) + return LarkLowerResult(self.current_frame.get_scope(ident)) + + +@dataclass(init=False) +class DialectGroupParser: + dialects: ir.DialectGroup + lark_parser: lark.Lark + registry: dict[ + str, + LarkLoweringTrait[ir.Statement] + | LarkLoweringTrait[ir.Attribute] + | types.PyClass, + ] + type_map: dict[str, type[ir.Statement] | type[ir.Attribute]] + + def __init__(self, dialects: ir.DialectGroup, start_node: ir.Statement): + self.dialects = dialects + + start = None + grammer = Grammar() + + for dialect in dialects.data: + for attr in dialect.attrs: + rule_id, trait = grammer.add_attr(attr) + self.registry[rule_id] = trait + self.type_map[rule_id] = attr + + for type_binding in dialect.python_types.keys(): + rule_id = grammer.add_pyclass(type_binding) + self.registry[rule_id] = type_binding + + for stmt in dialect.stmts: + rule_id, trait = grammer.add_attr(attr) + self.registry[rule_id] = trait + self.type_map[rule_id] = stmt + + if stmt is start_node: + start = rule_id + + if start is None: + raise LarkLoweringError(f"Start node {start_node} is not in the dialects") + + self.lark_parser = lark.Lark(grammer.emit(), start=start) diff --git a/src/kirin/source.py b/src/kirin/source.py index de96be847..8d01161eb 100644 --- a/src/kirin/source.py +++ b/src/kirin/source.py @@ -1,6 +1,8 @@ import ast from dataclasses import dataclass +import lark + @dataclass class SourceInfo: @@ -19,3 +21,21 @@ def from_ast(cls, node: ast.AST, lineno_offset: int = 0, col_offset: int = 0): end_lineno + lineno_offset if end_lineno is not None else None, end_col_offset + col_offset if end_col_offset is not None else None, ) + + @classmethod + def from_lark_tree(cls, node: lark.Tree | lark.Token): + return cls( + node.meta.line, + node.meta.column, + node.meta.end_line, + node.meta.end_column, + ) + + @classmethod + def from_lark_token(cls, token: lark.Token): + return cls( + token.line, + token.column, + token.end_line, + token.end_column, + ) From 15960db29ddec6e58d3a8f137d3b49ef9734e3d5 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 18 Mar 2025 09:02:17 -0400 Subject: [PATCH 13/14] removing uneeded change --- src/kirin/dialects/lowering/cf.py | 8 ++++---- src/kirin/dialects/lowering/func.py | 2 +- src/kirin/dialects/scf/lowering.py | 8 +++++--- src/kirin/ir/traits/lowering/context.py | 2 +- src/kirin/lowering/frame.py | 2 +- src/kirin/lowering/state.py | 2 +- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/kirin/dialects/lowering/cf.py b/src/kirin/dialects/lowering/cf.py index 03e018de8..28377eb86 100644 --- a/src/kirin/dialects/lowering/cf.py +++ b/src/kirin/dialects/lowering/cf.py @@ -34,7 +34,7 @@ def new_block_arg_if_inside_loop(frame: Frame, capture: ir.SSAValue): none_stmt = frame.append_stmt(py.Constant(None)) body_frame = state.push_frame( - Frame.from_ast( + Frame.from_stmts( node.body, state, region=state.current_frame.curr_region, @@ -87,7 +87,7 @@ def lower_If(self, state: LoweringState, node: ast.If) -> Result: frame = state.current_frame before_block = frame.curr_block if_frame = state.push_frame( - Frame.from_ast( + Frame.from_stmts( node.body, state, region=frame.curr_region, @@ -102,7 +102,7 @@ def lower_If(self, state: LoweringState, node: ast.If) -> Result: state.pop_frame() else_frame = state.push_frame( - Frame.from_ast( + Frame.from_stmts( node.orelse, state, region=frame.curr_region, @@ -117,7 +117,7 @@ def lower_If(self, state: LoweringState, node: ast.If) -> Result: state.pop_frame() after_frame = state.push_frame( - Frame.from_ast( + Frame.from_stmts( frame.stream.split(), state, region=frame.curr_region, diff --git a/src/kirin/dialects/lowering/func.py b/src/kirin/dialects/lowering/func.py index e2a28d577..20474fb21 100644 --- a/src/kirin/dialects/lowering/func.py +++ b/src/kirin/dialects/lowering/func.py @@ -60,7 +60,7 @@ def callback(frame: lowering.Frame, value: ir.SSAValue): return stmt.result func_frame = state.push_frame( - lowering.Frame.from_ast( + lowering.Frame.from_stmts( node.body, state, entr_block=entr_block, diff --git a/src/kirin/dialects/scf/lowering.py b/src/kirin/dialects/scf/lowering.py index 434d50db9..06d196900 100644 --- a/src/kirin/dialects/scf/lowering.py +++ b/src/kirin/dialects/scf/lowering.py @@ -14,7 +14,7 @@ class Lowering(lowering.FromPythonAST): def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Result: cond = state.visit(node.test).expect_one() frame = state.current_frame - body_frame = lowering.Frame.from_ast(node.body, state, globals=frame.globals) + body_frame = lowering.Frame.from_stmts(node.body, state, globals=frame.globals) then_cond = body_frame.curr_block.args.append_from(types.Bool, cond.name) if cond.name: body_frame.defs[cond.name] = then_cond @@ -22,7 +22,9 @@ def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Resu state.exhaust(body_frame) state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks - else_frame = lowering.Frame.from_ast(node.orelse, state, globals=frame.globals) + else_frame = lowering.Frame.from_stmts( + node.orelse, state, globals=frame.globals + ) else_cond = else_frame.curr_block.args.append_from(types.Bool, cond.name) if cond.name: else_frame.defs[cond.name] = else_cond @@ -89,7 +91,7 @@ def new_block_arg_if_inside_loop(frame: lowering.Frame, capture: ir.SSAValue): return frame.curr_block.args.append_from(capture.type, capture.name) body_frame = state.push_frame( - lowering.Frame.from_ast( + lowering.Frame.from_stmts( node.body, state, globals=state.current_frame.globals, diff --git a/src/kirin/ir/traits/lowering/context.py b/src/kirin/ir/traits/lowering/context.py index 4e79655b1..1c849df9b 100644 --- a/src/kirin/ir/traits/lowering/context.py +++ b/src/kirin/ir/traits/lowering/context.py @@ -71,7 +71,7 @@ def lower( f"Expected context expression to be a call for with {stmt.name}" ) - body_frame = lowering.Frame.from_ast(body, state, parent=state.current_frame) + body_frame = lowering.Frame.from_stmts(body, state, parent=state.current_frame) state.push_frame(body_frame) state.exhaust() region_name, region_info = next(iter(fs.regions.items())) diff --git a/src/kirin/lowering/frame.py b/src/kirin/lowering/frame.py index ed7603e5b..bd36f1b8e 100644 --- a/src/kirin/lowering/frame.py +++ b/src/kirin/lowering/frame.py @@ -85,7 +85,7 @@ def from_lark( ) @classmethod - def from_ast( + def from_stmts( cls, stmts: Sequence[ast.stmt] | StmtStream[ast.stmt], state: "LoweringState", diff --git a/src/kirin/lowering/state.py b/src/kirin/lowering/state.py index 2b73047d3..eae9451f2 100644 --- a/src/kirin/lowering/state.py +++ b/src/kirin/lowering/state.py @@ -62,7 +62,7 @@ def from_stmt( max_lines=max_lines, ) - frame = Frame.from_ast([stmt], state, globals=globals) + frame = Frame.from_stmts([stmt], state, globals=globals) state.push_frame(frame) return state From 06c8d0261c61e7ecb8511c199e75192657e08a9a Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Tue, 18 Mar 2025 09:03:07 -0400 Subject: [PATCH 14/14] Removing redundent declaration --- src/kirin/ir/traits/abc.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/kirin/ir/traits/abc.py b/src/kirin/ir/traits/abc.py index fe68bdd7c..9c410c97d 100644 --- a/src/kirin/ir/traits/abc.py +++ b/src/kirin/ir/traits/abc.py @@ -15,9 +15,6 @@ IRNodeType = TypeVar("IRNodeType") -IRNodeType = TypeVar("IRNodeType") - - @dataclass(frozen=True) class Trait(ABC, Generic[IRNodeType]): """Base class for all statement traits."""