Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3fabd5a
refactor squin to stim pass to be simpler + more performant with rep …
johnzl-777 Sep 29, 2025
012ce87
Merge branch 'main' into john/squin-to-stim-performance-refactor
johnzl-777 Sep 30, 2025
f21059d
Merge branch 'main' into john/squin-to-stim-performance-refactor
johnzl-777 Oct 2, 2025
95e651a
add reset op, get a decent set of tests to work properly
johnzl-777 Oct 5, 2025
9d86c00
undo unecessary qubit refactor
johnzl-777 Oct 7, 2025
79c919b
Shuffle reset into qubit, drop ResetToOne. Also try to make cirq runt…
johnzl-777 Oct 7, 2025
38a9ce6
remove reset from these tests for now considering different semantics
johnzl-777 Oct 7, 2025
462a8b8
bring back the old reset to keep tests happy, this is driving me nuts
johnzl-777 Oct 7, 2025
acb3364
add measurement tests back
johnzl-777 Oct 7, 2025
f83b68a
merge main in
johnzl-777 Oct 8, 2025
d1b798b
Merge branch 'main' into john/squin-to-stim-performance-refactor
johnzl-777 Oct 8, 2025
275ea82
move away from old clifford dialect
johnzl-777 Oct 10, 2025
2ad43be
complete U3 to clifford rewrite + test refactor
johnzl-777 Oct 10, 2025
6e2c997
bring back the SquinToStim U3 to Clifford test
johnzl-777 Oct 10, 2025
4f71a51
update to newer version of kirin with FlattenAdd to get CI to be happy
johnzl-777 Oct 10, 2025
9fec057
Merge branch 'main' into john/squin-to-stim-performance-refactor
david-pl Oct 14, 2025
ce77998
get rid of duplicate aggressive unroll, reuse the one from QASM2
johnzl-777 Oct 14, 2025
26116f0
fix types in SquinU3ToClifford
johnzl-777 Oct 14, 2025
afd901f
make lifting stricter
johnzl-777 Oct 14, 2025
676c257
remove leftover tinkering tests
johnzl-777 Oct 14, 2025
626bf3a
just reroute SSA values instead of generating new constants
johnzl-777 Oct 14, 2025
2821a33
get rid of is_measurement_used considering only one place uses it
johnzl-777 Oct 14, 2025
cb89601
get out of qubit to stim rewrite sooner if address attr not available…
johnzl-777 Oct 14, 2025
1315d89
get rid of DummyIdentity, can use None
johnzl-777 Oct 14, 2025
b8a034f
fix types in SquinU3ToClifford (again)
johnzl-777 Oct 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.22.0",
"scipy>=1.13.1",
"kirin-toolchain~=0.17.23",
"kirin-toolchain~=0.17.26",
"rich>=13.9.4",
"pydantic>=1.3.0,<2.11.0",
"pandas>=2.2.3",
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/pyqrack/squin/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def reset(
self,
interp: PyQrackInterpreter,
frame: interp.Frame,
stmt: op.stmts.Reset | op.stmts.ResetToOne,
stmt: op.stmts.Reset,
) -> tuple[OperatorRuntimeABC]:
target_state = isinstance(stmt, op.stmts.ResetToOne)
return (ResetRuntime(target_state=target_state),)
Expand Down
1 change: 0 additions & 1 deletion src/bloqade/squin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
noise as noise,
qubit as qubit,
analysis as analysis,
lowering as lowering,
_typeinfer as _typeinfer,
)
from .groups import wired as wired, kernel as kernel
Expand Down
15 changes: 7 additions & 8 deletions src/bloqade/squin/_typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from kirin.analysis import TypeInference, const
from kirin.dialects import ilist

from bloqade import squin
from bloqade.squin import qubit


@squin.qubit.dialect.register(key="typeinfer")
@qubit.dialect.register(key="typeinfer")
class TypeInfer(interp.MethodTable):
@interp.impl(squin.qubit.New)
def _call(self, interp: TypeInference, frame: interp.Frame, stmt: squin.qubit.New):
@interp.impl(qubit.New)
def _call(self, interp: TypeInference, frame: interp.Frame, stmt: qubit.New):
# based on Xiu-zhe (Roger) Luo's get_const_value function

if (hint := stmt.n_qubits.hints.get("const")) is None:
return (ilist.IListType[squin.qubit.QubitType, types.Any],)

return (ilist.IListType[qubit.QubitType, types.Any],)
if isinstance(hint, const.Value) and isinstance(hint.data, int):
return (ilist.IListType[squin.qubit.QubitType, types.Literal(hint.data)],)
return (ilist.IListType[qubit.QubitType, types.Literal(hint.data)],)

return (ilist.IListType[squin.qubit.QubitType, types.Any],)
return (ilist.IListType[qubit.QubitType, types.Any],)
6 changes: 6 additions & 0 deletions src/bloqade/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ class MeasurementId(ir.Statement):
result: ir.ResultValue = info.result(types.Int)


