diff --git a/patsy/design_info.py b/patsy/design_info.py index f4b5822..83f725a 100644 --- a/patsy/design_info.py +++ b/patsy/design_info.py @@ -36,6 +36,7 @@ from patsy.constraint import linear_constraint from patsy.contrasts import ContrastMatrix from patsy.desc import ModelDesc, Term +from collections import OrderedDict class FactorInfo(object): """A FactorInfo object is a simple class that provides some metadata about @@ -684,6 +685,49 @@ def var_names(self, eval_env=0): else: return {} + def partial(self, columns, product=False): + """Returns a partial prediction array where only the variables in the + dict ``columns`` are tranformed per the :class:`DesignInfo` + transformations. The terms that are not influenced by ``columns`` + return as zero. + + This is useful to perform a partial prediction on unseen data and to + view marginal differences in factors. + + :arg columns: A dict with the keys as the column names for the marginal + predictions desired and values as the marginal values to be predicted. + + :arg product: When `True`, the resturned numpy array represents the + Cartesian product of the values ``columns``. + + :returns: A numpy array of the partial design matrix. + """ + from .highlevel import dmatrix + if product: + columns = _column_product(columns) + rows = None + for col in columns: + if rows and rows != len(columns[col]): + raise ValueError('all columns must be of same length') + rows = len(columns[col]) + parts = [] + for term, subterm in six.iteritems(self.term_codings): + term_vars = term.var_names() + present = True + for term_var in term_vars: + if term_var not in columns: + present = False + if present and (term.name() != 'Intercept'): + # This seems like an inelegent way to not having the Intercept + # in the output + di = self.subset('0 + {}'.format(term.name())) + parts.append(dmatrix(di, columns)) + else: + num_columns = np.sum(s.num_columns for s in subterm) + dm = np.zeros((rows, num_columns)) + parts.append(dm) + return np.hstack(parts) + @classmethod def from_array(cls, array_like, default_column_prefix="column"): """Find or construct a DesignInfo appropriate for a given array_like. @@ -1230,3 +1274,61 @@ def test_design_matrix(): repr(DesignMatrix(np.zeros((1, 0)))) repr(DesignMatrix(np.zeros((0, 1)))) repr(DesignMatrix(np.zeros((0, 0)))) + + +def test_DesignInfo_partial(): + from .highlevel import dmatrix + from numpy.testing import assert_allclose + a = np.array(['a', 'b', 'a', 'b', 'a', 'a', 'b', 'a']) + b = np.array([1, 3, 2, 4, 1, 3, 1, 1]) + c = np.array([4, 3, 2, 1, 6, 4, 2, 1]) + dm = dmatrix('a + bs(b, df=3, degree=3) + np.log(c)') + x = np.zeros((3, 6)) + x[1, 1] = 1 + y = dm.design_info.partial({'a': ['a', 'b', 'a']}) + assert_allclose(x, y) + + x = np.zeros((2, 6)) + x[1, 1] = 1 + x[1, 5] = np.log(3) + y = dm.design_info.partial({'a': ['a', 'b'], 'c': [1, 3]}) + assert_allclose(x, y) + + x = np.zeros((4, 6)) + x[2, 1] = 1 + x[3, 1] = 1 + x[1, 5] = np.log(3) + x[3, 5] = np.log(3) + y = dm.design_info.partial({'a': ['a', 'b'], 'c': [1, 3]}, product=True) + assert_allclose(x, y) + + dm = dmatrix('a * b') + y = dm.design_info.partial({'a': ['a', 'b'], 'b': [1, 3]}) + x = np.array([[0, 0, 1, 0], [0, 1, 3, 3]]) + assert_allclose(x, y) + + from nose.tools import assert_raises + assert_raises(ValueError, dm.design_info.partial, {'a': ['a', 'b'], + 'b': [1, 2, 3]}) + + +def _column_product(columns): + from itertools import product + cols = [] + values = [] + for col, value in six.iteritems(columns): + cols.append(col) + values.append(value) + values = [value for value in product(*values)] + values = [value for value in zip(*values)] + return OrderedDict([(col, list(value)) + for col, value in zip(cols, values)]) + + +def test_column_product(): + x = OrderedDict([('a', [1, 2, 3]), ('b', ['a', 'b'])]) + y = OrderedDict([('a', [1, 1, 2, 2, 3, 3]), + ('b', ['a', 'b', 'a', 'b', 'a', 'b'])]) + x = _column_product(x) + assert x['a'] == y['a'] + assert x['b'] == y['b']