diff --git a/src/bloqade/squin/cirq/__init__.py b/src/bloqade/squin/cirq/__init__.py index 421bfadc..1cf31602 100644 --- a/src/bloqade/squin/cirq/__init__.py +++ b/src/bloqade/squin/cirq/__init__.py @@ -18,6 +18,8 @@ def load_circuit( circuit: cirq.Circuit, kernel_name: str = "main", dialects: ir.DialectGroup = kernel, + register_as_argument: bool = False, + return_register: bool = False, globals: dict[str, Any] | None = None, file: str | None = None, lineno_offset: int = 0, @@ -32,13 +34,21 @@ def load_circuit( Keyword Args: kernel_name (str): The name of the kernel to load. Defaults to "main". dialects (ir.DialectGroup | None): The dialects to use. Defaults to `squin.kernel`. + register_as_argument (bool): Determine whether the resulting kernel function should accept + a single `ilist.IList[Qubit, Any]` argument that is a list of qubits used within the + function. This allows you to compose kernel functions generated from circuits. + Defaults to `False`. + return_register (bool): Determine whether the resulting kernel functionr returns a + single value of type `ilist.IList[Qubit, Any]` that is the list of qubits used + in the kernel function. Useful when you want to compose multiple kernel functions + generated from circuits. Defaults to `False`. globals (dict[str, Any] | None): The global variables to use. Defaults to None. file (str | None): The file name for error reporting. Defaults to None. lineno_offset (int): The line number offset for error reporting. Defaults to 0. col_offset (int): The column number offset for error reporting. Defaults to 0. compactify (bool): Whether to compactify the output. Defaults to True. - Example: + ## Usage Examples: ```python # from cirq's "hello qubit" example @@ -60,6 +70,30 @@ def load_circuit( # print the resulting IR main.print() ``` + + You can also compose kernel functions generated from circuits by passing in + and / or returning the respective quantum registers: + + ```python + q = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q)) + + get_entangled_qubits = squin.cirq.load_circuit( + circuit, return_register=True, kernel_name="get_entangled_qubits" + ) + get_entangled_qubits.print() + + entangle_qubits = squin.cirq.load_circuit( + circuit, register_as_argument=True, kernel_name="entangle_qubits" + ) + + @squin.kernel + def main(): + qreg = get_entangled_qubits() + qreg2 = squin.qubit.new(1) + entangle_qubits([qreg[1], qreg2[0]]) + return squin.qubit.measure(qreg2) + ``` """ target = Squin(dialects=dialects, circuit=circuit) @@ -71,16 +105,34 @@ def load_circuit( lineno_offset=lineno_offset, col_offset=col_offset, compactify=compactify, + register_as_argument=register_as_argument, ) - # NOTE: no return value - return_value = func.ConstantNone() - body.blocks[0].stmts.append(return_value) - body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value)) + if return_register: + return_value = target.qreg + else: + return_value = func.ConstantNone() + body.blocks[0].stmts.append(return_value) + + return_node = func.Return(value_or_stmt=return_value) + body.blocks[0].stmts.append(return_node) + + if register_as_argument: + args = (target.qreg.type,) + else: + args = () + + # NOTE: add _self as argument; need to know signature before so do it after lowering + signature = func.Signature(args, return_node.value.type) + body.blocks[0].args.insert_from( + 0, + types.Generic(ir.Method, types.Tuple.where(signature.inputs), signature.output), + kernel_name + "_self", + ) code = func.Function( sym_name=kernel_name, - signature=func.Signature((), types.NoneType), + signature=signature, body=body, ) @@ -88,7 +140,7 @@ def load_circuit( mod=None, py_func=None, sym_name=kernel_name, - arg_names=[], + arg_names=[arg.name for arg in body.blocks[0].args if arg.name is not None], dialects=dialects, code=code, ) diff --git a/src/bloqade/squin/cirq/lowering.py b/src/bloqade/squin/cirq/lowering.py index e06a666c..0b4ab720 100644 --- a/src/bloqade/squin/cirq/lowering.py +++ b/src/bloqade/squin/cirq/lowering.py @@ -3,7 +3,7 @@ from dataclasses import field, dataclass import cirq -from kirin import ir, lowering +from kirin import ir, types, lowering from kirin.rewrite import Walk, CFGCompactify from kirin.dialects import py, scf, ilist @@ -25,20 +25,19 @@ class Squin(lowering.LoweringABC[CirqNode]): """Lower a cirq.Circuit object to a squin kernel""" circuit: cirq.Circuit - qreg: qubit.New = field(init=False) + qreg: ir.SSAValue = field(init=False) qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict) next_qreg_index: int = field(init=False, default=0) - def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid): - index = self.qreg_index.get(qid) - - if index is None: - index = self.next_qreg_index - self.qreg_index[qid] = index - self.next_qreg_index += 1 + def __post_init__(self): + # TODO: sort by cirq ordering + qbits = sorted(self.circuit.all_qubits()) + self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)} + def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid): + index = self.qreg_index[qid] index_ssa = state.current_frame.push(py.Constant(index)).result - qbit_getitem = state.current_frame.push(py.GetItem(self.qreg.result, index_ssa)) + qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa)) return qbit_getitem.result def lower_qubit_getindices( @@ -64,6 +63,7 @@ def run( lineno_offset: int = 0, col_offset: int = 0, compactify: bool = True, + register_as_argument: bool = False, ) -> ir.Region: state = lowering.State( @@ -73,16 +73,20 @@ def run( col_offset=col_offset, ) - with state.frame( - [stmt], - globals=globals, - finalize_next=False, - ) as frame: - # NOTE: create a global register of qubits first - # TODO: can there be a circuit without qubits? - n_qubits = cirq.num_qubits(self.circuit) - n = frame.push(py.Constant(n_qubits)) - self.qreg = frame.push(qubit.New(n_qubits=n.result)) + with state.frame([stmt], globals=globals, finalize_next=False) as frame: + + # NOTE: need a register of qubits before lowering statements + if register_as_argument: + # NOTE: register as argument to the kernel; we have freedom of choice for the name here + frame.curr_block.args.append_from( + ilist.IListType[qubit.QubitType, types.Any], name="q" + ) + self.qreg = frame.curr_block.args[0] + else: + # NOTE: create a new register of appropriate size + n_qubits = len(self.qreg_index) + n = frame.push(py.Constant(n_qubits)) + self.qreg = frame.push(qubit.New(n_qubits=n.result)).result self.visit(state, stmt) diff --git a/test/squin/cirq/test_cirq_to_squin.py b/test/squin/cirq/test_cirq_to_squin.py index 0e7162d8..2bc63e00 100644 --- a/test/squin/cirq/test_cirq_to_squin.py +++ b/test/squin/cirq/test_cirq_to_squin.py @@ -2,6 +2,9 @@ import cirq import pytest +from kirin import types +from kirin.passes import inline +from kirin.dialects import ilist from bloqade import squin from bloqade.pyqrack import DynamicMemorySimulator @@ -124,6 +127,19 @@ def noise_channels(): ) +def nested_circuit(): + q = cirq.LineQubit.range(3) + + return cirq.Circuit( + cirq.H(q[0]), + cirq.CircuitOperation( + cirq.Circuit(cirq.H(q[1]), cirq.CX(q[1], q[2])).freeze(), + use_repetition_ids=False, + ).controlled_by(q[0]), + cirq.measure(*q), + ) + + @pytest.mark.parametrize( "circuit_f", [ @@ -161,6 +177,66 @@ def test_circuit(circuit_f, run_sim: bool = False): print(ket) +def test_return_register(): + circuit = basic_circuit() + kernel = squin.load_circuit(circuit, return_register=True) + kernel.print() + + assert isinstance(kernel.return_type, types.Generic) + assert kernel.return_type.body.is_subseteq(ilist.IListType) + + +@pytest.mark.xfail +def test_nested_circuit(): + # TODO: lowering for CircuitOperation + test_circuit(nested_circuit) + + +def test_passing_in_register(): + circuit = pow_gate_circuit() + print(circuit) + kernel = squin.cirq.load_circuit(circuit, register_as_argument=True) + kernel.print() + + +def test_passing_and_returning_register(): + circuit = pow_gate_circuit() + print(circuit) + kernel = squin.cirq.load_circuit( + circuit, register_as_argument=True, return_register=True + ) + kernel.print() + + +def test_nesting_lowered_circuit(): + q = cirq.LineQubit.range(2) + circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q)) + + get_entangled_qubits = squin.cirq.load_circuit( + circuit, return_register=True, kernel_name="get_entangled_qubits" + ) + get_entangled_qubits.print() + + entangle_qubits = squin.cirq.load_circuit( + circuit, register_as_argument=True, kernel_name="entangle_qubits" + ) + + @squin.kernel + def main(): + qreg = get_entangled_qubits() + qreg2 = squin.qubit.new(1) + entangle_qubits([qreg[1], qreg2[0]]) + return squin.qubit.measure(qreg2) + + # if you get up to here, the validation works + main.print() + + # inline to see if the IR is correct + inline.InlinePass(main.dialects)(main) + + main.print() + + def test_classical_control(run_sim: bool = False): q = cirq.LineQubit.range(2) circuit = cirq.Circuit(