Skip to content
6 changes: 5 additions & 1 deletion src/kirin/analysis/const/prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def try_eval_const_pure(
_frame.set_values(stmt.args, tuple(x.data for x in values))
method = self._interp.lookup_registry(frame, stmt)
if method is not None:
value = method(self._interp, _frame, stmt)
try:
value = method(self._interp, _frame, stmt)
except NotImplementedError:
# the concrete interpreter doesn't have the implementation so we cannot evaluate it
return tuple(Unknown() for _ in stmt.results)
else:
return tuple(Unknown() for _ in stmt.results)
match value:
Expand Down
8 changes: 6 additions & 2 deletions src/kirin/dialects/ilist/constprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def one_args(
# 1. if the function is a constant method, and the method is pure, then the map is pure
if isinstance(fn, const.Value) and isinstance(method := fn.data, ir.Method):
self.detect_purity(interp_, frame, stmt, method.code, (fn, const.Unknown()))
if isinstance(collection, const.Value):
if isinstance(collection, const.Value) and stmt in frame.should_be_pure:
return interp_.try_eval_const_pure(frame, stmt, (fn, collection))
elif isinstance(fn, const.PartialLambda):
self.detect_purity(interp_, frame, stmt, fn.code, (fn, const.Unknown()))
Expand All @@ -57,7 +57,11 @@ def two_args(self, interp_: const.Propagate, frame: const.Frame, stmt: Foldl):
method.code,
(fn, const.Unknown(), const.Unknown()),
)
if isinstance(collection, const.Value) and isinstance(init, const.Value):
if (
isinstance(collection, const.Value)
and isinstance(init, const.Value)
and stmt in frame.should_be_pure
):
return interp_.try_eval_const_pure(frame, stmt, (fn, collection, init))
elif isinstance(fn, const.PartialLambda):
self.detect_purity(
Expand Down
52 changes: 52 additions & 0 deletions test/analysis/dataflow/constprop/test_missing_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from kirin import ir, types, passes, lowering
from kirin.decl import info, statement
from kirin.prelude import basic_no_opt
from kirin.analysis import const
from kirin.dialects import ilist

new_dialect = ir.Dialect("test")


@statement(dialect=new_dialect)
class DefaultInit(ir.Statement):
name = "test"

traits = frozenset({lowering.FromPythonCall(), ir.Pure()})

result: ir.ResultValue = info.result(types.Int)


dialect_group = basic_no_opt.add(new_dialect)


def test_missing_impl_try_eval_const_pure():
# this test is trying to trigger the code path in propagate.py
# where a statement has no concrete implementation but is pure
# in this case, the ilist will attempt to evaluate the closure
# which contains a call to DefaultInit, which has no implementation
# in the concrete interpreter. In this case we should still be able
# to mark the result as Unknown, rather than failing the analysis.
# In other words, if a statement has no implementation, but is pure,
# the function `try_eval_const_pure` will catch the exception and
# return Unknown for the result.
@dialect_group
def test():
n = 10

def _inner(val: int) -> int:
return DefaultInit() * val # type: ignore

return ilist.map(_inner, ilist.range(n))

passes.HintConst(dialect_group)(test)

for i in range(5):
stmt = test.callable_region.blocks[0].stmts.at(i)
assert all(
isinstance(result.hints.get("const"), const.Value)
for result in stmt.results
)

call_stmt = test.callable_region.blocks[0].stmts.at(5)
assert isinstance(call_stmt, ilist.Map)
assert isinstance(call_stmt.result.hints.get("const"), const.Unknown)
27 changes: 27 additions & 0 deletions test/dialects/test_ilist.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Any, Literal

from kirin import ir, types, rewrite
from kirin.decl import info, statement
from kirin.passes import aggressive
from kirin.prelude import basic_no_opt, python_basic
from kirin.analysis import const
from kirin.dialects import py, func, ilist, lowering
from kirin.lowering import FromPythonCall
from kirin.passes.typeinfer import TypeInfer


Expand Down Expand Up @@ -386,6 +388,31 @@ def main2():
assert target.data == (6, 6)


def test_ilist_constprop_non_pure():

new_dialect = ir.Dialect("test")

@statement(dialect=new_dialect)
class DefaultInit(ir.Statement):
name = "test"
traits = frozenset({FromPythonCall()})
result: ir.ResultValue = info.result(types.Float)

dialect_group = basic_no_opt.add(new_dialect)

@dialect_group
def test():

def inner(_: int):
return DefaultInit()

return ilist.map(inner, ilist.range(10))

_, res = const.Propagate(dialect_group).run(test)

assert isinstance(res, const.Unknown)


rule = rewrite.Fixpoint(rewrite.Walk(ilist.rewrite.Unroll()))
xs = ilist.IList([1, 2, 3])

Expand Down