Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pickling #86

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions patsy/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from patsy.tokens import (pretty_untokenize, normalize_token_spacing,
python_tokenize)
from patsy.compat import call_and_wrap_exc
from patsy.version import __version__

def _all_future_flags():
flags = 0
Expand Down Expand Up @@ -565,7 +566,41 @@ def eval(self, memorize_state, data):
memorize_state,
data)

__getstate__ = no_pickling
def __getstate__(self):
return {
'version': __version__,
'code': self.code,
'origin': self.origin
}

def __setstate__(self, state):
expected_fields = {
'code': 'REQUIRED',
'origin': 'OPTIONAL'
}

for field in expected_fields:
if field in state:
self.__setattr__(field, state[field])
continue
else:
pickling_version = state['version']
unpickling_newer_version = pickling_version.split('+')[0] > __version__.split('+')[0]
if expected_fields[field] == 'REQUIRED' and unpickling_newer_version:
msg = "This EvalFactor was pickled with patsy version %s," \
"and cannot be unpickled with version %s" % \
(pickling_version, __version__)
raise KeyError, msg
elif expected_fields[field] == 'OPTIONAL' and unpickling_newer_version:
msg = "This EvalFactor was pickled with patsy version %s," \
"and cannot be unpickled with full fidelity by version %s." \
"In particular, you have access to `code` but not to `origin`" % \
(pickling_version, __version__)
raise FutureWarning, msg
else:
msg = "Unable to unpickle EvalFactor field %s." % field
raise KeyError, msg


def test_EvalFactor_basics():
e = EvalFactor("a+b")
Expand All @@ -577,8 +612,6 @@ def test_EvalFactor_basics():
assert e.origin is None
assert e2.origin == "asdf"

assert_no_pickling(e)

def test_EvalFactor_memorize_passes_needed():
from patsy.state import stateful_transform
foo = stateful_transform(lambda: "FOO-OBJ")
Expand Down
21 changes: 21 additions & 0 deletions patsy/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Exhaustive end-to-end tests of the top-level API.

import sys
from six.moves import cPickle as pickle
import __future__
import six
import numpy as np
Expand Down Expand Up @@ -758,3 +759,23 @@ def test_C_and_pandas_categorical():
[[1, 0],
[1, 1],
[1, 0]])

def test_pickle_builder_roundtrips():
import numpy as np
# TODO Add center(x) and categorical interaction, and call to np.log to patsy formula.
design_matrix = dmatrix("x + a", {"x": [1, 2, 3],
"a": ["a1", "a2", "a3"]})
# TODO Remove builder, pass design_info to dmatrix() instead.
builder = design_matrix.design_info.builder
del np

new_data = {"x": [10, 20, 30],
"a": ["a3", "a1", "a2"]}
m1 = dmatrix(builder, new_data)

builder2 = pickle.loads(pickle.dumps(design_matrix.design_info.builder))
m2 = dmatrix(builder2, new_data)

assert np.allclose(m1, m2)


19 changes: 19 additions & 0 deletions patsy/test_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from six.moves import cPickle as pickle

from patsy.eval import EvalFactor
from patsy.version import __version__


objects_to_test = [
("EvalFactor('a+b', 'mars')", {
"0.4.1+dev": "ccopy_reg\n_reconstructor\np1\n(cpatsy.eval\nEvalFactor\np2\nc__builtin__\nobject\np3\nNtRp4\n(dp5\nS\'code\'\np6\nS\'a + b\'\np7\nsS\'origin\'\np8\nS\'mars\'\np9\nsS\'version\'\np10\nS\'0.4.1+dev\'\np11\nsb."
})
]

def test_pickling_roundtrips():
for obj_code, pickled_history in objects_to_test:
obj = eval(obj_code)
print pickle.dumps(obj).encode('string-escape')
assert obj == pickle.loads(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL))
for version, pickled in pickled_history.items():
assert pickle.dumps(obj) == pickled