Skip to content

Commit 4a81428

Browse files
authored
Merge pull request #58 from EconForge/albop/fix_numba_import
FIX: numba import for numba >=0.49
2 parents c650cba + bd5973b commit 4a81428

File tree

4 files changed

+32
-24
lines changed

4 files changed

+32
-24
lines changed

interpolation/compat.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
from distutils.version import LooseVersion
3+
4+
5+
from numba import __version__
6+
7+
if LooseVersion(__version__)>='0.43':
8+
overload_options = {'strict': False}
9+
else:
10+
overload_options = {}
11+
12+
if LooseVersion(__version__)>='0.49':
13+
from numba.types import Tuple, UniTuple
14+
else:
15+
from numba.types.containers import Tuple, UniTuple

interpolation/multilinear/fungen.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,10 @@
55
import ast
66

77
from numba.extending import overload
8-
from numba.types.containers import Tuple, UniTuple
8+
from numba.types import Array
9+
from ..compat import overload_options
10+
from ..compat import Tuple, UniTuple
911

10-
from distutils.version import LooseVersion
11-
from numba import __version__
12-
if LooseVersion(__version__)>='0.43':
13-
overload_options = {'strict': False}
14-
else:
15-
overload_options = {}
1612

1713
# from math import max, min
1814

@@ -215,7 +211,7 @@ def project(grid, point):
215211
s = "def __project(grid, point):\n"
216212
d = len(grid.types)
217213
for i in range(d):
218-
if isinstance(grid.types[i], numba.types.Array):
214+
if isinstance(grid.types[i], Array):
219215
s += f" x_{i} = min(max(point[{i}], grid[{i}][0]), grid[{i}][grid[{i}].shape[0]-1])\n"
220216
else:
221217
s += f" x_{i} = min(max(point[{i}], grid[{i}][0]), grid[{i}][1])\n"

interpolation/multilinear/mlinterp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from numba import njit
2424
from typing import Tuple
2525

26-
from numba.types import UniTuple, Array, float64
27-
from numba.types import Float, Integer
26+
from ..compat import UniTuple, Tuple
27+
from numba.types import Float, Integer, Array
2828
Scalar = (Float, Integer)
2929

3030
import numpy as np

interpolation/splines/eval_splines.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,8 @@
66
from numba import prange
77
from .codegen import get_code_linear, get_code_cubic, source_to_function
88

9-
from distutils.version import LooseVersion
10-
from numba import __version__
11-
if LooseVersion(__version__)>='0.43':
12-
overload_options = {'strict': False}
13-
else:
14-
overload_options = {}
9+
from ..compat import Tuple, UniTuple
10+
from ..compat import overload_options
1511

1612
#
1713

@@ -82,6 +78,7 @@
8278
import numpy as np
8379
from numba import njit
8480
from numba.extending import overload
81+
from numba.types import Array
8582

8683
def _eval_linear():
8784
pass
@@ -97,7 +94,7 @@ def __eval_linear(grid,C,points):
9794
vec_eval = (points.ndim==2)
9895
from math import floor
9996
from numpy import zeros
100-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
97+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
10198
context = {'floor': floor, 'zeros': zeros, 'np': np} #, 'Cd': Ad, 'dCd': dAd}
10299
code = get_code_linear(d, vector_valued=vector_valued, vectorized=vec_eval, allocate=True, grid_types=grid_types)
103100
# print(code)
@@ -113,7 +110,7 @@ def __eval_linear(grid,C,points,extrap_mode):
113110
vec_eval = (points.ndim==2)
114111
from math import floor
115112
from numpy import zeros
116-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
113+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
117114
context = {'floor': floor, 'zeros': zeros, 'np': np} #, 'Cd': Ad, 'dCd': dAd}
118115
# print(f"We are going to extrapolate in {extrap_mode} mode.")
119116
if extrap_mode == t_NEAREST:
@@ -138,7 +135,7 @@ def __eval_linear(grid,C,points,out,extrap_mode):
138135
n_x = len(grid.types)
139136
vector_valued = (C.ndim==d+1)
140137
vec_eval = (points.ndim==2)
141-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
138+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
142139
context = {'floor': floor, 'zeros': zeros, 'np': np} #, 'Cd': Ad, 'dCd': dAd}
143140
if extrap_mode == t_NEAREST:
144141
extrap_ = 'nearest'
@@ -162,7 +159,7 @@ def __eval_linear(grid,C,points,out):
162159
n_x = len(grid.types)
163160
vector_valued = (C.ndim==d+1)
164161
vec_eval = (points.ndim==2)
165-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
162+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
166163
context = {'floor': floor, 'zeros': zeros, 'np': np} #, 'Cd': Ad, 'dCd': dAd}
167164
code = get_code_linear(d, vector_valued=vector_valued, vectorized=vec_eval, allocate=False, grid_types=grid_types)
168165
# print(code)
@@ -193,7 +190,7 @@ def __eval_cubic(grid,C,points):
193190
vec_eval = (points.ndim==2)
194191
from math import floor
195192
from numpy import zeros
196-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
193+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
197194
context = {'floor': floor, 'zeros': zeros, 'Cd': Ad, 'dCd': dAd}
198195
code = get_code_cubic(d, vector_valued=vector_valued, vectorized=vec_eval, allocate=True, grid_types=grid_types)
199196
# print(code)
@@ -209,7 +206,7 @@ def __eval_cubic(grid,C,points,extrap_mode):
209206
vec_eval = (points.ndim==2)
210207
from math import floor
211208
from numpy import zeros
212-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
209+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
213210
context = {'floor': floor, 'zeros': zeros, 'Cd': Ad, 'dCd': dAd}
214211

215212
# print(f"We are going to extrapolate in {extrap_mode} mode.")
@@ -236,7 +233,7 @@ def __eval_cubic(grid,C,points,out,extrap_mode):
236233
n_x = len(grid.types)
237234
vector_valued = (C.ndim==d+1)
238235
vec_eval = (points.ndim==2)
239-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
236+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
240237
context = {'floor': floor, 'zeros': zeros, 'Cd': Ad, 'dCd': dAd}
241238
if extrap_mode == t_NEAREST:
242239
extrap_ = 'nearest'
@@ -260,7 +257,7 @@ def __eval_cubic(grid,C,points,out):
260257
n_x = len(grid.types)
261258
vector_valued = (C.ndim==d+1)
262259
vec_eval = (points.ndim==2)
263-
grid_types = ['nonuniform' if isinstance(tt, numba.types.Array) else 'uniform' for tt in grid.types]
260+
grid_types = ['nonuniform' if isinstance(tt, Array) else 'uniform' for tt in grid.types]
264261
context = {'floor': floor, 'zeros': zeros, 'Cd': Ad, 'dCd': dAd}
265262
code = get_code_cubic(d, vector_valued=vector_valued, vectorized=vec_eval, allocate=False, grid_types=grid_types)
266263
# print(code)

0 commit comments

Comments
 (0)