Skip to content

Commit 2d20da5

Browse files
authored
Merge pull request #114 from oyamad/overload
Apply `@njit` to `interp` and `mlinterp`
2 parents 705cbce + 7003f0f commit 2d20da5

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

interpolation/multilinear/mlinterp.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@
4141
# logic of multilinear interpolation
4242

4343

44-
def mlinterp(grid, c, u):
44+
def _mlinterp(grid, c, u):
4545
pass
4646

4747

48-
@overload(mlinterp)
48+
@overload(_mlinterp)
4949
def ol_mlinterp(grid, c, u):
5050
if isinstance(u, UniTuple):
5151

52-
def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:
52+
def mlininterp(grid, c, u):
5353
# get indices and barycentric coordinates
5454
tmp = fmap(get_index, grid, u)
5555
indices, barycenters = funzip(tmp)
@@ -59,7 +59,7 @@ def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:
5959

6060
elif isinstance(u, Array) and u.ndim == 2:
6161

62-
def mlininterp(grid: Tuple, c: Array, u: Array) -> float:
62+
def mlininterp(grid, c, u):
6363
N = u.shape[0]
6464
res = np.zeros(N)
6565
for n in range(N):
@@ -76,6 +76,11 @@ def mlininterp(grid: Tuple, c: Array, u: Array) -> float:
7676
return mlininterp
7777

7878

79+
@njit
80+
def mlinterp(grid, c, u):
81+
return _mlinterp(grid, c, u)
82+
83+
7984
### The rest of this file constrcts function `interp`
8085

8186
from collections import namedtuple
@@ -217,15 +222,13 @@ def {funname}(*args):
217222
return source
218223

219224

220-
def interp(*args):
225+
def _interp(*args):
221226
pass
222227

223228

224-
@overload(interp)
229+
@overload(_interp)
225230
def ol_interp(*args):
226-
aa = args[0].types
227-
228-
it = detect_types(aa)
231+
it = detect_types(args)
229232
if it.d == 1 and it.eval == "point":
230233
it = itt(it.d, it.values, "cartesian")
231234
source = make_mlinterp(it, "__mlinterp")
@@ -235,3 +238,8 @@ def ol_interp(*args):
235238
code = compile(tree, "<string>", "exec")
236239
eval(code, globals())
237240
return __mlinterp
241+
242+
243+
@njit
244+
def interp(*args):
245+
return _interp(*args)

interpolation/multilinear/tests/test_multilinear.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def test_mlinterp():
115115
pp = np.random.random((2000, 2))
116116

117117
res0 = mlinterp((x1, x2), y, pp)
118+
assert res0 is not None
119+
118120
res0 = mlinterp((x1, x2), y, (0.1, 0.2))
121+
assert res0 is not None
119122

120123

121124
def test_multilinear():
@@ -125,6 +128,8 @@ def test_multilinear():
125128
tt = [typeof(e) for e in t]
126129
rr = interp(*t)
127130

131+
assert rr is not None
132+
128133
try:
129134
print(f"{tt}: {rr.shape}")
130135
except:

0 commit comments

Comments
 (0)