Skip to content

Commit b310f03

Browse files
johnzl-777weinbe58
andauthored
Adding noise statements to squin (#211)
Some initial friction points/anticipated changes I can immediately see in translating this to stim - The original `PauliChannel` definition from @kaihsin can take a variable number of float arguments, but in the stim dialect there are explicit fields for each probability of a 1Q/2Q pauli operator being applied. My feeling is that this should be reflected on the squin side as just something like `1QPauliChannel` and `2QPauliChannel` with explicit fields. I don't think we'd lose any generality here along with the added benefit that it would be much harder for a user to misinput information (the default definition assumes you plug in either 3 or 15 values via varargs but it seems to easy to make a mistake here) - ~~If we DO want varargs I'm not quite sure how to nicely feed this to the kirin `@wraps` decorator~~ - From #200 , I see it could be possible to not have to provide an operator for `basis` and instead use a `CliffordString` (although I'd have to restrict CliffordString even further, to just the Pauli Operators. Would it be worth having something like a `PauliString`? Or is that silly?) cc: @weinbe58 @Roger-luo @kaihsin --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 6a35f85 commit b310f03

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

src/bloqade/squin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from . import op as op, wire as wire, qubit as qubit
1+
from . import op as op, wire as wire, noise as noise, qubit as qubit
22
from .groups import wired as wired, kernel as kernel

src/bloqade/squin/noise/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Put all the proper wrappers here
2+
3+
from kirin.lowering import wraps as _wraps
4+
5+
from bloqade.squin.op.types import Op
6+
7+
from . import stmts as stmts
8+
9+
10+
@_wraps(stmts.PauliError)
11+
def pauli_error(basis: Op, p: float) -> Op: ...
12+
13+
14+
@_wraps(stmts.PPError)
15+
def pp_error(op: Op, p: float) -> Op: ...
16+
17+
18+
@_wraps(stmts.Depolarize)
19+
def depolarize(n_qubits: int, p: float) -> Op: ...
20+
21+
22+
@_wraps(stmts.PauliChannel)
23+
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
24+
25+
26+
@_wraps(stmts.QubitLoss)
27+
def qubit_loss(p: float) -> Op: ...

src/bloqade/squin/noise/_dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kirin import ir
2+
3+
dialect = ir.Dialect(name="squin.noise")

src/bloqade/squin/noise/stmts.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from kirin import ir, types
2+
from kirin.decl import info, statement
3+
4+
from bloqade.squin.op.types import OpType
5+
6+
from ._dialect import dialect
7+
8+
9+
@statement
10+
class NoiseChannel(ir.Statement):
11+
pass
12+
13+
14+
@statement(dialect=dialect)
15+
class PauliError(NoiseChannel):
16+
basis: ir.SSAValue = info.argument(OpType)
17+
p: ir.SSAValue = info.argument(types.Float)
18+
result: ir.ResultValue = info.result(OpType)
19+
20+
21+
@statement(dialect=dialect)
22+
class PPError(NoiseChannel):
23+
"""
24+
Pauli Product Error
25+
"""
26+
27+
op: ir.SSAValue = info.argument(OpType)
28+
p: ir.SSAValue = info.argument(types.Float)
29+
result: ir.ResultValue = info.result(OpType)
30+
31+
32+
@statement(dialect=dialect)
33+
class Depolarize(NoiseChannel):
34+
"""
35+
Apply n-qubit depolaize error to qubits
36+
NOTE For Stim, this can only accept 1 or 2 qubits
37+
"""
38+
39+
n_qubits: int = info.attribute(types.Int)
40+
p: ir.SSAValue = info.argument(types.Float)
41+
result: ir.ResultValue = info.result(OpType)
42+
43+
44+
@statement(dialect=dialect)
45+
class PauliChannel(NoiseChannel):
46+
# NOTE:
47+
# 1-qubit 3 params px, py, pz
48+
# 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz
49+
# TODO add validation for params (maybe during lowering via custom lower?)
50+
n_qubits: int = info.attribute()
51+
params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)])
52+
result: ir.ResultValue = info.result(OpType)
53+
54+
55+
@statement(dialect=dialect)
56+
class QubitLoss(NoiseChannel):
57+
# NOTE: qubit loss error (not supported by Stim)
58+
p: ir.SSAValue = info.argument(types.Float)
59+
result: ir.ResultValue = info.result(OpType)

0 commit comments

Comments
 (0)