Skip to content

Commit d44d025

Browse files
committed
re-add sorting
1 parent a876ce9 commit d44d025

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,18 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple
140140
return _result_type(x, y)
141141

142142
else:
143-
if _builtin_all(isinstance(x, _py_scalars) for x in arrays_and_dtypes):
143+
# sort scalars so that they are treated last
144+
scalars, others = [], []
145+
for x in arrays_and_dtypes:
146+
if isinstance(x, _py_scalars):
147+
scalars.append(x)
148+
else:
149+
others.append(x)
150+
if not others:
144151
raise ValueError("At least one array or dtype must be provided")
145152

146-
return _reduce(_result_type, arrays_and_dtypes)
153+
# combine left-to-right
154+
return _reduce(_result_type, others + scalars)
147155

148156

149157
def _result_type(x, y):

tests/test_torch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Test "unspecified" behavior which we cannot easily test in the Array API test suite.
22
"""
3+
import itertools
4+
35
import pytest
46
import torch
57

@@ -51,7 +53,10 @@ def test_two_args(self):
5153
def test_multi_arg(self):
5254
torch.set_default_dtype(torch.float32)
5355

54-
args = [1, 2, 3j, xp.arange(3), 4, 5, 6]
56+
args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.]
57+
assert xp.result_type(*args) == torch.float16
58+
59+
args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6]
5560
assert xp.result_type(*args) == xp.complex64
5661

5762
args = [1, 2, 3j, xp.float64, 4, 5, 6]
@@ -60,5 +65,10 @@ def test_multi_arg(self):
6065
args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False]
6166
assert xp.result_type(*args) == xp.complex128
6267

68+
i64 = xp.ones(1, dtype=xp.int64)
69+
f16 = xp.ones(1, dtype=xp.float16)
70+
for i in itertools.permutations([i64, f16, 1.0, 1.0]):
71+
assert xp.result_type(*i) == xp.float16, f"{i}"
72+
6373
with pytest.raises(ValueError):
6474
xp.result_type(1, 2, 3, 4)

0 commit comments

Comments
 (0)