Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/bloqade/pyqrack/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,9 @@ def measure_qubit(
qbit: PyQrackQubit = frame.get(stmt.qubit)
result = self._measure_qubit(qbit, interp)
return (result,)

@interp.impl(qubit.Moment)
def moment(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Moment
):
return interp.run_ssacfg_region(frame, stmt.body, args=())
5 changes: 5 additions & 0 deletions src/bloqade/squin/cirq/emit/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,8 @@ def measure_qubit_list(
qbits = frame.get(stmt.qubits)
frame.circuit.append(cirq.measure(qbits))
return (emit.void,)

@impl(qubit.Moment)
def moment(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Moment):
emit.run_ssacfg_region(frame, stmt.body, args=())
return (emit.void,)
73 changes: 51 additions & 22 deletions src/bloqade/squin/cirq/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,18 @@ def visit_Circuit(
def visit_Moment(
self, state: lowering.State[CirqNode], node: cirq.Moment
) -> lowering.Result:
for op_ in node.operations:
state.lower(op_)
frame = state.current_frame
with state.frame(node.operations) as body_frame:
body_frame.defs.update(frame.defs)
body_frame.globals.update(frame.globals)

body_frame.exhaust()
body_region = body_frame.curr_region

state.current_frame.defs.update(body_frame.defs)
state.current_frame.globals.update(body_frame.globals)

state.current_frame.push(qubit.Moment(body=body_region))

def visit_GateOperation(
self, state: lowering.State[CirqNode], node: cirq.GateOperation
Expand Down Expand Up @@ -173,24 +183,8 @@ def visit_ClassicallyControlledOperation(
measurement_outcome = state.current_frame.defs[key]

if measurement_outcome.type.is_subseteq(ilist.IListType):
# NOTE: there is currently no convenient ilist.any method, so we need to use foldl
# with a simple function that just does an or

def bool_op_or(x: bool, y: bool) -> bool:
return x or y

f_code = state.current_frame.push(
lowering.Python(self.dialects).python_function(bool_op_or)
)
fn = ir.Method(
mod=None,
py_func=bool_op_or,
sym_name="bool_op_or",
arg_names=[],
dialects=self.dialects,
code=f_code,
)
f_const = state.current_frame.push(py.constant.Constant(fn))
# TODO: replace by ilist.Any
f_const = self._get_or_func(state)
init_val = state.current_frame.push(py.Constant(False)).result
condition = state.current_frame.push(
ilist.Foldl(f_const.result, measurement_outcome, init=init_val)
Expand Down Expand Up @@ -220,9 +214,44 @@ def bool_op_or(x: bool, y: bool) -> bool:

# NOTE: remove stmt from parent block
then_stmt.detach()
then_body = ir.Block((then_stmt,))
then_body = ir.Block((then_stmt, scf.Yield()))
then_body.args.append_from(types.Bool)

# NOTE: create empty else body
else_body = ir.Block(stmts=[scf.Yield()])
else_body.args.append_from(types.Bool)

return state.current_frame.push(
scf.IfElse(condition, then_body=then_body, else_body=else_body)
)

return state.current_frame.push(scf.IfElse(condition, then_body=then_body))
def _get_or_func(self, state: lowering.State[CirqNode]):
# NOTE: there is currently no convenient ilist.any method, so we need to use foldl
# with a simple function that just does an or

# NOTE: check if we already defined that function
f_prev = state.current_frame.globals.get("__BOOL_OR_FUNC")
if f_prev is not None:
return f_prev

def bool_op_or(x: bool, y: bool) -> bool:
return x or y

f_code = lowering.Python(self.dialects).python_function(bool_op_or)

if f_code.parent is None:
state.current_frame.push(f_code)
fn = ir.Method(
mod=None,
py_func=bool_op_or,
sym_name="bool_op_or",
arg_names=[],
dialects=self.dialects,
code=f_code,
)
f = state.current_frame.push(py.constant.Constant(fn))
state.current_frame.globals["__BOOL_OR_FUNC"] = f
return f

def visit_SingleQubitPauliStringGateOperation(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/bloqade/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class MeasureQubitList(ir.Statement):
result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType])


@statement(dialect=dialect)
class Moment(ir.Statement):
traits = frozenset({ir.SSACFGRegion()})
body: ir.Region = info.region()


# NOTE: no dependent types in Python, so we have to mark it Any...
@wraps(New)
def new(n_qubits: int) -> ilist.IList[Qubit, Any]:
Expand Down
37 changes: 37 additions & 0 deletions test/squin/cirq/test_cirq_to_squin.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def test_circuit(circuit_f, run_sim: bool = False):
kernel = squin.load_circuit(circuit)

kernel.print()
kernel.verify()

rewrite_noise_pass(kernel)

Expand All @@ -204,6 +205,7 @@ def test_return_register():
circuit = basic_circuit()
kernel = squin.load_circuit(circuit, return_register=True)
kernel.print()
kernel.verify()

assert isinstance(kernel.return_type, types.Generic)
assert kernel.return_type.body.is_subseteq(ilist.IListType)
Expand All @@ -220,6 +222,7 @@ def test_passing_in_register():
print(circuit)
kernel = squin.cirq.load_circuit(circuit, register_as_argument=True)
kernel.print()
kernel.verify()


def test_passing_and_returning_register():
Expand All @@ -229,6 +232,7 @@ def test_passing_and_returning_register():
circuit, register_as_argument=True, return_register=True
)
kernel.print()
kernel.verify()


def test_nesting_lowered_circuit():
Expand Down Expand Up @@ -277,6 +281,7 @@ def test_classical_control(run_sim: bool = False):

kernel = squin.cirq.load_circuit(circuit)
kernel.print()
kernel.verify()


def test_classical_control_register():
Expand All @@ -292,6 +297,7 @@ def test_classical_control_register():

kernel = squin.cirq.load_circuit(circuit)
kernel.print()
kernel.verify()


def test_multiple_classical_controls(run_sim: bool = False):
Expand All @@ -314,6 +320,37 @@ def test_multiple_classical_controls(run_sim: bool = False):

kernel = squin.cirq.load_circuit(circuit)
kernel.print()
kernel.verify()

if run_sim:
from bloqade.pyqrack import StackMemorySimulator

simq = StackMemorySimulator(min_qubits=3)
simq.run(kernel)


def test_multiple_classical_control_registers(run_sim: bool = False):
q = cirq.LineQubit.range(2)
q2 = cirq.GridQubit.rect(1, 2)
circuit = cirq.Circuit(
cirq.H(q[0]),
cirq.H(q2[0]),
cirq.measure(q, key="test"),
cirq.measure(q2, key="test2"),
cirq.X(q[1]).with_classical_controls("test", "test2"),
cirq.X(q[1]).with_classical_controls("test"),
cirq.measure(q[1]),
)

print(circuit)

if run_sim:
sim = cirq.Simulator()
sim.run(circuit)

kernel = squin.cirq.load_circuit(circuit)
kernel.print()
kernel.verify()


def test_ghz_simulation():
Expand Down
23 changes: 23 additions & 0 deletions test/squin/cirq/test_moment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import cirq

from bloqade import squin

q = cirq.LineQubit.range(3)
circuit = cirq.Circuit(
cirq.Moment(
cirq.X(q[0]),
cirq.H(q[1]),
),
cirq.Moment(
cirq.X(q[2]),
),
)

kernel = squin.cirq.load_circuit(circuit)

kernel.print()


circuit2 = squin.cirq.emit_circuit(kernel)

print(circuit2)