- 
                Notifications
    
You must be signed in to change notification settings  - Fork 1
 
Implement hermitian and unitary analysis and canonicalization pass #475
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Closed
      
      
    
  
     Closed
                    Changes from all commits
      Commits
    
    
            Show all changes
          
          
            15 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      38dee18
              
                Add Hermitian trait to squin ops
              
              
                david-pl f8e398a
              
                Fix typo
              
              
                david-pl 5df8a57
              
                Implement hermitian and unitary analysis and canonicalization pass
              
              
                david-pl 4e2e0d8
              
                Add tests
              
              
                david-pl 1a7b104
              
                Loosen constraint on flaky test
              
              
                david-pl 4ea2acb
              
                Merge branch 'main' into david/471-hermitian-trait
              
              
                david-pl 145b4e5
              
                Update src/bloqade/squin/analysis/hermitian_and_unitary.py
              
              
                david-pl b4f036f
              
                Merge branch 'main' into david/471-hermitian-trait
              
              
                david-pl ef98ba3
              
                Implement a proper lattice for the analyses
              
              
                david-pl f072013
              
                Update src/bloqade/squin/rewrite/canonicalize.py
              
              
                david-pl b6d4f04
              
                Fix analysis for PauliStrings
              
              
                david-pl ef35d3a
              
                Fix lattice for Hermitian
              
              
                david-pl 2599ba3
              
                Fix lattice for unitary analysis
              
              
                david-pl 29f3ba7
              
                Better fallback for hermitian analysis on primitive ops
              
              
                david-pl cab82c7
              
                Merge branch 'main' into david/471-hermitian-trait
              
              
                david-pl File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1 +1 @@ | ||
| from . import address_impl as address_impl | ||
| from . import unitary as unitary, hermitian as hermitian, address_impl as address_impl | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| # Need this for impl registration to work properly! | ||
| from . import impls as impls | ||
| from .lattice import ( | ||
| Hermitian as Hermitian, | ||
| NotHermitian as NotHermitian, | ||
| NotAnOperator as NotAnOperator, | ||
| HermitianLattice as HermitianLattice, | ||
| PossiblyHermitian as PossiblyHermitian, | ||
| ) | ||
| from .analysis import HermitianAnalysis as HermitianAnalysis | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| from kirin import ir | ||
| from kirin.analysis import Forward, ForwardFrame | ||
| 
     | 
||
| from ... import op | ||
| from .lattice import Hermitian, NotHermitian, HermitianLattice | ||
| 
     | 
||
| 
     | 
||
| class HermitianAnalysis(Forward): | ||
| keys = ["squin.hermitian"] | ||
| lattice = HermitianLattice | ||
| 
     | 
||
| def run_method(self, method: ir.Method, args: tuple[HermitianLattice, ...]): | ||
| return self.run_callable(method.code, (self.lattice.bottom(),) + args) | ||
| 
     | 
||
| def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): | ||
| if not isinstance(stmt, op.stmts.Operator): | ||
| return (self.lattice.bottom(),) | ||
| 
     | 
||
| if stmt.has_trait(op.traits.Hermitian): | ||
| return (Hermitian(),) | ||
| 
     | 
||
| if ( | ||
| trait := stmt.get_trait(op.traits.MaybeHermitian) | ||
| ) is not None and trait.is_hermitian(stmt): | ||
| return (Hermitian(),) | ||
| 
     | 
||
| if isinstance(stmt, op.stmts.PrimitiveOp): | ||
| # NOTE: simple operator without the hermitian trait, so we know it's non-hermitian | ||
| return (NotHermitian(),) | ||
| 
     | 
||
| return (self.lattice.top(),) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from kirin import types, interp | ||
| from kirin.analysis import ForwardFrame | ||
| from kirin.dialects import scf, func | ||
| 
     | 
||
| from ... import op | ||
| from .lattice import Hermitian | ||
| from .analysis import HermitianAnalysis | ||
| 
     | 
||
| 
     | 
||
| @op.dialect.register(key="squin.hermitian") | ||
| class HermitianMethods(interp.MethodTable): | ||
| @interp.impl(op.stmts.Control) | ||
| @interp.impl(op.stmts.Adjoint) | ||
| def simple_container( | ||
| self, | ||
| interp: HermitianAnalysis, | ||
| frame: ForwardFrame, | ||
| stmt: op.stmts.Control | op.stmts.Adjoint, | ||
| ): | ||
| return (frame.get(stmt.op),) | ||
| 
     | 
||
| @interp.impl(op.stmts.Scale) | ||
| def scale( | ||
| self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale | ||
| ): | ||
| is_hermitian = frame.get(stmt.op) | ||
| 
     | 
||
| if not is_hermitian.is_subseteq(Hermitian): | ||
| return (is_hermitian,) | ||
| 
     | 
