Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/openfermion/utils/operator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides generic tools for classes in ops/"""

from builtins import map, zip
import marshal
import os
Expand Down Expand Up @@ -280,10 +281,23 @@ def load_operator(file_name=None, data_directory=None, plain_text=False):
raise TypeError('Operator of invalid type.')
else:
with open(file_path, 'rb') as f:
data = marshal.load(f)
try:
data = marshal.load(f)
except Exception as e:
raise TypeError('The file content is not a valid marshal format.') from e

if not isinstance(data, (tuple, list)) or len(data) != 2:
raise TypeError('Invalid data format in file: expected a sequence of length 2.')

operator_type = data[0]
operator_terms = data[1]

if not isinstance(operator_type, str):
raise TypeError('Invalid operator type: expected a string.')

if not isinstance(operator_terms, dict):
raise TypeError('Invalid operator terms: expected a dictionary.')

if operator_type == 'FermionOperator':
operator = FermionOperator()
for term in operator_terms:
Expand Down
76 changes: 51 additions & 25 deletions src/openfermion/utils/operator_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,26 @@
# limitations under the License.
"""Tests for operator_utils."""

import os

import itertools

import marshal
import os
import unittest

import numpy

import sympy

from scipy.sparse import csc_matrix

from openfermion.config import DATA_DIRECTORY
from openfermion.hamiltonians import fermi_hubbard
from openfermion.ops.operators import (
FermionOperator,
MajoranaOperator,
BosonOperator,
QubitOperator,
QuadOperator,
IsingOperator,
)
from openfermion.ops.operators import (BosonOperator, FermionOperator, IsingOperator,
MajoranaOperator, QuadOperator, QubitOperator)
from openfermion.ops.representations import InteractionOperator
from openfermion.transforms.opconversions import jordan_wigner, bravyi_kitaev
from openfermion.transforms.repconversions import get_interaction_operator
from openfermion.testing.testing_utils import random_interaction_operator
from openfermion.utils.operator_utils import (
count_qubits,
hermitian_conjugated,
is_identity,
save_operator,
OperatorUtilsError,
is_hermitian,
load_operator,
get_file_path,
)
from openfermion.transforms.opconversions import bravyi_kitaev, jordan_wigner
from openfermion.transforms.repconversions import get_interaction_operator
from openfermion.utils.operator_utils import (OperatorUtilsError, count_qubits, get_file_path,
hermitian_conjugated, is_hermitian, is_identity,
load_operator, save_operator)


class OperatorUtilsTest(unittest.TestCase):
Expand Down Expand Up @@ -595,6 +579,48 @@ def test_save_bad_type(self):
save_operator('ping', 'somewhere')


class LoadOperatorSecurityTest(unittest.TestCase):
def setUp(self):
self.bad_file_name = 'bad_marshal'
self.bad_file_path = os.path.join(DATA_DIRECTORY, self.bad_file_name + '.data')

def tearDown(self):
if os.path.exists(self.bad_file_path):
os.remove(self.bad_file_path)

def test_load_invalid_structure_not_tuple_or_list(self):
# Save an integer instead of a tuple
with open(self.bad_file_path, 'wb') as f:
marshal.dump(123, f)

with self.assertRaisesRegex(TypeError, 'Invalid data format in file'):
load_operator(self.bad_file_name)

def test_load_invalid_structure_wrong_length(self):
# Save a tuple of length 1
with open(self.bad_file_path, 'wb') as f:
marshal.dump(('FermionOperator',), f)

with self.assertRaisesRegex(TypeError, 'Invalid data format in file'):
load_operator(self.bad_file_name)

def test_load_invalid_operator_type_type(self):
# Operator type is not a string
with open(self.bad_file_path, 'wb') as f:
marshal.dump((123, {}), f)

with self.assertRaisesRegex(TypeError, 'Invalid operator type'):
load_operator(self.bad_file_name)

def test_load_invalid_terms_type(self):
# Terms is not a dict
with open(self.bad_file_path, 'wb') as f:
marshal.dump(('FermionOperator', 123), f)

with self.assertRaisesRegex(TypeError, 'Invalid operator terms'):
load_operator(self.bad_file_name)


class GetFileDirTest(unittest.TestCase):
def setUp(self):
self.filename = 'foo'
Expand Down