Skip to content

Commit f9f6f64

Browse files
committed
Replace generated_jit with overload
1 parent 440d064 commit f9f6f64

File tree

9 files changed

+84
-32
lines changed

9 files changed

+84
-32
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
runs-on: ubuntu-latest
1010
strategy:
1111
matrix:
12-
python-version: [ '3.8', '3.9', '3.10', '3.11']
12+
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12']
1313

1414
name: Test Interpolation.py (Python ${{ matrix.python-version }})
1515
steps:

examples/example_mlinterp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22

3-
from numba import generated_jit
43
import ast
54

65
C = ((0.1, 0.2), (0.1, 0.2))

interpolation/multilinear/fungen.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numba
22
import numpy as np
33
from numba import float64, int64
4-
from numba import generated_jit, njit
4+
from numba import njit
55
import ast
66

77
from numba.extending import overload
@@ -25,8 +25,12 @@ def clamp(x, a, b):
2525

2626

2727
# returns the index of a 1d point along a 1d dimension
28-
@generated_jit(nopython=True)
2928
def get_index(gc, x):
29+
pass
30+
31+
32+
@overload(get_index)
33+
def ol_get_index(gc, x):
3034
if gc == t_coord:
3135
# regular coordinate
3236
def fun(gc, x):
@@ -53,8 +57,12 @@ def fun(gc, x):
5357

5458

5559
# returns number of dimension of a dimension
56-
@generated_jit(nopython=True)
5760
def get_size(gc):
61+
pass
62+
63+
64+
@overload(get_size)
65+
def ol_get_size(gc):
5866
if gc == t_coord:
5967
# regular coordinate
6068
def fun(gc):
@@ -145,8 +153,12 @@ def _map(*args):
145153
# funzip(((1,2), (2,3), (4,3))) -> ((1,2,4),(2,3,3))
146154

147155

148-
@generated_jit(nopython=True)
149156
def funzip(t):
157+
pass
158+
159+
160+
@overload(funzip)
161+
def ol_funzip(t):
150162
k = t.count
151163
assert len(set([e.count for e in t.types])) == 1
152164
l = t.types[0].count
@@ -169,8 +181,12 @@ def print_tuple(t):
169181
#####
170182

171183

172-
@generated_jit(nopython=True)
173184
def get_coeffs(X, I):
185+
pass
186+
187+
188+
@overload(get_coeffs)
189+
def ol_get_coeffs(X, I):
174190
if X.ndim > len(I):
175191
print("not implemented yet")
176192
else:
@@ -218,8 +234,12 @@ def gen_tensor_reduction(X, symbs, inds=[]):
218234
return str.join(" + ", exprs)
219235

220236

221-
@generated_jit(nopython=True)
222237
def tensor_reduction(C, l):
238+
pass
239+
240+
241+
@overload(tensor_reduction)
242+
def ol_tensor_reduction(C, l):
223243
d = len(l.types)
224244
ex = gen_tensor_reduction("C", ["l[{}]".format(i) for i in range(d)])
225245
dd = dict()
@@ -228,8 +248,12 @@ def tensor_reduction(C, l):
228248
return dd["tensor_reduction"]
229249

230250

231-
@generated_jit(nopython=True)
232251
def extract_row(a, n, tup):
252+
pass
253+
254+
255+
@overload(extract_row)
256+
def ol_extract_row(a, n, tup):
233257
d = len(tup.types)
234258
dd = {}
235259
s = "def extract_row(a, n, tup): return ({},)".format(
@@ -240,8 +264,12 @@ def extract_row(a, n, tup):
240264

241265

242266
# find closest point inside the grid domain
243-
@generated_jit
244267
def project(grid, point):
268+
pass
269+
270+
271+
@overload(project)
272+
def ol_project(grid, point):
245273
s = "def __project(grid, point):\n"
246274
d = len(grid.types)
247275
for i in range(d):

interpolation/multilinear/mlinterp.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,24 @@
2929
)
3030

3131
from numba import njit
32+
from numba.extending import overload
3233
from typing import Tuple
3334

3435
from ..compat import UniTuple, Tuple, Float, Integer, Array
3536

3637
Scalar = (Float, Integer)
3738

3839
import numpy as np
39-
from numba import generated_jit
4040

4141
# logic of multilinear interpolation
4242

4343

44-
@generated_jit
4544
def mlinterp(grid, c, u):
45+
pass
46+
47+
48+
@overload(mlinterp)
49+
def ol_mlinterp(grid, c, u):
4650
if isinstance(u, UniTuple):
4751

4852
def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:
@@ -213,11 +217,12 @@ def {funname}(*args):
213217
return source
214218

215219

216-
from numba import generated_jit
220+
def interp(*args):
221+
pass
217222

218223

219-
@generated_jit(nopython=True)
220-
def interp(*args):
224+
@overload(interp)
225+
def ol_interp(*args):
221226
aa = args[0].types
222227

