diff --git a/src/bloqade/cirq_utils/emit/op.py b/src/bloqade/cirq_utils/emit/op.py index 3289cbac..99dcb915 100644 --- a/src/bloqade/cirq_utils/emit/op.py +++ b/src/bloqade/cirq_utils/emit/op.py @@ -64,7 +64,8 @@ def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp): @impl(op.stmts.Identity) def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity): - op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites)) + sites = frame.get(stmt.sites) + op = HermitianRuntime(cirq.IdentityGate(num_qubits=sites)) return (op,) @impl(op.stmts.Control) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 165ddce5..e32d1b79 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -378,7 +378,8 @@ def visit_SingleQubitPauliStringGateOperation( case cirq.Z: op_ = op.stmts.Z() case cirq.I: - op_ = op.stmts.Identity(sites=1) + site = state.current_frame.push(py.Constant(1)).result + op_ = op.stmts.Identity(sites=site) case _: raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}") @@ -386,6 +387,12 @@ def visit_SingleQubitPauliStringGateOperation( qargs = self.lower_qubit_getindices(state, [node.qubit]) return state.current_frame.push(qubit.Apply(op_.result, qargs)) + def visit_IdentityGate( + self, state: lowering.State[CirqNode], node: cirq.IdentityGate + ): + sites = state.current_frame.push(py.Constant(node.num_qubits())) + return state.current_frame.push(op.stmts.Identity(sites=sites.result)) + def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate): if abs(node.exponent) == 1: return state.current_frame.push(op.stmts.H()) diff --git a/src/bloqade/pyqrack/squin/op.py b/src/bloqade/pyqrack/squin/op.py index 1a603aaa..a670423e 100644 --- a/src/bloqade/pyqrack/squin/op.py +++ b/src/bloqade/pyqrack/squin/op.py @@ -83,7 +83,8 @@ def rot( def identity( self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity ) -> tuple[OperatorRuntimeABC]: - return (IdentityRuntime(sites=stmt.sites),) + sites = frame.get(stmt.sites) + return (IdentityRuntime(sites=sites),) @interp.impl(op.stmts.PhaseOp) @interp.impl(op.stmts.ShiftOp) diff --git a/src/bloqade/squin/analysis/nsites/analysis.py b/src/bloqade/squin/analysis/nsites/analysis.py index 52bd3091..0768084c 100644 --- a/src/bloqade/squin/analysis/nsites/analysis.py +++ b/src/bloqade/squin/analysis/nsites/analysis.py @@ -11,6 +11,10 @@ class NSitesAnalysis(Forward[Sites]): + """Analysis pass to infer number of sites an operator applies to. + + **NOTE**: run kirin.passes.HintConst prior to using this analysis pass. + """ keys = ["op.nsites"] lattice = Sites diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index 3e47b607..9433e569 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -1,4 +1,5 @@ from kirin import interp +from kirin.analysis import const from kirin.dialects import scf, func, ilist from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer @@ -88,6 +89,21 @@ def pauli_string( s = stmt.string return (NumberSites(sites=len(s)),) + @interp.impl(op.stmts.Identity) + def identity( + self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Identity + ): + sites = stmt.sites + const_hint = sites.hints.get("const") + + if not isinstance(const_hint, const.Value): + return (interp.lattice.top(),) + + if not isinstance(n_sites := const_hint.data, int): + return (interp.lattice.top(),) + + return (NumberSites(sites=n_sites),) + @ilist.dialect.register(key="op.nsites") class IListMethods(interp.MethodTable): diff --git a/src/bloqade/squin/op/_wrapper.py b/src/bloqade/squin/op/_wrapper.py index 86703ad6..45540ecf 100644 --- a/src/bloqade/squin/op/_wrapper.py +++ b/src/bloqade/squin/op/_wrapper.py @@ -46,7 +46,7 @@ def reset_to_one() -> types.Op: ... @wraps(stmts.Identity) -def identity(*, sites: int) -> types.Op: ... +def identity(sites: int) -> types.Op: ... @wraps(stmts.Rot) diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index c4e512df..923d5dfa 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -96,7 +96,7 @@ class Rot(CompositeOp): @statement(dialect=dialect) class Identity(CompositeOp): traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()}) - sites: int = info.attribute() + sites: ir.SSAValue = info.argument(types.Int) @statement diff --git a/src/bloqade/squin/rewrite/U3_to_clifford.py b/src/bloqade/squin/rewrite/U3_to_clifford.py index bc3c4f63..c51350c7 100644 --- a/src/bloqade/squin/rewrite/U3_to_clifford.py +++ b/src/bloqade/squin/rewrite/U3_to_clifford.py @@ -14,34 +14,39 @@ def sdag() -> list[ir.Statement]: return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)] +def single_site_identity(node: ir.Statement) -> Tuple[List[ir.Statement], ...]: + (site := py.Constant(1)).insert_before(node) + return ([op.stmts.Identity(sites=site.result)],) + + # (theta, phi, lam) U3_HALF_PI_ANGLE_TO_GATES: dict[ - tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]] + tuple[int, int, int], Callable[[ir.Statement], Tuple[List[ir.Statement], ...]] ] = { - (0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],), - (0, 0, 1): lambda: ([op.stmts.S()],), - (0, 0, 2): lambda: ([op.stmts.Z()],), - (0, 0, 3): lambda: (sdag(),), - (1, 0, 0): lambda: ([op.stmts.SqrtY()],), - (1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]), - (1, 0, 2): lambda: ([op.stmts.H()],), - (1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]), - (1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]), - (1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]), - (1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()), - (1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()), - (1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()), - (1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()), - (2, 0, 0): lambda: ([op.stmts.Y()],), - (2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]), - (2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]), - (2, 0, 3): lambda: (sdag(), [op.stmts.Y()]), + (0, 0, 0): single_site_identity, + (0, 0, 1): lambda node: ([op.stmts.S()],), + (0, 0, 2): lambda node: ([op.stmts.Z()],), + (0, 0, 3): lambda node: (sdag(),), + (1, 0, 0): lambda node: ([op.stmts.SqrtY()],), + (1, 0, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()]), + (1, 0, 2): lambda node: ([op.stmts.H()],), + (1, 0, 3): lambda node: (sdag(), [op.stmts.SqrtY()]), + (1, 1, 0): lambda node: ([op.stmts.SqrtY()], [op.stmts.S()]), + (1, 1, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]), + (1, 1, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]), + (1, 1, 3): lambda node: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]), + (1, 2, 0): lambda node: ([op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 2, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 2, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 2, 3): lambda node: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]), + (1, 3, 0): lambda node: ([op.stmts.SqrtY()], sdag()), + (1, 3, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()), + (1, 3, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()), + (1, 3, 3): lambda node: (sdag(), [op.stmts.SqrtY()], sdag()), + (2, 0, 0): lambda node: ([op.stmts.Y()],), + (2, 0, 1): lambda node: ([op.stmts.S()], [op.stmts.Y()]), + (2, 0, 2): lambda node: ([op.stmts.Z()], [op.stmts.Y()]), + (2, 0, 3): lambda node: (sdag(), [op.stmts.Y()]), } @@ -154,4 +159,4 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ... gates_stmts is not None ), "internal error, U3 gates not found for angles: {}".format(angles_key) - return gates_stmts() + return gates_stmts(node) diff --git a/test/cirq_utils/test_cirq_to_squin.py b/test/cirq_utils/test_cirq_to_squin.py index 9aa5c235..79063c51 100644 --- a/test/cirq_utils/test_cirq_to_squin.py +++ b/test/cirq_utils/test_cirq_to_squin.py @@ -17,6 +17,7 @@ def basic_circuit(): # Create a circuit. return cirq.Circuit( + cirq.I(qubit), cirq.X(qubit), cirq.Y(qubit2), cirq.Z(qubit), diff --git a/test/squin/analysis/test_nsites_analysis.py b/test/squin/analysis/test_nsites_analysis.py index bba0402b..187db028 100644 --- a/test/squin/analysis/test_nsites_analysis.py +++ b/test/squin/analysis/test_nsites_analysis.py @@ -1,6 +1,6 @@ from kirin import ir, types from kirin.passes import Fold -from kirin.dialects import py, func +from kirin.dialects import py, scf, func from bloqade import squin from bloqade.squin import op, noise, qubit @@ -304,3 +304,72 @@ def test(): assert has_n_sites[6].sites == 1 # single_qubit_pauli_channel assert has_n_sites[7].sites == 2 # two_qubit_pauli_channel assert has_n_sites[8].sites == 1 # qubit_loss + + +def test_identity(): + + @squin.kernel(fold=False) + def main1(): + squin.op.identity(2) + + nsites_frame, _ = nsites.NSitesAnalysis(main1.dialects).run_analysis( + main1, no_raise=False + ) + + main1.print(analysis=nsites_frame.entries) + + assert [nsites_frame.entries[result] for result in results_at(main1, 0, 1)] == [ + nsites.NumberSites(sites=2) + ] + + @squin.kernel(fold=False) + def main2(n: int): + squin.op.identity(n) + + nsites_frame, _ = nsites.NSitesAnalysis(main2.dialects).run_analysis( + main2, no_raise=False + ) + + main2.print(analysis=nsites_frame.entries) + + assert [nsites_frame.entries[result] for result in results_at(main2, 0, 0)] == [ + nsites.AnySites() + ] + + @squin.kernel(fold=False) + def main3(): + n = 3 + squin.op.identity(n) + + nsites_frame, _ = nsites.NSitesAnalysis(main3.dialects).run_analysis( + main3, no_raise=False + ) + + main3.print(analysis=nsites_frame.entries) + + assert [nsites_frame.entries[result] for result in results_at(main3, 0, 1)] == [ + nsites.NumberSites(3) + ] + + @squin.kernel(fold=False) + def main4(): + x = 0 + for i in range(3): + squin.op.identity(i) + + # NOTE: need a workaround yield here, else kirin fails verification complaining about missing yield + x += 1 + + nsites_frame, _ = nsites.NSitesAnalysis(main4.dialects).run_analysis( + main4, no_raise=False + ) + + main4.print(analysis=nsites_frame.entries) + + for stmt in main4.callable_region.stmts(): + if not isinstance(stmt, scf.For): + continue + + for body_stmt in stmt.body.stmts(): + if isinstance(body_stmt, squin.op.stmts.Identity): + assert nsites_frame.get(body_stmt.result) == nsites.AnySites() diff --git a/test/squin/rewrite/test_U3_to_clifford.py b/test/squin/rewrite/test_U3_to_clifford.py index 6076c1cf..37e6a286 100644 --- a/test/squin/rewrite/test_U3_to_clifford.py +++ b/test/squin/rewrite/test_U3_to_clifford.py @@ -40,7 +40,7 @@ def test(): SquinToCliffordTestPass(test.dialects)(test) - assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Identity) + assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Identity) def test_s(): diff --git a/test/stim/passes/test_squin_wire_to_stim.py b/test/stim/passes/test_squin_wire_to_stim.py index 72f4b26f..d90b41cc 100644 --- a/test/stim/passes/test_squin_wire_to_stim.py +++ b/test/stim/passes/test_squin_wire_to_stim.py @@ -157,8 +157,8 @@ def test_wire_multiple_apply(): # pass the wires through some 1 Qubit operators (op1 := squin.op.stmts.S()), (op2 := squin.op.stmts.H()), - (op3 := squin.op.stmts.Identity(sites=1)), - (op4 := squin.op.stmts.Identity(sites=1)), + (op3 := squin.op.stmts.Identity(sites=n_qubits.result)), + (op4 := squin.op.stmts.Identity(sites=n_qubits.result)), (v0 := squin.wire.Apply(op1.result, w0.result)), (v1 := squin.wire.Apply(op2.result, v0.results[0])), (v2 := squin.wire.Apply(op3.result, v1.results[0])),