@statement(dialect=dialect)
class Reset(ir.Statement):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement needs a pyqrack method and a cirq emit method impl. I can also add those, if you like. So either add them here, or a create a new issue and assign it to me, please.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make an issue, thank you! (See #546 )

traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])


# NOTE: no dependent types in Python, so we have to mark it Any...
@wraps(New)
def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
Expand Down
112 changes: 60 additions & 52 deletions src/bloqade/squin/rewrite/U3_to_clifford.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,51 @@
# create rewrite rule name SquinMeasureToStim using kirin
import math
from typing import List, Tuple, Callable

import numpy as np
from kirin import ir
from kirin.dialects import py
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.squin import op, qubit
from bloqade.squin import gate


def sdag() -> list[ir.Statement]:
return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)]
# Placeholder type, swap in an actual S statement with adjoint=True
# during the rewrite method
class Sdag(ir.Statement):
pass


# Identity gate doesn't exist, use this as a place holder
class DummyIdentity(ir.Statement):
pass


# (theta, phi, lam)
U3_HALF_PI_ANGLE_TO_GATES: dict[
tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]]
] = {
(0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],),
(0, 0, 1): lambda: ([op.stmts.S()],),
(0, 0, 2): lambda: ([op.stmts.Z()],),
(0, 0, 3): lambda: (sdag(),),
(1, 0, 0): lambda: ([op.stmts.SqrtY()],),
(1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]),
(1, 0, 2): lambda: ([op.stmts.H()],),
(1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]),
(1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()),
(1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
(1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
(1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()),
(2, 0, 0): lambda: ([op.stmts.Y()],),
(2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]),
(2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]),
(2, 0, 3): lambda: (sdag(), [op.stmts.Y()]),
U3_HALF_PI_ANGLE_TO_GATES: dict[tuple[int, int, int], list[type[ir.Statement]]] = {
(0, 0, 0): [DummyIdentity],
(0, 0, 1): [gate.stmts.S],
(0, 0, 2): [gate.stmts.Z],
(0, 0, 3): [Sdag],
(1, 0, 0): [gate.stmts.SqrtY],
(1, 0, 1): [gate.stmts.S, gate.stmts.SqrtY],
(1, 0, 2): [gate.stmts.H],
(1, 0, 3): [Sdag, gate.stmts.SqrtY],
(1, 1, 0): [gate.stmts.SqrtY, gate.stmts.S],
(1, 1, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.S],
(1, 1, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S],
(1, 1, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.S],
(1, 2, 0): [gate.stmts.SqrtY, gate.stmts.Z],
(1, 2, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z],
(1, 2, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.Z],
(1, 2, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.Z],
(1, 3, 0): [gate.stmts.SqrtY, Sdag],
(1, 3, 1): [gate.stmts.S, gate.stmts.SqrtY, Sdag],
(1, 3, 2): [gate.stmts.Z, gate.stmts.SqrtY, Sdag],
(1, 3, 3): [Sdag, gate.stmts.SqrtY, Sdag],
(2, 0, 0): [gate.stmts.Y],
(2, 0, 1): [gate.stmts.S, gate.stmts.Y],
(2, 0, 2): [gate.stmts.Z, gate.stmts.Y],
(2, 0, 3): [Sdag, gate.stmts.Y],
}


Expand All @@ -61,8 +65,8 @@ class SquinU3ToClifford(RewriteRule):
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if isinstance(node, (qubit.Apply, qubit.Broadcast)):
return self.rewrite_ApplyOrBroadcast_onU3(node)
if isinstance(node, gate.stmts.U3):
return self.rewrite_U3(node)
else:
return RewriteResult()

Expand All @@ -87,35 +91,33 @@ def resolve_angle(self, angle: float) -> int | None:
else:
return round((angle / math.tau) % 1 * 4) % 4

def rewrite_ApplyOrBroadcast_onU3(
self, node: qubit.Apply | qubit.Broadcast
) -> RewriteResult:
def rewrite_U3(self, node: gate.stmts.U3) -> RewriteResult:
"""
Rewrite Apply and Broadcast nodes to their clifford equivalent statements.
"""
if not isinstance(node.operator.owner, op.stmts.U3):
return RewriteResult()

gates = self.decompose_U3_gates(node.operator.owner)
gates = self.decompose_U3_gates(node)

if len(gates) == 0:
return RewriteResult()

for stmt_list in gates:
for gate_stmt in stmt_list[:-1]:
gate_stmt.insert_before(node)
# Get rid of the U3 gate altogether if it's identity
if len(gates) == 1 and gates[0] is DummyIdentity:
node.delete()
return RewriteResult(has_done_something=True)

oper = stmt_list[-1]
oper.insert_before(node)
new_node = node.__class__(operator=oper.result, qubits=node.qubits)
new_node.insert_before(node)
for gate_stmt in gates:
if gate_stmt is Sdag:
new_stmt = gate.stmts.S(adjoint=True, qubits=node.qubits)
else:
new_stmt = gate_stmt(qubits=node.qubits)
new_stmt.insert_before(node)

node.delete()

# rewrite U3 to clifford gates
return RewriteResult(has_done_something=True)

