diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index f770f9a..568c80f 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -1,2 +1,2 @@ -from . import op as op, wire as wire, qubit as qubit +from . import op as op, wire as wire, noise as noise, qubit as qubit from .groups import wired as wired, kernel as kernel diff --git a/src/bloqade/squin/noise/__init__.py b/src/bloqade/squin/noise/__init__.py new file mode 100644 index 0000000..f553b4a --- /dev/null +++ b/src/bloqade/squin/noise/__init__.py @@ -0,0 +1,27 @@ +# Put all the proper wrappers here + +from kirin.lowering import wraps as _wraps + +from bloqade.squin.op.types import Op + +from . import stmts as stmts + + +@_wraps(stmts.PauliError) +def pauli_error(basis: Op, p: float) -> Op: ... + + +@_wraps(stmts.PPError) +def pp_error(op: Op, p: float) -> Op: ... + + +@_wraps(stmts.Depolarize) +def depolarize(n_qubits: int, p: float) -> Op: ... + + +@_wraps(stmts.PauliChannel) +def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ... + + +@_wraps(stmts.QubitLoss) +def qubit_loss(p: float) -> Op: ... diff --git a/src/bloqade/squin/noise/_dialect.py b/src/bloqade/squin/noise/_dialect.py new file mode 100644 index 0000000..025b2df --- /dev/null +++ b/src/bloqade/squin/noise/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect(name="squin.noise") diff --git a/src/bloqade/squin/noise/stmts.py b/src/bloqade/squin/noise/stmts.py new file mode 100644 index 0000000..58b59ce --- /dev/null +++ b/src/bloqade/squin/noise/stmts.py @@ -0,0 +1,59 @@ +from kirin import ir, types +from kirin.decl import info, statement + +from bloqade.squin.op.types import OpType + +from ._dialect import dialect + + +@statement +class NoiseChannel(ir.Statement): + pass + + +@statement(dialect=dialect) +class PauliError(NoiseChannel): + basis: ir.SSAValue = info.argument(OpType) + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class PPError(NoiseChannel): + """ + Pauli Product Error + """ + + op: ir.SSAValue = info.argument(OpType) + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class Depolarize(NoiseChannel): + """ + Apply n-qubit depolaize error to qubits + NOTE For Stim, this can only accept 1 or 2 qubits + """ + + n_qubits: int = info.attribute(types.Int) + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class PauliChannel(NoiseChannel): + # NOTE: + # 1-qubit 3 params px, py, pz + # 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz + # TODO add validation for params (maybe during lowering via custom lower?) + n_qubits: int = info.attribute() + params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)]) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class QubitLoss(NoiseChannel): + # NOTE: qubit loss error (not supported by Stim) + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType)