Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 43 additions & 30 deletions src/amuse/test/suite/core_tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,30 @@
from amuse.test import amusetest

import subprocess
import pickle
import sys
import os

from amuse.support.exceptions import AmuseException
from amuse.test import amusetest

from amuse.units import core
from amuse.units import si
from amuse.units import nbody_system
from amuse.units import generic_unit_system
from amuse.units.quantities import zero
from amuse.units.units import *
from amuse.units.constants import *
from amuse.units.units import m, km, kg, parsec, stellar_type

from amuse.datamodel import Particles, parameters

import subprocess
import pickle
import sys
import os


class TestPicklingOfUnitsAndQuantities(amusetest.TestCase):

def test1(self):
km = 1000 * m
self.assertEqual(1000, km.value_in(m))
pickled_km = pickle.dumps(km)
unpickled_km = pickle.loads(pickled_km)
self.assertEqual(1000, unpickled_km.value_in(m))
kilometer = 1000 * m
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renaming it here since 'km' is already defined... could probably just as easily use the pre-defined 'km' unit, but maybe there was a reason not to?

self.assertEqual(1000, kilometer.value_in(m))
pickled_kilometer = pickle.dumps(kilometer)
unpickled_kilometer = pickle.loads(pickled_kilometer)
self.assertEqual(1000, unpickled_kilometer.value_in(m))

def test2(self):
km = 1000 * m
quantity = 12.0 | km
kilometer = 1000 * m
quantity = 12.0 | kilometer
self.assertEqual(12000, quantity.value_in(m))
pickled_quantity = pickle.dumps(quantity)
unpickled_quantity = pickle.loads(pickled_quantity)
Expand Down Expand Up @@ -81,15 +74,22 @@ def test8(self):

def test9(self):
quantity = 1.3 | nbody_system.time
path = os.path.abspath(os.path.join(self.get_path_to_results(), "test9.pickle"))
path = os.path.abspath(
os.path.join(self.get_path_to_results(), "test9.pickle")
)

with open(path, "wb") as stream:
pickle.dump(quantity, stream)

pythonpath = os.pathsep.join(sys.path)
env = os.environ.copy()
env['PYTHONPATH'] = pythonpath
code = "import pickle;stream = open('{0}', 'rb'); print(str(pickle.load(stream)));stream.close()".format(path)
code = (
f"import pickle;"
f"stream = open('{path}', 'rb');"
f"print(str(pickle.load(stream)));"
f"stream.close()"
)

process = subprocess.Popen([
sys.executable,
Expand All @@ -100,18 +100,28 @@ def test9(self):
)
unpickled_quantity_string, error_string = process.communicate()
self.assertEqual(process.returncode, 0)
self.assertEqual(str(quantity), unpickled_quantity_string.strip().decode('utf-8'))
self.assertEqual(
str(quantity),
unpickled_quantity_string.strip().decode('utf-8')
)

def test10(self):
quantity = 1 | parsec
path = os.path.abspath(os.path.join(self.get_path_to_results(), "test10.pickle"))
path = os.path.abspath(
os.path.join(self.get_path_to_results(), "test10.pickle")
)
with open(path, "wb") as stream:
pickle.dump(quantity, stream)

pythonpath = os.pathsep.join(sys.path)
env = os.environ.copy()
env['PYTHONPATH'] = pythonpath
code = "import pickle;stream = open('{0}', 'rb'); print(str(pickle.load(stream)));stream.close()".format(path)
code = (
f"import pickle;"
f"stream = open('{path}', 'rb');"
f"print(str(pickle.load(stream)));"
f"stream.close()"
)

process = subprocess.Popen([
sys.executable,
Expand All @@ -122,7 +132,10 @@ def test10(self):
)
unpickled_quantity_string, error_string = process.communicate()
self.assertEqual(process.returncode, 0)
self.assertEqual(str(quantity), unpickled_quantity_string.strip().decode('utf-8'))
self.assertEqual(
str(quantity),
unpickled_quantity_string.strip().decode('utf-8')
)

def test11(self):
value = 1 | stellar_type
Expand Down Expand Up @@ -175,7 +188,7 @@ def test4(self):
self.assertEqual(unpickled_particles.center_of_mass(), [2, 3, 0] | m)


class BaseTestModule(object):
class BaseTestModule:
def before_get_parameter(self):
return

Expand All @@ -202,13 +215,13 @@ def set_test(self, value):
self.x = value

o = TestModule()
set = parameters.Parameters([definition,], o)
set.test_name = 10 | m
paramset = parameters.Parameters([definition,], o)
paramset.test_name = 10 | m

self.assertEqual(o.x, 10 | m)
self.assertEqual(set.test_name, 10 | m)
self.assertEqual(paramset.test_name, 10 | m)

memento = set.copy()
memento = paramset.copy()
self.assertEqual(memento.test_name, 10 | m)

pickled_memento = pickle.dumps(memento)
Expand Down