Skip to content

Commit 5fcc1d8

Browse files
committed
Fix #885: type coercion due to slightly incorrect default value (#1112)
When adding a new term to a `SymbolicOperator` with a sympy coefficient, the coefficient was being cast to a float. This was because the default value for a new term was `0.0`, which is a float. This commit changes the default value to `0`, which is an integer. This ensures that the type of the coefficient is preserved when adding new terms.
1 parent 61493eb commit 5fcc1d8

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

src/openfermion/ops/operators/fermion_operator_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12-
"""Tests fermion_operator.py."""
12+
13+
"""Tests fermion_operator.py."""
14+
1315
import unittest
1416

15-
from openfermion.ops.operators.fermion_operator import FermionOperator
17+
import sympy
18+
1619
from openfermion.hamiltonians import number_operator
20+
from openfermion.ops.operators.fermion_operator import FermionOperator
1721

1822

1923
class FermionOperatorTest(unittest.TestCase):
@@ -78,3 +82,10 @@ def test_is_two_body_number_conserving_three(self):
7882
def test_is_two_body_number_conserving_out_of_order(self):
7983
op = FermionOperator(((0, 1), (2, 0), (1, 1), (3, 0)))
8084
self.assertTrue(op.is_two_body_number_conserving())
85+
86+
def test_add_sympy_rational(self):
87+
"""Test adding operators with sympy.Rational coefficients."""
88+
a = FermionOperator('0^ 0', sympy.Rational(1, 2))
89+
b = FermionOperator('1^ 1', sympy.Rational(1, 2))
90+
c = a + b
91+
self.assertIsInstance(c.terms[((0, 1), (0, 0))], sympy.Rational)

src/openfermion/ops/operators/symbolic_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def __iadd__(self, addend):
433433
"""
434434
if isinstance(addend, type(self)):
435435
for term in addend.terms:
436-
self.terms[term] = self.terms.get(term, 0.0) + addend.terms[term]
436+
self.terms[term] = self.terms.get(term, 0) + addend.terms[term]
437437
if self._issmall(self.terms[term]):
438438
del self.terms[term]
439439
elif isinstance(addend, COEFFICIENT_TYPES):
@@ -480,7 +480,7 @@ def __isub__(self, subtrahend):
480480
"""
481481
if isinstance(subtrahend, type(self)):
482482
for term in subtrahend.terms:
483-
self.terms[term] = self.terms.get(term, 0.0) - subtrahend.terms[term]
483+
self.terms[term] = self.terms.get(term, 0) - subtrahend.terms[term]
484484
if self._issmall(self.terms[term]):
485485
del self.terms[term]
486486
elif isinstance(subtrahend, COEFFICIENT_TYPES):

src/openfermion/ops/operators/symbolic_operator_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,20 @@
99
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
12+
1213
"""Tests symbolic_operator.py."""
14+
1315
import copy
1416
import unittest
1517
import warnings
1618

1719
import numpy
1820
import sympy
19-
from openfermion.config import EQ_TOLERANCE
20-
from openfermion.testing.testing_utils import EqualsTester
2121

22+
from openfermion.config import EQ_TOLERANCE
23+
from openfermion.ops.operators.fermion_operator import FermionOperator
2224
from openfermion.ops.operators.symbolic_operator import SymbolicOperator
25+
from openfermion.testing.testing_utils import EqualsTester
2326

2427

2528
class DummyOperator1(SymbolicOperator):
@@ -687,6 +690,14 @@ def test_add_sympy(self):
687690
self.assertTrue(a.terms[term_a] - coeff_a == 0)
688691
self.assertTrue(a.terms[term_b] - coeff_b - 0.5 == 0)
689692

693+
def test_add_sympy_new_term(self):
694+
"""Test adding a new term with a sympy coefficient."""
695+
x = sympy.Symbol('x')
696+
op = FermionOperator('1^', x)
697+
op += FermionOperator('2', 2 * x)
698+
self.assertEqual(op.terms[((1, 1),)], x)
699+
self.assertEqual(op.terms[((2, 0),)], 2 * x)
700+
690701
def test_radd(self):
691702
term_a = ((1, 1), (3, 0), (8, 1))
692703
coeff_a = 1

0 commit comments

Comments
 (0)