Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.22.0",
"scipy>=1.13.1",
"kirin-toolchain~=0.17.23",
"kirin-toolchain~=0.17.26",
"rich>=13.9.4",
"pydantic>=1.3.0,<2.11.0",
"pandas>=2.2.3",
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/pyqrack/squin/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def reset(
self,
interp: PyQrackInterpreter,
frame: interp.Frame,
stmt: op.stmts.Reset | op.stmts.ResetToOne,
stmt: op.stmts.Reset,
) -> tuple[OperatorRuntimeABC]:
target_state = isinstance(stmt, op.stmts.ResetToOne)
return (ResetRuntime(target_state=target_state),)
Expand Down
1 change: 0 additions & 1 deletion src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
noise as noise,
qubit as qubit,
analysis as analysis,
lowering as lowering,
_typeinfer as _typeinfer,
)
from .groups import wired as wired, kernel as kernel
Expand Down
15 changes: 7 additions & 8 deletions src/bloqade/squin/_typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from kirin.analysis import TypeInference, const
from kirin.dialects import ilist

from bloqade import squin
from bloqade.squin import qubit


@squin.qubit.dialect.register(key="typeinfer")
@qubit.dialect.register(key="typeinfer")
class TypeInfer(interp.MethodTable):
@interp.impl(squin.qubit.New)
def _call(self, interp: TypeInference, frame: interp.Frame, stmt: squin.qubit.New):
@interp.impl(qubit.New)
def _call(self, interp: TypeInference, frame: interp.Frame, stmt: qubit.New):
# based on Xiu-zhe (Roger) Luo's get_const_value function

if (hint := stmt.n_qubits.hints.get("const")) is None:
return (ilist.IListType[squin.qubit.QubitType, types.Any],)

return (ilist.IListType[qubit.QubitType, types.Any],)
if isinstance(hint, const.Value) and isinstance(hint.data, int):
return (ilist.IListType[squin.qubit.QubitType, types.Literal(hint.data)],)
return (ilist.IListType[qubit.QubitType, types.Literal(hint.data)],)

return (ilist.IListType[squin.qubit.QubitType, types.Any],)
return (ilist.IListType[qubit.QubitType, types.Any],)
6 changes: 6 additions & 0 deletions src/bloqade/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class MeasurementId(ir.Statement):
result: ir.ResultValue = info.result(types.Int)


@statement(dialect=dialect)
class Reset(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])


# NOTE: no dependent types in Python, so we have to mark it Any...
@wraps(New)
def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
Expand Down
112 changes: 60 additions & 52 deletions src/bloqade/squin/rewrite/U3_to_clifford.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,51 @@
# create rewrite rule name SquinMeasureToStim using kirin
import math
from typing import List, Tuple, Callable

import numpy as np
from kirin import ir
from kirin.dialects import py
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.squin import op, qubit
from bloqade.squin import gate


def sdag() -> list[ir.Statement]:
return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)]
# Placeholder type, swap in an actual S statement with adjoint=True
# during the rewrite method
class Sdag(ir.Statement):
pass


# Identity gate doesn't exist, use this as a place holder
class DummyIdentity(ir.Statement):
pass


