diff --git a/src/bloqade/squin/analysis/__init__.py b/src/bloqade/squin/analysis/__init__.py index dea64d4e..9466bd3e 100644 --- a/src/bloqade/squin/analysis/__init__.py +++ b/src/bloqade/squin/analysis/__init__.py @@ -1 +1 @@ -from . import address_impl as address_impl +from . import unitary as unitary, hermitian as hermitian, address_impl as address_impl diff --git a/src/bloqade/squin/analysis/hermitian/__init__.py b/src/bloqade/squin/analysis/hermitian/__init__.py new file mode 100644 index 00000000..5ea284eb --- /dev/null +++ b/src/bloqade/squin/analysis/hermitian/__init__.py @@ -0,0 +1,10 @@ +# Need this for impl registration to work properly! +from . import impls as impls +from .lattice import ( + Hermitian as Hermitian, + NotHermitian as NotHermitian, + NotAnOperator as NotAnOperator, + HermitianLattice as HermitianLattice, + PossiblyHermitian as PossiblyHermitian, +) +from .analysis import HermitianAnalysis as HermitianAnalysis diff --git a/src/bloqade/squin/analysis/hermitian/analysis.py b/src/bloqade/squin/analysis/hermitian/analysis.py new file mode 100644 index 00000000..57fcbd0e --- /dev/null +++ b/src/bloqade/squin/analysis/hermitian/analysis.py @@ -0,0 +1,31 @@ +from kirin import ir +from kirin.analysis import Forward, ForwardFrame + +from ... import op +from .lattice import Hermitian, NotHermitian, HermitianLattice + + +class HermitianAnalysis(Forward): + keys = ["squin.hermitian"] + lattice = HermitianLattice + + def run_method(self, method: ir.Method, args: tuple[HermitianLattice, ...]): + return self.run_callable(method.code, (self.lattice.bottom(),) + args) + + def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): + if not isinstance(stmt, op.stmts.Operator): + return (self.lattice.bottom(),) + + if stmt.has_trait(op.traits.Hermitian): + return (Hermitian(),) + + if ( + trait := stmt.get_trait(op.traits.MaybeHermitian) + ) is not None and trait.is_hermitian(stmt): + return (Hermitian(),) + + if isinstance(stmt, op.stmts.PrimitiveOp): + # NOTE: simple operator without the hermitian trait, so we know it's non-hermitian + return (NotHermitian(),) + + return (self.lattice.top(),) diff --git a/src/bloqade/squin/analysis/hermitian/impls.py b/src/bloqade/squin/analysis/hermitian/impls.py new file mode 100644 index 00000000..427390cf --- /dev/null +++ b/src/bloqade/squin/analysis/hermitian/impls.py @@ -0,0 +1,59 @@ +from kirin import types, interp +from kirin.analysis import ForwardFrame +from kirin.dialects import scf, func + +from ... import op +from .lattice import Hermitian +from .analysis import HermitianAnalysis + + +@op.dialect.register(key="squin.hermitian") +class HermitianMethods(interp.MethodTable): + @interp.impl(op.stmts.Control) + @interp.impl(op.stmts.Adjoint) + def simple_container( + self, + interp: HermitianAnalysis, + frame: ForwardFrame, + stmt: op.stmts.Control | op.stmts.Adjoint, + ): + return (frame.get(stmt.op),) + + @interp.impl(op.stmts.Scale) + def scale( + self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale + ): + is_hermitian = frame.get(stmt.op) + + if not is_hermitian.is_subseteq(Hermitian): + return (is_hermitian,) + + # NOTE: need to check if the factor is a real number + if stmt.factor.type.is_subseteq(types.Float | types.Int | types.Bool): + return (Hermitian(),) + + # NOTE: could still be a complex number type with zero imaginary part + return (interp.lattice.top(),) + + @interp.impl(op.stmts.Kron) + def kron(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Kron): + is_hermitian = frame.get(stmt.lhs).join(frame.get(stmt.rhs)) + return (is_hermitian,) + + @interp.impl(op.stmts.Mult) + def mult(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Mult): + # NOTE: this could be smarter here and check whether lhs == adjoint(rhs) + if stmt.lhs != stmt.rhs: + return (interp.lattice.top(),) + + return (frame.get(stmt.lhs),) + + +@scf.dialect.register(key="squin.hermitian") +class ScfHermitianMethods(scf.typeinfer.TypeInfer): + pass + + +@func.dialect.register(key="squin.hermitian") +class FuncHermitianMethods(func.typeinfer.TypeInfer): + pass diff --git a/src/bloqade/squin/analysis/hermitian/lattice.py b/src/bloqade/squin/analysis/hermitian/lattice.py new file mode 100644 index 00000000..dd50ec7d --- /dev/null +++ b/src/bloqade/squin/analysis/hermitian/lattice.py @@ -0,0 +1,52 @@ +from typing import final +from dataclasses import dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + IsSubsetEqMixin, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class HermitianLattice( + SimpleJoinMixin["HermitianLattice"], + SimpleMeetMixin["HermitianLattice"], + IsSubsetEqMixin["HermitianLattice"], + BoundedLattice["HermitianLattice"], +): + @classmethod + def bottom(cls) -> "HermitianLattice": + return NotAnOperator() + + @classmethod + def top(cls) -> "HermitianLattice": + return PossiblyHermitian() + + +@final +@dataclass +class NotAnOperator(HermitianLattice, metaclass=SingletonMeta): + pass + + +@final +@dataclass +class NotHermitian(HermitianLattice, metaclass=SingletonMeta): + def is_subseteq(self, other: HermitianLattice) -> bool: + return isinstance(other, NotHermitian) + + +@final +@dataclass +class Hermitian(HermitianLattice, metaclass=SingletonMeta): + def is_subseteq(self, other: HermitianLattice) -> bool: + return isinstance(other, Hermitian) + + +@final +@dataclass +class PossiblyHermitian(HermitianLattice, metaclass=SingletonMeta): + pass diff --git a/src/bloqade/squin/analysis/unitary/__init__.py b/src/bloqade/squin/analysis/unitary/__init__.py new file mode 100644 index 00000000..a5fc22b9 --- /dev/null +++ b/src/bloqade/squin/analysis/unitary/__init__.py @@ -0,0 +1,8 @@ +from . import impls as impls +from .lattice import ( + Unitary as Unitary, + NotUnitary as NotUnitary, + NotAnOperator as NotAnOperator, + UnitaryLattice as UnitaryLattice, +) +from .analysis import UnitaryAnalysis as UnitaryAnalysis diff --git a/src/bloqade/squin/analysis/unitary/analysis.py b/src/bloqade/squin/analysis/unitary/analysis.py new file mode 100644 index 00000000..3013bd04 --- /dev/null +++ b/src/bloqade/squin/analysis/unitary/analysis.py @@ -0,0 +1,35 @@ +from kirin import ir +from kirin.analysis import Forward, ForwardFrame + +from ... import op +from .lattice import Unitary, NotUnitary, UnitaryLattice +from ..hermitian import HermitianLattice, HermitianAnalysis + + +class UnitaryAnalysis(Forward): + keys = ["squin.unitary"] + lattice = UnitaryLattice + hermitian_values: dict[ir.SSAValue, HermitianLattice] = dict() + + def run_method(self, method: ir.Method, args: tuple[UnitaryLattice, ...]): + hermitian_frame, _ = HermitianAnalysis(method.dialects).run_analysis(method) + self.hermitian_values.update(hermitian_frame.entries) + return self.run_callable(method.code, (self.lattice.bottom(),) + args) + + def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): + if not isinstance(stmt, op.stmts.Operator): + return (self.lattice.bottom(),) + + if stmt.has_trait(op.traits.Unitary): + return (Unitary(),) + + if ( + trait := stmt.get_trait(op.traits.MaybeUnitary) + ) is not None and trait.is_unitary(stmt): + return (Unitary(),) + + if isinstance(stmt, op.stmts.PrimitiveOp): + # NOTE: simple operator that doesn't have the trait or an impl so it's known not to be unitary + return (NotUnitary(),) + + return (self.lattice.top(),) diff --git a/src/bloqade/squin/analysis/unitary/impls.py b/src/bloqade/squin/analysis/unitary/impls.py new file mode 100644 index 00000000..d829c217 --- /dev/null +++ b/src/bloqade/squin/analysis/unitary/impls.py @@ -0,0 +1,69 @@ +from kirin import interp +from kirin.analysis import ForwardFrame, const +from kirin.dialects import scf, func + +from ... import op +from .lattice import Unitary, NotUnitary +from .analysis import UnitaryAnalysis +from ..hermitian import Hermitian, NotHermitian + + +@op.dialect.register(key="squin.unitary") +class UnitaryMethods(interp.MethodTable): + @interp.impl(op.stmts.Control) + @interp.impl(op.stmts.Adjoint) + def simple_container( + self, + interp: UnitaryAnalysis, + frame: ForwardFrame, + stmt: op.stmts.Control | op.stmts.Adjoint, + ): + return (frame.get(stmt.op),) + + @interp.impl(op.stmts.Scale) + def scale(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale): + is_unitary = frame.get(stmt.op) + + if not is_unitary.is_subseteq(Unitary): + return (is_unitary,) + + # NOTE: need to check if the factor has absolute value squared of 1 + constant_value = stmt.factor.hints.get("const") + if not isinstance(constant_value, const.Value): + return (interp.lattice.top(),) + + num = constant_value.data + if not isinstance(num, (float, int, bool, complex)): + return (interp.lattice.top(),) + + if abs(num) ** 2 == 1: + return (Unitary(),) + + return (NotUnitary(),) + + @interp.impl(op.stmts.Kron) + @interp.impl(op.stmts.Mult) + def binary_op( + self, + interp: UnitaryAnalysis, + frame: ForwardFrame, + stmt: op.stmts.Kron | op.stmts.Mult, + ): + return (frame.get(stmt.lhs).join(frame.get(stmt.rhs)),) + + @interp.impl(op.stmts.Rot) + def rot(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Rot): + if interp.hermitian_values.get(stmt.axis, NotHermitian()).is_equal(Hermitian()): + return (Unitary(),) + else: + return (NotUnitary(),) + + +@scf.dialect.register(key="squin.unitary") +class ScfUnitaryMethods(scf.typeinfer.TypeInfer): + pass + + +@func.dialect.register(key="squin.unitary") +class FuncUnitaryMethods(func.typeinfer.TypeInfer): + pass diff --git a/src/bloqade/squin/analysis/unitary/lattice.py b/src/bloqade/squin/analysis/unitary/lattice.py new file mode 100644 index 00000000..ed59b30c --- /dev/null +++ b/src/bloqade/squin/analysis/unitary/lattice.py @@ -0,0 +1,54 @@ +from typing import final +from dataclasses import dataclass + +from kirin.lattice import ( + SingletonMeta, + BoundedLattice, + IsSubsetEqMixin, + SimpleJoinMixin, + SimpleMeetMixin, +) + + +@dataclass +class UnitaryLattice( + SimpleJoinMixin["UnitaryLattice"], + SimpleMeetMixin["UnitaryLattice"], + IsSubsetEqMixin["UnitaryLattice"], + BoundedLattice["UnitaryLattice"], +): + @classmethod + def bottom(cls) -> "UnitaryLattice": + return NotAnOperator() + + @classmethod + def top(cls) -> "UnitaryLattice": + return PossiblyUnitary() + + +@final +@dataclass +class NotAnOperator(UnitaryLattice, metaclass=SingletonMeta): + pass + + +@final +@dataclass +class NotUnitary(UnitaryLattice, metaclass=SingletonMeta): + + def is_subseteq(self, other: UnitaryLattice) -> bool: + return isinstance(other, NotUnitary) + + +@final +@dataclass +class Unitary(UnitaryLattice, metaclass=SingletonMeta): + + def is_subseteq(self, other: UnitaryLattice) -> bool: + return isinstance(other, Unitary) + + +@final +@dataclass +class PossiblyUnitary(UnitaryLattice, metaclass=SingletonMeta): + pass diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index f2994038..1dab34ea 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -4,6 +4,7 @@ from kirin.dialects import ilist from . import op, wire, noise, qubit +from .rewrite import CanonicalizeUnitaryAndHermitian from .op.rewrite import PyMultToSquinMult from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule @@ -15,6 +16,7 @@ def kernel(self): ilist_desugar_pass = ilist.IListDesugar(self) desugar_pass = Walk(Chain(MeasureDesugarRule(), ApplyDesugarRule())) py_mult_to_mult_pass = PyMultToSquinMult(self) + canonicalize_hermitian_and_unitary = CanonicalizeUnitaryAndHermitian(self) def run_pass(method: ir.Method, *, fold=True, typeinfer=True): method.verify() @@ -31,6 +33,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): if typeinfer: typeinfer_pass(method) # fix types after desugaring + canonicalize_hermitian_and_unitary(method) method.verify_type() return run_pass diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index c4e512df..f59c7650 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -15,7 +15,14 @@ ControlledOpType, ) from .number import NumberType -from .traits import Unitary, HasSites, FixedSites, MaybeUnitary +from .traits import ( + Unitary, + HasSites, + Hermitian, + FixedSites, + MaybeUnitary, + MaybeHermitian, +) from ._dialect import dialect @@ -46,37 +53,52 @@ class BinaryOp(CompositeOp): @statement(dialect=dialect) class Kron(BinaryOp): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) + traits = frozenset( + {ir.Pure(), lowering.FromPythonCall(), MaybeUnitary(), MaybeHermitian()} + ) is_unitary: bool = info.attribute(default=False) + is_hermitian: bool = info.attribute(default=False) result: ir.ResultValue = info.result(KronType[LhsType, RhsType]) @statement(dialect=dialect) class Mult(BinaryOp): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) + traits = frozenset( + {ir.Pure(), lowering.FromPythonCall(), MaybeUnitary(), MaybeHermitian()} + ) is_unitary: bool = info.attribute(default=False) + is_hermitian: bool = info.attribute(default=False) result: ir.ResultValue = info.result(MultType[LhsType, RhsType]) @statement(dialect=dialect) class Adjoint(CompositeOp): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) + traits = frozenset( + {ir.Pure(), lowering.FromPythonCall(), MaybeUnitary(), MaybeHermitian()} + ) is_unitary: bool = info.attribute(default=False) + is_hermitian: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) @statement(dialect=dialect) class Scale(CompositeOp): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) + traits = frozenset( + {ir.Pure(), lowering.FromPythonCall(), MaybeUnitary(), MaybeHermitian()} + ) is_unitary: bool = info.attribute(default=False) + is_hermitian: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) factor: ir.SSAValue = info.argument(NumberType) @statement(dialect=dialect) class Control(CompositeOp): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) + traits = frozenset( + {ir.Pure(), lowering.FromPythonCall(), MaybeUnitary(), MaybeHermitian()} + ) is_unitary: bool = info.attribute(default=False) + is_hermitian: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(ControlledOpType) n_controls: int = info.attribute() result: ir.ResultValue = info.result(ControlOpType[ControlledOpType]) @@ -87,15 +109,18 @@ class Control(CompositeOp): @statement(dialect=dialect) class Rot(CompositeOp): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) axis: ir.SSAValue = info.argument(RotationAxisType) angle: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(ROpType[RotationAxisType]) + is_unitary: bool = info.attribute(default=False) @statement(dialect=dialect) class Identity(CompositeOp): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()}) + traits = frozenset( + {ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites(), Hermitian()} + ) sites: int = info.attribute() @@ -106,6 +131,19 @@ class ConstantOp(PrimitiveOp): ) +@statement +class ConstantHermitian(ConstantOp): + traits = frozenset( + { + ir.Pure(), + lowering.FromPythonCall(), + ir.ConstantLike(), + FixedSites(1), + Hermitian(), + } + ) + + @statement class ConstantUnitary(ConstantOp): traits = frozenset( @@ -189,12 +227,25 @@ class CliffordOp(ConstantUnitary): @statement class PauliOp(CliffordOp): + traits = frozenset( + { + ir.Pure(), + lowering.FromPythonCall(), + ir.ConstantLike(), + Unitary(), + FixedSites(1), + Hermitian(), + } + ) result: ir.ResultValue = info.result(type=PauliOpType) @statement(dialect=dialect) class PauliString(ConstantUnitary): - traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()}) + traits = frozenset( + {ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites(), Hermitian()} + ) + is_hermitian: bool = info.attribute(default=False) string: str = info.attribute() result: ir.ResultValue = info.result(type=PauliStringType) @@ -249,7 +300,7 @@ class T(ConstantUnitary): @statement(dialect=dialect) -class P0(ConstantOp): +class P0(ConstantHermitian): """ The $P_0$ projection operator. @@ -262,7 +313,7 @@ class P0(ConstantOp): @statement(dialect=dialect) -class P1(ConstantOp): +class P1(ConstantHermitian): """ The $P_1$ projection operator. diff --git a/src/bloqade/squin/op/traits.py b/src/bloqade/squin/op/traits.py index 506fdbf2..78c5cf61 100644 --- a/src/bloqade/squin/op/traits.py +++ b/src/bloqade/squin/op/traits.py @@ -41,3 +41,22 @@ def is_unitary(self, stmt: ir.Statement): def set_unitary(self, stmt: ir.Statement, value: bool): stmt.attributes["is_unitary"] = ir.PyAttr(value) return + + +@dataclass(frozen=True) +class Hermitian(ir.StmtTrait): + pass + + +@dataclass(frozen=True) +class MaybeHermitian(ir.StmtTrait): + + def is_hermitian(self, stmt: ir.Statement): + attr = stmt.get_attr_or_prop("is_hermitian") + if attr is None: + return False + return cast(ir.PyAttr[bool], attr).data + + def set_hermitian(self, stmt: ir.Statement, value: bool): + stmt.attributes["is_hermitian"] = ir.PyAttr(value) + return diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index f45da299..cdc341da 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,3 +1,6 @@ +from .canonicalize import ( + CanonicalizeUnitaryAndHermitian as CanonicalizeUnitaryAndHermitian, +) from .wrap_analysis import ( SitesAttribute as SitesAttribute, AddressAttribute as AddressAttribute, diff --git a/src/bloqade/squin/rewrite/canonicalize.py b/src/bloqade/squin/rewrite/canonicalize.py index df6af306..eb624f0f 100644 --- a/src/bloqade/squin/rewrite/canonicalize.py +++ b/src/bloqade/squin/rewrite/canonicalize.py @@ -1,10 +1,14 @@ from typing import cast +from dataclasses import dataclass from kirin import ir -from kirin.rewrite import abc +from kirin.passes import Pass +from kirin.rewrite import Walk, abc from kirin.dialects import cf -from .. import wire +from .. import op, wire, analysis +from ..analysis.unitary import Unitary, UnitaryLattice +from ..analysis.hermitian import Hermitian, HermitianLattice class CanonicalizeWired(abc.RewriteRule): @@ -58,3 +62,62 @@ def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: node.delete() return abc.RewriteResult(has_done_something=True) + + +@dataclass +class CanonicalizeHermitian(abc.RewriteRule): + hermitian_values: dict[ir.SSAValue, HermitianLattice] + + def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: + if not isinstance(node, op.stmts.Operator): + return abc.RewriteResult() + + maybe_hermitian = node.get_trait(op.traits.MaybeHermitian) + + if maybe_hermitian is None: + return abc.RewriteResult() + + is_hermitian = self.hermitian_values.get(node.result) + if is_hermitian is None: + return abc.RewriteResult() + + maybe_hermitian.set_hermitian(node, is_hermitian.is_equal(Hermitian())) + return abc.RewriteResult(has_done_something=True) + + +@dataclass +class CanonicalizeUnitary(abc.RewriteRule): + unitary_values: dict[ir.SSAValue, UnitaryLattice] + + def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult: + if not isinstance(node, op.stmts.Operator): + return abc.RewriteResult() + + maybe_unitary = node.get_trait(op.traits.MaybeUnitary) + + if maybe_unitary is None: + return abc.RewriteResult() + + is_unitary = self.unitary_values.get(node.result) + if is_unitary is None: + return abc.RewriteResult() + + new_unitary_status = is_unitary.is_equal(Unitary()) + maybe_unitary.set_unitary(node, new_unitary_status) + + return abc.RewriteResult(has_done_something=True) + + +class CanonicalizeUnitaryAndHermitian(Pass): + def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult: + unitary_analysis = analysis.unitary.UnitaryAnalysis(mt.dialects) + unitary_frame, _ = unitary_analysis.run_analysis(mt) + + hermitian_values = unitary_analysis.hermitian_values + unitary_values = unitary_frame.entries + + hermitian_rewrite = CanonicalizeHermitian(hermitian_values=hermitian_values) + unitary_rewrite = CanonicalizeUnitary(unitary_values=unitary_values) + + result = Walk(hermitian_rewrite).rewrite(mt.code) + return Walk(unitary_rewrite).rewrite(mt.code).join(result) diff --git a/test/cirq_utils/noise/test_noisy_ghz.py b/test/cirq_utils/noise/test_noisy_ghz.py index 95c330e3..1889672e 100644 --- a/test/cirq_utils/noise/test_noisy_ghz.py +++ b/test/cirq_utils/noise/test_noisy_ghz.py @@ -103,4 +103,4 @@ def fidelity(rho: np.ndarray, sigma: np.ndarray) -> float: for n, fid_squin in zip(range(2, max_num_qubits), fidelities_squin): # NOTE: higher fidelity requires larger nshots in order for this to converge # this gates harder for more qubits and takes a lot longer, which doesn't make sense for the test here - assert math.isclose(fid_squin, 1, abs_tol=1e-2 * n) + assert math.isclose(fid_squin, 1, abs_tol=2 * 1e-2 * n) diff --git a/test/squin/rewrite/test_canonicalize.py b/test/squin/rewrite/test_canonicalize.py index 02667182..8b5e5031 100644 --- a/test/squin/rewrite/test_canonicalize.py +++ b/test/squin/rewrite/test_canonicalize.py @@ -1,8 +1,10 @@ from kirin import ir, types, rewrite -from kirin.dialects import cf, py +from kirin.dialects import cf, py, func +from bloqade import squin from bloqade.squin import wire from bloqade.test_utils import assert_nodes +from bloqade.squin.analysis import unitary, hermitian from bloqade.squin.rewrite.canonicalize import CanonicalizeWired @@ -52,3 +54,89 @@ def test_canonicalize_wired_trivial(): rewrite.Walk(CanonicalizeWired()).rewrite(outer_region) assert_nodes(outer_region, expected_region) + + +def test_hermitian_and_unitary(): + + @squin.kernel(fold=False) + def main(): + n = 1 + x = squin.op.x() + _ = n * x + squin.op.control(x, n_controls=1) + squin.op.pauli_string(string="XYZ") + squin.op.pauli_string(string="XYX") + + squin.op.cx() + + y = squin.op.y() + + rx = squin.op.rot(axis=x, angle=0.125) + _ = rx * squin.op.adjoint(rx) + + squin.op.p0() + squin.op.rot(axis=y, angle=0) + + main.print() + + hermitian_frame, _ = hermitian.HermitianAnalysis(main.dialects).run_analysis( + main, no_raise=False + ) + + main.print(analysis=hermitian_frame.entries) + + unitary_frame, _ = unitary.UnitaryAnalysis(main.dialects).run_analysis( + main, no_raise=False + ) + main.print(analysis=unitary_frame.entries) + + def is_hermitian(stmt: squin.op.stmts.Operator | func.Invoke) -> bool: + return hermitian_frame.get(stmt.result).is_equal(hermitian.Hermitian()) + + def is_not_hermitian(stmt: squin.op.stmts.Operator) -> bool: + return hermitian_frame.get(stmt.result).is_equal(hermitian.NotHermitian()) + + def maybe_hermitian(stmt: squin.op.stmts.Operator) -> bool: + return hermitian_frame.get(stmt.result).is_equal(hermitian.PossiblyHermitian()) + + def is_unitary(stmt: squin.op.stmts.Operator | func.Invoke) -> bool: + return unitary_frame.get(stmt.result).is_equal(unitary.Unitary()) + + def is_not_unitary(stmt: squin.op.stmts.Operator) -> bool: + return unitary_frame.get(stmt.result).is_equal(unitary.NotUnitary()) + + for stmt in main.callable_region.blocks[0].stmts: + match stmt: + case squin.op.stmts.X() | squin.op.stmts.Y(): + assert is_hermitian(stmt) + assert is_unitary(stmt) + + case squin.op.stmts.Scale(): + assert is_hermitian(stmt) + assert is_unitary(stmt) + assert stmt.is_hermitian + assert stmt.is_unitary + + case squin.op.stmts.PauliString(): + assert is_unitary(stmt) + assert is_hermitian(stmt) + + case func.Invoke(): + # NOTE: only cx above + assert is_unitary(stmt) + assert is_hermitian(stmt) + + case squin.op.stmts.Rot(): + assert is_unitary(stmt) + assert maybe_hermitian(stmt) + assert stmt.is_unitary + + case squin.op.stmts.Mult(): + assert is_unitary(stmt) + assert maybe_hermitian(stmt) + assert stmt.is_unitary + assert not stmt.is_hermitian + + case squin.op.stmts.P0(): + assert is_hermitian(stmt) + assert is_not_unitary(stmt)