diff --git a/src/qldpc/codes/common.py b/src/qldpc/codes/common.py index d7cbbabe..5d559b9c 100644 --- a/src/qldpc/codes/common.py +++ b/src/qldpc/codes/common.py @@ -729,7 +729,7 @@ def get_logical_error_rate_func( num_samples: int, max_error_rate: float = 0.3, *, - discard_weights: Collection[int] = (), + discard_weight: int | Collection[int] = (), **decoder_kwargs: Any, ) -> ErrorRateFunc: """Construct a function from physical --> logical error rate in a code capacity model. @@ -760,6 +760,9 @@ def get_logical_error_rate_func( F(p) = q_0(p) + sum_(k>0) q_k(p) F_k. We thereby only need to sample errors of weight k > 0. """ + if not isinstance(discard_weight, Collection): + discard_weight = [discard_weight] + decoder = decoders.get_decoder(self.matrix, **decoder_kwargs) # sample errors of fixed weight and record failure/discard counts @@ -769,7 +772,7 @@ def get_logical_error_rate_func( for weight in range(1, len(sample_allocation)): num_failures[weight], num_discards[weight] = ( self._estimate_decoding_infidelity_and_variance( - weight, sample_allocation[weight], decoder, discard_weights + weight, sample_allocation[weight], decoder, discard_weight ) ) return ErrorRateFunc( @@ -781,7 +784,7 @@ def _estimate_decoding_infidelity_and_variance( error_weight: int, num_samples: int, decoder: decoders.Decoder, - discard_weights: Collection[int], + discard_weight: Collection[int], ) -> tuple[int, int]: """Sample and correct errors of a fixed weight. Return logical error and discard counts.""" num_failures = 0 @@ -795,7 +798,7 @@ def _estimate_decoding_infidelity_and_variance( # decode the error syndrome = self.matrix @ error decoded_error = decoder.decode(syndrome.view(np.ndarray)).view(self.field) - if discard_weights and np.count_nonzero(decoded_error) in discard_weights: + if discard_weight and np.count_nonzero(decoded_error) in discard_weight: num_discards += 1 elif np.any(decoded_error - error): num_failures += 1 @@ -1942,7 +1945,7 @@ def get_logical_error_rate_func( max_error_rate: float = 0.3, pauli_bias: Sequence[float] | None = None, *, - discard_weights: Collection[int] = (), + discard_weight: int | Collection[int] = (), **decoder_kwargs: Any, ) -> ErrorRateFunc: """Construct a function from physical --> logical error rate in a code capacity model. @@ -1963,6 +1966,9 @@ def get_logical_error_rate_func( See help(qldpc.codes.ClassicalCode.get_logical_error_rate_func) for more details about how this method works. """ + if not isinstance(discard_weight, Collection): + discard_weight = [discard_weight] + # collect relative probabilities of Z, X, and Y errors pauli_bias_zxy: npt.NDArray[np.floating] | None if pauli_bias is not None: @@ -1992,7 +1998,7 @@ def get_logical_error_rate_func( decoder, logical_ops, pauli_bias_zxy, - discard_weights, + discard_weight, ) ) return ErrorRateFunc( @@ -2006,7 +2012,7 @@ def _estimate_decoding_fidelity_and_variance( decoder: decoders.Decoder, logical_ops: npt.NDArray[np.int_], pauli_bias_zxy: npt.NDArray[np.floating] | None, - discard_weights: Collection[int], + discard_weight: Collection[int], ) -> tuple[int, int]: """Sample and correct errors of a fixed weight. Return logical error and discard counts.""" num_failures = 0 @@ -2032,7 +2038,7 @@ def _estimate_decoding_fidelity_and_variance( error = np.concatenate([error_x, error_z]).view(self.field) syndrome = syndrome_matrix @ error decoded_error = decoder.decode(syndrome.view(np.ndarray)).view(self.field) - if discard_weights and math.symplectic_weight(decoded_error) in discard_weights: + if discard_weight and math.symplectic_weight(decoded_error) in discard_weight: num_discards += 1 # pragma: no cover elif np.any(logical_ops @ math.symplectic_conjugate(decoded_error - error)): num_failures += 1 @@ -3022,7 +3028,7 @@ def get_logical_error_rate_func( *, decoder_x_kwargs: dict[str, Any] | None = None, decoder_z_kwargs: dict[str, Any] | None = None, - discard_weights: Collection[int] = (), + discard_weight: int | Collection[int] = (), **decoder_kwargs: Any, ) -> ErrorRateFunc: """Construct a function from physical --> logical error rate in a code capacity model. @@ -3043,6 +3049,9 @@ def get_logical_error_rate_func( See help(qldpc.codes.ClassicalCode.get_logical_error_rate_func) for more details about how this method works. """ + if not isinstance(discard_weight, Collection): + discard_weight = [discard_weight] + # collect relative probabilities of Z, X, and Y errors pauli_bias_zxy: npt.NDArray[np.floating] | None if pauli_bias is not None: @@ -3087,7 +3096,7 @@ def get_logical_error_rate_func( logicals_x, logicals_z, pauli_bias_zxy, - discard_weights, + discard_weight, ) ) return ErrorRateFunc( @@ -3103,7 +3112,7 @@ def _estimate_css_decoding_fidelity_and_variance( logicals_x: npt.NDArray[np.int_], logicals_z: npt.NDArray[np.int_], pauli_bias_zxy: npt.NDArray[np.floating] | None, - discard_weights: Collection[int], + discard_weight: Collection[int], ) -> tuple[int, int]: """Sample and correct errors of a fixed weight. Return logical error and discard counts.""" num_failures = 0 @@ -3122,12 +3131,12 @@ def _estimate_css_decoding_fidelity_and_variance( syndrome_z = self.matrix_x @ error_z decoded_error_z = decoder_z.decode(syndrome_z.view(np.ndarray)).view(self.field) - if discard_weights and np.count_nonzero(decoded_error_z) in discard_weights: + if discard_weight and np.count_nonzero(decoded_error_z) in discard_weight: num_discards += 1 continue failure_z = np.any(logicals_x @ (decoded_error_z - error_z)) - if not discard_weights and failure_z: + if not discard_weight and failure_z: # If we are _not_ post-selecting and there _was_ a decoding failure, then there is # no need to consider X-type errors, because we will record one failure either way. num_failures += 1 @@ -3142,7 +3151,7 @@ def _estimate_css_decoding_fidelity_and_variance( syndrome_x = self.matrix_z @ error_x decoded_error_x = decoder_x.decode(syndrome_x.view(np.ndarray)).view(self.field) if ( - discard_weights and np.count_nonzero(decoded_error_x) in discard_weights + discard_weight and np.count_nonzero(decoded_error_x) in discard_weight ): # pragma: no cover num_discards += 1 continue diff --git a/src/qldpc/codes/common_test.py b/src/qldpc/codes/common_test.py index 752f9ebb..7e53b676 100644 --- a/src/qldpc/codes/common_test.py +++ b/src/qldpc/codes/common_test.py @@ -231,7 +231,7 @@ def test_classical_capacity() -> None: # compute discard rates logical_error_rate_func = code.get_logical_error_rate_func( - num_samples=4, max_error_rate=0.5, discard_weights=[1] + num_samples=4, max_error_rate=0.5, discard_weight=1 ) assert logical_error_rate_func(0, discard_rate=True) == (0, 0) assert logical_error_rate_func(0.5, discard_rate=True) == (0.5, 0) @@ -539,7 +539,7 @@ def test_quantum_capacity() -> None: # compute discard rates (trivial deterministic example) logical_error_rate_func = code.get_logical_error_rate_func( - num_samples=1, max_error_rate=1, discard_weights=[1] + num_samples=1, max_error_rate=1, discard_weight=1 ) assert logical_error_rate_func(0, discard_rate=True) == (0, 0) @@ -746,6 +746,6 @@ def test_css_capacity() -> None: # compute discard rates (trivial deterministic example) logical_error_rate_func = code.get_logical_error_rate_func( - num_samples=1, max_error_rate=1, discard_weights=[1] + num_samples=1, max_error_rate=1, discard_weight=1 ) assert logical_error_rate_func(0, discard_rate=True) == (0, 0)