From 49115769eb465545a2ba4763ce503ceddf331928 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 29 Jul 2025 23:21:30 +0200 Subject: [PATCH 1/4] Implement squin moment statements --- src/bloqade/squin/cirq/emit/qubit.py | 5 +++++ src/bloqade/squin/cirq/lowering.py | 7 +++++-- src/bloqade/squin/qubit.py | 6 ++++++ test/squin/cirq/test_moment.py | 23 +++++++++++++++++++++++ 4 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 test/squin/cirq/test_moment.py diff --git a/src/bloqade/squin/cirq/emit/qubit.py b/src/bloqade/squin/cirq/emit/qubit.py index 0c4df791..5c46298d 100644 --- a/src/bloqade/squin/cirq/emit/qubit.py +++ b/src/bloqade/squin/cirq/emit/qubit.py @@ -58,3 +58,8 @@ def measure_qubit_list( qbits = frame.get(stmt.qubits) frame.circuit.append(cirq.measure(qbits)) return () + + @impl(qubit.Moment) + def moment(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Moment): + emit.run_ssacfg_region(frame, stmt.body, args=()) + return () diff --git a/src/bloqade/squin/cirq/lowering.py b/src/bloqade/squin/cirq/lowering.py index 349dd768..30ee9279 100644 --- a/src/bloqade/squin/cirq/lowering.py +++ b/src/bloqade/squin/cirq/lowering.py @@ -122,8 +122,11 @@ def visit_Circuit( def visit_Moment( self, state: lowering.State[CirqNode], node: cirq.Moment ) -> lowering.Result: - for op_ in node.operations: - state.lower(op_) + with state.frame(node.operations) as body_frame: + body_frame.exhaust() + body_region = body_frame.curr_region + + state.current_frame.push(qubit.Moment(body=body_region)) def visit_GateOperation( self, state: lowering.State[CirqNode], node: cirq.GateOperation diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 407b3152..3e09dc46 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -79,6 +79,12 @@ class MeasureQubitList(ir.Statement): result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType]) +@statement(dialect=dialect) +class Moment(ir.Statement): + traits = frozenset({ir.SSACFGRegion()}) + body: ir.Region = info.region() + + # NOTE: no dependent types in Python, so we have to mark it Any... @wraps(New) def new(n_qubits: int) -> ilist.IList[Qubit, Any]: diff --git a/test/squin/cirq/test_moment.py b/test/squin/cirq/test_moment.py new file mode 100644 index 00000000..474fad0f --- /dev/null +++ b/test/squin/cirq/test_moment.py @@ -0,0 +1,23 @@ +import cirq + +from bloqade import squin + +q = cirq.LineQubit.range(3) +circuit = cirq.Circuit( + cirq.Moment( + cirq.X(q[0]), + cirq.H(q[1]), + ), + cirq.Moment( + cirq.X(q[2]), + ), +) + +kernel = squin.cirq.load_circuit(circuit) + +kernel.print() + + +circuit2 = squin.cirq.emit_circuit(kernel) + +print(circuit2) From ed89ce8dcbfc69c3d6767d68587ef4d34d515992 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 30 Jul 2025 15:46:06 +0200 Subject: [PATCH 2/4] Implement pyqrack method for squin moments --- src/bloqade/pyqrack/squin/qubit.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/bloqade/pyqrack/squin/qubit.py b/src/bloqade/pyqrack/squin/qubit.py index af4dd061..bd0952b3 100644 --- a/src/bloqade/pyqrack/squin/qubit.py +++ b/src/bloqade/pyqrack/squin/qubit.py @@ -63,3 +63,9 @@ def measure_qubit( qbit: PyQrackQubit = frame.get(stmt.qubit) result = self._measure_qubit(qbit, interp) return (result,) + + @interp.impl(qubit.Moment) + def moment( + self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Moment + ): + return interp.run_ssacfg_region(frame, stmt.body, args=()) From a1d3bc79b5fb03ced409e87afb46185974c90012 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 31 Jul 2025 15:08:56 +0200 Subject: [PATCH 3/4] Fix lowering with classical controls --- src/bloqade/squin/cirq/lowering.py | 55 ++++++++++++++++++--------- test/squin/cirq/test_cirq_to_squin.py | 29 ++++++++++++++ 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/src/bloqade/squin/cirq/lowering.py b/src/bloqade/squin/cirq/lowering.py index 30ee9279..69d05013 100644 --- a/src/bloqade/squin/cirq/lowering.py +++ b/src/bloqade/squin/cirq/lowering.py @@ -122,10 +122,17 @@ def visit_Circuit( def visit_Moment( self, state: lowering.State[CirqNode], node: cirq.Moment ) -> lowering.Result: + frame = state.current_frame with state.frame(node.operations) as body_frame: + body_frame.defs.update(frame.defs) + body_frame.globals.update(frame.globals) + body_frame.exhaust() body_region = body_frame.curr_region + state.current_frame.defs.update(body_frame.defs) + state.current_frame.globals.update(body_frame.globals) + state.current_frame.push(qubit.Moment(body=body_region)) def visit_GateOperation( @@ -176,24 +183,8 @@ def visit_ClassicallyControlledOperation( measurement_outcome = state.current_frame.defs[key] if measurement_outcome.type.is_subseteq(ilist.IListType): - # NOTE: there is currently no convenient ilist.any method, so we need to use foldl - # with a simple function that just does an or - - def bool_op_or(x: bool, y: bool) -> bool: - return x or y - - f_code = state.current_frame.push( - lowering.Python(self.dialects).python_function(bool_op_or) - ) - fn = ir.Method( - mod=None, - py_func=bool_op_or, - sym_name="bool_op_or", - arg_names=[], - dialects=self.dialects, - code=f_code, - ) - f_const = state.current_frame.push(py.constant.Constant(fn)) + # TODO: replace by ilist.Any + f_const = self._get_or_func(state) init_val = state.current_frame.push(py.Constant(False)).result condition = state.current_frame.push( ilist.Foldl(f_const.result, measurement_outcome, init=init_val) @@ -227,6 +218,34 @@ def bool_op_or(x: bool, y: bool) -> bool: return state.current_frame.push(scf.IfElse(condition, then_body=then_body)) + def _get_or_func(self, state: lowering.State[CirqNode]): + # NOTE: there is currently no convenient ilist.any method, so we need to use foldl + # with a simple function that just does an or + + # NOTE: check if we already defined that function + f_prev = state.current_frame.globals.get("__BOOL_OR_FUNC") + if f_prev is not None: + return f_prev + + def bool_op_or(x: bool, y: bool) -> bool: + return x or y + + f_code = lowering.Python(self.dialects).python_function(bool_op_or) + + if f_code.parent is None: + state.current_frame.push(f_code) + fn = ir.Method( + mod=None, + py_func=bool_op_or, + sym_name="bool_op_or", + arg_names=[], + dialects=self.dialects, + code=f_code, + ) + f = state.current_frame.push(py.constant.Constant(fn)) + state.current_frame.globals["__BOOL_OR_FUNC"] = f + return f + def visit_SingleQubitPauliStringGateOperation( self, state: lowering.State[CirqNode], diff --git a/test/squin/cirq/test_cirq_to_squin.py b/test/squin/cirq/test_cirq_to_squin.py index 5b7bf0d0..fecd7a16 100644 --- a/test/squin/cirq/test_cirq_to_squin.py +++ b/test/squin/cirq/test_cirq_to_squin.py @@ -315,6 +315,35 @@ def test_multiple_classical_controls(run_sim: bool = False): kernel = squin.cirq.load_circuit(circuit) kernel.print() + if run_sim: + from bloqade.pyqrack import StackMemorySimulator + + simq = StackMemorySimulator(min_qubits=3) + simq.run(kernel) + + +def test_multiple_classical_control_registers(run_sim: bool = False): + q = cirq.LineQubit.range(2) + q2 = cirq.GridQubit.rect(1, 2) + circuit = cirq.Circuit( + cirq.H(q[0]), + cirq.H(q2[0]), + cirq.measure(q, key="test"), + cirq.measure(q2, key="test2"), + cirq.X(q[1]).with_classical_controls("test", "test2"), + cirq.X(q[1]).with_classical_controls("test"), + cirq.measure(q[1]), + ) + + print(circuit) + + if run_sim: + sim = cirq.Simulator() + sim.run(circuit) + + kernel = squin.cirq.load_circuit(circuit) + kernel.print() + def test_ghz_simulation(): q = cirq.LineQubit.range(2) From 2a343846192ab25b7fd791c14b0a4995cce55430 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Thu, 31 Jul 2025 15:16:40 +0200 Subject: [PATCH 4/4] Fix classical control lowering IfElse statement --- src/bloqade/squin/cirq/lowering.py | 11 +++++++++-- test/squin/cirq/test_cirq_to_squin.py | 8 ++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/bloqade/squin/cirq/lowering.py b/src/bloqade/squin/cirq/lowering.py index 69d05013..612cc504 100644 --- a/src/bloqade/squin/cirq/lowering.py +++ b/src/bloqade/squin/cirq/lowering.py @@ -214,9 +214,16 @@ def visit_ClassicallyControlledOperation( # NOTE: remove stmt from parent block then_stmt.detach() - then_body = ir.Block((then_stmt,)) + then_body = ir.Block((then_stmt, scf.Yield())) + then_body.args.append_from(types.Bool) - return state.current_frame.push(scf.IfElse(condition, then_body=then_body)) + # NOTE: create empty else body + else_body = ir.Block(stmts=[scf.Yield()]) + else_body.args.append_from(types.Bool) + + return state.current_frame.push( + scf.IfElse(condition, then_body=then_body, else_body=else_body) + ) def _get_or_func(self, state: lowering.State[CirqNode]): # NOTE: there is currently no convenient ilist.any method, so we need to use foldl diff --git a/test/squin/cirq/test_cirq_to_squin.py b/test/squin/cirq/test_cirq_to_squin.py index fecd7a16..65fdb668 100644 --- a/test/squin/cirq/test_cirq_to_squin.py +++ b/test/squin/cirq/test_cirq_to_squin.py @@ -191,6 +191,7 @@ def test_circuit(circuit_f, run_sim: bool = False): kernel = squin.load_circuit(circuit) kernel.print() + kernel.verify() rewrite_noise_pass(kernel) @@ -204,6 +205,7 @@ def test_return_register(): circuit = basic_circuit() kernel = squin.load_circuit(circuit, return_register=True) kernel.print() + kernel.verify() assert isinstance(kernel.return_type, types.Generic) assert kernel.return_type.body.is_subseteq(ilist.IListType) @@ -220,6 +222,7 @@ def test_passing_in_register(): print(circuit) kernel = squin.cirq.load_circuit(circuit, register_as_argument=True) kernel.print() + kernel.verify() def test_passing_and_returning_register(): @@ -229,6 +232,7 @@ def test_passing_and_returning_register(): circuit, register_as_argument=True, return_register=True ) kernel.print() + kernel.verify() def test_nesting_lowered_circuit(): @@ -277,6 +281,7 @@ def test_classical_control(run_sim: bool = False): kernel = squin.cirq.load_circuit(circuit) kernel.print() + kernel.verify() def test_classical_control_register(): @@ -292,6 +297,7 @@ def test_classical_control_register(): kernel = squin.cirq.load_circuit(circuit) kernel.print() + kernel.verify() def test_multiple_classical_controls(run_sim: bool = False): @@ -314,6 +320,7 @@ def test_multiple_classical_controls(run_sim: bool = False): kernel = squin.cirq.load_circuit(circuit) kernel.print() + kernel.verify() if run_sim: from bloqade.pyqrack import StackMemorySimulator @@ -343,6 +350,7 @@ def test_multiple_classical_control_registers(run_sim: bool = False): kernel = squin.cirq.load_circuit(circuit) kernel.print() + kernel.verify() def test_ghz_simulation():