Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/bloqade/pyqrack/squin/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def rot(
def identity(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: op.stmts.Identity
) -> tuple[OperatorRuntimeABC]:
return (IdentityRuntime(sites=stmt.sites),)
sites = frame.get(stmt.sites)
return (IdentityRuntime(sites=sites),)

@interp.impl(op.stmts.PhaseOp)
@interp.impl(op.stmts.ShiftOp)
Expand Down
21 changes: 19 additions & 2 deletions src/bloqade/squin/analysis/nsites/impls.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from kirin import interp
from kirin.dialects import scf, func, ilist
from kirin import ir, interp
from kirin.dialects import py, scf, func, ilist
from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer

from bloqade.squin import op, wire, noise
Expand Down Expand Up @@ -88,6 +88,23 @@ def pauli_string(
s = stmt.string
return (NumberSites(sites=len(s)),)

@interp.impl(op.stmts.Identity)
def identity(
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Identity
):
sites = stmt.sites

if not isinstance(sites, ir.ResultValue):
return (interp.lattice.top(),)

if not isinstance(site_stmt := sites.stmt, py.Constant):
return (interp.lattice.top(),)

if not isinstance(value := site_stmt.value, ir.PyAttr):
return (interp.lattice.top(),)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should return bottom instead of top tho

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this does not deal with alias, so double check if the analysis are run before alias inline pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, assuming wrap const and cont prop is done before this analysis, then you can get the constant it from hint["const"]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's not run, which is why there were no hints. But I'll just include it in the run then.

this should return bottom instead of top tho

Can you explain why? I thought since there is no error, but you just can't know how many sites there are since sites is not a constant that top would be the way to go as it's AnySites.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right... nvm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just change the get constant as @weinbe58 mentioned below then


return (NumberSites(sites=value.data),)


@ilist.dialect.register(key="op.nsites")
class IListMethods(interp.MethodTable):
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/squin/cirq/emit/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):

@impl(op.stmts.Identity)
def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites))
sites = frame.get(stmt.sites)
op = HermitianRuntime(cirq.IdentityGate(num_qubits=sites))
return (op,)