||
| # NOTE: need to check if the factor is a real number | ||
| if stmt.factor.type.is_subseteq(types.Float | types.Int | types.Bool): | ||
| return (Hermitian(),) | ||
| 
     | 
||
| # NOTE: could still be a complex number type with zero imaginary part | ||
| return (interp.lattice.top(),) | ||
| 
     | 
||
| @interp.impl(op.stmts.Kron) | ||
| def kron(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Kron): | ||
| is_hermitian = frame.get(stmt.lhs).join(frame.get(stmt.rhs)) | ||
| return (is_hermitian,) | ||
| 
     | 
||
| @interp.impl(op.stmts.Mult) | ||
| def mult(self, interp: HermitianAnalysis, frame: ForwardFrame, stmt: op.stmts.Mult): | ||
| # NOTE: this could be smarter here and check whether lhs == adjoint(rhs) | ||
| if stmt.lhs != stmt.rhs: | ||
| return (interp.lattice.top(),) | ||
| 
     | 
||
| return (frame.get(stmt.lhs),) | ||
| 
     | 
||
| 
     | 
||
| @scf.dialect.register(key="squin.hermitian") | ||
| class ScfHermitianMethods(scf.typeinfer.TypeInfer): | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| @func.dialect.register(key="squin.hermitian") | ||
| class FuncHermitianMethods(func.typeinfer.TypeInfer): | ||
| pass | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| from typing import final | ||
| from dataclasses import dataclass | ||
| 
     | 
||
| from kirin.lattice import ( | ||
| SingletonMeta, | ||
| BoundedLattice, | ||
| IsSubsetEqMixin, | ||
| SimpleJoinMixin, | ||
| SimpleMeetMixin, | ||
| ) | ||
| 
     | 
||
| 
     | 
||
| @dataclass | ||
| class HermitianLattice( | ||
| SimpleJoinMixin["HermitianLattice"], | ||
| SimpleMeetMixin["HermitianLattice"], | ||
| IsSubsetEqMixin["HermitianLattice"], | ||
| BoundedLattice["HermitianLattice"], | ||
| ): | ||
| @classmethod | ||
| def bottom(cls) -> "HermitianLattice": | ||
| return NotAnOperator() | ||
| 
     | 
||
| @classmethod | ||
| def top(cls) -> "HermitianLattice": | ||
| return PossiblyHermitian() | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class NotAnOperator(HermitianLattice, metaclass=SingletonMeta): | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class NotHermitian(HermitianLattice, metaclass=SingletonMeta): | ||
| def is_subseteq(self, other: HermitianLattice) -> bool: | ||
| return isinstance(other, NotHermitian) | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class Hermitian(HermitianLattice, metaclass=SingletonMeta): | ||
| def is_subseteq(self, other: HermitianLattice) -> bool: | ||
| return isinstance(other, Hermitian) | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class PossiblyHermitian(HermitianLattice, metaclass=SingletonMeta): | ||
| pass | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from . import impls as impls | ||
| from .lattice import ( | ||
| Unitary as Unitary, | ||
| NotUnitary as NotUnitary, | ||
| NotAnOperator as NotAnOperator, | ||
| UnitaryLattice as UnitaryLattice, | ||
| ) | ||
| from .analysis import UnitaryAnalysis as UnitaryAnalysis | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from kirin import ir | ||
| from kirin.analysis import Forward, ForwardFrame | ||
| 
     | 
||
| from ... import op | ||
| from .lattice import Unitary, NotUnitary, UnitaryLattice | ||
| from ..hermitian import HermitianLattice, HermitianAnalysis | ||
| 
     | 
||
| 
     | 
||
| class UnitaryAnalysis(Forward): | ||
| keys = ["squin.unitary"] | ||
| lattice = UnitaryLattice | ||
| hermitian_values: dict[ir.SSAValue, HermitianLattice] = dict() | ||
| 
     | 
||
| def run_method(self, method: ir.Method, args: tuple[UnitaryLattice, ...]): | ||
| hermitian_frame, _ = HermitianAnalysis(method.dialects).run_analysis(method) | ||
| self.hermitian_values.update(hermitian_frame.entries) | ||
| return self.run_callable(method.code, (self.lattice.bottom(),) + args) | ||
| 
     | 
||
| def eval_stmt_fallback(self, frame: ForwardFrame, stmt: ir.Statement): | ||
| if not isinstance(stmt, op.stmts.Operator): | ||
| return (self.lattice.bottom(),) | ||
| 
     | 
||
| if stmt.has_trait(op.traits.Unitary): | ||
| return (Unitary(),) | ||
| 
     | 
||
| if ( | ||
| trait := stmt.get_trait(op.traits.MaybeUnitary) | ||
| ) is not None and trait.is_unitary(stmt): | ||
| return (Unitary(),) | ||
| 
     | 
||
| if isinstance(stmt, op.stmts.PrimitiveOp): | ||
| # NOTE: simple operator that doesn't have the trait or an impl so it's known not to be unitary | ||
| return (NotUnitary(),) | ||
| 
     | 
