-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnumba_opt.py
executable file
·81 lines (65 loc) · 2.17 KB
/
numba_opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import functools
__author__ = 'Alex Pyattaev'
import os
from copy import deepcopy
from dataclasses import dataclass, fields, replace
import numpy as np
@dataclass
class deepcopy_compat:
"""Makes it possible to deepcopy dataclasses with numba objects in them. Keep in mind that putting numba instances
into containers will still crash on deepcopy.
Numba jitclass needs to specify the 'copy' method for this to work
"""
def __deepcopy__(self, memodict={}):
items = {}
for f in fields(self):
k = f.name
t = f.type
v = getattr(self, k)
if hasattr(v, "_numba_type_"): # Special handler for numba jitclasses
if hasattr(v, "copy"):
items[k] = v.copy()
else:
raise ValueError(f"Can not deepcopy {k} of type <{t}>: numba classes need 'copy' method defined to be copyable")
else:
items[k] = deepcopy(v, memodict)
return replace(self, **items)
try:
if 'NO_NUMBA' in os.environ:
print("Numba disabled by environment variable")
raise ImportError("Numba is disabled")
import numba
import numba.experimental
numba_available = True
jit_hardcore = functools.partial(numba.jit, nopython=True, nogil=True, cache=True)
njit = numba.njit(nogil=True, cache=True)
njit_nocache = numba.njit(nogil=True, cache=False)
jitclass = numba.experimental.jitclass
int64 = numba.int64
int32 = numba.int32
int16 = numba.int16
double = numba.double
complex128 = numba.complex128
from numba.typed import List as TypedList
vectorize = numba.vectorize
except ImportError:
TypedList = list
numba = None
numba_available = False
int64 = int
int32 = int
int16 = int
double = float
complex128 = complex
vectorize = np.vectorize
# define stub functions for Numba placeholders
def njit(f, *args, **kwargs):
return f
def njit_nocache(f, *args, **kwargs):
return f
def jitclass(c, *args, **kwargs):
def x(cls):
return cls
return x
def jit_hardcore(f, *args, **kwargs):
return f