Skip to content

Commit ce2420f

Browse files
committed
blackified everything
1 parent 14f6591 commit ce2420f

File tree

14 files changed

+100
-76
lines changed

14 files changed

+100
-76
lines changed

examples/example_mlinterp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def vec_eval(u):
6262

6363
x1 = np.linspace(0, 1, 100) ** 2 # non-uniform points
6464
x2 = np.linspace(0, 1, 100) ** 2 # non-uniform points
65-
y = np.array([[np.sqrt(u1 ** 2 + u2 ** 2) for u2 in x2] for u1 in x1])
65+
y = np.array([[np.sqrt(u1**2 + u2**2) for u2 in x2] for u1 in x1])
6666
# (y[i,j] = sqrt(x1[i]**2+x2[j]**2)
6767

6868

@@ -112,7 +112,7 @@ def vec_eval(p):
112112
x1 = np.linspace(0, 1, 100) ** 2 # non-uniform points for first dimensoin
113113
x2 = (0.0, 1.0, 100) # uniform points for second dimension
114114
grid = (x1, x2)
115-
y = np.array([[np.sqrt(u1 ** 2 + u2 ** 2) for u2 in x2] for u1 in x1])
115+
y = np.array([[np.sqrt(u1**2 + u2**2) for u2 in x2] for u1 in x1])
116116

117117

118118
points = np.random.random((1000, 2))

interpolation/smolyak/grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def __init__(self, d, mu, lb=None, ub=None):
786786
def __repr__(self):
787787
npoints = self.cube_grid.shape[0]
788788
nz_pts = np.count_nonzero(self.B)
789-
pct_nz = nz_pts / (npoints ** 2.0)
789+
pct_nz = nz_pts / (npoints**2.0)
790790

791791
if isinstance(self.mu, int):
792792
msg = "Smolyak Grid:\n\td: {0} \n\tmu: {1} \n\tnpoints: {2}"

interpolation/smolyak/tests/test_interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import numpy as np
66

77
# func = lambda x, y: np.exp(x**2 - y**2)
8-
func = lambda x, y: x ** 2 - y ** 2
8+
func = lambda x, y: x**2 - y**2
99

1010

1111
func1 = lambda points: func(points[:, 0], points[:, 1])
1212
func1_prime = lambda x: np.column_stack([2 * x[:, 0], -2 * x[:, 1]])
1313

14-
func2 = lambda x: np.sum(x ** 2, axis=1)
14+
func2 = lambda x: np.sum(x**2, axis=1)
1515
func2_prime = lambda x: 2 * x
1616

1717

interpolation/splines/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@
1212
def UCGrid(*args):
1313
tt = numba.typeof((10.0, 1.0, 1))
1414
for a in args:
15-
assert(numba.typeof(a) == tt)
15+
assert numba.typeof(a) == tt
1616
min, max, n = a
17-
assert(min<max)
18-
assert(n>1)
19-
17+
assert min < max
18+
assert n > 1
19+
2020
return tuple(args)
2121

2222

2323
def CGrid(*args):
2424
tt = numba.typeof((10.0, 1.0, 1))
2525
for a in args:
2626
if isinstance(a, np.ndarray):
27-
assert(a.ndim==1)
28-
assert(a.shape[0]>2)
29-
elif (numba.typeof(a) == tt):
27+
assert a.ndim == 1
28+
assert a.shape[0] > 2
29+
elif numba.typeof(a) == tt:
3030
min, max, n = a
31-
assert(min<max)
32-
assert(n>1)
31+
assert min < max
32+
assert n > 1
3333
else:
3434
raise Exception(f"Unknown dimension specification: {a}")
3535

interpolation/splines/hermite.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,27 @@
33
import numpy.typing as npt
44
from typing import Tuple
55

6+
67
class Vector:
78
pass
89