# (theta, phi, lam)
U3_HALF_PI_ANGLE_TO_GATES: dict[
tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]]
] = {
(0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],),
(0, 0, 1): lambda: ([op.stmts.S()],),
(0, 0, 2): lambda: ([op.stmts.Z()],),
(0, 0, 3): lambda: (sdag(),),
(1, 0, 0): lambda: ([op.stmts.SqrtY()],),
(1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]),
(1, 0, 2): lambda: ([op.stmts.H()],),
(1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]),
(1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()),
(1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
(1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
(1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()),
(2, 0, 0): lambda: ([op.stmts.Y()],),
(2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]),
(2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]),
(2, 0, 3): lambda: (sdag(), [op.stmts.Y()]),
U3_HALF_PI_ANGLE_TO_GATES: dict[tuple[int, int, int], list[type[ir.Statement]]] = {
(0, 0, 0): [DummyIdentity],
(0, 0, 1): [gate.stmts.S],
(0, 0, 2): [gate.stmts.Z],
(0, 0, 3): [Sdag],
(1, 0, 0): [gate.stmts.SqrtY],
(1, 0, 1): [gate.stmts.S, gate.stmts.SqrtY],
(1, 0, 2): [gate.stmts.H],
(1, 0, 3): [Sdag, gate.stmts.SqrtY],
(1, 1, 0): [gate.stmts.SqrtY, gate.stmts.S],
(1, 1, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.S],
(1, 1, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S],
(1, 1, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.S],
(1, 2, 0): [gate.stmts.SqrtY, gate.stmts.Z],
(1, 2, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z],
(1, 2, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.Z],
(1, 2, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.Z],
(1, 3, 0): [gate.stmts.SqrtY, Sdag],
(1, 3, 1): [gate.stmts.S, gate.stmts.SqrtY, Sdag],
(1, 3, 2): [gate.stmts.Z, gate.stmts.SqrtY, Sdag],
(1, 3, 3): [Sdag, gate.stmts.SqrtY, Sdag],
(2, 0, 0): [gate.stmts.Y],
(2, 0, 1): [gate.stmts.S, gate.stmts.Y],
(2, 0, 2): [gate.stmts.Z, gate.stmts.Y],
(2, 0, 3): [Sdag, gate.stmts.Y],
}


Expand All @@ -61,8 +65,8 @@ class SquinU3ToClifford(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if isinstance(node, (qubit.Apply, qubit.Broadcast)):
return self.rewrite_ApplyOrBroadcast_onU3(node)
if isinstance(node, gate.stmts.U3):
return self.rewrite_U3(node)
else:
return RewriteResult()

Expand All @@ -87,35 +91,33 @@ def resolve_angle(self, angle: float) -> int | None:
else:
return round((angle / math.tau) % 1 * 4) % 4

def rewrite_ApplyOrBroadcast_onU3(
self, node: qubit.Apply | qubit.Broadcast
) -> RewriteResult:
def rewrite_U3(self, node: gate.stmts.U3) -> RewriteResult:
"""
Rewrite Apply and Broadcast nodes to their clifford equivalent statements.
"""
if not isinstance(node.operator.owner, op.stmts.U3):
return RewriteResult()

gates = self.decompose_U3_gates(node.operator.owner)
gates = self.decompose_U3_gates(node)

if len(gates) == 0:
return RewriteResult()

for stmt_list in gates:
for gate_stmt in stmt_list[:-1]:
gate_stmt.insert_before(node)
# Get rid of the U3 gate altogether if it's identity
if len(gates) == 1 and gates[0] is DummyIdentity:
node.delete()
return RewriteResult(has_done_something=True)

oper = stmt_list[-1]
oper.insert_before(node)
new_node = node.__class__(operator=oper.result, qubits=node.qubits)
new_node.insert_before(node)
for gate_stmt in gates:
if gate_stmt is Sdag:
new_stmt = gate.stmts.S(adjoint=True, qubits=node.qubits)
else:
new_stmt = gate_stmt(qubits=node.qubits)
new_stmt.insert_before(node)

node.delete()

# rewrite U3 to clifford gates
return RewriteResult(has_done_something=True)

def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...]:
def decompose_U3_gates(self, node: gate.stmts.U3) -> list[ir.Statement]:
"""
Rewrite U3 statements to clifford gates if possible.
"""
Expand All @@ -124,7 +126,13 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
lam = self.get_constant(node.lam)

if theta is None or phi is None or lam is None:
return ()
return []

# Angles will be in units of turns, we convert to radians
# to allow for the old logic to work
theta = theta * math.tau
phi = phi * math.tau
lam = lam * math.tau

# For U3(2*pi*n, phi, lam) = U3(0, 0, lam + phi) which is a Z rotation.
if np.isclose(np.mod(theta, math.tau), 0):
Expand All @@ -139,13 +147,13 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
lam_half_pi: int | None = self.resolve_angle(lam)

if theta_half_pi is None or phi_half_pi is None or lam_half_pi is None:
return ()
return []

