Skip to content

Commit 1b87e1a

Browse files
speed up np_coerce
1 parent 2606642 commit 1b87e1a

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

src/anguilla/interpolate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __call__(self, targets:List[Output], scores:List[float], eps:float=1e-6):
122122
return Nearest()(targets, scores)
123123

124124
targets, scores = np_coerce(targets, scores)
125-
125+
126126
scores = scores**0.5
127127

128128
scores = scores + eps

src/anguilla/types.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ class SearchResult(NamedTuple):
2828
scores:Scores
2929

3030
def _np_coerce(x):
31-
if x is None:
32-
return None
31+
if x is None or isinstance(x, np.ndarray):
32+
# no conversion
33+
return x
3334
if hasattr(x, 'numpy'):
3435
# torch tensor, etc
3536
return x.numpy()

0 commit comments

Comments
 (0)