10+
911
@jit(nopython=True)
10-
def hermite_splines(lambda0: float)->Tuple[float, float, float, float]:
12+
def hermite_splines(lambda0: float) -> Tuple[float, float, float, float]:
1113
"""Computes the cubic Hermite splines in lambda0
1214
Inputs: - float: lambda0
1315
Output: - tuple: cubic Hermite splines evaluated in lambda0"""
14-
h00 = 2*(lambda0**3) - 3*(lambda0**2) + 1
15-
h10 = (lambda0**3) - 2*(lambda0**2) + lambda0
16-
h01 = -2*(lambda0**3) + 3*(lambda0**2)
16+
h00 = 2 * (lambda0**3) - 3 * (lambda0**2) + 1
17+
h10 = (lambda0**3) - 2 * (lambda0**2) + lambda0
18+
h01 = -2 * (lambda0**3) + 3 * (lambda0**2)
1719
h11 = (lambda0**3) - (lambda0**2)
1820
return (h00, h10, h01, h11)
1921

2022

2123
@jit(nopython=True)
22-
def hermite_interp(x0: float, xk: float, xkn: float, pk: float, pkn: float, mk: float, mkn: float)->float:
24+
def hermite_interp(
25+
x0: float, xk: float, xkn: float, pk: float, pkn: float, mk: float, mkn: float
26+
) -> float:
2327
"""Returns the interpolated value for x0.
2428
Inputs: - float: x0, abscissa of the point to interpolate
2529
- float: xk, abscissa of the nearest lowest point to x0 on the grid
@@ -30,9 +34,14 @@ def hermite_interp(x0: float, xk: float, xkn: float, pk: float, pkn: float, mk:
3034
- float: mkn, tangent in xkn
3135
Output: - float: interpolated value for x0
3236
"""
33-
t = (x0-xk)/(xkn-xk)
37+
t = (x0 - xk) / (xkn - xk)
3438
hsplines = hermite_splines(t)
35-
return (pk*hsplines[0] + mk*(xkn-xk)*hsplines[1] + pkn*hsplines[2] + mkn*(xkn-xk)*hsplines[3])
39+
return (
40+
pk * hsplines[0]
41+
+ mk * (xkn - xk) * hsplines[1]
42+
+ pkn * hsplines[2]
43+
+ mkn * (xkn - xk) * hsplines[3]
44+
)
3645

3746

3847
@jit(nopython=True)
@@ -48,12 +57,12 @@ def HermiteInterpolation(x0: float, x, y, yp):
4857
return y[0]
4958
elif x0 >= np.max(x):
5059
return y[-1]
51-
60+
5261
###### Interpolation case ######
5362
indx = np.searchsorted(x, x0)
54-
xk, xkn = x[indx-1], x[indx]
55-
pk, pkn = y[indx-1], y[indx]
56-
mk, mkn = yp[indx-1], yp[indx]
63+
xk, xkn = x[indx - 1], x[indx]
64+
pk, pkn = y[indx - 1], y[indx]
65+
mk, mkn = yp[indx - 1], yp[indx]
5766
return hermite_interp(x0, xk, xkn, pk, pkn, mk, mkn)
5867

5968

@@ -72,43 +81,54 @@ def HermiteInterpolationVect(xvect, x: Vector, y: Vector, yp: Vector):
7281
out[i] = HermiteInterpolation(x0, x, y, yp)
7382
return out
7483

84+
7585
from numba import njit, types
7686
from numba.extending import overload, register_jitable
7787
from numba import generated_jit
7888

7989

80-
def _hermite(x0,x,y,yp,out=None):
90+
def _hermite(x0, x, y, yp, out=None):
8191
pass
8292

93+
8394
@overload(_hermite)
84-
def _hermite(x0,x,y,yp,out=None):
85-
def _hermite(x0,x,y,yp,out=None):
86-
return HermiteInterpolation(x0,x,y,yp)
95+
def _hermite(x0, x, y, yp, out=None):
96+
def _hermite(x0, x, y, yp, out=None):
97+
return HermiteInterpolation(x0, x, y, yp)
98+
8799
return _hermite
88100

101+
89102
from numba.core.types.misc import NoneType as none
90103

