diff --git a/patsy/categorical.py b/patsy/categorical.py index 5812b39..131f932 100644 --- a/patsy/categorical.py +++ b/patsy/categorical.py @@ -47,7 +47,6 @@ pandas_Categorical_codes, safe_issubdtype, no_pickling, assert_no_pickling, check_pickle_version) -from patsy.state import StatefulTransform if have_pandas: import pandas @@ -65,17 +64,14 @@ def __getstate__(self): data = getattr(self, 'data') contrast = getattr(self, 'contrast') levels = getattr(self, 'levels') - return (0, data, contrast, levels) + return {'version': 0, 'data': data, 'contrast': contrast, + 'levels': levels} def __setstate__(self, pickle): - version, data, contrast, levels = pickle - check_pickle_version(version, 0, name=self.__class__.__name__) - self.data = data - self.contrast = contrast - self.levels = levels - - def __eq__(self, other): - return self.__dict__ == other.__dict__ + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.data = pickle['data'] + self.contrast = pickle['contrast'] + self.levels = pickle['levels'] def C(data, contrast=None, levels=None): @@ -137,19 +133,18 @@ def test_C(): assert c4.contrast == "NEW CONTRAST" assert c4.levels == "LEVELS" - # assert_no_pickling(c4) - def test_C_pickle(): from six.moves import cPickle as pickle + from patsy.util import assert_pickled_equals c1 = C("asdf") - assert c1 == pickle.loads(pickle.dumps(c1)) + assert_pickled_equals(c1, pickle.loads(pickle.dumps(c1))) c2 = C("DATA", "CONTRAST", "LEVELS") - assert c2 == pickle.loads(pickle.dumps(c2)) + assert_pickled_equals(c2, pickle.loads(pickle.dumps(c2))) c3 = C(c2, levels="NEW LEVELS") - assert c3 == pickle.loads(pickle.dumps(c3)) + assert_pickled_equals(c3, pickle.loads(pickle.dumps(c3))) c4 = C(c2, "NEW CONTRAST") - assert c4 == pickle.loads(pickle.dumps(c4)) + assert_pickled_equals(c4, pickle.loads(pickle.dumps(c4))) def guess_categorical(data): @@ -247,7 +242,7 @@ def sniff(self, data): # would be too. Otherwise we need to keep looking. return self._level_set == set([True, False]) - # __getstate__ = no_pickling + __getstate__ = no_pickling def test_CategoricalSniffer(): from patsy.missing import NAAction diff --git a/patsy/constraint.py b/patsy/constraint.py index 9c36a87..8b1625b 100644 --- a/patsy/constraint.py +++ b/patsy/constraint.py @@ -18,7 +18,7 @@ from patsy.util import (atleast_2d_column_default, repr_pretty_delegate, repr_pretty_impl, SortAnythingKey, - no_pickling, assert_no_pickling) + no_pickling, assert_no_pickling, check_pickle_version) from patsy.infix_parser import Token, Operator, ParseNode, infix_parse class LinearConstraint(object): @@ -65,10 +65,16 @@ def _repr_pretty_(self, p, cycle): return repr_pretty_impl(p, self, [self.variable_names, self.coefs, self.constants]) - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __getstate__(self): + return {'version': 0, 'variable_names': self.variable_names, + 'coefs': self.coefs, 'constants': self.constants} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.variable_names = pickle['variable_names'] + self.coefs = pickle['coefs'] + self.constants = pickle['constants'] - # __getstate__ = no_pickling @classmethod def combine(cls, constraints): @@ -121,8 +127,6 @@ def test_LinearConstraint(): assert_raises(ValueError, LinearConstraint, ["a", "b"], np.zeros((0, 2))) - # assert_no_pickling(lc) - def test_LinearConstraint_combine(): comb = LinearConstraint.combine([LinearConstraint(["a", "b"], [1, 0]), LinearConstraint(["a", "b"], [0, 1], [1])]) diff --git a/patsy/contrasts.py b/patsy/contrasts.py index ea3d0bc..4173f2e 100644 --- a/patsy/contrasts.py +++ b/patsy/contrasts.py @@ -75,7 +75,6 @@ def test_ContrastMatrix(): from nose.tools import assert_raises assert_raises(PatsyError, ContrastMatrix, [[1], [0]], ["a", "b"]) - # assert_no_pickling(cm) # This always produces an object of the type that Python calls 'str' (whether # that be a Python 2 string-of-bytes or a Python 3 string-of-unicode). It does diff --git a/patsy/desc.py b/patsy/desc.py index de0263c..10b9844 100644 --- a/patsy/desc.py +++ b/patsy/desc.py @@ -66,14 +66,12 @@ def name(self): return "Intercept" def __getstate__(self): - return (0, self.factors) + return {'version': 0, 'factors': self.factors} def __setstate__(self, pickle): - version, factors = pickle - check_pickle_version(version, 0, name=self.__class__.__name__) - self.factors = factors + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.factors = pickle['factors'] - # __getstate__ = no_pickling INTERCEPT = Term([]) @@ -85,12 +83,6 @@ def __init__(self, name): def name(self): return self._name - def __eq__(self, other): - return self.__dict__ == other.__dict__ - - def __hash__(self): - return hash((_MockFactor, str(self._name))) - def test_Term(): assert Term([1, 2, 1]).factors == (1, 2) @@ -102,11 +94,12 @@ def test_Term(): assert Term([f2, f1]).name() == "b:a" assert Term([]).name() == "Intercept" - # assert_no_pickling(Term([])) - from six.moves import cPickle as pickle + from patsy.util import assert_pickled_equals t = Term([f1, f2]) - assert t == pickle.loads(pickle.dumps(t, pickle.HIGHEST_PROTOCOL)) + t2 = pickle.loads(pickle.dumps(t, pickle.HIGHEST_PROTOCOL)) + assert_pickled_equals(t, t2) + class ModelDesc(object): """A simple container representing the termlists parsed from a formula. @@ -168,7 +161,7 @@ def term_code(term): if term != INTERCEPT] result += " + ".join(term_names) return result - + @classmethod def from_formula(cls, tree_or_string): """Construct a :class:`ModelDesc` from a formula string. @@ -186,10 +179,15 @@ def from_formula(cls, tree_or_string): assert isinstance(value, cls) return value - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __getstate__(self): + return {'version': 0, 'lhs_termlist': self.lhs_termlist, + 'rhs_termlist': self.rhs_termlist} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.lhs_termlist = pickle['lhs_termlist'] + self.rhs_termlist = pickle['rhs_termlist'] - # __getstate__ = no_pickling def test_ModelDesc(): f1 = _MockFactor("a") @@ -202,7 +200,9 @@ def test_ModelDesc(): # assert_no_pickling(m) from six.moves import cPickle as pickle - assert m == pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL)) + from patsy.util import assert_pickled_equals + m2 = pickle.loads(pickle.dumps(m, pickle.HIGHEST_PROTOCOL)) + assert_pickled_equals(m, m2) assert ModelDesc([], []).describe() == "~ 0" assert ModelDesc([INTERCEPT], []).describe() == "1 ~ 0" @@ -234,7 +234,12 @@ def _pretty_repr_(self, p, cycle): # pragma: no cover [self.intercept, self.intercept_origin, self.intercept_removed, self.terms]) - # __getstate__ = no_pickling + __getstate__ = no_pickling + + +def test_IntermediateExpr_smoke(): + assert_no_pickling(IntermediateExpr(False, None, True, [])) + def _maybe_add_intercept(doit, terms): if doit: diff --git a/patsy/design_info.py b/patsy/design_info.py index 5a755af..885503f 100644 --- a/patsy/design_info.py +++ b/patsy/design_info.py @@ -121,16 +121,18 @@ def __repr__(self): kwlist.append(("categories", self.categories)) repr_pretty_impl(p, self, [], kwlist) - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __getstate__(self): + return {'version': 0, 'factor': self.factor, 'type': self.type, + 'state': self.state, 'num_columns': self.num_columns, + 'categories': self.categories} - def __hash__(self): - if not self.categories: - categories = 'NoCategories' - else: - categories = frozenset(self.categories) - return hash((FactorInfo, str(self.factor), str(self.type), - str(self.state), str(self.num_columns), categories)) + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.factor = pickle['factor'] + self.type = pickle['type'] + self.state = pickle['state'] + self.num_columns = pickle['num_columns'] + self.categories = pickle['categories'] def test_FactorInfo(): @@ -245,10 +247,17 @@ def _repr_pretty_(self, p, cycle): ("contrast_matrices", self.contrast_matrices), ("num_columns", self.num_columns)]) - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __getstate__(self): + return {'version': 0, 'factors': self.factors, + 'contrast_matrices': self.contrast_matrices, + 'num_columns': self.num_columns} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.factors = pickle['factors'] + self.contrast_matrices = pickle['contrast_matrices'] + self.num_columns = pickle['num_columns'] - # __getstate__ = no_pickling def test_SubtermInfo(): cm = ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"]) @@ -706,21 +715,19 @@ def from_array(cls, array_like, default_column_prefix="column"): return DesignInfo(column_names) def __getstate__(self): - return (0, self.column_name_indexes, self.factor_infos, - self.term_codings, self.term_slices, self.term_name_slices) + return {'version': 0, 'column_name_indexes': self.column_name_indexes, + 'factor_infos': self.factor_infos, + 'term_codings': self.term_codings, + 'term_slices': self.term_slices, + 'term_name_slices': self.term_name_slices} def __setstate__(self, pickle): - (version, column_name_indexes, factor_infos, term_codings, - term_slices, term_name_slices) = pickle - check_pickle_version(version, 0, self.__class__.__name__) - self.column_name_indexes = column_name_indexes - self.factor_infos = factor_infos - self.term_codings = term_codings - self.term_slices = term_slices - self.term_name_slices = term_name_slices - - def __eq__(self, other): - return self.__dict__ == other.__dict__ + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.column_name_indexes = pickle['column_name_indexes'] + self.factor_infos = pickle['factor_infos'] + self.term_codings = pickle['term_codings'] + self.term_slices = pickle['term_slices'] + self.term_name_slices = pickle['term_name_slices'] class _MockFactor(object): @@ -772,9 +779,12 @@ def test_DesignInfo(): # smoke test repr(di) - from six.moves import cPickle as pickle - assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL)) + # Pickling check + from six.moves import cPickle as pickle + from patsy.util import assert_pickled_equals + di2 = pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL)) + assert_pickled_equals(di, di2) # One without term objects di = DesignInfo(["a1", "a2", "a3", "b"]) @@ -795,7 +805,8 @@ def test_DesignInfo(): assert di.slice("a3") == slice(2, 3) assert di.slice("b") == slice(3, 4) - assert di == pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL)) + di2 = pickle.loads(pickle.dumps(di, pickle.HIGHEST_PROTOCOL)) + assert_pickled_equals(di, di2) # Check intercept handling in describe() assert DesignInfo(["Intercept", "a", "b"]).describe() == "1 + a + b" diff --git a/patsy/eval.py b/patsy/eval.py index aa7a576..a0d5913 100644 --- a/patsy/eval.py +++ b/patsy/eval.py @@ -10,7 +10,7 @@ # for __future__ flags! # These are made available in the patsy.* namespace -__all__ = ["EvalEnvironment", "EvalFactor", "VarLookupDict"] +__all__ = ["EvalEnvironment", "EvalFactor"] import sys import __future__ @@ -62,9 +62,6 @@ def __contains__(self, key): else: return True - def __eq__(self, other): - return self.__dict__ == other.__dict__ - def get(self, key, default=None): try: return self[key] @@ -98,7 +95,6 @@ def test_VarLookupDict(): assert ds.get("c") is None assert isinstance(repr(ds), six.string_types) - # assert_no_pickling(ds) def ast_names(code): """Iterator that yields all the (ast) names in a Python expression. @@ -255,8 +251,7 @@ def _namespace_ids(self): def __eq__(self, other): return (isinstance(other, EvalEnvironment) and self.flags == other.flags - and self.namespace == other.namespace) - # and self._namespace_ids() == other._namespace_ids()) + and self._namespace_ids() == other._namespace_ids()) def __ne__(self, other): return not self == other @@ -382,7 +377,6 @@ def test_EvalEnvironment_capture_namespace(): assert_raises(TypeError, EvalEnvironment.capture, 1.2) - # assert_no_pickling(EvalEnvironment.capture()) def test_EvalEnvironment_capture_flags(): if sys.version_info >= (3,): @@ -649,15 +643,15 @@ def eval(self, memorize_state, data): data) def __getstate__(self): - return (0, self.code, self.origin) + return {'version': 0, 'code': self.code, 'origin': self.origin} - def __setstate__(self, state): - (version, code, origin) = state - check_pickle_version(version, 0, self.__class__.__name__) - self.code = code - self.origin = origin + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.code = pickle['code'] + self.origin = pickle['origin'] def test_EvalFactor_pickle_saves_origin(): + from patsy.util import assert_pickled_equals # The pickling tests use object equality before and after pickling # to test that pickling worked correctly. But EvalFactor's origin field # is not used in equality comparisons, so we need a separate test to @@ -667,7 +661,7 @@ def test_EvalFactor_pickle_saves_origin(): new_f = pickle.loads(pickle.dumps(f)) assert f.origin is not None - assert f.origin == new_f.origin + assert_pickled_equals(f, new_f) def test_EvalFactor_basics(): e = EvalFactor("a+b") diff --git a/patsy/infix_parser.py b/patsy/infix_parser.py index f6ac31e..bdf395d 100644 --- a/patsy/infix_parser.py +++ b/patsy/infix_parser.py @@ -44,7 +44,7 @@ def __init__(self, print_as): def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._print_as) - # __getstate__ = no_pickling + __getstate__ = no_pickling class Token(object): """A token with possible payload. @@ -70,7 +70,7 @@ def _repr_pretty_(self, p, cycle): kwargs = [("extra", self.extra)] return repr_pretty_impl(p, self, [self.type, self.origin], kwargs) - # __getstate__ = no_pickling + __getstate__ = no_pickling class ParseNode(object): def __init__(self, type, token, args, origin): @@ -83,7 +83,7 @@ def __init__(self, type, token, args, origin): def _repr_pretty_(self, p, cycle): return repr_pretty_impl(p, self, [self.type, self.token, self.args]) - # __getstate__ = no_pickling + __getstate__ = no_pickling class Operator(object): def __init__(self, token_type, arity, precedence): @@ -95,14 +95,14 @@ def __repr__(self): return "%s(%r, %r, %r)" % (self.__class__.__name__, self.token_type, self.arity, self.precedence) - # __getstate__ = no_pickling + __getstate__ = no_pickling class _StackOperator(object): def __init__(self, op, token): self.op = op self.token = token - # __getstate__ = no_pickling + __getstate__ = no_pickling _open_paren = Operator(Token.LPAREN, -1, -9999999) @@ -115,7 +115,7 @@ def __init__(self, unary_ops, binary_ops, atomic_types, trace): self.atomic_types = atomic_types self.trace = trace - # __getstate__ = no_pickling + __getstate__ = no_pickling def _read_noun_context(token, c): if token.type == Token.LPAREN: diff --git a/patsy/mgcv_cubic_splines.py b/patsy/mgcv_cubic_splines.py index 7770cf7..11ec055 100644 --- a/patsy/mgcv_cubic_splines.py +++ b/patsy/mgcv_cubic_splines.py @@ -10,8 +10,9 @@ import numpy as np from patsy.util import (have_pandas, atleast_2d_column_default, - no_pickling, assert_no_pickling, safe_string_eq) -from patsy.state import stateful_transform, StatefulTransform + no_pickling, assert_no_pickling, safe_string_eq, + check_pickle_version) +from patsy.state import stateful_transform if have_pandas: import pandas @@ -541,7 +542,7 @@ def _get_centering_constraint_from_dmatrix(design_matrix): return design_matrix.mean(axis=0).reshape((1, design_matrix.shape[1])) -class CubicRegressionSpline(StatefulTransform): +class CubicRegressionSpline(object): """Base class for cubic regression spline stateful transforms This class contains all the functionality for the following stateful @@ -685,7 +686,7 @@ def transform(self, x, df=None, knots=None, dm.index = x_orig.index return dm - # __getstate__ = no_pickling + __getstate__ = no_pickling class CR(CubicRegressionSpline): @@ -716,6 +717,18 @@ class CR(CubicRegressionSpline): def __init__(self): CubicRegressionSpline.__init__(self, name='cr', cyclic=False) + def __getstate__(self): + return {'version': 0, 'name': self._name, 'cyclic': self._cyclic, + 'all_knots': self._all_knots, 'constraints': self._constraints} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self._name = pickle['name'] + self._cyclic = pickle['cyclic'] + self._all_knots = pickle['all_knots'] + self._constraints = pickle['constraints'] + + cr = stateful_transform(CR) cr.__qualname__ = 'cr' cr.__name__ = 'cr' @@ -748,6 +761,18 @@ class CC(CubicRegressionSpline): def __init__(self): CubicRegressionSpline.__init__(self, name='cc', cyclic=True) + def __getstate__(self): + return {'version': 0, 'name': self._name, 'cyclic': self._cyclic, + 'all_knots': self._all_knots, 'constraints': self._constraints} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self._name = pickle['name'] + self._cyclic = pickle['cyclic'] + self._all_knots = pickle['all_knots'] + self._constraints = pickle['constraints'] + + cc = stateful_transform(CC) cc.__qualname__ = 'cc' cc.__name__ = 'cc' @@ -855,7 +880,7 @@ def test_crs_with_specific_constraint(): assert np.allclose(result1, result2, rtol=1e-12, atol=0.) -class TE(StatefulTransform): +class TE(object): """te(s1, .., sn, constraints=None) Generates smooth of several covariates as a tensor product of the bases @@ -944,7 +969,13 @@ def transform(self, *args, **kwargs): return _get_te_dmatrix(args_2d, self._constraints) - # __getstate__ = no_pickling + def __getstate__(self): + return {'version': 0, 'constraints': self._constraints} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self._constraints = pickle['constraints'] + te = stateful_transform(TE) te.__qualname__ = 'te' diff --git a/patsy/missing.py b/patsy/missing.py index 2bbd1c2..7481c57 100644 --- a/patsy/missing.py +++ b/patsy/missing.py @@ -39,7 +39,7 @@ import numpy as np from patsy import PatsyError from patsy.util import (safe_isnan, safe_scalar_isnan, - no_pickling, assert_no_pickling) + no_pickling, assert_no_pickling, check_pickle_version) # These are made available in the patsy.* namespace __all__ = ["NAAction"] @@ -180,10 +180,15 @@ def _handle_NA_drop(self, values, is_NAs, origins): # "..." to handle 1- versus 2-dim indexing return [v[good_mask, ...] for v in values] - def __eq__(self, other): - return self.__dict__ == other.__dict__ + def __getstate__(self): + return {'version': 0, 'NA_types': self.NA_types, + 'on_NA': self.on_NA} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.NA_types = pickle['NA_types'] + self.on_NA = pickle['on_NA'] - # __getstate__ = no_pickling def test_NAAction_basic(): from nose.tools import assert_raises @@ -191,7 +196,6 @@ def test_NAAction_basic(): assert_raises(ValueError, NAAction, NA_types=("NaN", "asdf")) assert_raises(ValueError, NAAction, NA_types="NaN") - # assert_no_pickling(NAAction()) def test_NAAction_NA_types_numerical(): for NA_types in [[], ["NaN"], ["None"], ["NaN", "None"]]: @@ -236,7 +240,7 @@ def test_NAAction_drop(): assert np.array_equal(out_values[0], [2, 4]) assert np.array_equal(out_values[1], [20.0, 40.0]) assert np.array_equal(out_values[2], [[3.0, 4.0], [6.0, 7.0]]) - + def test_NAAction_raise(): action = NAAction(on_NA="raise") diff --git a/patsy/origin.py b/patsy/origin.py index 2859313..20be583 100644 --- a/patsy/origin.py +++ b/patsy/origin.py @@ -10,6 +10,7 @@ # These are made available in the patsy.* namespace __all__ = ["Origin"] + class Origin(object): """This represents the origin of some object in some string. @@ -118,6 +119,18 @@ def __getstate__(self): raise NotImplementedError """ + def __getstate__(self): + return {'version': 0, 'code': self.code, 'start': self.start, + 'end': self.end} + + def __setstate__(self, pickle): + from patsy.util import check_pickle_version + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.code = pickle['code'] + self.start = pickle['start'] + self.end = pickle['end'] + + def test_Origin(): o1 = Origin("012345", 2, 4) o2 = Origin("012345", 4, 5) @@ -140,5 +153,4 @@ def __init__(self, origin=None): assert Origin.combine([ObjWithOrigin(), ObjWithOrigin()]) is None - # from patsy.util import assert_no_pickling - # assert_no_pickling(Origin("", 0, 0)) + from patsy.util import assert_no_pickling diff --git a/patsy/redundancy.py b/patsy/redundancy.py index 9fa78a8..415fa96 100644 --- a/patsy/redundancy.py +++ b/patsy/redundancy.py @@ -73,7 +73,7 @@ def __repr__(self): suffix = "-" return "%r%s" % (self.factor, suffix) - # __getstate__ = no_pickling + __getstate__ = no_pickling class _Subterm(object): "Also immutable." diff --git a/patsy/splines.py b/patsy/splines.py index 8a8abee..96668fc 100644 --- a/patsy/splines.py +++ b/patsy/splines.py @@ -9,8 +9,9 @@ import numpy as np -from patsy.util import have_pandas, no_pickling, assert_no_pickling -from patsy.state import stateful_transform, StatefulTransform +from patsy.util import (have_pandas, no_pickling, assert_no_pickling, + check_pickle_version) +from patsy.state import stateful_transform if have_pandas: import pandas @@ -74,7 +75,7 @@ def t(x, prob, expected): t([10, 20], [0.3, 0.7], [13, 17]) t(list(range(10)), [0.3, 0.7], [2.7, 6.3]) -class BS(StatefulTransform): +class BS(object): """bs(x, df=None, knots=None, degree=3, include_intercept=False, lower_bound=None, upper_bound=None) Generates a B-spline basis for ``x``, allowing non-linear fits. The usual @@ -245,12 +246,21 @@ def transform(self, x, df=None, knots=None, degree=3, basis.index = x.index return basis - # __getstate__ = no_pickling + def __getstate__(self): + return {'version': 0, 'degree': self._degree, + 'all_knots': self._all_knots} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self._degree = pickle['degree'] + self._all_knots = pickle['all_knots'] + bs = stateful_transform(BS) bs.__qualname__ = 'bs' bs.__name__ = 'bs' + def test_bs_compat(): from patsy.test_state import check_stateful from patsy.test_splines_bs_data import (R_bs_test_x, diff --git a/patsy/state.py b/patsy/state.py index 7d759a6..a45ed72 100644 --- a/patsy/state.py +++ b/patsy/state.py @@ -32,8 +32,7 @@ no_pickling, assert_no_pickling, check_pickle_version) # These are made available in the patsy.* namespace -__all__ = ["stateful_transform", "StatefulTransform", - "center", "standardize", "scale", +__all__ = ["stateful_transform", "center", "standardize", "scale", ] def stateful_transform(class_): @@ -76,18 +75,8 @@ def stateful_transform_wrapper(*args, **kwargs): # class QuantileEstimatingTransform(NonIncrementalStatefulTransform): # def memorize_all(self, input_data, *args, **kwargs): - -class StatefulTransform(object): - def __getstate__(self): - return (0, self.__dict__) - - def __setstate__(self, pickle): - version, dicts = pickle - check_pickle_version(version, 0, name=self.__class__.__name__) - self.__dict__ = dicts - - -class Center(StatefulTransform): + +class Center(object): """center(x) A stateful transform that centers input data, i.e., subtracts the mean. @@ -127,7 +116,14 @@ def transform(self, x): centered = atleast_2d_column_default(x, preserve_pandas=True) - mean_val return pandas_friendly_reshape(centered, x.shape) - # __getstate__ = no_pickling + def __getstate__(self): + return {'version': 0, 'sum': self._sum, 'count': self._count} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self._sum = pickle['sum'] + self._count = pickle['count'] + center = stateful_transform(Center) center.__qualname__ = 'center' @@ -136,7 +132,7 @@ def transform(self, x): # See: # http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm # or page 232 of Knuth vol. 3 (3rd ed.). -class Standardize(StatefulTransform): +class Standardize(object): """standardize(x, center=True, rescale=True, ddof=0) A stateful transform that standardizes input data, i.e. it subtracts the @@ -187,7 +183,16 @@ def transform(self, x, center=True, rescale=True, ddof=0): x_2d /= np.sqrt(self.current_M2 / (self.current_n - ddof)) return pandas_friendly_reshape(x_2d, x.shape) - # __getstate__ = no_pickling + def __getstate__(self): + return {'version': 0, 'current_n': self.current_n, + 'current_mean': self.current_mean, + 'current_M2': self.current_M2} + + def __setstate__(self, pickle): + check_pickle_version(pickle['version'], 0, self.__class__.__name__) + self.current_M2 = pickle['current_M2'] + self.current_mean = pickle['current_mean'] + self.current_n = pickle['current_n'] standardize = stateful_transform(Standardize) diff --git a/patsy/test_pickling.py b/patsy/test_pickling.py index e00d081..fb34667 100644 --- a/patsy/test_pickling.py +++ b/patsy/test_pickling.py @@ -6,38 +6,39 @@ import os import shutil -from patsy import EvalFactor, EvalEnvironment, VarLookupDict +from patsy import EvalFactor, EvalEnvironment import numpy as np +from patsy.eval import VarLookupDict from patsy.state import center, scale, standardize from patsy.categorical import C from patsy.splines import bs -from patsy.desc import Term, ModelDesc +from patsy.desc import Term, ModelDesc, _MockFactor from patsy.mgcv_cubic_splines import cc, te, cr from patsy.contrasts import ContrastMatrix from patsy.constraint import LinearConstraint from patsy.missing import NAAction from patsy.origin import Origin +from patsy.design_info import SubtermInfo +from patsy.util import assert_pickled_equals -PICKE_TESTCASES_ROOTDIR = os.path.join(os.path.dirname(__file__), '..', 'pickle_testcases') +PICKLE_TESTCASES_ROOTDIR = os.path.join(os.path.dirname(__file__), '..', + 'pickle_testcases') +f1 = _MockFactor("a") +f2 = _MockFactor("b") -class _MockFactor(object): - def __init__(self, name): - self._name = name - - def name(self): - return self._name +cm = ContrastMatrix(np.ones((2, 2)), ["[1]", "[2]"]) +si = SubtermInfo(["a", "x"], {"a": cm}, 4) - def __eq__(self, other): - return self.__dict__ == other.__dict__ - def __hash__(self): - return hash((_MockFactor, str(self._name))) +def _unwrap_stateful_function(function, *args, **kwargs): + obj = function.__patsy_stateful_transform__() + obj.memorize_chunk(*args, **kwargs) + obj.memorize_finish() + return (obj, args, kwargs) -f1 = _MockFactor("a") -f2 = _MockFactor("b") pickling_testcases = { "evalfactor_simple": EvalFactor("a+b"), @@ -48,7 +49,7 @@ def __hash__(self): "evalenv_transform_standardize": EvalEnvironment([{ 'standardize': standardize }]), - "evalenv_transform_catgorical": EvalEnvironment([{'C': C}]), + "evalenv_transform_categorical": EvalEnvironment([{'C': C}]), "evalenv_transform_bs": EvalEnvironment([{'cs': bs}]), "evalenv_transform_te": EvalEnvironment([{'te': te}]), "evalenv_transform_cr": EvalEnvironment([{'cs': cr}]), @@ -56,25 +57,80 @@ def __hash__(self): "evalenv_pickle": EvalEnvironment([{'np': np}]), "term": Term([1, 2, 1]), "contrast_matrix": ContrastMatrix([[1, 0], [0, 1]], ["a", "b"]), + "subterm_info": si, "linear_constraint": LinearConstraint(["a"], [[0]]), "model_desc": ModelDesc([Term([]), Term([f1])], [Term([f1]), Term([f1, f2])]), "na_action": NAAction(NA_types=["NaN", "None"]), - "origin": Origin("012345", 2, 5) + "origin": Origin("012345", 2, 5), + "transform_center": _unwrap_stateful_function(center, + np.arange(10, 20, 0.1)), + "transform_standardize_norescale": _unwrap_stateful_function( + standardize, + np.arange(10, 20, 0.1), + ), + "transform_standardize_rescale": _unwrap_stateful_function( + standardize, + np.arange(10, 20, 0.1), + rescale=True + ), + "transform_bs_df3": _unwrap_stateful_function( + bs, + np.arange(10, 20, 0.1), + df=3 + ), + "transform_bs_knots_13_15_17": _unwrap_stateful_function( + bs, + np.arange(10, 20, 0.1), + knots=[13, 15, 17] + ), + "transform_cc_df3": _unwrap_stateful_function( + cc, + np.arange(10, 20, 0.1), + df=3 + ), + "transform_cc_knots_13_15_17": _unwrap_stateful_function( + cc, + np.arange(10, 20, 0.1), + knots=[13, 15, 17] + ), + "transform_cr_df3": _unwrap_stateful_function( + cr, + np.arange(10, 20, 0.1), + df=3 + ), + "transform_cr_knots_13_15_17": _unwrap_stateful_function( + cr, + np.arange(10, 20, 0.1), + knots=[13, 15, 17] + ), + "transform_te_cr5": _unwrap_stateful_function( + te, + cr(np.arange(10, 20, 0.1), df=5) + ), + "transform_te_cr5_center": _unwrap_stateful_function( + te, + cr(np.arange(10, 20, 0.1), df=5), + constraint='center' + ), } def test_pickling_same_version_roundtrips(): for obj in six.itervalues(pickling_testcases): - yield (check_pickling_same_version_roundtrips, obj) + if isinstance(obj, tuple): + yield (check_pickling_same_version_roundtrips, obj[0]) + else: + yield (check_pickling_same_version_roundtrips, obj) def check_pickling_same_version_roundtrips(obj): - assert obj == pickle.loads(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)) + pickled_obj = pickle.loads(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)) + assert_pickled_equals(obj, pickled_obj) def test_pickling_old_versions_still_work(): - for (dirpath, dirnames, filenames) in os.walk(PICKE_TESTCASES_ROOTDIR): + for (dirpath, dirnames, filenames) in os.walk(PICKLE_TESTCASES_ROOTDIR): for fname in filenames: if os.path.splitext(fname)[1] == '.pickle': yield check_pickling_old_versions_still_work, os.path.join(dirpath, fname) @@ -88,7 +144,26 @@ def check_pickling_old_versions_still_work(pickle_filename): # equal to any instance of a previous version. How do we handle # that? # Maybe adding a minimum version requirement to each test? - assert pickling_testcases[testcase_name] == pickle.load(f) + obj = pickling_testcases[testcase_name] + if isinstance(obj, tuple): + assert_pickled_equals(pickling_testcases[testcase_name][0], + pickle.load(f)) + else: + assert_pickled_equals(pickling_testcases[testcase_name], + pickle.load(f)) + + +def test_pickling_transforms(): + for obj in six.itervalues(pickling_testcases): + if isinstance(obj, tuple): + obj, args, kwargs = obj + yield (check_pickling_transforms, obj, args, kwargs) + + +def check_pickling_transforms(obj, args, kwargs): + pickled_obj = pickle.loads(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)) + np.testing.assert_allclose(obj.transform(*args, **kwargs), + pickled_obj.transform(*args, **kwargs)) def test_unpickling_future_gives_sensible_error_msg(): @@ -102,7 +177,7 @@ def create_pickles(version): # TODO Add safety check that said force=True option will still give an # error when trying to remove pickles for a released version, by # comparing the version argument here with patsy.__version__. - pickle_testcases_dir = os.path.join(PICKE_TESTCASES_ROOTDIR, version) + pickle_testcases_dir = os.path.join(PICKLE_TESTCASES_ROOTDIR, version) if os.path.exists(pickle_testcases_dir): raise OSError("{} already exists. Aborting.".format(pickle_testcases_dir)) pickle_testcases_tempdir = pickle_testcases_dir + "_inprogress" @@ -110,6 +185,8 @@ def create_pickles(version): try: for name, obj in six.iteritems(pickling_testcases): + if isinstance(obj, tuple): + obj = obj[0] with open(os.path.join(pickle_testcases_tempdir, "{}.pickle".format(name)), "wb") as f: pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) except Exception: diff --git a/patsy/util.py b/patsy/util.py index 8834ee4..a12b47e 100644 --- a/patsy/util.py +++ b/patsy/util.py @@ -20,7 +20,8 @@ "safe_issubdtype", "no_pickling", "assert_no_pickling", - "safe_string_eq", + "assert_pickled_equals", + "safe_string_eq" ] import sys @@ -28,7 +29,7 @@ import numpy as np import six from six.moves import cStringIO as StringIO -from .compat import optional_dep_ok +from patsy.compat import optional_dep_ok try: import pandas @@ -746,9 +747,125 @@ def check_pickle_version(version, required_version, name=""): error_msg += "." # TODO Use a better exception than ValueError. - raise ValueError(error_msg) + raise PickleError(error_msg) def test_check_pickle_version(): - assert_raises(ValueError, check_pickle_version, 0, 1) - assert_raises(ValueError, check_pickle_version, 1, 0) + assert_raises(PickleError, check_pickle_version, 0, 1) + assert_raises(PickleError, check_pickle_version, 1, 0) check_pickle_version(0, 0) + + +def assert_pickled_equals(obj1, obj2): + def _walk_dict(obj): + _dict = {key: obj.__dict__[key] for key in six.iterkeys(obj.__dict__)} + for key in six.iterkeys(_dict): + if isinstance(_dict[key], dict): + newdict = {} + for key2 in six.iterkeys(_dict[key]): + if hasattr(key2, '__dict__'): + newkey = [str(key2.__dict__[i])[:6] for i in + sorted(six.iterkeys(key2.__dict__))] + newkey = str(newkey) + newdict[newkey] = _dict[key][key2] + else: + newdict[key2] = _dict[key][key2] + for key2 in six.iterkeys(newdict): + if isinstance(newdict[key2], (list, tuple)): + newdict[key2] = [_walk_dict(i) if + hasattr(i, '__dict__') else i for i in + newdict[key2]] + _dict[key] = {key2: _walk_dict(newdict[key2]) if + hasattr(newdict[key2], '__dict__') else + newdict[key2] for key2 in six.iterkeys(newdict)} + elif hasattr(_dict[key], '__dict__'): + _dict[key] = _walk_dict(_dict[key]) + if isinstance(_dict[key], (list, tuple)): + _dict[key] = [_walk_dict(i) if hasattr(i, '__dict__') else i + for i in _dict[key]] + return _dict + + if hasattr(obj1, '__dict__') and hasattr(obj2, '__dict__'): + obj1 = _walk_dict(obj1) + obj2 = _walk_dict(obj2) + + def _walk_dict_numpy_equals(_obj1, _obj2): + for key in six.iterkeys(_obj1): + if isinstance(_obj1[key], np.ndarray): + np.testing.assert_allclose(_obj1[key], _obj2[key]) + elif isinstance(_obj1[key], dict): + _walk_dict_numpy_equals(_obj1[key], _obj2[key]) + + def _walk_dict_remove_numpy(_obj1): + for key in six.iterkeys(_obj1): + if isinstance(_obj1[key], np.ndarray): + _obj1[key] = 0 + elif isinstance(_obj1[key], dict): + _walk_dict_remove_numpy(_obj1[key]) + + _walk_dict_numpy_equals(obj1, obj2) + _walk_dict_numpy_equals(obj2, obj1) + _walk_dict_remove_numpy(obj1) + _walk_dict_remove_numpy(obj2) + + assert obj1 == obj2 + + +def test_assert_pickled_equals(): + class _MockObject(object): + def __init__(self, foo): + self.foo = foo + + obj1 = _MockObject('bar') + obj2 = _MockObject('bar') + + assert_pickled_equals(obj1, obj2) + + obj3 = _MockObject('baz') + + assert_raises(AssertionError, assert_pickled_equals, obj1, obj3) + + obj4 = _MockObject(obj1) + obj5 = _MockObject(obj2) + + assert_pickled_equals(obj4, obj5) + + obj6 = _MockObject(_MockObject(np.array([[1, 2], [3, 4]]))) + obj7 = _MockObject(_MockObject(np.array([[1, 2], [3, 4]]))) + + assert_pickled_equals(obj6, obj7) + + obj8 = _MockObject(_MockObject(np.array([[1, 2], [3, 5]]))) + + assert_raises(AssertionError, assert_pickled_equals, obj6, obj8) + + obj9 = _MockObject([_MockObject('a'), _MockObject('b')]) + obj10 = _MockObject([_MockObject('a'), _MockObject('b')]) + + assert_pickled_equals(obj9, obj10) + + obj11 = _MockObject({_MockObject('a'): _MockObject('c')}) + obj12 = _MockObject({_MockObject('a'): _MockObject('c')}) + + assert_pickled_equals(obj11, obj12) + + obj13 = _MockObject({_MockObject('a'): _MockObject('d')}) + + assert_raises(AssertionError, assert_pickled_equals, obj11, obj13) + + +class PickleError(Exception): + """This is the error type for pickle problems. + + For ordinary display to the user with default formatting, use + ``str(exc)``. If you want to do something cleverer, you can use the + ``.message`` attribute directly. + """ + def __init__(self, message): + Exception.__init__(self, message) + self.message = message + + def __str__(self): + if self.origin is None: + return self.message + else: + return self.message diff --git a/pickle_testcases/0.5/contrast_matrix.pickle b/pickle_testcases/0.5/contrast_matrix.pickle new file mode 100644 index 0000000..0d63acd Binary files /dev/null and b/pickle_testcases/0.5/contrast_matrix.pickle differ diff --git a/pickle_testcases/0.5/evalenv_pickle.pickle b/pickle_testcases/0.5/evalenv_pickle.pickle new file mode 100644 index 0000000..8e62cdc Binary files /dev/null and b/pickle_testcases/0.5/evalenv_pickle.pickle differ diff --git a/pickle_testcases/0.5/evalenv_simple.pickle b/pickle_testcases/0.5/evalenv_simple.pickle new file mode 100644 index 0000000..7b8258a Binary files /dev/null and b/pickle_testcases/0.5/evalenv_simple.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_bs.pickle b/pickle_testcases/0.5/evalenv_transform_bs.pickle new file mode 100644 index 0000000..547608f Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_bs.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_categorical.pickle b/pickle_testcases/0.5/evalenv_transform_categorical.pickle new file mode 100644 index 0000000..33d403c Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_categorical.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_cc.pickle b/pickle_testcases/0.5/evalenv_transform_cc.pickle new file mode 100644 index 0000000..217977d Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_cc.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_center.pickle b/pickle_testcases/0.5/evalenv_transform_center.pickle new file mode 100644 index 0000000..60416ed Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_center.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_cr.pickle b/pickle_testcases/0.5/evalenv_transform_cr.pickle new file mode 100644 index 0000000..550190c Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_cr.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_scale.pickle b/pickle_testcases/0.5/evalenv_transform_scale.pickle new file mode 100644 index 0000000..bd2ea17 Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_scale.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_standardize.pickle b/pickle_testcases/0.5/evalenv_transform_standardize.pickle new file mode 100644 index 0000000..95f9030 Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_standardize.pickle differ diff --git a/pickle_testcases/0.5/evalenv_transform_te.pickle b/pickle_testcases/0.5/evalenv_transform_te.pickle new file mode 100644 index 0000000..bfece81 Binary files /dev/null and b/pickle_testcases/0.5/evalenv_transform_te.pickle differ diff --git a/pickle_testcases/0.5/evalfactor_simple.pickle b/pickle_testcases/0.5/evalfactor_simple.pickle index 1be6d51..e7d42fa 100644 Binary files a/pickle_testcases/0.5/evalfactor_simple.pickle and b/pickle_testcases/0.5/evalfactor_simple.pickle differ diff --git a/pickle_testcases/0.5/linear_constraint.pickle b/pickle_testcases/0.5/linear_constraint.pickle new file mode 100644 index 0000000..957aa39 Binary files /dev/null and b/pickle_testcases/0.5/linear_constraint.pickle differ diff --git a/pickle_testcases/0.5/model_desc.pickle b/pickle_testcases/0.5/model_desc.pickle new file mode 100644 index 0000000..5286dde Binary files /dev/null and b/pickle_testcases/0.5/model_desc.pickle differ diff --git a/pickle_testcases/0.5/na_action.pickle b/pickle_testcases/0.5/na_action.pickle new file mode 100644 index 0000000..dc5b370 Binary files /dev/null and b/pickle_testcases/0.5/na_action.pickle differ diff --git a/pickle_testcases/0.5/origin.pickle b/pickle_testcases/0.5/origin.pickle new file mode 100644 index 0000000..4759d5e Binary files /dev/null and b/pickle_testcases/0.5/origin.pickle differ diff --git a/pickle_testcases/0.5/subterm_info.pickle b/pickle_testcases/0.5/subterm_info.pickle new file mode 100644 index 0000000..3b301ff Binary files /dev/null and b/pickle_testcases/0.5/subterm_info.pickle differ diff --git a/pickle_testcases/0.5/term.pickle b/pickle_testcases/0.5/term.pickle new file mode 100644 index 0000000..3c4e48e Binary files /dev/null and b/pickle_testcases/0.5/term.pickle differ diff --git a/pickle_testcases/0.5/transform_bs_df3.pickle b/pickle_testcases/0.5/transform_bs_df3.pickle new file mode 100644 index 0000000..a9d9c35 Binary files /dev/null and b/pickle_testcases/0.5/transform_bs_df3.pickle differ diff --git a/pickle_testcases/0.5/transform_bs_knots_13_15_17.pickle b/pickle_testcases/0.5/transform_bs_knots_13_15_17.pickle new file mode 100644 index 0000000..2125231 Binary files /dev/null and b/pickle_testcases/0.5/transform_bs_knots_13_15_17.pickle differ diff --git a/pickle_testcases/0.5/transform_cc_df3.pickle b/pickle_testcases/0.5/transform_cc_df3.pickle new file mode 100644 index 0000000..064b17b Binary files /dev/null and b/pickle_testcases/0.5/transform_cc_df3.pickle differ diff --git a/pickle_testcases/0.5/transform_cc_knots_13_15_17.pickle b/pickle_testcases/0.5/transform_cc_knots_13_15_17.pickle new file mode 100644 index 0000000..15af146 Binary files /dev/null and b/pickle_testcases/0.5/transform_cc_knots_13_15_17.pickle differ diff --git a/pickle_testcases/0.5/transform_center.pickle b/pickle_testcases/0.5/transform_center.pickle new file mode 100644 index 0000000..f98f8d2 Binary files /dev/null and b/pickle_testcases/0.5/transform_center.pickle differ diff --git a/pickle_testcases/0.5/transform_cr_df3.pickle b/pickle_testcases/0.5/transform_cr_df3.pickle new file mode 100644 index 0000000..3f7e1d3 Binary files /dev/null and b/pickle_testcases/0.5/transform_cr_df3.pickle differ diff --git a/pickle_testcases/0.5/transform_cr_knots_13_15_17.pickle b/pickle_testcases/0.5/transform_cr_knots_13_15_17.pickle new file mode 100644 index 0000000..512b8ab Binary files /dev/null and b/pickle_testcases/0.5/transform_cr_knots_13_15_17.pickle differ diff --git a/pickle_testcases/0.5/transform_standardize_norescale.pickle b/pickle_testcases/0.5/transform_standardize_norescale.pickle new file mode 100644 index 0000000..7218243 Binary files /dev/null and b/pickle_testcases/0.5/transform_standardize_norescale.pickle differ diff --git a/pickle_testcases/0.5/transform_standardize_rescale.pickle b/pickle_testcases/0.5/transform_standardize_rescale.pickle new file mode 100644 index 0000000..7218243 Binary files /dev/null and b/pickle_testcases/0.5/transform_standardize_rescale.pickle differ diff --git a/pickle_testcases/0.5/transform_te_cr5.pickle b/pickle_testcases/0.5/transform_te_cr5.pickle new file mode 100644 index 0000000..3c641bc Binary files /dev/null and b/pickle_testcases/0.5/transform_te_cr5.pickle differ diff --git a/pickle_testcases/0.5/transform_te_cr5_center.pickle b/pickle_testcases/0.5/transform_te_cr5_center.pickle new file mode 100644 index 0000000..3c641bc Binary files /dev/null and b/pickle_testcases/0.5/transform_te_cr5_center.pickle differ diff --git a/pickle_testcases/0.5/varlookupdict_simple.pickle b/pickle_testcases/0.5/varlookupdict_simple.pickle new file mode 100644 index 0000000..853cfcd Binary files /dev/null and b/pickle_testcases/0.5/varlookupdict_simple.pickle differ