Skip to content

Commit 736fef7

Browse files
authored
Merge pull request #16 from EconForge/complete_ders
WIP: Add derivatives of complete polynomials
2 parents b98dc9a + 07408a8 commit 736fef7

File tree

2 files changed

+293
-3
lines changed

2 files changed

+293
-3
lines changed

interpolation/complete_poly.py

Lines changed: 265 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def n_complete(n, d):
7373
return out
7474

7575

76+
#
77+
# Complete Polynomials Basis
78+
#
7679
def complete_polynomial(z, d):
7780
"""
7881
Construct basis matrix for complete polynomial of degree `d`, given
@@ -111,7 +114,10 @@ def complete_polynomial(z, d):
111114
return out
112115

113116

114-
@jit(nopython=True, cache=True)
117+
# TODO: Currently turning off all cache arguments so that
118+
# code works. This will be fixed in numba 0.32
119+
# @jit(nopython=True, cache=True)
120+
@jit(nopython=True)
115121
def _complete_poly_impl_vec(z, d, out):
116122
"out and z should be vectors"
117123
nvar = z.shape[0]
@@ -185,7 +191,8 @@ def _complete_poly_impl_vec(z, d, out):
185191
return
186192

187193

188-
@jit(nopython=True, cache=True)
194+
# @jit(nopython=True, cache=True)
195+
@jit(nopython=True)
189196
def _complete_poly_impl(z, d, out):
190197
nvar = z.shape[0]
191198
nobs = z.shape[1]
@@ -274,6 +281,257 @@ def _complete_poly_impl(z, d, out):
274281
return
275282

276283

284+
#
285+
# Complete Polynomials Derivative Basis
286+
#
287+
def complete_polynomial_der(z, d, der):
288+
"""
289+
Construct basis matrix for complete polynomial of degree `d`, given
290+
input data `z`.
291+
292+
Parameters
293+
----------
294+
z : np.ndarray(size=(nvariables, nobservations))
295+
The degree 1 realization of each variable. For example, if
296+
variables are `q`, `r`, and `s`, then `z` should be
297+
`z = np.row_stack([q, r, s])`
298+
299+
d : int
300+
An integer specifying the degree of the complete polynomial
301+
302+
der : int
303+
An integer specifying which variable to take derivative wrt --
304+
a 0 means take derivative wrt first variable in z etc...
305+
306+
Returns
307+
-------
308+
out : np.ndarray(size=(ncomplete(nvariables, d), nobservations))
309+
The basis matrix for the derivative of a complete polynomial
310+
of degree d with respect to variable der
311+
312+
"""
313+
# check inputs
314+
assert d >= 0, "d must be non-negative"
315+
assert der >= 0, "derivative must be non-negative"
316+
z = np.asarray(z)
317+
318+
# compute inds allocate space for output
319+
nvar, nobs = z.shape
320+
assert der < nvar, "derivative integer must be smaller than nobs in z"
321+
out = np.zeros((n_complete(nvar, d), nobs))
322+
323+
if d > 5:
324+
raise ValueError("Complete polynomial only implemeted up to degree 5")
325+
326+
# populate out with jitted function
327+
_complete_poly_der_impl(z, d, der, out)
328+
329+
return out
330+
331+
332+
# @jit(nopython=True, cache=True)
333+
@jit(nopython=True)
334+
def _complete_poly_der_impl_vec(z, d, der, out):
335+
"out and z should be vectors"
336+
nvar = z.shape[0]
337+
338+
out[0] = 0.0
339+
340+
# fill first order stuff
341+
if d >= 1:
342+
# All linear terms except for one (the variable itself) are 0
343+
for i in range(nvar):
344+
out[i+1] = 0.0
345+
out[der+1] = 1.0
346+
347+
if d == 1:
348+
return
349+
350+
# now we need to fill in row nvar and beyond
351+
ix = nvar
352+
if d == 2:
353+
for i1 in range(nvar):
354+
# Get coefficients and values
355+
(c1, t1) = (1, 1.0) if i1==der else (0, z[i1])
356+
for i2 in range(i1, nvar):
357+
(c2, t2) = (c1+1, 1.0) if i2==der else (c1, z[i2])
358+
359+
# Update index and out
360+
ix += 1
361+
out[ix] = c2 * t1*t2 * z[der]**(c2-1) if c2>0 else 0.0
362+
363+
return
364+
365+
if d == 3:
366+
for i1 in range(nvar):
367+
# Get coefficients and values
368+
(c1, t1) = (1, 1.0) if i1==der else (0, z[i1])
369+
for i2 in range(i1, nvar):
370+
(c2, t2) = (c1+1, 1.0) if i2==der else (c1, z[i2])
371+
ix += 1
372+
out[ix] = c2 * t1*t2 * z[der]**(c2-1) if c2>0 else 0.0
373+
374+
for i3 in range(i2, nvar):
375+
(c3, t3) = (c2+1, 1.0) if i3==der else (c2, z[i3])
376+
ix += 1
377+
out[ix] = c3 * t1*t2*t3 * z[der]**(c3-1) if c3>0 else 0.0
378+
379+
return
380+
381+
if d == 4:
382+
for i1 in range(nvar):
383+
# Get coefficients and values
384+
(c1, t1) = (1, 1.0) if i1==der else (0, z[i1])
385+
for i2 in range(i1, nvar):
386+
(c2, t2) = (c1+1, 1.0) if i2==der else (c1, z[i2])
387+
ix += 1
388+
out[ix] = c2 * t1*t2* z[der]**(c2-1) if c2>0 else 0.0
389+
390+
for i3 in range(i2, nvar):
391+
(c3, t3) = (c2+1, 1.0) if i3==der else (c2, z[i3])
392+
ix += 1
393+
out[ix] = c3*t1*t2*t3*z[der]**(c3-1) if c3>0 else 0.0
394+
395+
for i4 in range(i3, nvar):
396+
(c4, t4) = (c3+1, 1.0) if i4==der else (c3, z[i4])
397+
ix += 1
398+
out[ix] = c4*t1*t2*t3*t4*z[der]**(c4-1) if c4>0 else 0.0
399+
400+
return
401+
402+
if d == 5:
403+
for i1 in range(nvar):
404+
# Get coefficients and values
405+
(c1, t1) = (1, 1.0) if i1==der else (0, z[i1])
406+
for i2 in range(i1, nvar):
407+
(c2, t2) = (c1+1, 1.0) if i2==der else (c1, z[i2])
408+
ix += 1
409+
out[ix] = c2 * t1*t2* z[der]**(c2-1) if c2>0 else 0.0
410+
411+
for i3 in range(i2, nvar):
412+
(c3, t3) = (c2+1, 1.0) if i3==der else (c2, z[i3])
413+
ix += 1
414+
out[ix] = c3 * t1*t2*t3* z[der]**(c3-1) if c3>0 else 0.0
415+
416+
for i4 in range(i3, nvar):
417+
(c4, t4) = (c3+1, 1.0) if i4==der else (c3, z[i4])
418+
ix += 1
419+
out[ix] = c4*t1*t2*t3*t4*z[der]**(c4-1) if c4>0 else 0.0
420+
421+
for i5 in range(i4, nvar):
422+
(c5, t5) = (c4+1, 1.0) if i5==der else (c4, z[i5])
423+
ix += 1
424+
out[ix] = c5*t1*t2*t3*t4*t5*z[der]**(c5-1) if c5>0 else 0.0
425+
426+
return
427+
428+
429+
# @jit(nopython=True, cache=True)
430+
@jit(nopython=True)
431+
def _complete_poly_der_impl(z, d, der, out):
432+
nvar = z.shape[0]
433+
nobs = z.shape[1]
434+
435+
for k in range(nobs):
436+
out[0, k] = 0.0
437+
438+
# fill first order stuff
439+
if d >= 1:
440+
# Make sure everything has zeros in it
441+
for i in range(nvar):
442+
for k in range(nobs):
443+
out[i+1, k] = 0.0
444+
445+
# Then place ones where they belong in variable
446+
for k in range(nobs):
447+
out[der+1, k] = 1.0
448+
449+
if d == 1:
450+
return
451+
452+
# now we need to fill in row nvar and beyond
453+
ix = nvar
454+
if d == 2:
455+
for i1 in range(nvar):
456+
for i2 in range(i1, nvar):
457+
ix += 1
458+
for k in range(nobs):
459+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
460+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
461+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
462+
463+
return
464+
465+
if d == 3:
466+
for i1 in range(nvar):
467+
for i2 in range(i1, nvar):
468+
ix += 1
469+
for k in range(nobs):
470+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
471+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
472+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
473+
474+
for i3 in range(i2, nvar):
475+
ix += 1
476+
for k in range(nobs):
477+
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
478+
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
479+
480+
return
481+
482+
if d == 4:
483+
for i1 in range(nvar):
484+
for i2 in range(i1, nvar):
485+
ix += 1
486+
for k in range(nobs):
487+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
488+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
489+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
490+
491+
for i3 in range(i2, nvar):
492+
ix += 1
493+
for k in range(nobs):
494+
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
495+
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
496+
497+
for i4 in range(i3, nvar):
498+
ix += 1
499+
for k in range(nobs):
500+
c4, t4 = (c3+1, 1.0) if i4==der else (c3, z[i4, k])
501+
out[ix, k] = c4*t1*t2*t3*t4*z[der, k]**(c4-1) if c4>0 else 0.0
502+
503+
return
504+
505+
if d == 5:
506+
for i1 in range(nvar):
507+
for i2 in range(i1, nvar):
508+
ix += 1
509+
for k in range(nobs):
510+
c1, t1 = (1, 1.0) if i1==der else (0, z[i1, k])
511+
c2, t2 = (c1+1, 1.0) if i2==der else (c1, z[i2, k])
512+
out[ix, k] = c2*t1*t2*z[der, k]**(c2-1) if c2>0 else 0.0
513+
514+
for i3 in range(i2, nvar):
515+
ix += 1
516+
for k in range(nobs):
517+
c3, t3 = (c2+1, 1.0) if i3==der else (c2, z[i3, k])
518+
out[ix, k] = c3*t1*t2*t3*z[der, k]**(c3-1) if c3>0 else 0.0
519+
520+
for i4 in range(i3, nvar):
521+
ix += 1
522+
for k in range(nobs):
523+
c4, t4 = (c3+1, 1.0) if i4==der else (c3, z[i4, k])
524+
out[ix, k] = c4*t1*t2*t3*t4*z[der, k]**(c4-1) if c4>0 else 0.0
525+
526+
for i5 in range(i4, nvar):
527+
ix += 1
528+
for k in range(nobs):
529+
c5, t5 = (c4+1, 1.0) if i5==der else (c4, z[i5, k])
530+
out[ix, k] = c5*t1*t2*t3*t4*t5*z[der, k]**(c5-1) if c5>0 else 0.0
531+
532+
return
533+
534+
277535
class CompletePolynomial:
278536

279537
def __init__(self, n, d):
@@ -289,7 +547,12 @@ def fit_values(self, s, x, damp=0.0):
289547
new_coefs = np.ascontiguousarray(lstsq(Phi, x)[0])
290548
self.coefs = (1 - damp) * new_coefs + damp * self.coefs
291549

550+
def der(self, s, der):
551+
dPhi = complete_polynomial_der(s.T, self.d, der).T
552+
return np.dot(dPhi, self.coefs)
553+
292554
def __call__(self, s):
293555

294556
Phi = complete_polynomial(s.T, self.d).T
295557
return np.dot(Phi, self.coefs)
558+

interpolation/tests/test_complete.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import numpy as np
2-
import numpy as np
32

43
from interpolation.complete_poly import CompletePolynomial
4+
from interpolation.complete_poly import (n_complete, complete_polynomial,
5+
complete_polynomial_der,
6+
_complete_poly_impl,
7+
_complete_poly_impl_vec,
8+
_complete_poly_der_impl,
9+
_complete_poly_der_impl_vec)
510

611
def test_complete_scalar():
712

@@ -43,7 +48,29 @@ def f2(x, y): return x**3 - y
4348

4449
cp.fit_values(points, vals, damp=0.5)
4550

51+
def test_complete_derivative():
52+
53+
# TODO: Currently if z has a 0 value then it breaks because occasionally
54+
# tries to raise 0 to a negative power -- This can be fixed by
55+
# checking whether coefficient is 0 before trying to do anything...
56+
57+
# Test derivative vector
58+
z = np.array([1, 2, 3])
59+
sol_vec = np.array([0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 3.0, 0.0, 0.0, 0.0])
60+
out_vec = np.empty(n_complete(3, 2))
61+
_complete_poly_der_impl_vec(z, 2, 0, out_vec)
62+
assert(abs(out_vec - sol_vec).max() < 1e-10)
63+
64+
# Test derivative matrix
65+
z = np.arange(1, 7).reshape(3, 2)
66+
out_mat = complete_polynomial_der(z, 2, 1)
67+
assert(abs(out_mat[0, :]).max() < 1e-10)
68+
assert(abs(out_mat[2, :] - np.ones(2)).max() < 1e-10)
69+
assert(abs(out_mat[-2, :] - np.array([5.0, 6.0])).max() < 1e-10)
70+
4671

4772
if __name__ == '__main__':
4873
test_complete_scalar()
4974
test_complete_vector()
75+
test_complete_derivative()
76+

0 commit comments

Comments
 (0)