Skip to content

Commit 10dff20

Browse files
johnzl-777kaihsin
andauthored
implement translation of squin qubit loss to stim dialect (#322)
Translate the squin qubit loss statement into special stim form. The necessity of this was brought to my attention by @ChenZhao44 --------- Co-authored-by: Kai-Hsin Wu <[email protected]>
1 parent a70a49e commit 10dff20

24 files changed

+185
-29
lines changed

src/bloqade/stim/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
z_error as z_error,
3232
detector as detector,
3333
identity as identity,
34+
qubit_loss as qubit_loss,
3435
depolarize1 as depolarize1,
3536
depolarize2 as depolarize2,
3637
pauli_string as pauli_string,

src/bloqade/stim/_wrappers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,7 @@ def y_error(p: float, targets: tuple[int, ...]) -> None: ...
190190

191191
@wraps(noise.ZError)
192192
def z_error(p: float, targets: tuple[int, ...]) -> None: ...
193+
194+
195+
@wraps(noise.QubitLoss)
196+
def qubit_loss(probs: tuple[float, ...], targets: tuple[int, ...]) -> None: ...

src/bloqade/stim/dialects/noise/emit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def pauli_channel2(
6666
return ()
6767

6868
@impl(stmts.TrivialError)
69+
@impl(stmts.QubitLoss)
6970
def non_stim_error(
7071
self, emit: EmitStimMain, frame: EmitStrFrame, stmt: stmts.TrivialError
7172
):

src/bloqade/stim/dialects/noise/stmts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,8 @@ class TrivialCorrelatedError(NonStimCorrelatedError):
104104
@statement(dialect=dialect)
105105
class TrivialError(NonStimError):
106106
name = "TRIV_ERROR"
107+
108+
109+
@statement(dialect=dialect)
110+
class QubitLoss(NonStimError):
111+
name = "loss"

src/bloqade/stim/rewrite/qubit_to_stim.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from kirin import ir
22
from kirin.rewrite.abc import RewriteRule, RewriteResult
33

4-
from bloqade.squin import op, qubit
4+
from bloqade.squin import op, noise, qubit
55
from bloqade.squin.rewrite import AddressAttribute
66
from bloqade.stim.rewrite.util import (
7-
SQUIN_STIM_GATE_MAPPING,
7+
SQUIN_STIM_OP_MAPPING,
88
rewrite_Control,
9+
rewrite_QubitLoss,
910
insert_qubit_idx_from_address,
1011
)
1112

@@ -29,14 +30,18 @@ def rewrite_Apply_and_Broadcast(
2930

3031
# this is an SSAValue, need it to be the actual operator
3132
applied_op = stmt.operator.owner
33+
34+
if isinstance(applied_op, noise.stmts.QubitLoss):
35+
return rewrite_QubitLoss(stmt)
36+
3237
assert isinstance(applied_op, op.stmts.Operator)
3338

3439
if isinstance(applied_op, op.stmts.Control):
3540
return rewrite_Control(stmt)
3641

3742
# need to handle Control through separate means
3843
# but we can handle X, Y, Z, H, and S here just fine
39-
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
44+
stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op))
4045
if stim_1q_op is None:
4146
return RewriteResult()
4247

@@ -45,7 +50,6 @@ def rewrite_Apply_and_Broadcast(
4550
if address_attr is None:
4651
return RewriteResult()
4752

48-
# sometimes you can get a whole AddressReg...
4953
assert isinstance(address_attr, AddressAttribute)
5054
qubit_idx_ssas = insert_qubit_idx_from_address(
5155
address=address_attr, stmt_to_insert_before=stmt

src/bloqade/stim/rewrite/util.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22
from kirin.dialects import py
33
from kirin.rewrite.abc import RewriteResult
44

5-
from bloqade.squin import op, wire, qubit
5+
from bloqade.squin import op, wire, noise as squin_noise, qubit
66
from bloqade.squin.rewrite import AddressAttribute
7-
from bloqade.stim.dialects import gate, collapse
7+
from bloqade.stim.dialects import gate, noise as stim_noise, collapse
88
from bloqade.analysis.address import AddressReg, AddressWire, AddressQubit, AddressTuple
99

10-
SQUIN_STIM_GATE_MAPPING = {
10+
SQUIN_STIM_OP_MAPPING = {
1111
op.stmts.X: gate.X,
1212
op.stmts.Y: gate.Y,
1313
op.stmts.Z: gate.Z,
1414
op.stmts.H: gate.H,
1515
op.stmts.S: gate.S,
1616
op.stmts.Identity: gate.Identity,
1717
op.stmts.Reset: collapse.RZ,
18+
squin_noise.stmts.QubitLoss: stim_noise.QubitLoss,
1819
}
1920

2021
# Squin allows creation of control gates where the gate can be any operator,
@@ -149,18 +150,52 @@ def rewrite_Control(
149150
stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits)
150151

151152
if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)):
152-
# have to "reroute" the input of these statements to directly plug in
153-
# to subsequent statements, remove dependency on the current statement
154-
for input_wire, output_wire in zip(
155-
stmt_with_ctrl.inputs, stmt_with_ctrl.results
156-
):
157-
output_wire.replace_by(input_wire)
153+
create_wire_passthrough(stmt_with_ctrl)
158154

159155
stmt_with_ctrl.replace_by(stim_stmt)
160156

161157
return RewriteResult(has_done_something=True)
162158

163159

160+
def rewrite_QubitLoss(
161+
stmt: qubit.Apply | qubit.Broadcast | wire.Broadcast | wire.Apply,
162+
) -> RewriteResult:
163+
"""
164+
Rewrite QubitLoss statements to Stim's TrivialError.
165+
"""
166+
167+
squin_loss_op = stmt.operator.owner
168+
assert isinstance(squin_loss_op, squin_noise.stmts.QubitLoss)
169+
170+
qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt)
171+
if qubit_idx_ssas is None:
172+
return RewriteResult()
173+
174+
stim_loss_stmt = stim_noise.QubitLoss(
175+
targets=qubit_idx_ssas,
176+
probs=(squin_loss_op.p,),
177+
)
178+
179+
if isinstance(stmt, (wire.Apply, wire.Broadcast)):
180+
create_wire_passthrough(stmt)
181+
182+
stmt.replace_by(stim_loss_stmt)
183+
# NoiseChannels are not pure,
184+
# need to manually delete because
185+
# DCE won't touch them
186+
stmt.operator.owner.delete()
187+
188+
return RewriteResult(has_done_something=True)
189+
190+
191+
def create_wire_passthrough(stmt: wire.Apply | wire.Broadcast) -> None:
192+
193+
for input_wire, output_wire in zip(stmt.inputs, stmt.results):
194+
# have to "reroute" the input of these statements to directly plug in
195+
# to subsequent statements, remove dependency on the current statement
196+
output_wire.replace_by(input_wire)
197+
198+
164199
def is_measure_result_used(
165200
stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure,
166201
) -> bool:

src/bloqade/stim/rewrite/wire_to_stim.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from kirin import ir
22
from kirin.rewrite.abc import RewriteRule, RewriteResult
33

4-
from bloqade.squin import op, wire
4+
from bloqade.squin import op, wire, noise
55
from bloqade.stim.rewrite.util import (
6-
SQUIN_STIM_GATE_MAPPING,
6+
SQUIN_STIM_OP_MAPPING,
77
rewrite_Control,
8+
rewrite_QubitLoss,
89
insert_qubit_idx_from_wire_ssa,
910
)
1011

@@ -24,12 +25,16 @@ def rewrite_Apply_and_Broadcast(
2425

2526
# this is an SSAValue, need it to be the actual operator
2627
applied_op = stmt.operator.owner
28+
29+
if isinstance(applied_op, noise.stmts.QubitLoss):
30+
return rewrite_QubitLoss(stmt)
31+
2732
assert isinstance(applied_op, op.stmts.Operator)
2833

2934
if isinstance(applied_op, op.stmts.Control):
3035
return rewrite_Control(stmt)
3136

32-
stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op))
37+
stim_1q_op = SQUIN_STIM_OP_MAPPING.get(type(applied_op))
3338
if stim_1q_op is None:
3439
return RewriteResult()
3540

