Skip to content

Allow passing in and returning the quantum register when loading cirq.Circuits into squin #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
66 changes: 59 additions & 7 deletions src/bloqade/squin/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def load_circuit(
circuit: cirq.Circuit,
kernel_name: str = "main",
dialects: ir.DialectGroup = kernel,
register_as_argument: bool = False,
return_register: bool = False,
globals: dict[str, Any] | None = None,
file: str | None = None,
lineno_offset: int = 0,
Expand All @@ -32,13 +34,21 @@ def load_circuit(
Keyword Args:
kernel_name (str): The name of the kernel to load. Defaults to "main".
dialects (ir.DialectGroup | None): The dialects to use. Defaults to `squin.kernel`.
register_as_argument (bool): Determine whether the resulting kernel function should accept
a single `ilist.IList[Qubit, Any]` argument that is a list of qubits used within the
function. This allows you to compose kernel functions generated from circuits.
Defaults to `False`.
return_register (bool): Determine whether the resulting kernel functionr returns a
single value of type `ilist.IList[Qubit, Any]` that is the list of qubits used
in the kernel function. Useful when you want to compose multiple kernel functions
generated from circuits. Defaults to `False`.
globals (dict[str, Any] | None): The global variables to use. Defaults to None.
file (str | None): The file name for error reporting. Defaults to None.
lineno_offset (int): The line number offset for error reporting. Defaults to 0.
col_offset (int): The column number offset for error reporting. Defaults to 0.
compactify (bool): Whether to compactify the output. Defaults to True.

Example:
## Usage Examples:

```python
# from cirq's "hello qubit" example
Expand All @@ -60,6 +70,30 @@ def load_circuit(
# print the resulting IR
main.print()
```

You can also compose kernel functions generated from circuits by passing in
and / or returning the respective quantum registers:

```python
q = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q))

get_entangled_qubits = squin.cirq.load_circuit(
circuit, return_register=True, kernel_name="get_entangled_qubits"
)
get_entangled_qubits.print()

entangle_qubits = squin.cirq.load_circuit(
circuit, register_as_argument=True, kernel_name="entangle_qubits"
)

@squin.kernel
def main():
qreg = get_entangled_qubits()
qreg2 = squin.qubit.new(1)
entangle_qubits([qreg[1], qreg2[0]])
return squin.qubit.measure(qreg2)
```
"""

target = Squin(dialects=dialects, circuit=circuit)
Expand All @@ -71,24 +105,42 @@ def load_circuit(
lineno_offset=lineno_offset,
col_offset=col_offset,
compactify=compactify,
register_as_argument=register_as_argument,
)

# NOTE: no return value
return_value = func.ConstantNone()
body.blocks[0].stmts.append(return_value)
body.blocks[0].stmts.append(func.Return(value_or_stmt=return_value))
if return_register:
return_value = target.qreg
else:
return_value = func.ConstantNone()
body.blocks[0].stmts.append(return_value)

return_node = func.Return(value_or_stmt=return_value)
body.blocks[0].stmts.append(return_node)

if register_as_argument:
args = (target.qreg.type,)
else:
args = ()

# NOTE: add _self as argument; need to know signature before so do it after lowering
signature = func.Signature(args, return_node.value.type)
body.blocks[0].args.insert_from(
0,
types.Generic(ir.Method, types.Tuple.where(signature.inputs), signature.output),
kernel_name + "_self",
)

code = func.Function(
sym_name=kernel_name,
signature=func.Signature((), types.NoneType),
signature=signature,
body=body,
)

return ir.Method(
mod=None,
py_func=None,
sym_name=kernel_name,
arg_names=[],
arg_names=[arg.name for arg in body.blocks[0].args if arg.name is not None],
dialects=dialects,
code=code,
)
Expand Down
44 changes: 24 additions & 20 deletions src/bloqade/squin/cirq/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import field, dataclass

import cirq
from kirin import ir, lowering
from kirin import ir, types, lowering
from kirin.rewrite import Walk, CFGCompactify
from kirin.dialects import py, scf, ilist

Expand All @@ -25,20 +25,19 @@ class Squin(lowering.LoweringABC[CirqNode]):
"""Lower a cirq.Circuit object to a squin kernel"""

circuit: cirq.Circuit
qreg: qubit.New = field(init=False)
qreg: ir.SSAValue = field(init=False)
qreg_index: dict[cirq.Qid, int] = field(init=False, default_factory=dict)
next_qreg_index: int = field(init=False, default=0)

def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
index = self.qreg_index.get(qid)

if index is None:
index = self.next_qreg_index
self.qreg_index[qid] = index
self.next_qreg_index += 1
def __post_init__(self):
# TODO: sort by cirq ordering
qbits = sorted(self.circuit.all_qubits())
self.qreg_index = {qid: idx for (idx, qid) in enumerate(qbits)}

def lower_qubit_getindex(self, state: lowering.State[CirqNode], qid: cirq.Qid):
index = self.qreg_index[qid]
index_ssa = state.current_frame.push(py.Constant(index)).result
qbit_getitem = state.current_frame.push(py.GetItem(self.qreg.result, index_ssa))
qbit_getitem = state.current_frame.push(py.GetItem(self.qreg, index_ssa))
return qbit_getitem.result

def lower_qubit_getindices(
Expand All @@ -64,6 +63,7 @@ def run(
lineno_offset: int = 0,
col_offset: int = 0,
compactify: bool = True,
register_as_argument: bool = False,
) -> ir.Region:

state = lowering.State(
Expand All @@ -73,16 +73,20 @@ def run(
col_offset=col_offset,
)

with state.frame(
[stmt],
globals=globals,
finalize_next=False,
) as frame:
# NOTE: create a global register of qubits first
# TODO: can there be a circuit without qubits?
n_qubits = cirq.num_qubits(self.circuit)
n = frame.push(py.Constant(n_qubits))
self.qreg = frame.push(qubit.New(n_qubits=n.result))
with state.frame([stmt], globals=globals, finalize_next=False) as frame:

# NOTE: need a register of qubits before lowering statements
if register_as_argument:
# NOTE: register as argument to the kernel; we have freedom of choice for the name here
frame.curr_block.args.append_from(
ilist.IListType[qubit.QubitType, types.Any], name="q"
)
self.qreg = frame.curr_block.args[0]
else:
# NOTE: create a new register of appropriate size
n_qubits = len(self.qreg_index)
n = frame.push(py.Constant(n_qubits))
self.qreg = frame.push(qubit.New(n_qubits=n.result)).result

self.visit(state, stmt)

Expand Down
76 changes: 76 additions & 0 deletions test/squin/cirq/test_cirq_to_squin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import cirq
import pytest
from kirin import types
from kirin.passes import inline
from kirin.dialects import ilist

from bloqade import squin
from bloqade.pyqrack import DynamicMemorySimulator
Expand Down Expand Up @@ -124,6 +127,19 @@ def noise_channels():
)


def nested_circuit():
q = cirq.LineQubit.range(3)

return cirq.Circuit(
cirq.H(q[0]),
cirq.CircuitOperation(
cirq.Circuit(cirq.H(q[1]), cirq.CX(q[1], q[2])).freeze(),
use_repetition_ids=False,
).controlled_by(q[0]),
cirq.measure(*q),
)


@pytest.mark.parametrize(
"circuit_f",
[
Expand Down Expand Up @@ -161,6 +177,66 @@ def test_circuit(circuit_f, run_sim: bool = False):
print(ket)


def test_return_register():
circuit = basic_circuit()
kernel = squin.load_circuit(circuit, return_register=True)
kernel.print()

assert isinstance(kernel.return_type, types.Generic)
assert kernel.return_type.body.is_subseteq(ilist.IListType)


@pytest.mark.xfail
def test_nested_circuit():
# TODO: lowering for CircuitOperation
test_circuit(nested_circuit)


def test_passing_in_register():
circuit = pow_gate_circuit()
print(circuit)
kernel = squin.cirq.load_circuit(circuit, register_as_argument=True)
kernel.print()


def test_passing_and_returning_register():
circuit = pow_gate_circuit()
print(circuit)
kernel = squin.cirq.load_circuit(
circuit, register_as_argument=True, return_register=True
)
kernel.print()


def test_nesting_lowered_circuit():
q = cirq.LineQubit.range(2)
circuit = cirq.Circuit(cirq.H(q[0]), cirq.CX(*q))

get_entangled_qubits = squin.cirq.load_circuit(
circuit, return_register=True, kernel_name="get_entangled_qubits"
)
get_entangled_qubits.print()

entangle_qubits = squin.cirq.load_circuit(
circuit, register_as_argument=True, kernel_name="entangle_qubits"
)

@squin.kernel
def main():
qreg = get_entangled_qubits()
qreg2 = squin.qubit.new(1)
entangle_qubits([qreg[1], qreg2[0]])
return squin.qubit.measure(qreg2)

# if you get up to here, the validation works
main.print()

# inline to see if the IR is correct
inline.InlinePass(main.dialects)(main)

main.print()


def test_classical_control(run_sim: bool = False):
q = cirq.LineQubit.range(2)
circuit = cirq.Circuit(
Expand Down