diff --git a/pyproject.toml b/pyproject.toml index 8886e76c..f338b047 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ requires-python = ">=3.10" dependencies = [ "numpy>=1.22.0", "scipy>=1.13.1", - "kirin-toolchain~=0.17.23", + "kirin-toolchain~=0.17.26", "rich>=13.9.4", "pydantic>=1.3.0,<2.11.0", "pandas>=2.2.3", diff --git a/src/bloqade/pyqrack/squin/op.py b/src/bloqade/pyqrack/squin/op.py index 1a603aaa..a9f1fb1a 100644 --- a/src/bloqade/pyqrack/squin/op.py +++ b/src/bloqade/pyqrack/squin/op.py @@ -103,7 +103,7 @@ def reset( self, interp: PyQrackInterpreter, frame: interp.Frame, - stmt: op.stmts.Reset | op.stmts.ResetToOne, + stmt: op.stmts.Reset, ) -> tuple[OperatorRuntimeABC]: target_state = isinstance(stmt, op.stmts.ResetToOne) return (ResetRuntime(target_state=target_state),) diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index c1febe37..89d51a90 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -5,7 +5,6 @@ noise as noise, qubit as qubit, analysis as analysis, - lowering as lowering, _typeinfer as _typeinfer, ) from .groups import wired as wired, kernel as kernel diff --git a/src/bloqade/squin/_typeinfer.py b/src/bloqade/squin/_typeinfer.py index b1055ef3..82bbd84c 100644 --- a/src/bloqade/squin/_typeinfer.py +++ b/src/bloqade/squin/_typeinfer.py @@ -2,19 +2,18 @@ from kirin.analysis import TypeInference, const from kirin.dialects import ilist -from bloqade import squin +from bloqade.squin import qubit -@squin.qubit.dialect.register(key="typeinfer") +@qubit.dialect.register(key="typeinfer") class TypeInfer(interp.MethodTable): - @interp.impl(squin.qubit.New) - def _call(self, interp: TypeInference, frame: interp.Frame, stmt: squin.qubit.New): + @interp.impl(qubit.New) + def _call(self, interp: TypeInference, frame: interp.Frame, stmt: qubit.New): # based on Xiu-zhe (Roger) Luo's get_const_value function if (hint := stmt.n_qubits.hints.get("const")) is None: - return (ilist.IListType[squin.qubit.QubitType, types.Any],) - + return (ilist.IListType[qubit.QubitType, types.Any],) if isinstance(hint, const.Value) and isinstance(hint.data, int): - return (ilist.IListType[squin.qubit.QubitType, types.Literal(hint.data)],) + return (ilist.IListType[qubit.QubitType, types.Literal(hint.data)],) - return (ilist.IListType[squin.qubit.QubitType, types.Any],) + return (ilist.IListType[qubit.QubitType, types.Any],) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 1faa74fd..857ff944 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -92,6 +92,12 @@ class MeasurementId(ir.Statement): result: ir.ResultValue = info.result(types.Int) +@statement(dialect=dialect) +class Reset(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) + + # NOTE: no dependent types in Python, so we have to mark it Any... @wraps(New) def new(n_qubits: int) -> ilist.IList[Qubit, Any]: diff --git a/src/bloqade/squin/rewrite/U3_to_clifford.py b/src/bloqade/squin/rewrite/U3_to_clifford.py index bc3c4f63..7e5ac1cc 100644 --- a/src/bloqade/squin/rewrite/U3_to_clifford.py +++ b/src/bloqade/squin/rewrite/U3_to_clifford.py @@ -1,47 +1,48 @@ # create rewrite rule name SquinMeasureToStim using kirin import math -from typing import List, Tuple, Callable import numpy as np from kirin import ir from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import op, qubit +from bloqade.squin import gate -def sdag() -> list[ir.Statement]: - return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)] +# Placeholder type, swap in an actual S statement with adjoint=True +# during the rewrite method +class Sdag(ir.Statement): + pass # (theta, phi, lam) U3_HALF_PI_ANGLE_TO_GATES: dict[ - tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]] + tuple[int, int, int], list[type[ir.Statement]] | list[None] ] = { - (0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],), - (0, 0, 1): lambda: ([op.stmts.S()],), - (0, 0, 2): lambda: ([op.stmts.Z()],), - (0, 0, 3): lambda: (sdag(),), - (1, 0, 0): lambda: ([op.stmts.SqrtY()],), - (1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]), - (1, 0, 2): lambda: ([op.stmts.H()],), - (1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]), - (1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()), - (1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()), - (1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()), - (1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()), - (2, 0, 0): lambda: ([op.stmts.Y()],), - (2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]), - (2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]), - (2, 0, 3): lambda: (sdag(), [op.stmts.Y()]), + (0, 0, 0): [None], + (0, 0, 1): [gate.stmts.S], + (0, 0, 2): [gate.stmts.Z], + (0, 0, 3): [Sdag], + (1, 0, 0): [gate.stmts.SqrtY], + (1, 0, 1): [gate.stmts.S, gate.stmts.SqrtY], + (1, 0, 2): [gate.stmts.H], + (1, 0, 3): [Sdag, gate.stmts.SqrtY], + (1, 1, 0): [gate.stmts.SqrtY, gate.stmts.S], + (1, 1, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.S], + (1, 1, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S], + (1, 1, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.S], + (1, 2, 0): [gate.stmts.SqrtY, gate.stmts.Z], + (1, 2, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z], + (1, 2, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.Z], + (1, 2, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.Z], + (1, 3, 0): [gate.stmts.SqrtY, Sdag], + (1, 3, 1): [gate.stmts.S, gate.stmts.SqrtY, Sdag], + (1, 3, 2): [gate.stmts.Z, gate.stmts.SqrtY, Sdag], + (1, 3, 3): [Sdag, gate.stmts.SqrtY, Sdag], + (2, 0, 0): [gate.stmts.Y], + (2, 0, 1): [gate.stmts.S, gate.stmts.Y], + (2, 0, 2): [gate.stmts.Z, gate.stmts.Y], + (2, 0, 3): [Sdag, gate.stmts.Y], } @@ -61,8 +62,8 @@ class SquinU3ToClifford(RewriteRule): """ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - if isinstance(node, (qubit.Apply, qubit.Broadcast)): - return self.rewrite_ApplyOrBroadcast_onU3(node) + if isinstance(node, gate.stmts.U3): + return self.rewrite_U3(node) else: return RewriteResult() @@ -87,35 +88,35 @@ def resolve_angle(self, angle: float) -> int | None: else: return round((angle / math.tau) % 1 * 4) % 4 - def rewrite_ApplyOrBroadcast_onU3( - self, node: qubit.Apply | qubit.Broadcast - ) -> RewriteResult: + def rewrite_U3(self, node: gate.stmts.U3) -> RewriteResult: """ Rewrite Apply and Broadcast nodes to their clifford equivalent statements. """ - if not isinstance(node.operator.owner, op.stmts.U3): - return RewriteResult() - gates = self.decompose_U3_gates(node.operator.owner) + gates = self.decompose_U3_gates(node) if len(gates) == 0: return RewriteResult() - for stmt_list in gates: - for gate_stmt in stmt_list[:-1]: - gate_stmt.insert_before(node) + # Get rid of the U3 gate altogether if it's identity + if len(gates) == 1 and gates[0] is None: + node.delete() + return RewriteResult(has_done_something=True) - oper = stmt_list[-1] - oper.insert_before(node) - new_node = node.__class__(operator=oper.result, qubits=node.qubits) - new_node.insert_before(node) + for gate_stmt in gates: + if gate_stmt is Sdag: + new_stmt = gate.stmts.S(adjoint=True, qubits=node.qubits) + else: + new_stmt = gate_stmt(qubits=node.qubits) + new_stmt.insert_before(node) node.delete() - # rewrite U3 to clifford gates return RewriteResult(has_done_something=True) - def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...]: + def decompose_U3_gates( + self, node: gate.stmts.U3 + ) -> list[type[ir.Statement]] | list[None]: """ Rewrite U3 statements to clifford gates if possible. """ @@ -124,7 +125,13 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ... lam = self.get_constant(node.lam) if theta is None or phi is None or lam is None: - return () + return [] + + # Angles will be in units of turns, we convert to radians + # to allow for the old logic to work + theta = theta * math.tau + phi = phi * math.tau + lam = lam * math.tau # For U3(2*pi*n, phi, lam) = U3(0, 0, lam + phi) which is a Z rotation. if np.isclose(np.mod(theta, math.tau), 0): @@ -139,13 +146,13 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ... lam_half_pi: int | None = self.resolve_angle(lam) if theta_half_pi is None or phi_half_pi is None or lam_half_pi is None: - return () + return [] angles_key = (theta_half_pi, phi_half_pi, lam_half_pi) if angles_key not in U3_HALF_PI_ANGLE_TO_GATES: angles_key = equivalent_u3_para(*angles_key) if angles_key not in U3_HALF_PI_ANGLE_TO_GATES: - return () + return [] gates_stmts = U3_HALF_PI_ANGLE_TO_GATES.get(angles_key) @@ -154,4 +161,4 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ... gates_stmts is not None ), "internal error, U3 gates not found for angles: {}".format(angles_key) - return gates_stmts() + return gates_stmts diff --git a/src/bloqade/stim/passes/__init__.py b/src/bloqade/stim/passes/__init__.py index b68a45dd..01d9987e 100644 --- a/src/bloqade/stim/passes/__init__.py +++ b/src/bloqade/stim/passes/__init__.py @@ -1,4 +1,3 @@ from .squin_to_stim import ( SquinToStimPass as SquinToStimPass, - StimSimplifyIfs as StimSimplifyIfs, ) diff --git a/src/bloqade/stim/passes/flatten.py b/src/bloqade/stim/passes/flatten.py new file mode 100644 index 00000000..26f72702 --- /dev/null +++ b/src/bloqade/stim/passes/flatten.py @@ -0,0 +1,61 @@ +# Taken from Phillip Weinberg's bloqade-shuttle implementation +from dataclasses import field, dataclass + +from kirin import ir +from kirin.passes import Pass, HintConst +from kirin.rewrite import ( + Walk, + Chain, + Fixpoint, + Call2Invoke, + ConstantFold, + InlineGetItem, + InlineGetField, + DeadCodeElimination, +) +from kirin.dialects import ilist +from kirin.ir.method import Method +from kirin.rewrite.abc import RewriteResult +from kirin.rewrite.cse import CommonSubexpressionElimination +from kirin.passes.inline import InlinePass + +from bloqade.qasm2.passes.fold import AggressiveUnroll +from bloqade.stim.passes.simplify_ifs import StimSimplifyIfs + + +@dataclass +class Fold(Pass): + hint_const: HintConst = field(init=False) + + def __post_init__(self): + self.hint_const = HintConst(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: Method) -> RewriteResult: + result = RewriteResult() + result = self.hint_const.unsafe_run(mt).join(result) + rule = Chain( + ConstantFold(), + Call2Invoke(), + InlineGetField(), + InlineGetItem(), + ilist.rewrite.InlineGetItem(), + ilist.rewrite.HintLen(), + DeadCodeElimination(), + CommonSubexpressionElimination(), + ) + result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result) + + return result + + +class Flatten(Pass): + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + rewrite_result = InlinePass(dialects=mt.dialects, no_raise=self.no_raise)(mt) + rewrite_result = AggressiveUnroll(dialects=mt.dialects, no_raise=self.no_raise)( + mt + ).join(rewrite_result) + rewrite_result = StimSimplifyIfs(dialects=mt.dialects, no_raise=self.no_raise)( + mt + ).join(rewrite_result) + + return rewrite_result diff --git a/src/bloqade/stim/passes/simplify_ifs.py b/src/bloqade/stim/passes/simplify_ifs.py index b2a691b3..4db85d23 100644 --- a/src/bloqade/stim/passes/simplify_ifs.py +++ b/src/bloqade/stim/passes/simplify_ifs.py @@ -7,8 +7,10 @@ Chain, Fixpoint, ConstantFold, + DeadCodeElimination, CommonSubexpressionElimination, ) +from kirin.dialects.scf.trim import UnusedYield from kirin.dialects.ilist.passes import ConstList2IList from ..rewrite.ifs_to_stim import StimLiftThenBody, StimSplitIfStmts @@ -20,7 +22,10 @@ class StimSimplifyIfs(Pass): def unsafe_run(self, mt: ir.Method): result = Chain( - Fixpoint(Walk(StimLiftThenBody())), + Walk(UnusedYield()), + Walk(StimLiftThenBody()), + # remove yields (if possible), then lift out as much stuff as possible + Walk(DeadCodeElimination()), Walk(StimSplitIfStmts()), ).rewrite(mt.code) diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 2f9889ab..b1776dc1 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -1,29 +1,21 @@ from dataclasses import dataclass -from kirin.passes import Fold, TypeInfer from kirin.rewrite import ( Walk, Chain, Fixpoint, - CFGCompactify, DeadCodeElimination, CommonSubexpressionElimination, ) -from kirin.dialects import ilist from kirin.ir.method import Method from kirin.passes.abc import Pass from kirin.rewrite.abc import RewriteResult -from kirin.passes.inline import InlinePass -from kirin.rewrite.alias import InlineAlias -from kirin.passes.aggressive import UnrollScf from bloqade.stim.rewrite import ( - SquinWireToStim, PyConstantToStim, SquinNoiseToStim, SquinQubitToStim, SquinMeasureToStim, - SquinWireIdentityElimination, ) from bloqade.squin.rewrite import ( SquinU3ToClifford, @@ -33,9 +25,9 @@ from bloqade.rewrite.passes import CanonicalizeIList from bloqade.analysis.address import AddressAnalysis from bloqade.analysis.measure_id import MeasurementIDAnalysis +from bloqade.stim.passes.flatten import Flatten from bloqade.squin.rewrite.desugar import ApplyDesugarRule, MeasureDesugarRule -from .simplify_ifs import StimSimplifyIfs from ..rewrite.ifs_to_stim import IfToStim @@ -45,61 +37,12 @@ class SquinToStimPass(Pass): def unsafe_run(self, mt: Method) -> RewriteResult: # inline aggressively: - rewrite_result = InlinePass( - dialects=mt.dialects, no_raise=self.no_raise - ).unsafe_run(mt) - - rewrite_result = Walk(ilist.rewrite.HintLen()).rewrite(mt.code) - rewrite_result = Fold(self.dialects).unsafe_run(mt).join(rewrite_result) - - rewrite_result = ( - UnrollScf(dialects=mt.dialects, no_raise=self.no_raise) - .fixpoint(mt) - .join(rewrite_result) + rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint( + mt ) rewrite_result = ( - Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(rewrite_result) - ) - - rewrite_result = Walk(InlineAlias()).rewrite(mt.code).join(rewrite_result) - - rewrite_result = ( - StimSimplifyIfs(mt.dialects, no_raise=self.no_raise) - .unsafe_run(mt) - .join(rewrite_result) - ) - - rewrite_result = ( - Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll())) - .rewrite(mt.code) - .join(rewrite_result) - ) - rewrite_result = Fold(mt.dialects, no_raise=self.no_raise)(mt) - - rewrite_result = ( - UnrollScf(mt.dialects, no_raise=self.no_raise) - .fixpoint(mt) - .join(rewrite_result) - ) - - rewrite_result = ( - CanonicalizeIList(dialects=mt.dialects, no_raise=self.no_raise) - .unsafe_run(mt) - .join(rewrite_result) - ) - - rewrite_result = TypeInfer( - dialects=mt.dialects, no_raise=self.no_raise - ).unsafe_run(mt) - - rewrite_result = ( - Walk( - Chain( - ApplyDesugarRule(), - MeasureDesugarRule(), - ) - ) + Walk(Chain(ApplyDesugarRule(), MeasureDesugarRule())) .rewrite(mt.code) .join(rewrite_result) ) @@ -145,8 +88,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult: Chain( SquinQubitToStim(), SquinMeasureToStim(), - SquinWireToStim(), - SquinWireIdentityElimination(), ) ) .rewrite(mt.code) diff --git a/src/bloqade/stim/rewrite/__init__.py b/src/bloqade/stim/rewrite/__init__.py index d5cbe4b2..4b0eb8fe 100644 --- a/src/bloqade/stim/rewrite/__init__.py +++ b/src/bloqade/stim/rewrite/__init__.py @@ -1,9 +1,5 @@ from .ifs_to_stim import IfToStim as IfToStim from .squin_noise import SquinNoiseToStim as SquinNoiseToStim -from .wire_to_stim import SquinWireToStim as SquinWireToStim from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim from .squin_measure import SquinMeasureToStim as SquinMeasureToStim from .py_constant_to_stim import PyConstantToStim as PyConstantToStim -from .wire_identity_elimination import ( - SquinWireIdentityElimination as SquinWireIdentityElimination, -) diff --git a/src/bloqade/stim/rewrite/ifs_to_stim.py b/src/bloqade/stim/rewrite/ifs_to_stim.py index c72fc83d..6a15253b 100644 --- a/src/bloqade/stim/rewrite/ifs_to_stim.py +++ b/src/bloqade/stim/rewrite/ifs_to_stim.py @@ -4,13 +4,13 @@ from kirin.dialects import py, scf, func from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import op, qubit +from bloqade.squin import gate from bloqade.rewrite.rules import LiftThenBody, SplitIfStmts from bloqade.squin.rewrite import AddressAttribute from bloqade.stim.rewrite.util import ( - SQUIN_STIM_CONTROL_GATE_MAPPING, insert_qubit_idx_from_address, ) +from bloqade.stim.dialects.gate import CX as stim_CX, CY as stim_CY, CZ as stim_CZ from bloqade.analysis.measure_id import MeasureIDFrame from bloqade.stim.dialects.auxiliary import GetRecord from bloqade.analysis.measure_id.lattice import ( @@ -58,8 +58,7 @@ def has_else_body(self, stmt: scf.IfElse) -> bool: """Check if the IfElse statement has an else body.""" if stmt.else_body.blocks and not ( len(stmt.else_body.blocks[0].stmts) == 1 - and isinstance(else_term := stmt.else_body.blocks[0].last_stmt, scf.Yield) - and not else_term.values # empty yield + and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield) ): return True @@ -67,12 +66,13 @@ def has_else_body(self, stmt: scf.IfElse) -> bool: DontLiftType = ( - qubit.Apply, - qubit.Broadcast, - scf.Yield, + gate.stmts.SingleQubitGate, + gate.stmts.RotationGate, + gate.stmts.ControlledGate, func.Return, func.Invoke, scf.IfElse, + scf.Yield, ) @@ -99,16 +99,16 @@ class StimSplitIfStmts(IfElseSimplification, SplitIfStmts): Given an IfElse with multiple valid statements in the then-body: if measure_result: - squin.qubit.apply(op.X, q0) - squin.qubit.apply(op.Y, q1) + squin.x(q0) + squin.y(q1) this should be rewritten to: if measure_result: - squin.qubit.apply(op.X, q0) + squin.x(q0) if measure_result: - squin.qubit.apply(op.Y, q1) + squin.y(q1) """ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @@ -139,24 +139,23 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: + # Check the condition is a singular MeasurementIdBool if not isinstance(self.measure_frame.entries[stmt.cond], MeasureIdBool): return RewriteResult() - # check that there is only qubit.Apply in the then-body, - # if there's more than that, we can't do a valid rewrite. - # Can reuse logic from SplitIf + # Reusing code from SplitIf, + # there should only be one statement in the body and it should be a pauli X, Y, or Z *stmts, _ = stmt.then_body.stmts() - if len(stmts) != 1 or not isinstance(stmts[0], (qubit.Apply, qubit.Broadcast)): + if len(stmts) != 1: return RewriteResult() - apply_or_broadcast = stmts[0] - # Check that the gate being applied/broadcasted can be converted to a stim - # controlled gate. - ctrl_op_target_gate = apply_or_broadcast.operator.owner - assert isinstance(ctrl_op_target_gate, op.stmts.Operator) - - stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate)) - if stim_gate is None: + if isinstance(stmts[0], gate.stmts.X): + stim_gate = stim_CX + elif isinstance(stmts[0], gate.stmts.Y): + stim_gate = stim_CY + elif isinstance(stmts[0], gate.stmts.Z): + stim_gate = stim_CZ + else: return RewriteResult() # get necessary measurement ID type from analysis @@ -169,12 +168,7 @@ def rewrite_IfElse(self, stmt: scf.IfElse) -> RewriteResult: ) get_record_stmt = GetRecord(id=measure_id_idx_stmt.result) # noqa: F841 - # get address attribute and generate qubit idx statements - if len(apply_or_broadcast.qubits) != 1: - # NOTE: this is actually invalid since we are dealing with single-qubit operators here - return RewriteResult() - - address_attr = apply_or_broadcast.qubits[0].hints.get("address") + address_attr = stmts[0].qubits.hints.get("address") if address_attr is None: return RewriteResult() diff --git a/src/bloqade/stim/rewrite/qubit_to_stim.py b/src/bloqade/stim/rewrite/qubit_to_stim.py index 1c4302c8..176ed841 100644 --- a/src/bloqade/stim/rewrite/qubit_to_stim.py +++ b/src/bloqade/stim/rewrite/qubit_to_stim.py @@ -1,13 +1,10 @@ from kirin import ir from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import op, noise, qubit +from bloqade.squin import gate, qubit from bloqade.squin.rewrite import AddressAttribute -from bloqade.stim.dialects import gate +from bloqade.stim.dialects import gate as stim_gate, collapse as stim_collapse from bloqade.stim.rewrite.util import ( - SQUIN_STIM_OP_MAPPING, - rewrite_Control, - rewrite_QubitLoss, insert_qubit_idx_from_address, ) @@ -20,64 +17,110 @@ class SquinQubitToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: - case qubit.Apply() | qubit.Broadcast(): - return self.rewrite_Apply_and_Broadcast(node) + # not supported by Stim + case gate.stmts.T() | gate.stmts.RotationGate(): + return RewriteResult() + # If you've reached this point all gates have stim equivalents + case qubit.Reset(): + return self.rewrite_Reset(node) + case gate.stmts.SingleQubitGate(): + return self.rewrite_SingleQubitGate(node) + case gate.stmts.ControlledGate(): + return self.rewrite_ControlledGate(node) case _: return RewriteResult() - def rewrite_Apply_and_Broadcast( - self, stmt: qubit.Apply | qubit.Broadcast - ) -> RewriteResult: - """ - Rewrite Apply and Broadcast nodes to their stim equivalent statements. - """ + def rewrite_Reset(self, stmt: qubit.Reset) -> RewriteResult: - # this is an SSAValue, need it to be the actual operator - applied_op = stmt.operator.owner + qubit_addr_attr = stmt.qubits.hints.get("address", None) - if isinstance(applied_op, noise.stmts.QubitLoss): - return rewrite_QubitLoss(stmt) + if qubit_addr_attr is None: + return RewriteResult() - assert isinstance(applied_op, op.stmts.Operator) + assert isinstance(qubit_addr_attr, AddressAttribute) - if isinstance(applied_op, op.stmts.Control): - return rewrite_Control(stmt) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=qubit_addr_attr, stmt_to_insert_before=stmt + ) - # need to handle Control through separate means + if qubit_idx_ssas is None: + return RewriteResult() - # check if its adjoint, assume its canonicalized so no nested adjoints. - is_conj = False - if isinstance(applied_op, op.stmts.Adjoint): - if not applied_op.is_unitary: - return RewriteResult() + stim_stmt = stim_collapse.RZ(targets=tuple(qubit_idx_ssas)) + stmt.replace_by(stim_stmt) - is_conj = True - applied_op = applied_op.op.owner + return RewriteResult(has_done_something=True) - stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op)) - if stim_1q_op is None: - return RewriteResult() + def rewrite_SingleQubitGate( + self, stmt: gate.stmts.SingleQubitGate + ) -> RewriteResult: + """ + Rewrite single qubit gate nodes to their stim equivalent statements. + Address Analysis should have been run along with Wrap Analysis before this rewrite is applied. + """ - address_attr = stmt.qubits[0].hints.get("address") + qubit_addr_attr = stmt.qubits.hints.get("address", None) - if address_attr is None: + if qubit_addr_attr is None: return RewriteResult() - assert isinstance(address_attr, AddressAttribute) + assert isinstance(qubit_addr_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=stmt + address=qubit_addr_attr, stmt_to_insert_before=stmt ) if qubit_idx_ssas is None: return RewriteResult() - if isinstance(stim_1q_op, gate.stmts.Gate): - stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas), dagger=is_conj) - else: - stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) - stmt.replace_by(stim_1q_stmt) + # Get the name of the inputted stmt and see if there is an + # equivalently named statement in stim, + # then create an instance of that stim statement + stmt_name = type(stmt).__name__ + stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None) + if stim_stmt_cls is None: + return RewriteResult() + stim_stmt = stim_stmt_cls(targets=tuple(qubit_idx_ssas)) + stmt.replace_by(stim_stmt) return RewriteResult(has_done_something=True) + def rewrite_ControlledGate(self, stmt: gate.stmts.ControlledGate) -> RewriteResult: + """ + Rewrite controlled gate nodes to their stim equivalent statements. + Address Analysis should have been run along with Wrap Analysis before this rewrite is applied. + """ + + controls_addr_attr = stmt.controls.hints.get("address", None) + targets_addr_attr = stmt.targets.hints.get("address", None) + + if controls_addr_attr is None or targets_addr_attr is None: + return RewriteResult() + + assert isinstance(controls_addr_attr, AddressAttribute) + assert isinstance(targets_addr_attr, AddressAttribute) + + controls_idx_ssas = insert_qubit_idx_from_address( + address=controls_addr_attr, stmt_to_insert_before=stmt + ) + targets_idx_ssas = insert_qubit_idx_from_address( + address=targets_addr_attr, stmt_to_insert_before=stmt + ) + + if controls_idx_ssas is None or targets_idx_ssas is None: + return RewriteResult() + + # Get the name of the inputted stmt and see if there is an + # equivalently named statement in stim, + # then create an instance of that stim statement + stmt_name = type(stmt).__name__ + stim_stmt_cls = getattr(stim_gate.stmts, stmt_name, None) + if stim_stmt_cls is None: + return RewriteResult() + + stim_stmt = stim_stmt_cls( + targets=tuple(targets_idx_ssas), controls=tuple(controls_idx_ssas) + ) + stmt.replace_by(stim_stmt) -# put rewrites for measure statements in separate rule, then just have to dispatch + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/stim/rewrite/squin_measure.py b/src/bloqade/stim/rewrite/squin_measure.py index c3255339..8926d5ac 100644 --- a/src/bloqade/stim/rewrite/squin_measure.py +++ b/src/bloqade/stim/rewrite/squin_measure.py @@ -5,11 +5,10 @@ from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import wire, qubit +from bloqade.squin import qubit from bloqade.squin.rewrite import AddressAttribute from bloqade.stim.dialects import collapse from bloqade.stim.rewrite.util import ( - is_measure_result_used, insert_qubit_idx_from_address, ) @@ -23,13 +22,13 @@ class SquinMeasureToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: - case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure(): + case qubit.MeasureQubit() | qubit.MeasureQubitList(): return self.rewrite_Measure(node) case _: return RewriteResult() def rewrite_Measure( - self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure + self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList ) -> RewriteResult: qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt) @@ -44,13 +43,16 @@ def rewrite_Measure( prob_noise_stmt.insert_before(measure_stmt) stim_measure_stmt.insert_before(measure_stmt) - if not is_measure_result_used(measure_stmt): + # if the measurement is not being used anywhere + # we can safely get rid of it. Measure cannot be DCE'd because + # it is not pure. + if not bool(measure_stmt.result.uses): measure_stmt.delete() return RewriteResult(has_done_something=True) def get_qubit_idx_ssas( - self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure + self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList ) -> tuple[ir.SSAValue, ...] | None: """ Extract the address attribute and insert qubit indices for the given measure statement. @@ -60,8 +62,6 @@ def get_qubit_idx_ssas( address_attr = measure_stmt.qubit.hints.get("address") case qubit.MeasureQubitList(): address_attr = measure_stmt.qubits.hints.get("address") - case wire.Measure(): - address_attr = measure_stmt.wire.hints.get("address") case _: return None diff --git a/src/bloqade/stim/rewrite/squin_noise.py b/src/bloqade/stim/rewrite/squin_noise.py index 8952792a..c0bb9c50 100644 --- a/src/bloqade/stim/rewrite/squin_noise.py +++ b/src/bloqade/stim/rewrite/squin_noise.py @@ -1,17 +1,14 @@ +import itertools from typing import Tuple from dataclasses import dataclass from kirin.ir import SSAValue, Statement -from kirin.dialects import py, ilist +from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import op, wire, noise as squin_noise, qubit +from bloqade.squin import noise as squin_noise from bloqade.stim.dialects import noise as stim_noise -from bloqade.stim.rewrite.util import ( - get_const_value, - create_wire_passthrough, - insert_qubit_idx_after_apply, -) +from bloqade.stim.rewrite.util import insert_qubit_idx_from_address @dataclass @@ -19,159 +16,142 @@ class SquinNoiseToStim(RewriteRule): def rewrite_Statement(self, node: Statement) -> RewriteResult: match node: - case qubit.Apply() | qubit.Broadcast() | wire.Apply() | wire.Broadcast(): - return self.rewrite_Apply_and_Broadcast(node) + case squin_noise.stmts.NoiseChannel(): + return self.rewrite_NoiseChannel(node) case _: return RewriteResult() - def rewrite_Apply_and_Broadcast( - self, stmt: qubit.Apply | qubit.Broadcast | wire.Apply | wire.Broadcast + def rewrite_NoiseChannel( + self, stmt: squin_noise.stmts.NoiseChannel ) -> RewriteResult: - """Rewrite Apply and Broadcast to their stim statements.""" + """Rewrite NoiseChannel statements to their stim equivalents.""" - # this is an SSAValue, need it to be the actual operator - applied_op = stmt.operator.owner - - if isinstance(applied_op, squin_noise.stmts.QubitLoss): + rewrite_method = getattr(self, f"rewrite_{type(stmt).__name__}", None) + # No rewrite method exists and the rewrite should stop + if rewrite_method is None: return RewriteResult() - if isinstance(applied_op, squin_noise.stmts.NoiseChannel): - - rewrite_method = getattr(self, f"rewrite_{type(applied_op).__name__}", None) - # No rewrite method exists and the rewrite should stop - if rewrite_method is None: + if isinstance(stmt, squin_noise.stmts.SingleQubitNoiseChannel): + qubit_address_attr = stmt.qubits.hints.get("address", None) + if qubit_address_attr is None: return RewriteResult() + qubit_idx_ssas = insert_qubit_idx_from_address(qubit_address_attr, stmt) - qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt) - if qubit_idx_ssas is None: + elif isinstance(stmt, squin_noise.stmts.TwoQubitNoiseChannel): + control_address_attr = stmt.controls.hints.get("address", None) + target_address_attr = stmt.targets.hints.get("address", None) + if control_address_attr is None or target_address_attr is None: + return RewriteResult() + control_qubit_idx_ssas = insert_qubit_idx_from_address( + control_address_attr, stmt + ) + target_qubit_idx_ssas = insert_qubit_idx_from_address( + target_address_attr, stmt + ) + if control_qubit_idx_ssas is None or target_qubit_idx_ssas is None: return RewriteResult() - stim_stmt = rewrite_method(stmt, qubit_idx_ssas) - - if isinstance(stmt, (wire.Apply, wire.Broadcast)): - create_wire_passthrough(stmt) - - # guaranteed that you have a valid stim_stmt to plug in - stmt.replace_by(stim_stmt) + # For stim statements you want to interleave the control and target qubit indices: + # ex: CX controls = (0,1) targets = (2,3) in stim is: CX 0 2 1 3 + qubit_idx_ssas = list( + itertools.chain.from_iterable( + zip(control_qubit_idx_ssas, target_qubit_idx_ssas) + ) + ) + else: + return RewriteResult() - return RewriteResult(has_done_something=True) - return RewriteResult() + # guaranteed that you have a valid stim_stmt to plug in + stim_stmt = rewrite_method(stmt, tuple(qubit_idx_ssas)) + stmt.replace_by(stim_stmt) - def rewrite_PauliError( - self, - stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply, - qubit_idx_ssas: Tuple[SSAValue], - ) -> Statement: - """Rewrite squin.noise.PauliError to XError, YError, ZError.""" - squin_channel = stmt.operator.owner - assert isinstance(squin_channel, squin_noise.stmts.PauliError) - basis = squin_channel.basis.owner - assert isinstance(basis, op.stmts.PauliOp) - p = get_const_value(float, squin_channel.p) - - p_stmt = py.Constant(p) - p_stmt.insert_before(stmt) - - if isinstance(basis, op.stmts.X): - stim_stmt = stim_noise.XError(targets=qubit_idx_ssas, p=p_stmt.result) - elif isinstance(basis, op.stmts.Y): - stim_stmt = stim_noise.YError(targets=qubit_idx_ssas, p=p_stmt.result) - else: - stim_stmt = stim_noise.ZError(targets=qubit_idx_ssas, p=p_stmt.result) - return stim_stmt + return RewriteResult(has_done_something=True) def rewrite_SingleQubitPauliChannel( self, - stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply, + stmt: squin_noise.stmts.SingleQubitPauliChannel, qubit_idx_ssas: Tuple[SSAValue], ) -> Statement: """Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1.""" - squin_channel = stmt.operator.owner - assert isinstance(squin_channel, squin_noise.stmts.SingleQubitPauliChannel) - - params = get_const_value(ilist.IList, squin_channel.params) - new_stmts = [ - p_x := py.Constant(params[0]), - p_y := py.Constant(params[1]), - p_z := py.Constant(params[2]), - ] - for new_stmt in new_stmts: - new_stmt.insert_before(stmt) - stim_stmt = stim_noise.PauliChannel1( targets=qubit_idx_ssas, - px=p_x.result, - py=p_y.result, - pz=p_z.result, + px=stmt.px, + py=stmt.py, + pz=stmt.pz, ) return stim_stmt - def rewrite_TwoQubitPauliChannel( + def rewrite_QubitLoss( self, - stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply, + stmt: squin_noise.stmts.QubitLoss, qubit_idx_ssas: Tuple[SSAValue], ) -> Statement: - """Rewrite squin.noise.SingleQubitPauliChannel to stim.PauliChannel1.""" - - squin_channel = stmt.operator.owner - assert isinstance(squin_channel, squin_noise.stmts.TwoQubitPauliChannel) - - params = get_const_value(ilist.IList, squin_channel.params) - param_stmts = [py.Constant(p) for p in params] - for param_stmt in param_stmts: - param_stmt.insert_before(stmt) + """Rewrite squin.noise.QubitLoss to stim.TrivialError.""" - stim_stmt = stim_noise.PauliChannel2( + stim_stmt = stim_noise.QubitLoss( targets=qubit_idx_ssas, - pix=param_stmts[0].result, - piy=param_stmts[1].result, - piz=param_stmts[2].result, - pxi=param_stmts[3].result, - pxx=param_stmts[4].result, - pxy=param_stmts[5].result, - pxz=param_stmts[6].result, - pyi=param_stmts[7].result, - pyx=param_stmts[8].result, - pyy=param_stmts[9].result, - pyz=param_stmts[10].result, - pzi=param_stmts[11].result, - pzx=param_stmts[12].result, - pzy=param_stmts[13].result, - pzz=param_stmts[14].result, + probs=(stmt.p,), ) + return stim_stmt - def rewrite_Depolarize2( + def rewrite_Depolarize( self, - stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply, + stmt: squin_noise.stmts.Depolarize, qubit_idx_ssas: Tuple[SSAValue], ) -> Statement: - """Rewrite squin.noise.Depolarize2 to stim.Depolarize2.""" - - squin_channel = stmt.operator.owner - assert isinstance(squin_channel, squin_noise.stmts.Depolarize2) + """Rewrite squin.noise.Depolarize to stim.Depolarize1.""" - p = get_const_value(float, squin_channel.p) - p_stmt = py.Constant(p) - p_stmt.insert_before(stmt) + stim_stmt = stim_noise.Depolarize1( + targets=qubit_idx_ssas, + p=stmt.p, + ) - stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=p_stmt.result) return stim_stmt - def rewrite_Depolarize( + def rewrite_TwoQubitPauliChannel( self, - stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply, + stmt: squin_noise.stmts.TwoQubitPauliChannel, qubit_idx_ssas: Tuple[SSAValue], ) -> Statement: - """Rewrite squin.noise.Depolarize to stim.Depolarize1.""" + """Rewrite squin.noise.TwoQubitPauliChannel to stim.PauliChannel2.""" - squin_channel = stmt.operator.owner - assert isinstance(squin_channel, squin_noise.stmts.Depolarize) + params = stmt.probabilities + prob_ssas = [] + for idx in range(15): + idx_stmt = py.Constant(value=idx) + idx_stmt.insert_before(stmt) + getitem_stmt = py.GetItem(obj=params, index=idx_stmt.result) + getitem_stmt.insert_before(stmt) + prob_ssas.append(getitem_stmt.result) - p = get_const_value(float, squin_channel.p) - p_stmt = py.Constant(p) - p_stmt.insert_before(stmt) + stim_stmt = stim_noise.PauliChannel2( + targets=qubit_idx_ssas, + pix=prob_ssas[0], + piy=prob_ssas[1], + piz=prob_ssas[2], + pxi=prob_ssas[3], + pxx=prob_ssas[4], + pxy=prob_ssas[5], + pxz=prob_ssas[6], + pyi=prob_ssas[7], + pyx=prob_ssas[8], + pyy=prob_ssas[9], + pyz=prob_ssas[10], + pzi=prob_ssas[11], + pzx=prob_ssas[12], + pzy=prob_ssas[13], + pzz=prob_ssas[14], + ) + return stim_stmt + + def rewrite_Depolarize2( + self, + stmt: squin_noise.stmts.Depolarize2, + qubit_idx_ssas: Tuple[SSAValue], + ) -> Statement: + """Rewrite squin.noise.Depolarize2 to stim.Depolarize2.""" - stim_stmt = stim_noise.Depolarize1(targets=qubit_idx_ssas, p=p_stmt.result) + stim_stmt = stim_noise.Depolarize2(targets=qubit_idx_ssas, p=stmt.p) return stim_stmt diff --git a/src/bloqade/stim/rewrite/util.py b/src/bloqade/stim/rewrite/util.py index 1efab606..82829f4b 100644 --- a/src/bloqade/stim/rewrite/util.py +++ b/src/bloqade/stim/rewrite/util.py @@ -1,36 +1,9 @@ -from typing import TypeVar - -from kirin import ir, interp -from kirin.analysis import const +from kirin import ir from kirin.dialects import py -from kirin.rewrite.abc import RewriteResult -from bloqade.squin import op, wire, noise as squin_noise, qubit from bloqade.squin.rewrite import AddressAttribute -from bloqade.stim.dialects import gate, noise as stim_noise, collapse from bloqade.analysis.address import AddressReg, AddressWire, AddressQubit, AddressTuple -SQUIN_STIM_OP_MAPPING = { - op.stmts.X: gate.X, - op.stmts.Y: gate.Y, - op.stmts.Z: gate.Z, - op.stmts.H: gate.H, - op.stmts.S: gate.S, - op.stmts.SqrtX: gate.SqrtX, - op.stmts.SqrtY: gate.SqrtY, - op.stmts.Identity: gate.Identity, - op.stmts.Reset: collapse.RZ, - squin_noise.stmts.QubitLoss: stim_noise.QubitLoss, -} - -# Squin allows creation of control gates where the gate can be any operator, -# but Stim only supports CX, CY, and CZ as control gates. -SQUIN_STIM_CONTROL_GATE_MAPPING = { - op.stmts.X: gate.CX, - op.stmts.Y: gate.CY, - op.stmts.Z: gate.CZ, -} - def create_and_insert_qubit_idx_stmt( qubit_idx, stmt_to_insert_before: ir.Statement, qubit_idx_ssas: list @@ -74,161 +47,3 @@ def insert_qubit_idx_from_address( return return tuple(qubit_idx_ssas) - - -def insert_qubit_idx_from_wire_ssa( - wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement -) -> tuple[ir.SSAValue, ...] | None: - """ - Extract qubit indices from wire SSA values and insert them into the SSA form. - """ - qubit_idx_ssas = [] - for wire_ssa in wire_ssas: - address_attribute = wire_ssa.hints.get("address") - if address_attribute is None: - return - assert isinstance(address_attribute, AddressAttribute) - wire_address = address_attribute.address - assert isinstance(wire_address, AddressWire) - qubit_idx = wire_address.origin_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - qubit_idx_ssas.append(qubit_idx_stmt.result) - qubit_idx_stmt.insert_before(stmt_to_insert_before) - - return tuple(qubit_idx_ssas) - - -def insert_qubit_idx_after_apply( - stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast, -) -> tuple[ir.SSAValue, ...] | None: - """ - Extract qubit indices from Apply or Broadcast statements. - """ - if isinstance(stmt, (qubit.Apply, qubit.Broadcast)): - qubits = stmt.qubits - if len(qubits) == 1: - address_attribute = qubits[0].hints.get("address") - if address_attribute is None: - return - else: - address_attribute_data = [] - for qbit in qubits: - address_attribute = qbit.hints.get("address") - if not isinstance(address_attribute, AddressAttribute): - return - address_attribute_data.append(address_attribute.address) - address_attribute = AddressAttribute( - AddressTuple(data=tuple(address_attribute_data)) - ) - - assert isinstance(address_attribute, AddressAttribute) - return insert_qubit_idx_from_address( - address=address_attribute, stmt_to_insert_before=stmt - ) - elif isinstance(stmt, (wire.Apply, wire.Broadcast)): - wire_ssas = stmt.inputs - return insert_qubit_idx_from_wire_ssa( - wire_ssas=wire_ssas, stmt_to_insert_before=stmt - ) - - -def rewrite_Control( - stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast, -) -> RewriteResult: - """ - Handle control gates for Apply and Broadcast statements. - """ - ctrl_op = stmt_with_ctrl.operator.owner - assert isinstance(ctrl_op, op.stmts.Control) - - ctrl_op_target_gate = ctrl_op.op.owner - assert isinstance(ctrl_op_target_gate, op.stmts.Operator) - - qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl) - if qubit_idx_ssas is None: - return RewriteResult() - - # Separate control and target qubits - target_qubits = [] - ctrl_qubits = [] - for i in range(len(qubit_idx_ssas)): - if (i % 2) == 0: - ctrl_qubits.append(qubit_idx_ssas[i]) - else: - target_qubits.append(qubit_idx_ssas[i]) - - target_qubits = tuple(target_qubits) - ctrl_qubits = tuple(ctrl_qubits) - - stim_gate = SQUIN_STIM_CONTROL_GATE_MAPPING.get(type(ctrl_op_target_gate)) - if stim_gate is None: - return RewriteResult() - - stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits) - - if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)): - create_wire_passthrough(stmt_with_ctrl) - - stmt_with_ctrl.replace_by(stim_stmt) - - return RewriteResult(has_done_something=True) - - -def rewrite_QubitLoss( - stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply, -) -> RewriteResult: - """ - Rewrite QubitLoss statements to Stim's TrivialError. - """ - - squin_loss_op = stmt.operator.owner - assert isinstance(squin_loss_op, squin_noise.stmts.QubitLoss) - - qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt) - if qubit_idx_ssas is None: - return RewriteResult() - - stim_loss_stmt = stim_noise.QubitLoss( - targets=qubit_idx_ssas, - probs=(squin_loss_op.p,), - ) - - if isinstance(stmt, (wire.Apply, wire.Broadcast)): - create_wire_passthrough(stmt) - - stmt.replace_by(stim_loss_stmt) - - return RewriteResult(has_done_something=True) - - -def create_wire_passthrough(stmt: wire.Apply | wire.Broadcast) -> None: - - for input_wire, output_wire in zip(stmt.inputs, stmt.results): - # have to "reroute" the input of these statements to directly plug in - # to subsequent statements, remove dependency on the current statement - output_wire.replace_by(input_wire) - - -def is_measure_result_used( - stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure, -) -> bool: - """ - Check if the result of a measure statement is used in the program. - """ - return bool(stmt.result.uses) - - -T = TypeVar("T") - - -def get_const_value(typ: type[T], value: ir.SSAValue) -> T: - if isinstance(hint := value.hints.get("const"), const.Value): - data = hint.data - if isinstance(data, typ): - return hint.data - raise interp.InterpreterError( - f"Expected constant value , got {data}" - ) - raise interp.InterpreterError( - f"Expected constant value , got {value}" - ) diff --git a/src/bloqade/stim/rewrite/wire_identity_elimination.py b/src/bloqade/stim/rewrite/wire_identity_elimination.py deleted file mode 100644 index a9dcc837..00000000 --- a/src/bloqade/stim/rewrite/wire_identity_elimination.py +++ /dev/null @@ -1,24 +0,0 @@ -from kirin import ir -from kirin.rewrite.abc import RewriteRule, RewriteResult - -from bloqade.squin import wire - - -class SquinWireIdentityElimination(RewriteRule): - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - """ - Handle the case where an unwrap feeds a wire directly into a wrap, - equivalent to nothing happening/identity operation - - w = unwrap(qubit) - wrap(qubit, w) - """ - if isinstance(node, wire.Wrap): - wire_origin_stmt = node.wire.owner - if isinstance(wire_origin_stmt, wire.Unwrap): - node.delete() # get rid of wrap - wire_origin_stmt.delete() # get rid of the unwrap - return RewriteResult(has_done_something=True) - - return RewriteResult() diff --git a/src/bloqade/stim/rewrite/wire_to_stim.py b/src/bloqade/stim/rewrite/wire_to_stim.py deleted file mode 100644 index 94640ebd..00000000 --- a/src/bloqade/stim/rewrite/wire_to_stim.py +++ /dev/null @@ -1,57 +0,0 @@ -from kirin import ir -from kirin.rewrite.abc import RewriteRule, RewriteResult - -from bloqade.squin import op, wire, noise -from bloqade.stim.rewrite.util import ( - SQUIN_STIM_OP_MAPPING, - rewrite_Control, - rewrite_QubitLoss, - insert_qubit_idx_from_wire_ssa, -) - - -class SquinWireToStim(RewriteRule): - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - match node: - case wire.Apply() | wire.Broadcast(): - return self.rewrite_Apply_and_Broadcast(node) - case _: - return RewriteResult() - - def rewrite_Apply_and_Broadcast( - self, stmt: wire.Apply | wire.Broadcast - ) -> RewriteResult: - - # this is an SSAValue, need it to be the actual operator - applied_op = stmt.operator.owner - - if isinstance(applied_op, noise.stmts.QubitLoss): - return rewrite_QubitLoss(stmt) - - assert isinstance(applied_op, op.stmts.Operator) - - if isinstance(applied_op, op.stmts.Control): - return rewrite_Control(stmt) - - stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op)) - if stim_1q_op is None: - return RewriteResult() - - qubit_idx_ssas = insert_qubit_idx_from_wire_ssa( - wire_ssas=stmt.inputs, stmt_to_insert_before=stmt - ) - if qubit_idx_ssas is None: - return RewriteResult() - - stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) - - # Get the wires from the inputs of Apply or Broadcast, - # then put those as the result of the current stmt - # before replacing it entirely - for input_wire, output_wire in zip(stmt.inputs, stmt.results): - output_wire.replace_by(input_wire) - - stmt.replace_by(stim_1q_stmt) - - return RewriteResult(has_done_something=True) diff --git a/test/squin/rewrite/test_U3_to_clifford.py b/test/squin/rewrite/test_U3_to_clifford.py index 6076c1cf..327c5138 100644 --- a/test/squin/rewrite/test_U3_to_clifford.py +++ b/test/squin/rewrite/test_U3_to_clifford.py @@ -3,27 +3,41 @@ from kirin import ir from kirin.rewrite import Walk, Chain from kirin.passes.abc import Pass +from kirin.passes.fold import Fold from kirin.rewrite.dce import DeadCodeElimination +from kirin.passes.inline import InlinePass -from bloqade.squin import op, qubit, kernel +from bloqade import squin as sq +from bloqade.squin import gate from bloqade.squin.rewrite.U3_to_clifford import SquinU3ToClifford class SquinToCliffordTestPass(Pass): def unsafe_run(self, mt: ir.Method): - return Walk( - Chain(Walk(SquinU3ToClifford()), Walk(DeadCodeElimination())) - ).rewrite(mt.code) + + rewrite_result = InlinePass(dialects=mt.dialects).fixpoint(mt) + rewrite_result = Fold(dialects=mt.dialects)(mt).join(rewrite_result) + + print("after inline and fold") + mt.print() + + return ( + Walk(Chain(Walk(SquinU3ToClifford()), Walk(DeadCodeElimination()))) + .rewrite(mt.code) + .join(rewrite_result) + ) def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement: return method.callable_region.blocks[0].stmts.at(idx) -def filter_statements_by_type(method: ir.Method, types: tuple[type, ...]) -> list[type]: +def filter_statements_by_type( + method: ir.Method, types: tuple[type, ...] +) -> list[ir.Statement]: return [ - type(stmt) + stmt for stmt in method.callable_region.blocks[0].stmts if isinstance(stmt, types) ] @@ -31,581 +45,617 @@ def filter_statements_by_type(method: ir.Method, types: tuple[type, ...]) -> lis def test_identity(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) - oper = op.u(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau) - qubit.apply(oper, q[0]) + q = sq.qubit.new(4) + sq.u3(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Identity) + # Should be no U3 statements left, they are eliminated if they're equivalent to Identity + no_stmt = filter_statements_by_type(test, (gate.stmts.U3,)) + assert len(no_stmt) == 0 def test_s(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) - oper = op.u(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=0.25 * math.tau) - qubit.apply(oper, q[0]) + q = sq.qubit.new(4) + # S gate + sq.u3(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=0.25 * math.tau, qubit=q[0]) + # Equivalent S gate (different parameters) + sq.u3(theta=math.tau, phi=0.5 * math.tau, lam=0.75 * math.tau, qubit=q[1]) + # S gate alternative form + sq.u3(theta=0.0, phi=0.25 * math.tau, lam=0.0, qubit=q[2]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - - # exercise equivalent_u3_para check - ## assumes it's already in units of half pi and normalized to [0, 1) - @kernel - def test_equiv(): - q = qubit.new(4) - oper = op.u(theta=math.tau, phi=0.5 * math.tau, lam=0.75 * math.tau) - qubit.apply(oper, q[0]) - - SquinToCliffordTestPass(test_equiv.dialects)(test_equiv) - - assert isinstance(get_stmt_at_idx(test_equiv, 4), op.stmts.S) - - -def test_s_alternative(): - - @kernel - def test(): - q = qubit.new(4) - oper = op.u(theta=0.0, phi=0.25 * math.tau, lam=0.0) - qubit.apply(oper, q[0]) - - SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 13), gate.stmts.S) + S_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + # Should be normal S gates, not adjoint/dagger + assert not S_stmts[0].adjoint + assert not S_stmts[1].adjoint + assert not S_stmts[2].adjoint def test_z(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # nice positive representation - op0 = op.u(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau) + sq.u3(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau, qubit=q[0]) # wrap around - op1 = op.u(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=1.5 * math.tau) + sq.u3(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=1.5 * math.tau, qubit=q[1]) # go backwards - op2 = op.u(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=-0.5 * math.tau) - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) - qubit.apply(op2, q[2]) + sq.u3(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=-0.5 * math.tau, qubit=q[2]) + # alternative form + sq.u3(theta=0.0, phi=0.5 * math.tau, lam=0.0, qubit=q[3]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 8), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Z) - - -def test_z_alternative(): - - @kernel - def test(): - q = qubit.new(4) - oper = op.u(theta=0.0, phi=0.5 * math.tau, lam=0.0) - qubit.apply(oper, q[0]) - - SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 13), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 17), gate.stmts.Z) def test_sdag(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) - oper = op.u(theta=0.0 * math.tau, phi=0.0 * math.tau, lam=-0.25 * math.tau) - qubit.apply(oper, q[0]) + q = sq.qubit.new(4) + sq.u3( + theta=0.0 * math.tau, phi=0.0 * math.tau, lam=-0.25 * math.tau, qubit=q[0] + ) + sq.u3(theta=0.0 * math.tau, phi=0.5 * math.tau, lam=0.25 * math.tau, qubit=q[1]) + sq.u3(theta=0.0, phi=-0.25 * math.tau, lam=0.0, qubit=q[2]) + sq.u3(theta=0.0, phi=0.75 * math.tau, lam=0.0, qubit=q[3]) + sq.u3(theta=2 * math.tau, phi=0.7 * math.tau, lam=0.05 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) - - @kernel - def test_equiv(): - q = qubit.new(4) - oper = op.u(theta=0.0 * math.tau, phi=0.5 * math.tau, lam=0.25 * math.tau) - qubit.apply(oper, q[0]) + test.print() - SquinToCliffordTestPass(test_equiv.dialects)(test_equiv) - assert isinstance(get_stmt_at_idx(test_equiv, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test_equiv, 5), op.stmts.Adjoint) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 13), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 17), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 21), gate.stmts.S) - -def test_sdag_alternative_negative(): - - @kernel - def test(): - q = qubit.new(4) - oper = op.u(theta=0.0, phi=-0.25 * math.tau, lam=0.0) - qubit.apply(oper, q[0]) - - SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) - - -def test_sdag_alternative(): - - @kernel - def test(): - q = qubit.new(4) - oper = op.u(theta=0.0, phi=0.75 * math.tau, lam=0.0) - qubit.apply(oper, q[0]) - - SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) - - -def test_sdag_weird_case(): - - @kernel - def test(): - q = qubit.new(4) - oper = op.u(theta=2 * math.tau, phi=0.7 * math.tau, lam=0.05 * math.tau) - qubit.apply(oper, q[0]) - - SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + sdag_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + for sdag_stmt in sdag_stmts: + assert sdag_stmt.adjoint +# Checks that Sdag is the first gate that gets generated, +# There is a Y that gets appended afterwards but is not checked def test_sdag_weirder_case(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) - oper = op.u(theta=0.5 * math.tau, phi=0.05 * math.tau, lam=0.8 * math.tau) - qubit.apply(oper, q[0]) + q = sq.qubit.new(4) + sq.u3(theta=0.5 * math.tau, phi=0.05 * math.tau, lam=0.8 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S) + [S_stmt] = filter_statements_by_type(test, (gate.stmts.S,)) + assert S_stmt.adjoint def test_sqrt_y(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) - op0 = op.u(theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau) + q = sq.qubit.new(4) # equivalent to sqrt(y) gate - op1 = op.u(theta=1.25 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau) - - qubit.apply(op0, q[0]) - qubit.apply(op1, q[0]) + sq.u3(theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau, qubit=q[0]) + sq.u3(theta=1.25 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 8), op.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtY) + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + assert not sqrt_y_stmts[0].adjoint + assert not sqrt_y_stmts[1].adjoint def test_s_sqrt_y(): - @kernel + @sq.kernel def test(): + q = sq.qubit.new(4) + sq.u3( + theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.25 * math.tau, qubit=q[0] + ) + sq.u3( + theta=1.25 * math.tau, phi=1.0 * math.tau, lam=1.25 * math.tau, qubit=q[1] + ) - q = qubit.new(4) - op0 = op.u(theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.25 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=1.0 * math.tau, lam=1.25 * math.tau) + SquinToCliffordTestPass(test.dialects)(test) - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.SqrtY) - SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 10), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.SqrtY) + s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + + for s_stmt in s_stmts: + assert not s_stmt.adjoint + + for sqrt_y_stmt in sqrt_y_stmts: + assert not sqrt_y_stmt.adjoint def test_h(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 0, 1) - op0 = op.u(theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=0.0 * math.tau, lam=1.5 * math.tau) - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3(theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau, qubit=q[0]) + sq.u3(theta=1.25 * math.tau, phi=0.0 * math.tau, lam=1.5 * math.tau, qubit=q[1]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.H) - assert isinstance(get_stmt_at_idx(test, 8), op.stmts.H) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.H) + assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.H) def test_sdg_sqrt_y(): - @kernel() + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 0, 3) - op0 = op.u(theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.75 * math.tau) - op1 = op.u(theta=-1.75 * math.tau, phi=0.0 * math.tau, lam=-1.25 * math.tau) - - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3( + theta=0.25 * math.tau, phi=0.0 * math.tau, lam=0.75 * math.tau, qubit=q[0] + ) + sq.u3( + theta=-1.75 * math.tau, phi=0.0 * math.tau, lam=-1.25 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Adjoint) - assert isinstance(get_stmt_at_idx(test, 7), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 11), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Adjoint) - assert isinstance(get_stmt_at_idx(test, 14), op.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.SqrtY) + + s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + + for s_stmt in s_stmts: + assert s_stmt.adjoint + + for sqrt_y_stmt in sqrt_y_stmts: + assert not sqrt_y_stmt.adjoint def test_sqrt_y_s(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 1, 0) - op0 = op.u(theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.0 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=-1.75 * math.tau, lam=0.0 * math.tau) - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3( + theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.0 * math.tau, qubit=q[0] + ) + sq.u3( + theta=1.25 * math.tau, phi=-1.75 * math.tau, lam=0.0 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) + test.print() + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.S) + + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + + for sqrt_y_stmt in sqrt_y_stmts: + assert not sqrt_y_stmt.adjoint + + for s_stmt in s_stmts: + assert not s_stmt.adjoint def test_s_sqrt_y_s(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 1, 1) - op0 = op.u(theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.25 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=1.25 * math.tau, lam=1.25 * math.tau) - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3( + theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.25 * math.tau, qubit=q[0] + ) + sq.u3( + theta=1.25 * math.tau, phi=1.25 * math.tau, lam=1.25 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (op.stmts.S, op.stmts.SqrtY)) + s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.S, - op.stmts.SqrtY, - op.stmts.S, + # Should be S, SqrtY, S for each op + assert [ + type(stmt) + for stmt in filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) + ] == [ + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.S, + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.S, ] + # Check adjoint property + for s_stmt in s_stmts: + assert not s_stmt.adjoint + for sqrt_y_stmt in sqrt_y_stmts: + assert not sqrt_y_stmt.adjoint + def test_z_sqrt_y_s(): - @kernel + @sq.kernel def test(): - q = qubit.new(1) + q = sq.qubit.new(1) # (1, 1, 2) - op0 = op.u(theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.5 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=1.25 * math.tau, lam=1.5 * math.tau) - qubit.apply(op0, q[0]) - qubit.apply(op1, q[0]) + sq.u3( + theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.5 * math.tau, qubit=q[0] + ) + sq.u3( + theta=1.25 * math.tau, phi=1.25 * math.tau, lam=1.5 * math.tau, qubit=q[0] + ) SquinToCliffordTestPass(test.dialects)(test) + test.print() relevant_stmts = filter_statements_by_type( - test, (op.stmts.Z, op.stmts.SqrtY, op.stmts.S) + test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S) ) - assert relevant_stmts == [ - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.S, + + expected_types = [ + gate.stmts.Z, + gate.stmts.SqrtY, + gate.stmts.S, + gate.stmts.Z, + gate.stmts.SqrtY, + gate.stmts.S, ] + assert [type(stmt) for stmt in relevant_stmts] == expected_types + + for relevant_stmt in relevant_stmts: + if type(relevant_stmt) is not gate.stmts.Z: + assert not relevant_stmt.adjoint def test_sdg_sqrt_y_s(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 1, 3) - op0 = op.u(theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.75 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=1.25 * math.tau, lam=1.75 * math.tau) - - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3( + theta=0.25 * math.tau, phi=0.25 * math.tau, lam=0.75 * math.tau, qubit=q[0] + ) + sq.u3( + theta=1.25 * math.tau, phi=1.25 * math.tau, lam=1.75 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY) - ) + relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.S, + # Should be Sdg, SqrtY, S for each op + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.S, + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.S, ] + # Check adjoint property: the first S in each group should be adjoint + s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + + assert s_stmts[0].adjoint + assert s_stmts[2].adjoint + for sqrt_y_stmt in sqrt_y_stmts: + assert not sqrt_y_stmt.adjoint + def test_sqrt_y_z(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 2, 0) - op0 = op.u(theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.0 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=-1.5 * math.tau, lam=0.0 * math.tau) - - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3(theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.0 * math.tau, qubit=q[0]) + sq.u3( + theta=1.25 * math.tau, phi=-1.5 * math.tau, lam=0.0 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) + test.print() - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 10), op.stmts.SqrtY) - assert isinstance(get_stmt_at_idx(test, 12), op.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.SqrtY) + assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.Z) + + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + for sqrt_y_stmt in sqrt_y_stmts: + assert not sqrt_y_stmt.adjoint def test_s_sqrt_y_z(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 2, 1) - op0 = op.u(theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.25 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=1.5 * math.tau, lam=-1.75 * math.tau) - - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3( + theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.25 * math.tau, qubit=q[0] + ) + sq.u3( + theta=1.25 * math.tau, phi=1.5 * math.tau, lam=-1.75 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.SqrtY, op.stmts.Z) + test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z) ) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.SqrtY, - op.stmts.Z, - op.stmts.S, - op.stmts.SqrtY, - op.stmts.Z, + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.Z, + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.Z, ] + for stmt in relevant_stmts: + if type(stmt) is not gate.stmts.Z: + assert not stmt.adjoint + def test_z_sqrt_y_z(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 2, 2) - op0 = op.u(theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.5 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=-0.5 * math.tau, lam=-1.5 * math.tau) - - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3(theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.5 * math.tau, qubit=q[0]) + sq.u3( + theta=1.25 * math.tau, phi=-0.5 * math.tau, lam=-1.5 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type(test, (op.stmts.Z, op.stmts.SqrtY)) + relevant_stmts = filter_statements_by_type(test, (gate.stmts.Z, gate.stmts.SqrtY)) - assert relevant_stmts == [ - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.Z, - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.Z, + expected_types = [ + gate.stmts.Z, + gate.stmts.SqrtY, + gate.stmts.Z, + gate.stmts.Z, + gate.stmts.SqrtY, + gate.stmts.Z, ] + assert [type(stmt) for stmt in relevant_stmts] == expected_types + + for stmt in relevant_stmts: + if type(stmt) is gate.stmts.SqrtY: + assert not stmt.adjoint def test_sdg_sqrt_y_z(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 2, 3) - op0 = op.u(theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.75 * math.tau) - op1 = op.u(theta=1.25 * math.tau, phi=1.5 * math.tau, lam=-1.25 * math.tau) - - qubit.apply(op0, q[0]) - qubit.apply(op1, q[1]) + sq.u3( + theta=0.25 * math.tau, phi=0.5 * math.tau, lam=0.75 * math.tau, qubit=q[0] + ) + sq.u3( + theta=1.25 * math.tau, phi=1.5 * math.tau, lam=-1.25 * math.tau, qubit=q[1] + ) SquinToCliffordTestPass(test.dialects)(test) relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY, op.stmts.Z) + test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z) ) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.Z, - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.Z, + + # Should be Sdag, SqrtY, Z for each op + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.Z, + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.Z, ] + # Check adjoint property: Sdag should be adjoint, SqrtY and Z should not + s_stmts = filter_statements_by_type(test, (gate.stmts.S,)) + sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,)) + + for s_stmt in s_stmts: + assert s_stmt.adjoint + + for sqrt_y_stmt in sqrt_y_stmts: + assert not sqrt_y_stmt.adjoint + def test_sqrt_y_sdg(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 3, 0) - op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.0 * math.tau) - qubit.apply(op0, q[0]) + sq.u3( + theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.0 * math.tau, qubit=q[0] + ) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type( - test, (op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) - ) - assert relevant_stmts == [ - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, + relevant_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY, gate.stmts.S)) + # Check for SqrtY followed by S (adjoint property can be checked if needed) + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.SqrtY, + gate.stmts.S, ] + assert not relevant_stmts[0].adjoint + assert relevant_stmts[1].adjoint + def test_s_sqrt_y_sdg(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 3, 1) - op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.25 * math.tau) - qubit.apply(op0, q[0]) + sq.u3( + theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.25 * math.tau, qubit=q[0] + ) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type( - test, (op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) - ) + relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.S, ] + # The last S should be adjoint + assert not relevant_stmts[0].adjoint + assert not relevant_stmts[1].adjoint + assert relevant_stmts[2].adjoint def test_z_sqrt_y_sdg(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 3, 2) - op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.5 * math.tau) - qubit.apply(op0, q[0]) + sq.u3( + theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.5 * math.tau, qubit=q[0] + ) SquinToCliffordTestPass(test.dialects)(test) relevant_stmts = filter_statements_by_type( - test, (op.stmts.Z, op.stmts.SqrtY, op.stmts.S, op.stmts.Adjoint) + test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S) ) - assert relevant_stmts == [ - op.stmts.Z, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, + # Should be Z, SqrtY, S (adjoint) + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.Z, + gate.stmts.SqrtY, + gate.stmts.S, ] + assert not relevant_stmts[1].adjoint + assert relevant_stmts[2].adjoint def test_sdg_sqrt_y_sdg(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (1, 3, 3) - op0 = op.u(theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.75 * math.tau) - qubit.apply(op0, q[0]) + sq.u3( + theta=0.25 * math.tau, phi=0.75 * math.tau, lam=0.75 * math.tau, qubit=q[0] + ) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.SqrtY) - ) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.SqrtY, - op.stmts.S, - op.stmts.Adjoint, + relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY)) + + # Should be Sdag, SqrtY, Sdag for the op + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.S, + gate.stmts.SqrtY, + gate.stmts.S, ] + # The first and last S should be adjoint, SqrtY should not + assert relevant_stmts[0].adjoint + assert not relevant_stmts[1].adjoint + assert relevant_stmts[2].adjoint def test_y(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (2, 0, 0) - op0 = op.u(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau) - qubit.apply(op0, q[0]) + sq.u3(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.0 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Y) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Y) def test_s_y(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (2, 0, 1) - op0 = op.u(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.25 * math.tau) - qubit.apply(op0, q[0]) + sq.u3(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.25 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.S) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Y) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Y) + + [s_stmt] = filter_statements_by_type(test, (gate.stmts.S,)) + + assert not s_stmt.adjoint def test_z_y(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (2, 0, 2) - op0 = op.u(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau) - qubit.apply(op0, q[0]) + sq.u3(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Z) - assert isinstance(get_stmt_at_idx(test, 6), op.stmts.Y) + assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Z) + assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Y) def test_sdg_y(): - @kernel + @sq.kernel def test(): - q = qubit.new(4) + q = sq.qubit.new(4) # (2, 0, 3) - op0 = op.u(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.75 * math.tau) - qubit.apply(op0, q[0]) + sq.u3(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.75 * math.tau, qubit=q[0]) SquinToCliffordTestPass(test.dialects)(test) - relevant_stmts = filter_statements_by_type( - test, (op.stmts.S, op.stmts.Adjoint, op.stmts.Y) - ) - assert relevant_stmts == [ - op.stmts.S, - op.stmts.Adjoint, - op.stmts.Y, + relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.Y)) + # Should be Sdag, Y for the op + assert [type(stmt) for stmt in relevant_stmts] == [ + gate.stmts.S, + gate.stmts.Y, ] + # The S should be adjoint + assert relevant_stmts[0].adjoint diff --git a/test/stim/passes/squin_noise_to_stim.py b/test/stim/passes/squin_noise_to_stim.py deleted file mode 100644 index 057b10d6..00000000 --- a/test/stim/passes/squin_noise_to_stim.py +++ /dev/null @@ -1,408 +0,0 @@ -import os - -import kirin.types as kirin_types -from kirin import ir, types -from kirin.decl import statement -from kirin.rewrite import Walk -from kirin.dialects import py, func, ilist - -import bloqade.types as bloqade_types -from bloqade.squin import op, wire, noise, qubit, kernel -from bloqade.stim.emit import EmitStimMain -from bloqade.stim.passes import SquinToStimPass -from bloqade.stim.rewrite import SquinNoiseToStim -from bloqade.squin.rewrite import WrapAddressAnalysis -from bloqade.analysis.address import AddressAnalysis - -extended_kernel = kernel.add(wire) - - -def gen_func_from_stmts(stmts, output_type=types.NoneType): - - block = ir.Block(stmts) - block.args.append_from(types.MethodType[[], types.NoneType], "main") - func_wrapper = func.Function( - sym_name="main", - signature=func.Signature(inputs=(), output=output_type), - body=ir.Region(blocks=block), - ) - - constructed_method = ir.Method( - mod=None, - py_func=None, - sym_name="main", - dialects=extended_kernel, - code=func_wrapper, - arg_names=[], - ) - - return constructed_method - - -def as_int(value: int): - return py.constant.Constant(value=value) - - -def as_float(value: float): - return py.constant.Constant(value=value) - - -def codegen(mt: ir.Method): - # method should not have any arguments! - emit = EmitStimMain() - emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output().strip() - - -def load_reference_program(filename): - """Load stim file.""" - path = os.path.join( - os.path.dirname(__file__), "stim_reference_programs", "noise", filename - ) - with open(path, "r") as f: - return f.read().strip() - - -def test_apply_pauli_channel_1(): - - @kernel - def test(): - q = qubit.new(1) - channel = noise.single_qubit_pauli_channel(params=[0.01, 0.02, 0.03]) - qubit.apply(channel, q[0]) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program("apply_pauli_channel_1.stim") - assert codegen(test) == expected_stim_program - - -def test_broadcast_pauli_channel_1(): - - @kernel - def test(): - q = qubit.new(1) - channel = noise.single_qubit_pauli_channel(params=[0.01, 0.02, 0.03]) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program("broadcast_pauli_channel_1.stim") - assert codegen(test) == expected_stim_program - - -def test_broadcast_pauli_channel_1_many_qubits(): - - @kernel - def test(): - q = qubit.new(2) - channel = noise.single_qubit_pauli_channel(params=[0.01, 0.02, 0.03]) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program( - "broadcast_pauli_channel_1_many_qubits.stim" - ) - assert codegen(test) == expected_stim_program - - -def test_broadcast_pauli_channel_1_reuse(): - - @kernel - def test(): - q = qubit.new(1) - channel = noise.single_qubit_pauli_channel(params=[0.01, 0.02, 0.03]) - qubit.broadcast(channel, q) - qubit.broadcast(channel, q) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program( - "broadcast_pauli_channel_1_reuse.stim" - ) - assert codegen(test) == expected_stim_program - - -def test_broadcast_pauli_channel_2(): - - @kernel - def test(): - q = qubit.new(2) - channel = noise.two_qubit_pauli_channel( - params=[ - 0.001, - 0.002, - 0.003, - 0.004, - 0.005, - 0.006, - 0.007, - 0.008, - 0.009, - 0.010, - 0.011, - 0.012, - 0.013, - 0.014, - 0.015, - ] - ) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program("broadcast_pauli_channel_2.stim") - assert codegen(test) == expected_stim_program - - -def test_broadcast_pauli_channel_2_reuse_on_4_qubits(): - - @kernel - def test(): - q = qubit.new(4) - channel = noise.two_qubit_pauli_channel( - params=[ - 0.001, - 0.002, - 0.003, - 0.004, - 0.005, - 0.006, - 0.007, - 0.008, - 0.009, - 0.010, - 0.011, - 0.012, - 0.013, - 0.014, - 0.015, - ] - ) - qubit.broadcast(channel, [q[0], q[1]]) - qubit.broadcast(channel, [q[2], q[3]]) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program( - "broadcast_pauli_channel_2_reuse_on_4_qubits.stim" - ) - assert codegen(test) == expected_stim_program - - -def test_broadcast_depolarize2(): - - @kernel - def test(): - q = qubit.new(2) - channel = noise.depolarize2(p=0.015) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program("broadcast_depolarize2.stim") - assert codegen(test) == expected_stim_program - - -def test_apply_depolarize1(): - - @kernel - def test(): - q = qubit.new(1) - channel = noise.depolarize(p=0.01) - qubit.apply(channel, q[0]) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program("apply_depolarize1.stim") - assert codegen(test) == expected_stim_program - - -def test_broadcast_depolarize1(): - - @kernel - def test(): - q = qubit.new(4) - channel = noise.depolarize(p=0.01) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program("broadcast_depolarize1.stim") - assert codegen(test) == expected_stim_program - - -def test_broadcast_iid_bit_flip_channel(): - - @kernel - def test(): - q = qubit.new(4) - x = op.x() - channel = noise.pauli_error(x, 0.01) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program( - "broadcast_iid_bit_flip_channel.stim" - ) - assert codegen(test) == expected_stim_program - - -def test_broadcast_iid_phase_flip_channel(): - - @kernel - def test(): - q = qubit.new(4) - z = op.z() - channel = noise.pauli_error(z, 0.01) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program( - "broadcast_iid_phase_flip_channel.stim" - ) - assert codegen(test) == expected_stim_program - - -def test_broadcast_iid_y_flip_channel(): - - @kernel - def test(): - q = qubit.new(4) - y = op.y() - channel = noise.pauli_error(y, 0.01) - qubit.broadcast(channel, q) - return - - SquinToStimPass(test.dialects)(test) - expected_stim_program = load_reference_program("broadcast_iid_y_flip_channel.stim") - assert codegen(test) == expected_stim_program - - -def test_apply_loss(): - - @kernel - def test(): - q = qubit.new(3) - loss = noise.qubit_loss(0.1) - qubit.apply(loss, q[0]) - qubit.apply(loss, q[1]) - qubit.apply(loss, q[2]) - - SquinToStimPass(test.dialects)(test) - - expected_stim_program = load_reference_program("apply_loss.stim") - assert codegen(test) == expected_stim_program - - -def test_wire_apply_pauli_channel_1(): - - stmts: list[ir.Statement] = [ - (n_qubits := as_int(1)), - (q := qubit.New(n_qubits=n_qubits.result)), - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(obj=q.result, index=idx0.result)), - (w0 := wire.Unwrap(qubit=q0.result)), - # apply noise other than qubit loss - (prob_x := as_float(0.01)), - (prob_y := as_float(0.01)), - (prob_z := as_float(0.01)), - ( - noise_params := ilist.New( - values=(prob_x.result, prob_y.result, prob_z.result) - ) - ), - ( - pauli_channel_1q := noise.stmts.SingleQubitPauliChannel( - params=noise_params.result - ) - ), - (app0 := wire.Apply(pauli_channel_1q.result, w0.result)), - (wire.Wrap(app0.results[0], q0.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - - SquinToStimPass(test_method.dialects)(test_method) - - expected_stim_program = load_reference_program("wire_apply_pauli_channel_1.stim") - assert codegen(test_method) == expected_stim_program - - -def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement: - return method.callable_region.blocks[0].stmts.at(idx) - - -# If there's no concrete qubit values from the address analysis then -# the rewrite rule should immediately return and not mutate the method. -def test_no_qubit_address_available(): - - @kernel - def test(q: ilist.IList[bloqade_types.Qubit, kirin_types.Literal]): - channel = noise.single_qubit_pauli_channel(params=[0.01, 0.02, 0.03]) - qubit.apply(channel, q[0]) - return - - Walk(SquinNoiseToStim()).rewrite(test.code) - - expected_noise_channel_stmt = get_stmt_at_idx(test, 1) - expected_qubit_apply_stmt = get_stmt_at_idx(test, 4) - - assert isinstance(expected_noise_channel_stmt, noise.stmts.SingleQubitPauliChannel) - assert isinstance(expected_qubit_apply_stmt, qubit.Apply) - - -def test_nonexistent_noise_channel(): - - @statement(dialect=noise.dialect) - class NonExistentNoiseChannel(noise.stmts.NoiseChannel): - """ - A non-existent noise channel for testing purposes. - """ - - pass - - @kernel - def test(): - q = qubit.new(1) - channel = NonExistentNoiseChannel() - qubit.apply(channel, q[0]) - return - - frame, _ = AddressAnalysis(test.dialects).run_analysis(test) - WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) - - rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) - - expected_noise_channel_stmt = get_stmt_at_idx(test, 2) - expected_qubit_apply_stmt = get_stmt_at_idx(test, 5) - - # The rewrite shouldn't have occurred at all because there is no rewrite logic for - # NonExistentNoiseChannel. - assert not rewrite_result.has_done_something - assert isinstance(expected_noise_channel_stmt, NonExistentNoiseChannel) - assert isinstance(expected_qubit_apply_stmt, qubit.Apply) - - -def test_standard_op_no_rewrite(): - - @kernel - def test(): - q = qubit.new(1) - qubit.apply(op.x(), q[0]) - return - - frame, _ = AddressAnalysis(test.dialects).run_analysis(test) - WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) - - rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) - - # Rewrite should not have done anything because target is not a noise channel - assert not rewrite_result.has_done_something diff --git a/test/stim/passes/squin_wire_to_stim.py b/test/stim/passes/squin_wire_to_stim.py deleted file mode 100644 index 89d32ea8..00000000 --- a/test/stim/passes/squin_wire_to_stim.py +++ /dev/null @@ -1,394 +0,0 @@ -import os - -import pytest -from kirin import ir, types -from kirin.passes import TypeInfer -from kirin.rewrite import Walk -from kirin.dialects import py, func - -from bloqade import squin -from bloqade.squin import wire, kernel -from bloqade.stim.emit import EmitStimMain -from bloqade.stim.passes import SquinToStimPass -from bloqade.squin.rewrite import WrapAddressAnalysis -from bloqade.analysis.address import AddressAnalysis - - -def gen_func_from_stmts(stmts, output_type=types.NoneType): - - extended_dialect = kernel.add(wire) - - block = ir.Block(stmts) - block.args.append_from(types.MethodType[[], types.NoneType], "main") - func_wrapper = func.Function( - sym_name="main", - signature=func.Signature(inputs=(), output=output_type), - body=ir.Region(blocks=block), - ) - - constructed_method = ir.Method( - mod=None, - py_func=None, - sym_name="main", - dialects=extended_dialect, - code=func_wrapper, - arg_names=[], - ) - - return constructed_method - - -def codegen(mt: ir.Method): - # method should not have any arguments! - emit = EmitStimMain() - emit.initialize() - emit.run(mt=mt, args=()) - return emit.get_output() - - -def as_int(value: int): - return py.constant.Constant(value=value) - - -def as_float(value: float): - return py.constant.Constant(value=value) - - -def get_stim_reference_file(filename: str) -> str: - path = os.path.join( - os.path.dirname(__file__), - "stim_reference_programs", - "wire", - filename, - ) - with open(path, "r") as f: - return f.read() - - -def run_passes(test_method): - TypeInfer(test_method.dialects)(test_method) - addr_frame, _ = AddressAnalysis(test_method.dialects).run_analysis(test_method) - Walk(WrapAddressAnalysis(address_analysis=addr_frame.entries)).rewrite( - test_method.code - ) - SquinToStimPass(test_method.dialects)(test_method) - - -@pytest.mark.xfail -def test_wire(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(4)), - # returns an ilist - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubits out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - (idx1 := as_int(1)), - (q1 := py.indexing.GetItem(q.result, idx1.result)), - (idx2 := as_int(2)), - (q2 := py.indexing.GetItem(q.result, idx2.result)), - (idx3 := as_int(3)), - (q3 := py.indexing.GetItem(q.result, idx3.result)), - # get wires from qubits - (w0 := squin.wire.Unwrap(qubit=q0.result)), - (w1 := squin.wire.Unwrap(qubit=q1.result)), - (w2 := squin.wire.Unwrap(qubit=q2.result)), - (w3 := squin.wire.Unwrap(qubit=q3.result)), - # try Apply - (op0 := squin.op.stmts.S()), - (app0 := squin.wire.Apply(op0.result, w0.result)), - # try Broadcast - (op1 := squin.op.stmts.H()), - ( - broad0 := squin.wire.Broadcast( - op1.result, app0.results[0], w1.result, w2.result, w3.result - ) - ), - # wrap everything back - (squin.wire.Wrap(broad0.results[0], q0.result)), - (squin.wire.Wrap(broad0.results[1], q1.result)), - (squin.wire.Wrap(broad0.results[2], q2.result)), - (squin.wire.Wrap(broad0.results[3], q3.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_apply(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(1)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubit out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - # Unwrap to get wires - (w0 := squin.wire.Unwrap(qubit=q0.result)), - # pass the wires through some 1 Qubit operators - (op1 := squin.op.stmts.S()), - (v0 := squin.wire.Apply(op1.result, w0.result)), - ( - squin.wire.Wrap(v0.results[0], q0.result) - ), # for wrap, just free a use for the result SSAval - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_apply.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_multiple_apply(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(1)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubit out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - # Unwrap to get wires - (w0 := squin.wire.Unwrap(qubit=q0.result)), - # pass the wires through some 1 Qubit operators - (op1 := squin.op.stmts.S()), - (op2 := squin.op.stmts.H()), - (op3 := squin.op.stmts.Identity(sites=1)), - (op4 := squin.op.stmts.Identity(sites=1)), - (v0 := squin.wire.Apply(op1.result, w0.result)), - (v1 := squin.wire.Apply(op2.result, v0.results[0])), - (v2 := squin.wire.Apply(op3.result, v1.results[0])), - (v3 := squin.wire.Apply(op4.result, v2.results[0])), - ( - squin.wire.Wrap(v3.results[0], q0.result) - ), # for wrap, just free a use for the result SSAval - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_multiple_apply.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_broadcast(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(4)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubits out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - (idx1 := as_int(1)), - (q1 := py.indexing.GetItem(q.result, idx1.result)), - (idx2 := as_int(2)), - (q2 := py.indexing.GetItem(q.result, idx2.result)), - (idx3 := as_int(3)), - (q3 := py.indexing.GetItem(q.result, idx3.result)), - # Unwrap to get wires - (w0 := squin.wire.Unwrap(qubit=q0.result)), - (w1 := squin.wire.Unwrap(qubit=q1.result)), - (w2 := squin.wire.Unwrap(qubit=q2.result)), - (w3 := squin.wire.Unwrap(qubit=q3.result)), - # Apply with stim semantics - (h_op := squin.op.stmts.H()), - ( - app_res := squin.wire.Broadcast( - h_op.result, w0.result, w1.result, w2.result, w3.result - ) - ), - # Wrap everything back - (squin.wire.Wrap(app_res.results[0], q0.result)), - (squin.wire.Wrap(app_res.results[1], q1.result)), - (squin.wire.Wrap(app_res.results[2], q2.result)), - (squin.wire.Wrap(app_res.results[3], q3.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_broadcast.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_broadcast_control(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(4)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubits out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - (idx1 := as_int(1)), - (q1 := py.indexing.GetItem(q.result, idx1.result)), - (idx2 := as_int(2)), - (q2 := py.indexing.GetItem(q.result, idx2.result)), - (idx3 := as_int(3)), - (q3 := py.indexing.GetItem(q.result, idx3.result)), - # Unwrap to get wires - (w0 := squin.wire.Unwrap(qubit=q0.result)), - (w1 := squin.wire.Unwrap(qubit=q1.result)), - (w2 := squin.wire.Unwrap(qubit=q2.result)), - (w3 := squin.wire.Unwrap(qubit=q3.result)), - # Create and apply CX gate - (x_op := squin.op.stmts.X()), - (ctrl_x_op := squin.op.stmts.Control(x_op.result, n_controls=1)), - ( - app_res := squin.wire.Broadcast( - ctrl_x_op.result, w0.result, w1.result, w2.result, w3.result - ) - ), - # measure it all out - (squin.wire.Measure(wire=app_res.results[0], qubit=q0.result)), - (squin.wire.Measure(wire=app_res.results[1], qubit=q1.result)), - (squin.wire.Measure(wire=app_res.results[2], qubit=q2.result)), - (squin.wire.Measure(wire=app_res.results[3], qubit=q3.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_broadcast_control.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_apply_control(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(2)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubis out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - (idx1 := as_int(1)), - (q1 := py.indexing.GetItem(q.result, idx1.result)), - # Unwrap to get wires - (w0 := squin.wire.Unwrap(qubit=q0.result)), - (w1 := squin.wire.Unwrap(qubit=q1.result)), - # set up control gate - (op1 := squin.op.stmts.X()), - (cx := squin.op.stmts.Control(op1.result, n_controls=1)), - (app := squin.wire.Apply(cx.result, w0.result, w1.result)), - # wrap things back - (squin.wire.Wrap(wire=app.results[0], qubit=q0.result)), - (squin.wire.Wrap(wire=app.results[1], qubit=q1.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_apply_control.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_measure(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(2)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubis out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - # Unwrap to get wires - (w0 := squin.wire.Unwrap(qubit=q0.result)), - # measure the wires out - (squin.wire.Measure(wire=w0.result, qubit=q0.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_measure.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_reset(): - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(1)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubits out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - # get wire - (w0 := squin.wire.Unwrap(q0.result)), - (res_op := squin.op.stmts.Reset()), - (app := squin.wire.Apply(res_op.result, w0.result)), - # wrap it back - (squin.wire.Measure(wire=app.results[0], qubit=q0.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_reset.stim") - assert codegen(test_method) == base_stim_prog.rstrip() - - -@pytest.mark.xfail -def test_wire_qubit_loss(): - - stmts: list[ir.Statement] = [ - (n_qubits := as_int(5)), - (q := squin.qubit.New(n_qubits=n_qubits.result)), - # Get qubits out - (idx0 := as_int(0)), - (q0 := py.indexing.GetItem(q.result, idx0.result)), - (idx1 := as_int(1)), - (q1 := py.indexing.GetItem(q.result, idx1.result)), - (idx2 := as_int(2)), - (q2 := py.indexing.GetItem(q.result, idx2.result)), - (idx3 := as_int(3)), - (q3 := py.indexing.GetItem(q.result, idx3.result)), - (idx4 := as_int(4)), - (q4 := py.indexing.GetItem(q.result, idx4.result)), - # get wires from qubits - (w0 := squin.wire.Unwrap(qubit=q0.result)), - (w1 := squin.wire.Unwrap(qubit=q1.result)), - (w2 := squin.wire.Unwrap(qubit=q2.result)), - (w3 := squin.wire.Unwrap(qubit=q3.result)), - (w4 := squin.wire.Unwrap(qubit=q4.result)), - (p_loss_0 := as_float(0.1)), - # apply and broadcast qubit loss - (ql_loss_0 := squin.noise.stmts.QubitLoss(p=p_loss_0.result)), - ( - app_0 := squin.wire.Broadcast( - ql_loss_0.result, w0.result, w1.result, w2.result, w3.result, w4.result - ) - ), - (p_loss_1 := as_float(0.9)), - (ql_loss_1 := squin.noise.stmts.QubitLoss(p=p_loss_1.result)), - (app_1 := squin.wire.Apply(ql_loss_1.result, app_0.results[0])), - # wrap everything back - (squin.wire.Measure(wire=app_1.results[0], qubit=q0.result)), - (squin.wire.Measure(wire=app_0.results[1], qubit=q1.result)), - (squin.wire.Measure(wire=app_0.results[2], qubit=q2.result)), - (squin.wire.Measure(wire=app_0.results[3], qubit=q3.result)), - (squin.wire.Measure(wire=app_0.results[4], qubit=q4.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - test_method = gen_func_from_stmts(stmts) - run_passes(test_method) - base_stim_prog = get_stim_reference_file("wire_qubit_loss.stim") - assert codegen(test_method) == base_stim_prog.rstrip() diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_depolarize2.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_depolarize2.stim index 598e7e54..0163ed9d 100644 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_depolarize2.stim +++ b/test/stim/passes/stim_reference_programs/noise/broadcast_depolarize2.stim @@ -1 +1 @@ -DEPOLARIZE2(0.01500000) 0 1 +DEPOLARIZE2(0.01500000) 0 2 1 3 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_iid_bit_flip_channel.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_iid_bit_flip_channel.stim deleted file mode 100644 index 009a748b..00000000 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_iid_bit_flip_channel.stim +++ /dev/null @@ -1 +0,0 @@ -X_ERROR(0.01000000) 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_iid_phase_flip_channel.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_iid_phase_flip_channel.stim deleted file mode 100644 index 9b5c27a2..00000000 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_iid_phase_flip_channel.stim +++ /dev/null @@ -1 +0,0 @@ -Z_ERROR(0.01000000) 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_iid_y_flip_channel.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_iid_y_flip_channel.stim deleted file mode 100644 index 8a1e3b05..00000000 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_iid_y_flip_channel.stim +++ /dev/null @@ -1 +0,0 @@ -Y_ERROR(0.01000000) 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1.stim index 88cd3e81..1bde457d 100644 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1.stim +++ b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1.stim @@ -1 +1 @@ -PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 +PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 1 2 3 4 5 6 7 8 9 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1_many_qubits.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1_many_qubits.stim deleted file mode 100644 index a74b3793..00000000 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1_many_qubits.stim +++ /dev/null @@ -1 +0,0 @@ -PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 1 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1_reuse.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1_reuse.stim index d06b1654..bb8fee21 100644 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1_reuse.stim +++ b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_1_reuse.stim @@ -1,3 +1,3 @@ -PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 -PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 -PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 +PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 1 +PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 1 +PAULI_CHANNEL_1(0.01000000, 0.02000000, 0.03000000) 0 1 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2.stim index c5ea983d..58228470 100644 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2.stim +++ b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2.stim @@ -1 +1 @@ -PAULI_CHANNEL_2(0.00100000, 0.00200000, 0.00300000, 0.00400000, 0.00500000, 0.00600000, 0.00700000, 0.00800000, 0.00900000, 0.01000000, 0.01100000, 0.01200000, 0.01300000, 0.01400000, 0.01500000) 0 1 +PAULI_CHANNEL_2(0.00100000, 0.00200000, 0.00300000, 0.00400000, 0.00500000, 0.00600000, 0.00700000, 0.00800000, 0.00900000, 0.01000000, 0.01100000, 0.01200000, 0.01300000, 0.01400000, 0.01500000) 0 4 1 5 2 6 3 7 diff --git a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2_reuse_on_4_qubits.stim b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2_reuse.stim similarity index 94% rename from test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2_reuse_on_4_qubits.stim rename to test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2_reuse.stim index 7f51e4d5..be4a9d97 100644 --- a/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2_reuse_on_4_qubits.stim +++ b/test/stim/passes/stim_reference_programs/noise/broadcast_pauli_channel_2_reuse.stim @@ -1,2 +1,2 @@ -PAULI_CHANNEL_2(0.00100000, 0.00200000, 0.00300000, 0.00400000, 0.00500000, 0.00600000, 0.00700000, 0.00800000, 0.00900000, 0.01000000, 0.01100000, 0.01200000, 0.01300000, 0.01400000, 0.01500000) 0 1 -PAULI_CHANNEL_2(0.00100000, 0.00200000, 0.00300000, 0.00400000, 0.00500000, 0.00600000, 0.00700000, 0.00800000, 0.00900000, 0.01000000, 0.01100000, 0.01200000, 0.01300000, 0.01400000, 0.01500000) 2 3 +PAULI_CHANNEL_2(0.00100000, 0.00200000, 0.00300000, 0.00400000, 0.00500000, 0.00600000, 0.00700000, 0.00800000, 0.00900000, 0.01000000, 0.01100000, 0.01200000, 0.01300000, 0.01400000, 0.01500000) 0 2 1 3 +PAULI_CHANNEL_2(0.00100000, 0.00200000, 0.00300000, 0.00400000, 0.00500000, 0.00600000, 0.00700000, 0.00800000, 0.00900000, 0.01000000, 0.01100000, 0.01200000, 0.01300000, 0.01400000, 0.01500000) 4 6 5 7 diff --git a/test/stim/passes/stim_reference_programs/noise/wire_apply_pauli_channel_1.stim b/test/stim/passes/stim_reference_programs/noise/wire_apply_pauli_channel_1.stim deleted file mode 100644 index 2667459d..00000000 --- a/test/stim/passes/stim_reference_programs/noise/wire_apply_pauli_channel_1.stim +++ /dev/null @@ -1 +0,0 @@ -PAULI_CHANNEL_1(0.01000000, 0.01000000, 0.01000000) 0 diff --git a/test/stim/passes/stim_reference_programs/qubit/qubit.stim b/test/stim/passes/stim_reference_programs/qubit/qubit.stim index 515e7c53..17873714 100644 --- a/test/stim/passes/stim_reference_programs/qubit/qubit.stim +++ b/test/stim/passes/stim_reference_programs/qubit/qubit.stim @@ -1,5 +1,5 @@ H 0 1 X 0 -CX 1 0 +CX 0 1 MZ(0.00000000) 0 1 diff --git a/test/stim/passes/stim_reference_programs/qubit/rep_code.stim b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim new file mode 100644 index 00000000..9105cf43 --- /dev/null +++ b/test/stim/passes/stim_reference_programs/qubit/rep_code.stim @@ -0,0 +1,13 @@ + +RZ 0 1 2 3 4 +CX 0 1 2 3 +CX 2 1 4 3 +MZ(0.00000000) 1 3 +CX 0 1 2 3 +CX 2 1 4 3 +MZ(0.00000000) 1 3 +CX 0 1 2 3 +CX 2 1 4 3 +DEPOLARIZE2(0.01000000) 0 1 2 3 +I_ERROR[loss](0.00100000) 0 1 2 3 4 +MZ(0.00000000) 1 3 diff --git a/test/stim/passes/stim_reference_programs/wire/wire.stim b/test/stim/passes/stim_reference_programs/wire/wire.stim deleted file mode 100644 index ce2cff8d..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire.stim +++ /dev/null @@ -1,3 +0,0 @@ - -S 0 -H 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_apply.stim b/test/stim/passes/stim_reference_programs/wire/wire_apply.stim deleted file mode 100644 index 4524a497..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_apply.stim +++ /dev/null @@ -1,2 +0,0 @@ - -S 0 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_apply_control.stim b/test/stim/passes/stim_reference_programs/wire/wire_apply_control.stim deleted file mode 100644 index fd17213d..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_apply_control.stim +++ /dev/null @@ -1,2 +0,0 @@ - -CX 0 1 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_broadcast.stim b/test/stim/passes/stim_reference_programs/wire/wire_broadcast.stim deleted file mode 100644 index 3d881c78..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_broadcast.stim +++ /dev/null @@ -1,2 +0,0 @@ - -H 0 1 2 3 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_broadcast_control.stim b/test/stim/passes/stim_reference_programs/wire/wire_broadcast_control.stim deleted file mode 100644 index f2005d9d..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_broadcast_control.stim +++ /dev/null @@ -1,6 +0,0 @@ - -CX 0 1 2 3 -MZ(0.00000000) 0 -MZ(0.00000000) 1 -MZ(0.00000000) 2 -MZ(0.00000000) 3 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_measure.stim b/test/stim/passes/stim_reference_programs/wire/wire_measure.stim deleted file mode 100644 index c5db67e3..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_measure.stim +++ /dev/null @@ -1,2 +0,0 @@ - -MZ(0.00000000) 0 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_multiple_apply.stim b/test/stim/passes/stim_reference_programs/wire/wire_multiple_apply.stim deleted file mode 100644 index 8f5eabf0..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_multiple_apply.stim +++ /dev/null @@ -1,5 +0,0 @@ - -S 0 -H 0 -I 0 -I 0 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_qubit_loss.stim b/test/stim/passes/stim_reference_programs/wire/wire_qubit_loss.stim deleted file mode 100644 index f84921f6..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_qubit_loss.stim +++ /dev/null @@ -1,8 +0,0 @@ - -I_ERROR[loss](0.10000000) 0 1 2 3 4 -I_ERROR[loss](0.90000000) 0 -MZ(0.00000000) 0 -MZ(0.00000000) 1 -MZ(0.00000000) 2 -MZ(0.00000000) 3 -MZ(0.00000000) 4 diff --git a/test/stim/passes/stim_reference_programs/wire/wire_reset.stim b/test/stim/passes/stim_reference_programs/wire/wire_reset.stim deleted file mode 100644 index 958e0cfe..00000000 --- a/test/stim/passes/stim_reference_programs/wire/wire_reset.stim +++ /dev/null @@ -1,3 +0,0 @@ - -RZ 0 -MZ(0.00000000) 0 diff --git a/test/stim/passes/squin_meas_to_stim.py b/test/stim/passes/test_squin_meas_to_stim.py similarity index 65% rename from test/stim/passes/squin_meas_to_stim.py rename to test/stim/passes/test_squin_meas_to_stim.py index dff954e0..c3d7f279 100644 --- a/test/stim/passes/squin_meas_to_stim.py +++ b/test/stim/passes/test_squin_meas_to_stim.py @@ -3,10 +3,9 @@ from kirin import ir from kirin.dialects.ilist import IList -from bloqade import squin -from bloqade.squin import op, qubit +from bloqade import squin as sq +from bloqade.types import MeasurementResult from bloqade.stim.emit import EmitStimMain -from bloqade.squin.qubit import MeasurementResult from bloqade.stim.passes import SquinToStimPass @@ -29,23 +28,23 @@ def load_reference_program(filename): def test_cond_on_measurement(): - @squin.kernel + @sq.kernel def main(): n_qubits = 4 - q = qubit.new(n_qubits) + q = sq.qubit.new(n_qubits) - ms = qubit.measure(q) + ms = sq.qubit.measure(q) if ms[0]: - qubit.apply(op.z(), q[0]) - qubit.broadcast(op.x(), [q[1], q[2], q[3]]) - qubit.broadcast(op.z(), q) + sq.z(q[0]) + sq.broadcast.x([q[1], q[2], q[3]]) + sq.broadcast.z(q) if ms[1]: - qubit.apply(op.x(), q[0]) - qubit.apply(op.y(), q[1]) + sq.x(q[0]) + sq.y(q[1]) - qubit.measure(q) + sq.qubit.measure(q) SquinToStimPass(main.dialects)(main) @@ -56,15 +55,15 @@ def main(): def test_alias_with_measure_list(): - @squin.kernel + @sq.kernel def main(): - q = qubit.new(4) - ms = qubit.measure(q) + q = sq.qubit.new(4) + ms = sq.qubit.measure(q) new_ms = ms if new_ms[0]: - qubit.apply(op.z(), q[0]) + sq.z(q[0]) SquinToStimPass(main.dialects)(main) @@ -75,30 +74,30 @@ def main(): def test_record_index_order(): - @squin.kernel + @sq.kernel def main(): n_qubits = 4 - q = qubit.new(n_qubits) + q = sq.qubit.new(n_qubits) - ms0 = qubit.measure(q) + ms0 = sq.qubit.measure(q) if ms0[0]: # should be rec[-4] - qubit.apply(op.z(), q[0]) + sq.z(q[0]) # another measurement - ms1 = qubit.measure(q) + ms1 = sq.qubit.measure(q) if ms1[0]: # should be rec[-4] - qubit.apply(op.x(), q[0]) + sq.x(q[0]) # second round of measurement - ms2 = qubit.measure(q) # noqa: F841 + ms2 = sq.qubit.measure(q) # noqa: F841 # try accessing measurements from the very first round ## There are now 12 total measurements, ms0[0] ## is the oldest measurement in the entire program if ms0[0]: - qubit.apply(op.y(), q[1]) + sq.y(q[1]) SquinToStimPass(main.dialects)(main) @@ -109,33 +108,33 @@ def main(): def test_complex_intermediate_storage_of_measurements(): - @squin.kernel + @sq.kernel def main(): n_qubits = 4 - q = qubit.new(n_qubits) + q = sq.qubit.new(n_qubits) - ms0 = qubit.measure(q) + ms0 = sq.qubit.measure(q) if ms0[0]: - qubit.apply(op.z(), q[0]) + sq.z(q[0]) - ms1 = qubit.measure(q) + ms1 = sq.qubit.measure(q) if ms1[0]: - qubit.apply(op.x(), q[1]) + sq.x(q[1]) # another measurement - ms2 = qubit.measure(q) + ms2 = sq.qubit.measure(q) if ms2[0]: - qubit.apply(op.y(), q[2]) + sq.y(q[2]) # Intentionally obnoxious mix of measurements mix = [ms0[0], ms1[2], ms2[3]] mix_again = (mix[2], mix[0]) if mix_again[0]: - qubit.apply(op.z(), q[3]) + sq.z(q[3]) SquinToStimPass(main.dialects)(main) @@ -146,14 +145,14 @@ def main(): def test_addition_assignment_on_measures_in_list(): - @squin.kernel(fold=False) + @sq.kernel(fold=False) def main(): - q = qubit.new(2) + q = sq.qubit.new(2) results = [] - result: MeasurementResult = qubit.measure(q[0]) + result: MeasurementResult = sq.qubit.measure(q[0]) results += [result] - result: MeasurementResult = qubit.measure(q[1]) + result: MeasurementResult = sq.qubit.measure(q[1]) results += [result] SquinToStimPass(main.dialects)(main) @@ -167,15 +166,15 @@ def test_measure_desugar(): pairs = IList([0, 1, 2, 3]) - @squin.kernel + @sq.kernel def main(): - q = qubit.new(10) - qubit.measure(q[pairs[0]]) + q = sq.qubit.new(10) + sq.qubit.measure(q[pairs[0]]) for i in range(1): - qubit.measure(q[0]) - qubit.measure(q[i]) - qubit.measure(q[pairs[0]]) - qubit.measure(q[pairs[i]]) + sq.qubit.measure(q[0]) + sq.qubit.measure(q[i]) + sq.qubit.measure(q[pairs[0]]) + sq.qubit.measure(q[pairs[i]]) SquinToStimPass(main.dialects)(main) diff --git a/test/stim/passes/test_squin_noise_to_stim.py b/test/stim/passes/test_squin_noise_to_stim.py new file mode 100644 index 00000000..ad425845 --- /dev/null +++ b/test/stim/passes/test_squin_noise_to_stim.py @@ -0,0 +1,285 @@ +import os + +import kirin.types as kirin_types +from kirin import ir, types +from kirin.decl import info, statement +from kirin.rewrite import Walk +from kirin.dialects import ilist + +from bloqade import squin as sq +from bloqade.squin import noise, qubit, kernel +from bloqade.types import Qubit, QubitType +from bloqade.stim.emit import EmitStimMain +from bloqade.stim.passes import SquinToStimPass, flatten +from bloqade.stim.rewrite import SquinNoiseToStim +from bloqade.squin.rewrite import WrapAddressAnalysis +from bloqade.analysis.address import AddressAnalysis + + +def codegen(mt: ir.Method): + # method should not have any arguments! + emit = EmitStimMain() + emit.initialize() + emit.run(mt=mt, args=()) + return emit.get_output().strip() + + +def load_reference_program(filename): + """Load stim file.""" + path = os.path.join( + os.path.dirname(__file__), "stim_reference_programs", "noise", filename + ) + with open(path, "r") as f: + return f.read().strip() + + +def test_apply_pauli_channel_1(): + + @kernel + def test(): + q = qubit.new(1) + sq.single_qubit_pauli_channel(px=0.01, py=0.02, pz=0.03, qubit=q[0]) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program("apply_pauli_channel_1.stim") + assert codegen(test) == expected_stim_program + + +def test_broadcast_pauli_channel_1(): + + @kernel + def test(): + q = qubit.new(10) + sq.broadcast.single_qubit_pauli_channel(px=0.01, py=0.02, pz=0.03, qubits=q) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program("broadcast_pauli_channel_1.stim") + assert codegen(test) == expected_stim_program + + +def test_broadcast_pauli_channel_1_reuse(): + + @kernel + def fixed_1q_pauli_channel(qubits): + sq.broadcast.single_qubit_pauli_channel( + px=0.01, py=0.02, pz=0.03, qubits=qubits + ) + + @kernel + def test(): + q = qubit.new(2) + fixed_1q_pauli_channel(q) + fixed_1q_pauli_channel(q) + fixed_1q_pauli_channel(q) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program( + "broadcast_pauli_channel_1_reuse.stim" + ) + assert codegen(test) == expected_stim_program + + +def test_broadcast_pauli_channel_2(): + + @kernel + def test(): + q = qubit.new(8) + sq.broadcast.two_qubit_pauli_channel( + probabilities=[ + 0.001, + 0.002, + 0.003, + 0.004, + 0.005, + 0.006, + 0.007, + 0.008, + 0.009, + 0.010, + 0.011, + 0.012, + 0.013, + 0.014, + 0.015, + ], + controls=q[:4], + targets=q[4:], + ) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program("broadcast_pauli_channel_2.stim") + assert codegen(test) == expected_stim_program + + +def test_broadcast_pauli_channel_2_reuse(): + + @kernel + def fixed_2q_pauli_channel(controls, targets): + sq.broadcast.two_qubit_pauli_channel( + probabilities=[ + 0.001, + 0.002, + 0.003, + 0.004, + 0.005, + 0.006, + 0.007, + 0.008, + 0.009, + 0.010, + 0.011, + 0.012, + 0.013, + 0.014, + 0.015, + ], + controls=controls, + targets=targets, + ) + + @kernel + def test(): + q = qubit.new(8) + + fixed_2q_pauli_channel([q[0], q[1]], [q[2], q[3]]) + fixed_2q_pauli_channel([q[4], q[5]], [q[6], q[7]]) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program( + "broadcast_pauli_channel_2_reuse.stim" + ) + assert codegen(test) == expected_stim_program + + +def test_broadcast_depolarize2(): + + @kernel + def test(): + q = qubit.new(4) + sq.broadcast.depolarize2(p=0.015, controls=q[:2], targets=q[2:]) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program("broadcast_depolarize2.stim") + assert codegen(test) == expected_stim_program + + +def test_apply_depolarize1(): + + @kernel + def test(): + q = qubit.new(1) + sq.depolarize(p=0.01, qubit=q[0]) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program("apply_depolarize1.stim") + assert codegen(test) == expected_stim_program + + +def test_broadcast_depolarize1(): + + @kernel + def test(): + q = qubit.new(4) + sq.broadcast.depolarize(p=0.01, qubits=q) + return + + SquinToStimPass(test.dialects)(test) + expected_stim_program = load_reference_program("broadcast_depolarize1.stim") + assert codegen(test) == expected_stim_program + + +def test_apply_loss(): + + @kernel + def apply_loss(qubit): + sq.qubit_loss(0.1, qubit=qubit) + + @kernel + def test(): + q = qubit.new(3) + apply_loss(q[0]) + apply_loss(q[1]) + apply_loss(q[2]) + + SquinToStimPass(test.dialects)(test) + + expected_stim_program = load_reference_program("apply_loss.stim") + assert codegen(test) == expected_stim_program + + +def get_stmt_at_idx(method: ir.Method, idx: int) -> ir.Statement: + return method.callable_region.blocks[0].stmts.at(idx) + + +# If there's no concrete qubit values from the address analysis then +# the rewrite rule should immediately return and not mutate the method. +def test_no_qubit_address_available(): + + @kernel + def test(q: ilist.IList[Qubit, kirin_types.Literal]): + sq.single_qubit_pauli_channel(px=0.01, py=0.02, pz=0.03, qubit=q[0]) + return + + flatten.Flatten(dialects=test.dialects).fixpoint(test) + Walk(SquinNoiseToStim()).rewrite(test.code) + + expected_1q_noise_pauli_channel = get_stmt_at_idx(test, 6) + + assert isinstance( + expected_1q_noise_pauli_channel, noise.stmts.SingleQubitPauliChannel + ) + + +def test_nonexistent_noise_channel(): + + @statement(dialect=noise.dialect) + class NonExistentNoiseChannel(noise.stmts.NoiseChannel): + """ + A non-existent noise channel for testing purposes. + """ + + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) + + pass + + @kernel + def test(): + q = qubit.new(1) + NonExistentNoiseChannel(qubits=q) + return + + frame, _ = AddressAnalysis(test.dialects).run_analysis(test) + WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) + + rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) + + expected_noise_channel_stmt = get_stmt_at_idx(test, 2) + + # The rewrite shouldn't have occurred at all because there is no rewrite logic for + # NonExistentNoiseChannel. + assert not rewrite_result.has_done_something + assert isinstance(expected_noise_channel_stmt, NonExistentNoiseChannel) + + +def test_standard_op_no_rewrite(): + + @kernel + def test(): + q = qubit.new(1) + sq.x(qubit=q[0]) + return + + frame, _ = AddressAnalysis(test.dialects).run_analysis(test) + WrapAddressAnalysis(address_analysis=frame.entries).rewrite(test.code) + + rewrite_result = Walk(SquinNoiseToStim()).rewrite(test.code) + + # Rewrite should not have done anything because target is not a noise channel + assert not rewrite_result.has_done_something diff --git a/test/stim/passes/squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py similarity index 63% rename from test/stim/passes/squin_qubit_to_stim.py rename to test/stim/passes/test_squin_qubit_to_stim.py index 6385d2c9..f5b76f0d 100644 --- a/test/stim/passes/squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -4,8 +4,8 @@ from kirin import ir from kirin.dialects import py -from bloqade import squin -from bloqade.squin import op, noise, qubit, kernel +from bloqade import squin as sq +from bloqade.squin import qubit, kernel from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass @@ -39,20 +39,15 @@ def test_qubit(): @kernel def test(): n_qubits = 2 - ql = qubit.new(n_qubits) - qubit.broadcast(op.h(), ql) - qubit.apply(op.x(), ql[0]) - ctrl = op.control(op.x(), n_controls=1) - qubit.apply(ctrl, ql[1], ql[0]) - # measure out - squin.qubit.measure(ql) + ql = sq.qubit.new(n_qubits) + sq.broadcast.h(ql) + sq.x(ql[0]) + sq.cx(ql[0], ql[1]) + sq.qubit.measure(ql) return - test.print() - SquinToStimPass(test.dialects)(test) base_stim_prog = load_reference_program("qubit.stim") - assert codegen(test) == base_stim_prog.rstrip() @@ -62,9 +57,9 @@ def test(): n_qubits = 1 q = qubit.new(n_qubits) # reset the qubit - squin.qubit.apply(op.reset(), q[0]) + qubit.Reset(q) # measure out - squin.qubit.measure(q[0]) + sq.qubit.measure(q[0]) return SquinToStimPass(test.dialects)(test) @@ -79,9 +74,9 @@ def test(): n_qubits = 4 ql = qubit.new(n_qubits) # apply Hadamard to all qubits - squin.qubit.broadcast(op.h(), ql) + sq.broadcast.h(ql) # measure out - squin.qubit.measure(ql) + sq.qubit.measure(ql) return SquinToStimPass(test.dialects)(test) @@ -90,18 +85,18 @@ def test(): assert codegen(test) == base_stim_prog.rstrip() -def test_qubit_loss(): +def test_gates_with_loss(): @kernel def test(): n_qubits = 5 ql = qubit.new(n_qubits) # apply Hadamard to all qubits - squin.qubit.broadcast(op.h(), ql) + sq.broadcast.h(ql) # apply and broadcast qubit loss - squin.qubit.apply(noise.qubit_loss(0.1), ql[3]) - squin.qubit.broadcast(noise.qubit_loss(0.05), ql) + sq.qubit_loss(p=0.1, qubit=ql[3]) + sq.broadcast.qubit_loss(p=0.05, qubits=ql) # measure out - squin.qubit.measure(ql) + sq.qubit.measure(ql) return SquinToStimPass(test.dialects)(test) @@ -117,9 +112,9 @@ def test(): n_qubits = 1 q = qubit.new(n_qubits) # apply U3 rotation that can be translated to a Clifford gate - squin.qubit.apply(op.u(0.25 * math.tau, 0.0 * math.tau, 0.5 * math.tau), q[0]) + sq.u3(0.25 * math.tau, 0.0 * math.tau, 0.5 * math.tau, qubit=q[0]) # measure out - squin.qubit.measure(q) + sq.qubit.measure(q) return SquinToStimPass(test.dialects)(test) @@ -131,10 +126,10 @@ def test(): def test_sqrt_x_rewrite(): - @squin.kernel + @sq.kernel def test(): q = qubit.new(1) - qubit.broadcast(op.sqrt_x(), q) + sq.broadcast.sqrt_x(q) return SquinToStimPass(test.dialects)(test) @@ -144,10 +139,10 @@ def test(): def test_sqrt_y_rewrite(): - @squin.kernel + @sq.kernel def test(): q = qubit.new(1) - qubit.broadcast(op.sqrt_y(), q) + sq.broadcast.sqrt_y(q) return SquinToStimPass(test.dialects)(test) @@ -157,13 +152,12 @@ def test(): def test_for_loop_nontrivial_index_rewrite(): - @squin.kernel + @sq.kernel def main(): - q = squin.qubit.new(3) - squin.qubit.apply(squin.op.h(), q[0]) - cx = squin.op.cx() + q = sq.qubit.new(3) + sq.h(q[0]) for i in range(2): - squin.qubit.apply(cx, q[i], q[i + 1]) + sq.cx(q[i], q[i + 1]) SquinToStimPass(main.dialects)(main) base_stim_prog = load_reference_program("for_loop_nontrivial_index.stim") @@ -173,14 +167,13 @@ def main(): def test_nested_for_loop_rewrite(): - @squin.kernel + @sq.kernel def main(): - q = squin.qubit.new(5) - squin.qubit.apply(squin.op.h(), q[0]) - cx = squin.op.cx() + q = sq.qubit.new(5) + sq.h(q[0]) for i in range(2): for j in range(2, 3): - squin.qubit.apply(cx, q[i], q[j]) + sq.cx(q[i], q[j]) SquinToStimPass(main.dialects)(main) base_stim_prog = load_reference_program("nested_for_loop.stim") @@ -199,12 +192,11 @@ def test_nested_list(): pairs = [[0, 1], [2, 3]] - @squin.kernel + @sq.kernel def main(): - q = qubit.new(10) - h = squin.op.h() + q = sq.qubit.new(10) for i in range(2): - squin.qubit.apply(h, q[pairs[i][0]]) + sq.h(q[pairs[i][0]]) SquinToStimPass(main.dialects)(main) @@ -215,14 +207,14 @@ def main(): def test_pick_if_else(): - @squin.kernel + @sq.kernel def main(): q = qubit.new(10) if False: - qubit.apply(squin.op.h(), q[0]) + sq.h(q[0]) if True: - qubit.apply(squin.op.h(), q[1]) + sq.h(q[1]) SquinToStimPass(main.dialects)(main) @@ -239,10 +231,64 @@ def test_squin_kernel(): outputs = [] for rnd in range(len(result)): # Non-pure loop iterator outputs += [] - qubit.apply(op.x(), q[rnd]) # make sure body does something + sq.x(q[rnd]) # make sure body does something return main = test_squin_kernel.similar() SquinToStimPass(main.dialects)(main) base_stim_prog = load_reference_program("non_pure_loop_iterator.stim") assert codegen(main) == base_stim_prog.rstrip() + + +def test_rep_code(): + + # NOTE: This is not a true repetition code in the sense there is no + # detector definition or final observables being defined. + + @sq.kernel + def entangle(cx_pairs): + sq.broadcast.cx(controls=cx_pairs[0][0], targets=cx_pairs[0][1]) + sq.broadcast.cx(controls=cx_pairs[1][0], targets=cx_pairs[1][1]) + + @sq.kernel + def rep_code(): + + q = qubit.new(5) + data = q[::2] + ancilla = q[1::2] + + # reset everything initially + qubit.Reset(q) + + ## Initial round, entangle data qubits with ancillas. + ## This entanglement will happen again so it's best we + ## save the qubit pairs for reuse. + cx_pair_1_controls = [data[0], data[1]] + cx_pair_1_targets = [ancilla[0], ancilla[1]] + cx_pair_1 = [cx_pair_1_controls, cx_pair_1_targets] + + cx_pair_2_controls = [data[1], data[2]] + cx_pair_2_targets = [ancilla[0], ancilla[1]] + cx_pair_2 = [cx_pair_2_controls, cx_pair_2_targets] + + cx_pairs = [cx_pair_1, cx_pair_2] + + entangle(cx_pairs) + + qubit.measure(ancilla) + + entangle(cx_pairs) + qubit.measure(ancilla) + + # Let's make this one a bit noisy + entangle(cx_pairs) + sq.broadcast.depolarize2( + 0.01, controls=cx_pair_1_controls, targets=cx_pair_1_targets + ) + sq.broadcast.qubit_loss(p=0.001, qubits=q) + + qubit.measure(ancilla) + + SquinToStimPass(rep_code.dialects)(rep_code) + base_stim_prog = load_reference_program("rep_code.stim") + assert codegen(rep_code) == base_stim_prog.rstrip()