test/stim/dialects/stim/test_stim_circuits.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ def test_pauli_error():
8686
interp.run(test_pauli_error, args=())
8787
print(interp.get_output())
8888

89+
@stim.main
90+
def test_qubit_loss():
91+
stim.qubit_loss(probs=(0.1, 0.2), targets=(0, 1, 2))
92+
93+
interp.run(test_qubit_loss, args=())
94+
assert interp.get_output() == "\nI_ERROR[loss](0.10000000, 0.20000000) 0 1 2"
95+
8996

9097
def test_collapse():
9198
@stim.main

test/stim/parse/test_parse_custom.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,15 @@ def test_parse_trivial_correlated():
2828
# test roundtrip
2929
out = codegen(mt)
3030
assert out.strip() == "I_ERROR[TRIV_CORR_ERROR:3](0.20000000, 0.30000000) 5 0 1 2"
31+
32+
33+
def test_qubit_loss():
34+
mt = loads(
35+
"I_ERROR[loss](0.1, 0.2) 0 1", nonstim_noise_ops={"loss": noise.QubitLoss}
36+
)
37+
38+
mt.print()
39+
40+
# test roundtrip
41+
out = codegen(mt)
42+
assert out.strip() == "I_ERROR[loss](0.10000000, 0.20000000) 0 1"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
2+
H 0 1 2 3 4
3+
I_ERROR[loss](0.10000000) 3
4+
I_ERROR[loss](0.05000000) 0 1 2 3 4
5+
MZ(0.00000000) 0 1 2 3 4
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
I_ERROR[loss](0.10000000) 0 1 2 3 4
3+
I_ERROR[loss](0.90000000) 0
4+
MZ(0.00000000) 0
5+
MZ(0.00000000) 1
6+
MZ(0.00000000) 2
7+
MZ(0.00000000) 3
8+
MZ(0.00000000) 4

