Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions src/openfermion/ops/operators/symbolic_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def __pow__(self, exponent):

def __eq__(self, other):
"""Approximate numerical equality (not true equality)."""
return self.isclose(other)
return self.isclose(other, rel_tol=EQ_TOLERANCE, abs_tol=EQ_TOLERANCE)

def __ne__(self, other):
return not (self == other)
Expand All @@ -618,15 +618,17 @@ def __next__(self):
term, coefficient = next(self._iter)
return self.__class__(term=term, coefficient=coefficient)

def isclose(self, other, tol=EQ_TOLERANCE):
def isclose(self, other, rel_tol=EQ_TOLERANCE, abs_tol=EQ_TOLERANCE):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is technically a breaking API change. are we ok with that? weigh confusion vs. code compatibility to keeping one of them named tol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe we could keep 'tol' as a deprecated parameter and set atol=tol if it exists.

"""Check if other (SymbolicOperator) is close to self.

Comparison is done for each term individually. Return True
if the difference between each term in self and other is
less than EQ_TOLERANCE
less than the specified tolerance.

Args:
other(SymbolicOperator): SymbolicOperator to compare against.
rel_tol(float): Relative tolerance.
abs_tol(float): Absolute tolerance.
Comment on lines +630 to +631
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider rtol and atol for consistency with numpy/scipy

"""
if not isinstance(self, type(other)):
return NotImplemented
Expand All @@ -635,17 +637,26 @@ def isclose(self, other, tol=EQ_TOLERANCE):
for term in set(self.terms).intersection(set(other.terms)):
a = self.terms[term]
b = other.terms[term]
if not (isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr)):
tol *= max(1, abs(a), abs(b))
if self._issmall(a - b, tol) is False:
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
if self._issmall(a - b, abs_tol) is False:
return False
elif not abs(a - b) <= abs_tol + rel_tol * max(abs(a), abs(b)):
return False
# terms only in one (compare to 0.0 so only abs_tol)
for term in set(self.terms).symmetric_difference(set(other.terms)):
if term in self.terms:
if self._issmall(self.terms[term], tol) is False:
coeff = self.terms[term]
if isinstance(coeff, sympy.Expr):
if self._issmall(coeff, abs_tol) is False:
return False
elif not abs(coeff) <= abs_tol:
return False
else:
if self._issmall(other.terms[term], tol) is False:
coeff = other.terms[term]
if isinstance(coeff, sympy.Expr):
if self._issmall(coeff, abs_tol) is False:
return False
elif not abs(coeff) <= abs_tol:
return False
return True

Expand Down
37 changes: 34 additions & 3 deletions src/openfermion/ops/operators/symbolic_operator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests symbolic_operator.py."""

import copy
import numpy
import sympy
import unittest
import warnings

import numpy
import sympy
from openfermion.config import EQ_TOLERANCE
from openfermion.testing.testing_utils import EqualsTester

from openfermion.ops.operators.fermion_operator import FermionOperator
from openfermion.ops.operators.symbolic_operator import SymbolicOperator


Expand Down Expand Up @@ -857,7 +860,35 @@ def test_pow_high_term(self):
term = DummyOperator1(ops, coeff)
high = term**10
expected = DummyOperator1(ops * 10, coeff**10)
self.assertTrue(expected == high)
self.assertTrue(high.isclose(expected, rel_tol=1e-12, abs_tol=1e-12))

def test_isclose(self):
op1 = DummyOperator1()
op2 = DummyOperator1()
op1 += DummyOperator1('0^ 1', 1000000)
op1 += DummyOperator1('2^ 3', 1)
op2 += DummyOperator1('0^ 1', 1000000)
op2 += DummyOperator1('2^ 3', 1.001)
self.assertFalse(op1.isclose(op2, abs_tol=1e-4))
self.assertTrue(op1.isclose(op2, abs_tol=1e-2))

# Case from https://github.com/quantumlib/OpenFermion/issues/764
x = FermionOperator("0^ 0")
y = FermionOperator("0^ 0")

# construct two identical operators up to some number of terms
num_terms_before_ineq = 30
for i in range(num_terms_before_ineq):
x += FermionOperator(f" (10+0j) [0^ {i}]")
y += FermionOperator(f" (10+0j) [0^ {i}]")

xfinal = FermionOperator(f" (1+0j) [0^ {num_terms_before_ineq + 1}]")
yfinal = FermionOperator(f" (2+0j) [0^ {num_terms_before_ineq + 1}]")
assert xfinal != yfinal

x += xfinal
y += yfinal
assert x != y

def test_pow_neg_error(self):
with self.assertRaises(ValueError):
Expand Down
Loading