||
| return (self.lattice.top(),) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| from kirin import interp | ||
| from kirin.analysis import ForwardFrame, const | ||
| from kirin.dialects import scf, func | ||
| 
     | 
||
| from ... import op | ||
| from .lattice import Unitary, NotUnitary | ||
| from .analysis import UnitaryAnalysis | ||
| from ..hermitian import Hermitian, NotHermitian | ||
| 
     | 
||
| 
     | 
||
| @op.dialect.register(key="squin.unitary") | ||
| class UnitaryMethods(interp.MethodTable): | ||
| @interp.impl(op.stmts.Control) | ||
| @interp.impl(op.stmts.Adjoint) | ||
| def simple_container( | ||
| self, | ||
| interp: UnitaryAnalysis, | ||
| frame: ForwardFrame, | ||
| stmt: op.stmts.Control | op.stmts.Adjoint, | ||
| ): | ||
| return (frame.get(stmt.op),) | ||
| 
     | 
||
| @interp.impl(op.stmts.Scale) | ||
| def scale(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Scale): | ||
| is_unitary = frame.get(stmt.op) | ||
| 
     | 
||
| if not is_unitary.is_subseteq(Unitary): | ||
| return (is_unitary,) | ||
| 
     | 
||
| # NOTE: need to check if the factor has absolute value squared of 1 | ||
| constant_value = stmt.factor.hints.get("const") | ||
| if not isinstance(constant_value, const.Value): | ||
| return (interp.lattice.top(),) | ||
| 
     | 
||
| num = constant_value.data | ||
| if not isinstance(num, (float, int, bool, complex)): | ||
| return (interp.lattice.top(),) | ||
| 
     | 
||
| if abs(num) ** 2 == 1: | ||
| return (Unitary(),) | ||
| 
     | 
||
| return (NotUnitary(),) | ||
| 
     | 
||
| @interp.impl(op.stmts.Kron) | ||
| @interp.impl(op.stmts.Mult) | ||
| def binary_op( | ||
| self, | ||
| interp: UnitaryAnalysis, | ||
| frame: ForwardFrame, | ||
| stmt: op.stmts.Kron | op.stmts.Mult, | ||
| ): | ||
| return (frame.get(stmt.lhs).join(frame.get(stmt.rhs)),) | ||
| 
     | 
||
| @interp.impl(op.stmts.Rot) | ||
| def rot(self, interp: UnitaryAnalysis, frame: ForwardFrame, stmt: op.stmts.Rot): | ||
| if interp.hermitian_values.get(stmt.axis, NotHermitian()).is_equal(Hermitian()): | ||
| return (Unitary(),) | ||
| else: | ||
| return (NotUnitary(),) | ||
| 
     | 
||
| 
     | 
||
| @scf.dialect.register(key="squin.unitary") | ||
| class ScfUnitaryMethods(scf.typeinfer.TypeInfer): | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| @func.dialect.register(key="squin.unitary") | ||
| class FuncUnitaryMethods(func.typeinfer.TypeInfer): | ||
| pass | ||
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| from typing import final | ||
| from dataclasses import dataclass | ||
| 
     | 
||
| from kirin.lattice import ( | ||
| SingletonMeta, | ||
| BoundedLattice, | ||
| IsSubsetEqMixin, | ||
| SimpleJoinMixin, | ||
| SimpleMeetMixin, | ||
| ) | ||
| 
     | 
||
| 
     | 
||
| @dataclass | ||
| class UnitaryLattice( | ||
| SimpleJoinMixin["UnitaryLattice"], | ||
| SimpleMeetMixin["UnitaryLattice"], | ||
| IsSubsetEqMixin["UnitaryLattice"], | ||
| BoundedLattice["UnitaryLattice"], | ||
| ): | ||
| @classmethod | ||
| def bottom(cls) -> "UnitaryLattice": | ||
| return NotAnOperator() | ||
| 
     | 
||
| @classmethod | ||
| def top(cls) -> "UnitaryLattice": | ||
| return PossiblyUnitary() | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class NotAnOperator(UnitaryLattice, metaclass=SingletonMeta): | ||
| pass | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class NotUnitary(UnitaryLattice, metaclass=SingletonMeta): | ||
| 
     | 
||
| def is_subseteq(self, other: UnitaryLattice) -> bool: | ||
| return isinstance(other, NotUnitary) | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class Unitary(UnitaryLattice, metaclass=SingletonMeta): | ||
| 
     | 
||
| def is_subseteq(self, other: UnitaryLattice) -> bool: | ||
| return isinstance(other, Unitary) | ||
| 
     | 
||
| 
     | 
||
| @final | ||
| @dataclass | ||
| class PossiblyUnitary(UnitaryLattice, metaclass=SingletonMeta): | ||
| pass | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.