@@ -99,17 +99,23 @@ def complete_polynomial(z, d):
9999 """
100100 # check inputs
101101 assert d >= 0 , "d must be non-negative"
102- z = np .asarray (z )
103-
104- # compute inds allocate space for output
105- nvar , nobs = z .shape
106- out = np .zeros ((n_complete (nvar , d ), nobs ))
107-
108102 if d > 5 :
109103 raise ValueError ("Complete polynomial only implemeted up to degree 5" )
110104
111- # populate out with jitted function
112- _complete_poly_impl (z , d , out )
105+ # Assure z is array
106+ z = np .asarray (z )
107+
108+ # compute inds allocate space for output
109+ if np .ndim (z ) == 1 :
110+ nvar = z .size
111+ out = np .zeros (n_complete (nvar , d ))
112+ # populate out with jitted function
113+ _complete_poly_impl_vec (z , d , out )
114+ else :
115+ nvar , nobs = z .shape
116+ out = np .zeros ((n_complete (nvar , d ), nobs ))
117+ # populate out with jitted function
118+ _complete_poly_impl (z , d , out )
113119
114120 return out
115121
@@ -313,18 +319,25 @@ def complete_polynomial_der(z, d, der):
313319 # check inputs
314320 assert d >= 0 , "d must be non-negative"
315321 assert der >= 0 , "derivative must be non-negative"
316- z = np .asarray (z )
317-
318- # compute inds allocate space for output
319- nvar , nobs = z .shape
320- assert der < nvar , "derivative integer must be smaller than nobs in z"
321- out = np .zeros ((n_complete (nvar , d ), nobs ))
322-
323322 if d > 5 :
324323 raise ValueError ("Complete polynomial only implemeted up to degree 5" )
325324
326- # populate out with jitted function
327- _complete_poly_der_impl (z , d , der , out )
325+ # Ensure z is a numpy array
326+ z = np .asarray (z )
327+
328+ # compute inds allocate space for output
329+ if np .ndim (z ) == 1 :
330+ nvar = z .size
331+ assert der < nvar , "derivative integer must be smaller than nobs in z"
332+ out = np .zeros (n_complete (nvar , d ))
333+ # populate with jitted function
334+ _complete_poly_der_impl_vec (z , d , der , out )
335+ else :
336+ nvar , nobs = z .shape
337+ assert der < nvar , "derivative integer must be smaller than nobs in z"
338+ out = np .zeros ((n_complete (nvar , d ), nobs ))
339+ # populate out with jitted function
340+ _complete_poly_der_impl (z , d , der , out )
328341
329342 return out
330343
@@ -474,6 +487,8 @@ def _complete_poly_der_impl(z, d, der, out):
474487 for i3 in range (i2 , nvar ):
475488 ix += 1
476489 for k in range (nobs ):
490+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
491+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
477492 c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
478493 out [ix , k ] = c3 * t1 * t2 * t3 * z [der , k ]** (c3 - 1 ) if c3 > 0 else 0.0
479494
@@ -491,12 +506,17 @@ def _complete_poly_der_impl(z, d, der, out):
491506 for i3 in range (i2 , nvar ):
492507 ix += 1
493508 for k in range (nobs ):
509+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
510+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
494511 c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
495512 out [ix , k ] = c3 * t1 * t2 * t3 * z [der , k ]** (c3 - 1 ) if c3 > 0 else 0.0
496513
497514 for i4 in range (i3 , nvar ):
498515 ix += 1
499516 for k in range (nobs ):
517+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
518+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
519+ c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
500520 c4 , t4 = (c3 + 1 , 1.0 ) if i4 == der else (c3 , z [i4 , k ])
501521 out [ix , k ] = c4 * t1 * t2 * t3 * t4 * z [der , k ]** (c4 - 1 ) if c4 > 0 else 0.0
502522
@@ -514,18 +534,27 @@ def _complete_poly_der_impl(z, d, der, out):
514534 for i3 in range (i2 , nvar ):
515535 ix += 1
516536 for k in range (nobs ):
537+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
538+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
517539 c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
518540 out [ix , k ] = c3 * t1 * t2 * t3 * z [der , k ]** (c3 - 1 ) if c3 > 0 else 0.0
519541
520542 for i4 in range (i3 , nvar ):
521543 ix += 1
522544 for k in range (nobs ):
545+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
546+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
547+ c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
523548 c4 , t4 = (c3 + 1 , 1.0 ) if i4 == der else (c3 , z [i4 , k ])
524549 out [ix , k ] = c4 * t1 * t2 * t3 * t4 * z [der , k ]** (c4 - 1 ) if c4 > 0 else 0.0
525550
526551 for i5 in range (i4 , nvar ):
527552 ix += 1
528553 for k in range (nobs ):
554+ c1 , t1 = (1 , 1.0 ) if i1 == der else (0 , z [i1 , k ])
555+ c2 , t2 = (c1 + 1 , 1.0 ) if i2 == der else (c1 , z [i2 , k ])
556+ c3 , t3 = (c2 + 1 , 1.0 ) if i3 == der else (c2 , z [i3 , k ])
557+ c4 , t4 = (c3 + 1 , 1.0 ) if i4 == der else (c3 , z [i4 , k ])
529558 c5 , t5 = (c4 + 1 , 1.0 ) if i5 == der else (c4 , z [i5 , k ])
530559 out [ix , k ] = c5 * t1 * t2 * t3 * t4 * t5 * z [der , k ]** (c5 - 1 ) if c5 > 0 else 0.0
531560
0 commit comments