Skip to content

Commit c6fc320

Browse files
authored
✨ return specialized numpy.dtypes instances from dtype.__new__ (#483)
1 parent e7382bc commit c6fc320

File tree

2 files changed

+63
-64
lines changed

2 files changed

+63
-64
lines changed

src/numpy-stubs/@test/static/accept/dtype.pyi

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@ import ctypes as ct
22
from typing_extensions import LiteralString, assert_type
33

44
import numpy as np
5-
from numpy.dtypes import StringDType
65

76
dtype_U: np.dtype[np.str_]
87
dtype_V: np.dtype[np.void]
9-
dtype_i8: np.dtype[np.int64]
8+
dtype_i8: np.dtypes.Int64DType
109

1110
py_int_co: type[int]
1211
py_float_co: type[float]
@@ -19,30 +18,30 @@ ct_floating: type[ct.c_float | ct.c_double | ct.c_longdouble]
1918
ct_number: type[ct.c_uint8 | ct.c_float]
2019
ct_generic: type[ct.c_bool | ct.c_char]
2120

22-
dt_string: StringDType
21+
dt_string: np.dtypes.StringDType
2322

24-
assert_type(np.dtype(np.float64), np.dtype[np.float64])
25-
assert_type(np.dtype(np.float64, metadata={"test": "test"}), np.dtype[np.float64])
26-
assert_type(np.dtype(np.int64), np.dtype[np.int64])
23+
assert_type(np.dtype(np.float64), np.dtypes.Float64DType)
24+
assert_type(np.dtype(np.float64, metadata={"test": "test"}), np.dtypes.Float64DType)
25+
assert_type(np.dtype(np.int64), np.dtypes.Int64DType)
2726

2827
# String aliases
29-
assert_type(np.dtype("bool"), np.dtype[np.bool])
30-
assert_type(np.dtype("int32"), np.dtype[np.int32])
31-
assert_type(np.dtype("int64"), np.dtype[np.int64])
32-
assert_type(np.dtype("float32"), np.dtype[np.float32])
33-
assert_type(np.dtype("float64"), np.dtype[np.float64])
34-
assert_type(np.dtype("bytes"), np.dtype[np.bytes_])
35-
assert_type(np.dtype("str"), np.dtype[np.str_])
28+
assert_type(np.dtype("bool"), np.dtypes.BoolDType)
29+
assert_type(np.dtype("int32"), np.dtypes.Int32DType)
30+
assert_type(np.dtype("int64"), np.dtypes.Int64DType)
31+
assert_type(np.dtype("float32"), np.dtypes.Float32DType)
32+
assert_type(np.dtype("float64"), np.dtypes.Float64DType)
33+
assert_type(np.dtype("bytes"), np.dtypes.BytesDType)
34+
assert_type(np.dtype("str"), np.dtypes.StrDType)
3635

3736
# Python types
38-
assert_type(np.dtype(bool), np.dtype[np.bool])
39-
assert_type(np.dtype(int), np.dtype[np.int_])
40-
assert_type(np.dtype(float), np.dtype[np.float64])
41-
assert_type(np.dtype(complex), np.dtype[np.complex128])
42-
assert_type(np.dtype(object), np.dtype[np.object_])
43-
assert_type(np.dtype(str), np.dtype[np.str_])
44-
assert_type(np.dtype(bytes), np.dtype[np.bytes_])
45-
assert_type(np.dtype(memoryview), np.dtype[np.void])
37+
assert_type(np.dtype(bool), np.dtypes.BoolDType)
38+
assert_type(np.dtype(int), np.dtypes.Int64DType)
39+
assert_type(np.dtype(float), np.dtypes.Float64DType)
40+
assert_type(np.dtype(complex), np.dtypes.Complex128DType)
41+
assert_type(np.dtype(object), np.dtypes.ObjectDType)
42+
assert_type(np.dtype(str), np.dtypes.StrDType)
43+
assert_type(np.dtype(bytes), np.dtypes.BytesDType)
44+
assert_type(np.dtype(memoryview), np.dtypes.VoidDType)
4645

4746
assert_type(np.dtype(np.signedinteger), np.dtype[np.signedinteger])
4847
assert_type(np.dtype(np.unsignedinteger), np.dtype[np.unsignedinteger])
@@ -55,37 +54,37 @@ assert_type(np.dtype(np.number), np.dtype[np.number])
5554
assert_type(np.dtype(np.generic), np.dtype[np.generic])
5655

5756
# char-codes
58-
assert_type(np.dtype("u1"), np.dtype[np.uint8])
59-
assert_type(np.dtype("int_"), np.dtype[np.intp])
60-
assert_type(np.dtype("longlong"), np.dtype[np.longlong])
61-
assert_type(np.dtype(">g"), np.dtype[np.longdouble])
57+
assert_type(np.dtype("u1"), np.dtypes.UInt8DType)
58+
assert_type(np.dtype("int_"), np.dtypes.Int64DType)
59+
assert_type(np.dtype("longlong"), np.dtypes.Int64DType)
60+
assert_type(np.dtype(">g"), np.dtypes.LongDoubleDType)
6261

6362
# ctypes
64-
assert_type(np.dtype(ct.c_bool), np.dtype[np.bool])
65-
assert_type(np.dtype(ct.c_uint32), np.dtype[np.uint32])
66-
assert_type(np.dtype(ct.c_ssize_t), np.dtype[np.intp])
67-
assert_type(np.dtype(ct.c_longlong), np.dtype[np.longlong])
68-
assert_type(np.dtype(ct.c_double), np.dtype[np.double])
69-
assert_type(np.dtype(ct.py_object), np.dtype[np.object_])
70-
assert_type(np.dtype(ct.c_char), np.dtype[np.bytes_])
63+
assert_type(np.dtype(ct.c_bool), np.dtypes.BoolDType)
64+
assert_type(np.dtype(ct.c_uint32), np.dtypes.UInt32DType)
65+
assert_type(np.dtype(ct.c_ssize_t), np.dtypes.Int64DType)
66+
assert_type(np.dtype(ct.c_longlong), np.dtypes.Int64DType)
67+
assert_type(np.dtype(ct.c_double), np.dtypes.Float64DType)
68+
assert_type(np.dtype(ct.py_object), np.dtypes.ObjectDType)
69+
assert_type(np.dtype(ct.c_char), np.dtypes.BytesDType)
7170

7271
# Special case for None
73-
assert_type(np.dtype(None), np.dtype[np.float64])
72+
assert_type(np.dtype(None), np.dtypes.Float64DType)
7473

7574
# Dypes of dtypes
76-
assert_type(np.dtype(np.dtype(np.float64)), np.dtype[np.float64])
75+
assert_type(np.dtype(np.dtype(np.float64)), np.dtypes.Float64DType)
7776

7877
# Parameterized dtypes
7978
assert_type(np.dtype("S8"), np.dtype)
8079

8180
# Void
82-
assert_type(np.dtype(("U", 10)), np.dtype[np.void])
81+
assert_type(np.dtype(("U", 10)), np.dtypes.VoidDType)
8382

8483
# StringDType
85-
assert_type(np.dtype(dt_string), StringDType)
86-
assert_type(np.dtype("T"), StringDType)
87-
assert_type(np.dtype("=T"), StringDType)
88-
assert_type(np.dtype("|T"), StringDType)
84+
assert_type(np.dtype(dt_string), np.dtypes.StringDType)
85+
assert_type(np.dtype("T"), np.dtypes.StringDType)
86+
assert_type(np.dtype("=T"), np.dtypes.StringDType)
87+
assert_type(np.dtype("|T"), np.dtypes.StringDType)
8988

9089
# Methods and attributes
9190
assert_type(dtype_U.base, np.dtype)
@@ -100,7 +99,7 @@ assert_type(dtype_U * 1, np.dtype[np.str_])
10099
assert_type(dtype_U * 2, np.dtype[np.str_])
101100

102101
assert_type(dtype_i8 * 0, np.dtype[np.void])
103-
assert_type(dtype_i8 * 1, np.dtype[np.int64])
102+
assert_type(dtype_i8 * 1, np.dtypes.Int64DType)
104103
assert_type(dtype_i8 * 2, np.dtype[np.void])
105104

106105
assert_type(0 * dtype_U, np.dtype[np.str_])

src/numpy-stubs/__init__.pyi

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,99 +1144,99 @@ class dtype(Generic[_ScalarT_co], metaclass=_DTypeMeta):
11441144
@overload
11451145
def __new__(
11461146
cls, dtype: _nt.ToDTypeBool, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1147-
) -> dtype[bool_]: ...
1147+
) -> dtypes.BoolDType: ...
11481148
@overload
11491149
def __new__(
11501150
cls, dtype: _nt.ToDTypeInt8, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1151-
) -> dtype[int8]: ...
1151+
) -> dtypes.Int8DType: ...
11521152
@overload
11531153
def __new__(
11541154
cls, dtype: _nt.ToDTypeUInt8, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1155-
) -> dtype[uint8]: ...
1155+
) -> dtypes.UByteDType: ...
11561156
@overload
11571157
def __new__(
11581158
cls, dtype: _nt.ToDTypeInt16, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1159-
) -> dtype[int16]: ...
1159+
) -> dtypes.Int16DType: ...
11601160
@overload
11611161
def __new__(
11621162
cls, dtype: _nt.ToDTypeUInt16, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1163-
) -> dtype[uint16]: ...
1163+
) -> dtypes.UInt16DType: ...
11641164
@overload
11651165
def __new__(
11661166
cls, dtype: _nt.ToDTypeInt32, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1167-
) -> dtype[int32]: ...
1167+
) -> dtypes.Int32DType: ...
11681168
@overload
11691169
def __new__(
11701170
cls, dtype: _nt.ToDTypeUInt32, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1171-
) -> dtype[uint32]: ...
1171+
) -> dtypes.UInt32DType: ...
11721172
@overload
11731173
def __new__(
11741174
cls, dtype: type[ct.c_long] | _LongCodes, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1175-
) -> dtype[long]: ...
1175+
) -> dtypes.LongDType: ...
11761176
@overload
11771177
def __new__(
11781178
cls, dtype: type[ct.c_ulong] | _ULongCodes, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1179-
) -> dtype[ulong]: ...
1179+
) -> dtypes.ULongDType: ...
11801180
@overload
11811181
def __new__(
11821182
cls, dtype: _nt.ToDTypeInt64, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1183-
) -> dtype[int64]: ...
1183+
) -> dtypes.Int64DType: ...
11841184
@overload
11851185
def __new__(
11861186
cls, dtype: _nt.ToDTypeUInt64, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1187-
) -> dtype[uint64]: ...
1187+
) -> dtypes.UInt64DType: ...
11881188
@overload
11891189
def __new__(
11901190
cls, dtype: _nt.ToDTypeFloat16, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1191-
) -> dtype[float16]: ...
1191+
) -> dtypes.Float16DType: ...
11921192
@overload
11931193
def __new__(
11941194
cls, dtype: _nt.ToDTypeFloat32, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1195-
) -> dtype[float32]: ...
1195+
) -> dtypes.Float32DType: ...
11961196
@overload
11971197
def __new__(
11981198
cls, dtype: _nt.ToDTypeFloat64 | None, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1199-
) -> dtype[float64]: ...
1199+
) -> dtypes.Float64DType: ...
12001200
@overload
12011201
def __new__(
12021202
cls, dtype: _nt.ToDTypeLongDouble, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1203-
) -> dtype[longdouble]: ...
1203+
) -> dtypes.LongDoubleDType: ...
12041204
@overload
12051205
def __new__(
12061206
cls, dtype: _nt.ToDTypeComplex64, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1207-
) -> dtype[complex64]: ...
1207+
) -> dtypes.Complex64DType: ...
12081208
@overload
12091209
def __new__(
12101210
cls, dtype: _nt.ToDTypeComplex128, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1211-
) -> dtype[complex128]: ...
1211+
) -> dtypes.Complex128DType: ...
12121212
@overload
12131213
def __new__(
12141214
cls, dtype: _nt.ToDTypeCLongDouble, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1215-
) -> dtype[clongdouble]: ...
1215+
) -> dtypes.CLongDoubleDType: ...
12161216
@overload
12171217
def __new__(
12181218
cls, dtype: _nt.ToDTypeObject, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1219-
) -> dtype[object_]: ...
1219+
) -> dtypes.ObjectDType: ...
12201220
@overload
12211221
def __new__(
12221222
cls, dtype: _nt.ToDTypeBytes, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1223-
) -> dtype[bytes_]: ...
1223+
) -> dtypes.BytesDType: ...
12241224
@overload
12251225
def __new__( # type: ignore[overload-overlap]
12261226
cls, dtype: _nt.ToDTypeStr, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1227-
) -> dtype[str_]: ...
1227+
) -> dtypes.StrDType: ...
12281228
@overload
12291229
def __new__(
12301230
cls, dtype: _nt.ToDTypeVoid, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1231-
) -> dtype[void]: ...
1231+
) -> dtypes.VoidDType: ...
12321232
@overload
12331233
def __new__(
12341234
cls, dtype: _nt.ToDTypeDateTime64, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1235-
) -> dtype[datetime64]: ...
1235+
) -> dtypes.DateTime64DType: ...
12361236
@overload
12371237
def __new__(
12381238
cls, dtype: _nt.ToDTypeTimeDelta64, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...
1239-
) -> dtype[timedelta64]: ...
1239+
) -> dtypes.TimeDelta64DType: ...
12401240
@overload
12411241
def __new__(
12421242
cls, dtype: _nt.ToDTypeString, align: py_bool = False, copy: py_bool = False, metadata: _MetaData = ...

0 commit comments

Comments
 (0)