diff --git a/src/openfermion/utils/operator_utils.py b/src/openfermion/utils/operator_utils.py index 82d2e345e..2a98d6d2b 100644 --- a/src/openfermion/utils/operator_utils.py +++ b/src/openfermion/utils/operator_utils.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """This module provides generic tools for classes in ops/""" + from builtins import map, zip import marshal import os @@ -280,10 +281,23 @@ def load_operator(file_name=None, data_directory=None, plain_text=False): raise TypeError('Operator of invalid type.') else: with open(file_path, 'rb') as f: - data = marshal.load(f) + try: + data = marshal.load(f) + except Exception as e: + raise TypeError('The file content is not a valid marshal format.') from e + + if not isinstance(data, (tuple, list)) or len(data) != 2: + raise TypeError('Invalid data format in file: expected a sequence of length 2.') + operator_type = data[0] operator_terms = data[1] + if not isinstance(operator_type, str): + raise TypeError('Invalid operator type: expected a string.') + + if not isinstance(operator_terms, dict): + raise TypeError('Invalid operator terms: expected a dictionary.') + if operator_type == 'FermionOperator': operator = FermionOperator() for term in operator_terms: diff --git a/src/openfermion/utils/operator_utils_test.py b/src/openfermion/utils/operator_utils_test.py index 95b7d32cb..7c52936e6 100644 --- a/src/openfermion/utils/operator_utils_test.py +++ b/src/openfermion/utils/operator_utils_test.py @@ -11,42 +11,26 @@ # limitations under the License. """Tests for operator_utils.""" -import os - import itertools - +import marshal +import os import unittest import numpy - import sympy - from scipy.sparse import csc_matrix from openfermion.config import DATA_DIRECTORY from openfermion.hamiltonians import fermi_hubbard -from openfermion.ops.operators import ( - FermionOperator, - MajoranaOperator, - BosonOperator, - QubitOperator, - QuadOperator, - IsingOperator, -) +from openfermion.ops.operators import (BosonOperator, FermionOperator, IsingOperator, + MajoranaOperator, QuadOperator, QubitOperator) from openfermion.ops.representations import InteractionOperator -from openfermion.transforms.opconversions import jordan_wigner, bravyi_kitaev -from openfermion.transforms.repconversions import get_interaction_operator from openfermion.testing.testing_utils import random_interaction_operator -from openfermion.utils.operator_utils import ( - count_qubits, - hermitian_conjugated, - is_identity, - save_operator, - OperatorUtilsError, - is_hermitian, - load_operator, - get_file_path, -) +from openfermion.transforms.opconversions import bravyi_kitaev, jordan_wigner +from openfermion.transforms.repconversions import get_interaction_operator +from openfermion.utils.operator_utils import (OperatorUtilsError, count_qubits, get_file_path, + hermitian_conjugated, is_hermitian, is_identity, + load_operator, save_operator) class OperatorUtilsTest(unittest.TestCase): @@ -595,6 +579,48 @@ def test_save_bad_type(self): save_operator('ping', 'somewhere') +class LoadOperatorSecurityTest(unittest.TestCase): + def setUp(self): + self.bad_file_name = 'bad_marshal' + self.bad_file_path = os.path.join(DATA_DIRECTORY, self.bad_file_name + '.data') + + def tearDown(self): + if os.path.exists(self.bad_file_path): + os.remove(self.bad_file_path) + + def test_load_invalid_structure_not_tuple_or_list(self): + # Save an integer instead of a tuple + with open(self.bad_file_path, 'wb') as f: + marshal.dump(123, f) + + with self.assertRaisesRegex(TypeError, 'Invalid data format in file'): + load_operator(self.bad_file_name) + + def test_load_invalid_structure_wrong_length(self): + # Save a tuple of length 1 + with open(self.bad_file_path, 'wb') as f: + marshal.dump(('FermionOperator',), f) + + with self.assertRaisesRegex(TypeError, 'Invalid data format in file'): + load_operator(self.bad_file_name) + + def test_load_invalid_operator_type_type(self): + # Operator type is not a string + with open(self.bad_file_path, 'wb') as f: + marshal.dump((123, {}), f) + + with self.assertRaisesRegex(TypeError, 'Invalid operator type'): + load_operator(self.bad_file_name) + + def test_load_invalid_terms_type(self): + # Terms is not a dict + with open(self.bad_file_path, 'wb') as f: + marshal.dump(('FermionOperator', 123), f) + + with self.assertRaisesRegex(TypeError, 'Invalid operator terms'): + load_operator(self.bad_file_name) + + class GetFileDirTest(unittest.TestCase): def setUp(self): self.filename = 'foo'