Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bloqade/squin/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import address_impl as address_impl
from . import unitary as unitary, hermitian as hermitian, address_impl as address_impl
10 changes: 10 additions & 0 deletions src/bloqade/squin/analysis/hermitian/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Need this for impl registration to work properly!
from . import impls as impls
from .lattice import (
Hermitian as Hermitian,
NotHermitian as NotHermitian,
NotAnOperator as NotAnOperator,
HermitianLattice as HermitianLattice,
PossiblyHermitian as PossiblyHermitian,
)
from .analysis import HermitianAnalysis as HermitianAnalysis
31 changes: 31 additions & 0 deletions src/bloqade/squin/analysis/hermitian/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from kirin import ir
from kirin.analysis import Forward, ForwardFrame

from ... import op
from .lattice import Hermitian, NotHermitian, HermitianLattice


class HermitianAnalysis(Forward):
keys = ["squin.hermitian"]
lattice = HermitianLattice

def run_method(self, method: ir.Method, args: tuple[HermitianLattice, ...]):
return self.run_callable(method.code, (self.lattice.bottom(),) + args)

def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
if not isinstance(stmt, op.stmts.Operator):
return (self.lattice.bottom(),)

if stmt.has_trait(op.traits.Hermitian):
return (Hermitian(),)

if (
trait := stmt.get_trait(op.traits.MaybeHermitian)
) is not None and trait.is_hermitian(stmt):
return (Hermitian(),)

if isinstance(stmt, op.stmts.PrimitiveOp):
# NOTE: simple operator without the hermitian trait, so we know it's non-hermitian
return (NotHermitian(),)

return (self.lattice.top(),)
59 changes: 59 additions & 0 deletions src/bloqade/squin/analysis/hermitian/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from kirin import types, interp
from kirin.analysis import ForwardFrame
from kirin.dialects import scf, func

from ... import op
from .lattice import Hermitian
from .analysis import HermitianAnalysis


@op.dialect.register(key="squin.hermitian")
class HermitianMethods(interp.MethodTable):
@interp.impl(op.stmts.Control)
@interp.impl(op.stmts.Adjoint)
def simple_container(
self,
interp: HermitianAnalysis,
frame: ForwardFrame,
stmt: op.stmts.Control | op.stmts.Adjoint,
):
return (frame.get(stmt.op),)

@interp.impl(op.stmts.Scale)
def scale(
self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale
):
is_hermitian = frame.get(stmt.op)

if not is_hermitian.is_subseteq(Hermitian):
return (is_hermitian,)

# NOTE: need to check if the factor is a real number
if stmt.factor.type.is_subseteq(types.Float | types.Int | types.Bool):
return (Hermitian(),)

# NOTE: could still be a complex number type with zero imaginary part
return (interp.lattice.top(),)

@interp.impl(op.stmts.Kron)
def kron(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Kron):
is_hermitian = frame.get(stmt.lhs).join(frame.get(stmt.rhs))
return (is_hermitian,)

@interp.impl(op.stmts.Mult)
def mult(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Mult):
# NOTE: this could be smarter here and check whether lhs == adjoint(rhs)
if stmt.lhs != stmt.rhs:
return (interp.lattice.top(),)

return (frame.get(stmt.lhs),)


@scf.dialect.register(key="squin.hermitian")
class ScfHermitianMethods(scf.typeinfer.TypeInfer):
pass


@func.dialect.register(key="squin.hermitian")
class FuncHermitianMethods(func.typeinfer.TypeInfer):
pass
52 changes: 52 additions & 0 deletions src/bloqade/squin/analysis/hermitian/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import final
from dataclasses import dataclass

from kirin.lattice import (
SingletonMeta,
BoundedLattice,
IsSubsetEqMixin,
SimpleJoinMixin,
SimpleMeetMixin,
)


@dataclass
class HermitianLattice(
SimpleJoinMixin["HermitianLattice"],
SimpleMeetMixin["HermitianLattice"],
IsSubsetEqMixin["HermitianLattice"],
BoundedLattice["HermitianLattice"],
):
@classmethod
def bottom(cls) -> "HermitianLattice":
return NotAnOperator()

@classmethod
def top(cls) -> "HermitianLattice":
return PossiblyHermitian()


@final
@dataclass
class NotAnOperator(HermitianLattice, metaclass=SingletonMeta):
pass


@final
@dataclass
class NotHermitian(HermitianLattice, metaclass=SingletonMeta):
def is_subseteq(self, other: HermitianLattice) -> bool:
return isinstance(other, NotHermitian)


@final
@dataclass
class Hermitian(HermitianLattice, metaclass=SingletonMeta):
def is_subseteq(self, other: HermitianLattice) -> bool:
return isinstance(other, Hermitian)


@final
@dataclass
class PossiblyHermitian(HermitianLattice, metaclass=SingletonMeta):
pass
8 changes: 8 additions & 0 deletions src/bloqade/squin/analysis/unitary/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from . import impls as impls
from .lattice import (
Unitary as Unitary,
NotUnitary as NotUnitary,
NotAnOperator as NotAnOperator,
UnitaryLattice as UnitaryLattice,
)
from .analysis import UnitaryAnalysis as UnitaryAnalysis
35 changes: 35 additions & 0 deletions src/bloqade/squin/analysis/unitary/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from kirin import ir
from kirin.analysis import Forward, ForwardFrame

from ... import op
from .lattice import Unitary, NotUnitary, UnitaryLattice
from ..hermitian import HermitianLattice, HermitianAnalysis


