Skip to content

Commit

Permalink
Added __getstate__ more appropriately, tests for pickling objects, an…
Browse files Browse the repository at this point in the history
…d special tests for transform equivalence, new assert statement that makes the pickle a dict (not the most robust... but it works), added PickleError,
  • Loading branch information
thequackdaddy committed Nov 4, 2018
1 parent 08cb4f1 commit 0301817
Show file tree
Hide file tree
Showing 45 changed files with 420 additions and 156 deletions.
29 changes: 12 additions & 17 deletions patsy/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
pandas_Categorical_codes,
safe_issubdtype,
no_pickling, assert_no_pickling, check_pickle_version)
from patsy.state import StatefulTransform

if have_pandas:
import pandas
Expand All @@ -65,17 +64,14 @@ def __getstate__(self):
data = getattr(self, 'data')
contrast = getattr(self, 'contrast')
levels = getattr(self, 'levels')
return (0, data, contrast, levels)
return {'version': 0, 'data': data, 'contrast': contrast,
'levels': levels}

def __setstate__(self, pickle):
version, data, contrast, levels = pickle
check_pickle_version(version, 0, name=self.__class__.__name__)
self.data = data
self.contrast = contrast
self.levels = levels

def __eq__(self, other):
return self.__dict__ == other.__dict__
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.data = pickle['data']
self.contrast = pickle['contrast']
self.levels = pickle['levels']


def C(data, contrast=None, levels=None):
Expand Down Expand Up @@ -137,19 +133,18 @@ def test_C():
assert c4.contrast == "NEW CONTRAST"
assert c4.levels == "LEVELS"

# assert_no_pickling(c4)


def test_C_pickle():
from six.moves import cPickle as pickle
from patsy.util import assert_pickled_equals
c1 = C("asdf")
assert c1 == pickle.loads(pickle.dumps(c1))
assert_pickled_equals(c1, pickle.loads(pickle.dumps(c1)))
c2 = C("DATA", "CONTRAST", "LEVELS")
assert c2 == pickle.loads(pickle.dumps(c2))
assert_pickled_equals(c2, pickle.loads(pickle.dumps(c2)))
c3 = C(c2, levels="NEW LEVELS")
assert c3 == pickle.loads(pickle.dumps(c3))
assert_pickled_equals(c3, pickle.loads(pickle.dumps(c3)))
c4 = C(c2, "NEW CONTRAST")
assert c4 == pickle.loads(pickle.dumps(c4))
assert_pickled_equals(c4, pickle.loads(pickle.dumps(c4)))


def guess_categorical(data):
Expand Down Expand Up @@ -247,7 +242,7 @@ def sniff(self, data):
# would be too. Otherwise we need to keep looking.
return self._level_set == set([True, False])

# __getstate__ = no_pickling
__getstate__ = no_pickling

def test_CategoricalSniffer():
from patsy.missing import NAAction
Expand Down
16 changes: 10 additions & 6 deletions patsy/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from patsy.origin import Origin
from patsy.util import (atleast_2d_column_default,
repr_pretty_delegate, repr_pretty_impl,
no_pickling, assert_no_pickling)
no_pickling, assert_no_pickling, check_pickle_version)
from patsy.infix_parser import Token, Operator, infix_parse
from patsy.parse_formula import _parsing_error_test

Expand Down Expand Up @@ -69,10 +69,16 @@ def _repr_pretty_(self, p, cycle):
return repr_pretty_impl(p, self,
[self.variable_names, self.coefs, self.constants])

def __eq__(self, other):
return self.__dict__ == other.__dict__
def __getstate__(self):
return {'version': 0, 'variable_names': self.variable_names,
'coefs': self.coefs, 'constants': self.constants}

def __setstate__(self, pickle):
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.variable_names = pickle['variable_names']
self.coefs = pickle['coefs']
self.constants = pickle['constants']

# __getstate__ = no_pickling

@classmethod
def combine(cls, constraints):
Expand Down Expand Up @@ -128,8 +134,6 @@ def test_LinearConstraint():
assert_raises(ValueError, LinearConstraint, ["a", "b"],
np.zeros((0, 2)))

# assert_no_pickling(lc)

def test_LinearConstraint_combine():
comb = LinearConstraint.combine([LinearConstraint(["a", "b"], [1, 0]),
LinearConstraint(["a", "b"], [0, 1], [1])])
Expand Down
1 change: 0 additions & 1 deletion patsy/contrasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_ContrastMatrix():
from nose.tools import assert_raises
assert_raises(PatsyError, ContrastMatrix, [[1], [0]], ["a", "b"])

# assert_no_pickling(cm)

# This always produces an object of the type that Python calls 'str' (whether
# that be a Python 2 string-of-bytes or a Python 3 string-of-unicode). It does
Expand Down
45 changes: 25 additions & 20 deletions patsy/desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@ def name(self):
return "Intercept"

