@@ -199,7 +199,7 @@ def __array__(
199
199
# NumPy behavior
200
200
201
201
def _check_allowed_dtypes (
202
- self , other : Array | complex , dtype_category : str , op : str
202
+ self , other : Array | bool | int | float | complex , dtype_category : str , op : str
203
203
) -> Array :
204
204
"""
205
205
Helper function for operators to only allow specific input dtypes
@@ -241,7 +241,7 @@ def _check_allowed_dtypes(
241
241
242
242
return other
243
243
244
- def _check_device (self , other : Array | complex ) -> None :
244
+ def _check_device (self , other : Array | bool | int | float | complex ) -> None :
245
245
"""Check that other is on a device compatible with the current array"""
246
246
if isinstance (other , (bool , int , float , complex )):
247
247
return
@@ -252,7 +252,7 @@ def _check_device(self, other: Array | complex) -> None:
252
252
raise TypeError (f"Expected Array | python scalar; got { type (other )} " )
253
253
254
254
# Helper function to match the type promotion rules in the spec
255
- def _promote_scalar (self , scalar : complex ) -> Array :
255
+ def _promote_scalar (self , scalar : bool | int | float | complex ) -> Array :
256
256
"""
257
257
Returns a promoted version of a Python scalar appropriate for use with
258
258
operations on self.
@@ -546,7 +546,7 @@ def __abs__(self) -> Array:
546
546
res = self ._array .__abs__ ()
547
547
return self .__class__ ._new (res , device = self .device )
548
548
549
- def __add__ (self , other : Array | complex , / ) -> Array :
549
+ def __add__ (self , other : Array | int | float | complex , / ) -> Array :
550
550
"""
551
551
Performs the operation __add__.
552
552
"""
@@ -558,7 +558,7 @@ def __add__(self, other: Array | complex, /) -> Array:
558
558
res = self ._array .__add__ (other ._array )
559
559
return self .__class__ ._new (res , device = self .device )
560
560
561
- def __and__ (self , other : Array | int , / ) -> Array :
561
+ def __and__ (self , other : Array | bool | int , / ) -> Array :
562
562
"""
563
563
Performs the operation __and__.
564
564
"""
@@ -655,7 +655,7 @@ def __dlpack_device__(self) -> tuple[IntEnum, int]:
655
655
# Note: device support is required for this
656
656
return self ._array .__dlpack_device__ ()
657
657
658
- def __eq__ (self , other : Array | complex , / ) -> Array : # type: ignore[override]
658
+ def __eq__ (self , other : Array | bool | int | float | complex , / ) -> Array : # type: ignore[override]
659
659
"""
660
660
Performs the operation __eq__.
661
661
"""
@@ -681,7 +681,7 @@ def __float__(self) -> float:
681
681
res = self ._array .__float__ ()
682
682
return res
683
683
684
- def __floordiv__ (self , other : Array | float , / ) -> Array :
684
+ def __floordiv__ (self , other : Array | int | float , / ) -> Array :
685
685
"""
686
686
Performs the operation __floordiv__.
687
687
"""
@@ -693,7 +693,7 @@ def __floordiv__(self, other: Array | float, /) -> Array:
693
693
res = self ._array .__floordiv__ (other ._array )
694
694
return self .__class__ ._new (res , device = self .device )
695
695
696
- def __ge__ (self , other : Array | float , / ) -> Array :
696
+ def __ge__ (self , other : Array | int | float , / ) -> Array :
697
697
"""
698
698
Performs the operation __ge__.
699
699
"""
@@ -728,7 +728,7 @@ def __getitem__(
728
728
res = self ._array .__getitem__ (np_key )
729
729
return self ._new (res , device = self .device )
730
730
731
- def __gt__ (self , other : Array | float , / ) -> Array :
731
+ def __gt__ (self , other : Array | int | float , / ) -> Array :
732
732
"""
733
733
Performs the operation __gt__.
734
734
"""
@@ -783,7 +783,7 @@ def __iter__(self) -> Iterator[Array]:
783
783
# implemented, which implies iteration on 1-D arrays.
784
784
return (Array ._new (i , device = self .device ) for i in self ._array )
785
785
786
- def __le__ (self , other : Array | float , / ) -> Array :
786
+ def __le__ (self , other : Array | int | float , / ) -> Array :
787
787
"""
788
788
Performs the operation __le__.
789
789
"""
@@ -807,7 +807,7 @@ def __lshift__(self, other: Array | int, /) -> Array:
807
807
res = self ._array .__lshift__ (other ._array )
808
808
return self .__class__ ._new (res , device = self .device )
809
809
810
- def __lt__ (self , other : Array | float , / ) -> Array :
810
+ def __lt__ (self , other : Array | int | float , / ) -> Array :
811
811
"""
812
812
Performs the operation __lt__.
813
813
"""
@@ -832,7 +832,7 @@ def __matmul__(self, other: Array, /) -> Array:
832
832
res = self ._array .__matmul__ (other ._array )
833
833
return self .__class__ ._new (res , device = self .device )
834
834
835
- def __mod__ (self , other : Array | float , / ) -> Array :
835
+ def __mod__ (self , other : Array | int | float , / ) -> Array :
836
836
"""
837
837
Performs the operation __mod__.
838
838
"""
@@ -844,7 +844,7 @@ def __mod__(self, other: Array | float, /) -> Array:
844
844
res = self ._array .__mod__ (other ._array )
845
845
return self .__class__ ._new (res , device = self .device )
846
846
847
- def __mul__ (self , other : Array | complex , / ) -> Array :
847
+ def __mul__ (self , other : Array | int | float | complex , / ) -> Array :
848
848
"""
849
849
Performs the operation __mul__.
850
850
"""
@@ -856,7 +856,7 @@ def __mul__(self, other: Array | complex, /) -> Array:
856
856
res = self ._array .__mul__ (other ._array )
857
857
return self .__class__ ._new (res , device = self .device )
858
858
859
- def __ne__ (self , other : Array | complex , / ) -> Array : # type: ignore[override]
859
+ def __ne__ (self , other : Array | bool | int | float | complex , / ) -> Array : # type: ignore[override]
860
860
"""
861
861
Performs the operation __ne__.
862
862
"""
@@ -877,7 +877,7 @@ def __neg__(self) -> Array:
877
877
res = self ._array .__neg__ ()
878
878
return self .__class__ ._new (res , device = self .device )
879
879
880
- def __or__ (self , other : Array | int , / ) -> Array :
880
+ def __or__ (self , other : Array | bool | int , / ) -> Array :
881
881
"""
882
882
Performs the operation __or__.
883
883
"""
@@ -898,7 +898,7 @@ def __pos__(self) -> Array:
898
898
res = self ._array .__pos__ ()
899
899
return self .__class__ ._new (res , device = self .device )
900
900
901
- def __pow__ (self , other : Array | complex , / ) -> Array :
901
+ def __pow__ (self , other : Array | int | float | complex , / ) -> Array :
902
902
"""
903
903
Performs the operation __pow__.
904
904
"""
@@ -942,7 +942,7 @@ def __setitem__(
942
942
np_key = key ._array if isinstance (key , Array ) else key
943
943
self ._array .__setitem__ (np_key , asarray (value )._array )
944
944
945
- def __sub__ (self , other : Array | complex , / ) -> Array :
945
+ def __sub__ (self , other : Array | int | float | complex , / ) -> Array :
946
946
"""
947
947
Performs the operation __sub__.
948
948
"""
@@ -956,7 +956,7 @@ def __sub__(self, other: Array | complex, /) -> Array:
956
956
957
957
# PEP 484 requires int to be a subtype of float, but __truediv__ should
958
958
# not accept int.
959
- def __truediv__ (self , other : Array | complex , / ) -> Array :
959
+ def __truediv__ (self , other : Array | int | float | complex , / ) -> Array :
960
960
"""
961
961
Performs the operation __truediv__.
962
962
"""
@@ -968,7 +968,7 @@ def __truediv__(self, other: Array | complex, /) -> Array:
968
968
res = self ._array .__truediv__ (other ._array )
969
969
return self .__class__ ._new (res , device = self .device )
970
970
971
- def __xor__ (self , other : Array | int , / ) -> Array :
971
+ def __xor__ (self , other : Array | bool | int , / ) -> Array :
972
972
"""
973
973
Performs the operation __xor__.
974
974
"""
@@ -980,7 +980,7 @@ def __xor__(self, other: Array | int, /) -> Array:
980
980
res = self ._array .__xor__ (other ._array )
981
981
return self .__class__ ._new (res , device = self .device )
982
982
983
- def __iadd__ (self , other : Array | complex , / ) -> Array :
983
+ def __iadd__ (self , other : Array | int | float | complex , / ) -> Array :
984
984
"""
985
985
Performs the operation __iadd__.
986
986
"""
@@ -991,7 +991,7 @@ def __iadd__(self, other: Array | complex, /) -> Array:
991
991
self ._array .__iadd__ (other ._array )
992
992
return self
993
993
994
- def __radd__ (self , other : Array | complex , / ) -> Array :
994
+ def __radd__ (self , other : Array | int | float | complex , / ) -> Array :
995
995
"""
996
996
Performs the operation __radd__.
997
997
"""
@@ -1003,7 +1003,7 @@ def __radd__(self, other: Array | complex, /) -> Array:
1003
1003
res = self ._array .__radd__ (other ._array )
1004
1004
return self .__class__ ._new (res , device = self .device )
1005
1005
1006
- def __iand__ (self , other : Array | int , / ) -> Array :
1006
+ def __iand__ (self , other : Array | bool | int , / ) -> Array :
1007
1007
"""
1008
1008
Performs the operation __iand__.
1009
1009
"""
@@ -1014,7 +1014,7 @@ def __iand__(self, other: Array | int, /) -> Array:
1014
1014
self ._array .__iand__ (other ._array )
1015
1015
return self
1016
1016
1017
- def __rand__ (self , other : Array | int , / ) -> Array :
1017
+ def __rand__ (self , other : Array | bool | int , / ) -> Array :
1018
1018
"""
1019
1019
Performs the operation __rand__.
1020
1020
"""
@@ -1026,7 +1026,7 @@ def __rand__(self, other: Array | int, /) -> Array:
1026
1026
res = self ._array .__rand__ (other ._array )
1027
1027
return self .__class__ ._new (res , device = self .device )
1028
1028
1029
- def __ifloordiv__ (self , other : Array | float , / ) -> Array :
1029
+ def __ifloordiv__ (self , other : Array | int | float , / ) -> Array :
1030
1030
"""
1031
1031
Performs the operation __ifloordiv__.
1032
1032
"""
@@ -1037,7 +1037,7 @@ def __ifloordiv__(self, other: Array | float, /) -> Array:
1037
1037
self ._array .__ifloordiv__ (other ._array )
1038
1038
return self
1039
1039
1040
- def __rfloordiv__ (self , other : Array | float , / ) -> Array :
1040
+ def __rfloordiv__ (self , other : Array | int | float , / ) -> Array :
1041
1041
"""
1042
1042
Performs the operation __rfloordiv__.
1043
1043
"""
@@ -1098,7 +1098,7 @@ def __rmatmul__(self, other: Array, /) -> Array:
1098
1098
res = self ._array .__rmatmul__ (other ._array )
1099
1099
return self .__class__ ._new (res , device = self .device )
1100
1100
1101
- def __imod__ (self , other : Array | float , / ) -> Array :
1101
+ def __imod__ (self , other : Array | int | float , / ) -> Array :
1102
1102
"""
1103
1103
Performs the operation __imod__.
1104
1104
"""
@@ -1108,7 +1108,7 @@ def __imod__(self, other: Array | float, /) -> Array:
1108
1108
self ._array .__imod__ (other ._array )
1109
1109
return self
1110
1110
1111
- def __rmod__ (self , other : Array | float , / ) -> Array :
1111
+ def __rmod__ (self , other : Array | int | float , / ) -> Array :
1112
1112
"""
1113
1113
Performs the operation __rmod__.
1114
1114
"""
@@ -1120,7 +1120,7 @@ def __rmod__(self, other: Array | float, /) -> Array:
1120
1120
res = self ._array .__rmod__ (other ._array )
1121
1121
return self .__class__ ._new (res , device = self .device )
1122
1122
1123
- def __imul__ (self , other : Array | complex , / ) -> Array :
1123
+ def __imul__ (self , other : Array | int | float | complex , / ) -> Array :
1124
1124
"""
1125
1125
Performs the operation __imul__.
1126
1126
"""
@@ -1130,7 +1130,7 @@ def __imul__(self, other: Array | complex, /) -> Array:
1130
1130
self ._array .__imul__ (other ._array )
1131
1131
return self
1132
1132
1133
- def __rmul__ (self , other : Array | complex , / ) -> Array :
1133
+ def __rmul__ (self , other : Array | int | float | complex , / ) -> Array :
1134
1134
"""
1135
1135
Performs the operation __rmul__.
1136
1136
"""
@@ -1142,7 +1142,7 @@ def __rmul__(self, other: Array | complex, /) -> Array:
1142
1142
res = self ._array .__rmul__ (other ._array )
1143
1143
return self .__class__ ._new (res , device = self .device )
1144
1144
1145
- def __ior__ (self , other : Array | int , / ) -> Array :
1145
+ def __ior__ (self , other : Array | bool | int , / ) -> Array :
1146
1146
"""
1147
1147
Performs the operation __ior__.
1148
1148
"""
@@ -1152,7 +1152,7 @@ def __ior__(self, other: Array | int, /) -> Array:
1152
1152
self ._array .__ior__ (other ._array )
1153
1153
return self
1154
1154
1155
- def __ror__ (self , other : Array | int , / ) -> Array :
1155
+ def __ror__ (self , other : Array | bool | int , / ) -> Array :
1156
1156
"""
1157
1157
Performs the operation __ror__.
1158
1158
"""
@@ -1164,7 +1164,7 @@ def __ror__(self, other: Array | int, /) -> Array:
1164
1164
res = self ._array .__ror__ (other ._array )
1165
1165
return self .__class__ ._new (res , device = self .device )
1166
1166
1167
- def __ipow__ (self , other : Array | complex , / ) -> Array :
1167
+ def __ipow__ (self , other : Array | int | float | complex , / ) -> Array :
1168
1168
"""
1169
1169
Performs the operation __ipow__.
1170
1170
"""
@@ -1174,7 +1174,7 @@ def __ipow__(self, other: Array | complex, /) -> Array:
1174
1174
self ._array .__ipow__ (other ._array )
1175
1175
return self
1176
1176
1177
- def __rpow__ (self , other : Array | complex , / ) -> Array :
1177
+ def __rpow__ (self , other : Array | int | float | complex , / ) -> Array :
1178
1178
"""
1179
1179
Performs the operation __rpow__.
1180
1180
"""
@@ -1209,7 +1209,7 @@ def __rrshift__(self, other: Array | int, /) -> Array:
1209
1209
res = self ._array .__rrshift__ (other ._array )
1210
1210
return self .__class__ ._new (res , device = self .device )
1211
1211
1212
- def __isub__ (self , other : Array | complex , / ) -> Array :
1212
+ def __isub__ (self , other : Array | int | float | complex , / ) -> Array :
1213
1213
"""
1214
1214
Performs the operation __isub__.
1215
1215
"""
@@ -1219,7 +1219,7 @@ def __isub__(self, other: Array | complex, /) -> Array:
1219
1219
self ._array .__isub__ (other ._array )
1220
1220
return self
1221
1221
1222
- def __rsub__ (self , other : Array | complex , / ) -> Array :
1222
+ def __rsub__ (self , other : Array | int | float | complex , / ) -> Array :
1223
1223
"""
1224
1224
Performs the operation __rsub__.
1225
1225
"""
@@ -1231,7 +1231,7 @@ def __rsub__(self, other: Array | complex, /) -> Array:
1231
1231
res = self ._array .__rsub__ (other ._array )
1232
1232
return self .__class__ ._new (res , device = self .device )
1233
1233
1234
- def __itruediv__ (self , other : Array | complex , / ) -> Array :
1234
+ def __itruediv__ (self , other : Array | int | float | complex , / ) -> Array :
1235
1235
"""
1236
1236
Performs the operation __itruediv__.
1237
1237
"""
@@ -1241,7 +1241,7 @@ def __itruediv__(self, other: Array | complex, /) -> Array:
1241
1241
self ._array .__itruediv__ (other ._array )
1242
1242
return self
1243
1243
1244
- def __rtruediv__ (self , other : Array | complex , / ) -> Array :
1244
+ def __rtruediv__ (self , other : Array | int | float | complex , / ) -> Array :
1245
1245
"""
1246
1246
Performs the operation __rtruediv__.
1247
1247
"""
@@ -1253,7 +1253,7 @@ def __rtruediv__(self, other: Array | complex, /) -> Array:
1253
1253
res = self ._array .__rtruediv__ (other ._array )
1254
1254
return self .__class__ ._new (res , device = self .device )
1255
1255
1256
- def __ixor__ (self , other : Array | int , / ) -> Array :
1256
+ def __ixor__ (self , other : Array | bool | int , / ) -> Array :
1257
1257
"""
1258
1258
Performs the operation __ixor__.
1259
1259
"""
@@ -1263,7 +1263,7 @@ def __ixor__(self, other: Array | int, /) -> Array:
1263
1263
self ._array .__ixor__ (other ._array )
1264
1264
return self
1265
1265
1266
- def __rxor__ (self , other : Array | int , / ) -> Array :
1266
+ def __rxor__ (self , other : Array | bool | int , / ) -> Array :
1267
1267
"""
1268
1268
Performs the operation __rxor__.
1269
1269
"""
0 commit comments