class UnitaryAnalysis(Forward):
keys = ["squin.unitary"]
lattice = UnitaryLattice
hermitian_values: dict[ir.SSAValue, HermitianLattice] = dict()

def run_method(self, method: ir.Method, args: tuple[UnitaryLattice, ...]):
hermitian_frame, _ = HermitianAnalysis(method.dialects).run_analysis(method)
self.hermitian_values.update(hermitian_frame.entries)
return self.run_callable(method.code, (self.lattice.bottom(),) + args)

def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement):
if not isinstance(stmt, op.stmts.Operator):
return (self.lattice.bottom(),)

if stmt.has_trait(op.traits.Unitary):
return (Unitary(),)

if (
trait := stmt.get_trait(op.traits.MaybeUnitary)
) is not None and trait.is_unitary(stmt):
return (Unitary(),)

if isinstance(stmt, op.stmts.PrimitiveOp):
# NOTE: simple operator that doesn't have the trait or an impl so it's known not to be unitary
return (NotUnitary(),)

return (self.lattice.top(),)
69 changes: 69 additions & 0 deletions src/bloqade/squin/analysis/unitary/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from kirin import interp
from kirin.analysis import ForwardFrame, const
from kirin.dialects import scf, func

from ... import op
from .lattice import Unitary, NotUnitary
from .analysis import UnitaryAnalysis
from ..hermitian import Hermitian, NotHermitian


@op.dialect.register(key="squin.unitary")
class UnitaryMethods(interp.MethodTable):
@interp.impl(op.stmts.Control)
@interp.impl(op.stmts.Adjoint)
def simple_container(
self,
interp: UnitaryAnalysis,
frame: ForwardFrame,
stmt: op.stmts.Control | op.stmts.Adjoint,
):
return (frame.get(stmt.op),)

@interp.impl(op.stmts.Scale)
def scale(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale):
is_unitary = frame.get(stmt.op)

if not is_unitary.is_subseteq(Unitary):
return (is_unitary,)

# NOTE: need to check if the factor has absolute value squared of 1
constant_value = stmt.factor.hints.get("const")
if not isinstance(constant_value, const.Value):
return (interp.lattice.top(),)

num = constant_value.data
if not isinstance(num, (float, int, bool, complex)):
return (interp.lattice.top(),)

if abs(num) ** 2 == 1:
return (Unitary(),)

return (NotUnitary(),)

@interp.impl(op.stmts.Kron)
@interp.impl(op.stmts.Mult)
def binary_op(
self,
interp: UnitaryAnalysis,
frame: ForwardFrame,
stmt: op.stmts.Kron | op.stmts.Mult,
):
return (frame.get(stmt.lhs).join(frame.get(stmt.rhs)),)

@interp.impl(op.stmts.Rot)
def rot(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Rot):
if interp.hermitian_values.get(stmt.axis, NotHermitian()).is_equal(Hermitian()):
return (Unitary(),)
else:
return (NotUnitary(),)


@scf.dialect.register(key="squin.unitary")
class ScfUnitaryMethods(scf.typeinfer.TypeInfer):
pass


@func.dialect.register(key="squin.unitary")
class FuncUnitaryMethods(func.typeinfer.TypeInfer):
pass
54 changes: 54 additions & 0 deletions src/bloqade/squin/analysis/unitary/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import final
from dataclasses import dataclass

from kirin.lattice import (
SingletonMeta,
BoundedLattice,
IsSubsetEqMixin,
SimpleJoinMixin,
SimpleMeetMixin,
)


@dataclass
class UnitaryLattice(
SimpleJoinMixin["UnitaryLattice"],
SimpleMeetMixin["UnitaryLattice"],
IsSubsetEqMixin["UnitaryLattice"],
BoundedLattice["UnitaryLattice"],
):
@classmethod
def bottom(cls) -> "UnitaryLattice":
return NotAnOperator()

@classmethod
def top(cls) -> "UnitaryLattice":
return PossiblyUnitary()


@final
@dataclass
class NotAnOperator(UnitaryLattice, metaclass=SingletonMeta):
pass


@final
@dataclass
class NotUnitary(UnitaryLattice, metaclass=SingletonMeta):

def is_subseteq(self, other: UnitaryLattice) -> bool:
return isinstance(other, NotUnitary)


@final
@dataclass
class Unitary(UnitaryLattice, metaclass=SingletonMeta):

def is_subseteq(self, other: UnitaryLattice) -> bool:
return isinstance(other, Unitary)


@final
@dataclass
class PossiblyUnitary(UnitaryLattice, metaclass=SingletonMeta):
pass
3 changes: 3 additions & 0 deletions src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from kirin.dialects import ilist

from . import op, wire, noise, qubit
from .rewrite import CanonicalizeUnitaryAndHermitian
from .op.rewrite import PyMultToSquinMult
from .rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule

Expand All @@ -15,6 +16,7 @@ def kernel(self):
ilist_desugar_pass = ilist.IListDesugar(self)
desugar_pass = Walk(Chain(MeasureDesugarRule(), ApplyDesugarRule()))
py_mult_to_mult_pass = PyMultToSquinMult(self)
canonicalize_hermitian_and_unitary = CanonicalizeUnitaryAndHermitian(self)

def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
method.verify()
Expand All @@ -31,6 +33,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):

if typeinfer:
typeinfer_pass(method) # fix types after desugaring
canonicalize_hermitian_and_unitary(method)
method.verify_type()

return run_pass
Expand Down
Loading