def __getstate__(self):
return (0, self.factors)
return {'version': 0, 'factors': self.factors}

def __setstate__(self, pickle):
version, factors = pickle
check_pickle_version(version, 0, name=self.__class__.__name__)
self.factors = factors
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.factors = pickle['factors']

# __getstate__ = no_pickling

INTERCEPT = Term([])

Expand All @@ -85,12 +83,6 @@ def __init__(self, name):
def name(self):
return self._name

def __eq__(self, other):
return self.__dict__ == other.__dict__

def __hash__(self):
return hash((_MockFactor, str(self._name)))


def test_Term():
assert Term([1, 2, 1]).factors == (1, 2)
Expand All @@ -102,11 +94,12 @@ def test_Term():
assert Term([f2, f1]).name() == "b:a"
assert Term([]).name() == "Intercept"

# assert_no_pickling(Term([]))

from six.moves import cPickle as pickle
from patsy.util import assert_pickled_equals
t = Term([f1, f2])
assert t == pickle.loads(pickle.dumps(t, pickle.HIGHEST_PROTOCOL))
t2 = pickle.loads(pickle.dumps(t, pickle.HIGHEST_PROTOCOL))
assert_pickled_equals(t, t2)


class ModelDesc(object):
"""A simple container representing the termlists parsed from a formula.
Expand Down Expand Up @@ -168,7 +161,7 @@ def term_code(term):
if term != INTERCEPT]
result += " + ".join(term_names)
return result

@classmethod
def from_formula(cls, tree_or_string):
"""Construct a :class:`ModelDesc` from a formula string.
Expand All @@ -186,10 +179,15 @@ def from_formula(cls, tree_or_string):
assert isinstance(value, cls)
return value

def __eq__(self, other):
return self.__dict__ == other.__dict__
def __getstate__(self):
return {'version': 0, 'lhs_termlist': self.lhs_termlist,
'rhs_termlist': self.rhs_termlist}

def __setstate__(self, pickle):
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.lhs_termlist = pickle['lhs_termlist']
self.rhs_termlist = pickle['rhs_termlist']

# __getstate__ = no_pickling

def test_ModelDesc():
f1 = _MockFactor("a")
Expand All @@ -202,7 +200,9 @@ def test_ModelDesc():

# assert_no_pickling(m)
from six.moves import cPickle as pickle
assert m == pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL))
from patsy.util import assert_pickled_equals
m2 = pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL))
assert_pickled_equals(m, m2)

assert ModelDesc([], []).describe() == "~ 0"
assert ModelDesc([INTERCEPT], []).describe() == "1 ~ 0"
Expand Down Expand Up @@ -234,7 +234,12 @@ def _pretty_repr_(self, p, cycle): # pragma: no cover
[self.intercept, self.intercept_origin,
self.intercept_removed, self.terms])

# __getstate__ = no_pickling
__getstate__ = no_pickling


def test_IntermediateExpr_smoke():
assert_no_pickling(IntermediateExpr(False, None, True, []))


def _maybe_add_intercept(doit, terms):
if doit:
Expand Down
67 changes: 39 additions & 28 deletions patsy/design_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,18 @@ def __repr__(self):
kwlist.append(("categories", self.categories))
repr_pretty_impl(p, self, [], kwlist)

def __eq__(self, other):
return self.__dict__ == other.__dict__
def __getstate__(self):
return {'version': 0, 'factor': self.factor, 'type': self.type,
'state': self.state, 'num_columns': self.num_columns,
'categories': self.categories}

def __hash__(self):
if not self.categories:
categories = 'NoCategories'
else:
categories = frozenset(self.categories)
return hash((FactorInfo, str(self.factor), str(self.type),
str(self.state), str(self.num_columns), categories))
def __setstate__(self, pickle):
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.factor = pickle['factor']
self.type = pickle['type']
self.state = pickle['state']
self.num_columns = pickle['num_columns']
self.categories = pickle['categories']


def test_FactorInfo():
Expand Down Expand Up @@ -245,10 +247,17 @@ def _repr_pretty_(self, p, cycle):
("contrast_matrices", self.contrast_matrices),
("num_columns", self.num_columns)])

def __eq__(self, other):
return self.__dict__ == other.__dict__
def __getstate__(self):
return {'version': 0, 'factors': self.factors,
'contrast_matrices': self.contrast_matrices,
'num_columns': self.num_columns}

def __setstate__(self, pickle):
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.factors = pickle['factors']
self.contrast_matrices = pickle['contrast_matrices']
self.num_columns = pickle['num_columns']

# __getstate__ = no_pickling

