-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
5 changed files
with
286 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
# This file is part of Patsy | ||
# Copyright (C) 2012-2013 Nathaniel Smith <[email protected]> | ||
# See file LICENSE.txt for license information. | ||
|
||
# R-compatible poly function | ||
|
||
# These are made available in the patsy.* namespace | ||
__all__ = ["poly"] | ||
|
||
import numpy as np | ||
|
||
from patsy.util import have_pandas, no_pickling, assert_no_pickling | ||
from patsy.state import stateful_transform | ||
|
||
if have_pandas: | ||
import pandas | ||
|
||
class Poly(object): | ||
"""poly(x, degree=1, raw=False) | ||
Generates an orthogonal polynomial transformation of x of degree. | ||
Generic usage is something along the lines of:: | ||
y ~ 1 + poly(x, 4) | ||
to fit ``y`` as a function of ``x``, with a 4th degree polynomial. | ||
:arg degree: The number of degrees for the polynomial expansion. | ||
:arg raw: When raw is False (the default), will return orthogonal | ||
polynomials. | ||
.. versionadded:: 0.4.1 | ||
""" | ||
def __init__(self): | ||
self._tmp = {} | ||
self._degree = None | ||
self._raw = None | ||
|
||
def memorize_chunk(self, x, degree=3, raw=False): | ||
args = {"degree": degree, | ||
"raw": raw | ||
} | ||
self._tmp["args"] = args | ||
# XX: check whether we need x values before saving them | ||
x = np.atleast_1d(x) | ||
if x.ndim == 2 and x.shape[1] == 1: | ||
x = x[:, 0] | ||
if x.ndim > 1: | ||
raise ValueError("input to 'poly' must be 1-d, " | ||
"or a 2-d column vector") | ||
# There's no better way to compute exact quantiles than memorizing | ||
# all data. | ||
x = np.array(x, dtype=float) | ||
self._tmp.setdefault("xs", []).append(x) | ||
|
||
def memorize_finish(self): | ||
tmp = self._tmp | ||
args = tmp["args"] | ||
del self._tmp | ||
|
||
if args["degree"] < 1: | ||
raise ValueError("degree must be greater than 0 (not %r)" | ||
% (args["degree"],)) | ||
if int(args["degree"]) != args["degree"]: | ||
raise ValueError("degree must be an integer (not %r)" | ||
% (self._degree,)) | ||
|
||
# These are guaranteed to all be 1d vectors by the code above | ||
scores = np.concatenate(tmp["xs"]) | ||
scores_mean = scores.mean() | ||
# scores -= scores_mean | ||
self.scores_mean = scores_mean | ||
n = args['degree'] | ||
self.degree = n | ||
raw_poly = scores.reshape((-1, 1)) ** np.arange(n + 1).reshape((1, -1)) | ||
raw = args['raw'] | ||
self.raw = raw | ||
if not raw: | ||
q, r = np.linalg.qr(raw_poly) | ||
# Q is now orthognoal of degree n. To match what R is doing, we | ||
# need to use the three-term recurrence technique to calculate | ||
# new alpha, beta, and norm. | ||
|
||
self.alpha = (np.sum(scores.reshape((-1, 1)) * q[:, :n] ** 2, | ||
axis=0) / | ||
np.sum(q[:, :n] ** 2, axis=0)) | ||
|
||
# For reasons I don't understand, the norms R uses are based off | ||
# of the diagonal of the r upper triangular matrix. | ||
|
||
self.norm = np.linalg.norm(q * np.diag(r), axis=0) | ||
self.beta = (self.norm[1:] / self.norm[:n]) ** 2 | ||
|
||
def transform(self, x, degree=3, raw=False): | ||
if have_pandas: | ||
if isinstance(x, (pandas.Series, pandas.DataFrame)): | ||
to_pandas = True | ||
idx = x.index | ||
else: | ||
to_pandas = False | ||
else: | ||
to_pandas = False | ||
x = np.array(x, ndmin=1).flatten() | ||
|
||
if self.raw: | ||
n = self.degree | ||
p = x.reshape((-1, 1)) ** np.arange(n + 1).reshape((1, -1)) | ||
else: | ||
# This is where the three-term recurrance technique is unwound. | ||
|
||
p = np.empty((x.shape[0], self.degree + 1)) | ||
p[:, 0] = 1 | ||
|
||
for i in np.arange(self.degree): | ||
p[:, i + 1] = (x - self.alpha[i]) * p[:, i] | ||
if i > 0: | ||
p[:, i + 1] = (p[:, i + 1] - | ||
(self.beta[i - 1] * p[:, i - 1])) | ||
p /= self.norm | ||
|
||
p = p[:, 1:] | ||
if to_pandas: | ||
p = pandas.DataFrame(p) | ||
p.index = idx | ||
return p | ||
|
||
__getstate__ = no_pickling | ||
|
||
poly = stateful_transform(Poly) | ||
|
||
|
||
def test_poly_compat(): | ||
from patsy.test_state import check_stateful | ||
from patsy.test_poly_data import (R_poly_test_x, | ||
R_poly_test_data, | ||
R_poly_num_tests) | ||
lines = R_poly_test_data.split("\n") | ||
tests_ran = 0 | ||
start_idx = lines.index("--BEGIN TEST CASE--") | ||
while True: | ||
if not lines[start_idx] == "--BEGIN TEST CASE--": | ||
break | ||
start_idx += 1 | ||
stop_idx = lines.index("--END TEST CASE--", start_idx) | ||
block = lines[start_idx:stop_idx] | ||
test_data = {} | ||
for line in block: | ||
key, value = line.split("=", 1) | ||
test_data[key] = value | ||
# Translate the R output into Python calling conventions | ||
kwargs = { | ||
# integer | ||
"degree": int(test_data["degree"]), | ||
# boolen | ||
"raw": (test_data["raw"] == 'TRUE') | ||
} | ||
# Special case: in R, setting intercept=TRUE increases the effective | ||
# dof by 1. Adjust our arguments to match. | ||
# if kwargs["df"] is not None and kwargs["include_intercept"]: | ||
# kwargs["df"] += 1 | ||
output = np.asarray(eval(test_data["output"])) | ||
# Do the actual test | ||
check_stateful(Poly, False, R_poly_test_x, output, **kwargs) | ||
tests_ran += 1 | ||
# Set up for the next one | ||
start_idx = stop_idx + 1 | ||
assert tests_ran == R_poly_num_tests | ||
|
||
|
||
def test_poly_errors(): | ||
from nose.tools import assert_raises | ||
x = np.arange(27) | ||
# Invalid input shape | ||
assert_raises(ValueError, poly, x.reshape((3, 3, 3))) | ||
assert_raises(ValueError, poly, x.reshape((3, 3, 3)), raw=True) | ||
# Invalid degree | ||
assert_raises(ValueError, poly, x, degree=-1) | ||
assert_raises(ValueError, poly, x, degree=0) | ||
assert_raises(ValueError, poly, x, degree=3.5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# This file auto-generated by tools/get-R-bs-test-vectors.R | ||
# Using: R version 3.2.4 Revised (2016-03-16 r70336) | ||
import numpy as np | ||
R_poly_test_x = np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ]) | ||
R_poly_test_data = """ | ||
--BEGIN TEST CASE-- | ||
degree=1 | ||
raw=TRUE | ||
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, ]).reshape((20, 1, ), order="F") | ||
--END TEST CASE-- | ||
--BEGIN TEST CASE-- | ||
degree=1 | ||
raw=FALSE | ||
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, ]).reshape((20, 1, ), order="F") | ||
--END TEST CASE-- | ||
--BEGIN TEST CASE-- | ||
degree=3 | ||
raw=TRUE | ||
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, ]).reshape((20, 3, ), order="F") | ||
--END TEST CASE-- | ||
--BEGIN TEST CASE-- | ||
degree=3 | ||
raw=FALSE | ||
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, ]).reshape((20, 3, ), order="F") | ||
--END TEST CASE-- | ||
--BEGIN TEST CASE-- | ||
degree=5 | ||
raw=TRUE | ||
output=np.array([1, 1.5, 2.25, 3.375, 5.0625, 7.59375, 11.390625, 17.0859375, 25.62890625, 38.443359375, 57.6650390625, 86.49755859375, 129.746337890625, 194.6195068359375, 291.92926025390625, 437.89389038085938, 656.84083557128906, 985.26125335693359, 1477.8918800354004, 2216.8378200531006, 1, 2.25, 5.0625, 11.390625, 25.62890625, 57.6650390625, 129.746337890625, 291.92926025390625, 656.84083557128906, 1477.8918800354004, 3325.2567300796509, 7481.8276426792145, 16834.112196028233, 37876.752441063523, 85222.692992392927, 191751.05923288409, 431439.8832739892, 970739.73736647563, 2184164.4090745705, 4914369.9204177829, 1, 3.375, 11.390625, 38.443359375, 129.746337890625, 437.89389038085938, 1477.8918800354004, 4987.8850951194763, 16834.112196028233, 56815.128661595285, 191751.05923288409, 647159.82491098379, 2184164.4090745705, 7371554.8806266747, 24878997.722115029, 83966617.312138215, 283387333.4284665, 956432250.32107437, 3227958844.8336263, 10894361101.313488, 1, 5.0625, 25.62890625, 129.746337890625, 656.84083557128906, 3325.2567300796509, 16834.112196028233, 85222.692992392927, 431439.8832739892, 2184164.4090745705, 11057332.320940012, 55977744.87475881, 283387333.4284665, 1434648375.4816115, 7262907400.875659, 36768468716.933022, 186140372879.47342, 942335637702.33411, 4770574165868.0674, 24151031714707.086, 1, 7.59375, 57.6650390625, 437.89389038085938, 3325.2567300796509, 25251.168294042349, 191751.05923288409, 1456109.6060497134, 11057332.320940012, 83966617.31213823, 637621500.21404958, 4841938267.2504387, 36768468716.933022, 279210559319.21014, 2120255184830.252, 16100687809804.727, 122264598055704.64, 928446791485507, 7050392822843070, 53538920498464552, ]).reshape((20, 5, ), order="F") | ||
--END TEST CASE-- | ||
--BEGIN TEST CASE-- | ||
degree=5 | ||
raw=FALSE | ||
output=np.array([-0.12865949508274149, -0.12846539500908838, -0.12817424489860868, -0.12773751973288924, -0.12708243198431005, -0.12609980036144131, -0.12462585292713815, -0.12241493177568342, -0.11909855004850137, -0.11412397745772825, -0.10666211857156857, -0.095469330242329037, -0.07868014774846975, -0.053496374007680828, -0.015720713396497447, 0.040942777520277619, 0.12593801389544024, 0.25343086845818413, 0.4446701503023, 0.73152907306847381, 0.11682670564764953, 0.11622774987820758, 0.1153299112445243, 0.11398449209008393, 0.11196937564961051, 0.10895347864407183, 0.10444488285989936, 0.097716301062945765, 0.087700630095951776, 0.072850827534442664, 0.05096695744238839, 0.019020528242278005, -0.026920519697452645, -0.091380250921070119, -0.1780532062130448, -0.28552519567824058, -0.39602393206231051, -0.44767622905753701, -0.26843910749340033, 0.57802660073100254, -0.11560888340228653, -0.11436481217184656, -0.11250218782662975, -0.10971608089390825, -0.10555451667328646, -0.099351692934324679, -0.090136150925525155, -0.076511614727544461, -0.05651941299388475, -0.027522538371457093, 0.013772191900731716, 0.070864671671751547, 0.14593497036033168, 0.23591981919395397, 0.32391016867398448, 0.36336942185480259, 0.25890497941187346, -0.11572025100301592, -0.66076386903314166, 0.27159578788942196, 0.11925766326375063, 0.11701962699862156, 0.11367531238125347, 0.10868744714732725, 0.10126981942884175, 0.090287103769210786, 0.074134201646975206, 0.050620044131431986, 0.016933017097416861, -0.030116712154368355, -0.093138533517390085, -0.17160263551697441, -0.25618209006285081, -0.3183631162695052, -0.29707753517866498, -0.10102478727647804, 0.30185248746535442, 0.55289166632880227, -0.46108564710186972, 0.081962667419115426, -0.12626707822019206, -0.12250155553682644, -0.11689136915447108, -0.10856147160045609, -0.096257598068575617, -0.078227654788373013, -0.052128116579684983, -0.015063001240831148, 0.035988153544508683, 0.10280803884977513, 0.18263307034840112, 0.26144732880503613, 0.30325203347309243, 0.24116709207723347, -0.00082575540196283526, -0.37830141983168153, -0.42887161757203512, 0.55207091753656046, -0.17171017635275559, 0.016240179713238136, ]).reshape((20, 5, ), order="F") | ||
--END TEST CASE-- | ||
""" | ||
R_poly_num_tests = 6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
cat("# This file auto-generated by tools/get-R-bs-test-vectors.R\n") | ||
cat(sprintf("# Using: %s\n", R.Version()$version.string)) | ||
cat("import numpy as np\n") | ||
|
||
options(digits=20) | ||
library(splines) | ||
x <- (1.5)^(0:19) | ||
|
||
MISSING <- "MISSING" | ||
|
||
is.missing <- function(obj) { | ||
length(obj) == 1 && obj == MISSING | ||
} | ||
|
||
pyprint <- function(arr) { | ||
if (is.missing(arr)) { | ||
cat("None\n") | ||
} else { | ||
cat("np.array([") | ||
for (val in arr) { | ||
cat(val) | ||
cat(", ") | ||
} | ||
cat("])") | ||
if (!is.null(dim(arr))) { | ||
cat(".reshape((") | ||
for (size in dim(arr)) { | ||
cat(sprintf("%s, ", size)) | ||
} | ||
cat("), order=\"F\")") | ||
} | ||
cat("\n") | ||
} | ||
} | ||
|
||
num.tests <- 0 | ||
dump.poly <- function(degree, raw) { | ||
cat("--BEGIN TEST CASE--\n") | ||
cat(sprintf("degree=%s\n", degree)) | ||
cat(sprintf("raw=%s\n", raw)) | ||
|
||
args <- list(x=x, degree=degree, raw=raw) | ||
|
||
result <- do.call(poly, args) | ||
|
||
cat("output=") | ||
pyprint(result) | ||
cat("--END TEST CASE--\n") | ||
assign("num.tests", num.tests + 1, envir=.GlobalEnv) | ||
} | ||
|
||
cat("R_poly_test_x = ") | ||
pyprint(x) | ||
cat("R_poly_test_data = \"\"\"\n") | ||
|
||
for (degree in c(1, 3, 5)) { | ||
for (raw in c(TRUE, FALSE)) { | ||
dump.poly(degree, raw) | ||
} | ||
} | ||
cat("\"\"\"\n") | ||
cat(sprintf("R_poly_num_tests = %s\n", num.tests)) |