angles_key = (theta_half_pi, phi_half_pi, lam_half_pi)
if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
angles_key = equivalent_u3_para(*angles_key)
if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
return ()
return []

gates_stmts = U3_HALF_PI_ANGLE_TO_GATES.get(angles_key)

Expand All @@ -154,4 +162,4 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
gates_stmts is not None
), "internal error, U3 gates not found for angles: {}".format(angles_key)

return gates_stmts()
return gates_stmts
1 change: 0 additions & 1 deletion src/bloqade/stim/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .squin_to_stim import (
SquinToStimPass as SquinToStimPass,
StimSimplifyIfs as StimSimplifyIfs,
)
101 changes: 101 additions & 0 deletions src/bloqade/stim/passes/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Taken from Phillip Weinberg's bloqade-shuttle implementation
from dataclasses import field, dataclass

from kirin import ir
from kirin.passes import Pass, HintConst, TypeInfer
from kirin.rewrite import (
Walk,
Chain,
Inline,
Fixpoint,
Call2Invoke,
ConstantFold,
CFGCompactify,
InlineGetItem,
InlineGetField,
DeadCodeElimination,
)
from kirin.dialects import scf, ilist
from kirin.ir.method import Method
from kirin.rewrite.abc import RewriteResult
from kirin.rewrite.cse import CommonSubexpressionElimination
from kirin.passes.inline import InlinePass
from kirin.passes.aggressive import UnrollScf

from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs


@dataclass
class Fold(Pass):
hint_const: HintConst = field(init=False)

def __post_init__(self):
self.hint_const = HintConst(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
result = self.hint_const.unsafe_run(mt).join(result)
rule = Chain(
ConstantFold(),
Call2Invoke(),
InlineGetField(),
InlineGetItem(),
ilist.rewrite.InlineGetItem(),
ilist.rewrite.HintLen(),
DeadCodeElimination(),
CommonSubexpressionElimination(),
)
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)

return result


@dataclass
class AggressiveUnroll(Pass):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass was also added to QASM2 in #536. You can just use that and remove it here.

"""Fold pass to fold control flow"""

fold: Fold = field(init=False)
typeinfer: TypeInfer = field(init=False)
scf_unroll: UnrollScf = field(init=False)

def __post_init__(self):
self.fold = Fold(self.dialects, no_raise=self.no_raise)
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
result = self.scf_unroll.unsafe_run(mt).join(result)
result = (
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
.rewrite(mt.code)
.join(result)
)
result = self.typeinfer.unsafe_run(mt).join(result)
result = self.fold.unsafe_run(mt).join(result)
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
return result

@classmethod
def inline_heuristic(cls, node: ir.Statement) -> bool:
"""The heuristic to decide whether to inline a function call or not.
inside loops and if-else, only inline simple functions, i.e.
functions with a single block
"""
return not isinstance(
node.parent_stmt, (scf.For, scf.IfElse)
) # always inline calls outside of loops and if-else


class Flatten(Pass):
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
rewrite_result = InlinePass(dialects=mt.dialects, no_raise=self.no_raise)(mt)
rewrite_result = AggressiveUnroll(dialects=mt.dialects, no_raise=self.no_raise)(
mt
).join(rewrite_result)
rewrite_result = StimSimplifyIfs(dialects=mt.dialects, no_raise=self.no_raise)(
mt
).join(rewrite_result)

return rewrite_result
7 changes: 6 additions & 1 deletion src/bloqade/stim/passes/simplify_ifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
Chain,
Fixpoint,
ConstantFold,
DeadCodeElimination,
CommonSubexpressionElimination,
)
from kirin.dialects.scf.trim import UnusedYield
from kirin.dialects.ilist.passes import ConstList2IList

from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
Expand All @@ -20,7 +22,10 @@ class StimSimplifyIfs(Pass):
def unsafe_run(self, mt: ir.Method):

result = Chain(
Fixpoint(Walk(StimLiftThenBody())),
Walk(UnusedYield()),
Walk(StimLiftThenBody()),
# remove yields (if possible), then lift out as much stuff as possible
Walk(DeadCodeElimination()),
Walk(StimSplitIfStmts()),
).rewrite(mt.code)

Expand Down
Loading