33import numpy .typing as npt
44from typing import Tuple
55
6+
67class Vector :
78 pass
89
10+
911@jit (nopython = True )
10- def hermite_splines (lambda0 : float )-> Tuple [float , float , float , float ]:
12+ def hermite_splines (lambda0 : float ) -> Tuple [float , float , float , float ]:
1113 """Computes the cubic Hermite splines in lambda0
1214 Inputs: - float: lambda0
1315 Output: - tuple: cubic Hermite splines evaluated in lambda0"""
14- h00 = 2 * (lambda0 ** 3 ) - 3 * (lambda0 ** 2 ) + 1
15- h10 = (lambda0 ** 3 ) - 2 * (lambda0 ** 2 ) + lambda0
16- h01 = - 2 * (lambda0 ** 3 ) + 3 * (lambda0 ** 2 )
16+ h00 = 2 * (lambda0 ** 3 ) - 3 * (lambda0 ** 2 ) + 1
17+ h10 = (lambda0 ** 3 ) - 2 * (lambda0 ** 2 ) + lambda0
18+ h01 = - 2 * (lambda0 ** 3 ) + 3 * (lambda0 ** 2 )
1719 h11 = (lambda0 ** 3 ) - (lambda0 ** 2 )
1820 return (h00 , h10 , h01 , h11 )
1921
2022
2123@jit (nopython = True )
22- def hermite_interp (x0 : float , xk : float , xkn : float , pk : float , pkn : float , mk : float , mkn : float )-> float :
24+ def hermite_interp (
25+ x0 : float , xk : float , xkn : float , pk : float , pkn : float , mk : float , mkn : float
26+ ) -> float :
2327 """Returns the interpolated value for x0.
2428 Inputs: - float: x0, abscissa of the point to interpolate
2529 - float: xk, abscissa of the nearest lowest point to x0 on the grid
@@ -30,9 +34,14 @@ def hermite_interp(x0: float, xk: float, xkn: float, pk: float, pkn: float, mk:
3034 - float: mkn, tangent in xkn
3135 Output: - float: interpolated value for x0
3236 """
33- t = (x0 - xk )/ (xkn - xk )
37+ t = (x0 - xk ) / (xkn - xk )
3438 hsplines = hermite_splines (t )
35- return (pk * hsplines [0 ] + mk * (xkn - xk )* hsplines [1 ] + pkn * hsplines [2 ] + mkn * (xkn - xk )* hsplines [3 ])
39+ return (
40+ pk * hsplines [0 ]
41+ + mk * (xkn - xk ) * hsplines [1 ]
42+ + pkn * hsplines [2 ]
43+ + mkn * (xkn - xk ) * hsplines [3 ]
44+ )
3645
3746
3847@jit (nopython = True )
@@ -48,12 +57,12 @@ def HermiteInterpolation(x0: float, x, y, yp):
4857 return y [0 ]
4958 elif x0 >= np .max (x ):
5059 return y [- 1 ]
51-
60+
5261 ###### Interpolation case ######
5362 indx = np .searchsorted (x , x0 )
54- xk , xkn = x [indx - 1 ], x [indx ]
55- pk , pkn = y [indx - 1 ], y [indx ]
56- mk , mkn = yp [indx - 1 ], yp [indx ]
63+ xk , xkn = x [indx - 1 ], x [indx ]
64+ pk , pkn = y [indx - 1 ], y [indx ]
65+ mk , mkn = yp [indx - 1 ], yp [indx ]
5766 return hermite_interp (x0 , xk , xkn , pk , pkn , mk , mkn )
5867
5968
@@ -72,43 +81,54 @@ def HermiteInterpolationVect(xvect, x: Vector, y: Vector, yp: Vector):
7281 out [i ] = HermiteInterpolation (x0 , x , y , yp )
7382 return out
7483
84+
7585from numba import njit , types
7686from numba .extending import overload , register_jitable
7787from numba import generated_jit
7888
7989
80- def _hermite (x0 ,x , y , yp ,out = None ):
90+ def _hermite (x0 , x , y , yp , out = None ):
8191 pass
8292
93+
8394@overload (_hermite )
84- def _hermite (x0 ,x ,y ,yp ,out = None ):
85- def _hermite (x0 ,x ,y ,yp ,out = None ):
86- return HermiteInterpolation (x0 ,x ,y ,yp )
95+ def _hermite (x0 , x , y , yp , out = None ):
96+ def _hermite (x0 , x , y , yp , out = None ):
97+ return HermiteInterpolation (x0 , x , y , yp )
98+
8799 return _hermite
88100
101+
89102from numba .core .types .misc import NoneType as none
90103
104+
91105@generated_jit
92- def hermite (x0 ,x , y , yp ,out = None ):
106+ def hermite (x0 , x , y , yp , out = None ):
93107 try :
94108 n = x0 .ndim
95- if n == 1 :
96- input_type = ' vector'
97- elif n == 2 :
98- input_type = ' matrix'
109+ if n == 1 :
110+ input_type = " vector"
111+ elif n == 2 :
112+ input_type = " matrix"
99113 else :
100114 raise Exception ("Invalid input type" )
101115 except :
102116 # n must be a scalar
103- input_type = 'scalar'
104-
105- if input_type == 'scalar' :
106- def _hermite (x0 ,x ,y ,yp ,out = None ):
107- return HermiteInterpolation (x0 ,x ,y ,yp )
108- elif input_type == 'vector' :
109- def _hermite (x0 ,x ,y ,yp ,out = None ):
110- return HermiteInterpolationVect (x0 ,x ,y ,yp )
111- elif input_type == 'matrix' :
112- def _hermite (x0 ,x ,y ,yp ,out = None ):
113- return HermiteInterpolationVect (x0 [:,0 ],x ,y ,yp )
114- return _hermite
117+ input_type = "scalar"
118+
119+ if input_type == "scalar" :
120+
121+ def _hermite (x0 , x , y , yp , out = None ):
122+ return HermiteInterpolation (x0 , x , y , yp )
123+
124+ elif input_type == "vector" :
125+
126+ def _hermite (x0 , x , y , yp , out = None ):
127+ return HermiteInterpolationVect (x0 , x , y , yp )
128+
129+ elif input_type == "matrix" :
130+
131+ def _hermite (x0 , x , y , yp , out = None ):
132+ return HermiteInterpolationVect (x0 [:, 0 ], x , y , yp )
133+
134+ return _hermite
0 commit comments