test/stim/passes/test_squin_qubit_to_stim.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from kirin.dialects import py
66

77
from bloqade import squin
8-
from bloqade.squin import op, qubit, kernel
8+
from bloqade.squin import op, noise, qubit, kernel
99
from bloqade.stim.emit import EmitStimMain
1010
from bloqade.stim.passes import SquinToStim
1111
from bloqade.squin.rewrite import WrapAddressAnalysis
@@ -37,7 +37,7 @@ def load_reference_program(filename):
3737
return f.read()
3838

3939

40-
def run_address_and_stim_passes(test):
40+
def run_address_and_stim_passes(test: ir.Method):
4141
addr_frame, _ = AddressAnalysis(test.dialects).run_analysis(test)
4242
Walk(WrapAddressAnalysis(address_analysis=addr_frame.entries)).rewrite(test.code)
4343
SquinToStim(test.dialects)(test)
@@ -57,7 +57,7 @@ def test():
5757
return
5858

5959
run_address_and_stim_passes(test)
60-
base_stim_prog = load_reference_program("qubit.txt")
60+
base_stim_prog = load_reference_program("qubit.stim")
6161

6262
assert codegen(test) == base_stim_prog.rstrip()
6363

@@ -74,7 +74,7 @@ def test():
7474
return
7575

7676
run_address_and_stim_passes(test)
77-
base_stim_prog = load_reference_program("qubit_reset.txt")
77+
base_stim_prog = load_reference_program("qubit_reset.stim")
7878

7979
assert codegen(test) == base_stim_prog.rstrip()
8080

@@ -91,6 +91,26 @@ def test():
9191
return
9292

9393
run_address_and_stim_passes(test)
94-
base_stim_prog = load_reference_program("qubit_broadcast.txt")
94+
base_stim_prog = load_reference_program("qubit_broadcast.stim")
95+
96+
assert codegen(test) == base_stim_prog.rstrip()
97+
98+
99+
def test_qubit_loss():
100+
@kernel
101+
def test():
102+
n_qubits = 5
103+
ql = qubit.new(n_qubits)
104+
# apply Hadamard to all qubits
105+
squin.qubit.broadcast(op.h(), ql)
106+
# apply and broadcast qubit loss
107+
squin.qubit.apply(noise.qubit_loss(0.1), ql[3])
108+
squin.qubit.broadcast(noise.qubit_loss(0.05), ql)
109+
# measure out
110+
squin.qubit.measure(ql)
111+
return
112+
113+
run_address_and_stim_passes(test)
114+
base_stim_prog = load_reference_program("qubit_loss.stim")
95115

96116
assert codegen(test) == base_stim_prog.rstrip()

0 commit comments

Comments
 (0)