Skip to content

Commit 0a955cc

Browse files
richrines1mhucka
andauthored
correct noise channel approximate equality (#6632)
This fixes two related bugs related to equality checks for noise channels: - approximate comparison between cirq noise channels and other gates raises an error - exact comparison ignores DepolarizingChannel.n_qubits The first issue isfixed by just adding approximate=True to the @value_equality decorator for each class and removing their explicit implementations of _approx_eq_. The second issue just required the inclusion of n_qubits in DepolarizingChannel._value_equality_values_ Fixes: #6631 Co-authored-by: Michael Hucka <[email protected]>
1 parent 2febe7f commit 0a955cc

File tree

2 files changed

+18
-34
lines changed

2 files changed

+18
-34
lines changed

cirq-core/cirq/ops/common_channels.py

+8-34
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import cirq
2828

2929

30-
@value.value_equality
30+
@value.value_equality(approximate=True)
3131
class AsymmetricDepolarizingChannel(raw_types.Gate):
3232
r"""A channel that depolarizes asymmetrically along different directions.
3333
@@ -196,11 +196,6 @@ def error_probabilities(self) -> Dict[str, float]:
196196
def _json_dict_(self) -> Dict[str, Any]:
197197
return protocols.obj_to_dict_helper(self, ['error_probabilities'])
198198

199-
def _approx_eq_(self, other: Any, atol: float) -> bool:
200-
self_keys, self_values = zip(*sorted(self.error_probabilities.items()))
201-
other_keys, other_values = zip(*sorted(other.error_probabilities.items()))
202-
return self_keys == other_keys and protocols.approx_eq(self_values, other_values, atol=atol)
203-
204199

205200
def asymmetric_depolarize(
206201
p_x: Optional[float] = None,
@@ -246,7 +241,7 @@ def asymmetric_depolarize(
246241
return AsymmetricDepolarizingChannel(p_x, p_y, p_z, error_probabilities, tol)
247242

248243

249-
@value.value_equality
244+
@value.value_equality(approximate=True)
250245
class DepolarizingChannel(raw_types.Gate):
251246
r"""A channel that depolarizes one or several qubits.
252247
@@ -306,7 +301,7 @@ def _has_mixture_(self) -> bool:
306301
return True
307302

308303
def _value_equality_values_(self):
309-
return self._p
304+
return self._p, self._n_qubits
310305

311306
def __repr__(self) -> str:
312307
if self._n_qubits == 1:
@@ -347,9 +342,6 @@ def _json_dict_(self) -> Dict[str, Any]:
347342
return protocols.obj_to_dict_helper(self, ['p'])
348343
return protocols.obj_to_dict_helper(self, ['p', 'n_qubits'])
349344

350-
def _approx_eq_(self, other: Any, atol: float) -> bool:
351-
return np.isclose(self.p, other.p, atol=atol).item() and self.n_qubits == other.n_qubits
352-
353345

354346
def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel:
355347
r"""Returns a DepolarizingChannel with given probability of error.
@@ -381,7 +373,7 @@ def depolarize(p: float, n_qubits: int = 1) -> DepolarizingChannel:
381373
return DepolarizingChannel(p, n_qubits)
382374

383375

384-
@value.value_equality
376+
@value.value_equality(approximate=True)
385377
class GeneralizedAmplitudeDampingChannel(raw_types.Gate):
386378
r"""Dampen qubit amplitudes through non ideal dissipation.
387379
@@ -489,12 +481,6 @@ def gamma(self) -> float:
489481
def _json_dict_(self) -> Dict[str, Any]:
490482
return protocols.obj_to_dict_helper(self, ['p', 'gamma'])
491483

492-
def _approx_eq_(self, other: Any, atol: float) -> bool:
493-
return (
494-
np.isclose(self.gamma, other.gamma, atol=atol).item()
495-
and np.isclose(self.p, other.p, atol=atol).item()
496-
)
497-
498484

499485
def generalized_amplitude_damp(p: float, gamma: float) -> GeneralizedAmplitudeDampingChannel:
500486
r"""Returns a GeneralizedAmplitudeDampingChannel with probabilities gamma and p.
@@ -542,7 +528,7 @@ def generalized_amplitude_damp(p: float, gamma: float) -> GeneralizedAmplitudeDa
542528
return GeneralizedAmplitudeDampingChannel(p, gamma)
543529

544530

545-
@value.value_equality
531+
@value.value_equality(approximate=True)
546532
class AmplitudeDampingChannel(raw_types.Gate):
547533
r"""Dampen qubit amplitudes through dissipation.
548534
@@ -619,9 +605,6 @@ def gamma(self) -> float:
619605
def _json_dict_(self) -> Dict[str, Any]:
620606
return protocols.obj_to_dict_helper(self, ['gamma'])
621607

622-
def _approx_eq_(self, other: Any, atol: float) -> bool:
623-
return np.isclose(self.gamma, other.gamma, atol=atol).item()
624-
625608

626609
def amplitude_damp(gamma: float) -> AmplitudeDampingChannel:
627610
r"""Returns an AmplitudeDampingChannel with the given probability gamma.
@@ -787,7 +770,7 @@ def reset_each(*qubits: 'cirq.Qid') -> List[raw_types.Operation]:
787770
return [ResetChannel(q.dimension).on(q) for q in qubits]
788771

789772

790-
@value.value_equality
773+
@value.value_equality(approximate=True)
791774
class PhaseDampingChannel(raw_types.Gate):
792775
r"""Dampen qubit phase.
793776
@@ -881,9 +864,6 @@ def gamma(self) -> float:
881864
def _json_dict_(self) -> Dict[str, Any]:
882865
return protocols.obj_to_dict_helper(self, ['gamma'])
883866

884-
def _approx_eq_(self, other: Any, atol: float) -> bool:
885-
return np.isclose(self._gamma, other._gamma, atol=atol).item()
886-
887867

888868
def phase_damp(gamma: float) -> PhaseDampingChannel:
889869
r"""Creates a PhaseDampingChannel with damping constant gamma.
@@ -919,7 +899,7 @@ def phase_damp(gamma: float) -> PhaseDampingChannel:
919899
return PhaseDampingChannel(gamma)
920900

921901

922-
@value.value_equality
902+
@value.value_equality(approximate=True)
923903
class PhaseFlipChannel(raw_types.Gate):
924904
r"""Probabilistically flip the sign of the phase of a qubit.
925905
@@ -991,9 +971,6 @@ def p(self) -> float:
991971
def _json_dict_(self) -> Dict[str, Any]:
992972
return protocols.obj_to_dict_helper(self, ['p'])
993973

994-
def _approx_eq_(self, other: Any, atol: float) -> bool:
995-
return np.isclose(self.p, other.p, atol=atol).item()
996-
997974

998975
def _phase_flip_Z() -> common_gates.ZPowGate:
999976
"""Returns a cirq.Z which corresponds to a guaranteed phase flip."""
@@ -1073,7 +1050,7 @@ def phase_flip(p: Optional[float] = None) -> Union[common_gates.ZPowGate, PhaseF
10731050
return _phase_flip(p)
10741051

10751052

1076-
@value.value_equality
1053+
@value.value_equality(approximate=True)
10771054
class BitFlipChannel(raw_types.Gate):
10781055
r"""Probabilistically flip a qubit from 1 to 0 state or vice versa.
10791056
@@ -1148,9 +1125,6 @@ def p(self) -> float:
11481125
def _json_dict_(self) -> Dict[str, Any]:
11491126
return protocols.obj_to_dict_helper(self, ['p'])
11501127

1151-
def _approx_eq_(self, other: Any, atol: float) -> bool:
1152-
return np.isclose(self._p, other._p, atol=atol).item()
1153-
11541128

11551129
def _bit_flip(p: float) -> BitFlipChannel:
11561130
r"""Construct a BitFlipChannel that flips a qubit state with probability of a flip given by p.

cirq-core/cirq/ops/common_channels_test.py

+10
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def test_asymmetric_depolarizing_channel_eq():
9292
c = cirq.asymmetric_depolarize(0.0, 0.0, 0.0)
9393

9494
assert cirq.approx_eq(a, b, atol=1e-2)
95+
assert not cirq.approx_eq(a, cirq.X)
9596

9697
et = cirq.testing.EqualsTester()
9798
et.make_equality_group(lambda: c)
@@ -276,13 +277,15 @@ def test_depolarizing_channel_eq():
276277
c = cirq.depolarize(0.0)
277278

278279
assert cirq.approx_eq(a, b, atol=1e-2)
280+
assert not cirq.approx_eq(a, cirq.X)
279281

280282
et = cirq.testing.EqualsTester()
281283

282284
et.make_equality_group(lambda: c)
283285
et.add_equality_group(cirq.depolarize(0.1))
284286
et.add_equality_group(cirq.depolarize(0.9))
285287
et.add_equality_group(cirq.depolarize(1.0))
288+
et.add_equality_group(cirq.depolarize(1.0, n_qubits=2))
286289

287290

288291
def test_depolarizing_channel_invalid_probability():
@@ -349,6 +352,7 @@ def test_generalized_amplitude_damping_channel_eq():
349352
b = cirq.generalized_amplitude_damp(0.01, 0.0099999)
350353

351354
assert cirq.approx_eq(a, b, atol=1e-2)
355+
assert not cirq.approx_eq(a, cirq.X)
352356

353357
et = cirq.testing.EqualsTester()
354358
c = cirq.generalized_amplitude_damp(0.0, 0.0)
@@ -411,6 +415,7 @@ def test_amplitude_damping_channel_eq():
411415
c = cirq.amplitude_damp(0.0)
412416

413417
assert cirq.approx_eq(a, b, atol=1e-2)
418+
assert not cirq.approx_eq(a, cirq.X)
414419

415420
et = cirq.testing.EqualsTester()
416421
et.make_equality_group(lambda: c)
@@ -562,6 +567,7 @@ def test_phase_damping_channel_eq():
562567
c = cirq.phase_damp(0.0)
563568

564569
assert cirq.approx_eq(a, b, atol=1e-2)
570+
assert not cirq.approx_eq(a, cirq.X)
565571

566572
et = cirq.testing.EqualsTester()
567573
et.make_equality_group(lambda: c)
@@ -636,6 +642,7 @@ def test_phase_flip_channel_eq():
636642
c = cirq.phase_flip(0.0)
637643

638644
assert cirq.approx_eq(a, b, atol=1e-2)
645+
assert not cirq.approx_eq(a, cirq.X)
639646

640647
et = cirq.testing.EqualsTester()
641648
et.make_equality_group(lambda: c)
@@ -701,6 +708,7 @@ def test_bit_flip_channel_eq():
701708
c = cirq.bit_flip(0.0)
702709

703710
assert cirq.approx_eq(a, b, atol=1e-2)
711+
assert not cirq.approx_eq(a, cirq.X)
704712

705713
et = cirq.testing.EqualsTester()
706714
et.make_equality_group(lambda: c)
@@ -834,6 +842,8 @@ def test_multi_asymmetric_depolarizing_eq():
834842

835843
assert cirq.approx_eq(a, b, atol=1e-3)
836844

845+
assert not cirq.approx_eq(a, cirq.X)
846+
837847

838848
def test_multi_asymmetric_depolarizing_channel_str():
839849
assert str(cirq.asymmetric_depolarize(error_probabilities={'II': 0.8, 'XX': 0.2})) == (

0 commit comments

Comments
 (0)