Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions src/qldpc/codes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def get_logical_error_rate_func(
num_samples: int,
max_error_rate: float = 0.3,
*,
discard_weights: Collection[int] = (),
discard_weight: int | Collection[int] = (),
**decoder_kwargs: Any,
) -> ErrorRateFunc:
"""Construct a function from physical --> logical error rate in a code capacity model.
Expand Down Expand Up @@ -760,6 +760,9 @@ def get_logical_error_rate_func(
F(p) = q_0(p) + sum_(k>0) q_k(p) F_k.
We thereby only need to sample errors of weight k > 0.
"""
if not isinstance(discard_weight, Collection):
discard_weight = [discard_weight]

decoder = decoders.get_decoder(self.matrix, **decoder_kwargs)

# sample errors of fixed weight and record failure/discard counts
Expand All @@ -769,7 +772,7 @@ def get_logical_error_rate_func(
for weight in range(1, len(sample_allocation)):
num_failures[weight], num_discards[weight] = (
self._estimate_decoding_infidelity_and_variance(
weight, sample_allocation[weight], decoder, discard_weights
weight, sample_allocation[weight], decoder, discard_weight
)
)
return ErrorRateFunc(
Expand All @@ -781,7 +784,7 @@ def _estimate_decoding_infidelity_and_variance(
error_weight: int,
num_samples: int,
decoder: decoders.Decoder,
discard_weights: Collection[int],
discard_weight: Collection[int],
) -> tuple[int, int]:
"""Sample and correct errors of a fixed weight. Return logical error and discard counts."""
num_failures = 0
Expand All @@ -795,7 +798,7 @@ def _estimate_decoding_infidelity_and_variance(
# decode the error
syndrome = self.matrix @ error
decoded_error = decoder.decode(syndrome.view(np.ndarray)).view(self.field)
if discard_weights and np.count_nonzero(decoded_error) in discard_weights:
if discard_weight and np.count_nonzero(decoded_error) in discard_weight:
num_discards += 1
elif np.any(decoded_error - error):
num_failures += 1
Expand Down Expand Up @@ -1942,7 +1945,7 @@ def get_logical_error_rate_func(
max_error_rate: float = 0.3,
pauli_bias: Sequence[float] | None = None,
*,
discard_weights: Collection[int] = (),
discard_weight: int | Collection[int] = (),
**decoder_kwargs: Any,
) -> ErrorRateFunc:
"""Construct a function from physical --> logical error rate in a code capacity model.
Expand All @@ -1963,6 +1966,9 @@ def get_logical_error_rate_func(
See help(qldpc.codes.ClassicalCode.get_logical_error_rate_func) for more details about how
this method works.
"""
if not isinstance(discard_weight, Collection):
discard_weight = [discard_weight]

# collect relative probabilities of Z, X, and Y errors
pauli_bias_zxy: npt.NDArray[np.floating] | None
if pauli_bias is not None:
Expand Down Expand Up @@ -1992,7 +1998,7 @@ def get_logical_error_rate_func(
decoder,
logical_ops,
pauli_bias_zxy,
discard_weights,
discard_weight,
)
)
return ErrorRateFunc(
Expand All @@ -2006,7 +2012,7 @@ def _estimate_decoding_fidelity_and_variance(
decoder: decoders.Decoder,
logical_ops: npt.NDArray[np.int_],
pauli_bias_zxy: npt.NDArray[np.floating] | None,
discard_weights: Collection[int],
discard_weight: Collection[int],
) -> tuple[int, int]:
"""Sample and correct errors of a fixed weight. Return logical error and discard counts."""
num_failures = 0
Expand All @@ -2032,7 +2038,7 @@ def _estimate_decoding_fidelity_and_variance(
error = np.concatenate([error_x, error_z]).view(self.field)
syndrome = syndrome_matrix @ error
decoded_error = decoder.decode(syndrome.view(np.ndarray)).view(self.field)
if discard_weights and math.symplectic_weight(decoded_error) in discard_weights:
if discard_weight and math.symplectic_weight(decoded_error) in discard_weight:
num_discards += 1 # pragma: no cover
elif np.any(logical_ops @ math.symplectic_conjugate(decoded_error - error)):
num_failures += 1
Expand Down Expand Up @@ -3022,7 +3028,7 @@ def get_logical_error_rate_func(
*,
decoder_x_kwargs: dict[str, Any] | None = None,
decoder_z_kwargs: dict[str, Any] | None = None,
discard_weights: Collection[int] = (),
discard_weight: int | Collection[int] = (),
**decoder_kwargs: Any,
) -> ErrorRateFunc:
"""Construct a function from physical --> logical error rate in a code capacity model.
Expand All @@ -3043,6 +3049,9 @@ def get_logical_error_rate_func(
See help(qldpc.codes.ClassicalCode.get_logical_error_rate_func) for more details about how
this method works.
"""
if not isinstance(discard_weight, Collection):
discard_weight = [discard_weight]

# collect relative probabilities of Z, X, and Y errors
pauli_bias_zxy: npt.NDArray[np.floating] | None
if pauli_bias is not None:
Expand Down Expand Up @@ -3087,7 +3096,7 @@ def get_logical_error_rate_func(
logicals_x,
logicals_z,
pauli_bias_zxy,
discard_weights,
discard_weight,
)
)
return ErrorRateFunc(
Expand All @@ -3103,7 +3112,7 @@ def _estimate_css_decoding_fidelity_and_variance(
logicals_x: npt.NDArray[np.int_],
logicals_z: npt.NDArray[np.int_],
pauli_bias_zxy: npt.NDArray[np.floating] | None,
discard_weights: Collection[int],
discard_weight: Collection[int],
) -> tuple[int, int]:
"""Sample and correct errors of a fixed weight. Return logical error and discard counts."""
num_failures = 0
Expand All @@ -3122,12 +3131,12 @@ def _estimate_css_decoding_fidelity_and_variance(
syndrome_z = self.matrix_x @ error_z
decoded_error_z = decoder_z.decode(syndrome_z.view(np.ndarray)).view(self.field)

if discard_weights and np.count_nonzero(decoded_error_z) in discard_weights:
if discard_weight and np.count_nonzero(decoded_error_z) in discard_weight:
num_discards += 1
continue

failure_z = np.any(logicals_x @ (decoded_error_z - error_z))
if not discard_weights and failure_z:
if not discard_weight and failure_z:
# If we are _not_ post-selecting and there _was_ a decoding failure, then there is
# no need to consider X-type errors, because we will record one failure either way.
num_failures += 1
Expand All @@ -3142,7 +3151,7 @@ def _estimate_css_decoding_fidelity_and_variance(
syndrome_x = self.matrix_z @ error_x
decoded_error_x = decoder_x.decode(syndrome_x.view(np.ndarray)).view(self.field)
if (
discard_weights and np.count_nonzero(decoded_error_x) in discard_weights
discard_weight and np.count_nonzero(decoded_error_x) in discard_weight
): # pragma: no cover
num_discards += 1
continue
Expand Down
6 changes: 3 additions & 3 deletions src/qldpc/codes/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_classical_capacity() -> None:

# compute discard rates
logical_error_rate_func = code.get_logical_error_rate_func(
num_samples=4, max_error_rate=0.5, discard_weights=[1]
num_samples=4, max_error_rate=0.5, discard_weight=1
)
assert logical_error_rate_func(0, discard_rate=True) == (0, 0)
assert logical_error_rate_func(0.5, discard_rate=True) == (0.5, 0)
Expand Down Expand Up @@ -539,7 +539,7 @@ def test_quantum_capacity() -> None:

# compute discard rates (trivial deterministic example)
logical_error_rate_func = code.get_logical_error_rate_func(
num_samples=1, max_error_rate=1, discard_weights=[1]
num_samples=1, max_error_rate=1, discard_weight=1
)
assert logical_error_rate_func(0, discard_rate=True) == (0, 0)

Expand Down Expand Up @@ -746,6 +746,6 @@ def test_css_capacity() -> None:

# compute discard rates (trivial deterministic example)
logical_error_rate_func = code.get_logical_error_rate_func(
num_samples=1, max_error_rate=1, discard_weights=[1]
num_samples=1, max_error_rate=1, discard_weight=1
)
assert logical_error_rate_func(0, discard_rate=True) == (0, 0)