def test_SubtermInfo():
cm = ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"])
Expand Down Expand Up @@ -706,21 +715,19 @@ def from_array(cls, array_like, default_column_prefix="column"):
return DesignInfo(column_names)

def __getstate__(self):
return (0, self.column_name_indexes, self.factor_infos,
self.term_codings, self.term_slices, self.term_name_slices)
return {'version': 0, 'column_name_indexes': self.column_name_indexes,
'factor_infos': self.factor_infos,
'term_codings': self.term_codings,
'term_slices': self.term_slices,
'term_name_slices': self.term_name_slices}

def __setstate__(self, pickle):
(version, column_name_indexes, factor_infos, term_codings,
term_slices, term_name_slices) = pickle
check_pickle_version(version, 0, self.__class__.__name__)
self.column_name_indexes = column_name_indexes
self.factor_infos = factor_infos
self.term_codings = term_codings
self.term_slices = term_slices
self.term_name_slices = term_name_slices

def __eq__(self, other):
return self.__dict__ == other.__dict__
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.column_name_indexes = pickle['column_name_indexes']
self.factor_infos = pickle['factor_infos']
self.term_codings = pickle['term_codings']
self.term_slices = pickle['term_slices']
self.term_name_slices = pickle['term_name_slices']


class _MockFactor(object):
Expand Down Expand Up @@ -772,9 +779,12 @@ def test_DesignInfo():

# smoke test
repr(di)
from six.moves import cPickle as pickle

assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
# Pickling check
from six.moves import cPickle as pickle
from patsy.util import assert_pickled_equals
di2 = pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
assert_pickled_equals(di, di2)

# One without term objects
di = DesignInfo(["a1", "a2", "a3", "b"])
Expand All @@ -795,7 +805,8 @@ def test_DesignInfo():
assert di.slice("a3") == slice(2, 3)
assert di.slice("b") == slice(3, 4)

assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
di2 = pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL))
assert_pickled_equals(di, di2)

# Check intercept handling in describe()
assert DesignInfo(["Intercept", "a", "b"]).describe() == "1 + a + b"
Expand Down
24 changes: 9 additions & 15 deletions patsy/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# for __future__ flags!

# These are made available in the patsy.* namespace
__all__ = ["EvalEnvironment", "EvalFactor", "VarLookupDict"]
__all__ = ["EvalEnvironment", "EvalFactor"]

import sys
import __future__
Expand Down Expand Up @@ -62,9 +62,6 @@ def __contains__(self, key):
else:
return True

def __eq__(self, other):
return self.__dict__ == other.__dict__

def get(self, key, default=None):
try:
return self[key]
Expand Down Expand Up @@ -98,7 +95,6 @@ def test_VarLookupDict():
assert ds.get("c") is None
assert isinstance(repr(ds), six.string_types)

# assert_no_pickling(ds)

def ast_names(code):
"""Iterator that yields all the (ast) names in a Python expression.
Expand Down Expand Up @@ -255,8 +251,7 @@ def _namespace_ids(self):
def __eq__(self, other):
return (isinstance(other, EvalEnvironment)
and self.flags == other.flags
and self.namespace == other.namespace)
# and self._namespace_ids() == other._namespace_ids())
and self._namespace_ids() == other._namespace_ids())

def __ne__(self, other):
return not self == other
Expand Down Expand Up @@ -382,7 +377,6 @@ def test_EvalEnvironment_capture_namespace():

assert_raises(TypeError, EvalEnvironment.capture, 1.2)

# assert_no_pickling(EvalEnvironment.capture())

def test_EvalEnvironment_capture_flags():
if sys.version_info >= (3,):
Expand Down Expand Up @@ -649,15 +643,15 @@ def eval(self, memorize_state, data):
data)

def __getstate__(self):
return (0, self.code, self.origin)
return {'version': 0, 'code': self.code, 'origin': self.origin}

def __setstate__(self, state):
(version, code, origin) = state
check_pickle_version(version, 0, self.__class__.__name__)
self.code = code
self.origin = origin
def __setstate__(self, pickle):
check_pickle_version(pickle['version'], 0, self.__class__.__name__)
self.code = pickle['code']
self.origin = pickle['origin']

def test_EvalFactor_pickle_saves_origin():
from patsy.util import assert_pickled_equals
# The pickling tests use object equality before and after pickling
# to test that pickling worked correctly. But EvalFactor's origin field
# is not used in equality comparisons, so we need a separate test to
Expand All @@ -667,7 +661,7 @@ def test_EvalFactor_pickle_saves_origin():
new_f = pickle.loads(pickle.dumps(f))

assert f.origin is not None
assert f.origin == new_f.origin
assert_pickled_equals(f, new_f)

def test_EvalFactor_basics():
e = EvalFactor("a+b")
Expand Down
Loading

0 comments on commit 0301817

Please sign in to comment.