Skip to content

Lark text parser #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions src/kirin/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ def __init__(self, node: "IRNode", *messages: str) -> None:

class DuplicatedDefinitionError(Exception):
pass


class LarkLoweringError(Exception):
pass
1 change: 1 addition & 0 deletions src/kirin/ir/traits/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .abc import (
Trait as Trait,
RegionTrait as RegionTrait,
LarkLoweringTrait as LarkLoweringTrait,
PythonLoweringTrait as PythonLoweringTrait,
)
from .basic import (
Expand Down
15 changes: 15 additions & 0 deletions src/kirin/ir/traits/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from typing import TYPE_CHECKING, Generic, TypeVar
from dataclasses import dataclass

import lark

if TYPE_CHECKING:
from kirin import lowering
from kirin.ir import Block, Region, Statement
from kirin.graph import Graph
from kirin.parse.grammar import Grammar, LarkLowerResult


IRNodeType = TypeVar("IRNodeType")
Expand Down Expand Up @@ -43,3 +46,15 @@ class PythonLoweringTrait(Trait[StmtType], Generic[StmtType, ASTNode]):
def lower(
self, stmt: type[StmtType], state: "lowering.LoweringState", node: ASTNode
) -> "lowering.Result": ...


@dataclass(frozen=True)
class LarkLoweringTrait(Trait[IRNodeType]):

@abstractmethod
def lark_rule(self, grammar: "Grammar", node: IRNodeType) -> str: ...

@abstractmethod
def lower(
self, state: "lowering.LoweringState", stmt: type[IRNodeType], tree: "lark.Tree"
) -> "LarkLowerResult[IRNodeType]": ...
39 changes: 39 additions & 0 deletions src/kirin/ir/traits/lark/func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import TYPE_CHECKING

import lark

from kirin.lowering import Frame
from kirin.parse.grammar import LarkLoweringState

from ..abc import LarkLoweringTrait

if TYPE_CHECKING:
from kirin.dialects import func


class FunctionLowerTrait(LarkLoweringTrait["func.Function"]):

def lark_rule(self, _, __):
return '"func.func" IDENTIFIER signature region'

def lower(
self,
state: LarkLoweringState,
func_type: type["func.Function"],
tree: lark.Tree,
):
from kirin.dialects.func import Signature

_, sym_name_tree, signature_tree, region_tree = tree.children

sym_name = state.visit(sym_name_tree).expect(str)
signature = state.visit(signature_tree).expect(Signature)
state.push_frame(Frame.from_lark(state))

state.visit(region_tree)

return func_type(
sym_name=sym_name,
signature=signature,
body=state.pop_frame(finalize_next=False),
)
71 changes: 71 additions & 0 deletions src/kirin/ir/traits/lark/stmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from lark import Tree, Token

from kirin.decl import fields
from kirin.exceptions import LarkLoweringError
from kirin.ir.nodes.stmt import Statement
from kirin.parse.grammar import Grammar, LarkLowerResult, LarkLoweringState
from kirin.lowering.result import Result as Result

from ..abc import LarkLoweringTrait


class FromLark(LarkLoweringTrait):
def lark_rule(
self,
grammar: Grammar,
stmt_type: type[Statement],
) -> str:
assert (
stmt_type.dialect is not None
), f"Statement {stmt_type} must have a dialect"

stmt_fields = fields(stmt_type)

if len(stmt_fields.regions) > 0:
raise LarkLoweringError(
f"Statement {stmt_type} has regions, which are not supported by FromLark trait. create a custom trait for this statement"
)

if len(stmt_fields.blocks) > 0:
raise LarkLoweringError(
f"Statement {stmt_type} has blocks, which are not supported by FromLark trait. create a custom trait for this statement"
)

results = 'stmt_return_args "=" ' if len(stmt_fields.results) > 0 else ""
attrs = "stmt_attr_args" if len(stmt_fields.attrs) > 0 else ""

return f'{results} "{stmt_type.dialect.name}.{stmt_type.name}" stmt_ssa_args {attrs}'

def lower(
self, state: LarkLoweringState, stmt_type: type[Statement], tree: Tree
) -> LarkLowerResult:
results = []
attrs = {}
match tree.children:
case [
Tree() as results_tree,
Token(),
Token(),
Tree() as ssa_args_tree,
Tree() as attrs_tree,
]:
results = state.visit(results_tree).expect(list)
ssa_args = state.visit(ssa_args_tree).expect(dict)
attrs = state.visit(attrs_tree).expect(dict)
case [Tree() as results_tree, Token(), Token(), Tree() as ssa_args_tree]:
results = state.visit(results_tree).expect(list)
ssa_args = state.visit(ssa_args_tree).expect(dict)
case [Token(), Tree() as ssa_args_tree, Tree() as attrs_tree]:
ssa_args = state.visit(ssa_args_tree).expect(dict)
attrs = state.visit(attrs_tree).expect(dict)
case [Token(), Tree() as ssa_args_tree]:
ssa_args = state.visit(ssa_args_tree).expect(dict)
case _:
raise ValueError(f"Unexpected tree shape: {tree}")

stmt = state.append_stmt(stmt_type(**ssa_args, **attrs))
state.current_frame.defs.update(
{result: ssa for result, ssa in zip(results, stmt.results)}
)

return LarkLowerResult(stmt)
5 changes: 4 additions & 1 deletion src/kirin/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from kirin.lowering.core import Lowering as Lowering
from kirin.lowering.frame import Frame as Frame
from kirin.lowering.state import LoweringState as LoweringState
from kirin.lowering.state import (
SourceInfo as SourceInfo,
LoweringState as LoweringState,
)
from kirin.lowering.result import Result as Result
from kirin.lowering.stream import StmtStream as StmtStream
from kirin.lowering.binding import wraps as wraps
Expand Down
49 changes: 46 additions & 3 deletions src/kirin/lowering/frame.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import ast
from typing import TYPE_CHECKING, Any, TypeVar, Callable, Optional, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Optional, Sequence
from dataclasses import field, dataclass

from kirin.ir import Block, Region, SSAValue, Statement
from kirin.exceptions import DialectLoweringError
from kirin.lowering.stream import StmtStream

if TYPE_CHECKING:
from kirin.parse.grammar import LarkLoweringState
from kirin.lowering.state import LoweringState


CallbackFn = Callable[["Frame", SSAValue], SSAValue]

NodeAST = TypeVar("NodeAST")


@dataclass
class Frame:
class Frame(Generic[NodeAST]):
state: "LoweringState"
"""lowering state"""
parent: Optional["Frame"]
"""parent frame, if any"""
stream: StmtStream[ast.stmt]
stream: StmtStream[NodeAST]
"""stream of statements to be lowered"""

curr_region: Region
Expand All @@ -41,6 +44,46 @@ class Frame:
capture_callback: Optional[CallbackFn] = None
"""callback function that creates a local SSAValue value when an captured value was used."""

@classmethod
def from_lark(
cls,
state: "LarkLoweringState",
parent: Optional["Frame"] = None,
region: Optional[Region] = None,
entr_block: Optional[Block] = None,
next_block: Optional[Block] = None,
):
"""Create a new frame from a lark lowering state.

- `state`: lark lowering state.
- `region`: `Region` to append the new block to, `None` to create a new one, default `None`.
- `entr_block`: `Block` to append the new statements to,
`None` to create a new one and attached to the region, default `None`.
- `next_block`: `Block` to use if branching to a new block, if `None` to create
a new one without attaching to the region. (note: this should not attach to
the region at frame construction)
"""
if region is None:
entr_block = entr_block or Block()
region = Region(entr_block)

if entr_block is None:
entr_block = region.blocks[0]

assert (
region.blocks[0] is entr_block
), "entr_block must be the first block in the region"

return cls(
state=state,
parent=parent,
stream=StmtStream([]),
curr_region=region,
entr_block=entr_block,
curr_block=entr_block,
next_block=next_block or Block(),
)

@classmethod
def from_stmts(
cls,
Expand Down
Empty file added src/kirin/parse/__init__.py
Empty file.
Loading
Loading