104+
91105
@generated_jit
92-
def hermite(x0,x,y,yp,out=None):
106+
def hermite(x0, x, y, yp, out=None):
93107
try:
94108
n = x0.ndim
95-
if n==1:
96-
input_type = 'vector'
97-
elif n==2:
98-
input_type = 'matrix'
109+
if n == 1:
110+
input_type = "vector"
111+
elif n == 2:
112+
input_type = "matrix"
99113
else:
100114
raise Exception("Invalid input type")
101115
except:
102116
# n must be a scalar
103-
input_type = 'scalar'
104-
105-
if input_type == 'scalar':
106-
def _hermite(x0,x,y,yp,out=None):
107-
return HermiteInterpolation(x0,x,y,yp)
108-
elif input_type == 'vector':
109-
def _hermite(x0,x,y,yp,out=None):
110-
return HermiteInterpolationVect(x0,x,y,yp)
111-
elif input_type == 'matrix':
112-
def _hermite(x0,x,y,yp,out=None):
113-
return HermiteInterpolationVect(x0[:,0],x,y,yp)
114-
return _hermite
117+
input_type = "scalar"
118+
119+
if input_type == "scalar":
120+
121+
def _hermite(x0, x, y, yp, out=None):
122+
return HermiteInterpolation(x0, x, y, yp)
123+
124+
elif input_type == "vector":
125+
126+
def _hermite(x0, x, y, yp, out=None):
127+
return HermiteInterpolationVect(x0, x, y, yp)
128+
129+
elif input_type == "matrix":
130+
131+
def _hermite(x0, x, y, yp, out=None):
132+
return HermiteInterpolationVect(x0[:, 0], x, y, yp)
133+
134+
return _hermite

interpolation/splines/tests/test_derivatives.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ def test_derivatives():
44

55
import numpy as np
66

7-
87
grid = ((0.0, 1.0, 10),)
98

109
# grid = ((0.0, 1.0, 0.1),(0.0, 1.0, 0.1))
@@ -13,7 +12,6 @@ def test_derivatives():
1312
Cx = np.concatenate([C[:, None], C[:, None] * 2])
1413
points = np.random.random((10, 1))
1514

16-
1715
eval_spline(
1816
grid, C, (-0.1,), out=None, order=1, diff="None", extrap_mode="nearest"
1917
) # no alloc
@@ -24,7 +22,6 @@ def test_derivatives():
2422
grid, C, (-0.1,), out=None, order=1, diff="None", extrap_mode="linear"
2523
) # no alloc
2624

27-
2825
eval_spline(
2926
grid, C, (1.1,), out=None, order=1, diff="None", extrap_mode="nearest"
3027
) # no alloc
@@ -35,16 +32,13 @@ def test_derivatives():
3532
grid, C, (1.1,), out=None, order=1, diff="None", extrap_mode="linear"
3633
) # no alloc
3734

38-
3935
eval_spline(
4036
grid, Cx, points[0, :], out=None, order=1, diff="None", extrap_mode="linear"
4137
)
4238

43-
4439
eval_spline(grid, C, points, out=None, order=1, diff="None", extrap_mode="linear")
4540
eval_spline(grid, Cx, points, out=None, order=1, diff="None", extrap_mode="linear")
4641

47-
4842
orders = str(((0,), (1,)))
4943