def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...]:
def decompose_U3_gates(self, node: gate.stmts.U3) -> list[ir.Statement]:
"""
Rewrite U3 statements to clifford gates if possible.
"""
Expand All @@ -124,7 +126,13 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
lam = self.get_constant(node.lam)

if theta is None or phi is None or lam is None:
return ()
return []

# Angles will be in units of turns, we convert to radians
# to allow for the old logic to work
theta = theta * math.tau
phi = phi * math.tau
lam = lam * math.tau

# For U3(2*pi*n, phi, lam) = U3(0, 0, lam + phi) which is a Z rotation.
if np.isclose(np.mod(theta, math.tau), 0):
Expand All @@ -139,13 +147,13 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
lam_half_pi: int | None = self.resolve_angle(lam)

if theta_half_pi is None or phi_half_pi is None or lam_half_pi is None:
return ()
return []

angles_key = (theta_half_pi, phi_half_pi, lam_half_pi)
if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
angles_key = equivalent_u3_para(*angles_key)
if angles_key not in U3_HALF_PI_ANGLE_TO_GATES:
return ()
return []

gates_stmts = U3_HALF_PI_ANGLE_TO_GATES.get(angles_key)

Expand All @@ -154,4 +162,4 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
gates_stmts is not None
), "internal error, U3 gates not found for angles: {}".format(angles_key)

return gates_stmts()
return gates_stmts
1 change: 0 additions & 1 deletion src/bloqade/stim/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .squin_to_stim import (
SquinToStimPass as SquinToStimPass,
StimSimplifyIfs as StimSimplifyIfs,
)
101 changes: 101 additions & 0 deletions src/bloqade/stim/passes/flatten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Taken from Phillip Weinberg's bloqade-shuttle implementation
from dataclasses import field, dataclass

from kirin import ir
from kirin.passes import Pass, HintConst, TypeInfer
from kirin.rewrite import (
Walk,
Chain,
Inline,
Fixpoint,
Call2Invoke,
ConstantFold,
CFGCompactify,
InlineGetItem,
InlineGetField,
DeadCodeElimination,
)
from kirin.dialects import scf, ilist
from kirin.ir.method import Method
from kirin.rewrite.abc import RewriteResult
from kirin.rewrite.cse import CommonSubexpressionElimination
from kirin.passes.inline import InlinePass
from kirin.passes.aggressive import UnrollScf

from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs


@dataclass
class Fold(Pass):
hint_const: HintConst = field(init=False)

def __post_init__(self):
self.hint_const = HintConst(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
result = self.hint_const.unsafe_run(mt).join(result)
rule = Chain(
ConstantFold(),
Call2Invoke(),
InlineGetField(),
InlineGetItem(),
ilist.rewrite.InlineGetItem(),
ilist.rewrite.HintLen(),
DeadCodeElimination(),
CommonSubexpressionElimination(),
)
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)

return result


@dataclass
class AggressiveUnroll(Pass):
"""Fold pass to fold control flow"""

fold: Fold = field(init=False)
typeinfer: TypeInfer = field(init=False)
scf_unroll: UnrollScf = field(init=False)

def __post_init__(self):
self.fold = Fold(self.dialects, no_raise=self.no_raise)
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)

def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
result = self.scf_unroll.unsafe_run(mt).join(result)
result = (
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
.rewrite(mt.code)
.join(result)
)
result = self.typeinfer.unsafe_run(mt).join(result)
result = self.fold.unsafe_run(mt).join(result)
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
return result

@classmethod
def inline_heuristic(cls, node: ir.Statement) -> bool:
"""The heuristic to decide whether to inline a function call or not.
inside loops and if-else, only inline simple functions, i.e.
functions with a single block
"""
return not isinstance(
node.parent_stmt, (scf.For, scf.IfElse)
) # always inline calls outside of loops and if-else


class Flatten(Pass):
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
rewrite_result = InlinePass(dialects=mt.dialects, no_raise=self.no_raise)(mt)
rewrite_result = AggressiveUnroll(dialects=mt.dialects, no_raise=self.no_raise)(
mt
).join(rewrite_result)
rewrite_result = StimSimplifyIfs(dialects=mt.dialects, no_raise=self.no_raise)(
mt
).join(rewrite_result)

return rewrite_result
7 changes: 6 additions & 1 deletion src/bloqade/stim/passes/simplify_ifs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
Chain,
Fixpoint,
ConstantFold,
DeadCodeElimination,
CommonSubexpressionElimination,
)
from kirin.dialects.scf.trim import UnusedYield
from kirin.dialects.ilist.passes import ConstList2IList

from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts
Expand All @@ -20,7 +22,10 @@ class StimSimplifyIfs(Pass):
def unsafe_run(self, mt: ir.Method):

result = Chain(
Fixpoint(Walk(StimLiftThenBody())),
Walk(UnusedYield()),
Walk(StimLiftThenBody()),
# remove yields (if possible), then lift out as much stuff as possible
Walk(DeadCodeElimination()),
Walk(StimSplitIfStmts()),
).rewrite(mt.code)

Expand Down
Loading