Skip to content

Commit 7effb68

Browse files
committed
verbose Python scalar types
1 parent faca616 commit 7effb68

8 files changed

+69
-56
lines changed

array_api_strict/_array_object.py

+39-39
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def __array__(
199199
# NumPy behavior
200200

201201
def _check_allowed_dtypes(
202-
self, other: Array | complex, dtype_category: str, op: str
202+
self, other: Array | bool | int | float | complex, dtype_category: str, op: str
203203
) -> Array:
204204
"""
205205
Helper function for operators to only allow specific input dtypes
@@ -241,7 +241,7 @@ def _check_allowed_dtypes(
241241

242242
return other
243243

244-
def _check_device(self, other: Array | complex) -> None:
244+
def _check_device(self, other: Array | bool | int | float | complex) -> None:
245245
"""Check that other is on a device compatible with the current array"""
246246
if isinstance(other, (bool, int, float, complex)):
247247
return
@@ -252,7 +252,7 @@ def _check_device(self, other: Array | complex) -> None:
252252
raise TypeError(f"Expected Array | python scalar; got {type(other)}")
253253

254254
# Helper function to match the type promotion rules in the spec
255-
def _promote_scalar(self, scalar: complex) -> Array:
255+
def _promote_scalar(self, scalar: bool | int | float | complex) -> Array:
256256
"""
257257
Returns a promoted version of a Python scalar appropriate for use with
258258
operations on self.
@@ -546,7 +546,7 @@ def __abs__(self) -> Array:
546546
res = self._array.__abs__()
547547
return self.__class__._new(res, device=self.device)
548548

549-
def __add__(self, other: Array | complex, /) -> Array:
549+
def __add__(self, other: Array | int | float | complex, /) -> Array:
550550
"""
551551
Performs the operation __add__.
552552
"""
@@ -558,7 +558,7 @@ def __add__(self, other: Array | complex, /) -> Array:
558558
res = self._array.__add__(other._array)
559559
return self.__class__._new(res, device=self.device)
560560

561-
def __and__(self, other: Array | int, /) -> Array:
561+
def __and__(self, other: Array | bool | int, /) -> Array:
562562
"""
563563
Performs the operation __and__.
564564
"""
@@ -655,7 +655,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
655655
# Note: device support is required for this
656656
return self._array.__dlpack_device__()
657657

658-
def __eq__(self, other: Array | complex, /) -> Array: # type: ignore[override]
658+
def __eq__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
659659
"""
660660
Performs the operation __eq__.
661661
"""
@@ -681,7 +681,7 @@ def __float__(self) -> float:
681681
res = self._array.__float__()
682682
return res
683683

684-
def __floordiv__(self, other: Array | float, /) -> Array:
684+
def __floordiv__(self, other: Array | int | float, /) -> Array:
685685
"""
686686
Performs the operation __floordiv__.
687687
"""
@@ -693,7 +693,7 @@ def __floordiv__(self, other: Array | float, /) -> Array:
693693
res = self._array.__floordiv__(other._array)
694694
return self.__class__._new(res, device=self.device)
695695

696-
def __ge__(self, other: Array | float, /) -> Array:
696+
def __ge__(self, other: Array | int | float, /) -> Array:
697697
"""
698698
Performs the operation __ge__.
699699
"""
@@ -728,7 +728,7 @@ def __getitem__(
728728
res = self._array.__getitem__(np_key)
729729
return self._new(res, device=self.device)
730730

731-
def __gt__(self, other: Array | float, /) -> Array:
731+
def __gt__(self, other: Array | int | float, /) -> Array:
732732
"""
733733
Performs the operation __gt__.
734734
"""
@@ -783,7 +783,7 @@ def __iter__(self) -> Iterator[Array]:
783783
# implemented, which implies iteration on 1-D arrays.
784784
return (Array._new(i, device=self.device) for i in self._array)
785785

786-
def __le__(self, other: Array | float, /) -> Array:
786+
def __le__(self, other: Array | int | float, /) -> Array:
787787
"""
788788
Performs the operation __le__.
789789
"""
@@ -807,7 +807,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
807807
res = self._array.__lshift__(other._array)
808808
return self.__class__._new(res, device=self.device)
809809

810-
def __lt__(self, other: Array | float, /) -> Array:
810+
def __lt__(self, other: Array | int | float, /) -> Array:
811811
"""
812812
Performs the operation __lt__.
813813
"""
@@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array:
832832
res = self._array.__matmul__(other._array)
833833
return self.__class__._new(res, device=self.device)
834834

835-
def __mod__(self, other: Array | float, /) -> Array:
835+
def __mod__(self, other: Array | int | float, /) -> Array:
836836
"""
837837
Performs the operation __mod__.
838838
"""
@@ -844,7 +844,7 @@ def __mod__(self, other: Array | float, /) -> Array:
844844
res = self._array.__mod__(other._array)
845845
return self.__class__._new(res, device=self.device)
846846

847-
def __mul__(self, other: Array | complex, /) -> Array:
847+
def __mul__(self, other: Array | int | float | complex, /) -> Array:
848848
"""
849849
Performs the operation __mul__.
850850
"""
@@ -856,7 +856,7 @@ def __mul__(self, other: Array | complex, /) -> Array:
856856
res = self._array.__mul__(other._array)
857857
return self.__class__._new(res, device=self.device)
858858

859-
def __ne__(self, other: Array | complex, /) -> Array: # type: ignore[override]
859+
def __ne__(self, other: Array | bool | int | float | complex, /) -> Array: # type: ignore[override]
860860
"""
861861
Performs the operation __ne__.
862862
"""
@@ -877,7 +877,7 @@ def __neg__(self) -> Array:
877877
res = self._array.__neg__()
878878
return self.__class__._new(res, device=self.device)
879879

880-
def __or__(self, other: Array | int, /) -> Array:
880+
def __or__(self, other: Array | bool | int, /) -> Array:
881881
"""
882882
Performs the operation __or__.
883883
"""
@@ -898,7 +898,7 @@ def __pos__(self) -> Array:
898898
res = self._array.__pos__()
899899
return self.__class__._new(res, device=self.device)
900900

901-
def __pow__(self, other: Array | complex, /) -> Array:
901+
def __pow__(self, other: Array | int | float | complex, /) -> Array:
902902
"""
903903
Performs the operation __pow__.
904904
"""
@@ -942,7 +942,7 @@ def __setitem__(
942942
np_key = key._array if isinstance(key, Array) else key
943943
self._array.__setitem__(np_key, asarray(value)._array)
944944

945-
def __sub__(self, other: Array | complex, /) -> Array:
945+
def __sub__(self, other: Array | int | float | complex, /) -> Array:
946946
"""
947947
Performs the operation __sub__.
948948
"""
@@ -956,7 +956,7 @@ def __sub__(self, other: Array | complex, /) -> Array:
956956

957957
# PEP 484 requires int to be a subtype of float, but __truediv__ should
958958
# not accept int.
959-
def __truediv__(self, other: Array | complex, /) -> Array:
959+
def __truediv__(self, other: Array | int | float | complex, /) -> Array:
960960
"""
961961
Performs the operation __truediv__.
962962
"""
@@ -968,7 +968,7 @@ def __truediv__(self, other: Array | complex, /) -> Array:
968968
res = self._array.__truediv__(other._array)
969969
return self.__class__._new(res, device=self.device)
970970

971-
def __xor__(self, other: Array | int, /) -> Array:
971+
def __xor__(self, other: Array | bool | int, /) -> Array:
972972
"""
973973
Performs the operation __xor__.
974974
"""
@@ -980,7 +980,7 @@ def __xor__(self, other: Array | int, /) -> Array:
980980
res = self._array.__xor__(other._array)
981981
return self.__class__._new(res, device=self.device)
982982

983-
def __iadd__(self, other: Array | complex, /) -> Array:
983+
def __iadd__(self, other: Array | int | float | complex, /) -> Array:
984984
"""
985985
Performs the operation __iadd__.
986986
"""
@@ -991,7 +991,7 @@ def __iadd__(self, other: Array | complex, /) -> Array:
991991
self._array.__iadd__(other._array)
992992
return self
993993

994-
def __radd__(self, other: Array | complex, /) -> Array:
994+
def __radd__(self, other: Array | int | float | complex, /) -> Array:
995995
"""
996996
Performs the operation __radd__.
997997
"""
@@ -1003,7 +1003,7 @@ def __radd__(self, other: Array | complex, /) -> Array:
10031003
res = self._array.__radd__(other._array)
10041004
return self.__class__._new(res, device=self.device)
10051005

1006-
def __iand__(self, other: Array | int, /) -> Array:
1006+
def __iand__(self, other: Array | bool | int, /) -> Array:
10071007
"""
10081008
Performs the operation __iand__.
10091009
"""
@@ -1014,7 +1014,7 @@ def __iand__(self, other: Array | int, /) -> Array:
10141014
self._array.__iand__(other._array)
10151015
return self
10161016

1017-
def __rand__(self, other: Array | int, /) -> Array:
1017+
def __rand__(self, other: Array | bool | int, /) -> Array:
10181018
"""
10191019
Performs the operation __rand__.
10201020
"""
@@ -1026,7 +1026,7 @@ def __rand__(self, other: Array | int, /) -> Array:
10261026
res = self._array.__rand__(other._array)
10271027
return self.__class__._new(res, device=self.device)
10281028

1029-
def __ifloordiv__(self, other: Array | float, /) -> Array:
1029+
def __ifloordiv__(self, other: Array | int | float, /) -> Array:
10301030
"""
10311031
Performs the operation __ifloordiv__.
10321032
"""
@@ -1037,7 +1037,7 @@ def __ifloordiv__(self, other: Array | float, /) -> Array:
10371037
self._array.__ifloordiv__(other._array)
10381038
return self
10391039

1040-
def __rfloordiv__(self, other: Array | float, /) -> Array:
1040+
def __rfloordiv__(self, other: Array | int | float, /) -> Array:
10411041
"""
10421042
Performs the operation __rfloordiv__.
10431043
"""
@@ -1098,7 +1098,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
10981098
res = self._array.__rmatmul__(other._array)
10991099
return self.__class__._new(res, device=self.device)
11001100

1101-
def __imod__(self, other: Array | float, /) -> Array:
1101+
def __imod__(self, other: Array | int | float, /) -> Array:
11021102
"""
11031103
Performs the operation __imod__.
11041104
"""
@@ -1108,7 +1108,7 @@ def __imod__(self, other: Array | float, /) -> Array:
11081108
self._array.__imod__(other._array)
11091109
return self
11101110

1111-
def __rmod__(self, other: Array | float, /) -> Array:
1111+
def __rmod__(self, other: Array | int | float, /) -> Array:
11121112
"""
11131113
Performs the operation __rmod__.
11141114
"""
@@ -1120,7 +1120,7 @@ def __rmod__(self, other: Array | float, /) -> Array:
11201120
res = self._array.__rmod__(other._array)
11211121
return self.__class__._new(res, device=self.device)
11221122

1123-
def __imul__(self, other: Array | complex, /) -> Array:
1123+
def __imul__(self, other: Array | int | float | complex, /) -> Array:
11241124
"""
11251125
Performs the operation __imul__.
11261126
"""
@@ -1130,7 +1130,7 @@ def __imul__(self, other: Array | complex, /) -> Array:
11301130
self._array.__imul__(other._array)
11311131
return self
11321132

1133-
def __rmul__(self, other: Array | complex, /) -> Array:
1133+
def __rmul__(self, other: Array | int | float | complex, /) -> Array:
11341134
"""
11351135
Performs the operation __rmul__.
11361136
"""
@@ -1142,7 +1142,7 @@ def __rmul__(self, other: Array | complex, /) -> Array:
11421142
res = self._array.__rmul__(other._array)
11431143
return self.__class__._new(res, device=self.device)
11441144

1145-
def __ior__(self, other: Array | int, /) -> Array:
1145+
def __ior__(self, other: Array | bool | int, /) -> Array:
11461146
"""
11471147
Performs the operation __ior__.
11481148
"""
@@ -1152,7 +1152,7 @@ def __ior__(self, other: Array | int, /) -> Array:
11521152
self._array.__ior__(other._array)
11531153
return self
11541154

1155-
def __ror__(self, other: Array | int, /) -> Array:
1155+
def __ror__(self, other: Array | bool | int, /) -> Array:
11561156
"""
11571157
Performs the operation __ror__.
11581158
"""
@@ -1164,7 +1164,7 @@ def __ror__(self, other: Array | int, /) -> Array:
11641164
res = self._array.__ror__(other._array)
11651165
return self.__class__._new(res, device=self.device)
11661166

1167-
def __ipow__(self, other: Array | complex, /) -> Array:
1167+
def __ipow__(self, other: Array | int | float | complex, /) -> Array:
11681168
"""
11691169
Performs the operation __ipow__.
11701170
"""
@@ -1174,7 +1174,7 @@ def __ipow__(self, other: Array | complex, /) -> Array:
11741174
self._array.__ipow__(other._array)
11751175
return self
11761176

1177-
def __rpow__(self, other: Array | complex, /) -> Array:
1177+
def __rpow__(self, other: Array | int | float | complex, /) -> Array:
11781178
"""
11791179
Performs the operation __rpow__.
11801180
"""
@@ -1209,7 +1209,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
12091209
res = self._array.__rrshift__(other._array)
12101210
return self.__class__._new(res, device=self.device)
12111211

1212-
def __isub__(self, other: Array | complex, /) -> Array:
1212+
def __isub__(self, other: Array | int | float | complex, /) -> Array:
12131213
"""
12141214
Performs the operation __isub__.
12151215
"""
@@ -1219,7 +1219,7 @@ def __isub__(self, other: Array | complex, /) -> Array:
12191219
self._array.__isub__(other._array)
12201220
return self
12211221

1222-
def __rsub__(self, other: Array | complex, /) -> Array:
1222+
def __rsub__(self, other: Array | int | float | complex, /) -> Array:
12231223
"""
12241224
Performs the operation __rsub__.
12251225
"""
@@ -1231,7 +1231,7 @@ def __rsub__(self, other: Array | complex, /) -> Array:
12311231
res = self._array.__rsub__(other._array)
12321232
return self.__class__._new(res, device=self.device)
12331233

1234-
def __itruediv__(self, other: Array | complex, /) -> Array:
1234+
def __itruediv__(self, other: Array | int | float | complex, /) -> Array:
12351235
"""
12361236
Performs the operation __itruediv__.
12371237
"""
@@ -1241,7 +1241,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array:
12411241
self._array.__itruediv__(other._array)
12421242
return self
12431243

1244-
def __rtruediv__(self, other: Array | complex, /) -> Array:
1244+
def __rtruediv__(self, other: Array | int | float | complex, /) -> Array:
12451245
"""
12461246
Performs the operation __rtruediv__.
12471247
"""
@@ -1253,7 +1253,7 @@ def __rtruediv__(self, other: Array | complex, /) -> Array:
12531253
res = self._array.__rtruediv__(other._array)
12541254
return self.__class__._new(res, device=self.device)
12551255

1256-
def __ixor__(self, other: Array | int, /) -> Array:
1256+
def __ixor__(self, other: Array | bool | int, /) -> Array:
12571257
"""
12581258
Performs the operation __ixor__.
12591259
"""
@@ -1263,7 +1263,7 @@ def __ixor__(self, other: Array | int, /) -> Array:
12631263
self._array.__ixor__(other._array)
12641264
return self
12651265

1266-
def __rxor__(self, other: Array | int, /) -> Array:
1266+
def __rxor__(self, other: Array | bool | int, /) -> Array:
12671267
"""
12681268
Performs the operation __rxor__.
12691269
"""

0 commit comments

Comments
 (0)