@impl(op.stmts.Control)
Expand Down
9 changes: 8 additions & 1 deletion src/bloqade/squin/cirq/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,21 @@ def visit_SingleQubitPauliStringGateOperation(
case cirq.Z:
op_ = op.stmts.Z()
case cirq.I:
op_ = op.stmts.Identity(sites=1)
site = state.current_frame.push(py.Constant(1)).result
op_ = op.stmts.Identity(sites=site)
case _:
raise lowering.BuildError(f"Unexpected Pauli operation {node.pauli}")

state.current_frame.push(op_)
qargs = self.lower_qubit_getindices(state, [node.qubit])
return state.current_frame.push(qubit.Apply(op_.result, qargs))

def visit_IdentityGate(
self, state: lowering.State[CirqNode], node: cirq.IdentityGate
):
sites = state.current_frame.push(py.Constant(node.num_qubits()))
return state.current_frame.push(op.stmts.Identity(sites=sites.result))

def visit_HPowGate(self, state: lowering.State[CirqNode], node: cirq.HPowGate):
if abs(node.exponent) == 1:
return state.current_frame.push(op.stmts.H())
Expand Down
3 changes: 2 additions & 1 deletion src/bloqade/squin/noise/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def rewrite_two_qubit_pauli_channel(
def _insert_two_qubit_paulis_before_node(
node: TwoQubitPauliChannel | Depolarize2,
) -> ir.ResultValue:
paulis = (Identity(sites=1), X(), Y(), Z())
(site := py.Constant(1)).insert_before(node)
paulis = (Identity(sites=site.result), X(), Y(), Z())
for op in paulis:
op.insert_before(node)

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/squin/op/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def reset_to_one() -> types.Op: ...


@wraps(stmts.Identity)
def identity(*, sites: int) -> types.Op: ...
def identity(sites: int) -> types.Op: ...


@wraps(stmts.Rot)
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/squin/op/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Rot(CompositeOp):
@statement(dialect=dialect)
class Identity(CompositeOp):
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
sites: int = info.attribute()
sites: ir.SSAValue = info.argument(types.Int)


@statement
Expand Down
57 changes: 31 additions & 26 deletions src/bloqade/squin/rewrite/U3_to_clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,39 @@ def sdag() -> list[ir.Statement]:
return [_op := op.stmts.S(), op.stmts.Adjoint(op=_op.result, is_unitary=True)]


def single_site_identity(node: ir.Statement) -> Tuple[List[ir.Statement], ...]:
(site := py.Constant(1)).insert_before(node)
return ([op.stmts.Identity(sites=site.result)],)


# (theta, phi, lam)
U3_HALF_PI_ANGLE_TO_GATES: dict[
tuple[int, int, int], Callable[[], Tuple[List[ir.Statement], ...]]
tuple[int, int, int], Callable[[ir.Statement], Tuple[List[ir.Statement], ...]]
] = {
(0, 0, 0): lambda: ([op.stmts.Identity(sites=1)],),
(0, 0, 1): lambda: ([op.stmts.S()],),
(0, 0, 2): lambda: ([op.stmts.Z()],),
(0, 0, 3): lambda: (sdag(),),
(1, 0, 0): lambda: ([op.stmts.SqrtY()],),
(1, 0, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()]),
(1, 0, 2): lambda: ([op.stmts.H()],),
(1, 0, 3): lambda: (sdag(), [op.stmts.SqrtY()]),
(1, 1, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 2, 0): lambda: ([op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 3): lambda: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 3, 0): lambda: ([op.stmts.SqrtY()], sdag()),
(1, 3, 1): lambda: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
(1, 3, 2): lambda: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
(1, 3, 3): lambda: (sdag(), [op.stmts.SqrtY()], sdag()),
(2, 0, 0): lambda: ([op.stmts.Y()],),
(2, 0, 1): lambda: ([op.stmts.S()], [op.stmts.Y()]),
(2, 0, 2): lambda: ([op.stmts.Z()], [op.stmts.Y()]),
(2, 0, 3): lambda: (sdag(), [op.stmts.Y()]),
(0, 0, 0): single_site_identity,
(0, 0, 1): lambda node: ([op.stmts.S()],),
(0, 0, 2): lambda node: ([op.stmts.Z()],),
(0, 0, 3): lambda node: (sdag(),),
(1, 0, 0): lambda node: ([op.stmts.SqrtY()],),
(1, 0, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()]),
(1, 0, 2): lambda node: ([op.stmts.H()],),
(1, 0, 3): lambda node: (sdag(), [op.stmts.SqrtY()]),
(1, 1, 0): lambda node: ([op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 1, 3): lambda node: (sdag(), [op.stmts.SqrtY()], [op.stmts.S()]),
(1, 2, 0): lambda node: ([op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 2, 3): lambda node: (sdag(), [op.stmts.SqrtY()], [op.stmts.Z()]),
(1, 3, 0): lambda node: ([op.stmts.SqrtY()], sdag()),
(1, 3, 1): lambda node: ([op.stmts.S()], [op.stmts.SqrtY()], sdag()),
(1, 3, 2): lambda node: ([op.stmts.Z()], [op.stmts.SqrtY()], sdag()),
(1, 3, 3): lambda node: (sdag(), [op.stmts.SqrtY()], sdag()),
(2, 0, 0): lambda node: ([op.stmts.Y()],),
(2, 0, 1): lambda node: ([op.stmts.S()], [op.stmts.Y()]),
(2, 0, 2): lambda node: ([op.stmts.Z()], [op.stmts.Y()]),
(2, 0, 3): lambda node: (sdag(), [op.stmts.Y()]),
}


Expand Down Expand Up @@ -154,4 +159,4 @@ def decompose_U3_gates(self, node: op.stmts.U3) -> Tuple[List[ir.Statement], ...
gates_stmts is not None
), "internal error, U3 gates not found for angles: {}".format(angles_key)

return gates_stmts()
return gates_stmts(node)
8 changes: 5 additions & 3 deletions test/squin/analysis/test_nsites_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,11 @@ def test():
qubit.apply(p_3q_err, q[0], q[3], q[5])
qubit.apply(noise.depolarize2(p=0.1), q[0], q[1])

RewriteNoiseStmts(dialects=test.dialects)(test)
RewriteNoiseStmts(dialects=test.dialects, no_raise=False)(test)

nsites_frame, _ = nsites.NSitesAnalysis(test.dialects).run_analysis(test)
nsites_frame, _ = nsites.NSitesAnalysis(test.dialects).run_analysis(
test, no_raise=False
)

test.print(analysis=nsites_frame.entries)

Expand All @@ -332,6 +334,6 @@ def test():
assert [nsites_frame.entries[result] for result in results_at(test, 0, 11)] == [
nsites.NumberSites(sites=3)
]
assert [nsites_frame.entries[result] for result in results_at(test, 0, 46)] == [
assert [nsites_frame.entries[result] for result in results_at(test, 0, 47)] == [
nsites.NumberSites(sites=2)
]
1 change: 1 addition & 0 deletions test/squin/cirq/test_cirq_to_squin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def basic_circuit():

# Create a circuit.
return cirq.Circuit(
cirq.I(qubit),
cirq.X(qubit),
cirq.Y(qubit2),
cirq.Z(qubit),
Expand Down
2 changes: 1 addition & 1 deletion test/squin/rewrite/test_U3_to_clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test():

SquinToCliffordTestPass(test.dialects)(test)

assert isinstance(get_stmt_at_idx(test, 4), op.stmts.Identity)
assert isinstance(get_stmt_at_idx(test, 5), op.stmts.Identity)


def test_s():
Expand Down
4 changes: 2 additions & 2 deletions test/stim/passes/test_squin_wire_to_stim.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def test_wire_multiple_apply():
# pass the wires through some 1 Qubit operators
(op1 := squin.op.stmts.S()),
(op2 := squin.op.stmts.H()),
(op3 := squin.op.stmts.Identity(sites=1)),
(op4 := squin.op.stmts.Identity(sites=1)),
(op3 := squin.op.stmts.Identity(sites=n_qubits.result)),
(op4 := squin.op.stmts.Identity(sites=n_qubits.result)),
(v0 := squin.wire.Apply(op1.result, w0.result)),
(v1 := squin.wire.Apply(op2.result, v0.results[0])),
(v2 := squin.wire.Apply(op3.result, v1.results[0])),
Expand Down