-
Notifications
You must be signed in to change notification settings - Fork 1
Implement hermitian and unitary analysis and canonicalization pass #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
38dee18
Add Hermitian trait to squin ops
david-pl f8e398a
Fix typo
david-pl 5df8a57
Implement hermitian and unitary analysis and canonicalization pass
david-pl 4e2e0d8
Add tests
david-pl 1a7b104
Loosen constraint on flaky test
david-pl 4ea2acb
Merge branch 'main' into david/471-hermitian-trait
david-pl 145b4e5
Update src/bloqade/squin/analysis/hermitian_and_unitary.py
david-pl b4f036f
Merge branch 'main' into david/471-hermitian-trait
david-pl ef98ba3
Implement a proper lattice for the analyses
david-pl f072013
Update src/bloqade/squin/rewrite/canonicalize.py
david-pl b6d4f04
Fix analysis for PauliStrings
david-pl ef35d3a
Fix lattice for Hermitian
david-pl 2599ba3
Fix lattice for unitary analysis
david-pl 29f3ba7
Better fallback for hermitian analysis on primitive ops
david-pl cab82c7
Merge branch 'main' into david/471-hermitian-trait
david-pl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from . import address_impl as address_impl | ||
from . import unitary as unitary, hermitian as hermitian, address_impl as address_impl |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Need this for impl registration to work properly! | ||
from . import impls as impls | ||
from .lattice import ( | ||
Hermitian as Hermitian, | ||
NotHermitian as NotHermitian, | ||
NotAnOperator as NotAnOperator, | ||
HermitianLattice as HermitianLattice, | ||
PossiblyHermitian as PossiblyHermitian, | ||
) | ||
from .analysis import HermitianAnalysis as HermitianAnalysis |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from kirin import ir | ||
from kirin.analysis import Forward, ForwardFrame | ||
|
||
from ... import op | ||
from .lattice import Hermitian, NotHermitian, HermitianLattice | ||
|
||
|
||
class HermitianAnalysis(Forward): | ||
keys = ["squin.hermitian"] | ||
lattice = HermitianLattice | ||
|
||
def run_method(self, method: ir.Method, args: tuple[HermitianLattice, ...]): | ||
return self.run_callable(method.code, (self.lattice.bottom(),) + args) | ||
|
||
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): | ||
if not isinstance(stmt, op.stmts.Operator): | ||
return (self.lattice.bottom(),) | ||
|
||
if stmt.has_trait(op.traits.Hermitian): | ||
return (Hermitian(),) | ||
|
||
if ( | ||
trait := stmt.get_trait(op.traits.MaybeHermitian) | ||
) is not None and trait.is_hermitian(stmt): | ||
return (Hermitian(),) | ||
|
||
if isinstance(stmt, op.stmts.PrimitiveOp): | ||
# NOTE: simple operator without the hermitian trait, so we know it's non-hermitian | ||
return (NotHermitian(),) | ||
|
||
return (self.lattice.top(),) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from kirin import types, interp | ||
from kirin.analysis import ForwardFrame | ||
from kirin.dialects import scf, func | ||
|
||
from ... import op | ||
from .lattice import Hermitian | ||
from .analysis import HermitianAnalysis | ||
|
||
|
||
@op.dialect.register(key="squin.hermitian") | ||
class HermitianMethods(interp.MethodTable): | ||
@interp.impl(op.stmts.Control) | ||
@interp.impl(op.stmts.Adjoint) | ||
def simple_container( | ||
self, | ||
interp: HermitianAnalysis, | ||
frame: ForwardFrame, | ||
stmt: op.stmts.Control | op.stmts.Adjoint, | ||
): | ||
return (frame.get(stmt.op),) | ||
|
||
@interp.impl(op.stmts.Scale) | ||
def scale( | ||
self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale | ||
): | ||
is_hermitian = frame.get(stmt.op) | ||
|
||
if not is_hermitian.is_subseteq(Hermitian): | ||
return (is_hermitian,) | ||
|
||
# NOTE: need to check if the factor is a real number | ||
if stmt.factor.type.is_subseteq(types.Float | types.Int | types.Bool): | ||
return (Hermitian(),) | ||
|
||
# NOTE: could still be a complex number type with zero imaginary part | ||
return (interp.lattice.top(),) | ||
|
||
@interp.impl(op.stmts.Kron) | ||
def kron(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Kron): | ||
is_hermitian = frame.get(stmt.lhs).join(frame.get(stmt.rhs)) | ||
return (is_hermitian,) | ||
|
||
@interp.impl(op.stmts.Mult) | ||
def mult(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Mult): | ||
# NOTE: this could be smarter here and check whether lhs == adjoint(rhs) | ||
if stmt.lhs != stmt.rhs: | ||
return (interp.lattice.top(),) | ||
|
||
return (frame.get(stmt.lhs),) | ||
|
||
|
||
@scf.dialect.register(key="squin.hermitian") | ||
class ScfHermitianMethods(scf.typeinfer.TypeInfer): | ||
pass | ||
|
||
|
||
@func.dialect.register(key="squin.hermitian") | ||
class FuncHermitianMethods(func.typeinfer.TypeInfer): | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import final | ||
from dataclasses import dataclass | ||
|
||
from kirin.lattice import ( | ||
SingletonMeta, | ||
BoundedLattice, | ||
IsSubsetEqMixin, | ||
SimpleJoinMixin, | ||
SimpleMeetMixin, | ||
) | ||
|
||
|
||
@dataclass | ||
class HermitianLattice( | ||
SimpleJoinMixin["HermitianLattice"], | ||
SimpleMeetMixin["HermitianLattice"], | ||
IsSubsetEqMixin["HermitianLattice"], | ||
BoundedLattice["HermitianLattice"], | ||
): | ||
@classmethod | ||
def bottom(cls) -> "HermitianLattice": | ||
return NotAnOperator() | ||
|
||
@classmethod | ||
def top(cls) -> "HermitianLattice": | ||
return PossiblyHermitian() | ||
|
||
|
||
@final | ||
@dataclass | ||
class NotAnOperator(HermitianLattice, metaclass=SingletonMeta): | ||
pass | ||
|
||
|
||
@final | ||
@dataclass | ||
class NotHermitian(HermitianLattice, metaclass=SingletonMeta): | ||
def is_subseteq(self, other: HermitianLattice) -> bool: | ||
return isinstance(other, NotHermitian) | ||
|
||
|
||
@final | ||
@dataclass | ||
class Hermitian(HermitianLattice, metaclass=SingletonMeta): | ||
def is_subseteq(self, other: HermitianLattice) -> bool: | ||
return isinstance(other, Hermitian) | ||
|
||
|
||
@final | ||
@dataclass | ||
class PossiblyHermitian(HermitianLattice, metaclass=SingletonMeta): | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from . import impls as impls | ||
from .lattice import ( | ||
Unitary as Unitary, | ||
NotUnitary as NotUnitary, | ||
NotAnOperator as NotAnOperator, | ||
UnitaryLattice as UnitaryLattice, | ||
) | ||
from .analysis import UnitaryAnalysis as UnitaryAnalysis |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from kirin import ir | ||
from kirin.analysis import Forward, ForwardFrame | ||
|
||
from ... import op | ||
from .lattice import Unitary, NotUnitary, UnitaryLattice | ||
from ..hermitian import HermitianLattice, HermitianAnalysis | ||
|
||
|
||
class UnitaryAnalysis(Forward): | ||
keys = ["squin.unitary"] | ||
lattice = UnitaryLattice | ||
hermitian_values: dict[ir.SSAValue, HermitianLattice] = dict() | ||
|
||
def run_method(self, method: ir.Method, args: tuple[UnitaryLattice, ...]): | ||
hermitian_frame, _ = HermitianAnalysis(method.dialects).run_analysis(method) | ||
self.hermitian_values.update(hermitian_frame.entries) | ||
return self.run_callable(method.code, (self.lattice.bottom(),) + args) | ||
|
||
def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): | ||
if not isinstance(stmt, op.stmts.Operator): | ||
return (self.lattice.bottom(),) | ||
|
||
if stmt.has_trait(op.traits.Unitary): | ||
return (Unitary(),) | ||
|
||
if ( | ||
trait := stmt.get_trait(op.traits.MaybeUnitary) | ||
) is not None and trait.is_unitary(stmt): | ||
return (Unitary(),) | ||
|
||
if isinstance(stmt, op.stmts.PrimitiveOp): | ||
# NOTE: simple operator that doesn't have the trait or an impl so it's known not to be unitary | ||
return (NotUnitary(),) | ||
|
||
return (self.lattice.top(),) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from kirin import interp | ||
from kirin.analysis import ForwardFrame, const | ||
from kirin.dialects import scf, func | ||
|
||
from ... import op | ||
from .lattice import Unitary, NotUnitary | ||
from .analysis import UnitaryAnalysis | ||
from ..hermitian import Hermitian, NotHermitian | ||
|
||
|
||
@op.dialect.register(key="squin.unitary") | ||
class UnitaryMethods(interp.MethodTable): | ||
@interp.impl(op.stmts.Control) | ||
@interp.impl(op.stmts.Adjoint) | ||
def simple_container( | ||
self, | ||
interp: UnitaryAnalysis, | ||
frame: ForwardFrame, | ||
stmt: op.stmts.Control | op.stmts.Adjoint, | ||
): | ||
return (frame.get(stmt.op),) | ||
|
||
@interp.impl(op.stmts.Scale) | ||
def scale(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale): | ||
is_unitary = frame.get(stmt.op) | ||
|
||
if not is_unitary.is_subseteq(Unitary): | ||
return (is_unitary,) | ||
|
||
# NOTE: need to check if the factor has absolute value squared of 1 | ||
constant_value = stmt.factor.hints.get("const") | ||
if not isinstance(constant_value, const.Value): | ||
return (interp.lattice.top(),) | ||
|
||
num = constant_value.data | ||
if not isinstance(num, (float, int, bool, complex)): | ||
return (interp.lattice.top(),) | ||
|
||
if abs(num) ** 2 == 1: | ||
return (Unitary(),) | ||
|
||
return (NotUnitary(),) | ||
|
||
@interp.impl(op.stmts.Kron) | ||
@interp.impl(op.stmts.Mult) | ||
def binary_op( | ||
self, | ||
interp: UnitaryAnalysis, | ||
frame: ForwardFrame, | ||
stmt: op.stmts.Kron | op.stmts.Mult, | ||
): | ||
return (frame.get(stmt.lhs).join(frame.get(stmt.rhs)),) | ||
|
||
@interp.impl(op.stmts.Rot) | ||
def rot(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Rot): | ||
if interp.hermitian_values.get(stmt.axis, NotHermitian()).is_equal(Hermitian()): | ||
return (Unitary(),) | ||
else: | ||
return (NotUnitary(),) | ||
|
||
|
||
@scf.dialect.register(key="squin.unitary") | ||
class ScfUnitaryMethods(scf.typeinfer.TypeInfer): | ||
pass | ||
|
||
|
||
@func.dialect.register(key="squin.unitary") | ||
class FuncUnitaryMethods(func.typeinfer.TypeInfer): | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import final | ||
from dataclasses import dataclass | ||
|
||
from kirin.lattice import ( | ||
SingletonMeta, | ||
BoundedLattice, | ||
IsSubsetEqMixin, | ||
SimpleJoinMixin, | ||
SimpleMeetMixin, | ||
) | ||
|
||
|
||
@dataclass | ||
class UnitaryLattice( | ||
SimpleJoinMixin["UnitaryLattice"], | ||
SimpleMeetMixin["UnitaryLattice"], | ||
IsSubsetEqMixin["UnitaryLattice"], | ||
BoundedLattice["UnitaryLattice"], | ||
): | ||
@classmethod | ||
def bottom(cls) -> "UnitaryLattice": | ||
return NotAnOperator() | ||
|
||
@classmethod | ||
def top(cls) -> "UnitaryLattice": | ||
return PossiblyUnitary() | ||
|
||
|
||
@final | ||
@dataclass | ||
class NotAnOperator(UnitaryLattice, metaclass=SingletonMeta): | ||
pass | ||
|
||
|
||
@final | ||
@dataclass | ||
class NotUnitary(UnitaryLattice, metaclass=SingletonMeta): | ||
|
||
def is_subseteq(self, other: UnitaryLattice) -> bool: | ||
return isinstance(other, NotUnitary) | ||
|
||
|
||
@final | ||
@dataclass | ||
class Unitary(UnitaryLattice, metaclass=SingletonMeta): | ||
|
||
def is_subseteq(self, other: UnitaryLattice) -> bool: | ||
return isinstance(other, Unitary) | ||
|
||
|
||
@final | ||
@dataclass | ||
class PossiblyUnitary(UnitaryLattice, metaclass=SingletonMeta): | ||
pass |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.