Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/bloqade/cirq_utils/emit/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):

@impl(op.stmts.Identity)
def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites))
sites = frame.get(stmt.sites)
op = HermitianRuntime(cirq.IdentityGate(num_qubits=sites))
return (op,)

@impl(op.stmts.Control)
Expand Down
9 changes: 8 additions & 1 deletion src/bloqade/cirq_utils/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,14 +378,21 @@ def visit_SingleQubitPauliStringGateOperation(
case cirq.Z:
op_ = op.stmts.Z()
case cirq.I:
op_ = op.stmts.Identity(sites=1)
site = state.current_frame.push(py.Constant(1)).result
op_ = op.stmts.Identity(sites=site)
case _:
raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}")

state.current_frame.push(op_)
qargs = self.lower_qubit_getindices(state, [node.qubit])
return state.current_frame.push(qubit.Apply(op_.result, qargs))

def visit_IdentityGate(
self, state: lowering.State[CirqNode], node: cirq.IdentityGate
):
sites = state.current_frame.push(py.Constant(node.num_qubits()))
return state.current_frame.push(op.stmts.Identity(sites=sites.result))

def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate):
if abs(node.exponent) == 1:
return state.current_frame.push(op.stmts.H())
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/pyqrack/squin/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def rot(
def identity(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity
) -> tuple[OperatorRuntimeABC]:
return (IdentityRuntime(sites=stmt.sites),)
sites = frame.get(stmt.sites)
return (IdentityRuntime(sites=sites),)

@interp.impl(op.stmts.PhaseOp)
@interp.impl(op.stmts.ShiftOp)
Expand Down
4 changes: 4 additions & 0 deletions src/bloqade/squin/analysis/nsites/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@


class NSitesAnalysis(Forward[Sites]):
"""Analysis pass to infer number of sites an operator applies to.

**NOTE**: run kirin.passes.HintConst prior to using this analysis pass.
"""

keys = ["op.nsites"]
lattice = Sites
Expand Down
16 changes: 16 additions & 0 deletions src/bloqade/squin/analysis/nsites/impls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from kirin import interp
from kirin.analysis import const
from kirin.dialects import scf, func, ilist
from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer

Expand Down Expand Up @@ -88,6 +89,21 @@ def pauli_string(
s = stmt.string
return (NumberSites(sites=len(s)),)

@interp.impl(op.stmts.Identity)
def identity(
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Identity
):
sites = stmt.sites
const_hint = sites.hints.get("const")

if not isinstance(const_hint, const.Value):
return (interp.lattice.top(),)

if not isinstance(n_sites := const_hint.data, int):
return (interp.lattice.top(),)

return (NumberSites(sites=n_sites),)


@ilist.dialect.register(key="op.nsites")
class IListMethods(interp.MethodTable):
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/squin/op/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def reset_to_one() -> types.Op: ...


@wraps(stmts.Identity)
def identity(*, sites: int) -> types.Op: ...
def identity(sites: int) -> types.Op: ...


@wraps(stmts.Rot)
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/squin/op/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Rot(CompositeOp):
@statement(dialect=dialect)
class Identity(CompositeOp):
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
sites: int = info.attribute()
sites: ir.SSAValue = info.argument(types.Int)


@statement
Expand Down
57 changes: 31 additions & 26 deletions src/bloqade/squin/rewrite/U3_to_clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,39 @@ def sdag() -> list[ir.Statement]:
return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)]


def single_site_identity(node: ir.Statement) -> Tuple[List[ir.Statement], ...]:
(site := py.Constant(1)).insert_before(node)
return ([op.stmts.Identity(sites=site.result)],)


# (theta, phi, lam)
U3_HALF_PI_ANGLE_TO_GATES: dict[
tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]]
tuple[int, int, int], Callable[[ir.Statement], Tuple[List[ir.Statement], ...]]
] = {
(0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],),
(0, 0, 1): lambda: ([op.stmts.S()],),
(0, 0, 2): lambda: ([op.stmts.Z()],),
(0, 0, 3): lambda: (sdag(),),
(1, 0, 0): lambda: ([op.stmts.SqrtY()],),
(1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]),
(1, 0, 2): lambda: ([op.stmts.H()],),
(1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]),
(1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()),
(1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
(1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
(1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()),
(2, 0, 0): lambda: ([op.stmts.Y()],),
(2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]),
(2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]),
(2, 0, 3): lambda: (sdag(), [op.stmts.Y()]),
(0, 0, 0): single_site_identity,
(0, 0, 1): lambda node: ([op.stmts.S()],),
(0, 0, 2): lambda node: ([op.stmts.Z()],),
(0, 0, 3): lambda node: (sdag(),),
(1, 0, 0): lambda node: ([op.stmts.SqrtY()],),
(1, 0, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()]),
(1, 0, 2): lambda node: ([op.stmts.H()],),
(1, 0, 3): lambda node: (sdag(), [op.stmts.SqrtY()]),
(1, 1, 0): lambda node: ([op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 3): lambda node: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 2, 0): lambda node: ([op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 3): lambda node: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 3, 0): lambda node: ([op.stmts.SqrtY()], sdag()),
(1, 3, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
(1, 3, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
(1, 3, 3): lambda node: (sdag(), [op.stmts.SqrtY()], sdag()),
(2, 0, 0): lambda node: ([op.stmts.Y()],),
(2, 0, 1): lambda node: ([op.stmts.S()], [op.stmts.Y()]),
(2, 0, 2): lambda node: ([op.stmts.Z()], [op.stmts.Y()]),
(2, 0, 3): lambda node: (sdag(), [op.stmts.Y()]),
}


Expand Down Expand Up @@ -154,4 +159,4 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
gates_stmts is not None
), "internal error, U3 gates not found for angles: {}".format(angles_key)