5044
eval_spline(
@@ -56,18 +50,18 @@ def test_derivatives():
5650
eval_spline(grid, C, points, out=None, order=1, diff=orders, extrap_mode="linear")
5751
eval_spline(grid, Cx, points, out=None, order=1, diff=orders, extrap_mode="linear")
5852

59-
6053
out = eval_spline(
6154
grid, Cx, points, out=None, order=1, diff=orders, extrap_mode="linear"
6255
)
6356
out2 = np.zeros_like(out)
6457
eval_spline(grid, Cx, points, out=out2, order=1, diff=orders, extrap_mode="linear")
6558
print(abs(out - out2).max())
6659

67-
6860
k = 3
6961

70-
eval_spline(grid, C, points[0, :], out=None, order=3, diff="None", extrap_mode="linear")
62+
eval_spline(
63+
grid, C, points[0, :], out=None, order=3, diff="None", extrap_mode="linear"
64+
)
7165

7266
eval_spline(grid, C, points, out=None, order=k, diff="None", extrap_mode="linear")
7367
eval_spline(
@@ -77,7 +71,9 @@ def test_derivatives():
7771

7872
orders = str(((0,), (1,)))
7973

80-
eval_spline(grid, C, points[0, :], out=None, order=k, diff=orders, extrap_mode="linear")
74+
eval_spline(
75+
grid, C, points[0, :], out=None, order=k, diff=orders, extrap_mode="linear"
76+
)
8177
eval_spline(grid, C, points, out=None, order=k, diff=orders, extrap_mode="linear")
8278
eval_spline(
8379
grid, Cx, points[0, :], out=None, order=k, diff=orders, extrap_mode="linear"
Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
11
def test_hermite_splines():
2-
2+
33
from interpolation.splines.hermite import HermiteInterpolationVect
44
import numpy as np
5-
6-
N = 10000 # Number of points in the initial dataset
5+
6+
N = 10000 # Number of points in the initial dataset
77
K = 1000 # Number of new points to interpolate
8-
8+
99
# Initial dataset
1010
# grid = ((0.0, 1.0, K),) # Creation of an x-axis grid (xi)
1111
grid = np.linspace(0.0, 1.0, N) # Creation of an x-axis grid (xi)
12-
points = np.random.random((N)) # Random values for f(xi)
12+
points = np.random.random((N)) # Random values for f(xi)
1313
dpoints = np.random.random((N)) # Random derivatives for f'(xi)
14-
14+
1515
# Generate new points
1616
newgrid = np.random.random((K))
17-
17+
1818
# Interpolation
1919
out = HermiteInterpolationVect(newgrid, grid, points, dpoints)
20-
20+
2121
print("OK")
22-

interpolation/splines/tests/test_multilinear_extrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def test_multilinear_extrap():
1414

1515
s = mlinspace(a, b, n)
1616

17-
f = lambda x: (x ** 2).sum(axis=1)
17+
f = lambda x: (x**2).sum(axis=1)
1818

1919
x = f(s)
2020
v = x.reshape(n)

interpolation/tests/test_complete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def f(x, y):
3535
return x
3636

3737
def f2(x, y):
38-
return x ** 3 - y
38+
return x**3 - y
3939

4040
points = np.random.random((1000, 2))
4141
vals = np.column_stack(

interpolation/tests/test_derivs.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
class Check1DDerivatives(unittest.TestCase):
7-
"""
7+
"""
88
Checks derivatives in a 1D interpolator
99
"""
1010

@@ -35,7 +35,10 @@ def test_linear(self):
3535
# 0-order must be the function
3636
# 1-order must be the slope
3737
result = np.vstack(
38-
[y0 + slope * eval_points, np.ones_like(eval_points) * slope,]
38+
[
39+
y0 + slope * eval_points,
40+
np.ones_like(eval_points) * slope,
41+
]
3942
).T
4043

4144
self.assertTrue(np.allclose(grad, result))
@@ -61,7 +64,10 @@ def test_nonlinear(self):
6164
# 0-order must be the function
6265
# 1-order must be + or - pi/2
6366
result = np.vstack(
64-
[np.array([0, -1, 0, 1, 0]), np.array([-1, -1, 1, 1, -1]) * 2 / np.pi,]
67+
[
68+
np.array([0, -1, 0, 1, 0]),
69+
np.array([-1, -1, 1, 1, -1]) * 2 / np.pi,
70+
]
6571
).T
6672

6773
self.assertTrue(np.allclose(grad, result))
@@ -87,14 +93,17 @@ def test_nonlinear_approx(self):
8793
# 0-order must be x^3
8894
# 1-order must be close to 3x^2
8995
result = np.vstack(
90-
[np.power(eval_points, 3), np.power(eval_points, 2) * 3.0,]
96+
[
97+
np.power(eval_points, 3),
98+
np.power(eval_points, 2) * 3.0,
99+
]
91100
).T
92101

93102
self.assertTrue(np.allclose(grad, result, atol=0.02))
94103

95104

96105
class Check2DDerivatives(unittest.TestCase):
97-
"""
106+
"""
98107
Checks derivatives in a 2D interpolator
99108
"""
100109

0 commit comments

Comments
 (0)