11import numba
22import numpy as np
33from numba import float64 , int64
4- from numba import generated_jit , njit
4+ from numba import njit
55import ast
66
77from numba .extending import overload
@@ -25,8 +25,12 @@ def clamp(x, a, b):
2525
2626
2727# returns the index of a 1d point along a 1d dimension
28- @generated_jit (nopython = True )
2928def get_index (gc , x ):
29+ pass
30+
31+
32+ @overload (get_index )
33+ def ol_get_index (gc , x ):
3034 if gc == t_coord :
3135 # regular coordinate
3236 def fun (gc , x ):
@@ -53,8 +57,12 @@ def fun(gc, x):
5357
5458
5559# returns number of dimension of a dimension
56- @generated_jit (nopython = True )
5760def get_size (gc ):
61+ pass
62+
63+
64+ @overload (get_size )
65+ def ol_get_size (gc ):
5866 if gc == t_coord :
5967 # regular coordinate
6068 def fun (gc ):
@@ -145,8 +153,12 @@ def _map(*args):
145153# funzip(((1,2), (2,3), (4,3))) -> ((1,2,4),(2,3,3))
146154
147155
148- @generated_jit (nopython = True )
149156def funzip (t ):
157+ pass
158+
159+
160+ @overload (funzip )
161+ def ol_funzip (t ):
150162 k = t .count
151163 assert len (set ([e .count for e in t .types ])) == 1
152164 l = t .types [0 ].count
@@ -169,8 +181,12 @@ def print_tuple(t):
169181#####
170182
171183
172- @generated_jit (nopython = True )
173184def get_coeffs (X , I ):
185+ pass
186+
187+
188+ @overload (get_coeffs )
189+ def ol_get_coeffs (X , I ):
174190 if X .ndim > len (I ):
175191 print ("not implemented yet" )
176192 else :
@@ -218,8 +234,12 @@ def gen_tensor_reduction(X, symbs, inds=[]):
218234 return str .join (" + " , exprs )
219235
220236
221- @generated_jit (nopython = True )
222237def tensor_reduction (C , l ):
238+ pass
239+
240+
241+ @overload (tensor_reduction )
242+ def ol_tensor_reduction (C , l ):
223243 d = len (l .types )
224244 ex = gen_tensor_reduction ("C" , ["l[{}]" .format (i ) for i in range (d )])
225245 dd = dict ()
@@ -228,8 +248,12 @@ def tensor_reduction(C, l):
228248 return dd ["tensor_reduction" ]
229249
230250
231- @generated_jit (nopython = True )
232251def extract_row (a , n , tup ):
252+ pass
253+
254+
255+ @overload (extract_row )
256+ def ol_extract_row (a , n , tup ):
233257 d = len (tup .types )
234258 dd = {}
235259 s = "def extract_row(a, n, tup): return ({},)" .format (
@@ -240,8 +264,12 @@ def extract_row(a, n, tup):
240264
241265
242266# find closest point inside the grid domain
243- @generated_jit
244267def project (grid , point ):
268+ pass
269+
270+
271+ @overload (project )
272+ def ol_project (grid , point ):
245273 s = "def __project(grid, point):\n "
246274 d = len (grid .types )
247275 for i in range (d ):
0 commit comments