223228
it = detect_types(aa)

interpolation/multilinear/tests/test_multilinear.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
from numpy import linspace, array
22
from numpy.random import random
33
from numba import typeof
4+
from numba import njit
45

56
import numpy as np
67
from ..fungen import get_index
78

89

10+
@njit
11+
def get_index_njit(gc, x):
12+
return get_index(gc, x)
13+
14+
915
def test_barycentric_indexes():
1016
# irregular grid
1117
gg = np.array([0.0, 1.0])
12-
assert get_index(gg, -0.1) == (0, -0.1)
13-
assert get_index(gg, 0.5) == (0, 0.5)
14-
assert get_index(gg, 1.1) == (0, 1.1)
18+
assert get_index_njit(gg, -0.1) == (0, -0.1)
19+
assert get_index_njit(gg, 0.5) == (0, 0.5)
20+
assert get_index_njit(gg, 1.1) == (0, 1.1)
1521

1622
# regular grid
1723
gg = (0.0, 1.0, 2)
18-
assert get_index(gg, -0.1) == (0, -0.1)
19-
assert get_index(gg, 0.5) == (0, 0.5)
20-
assert get_index(gg, 1.1) == (0, 1.1)
24+
assert get_index_njit(gg, -0.1) == (0, -0.1)
25+
assert get_index_njit(gg, 0.5) == (0, 0.5)
26+
assert get_index_njit(gg, 1.1) == (0, 1.1)
2127

2228

2329
# 2d-vecev-scalar

interpolation/splines/eval_cubic.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy
22

3+
from numba import njit
4+
from numba.extending import overload
35
from .eval_splines import eval_cubic
46

57
## the functions in this file provide backward compatibility calls
@@ -11,19 +13,27 @@
1113
# Compatibility calls #
1214
#######################
1315

14-
from numba import generated_jit
1516
from .codegen import source_to_function
1617

1718

18-
@generated_jit
19-
def get_grid(a, b, n, C):
19+
def _get_grid(a, b, n, C):
20+
pass
21+
22+
23+
@overload(_get_grid)
24+
def ol_get_grid(a, b, n, C):
2025
d = C.ndim
2126
s = "({},)".format(str.join(", ", [f"(a[{k}],b[{k}],n[{k}])" for k in range(d)]))
2227
txt = "def get_grid(a,b,n,C): return {}".format(s)
2328
f = source_to_function(txt)
2429
return f
2530

2631

32+
@njit
33+
def get_grid(a, b, n, C):
34+
return _get_grid(a, b, n, C)
35+
36+
2737
def eval_cubic_spline(a, b, orders, coefs, point):
2838
"""Evaluates a cubic spline at one point
2939

interpolation/splines/eval_splines.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from numba import jit, generated_jit
32
from numpy import zeros
43
from numpy import floor
54

@@ -19,7 +18,6 @@
1918
from interpolation.splines.codegen import get_code_spline, source_to_function
2019
from numba.types import UniTuple, float64, Array
2120
from interpolation.splines.codegen import source_to_function
22-
from numba import generated_jit
2321

2422

2523
from ..compat import Tuple, UniTuple
@@ -50,9 +48,12 @@
5048
### eval spline (main function)
5149

5250

53-
# @generated_jit(inline='always', nopython=True) # doens't work
54-
@generated_jit(nopython=True)
5551
def allocate_output(G, C, P, O):
52+
pass
53+
54+
55+
@overload(allocate_output)
56+
def ol_allocate_output(G, C, P, O):
5657
if C.ndim == len(G) + 1:
5758
# vector valued
5859
if P.ndim == 2:

interpolation/splines/hermite.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def HermiteInterpolationVect(xvect, x: Vector, y: Vector, yp: Vector):
8484

8585
from numba import njit, types
8686
from numba.extending import overload, register_jitable
87-
from numba import generated_jit
8887

8988

9089
def _hermite(x0, x, y, yp, out=None):
@@ -102,8 +101,12 @@ def _hermite(x0, x, y, yp, out=None):
102101
from numba.core.types.misc import NoneType as none
103102

104103

105-
@generated_jit
106104
def hermite(x0, x, y, yp, out=None):
105+
pass
106+
107+
108+
@overload(hermite)
109+
def ol_hermite(x0, x, y, yp, out=None):
107110
try:
108111
n = x0.ndim
109112
if n == 1:

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ maintainers = [
1414
license = "BSD-2-Clause"
1515

1616
[tool.poetry.dependencies]
17-
python = ">=3.9, <=3.12"
18-
numba = "^0.57"
17+
python = ">=3.9"
18+
numba = ">=0.57"
1919
scipy = "^1.10"
2020

2121
[tool.poetry.dev-dependencies]

0 commit comments

Comments
 (0)