return gates_stmts()
return gates_stmts(node)
1 change: 1 addition & 0 deletions test/cirq_utils/test_cirq_to_squin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def basic_circuit():

# Create a circuit.
return cirq.Circuit(
cirq.I(qubit),
cirq.X(qubit),
cirq.Y(qubit2),
cirq.Z(qubit),
Expand Down
71 changes: 70 additions & 1 deletion test/squin/analysis/test_nsites_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from kirin import ir, types
from kirin.passes import Fold
from kirin.dialects import py, func
from kirin.dialects import py, scf, func

from bloqade import squin
from bloqade.squin import op, noise, qubit
Expand Down Expand Up @@ -304,3 +304,72 @@ def test():
assert has_n_sites[6].sites == 1 # single_qubit_pauli_channel
assert has_n_sites[7].sites == 2 # two_qubit_pauli_channel
assert has_n_sites[8].sites == 1 # qubit_loss


def test_identity():

@squin.kernel(fold=False)
def main1():
squin.op.identity(2)

nsites_frame, _ = nsites.NSitesAnalysis(main1.dialects).run_analysis(
main1, no_raise=False
)

main1.print(analysis=nsites_frame.entries)

assert [nsites_frame.entries[result] for result in results_at(main1, 0, 1)] == [
nsites.NumberSites(sites=2)
]

@squin.kernel(fold=False)
def main2(n: int):
squin.op.identity(n)

nsites_frame, _ = nsites.NSitesAnalysis(main2.dialects).run_analysis(
main2, no_raise=False
)

main2.print(analysis=nsites_frame.entries)

assert [nsites_frame.entries[result] for result in results_at(main2, 0, 0)] == [
nsites.AnySites()
]

@squin.kernel(fold=False)
def main3():
n = 3
squin.op.identity(n)

nsites_frame, _ = nsites.NSitesAnalysis(main3.dialects).run_analysis(
main3, no_raise=False
)

main3.print(analysis=nsites_frame.entries)

assert [nsites_frame.entries[result] for result in results_at(main3, 0, 1)] == [
nsites.NumberSites(3)
]

@squin.kernel(fold=False)
def main4():
x = 0
for i in range(3):
squin.op.identity(i)

# NOTE: need a workaround yield here, else kirin fails verification complaining about missing yield
x += 1

nsites_frame, _ = nsites.NSitesAnalysis(main4.dialects).run_analysis(
main4, no_raise=False
)

main4.print(analysis=nsites_frame.entries)

for stmt in main4.callable_region.stmts():
if not isinstance(stmt, scf.For):
continue

for body_stmt in stmt.body.stmts():
if isinstance(body_stmt, squin.op.stmts.Identity):
assert nsites_frame.get(body_stmt.result) == nsites.AnySites()
2 changes: 1 addition & 1 deletion test/squin/rewrite/test_U3_to_clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test():

SquinToCliffordTestPass(test.dialects)(test)

assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Identity)
assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Identity)


def test_s():
Expand Down
4 changes: 2 additions & 2 deletions test/stim/passes/test_squin_wire_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def test_wire_multiple_apply():
# pass the wires through some 1 Qubit operators
(op1 := squin.op.stmts.S()),
(op2 := squin.op.stmts.H()),
(op3 := squin.op.stmts.Identity(sites=1)),
(op4 := squin.op.stmts.Identity(sites=1)),
(op3 := squin.op.stmts.Identity(sites=n_qubits.result)),
(op4 := squin.op.stmts.Identity(sites=n_qubits.result)),
(v0 := squin.wire.Apply(op1.result, w0.result)),
(v1 := squin.wire.Apply(op2.result, v0.results[0])),
(v2 := squin.wire.Apply(op3.result, v1.results[0])),
Expand Down