Skip to content

Commit 3efa4b6

Browse files
authored
Merge pull request #260 from crusaderky/torch_import
MAINT: torch: tweak imports
2 parents 6cf679c + d1d4216 commit 3efa4b6

File tree

1 file changed

+7
-13
lines changed

1 file changed

+7
-13
lines changed

array_api_compat/torch/_aliases.py

+7-13
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@
33
from functools import wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
55

6-
from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
7-
vecdot as _aliases_vecdot,
8-
clip as _aliases_clip,
9-
unstack as _aliases_unstack,
10-
cumulative_sum as _aliases_cumulative_sum,
11-
cumulative_prod as _aliases_cumulative_prod,
12-
)
6+
from ..common import _aliases
137
from .._internal import get_xp
148

159
from ._info import __array_namespace_info__
@@ -215,10 +209,10 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
215209
return torch.clone(x)
216210
return torch.amin(x, axis, keepdims=keepdims)
217211

218-
clip = get_xp(torch)(_aliases_clip)
219-
unstack = get_xp(torch)(_aliases_unstack)
220-
cumulative_sum = get_xp(torch)(_aliases_cumulative_sum)
221-
cumulative_prod = get_xp(torch)(_aliases_cumulative_prod)
212+
clip = get_xp(torch)(_aliases.clip)
213+
unstack = get_xp(torch)(_aliases.unstack)
214+
cumulative_sum = get_xp(torch)(_aliases.cumulative_sum)
215+
cumulative_prod = get_xp(torch)(_aliases.cumulative_prod)
222216

223217
# torch.sort also returns a tuple
224218
# https://github.com/pytorch/pytorch/issues/70921
@@ -710,8 +704,8 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array:
710704
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
711705
return torch.matmul(x1, x2, **kwargs)
712706

713-
matrix_transpose = get_xp(torch)(_aliases_matrix_transpose)
714-
_vecdot = get_xp(torch)(_aliases_vecdot)
707+
matrix_transpose = get_xp(torch)(_aliases.matrix_transpose)
708+
_vecdot = get_xp(torch)(_aliases.vecdot)
715709

716710
def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
717711
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

0 commit comments

Comments
 (0)