Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Generate var_names from the data and partial predict #98

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions patsy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def test__subterm_column_names_iter_and__build_subterm():
mat3)
assert np.allclose(mat3, 1)

def _factors_memorize(factors, data_iter_maker, eval_env):
def _factors_memorize(factors, data_iter_maker, eval_env, var_names):
# First, start off the memorization process by setting up each factor's
# state and finding out how many passes it will need:
factor_states = {}
Expand All @@ -362,7 +362,7 @@ def _factors_memorize(factors, data_iter_maker, eval_env):
memorize_needed.add(factor)
which_pass = 0
while memorize_needed:
for data in data_iter_maker():
for data in safe_data_maker(data_iter_maker, var_names):
for factor in memorize_needed:
state = factor_states[factor]
factor.memorize_chunk(state, which_pass, data)
Expand All @@ -373,6 +373,18 @@ def _factors_memorize(factors, data_iter_maker, eval_env):
which_pass += 1
return factor_states


def safe_data_maker(data_iter_maker, var_names):
"""Safely test if the `data_iter_maker` can accept var_names as a
parameter.
"""
var_names = list(var_names)
try:
return data_iter_maker(var_names)
except TypeError:
return data_iter_maker()


def test__factors_memorize():
class MockFactor(object):
def __init__(self, requested_passes, token):
Expand Down Expand Up @@ -408,7 +420,7 @@ def __call__(self):
f1 = MockFactor(1, "f1")
f2a = MockFactor(2, "f2a")
f2b = MockFactor(2, "f2b")
factor_states = _factors_memorize(set([f0, f1, f2a, f2b]), data, {})
factor_states = _factors_memorize(set([f0, f1, f2a, f2b]), data, {}, [])
assert data.calls == 2
mem_chunks0 = [("memorize_chunk", 0)] * data.CHUNKS
mem_chunks1 = [("memorize_chunk", 1)] * data.CHUNKS
Expand All @@ -434,11 +446,12 @@ def __call__(self):
}
assert factor_states == expected

def _examine_factor_types(factors, factor_states, data_iter_maker, NA_action):
def _examine_factor_types(factors, factor_states, data_iter_maker, NA_action,
var_names):
num_column_counts = {}
cat_sniffers = {}
examine_needed = set(factors)
for data in data_iter_maker():
for data in safe_data_maker(data_iter_maker, var_names):
for factor in list(examine_needed):
value = factor.eval(factor_states[factor], data)
if factor in cat_sniffers or guess_categorical(value):
Expand Down Expand Up @@ -519,9 +532,10 @@ def next(self):
}

it = DataIterMaker()
var_names = []
(num_column_counts, cat_levels_contrasts,
) = _examine_factor_types(factor_states.keys(), factor_states, it,
NAAction())
NAAction(), var_names)
assert it.i == 2
iterations = 0
assert num_column_counts == {num_1dim: 1, num_1col: 1, num_4col: 4}
Expand All @@ -537,7 +551,7 @@ def next(self):
no_read_necessary = [num_1dim, num_1col, num_4col, categ_1col, bool_1col]
(num_column_counts, cat_levels_contrasts,
) = _examine_factor_types(no_read_necessary, factor_states, it,
NAAction())
NAAction(), var_names)
assert it.i == 0
assert num_column_counts == {num_1dim: 1, num_1col: 1, num_4col: 4}
assert cat_levels_contrasts == {
Expand All @@ -562,7 +576,7 @@ def next(self):
it = DataIterMaker()
try:
_examine_factor_types([illegal_factor], illegal_factor_states, it,
NAAction())
NAAction(), var_names)
except PatsyError as e:
assert e.origin is illegal_factor.origin
else:
Expand Down Expand Up @@ -686,14 +700,18 @@ def design_matrix_builders(termlists, data_iter_maker, eval_env,
for termlist in termlists:
for term in termlist:
all_factors.update(term.factors)
factor_states = _factors_memorize(all_factors, data_iter_maker, eval_env)
var_names = {i for f in all_factors
for i in f.var_names(eval_env=eval_env)}
factor_states = _factors_memorize(all_factors, data_iter_maker, eval_env,
var_names)
# Now all the factors have working eval methods, so we can evaluate them
# on some data to find out what type of data they return.
(num_column_counts,
cat_levels_contrasts) = _examine_factor_types(all_factors,
factor_states,
data_iter_maker,
NA_action)
NA_action,
var_names)
# Now we need the factor infos, which encapsulate the knowledge of
# how to turn any given factor into a chunk of data:
factor_infos = {}
Expand Down
32 changes: 31 additions & 1 deletion patsy/desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,30 @@ def name(self):
else:
return "Intercept"

def var_names(self, eval_env=0):
"""Returns a set of variable names that are used in the :class:`Term`,
but not available in the current evalulation environment. These are
likely to be provided by data.

:arg eval_env: Either a :class:`EvalEnvironment` which will be used to
look up any variables referenced in the :class:`Term` that cannot be
found in :class:`EvalEnvironment`, or else a depth represented as an
integer which will be passed to :meth:`EvalEnvironment.capture`.
``eval_env=0`` means to use the context of the function calling
:meth:`var_names` for lookups. If calling this function from a
library, you probably want ``eval_env=1``, which means that variables
should be resolved in *your* caller's namespace.

:returns: A set of strings of the potential variable names.
"""
if not eval_env:
eval_env = EvalEnvironment.capture(eval_env, reference=1)
if self.factors:
return {i for f in self.factors
for i in f.var_names(eval_env=eval_env)}
else:
return {}

__getstate__ = no_pickling

INTERCEPT = Term([])
Expand All @@ -76,6 +100,9 @@ def __init__(self, name):
def name(self):
return self._name

def var_names(self, eval_env=0):
return {'{}_var'.format(self._name)}

def test_Term():
assert Term([1, 2, 1]).factors == (1, 2)
assert Term([1, 2]) == Term([2, 1])
Expand All @@ -85,6 +112,9 @@ def test_Term():
assert Term([f1, f2]).name() == "a:b"
assert Term([f2, f1]).name() == "b:a"
assert Term([]).name() == "Intercept"
assert Term([f1]).var_names() == {'a_var'}
assert Term([f1, f2]).var_names() == {'a_var', 'b_var'}
assert Term([]).var_names() == {}

assert_no_pickling(Term([]))

Expand Down Expand Up @@ -148,7 +178,7 @@ def term_code(term):
if term != INTERCEPT]
result += " + ".join(term_names)
return result

@classmethod
def from_formula(cls, tree_or_string):
"""Construct a :class:`ModelDesc` from a formula string.
Expand Down
Loading