Skip to content

Commit 438ac56

Browse files
committed
🚚 port ma.arg{min,max} and MaskedArray.arg{min,max}
1 parent c6fc320 commit 438ac56

File tree

3 files changed

+201
-18
lines changed

3 files changed

+201
-18
lines changed

‎src/numpy-stubs/@test/static/accept/ma.pyi

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from typing import Any, TypeAlias, TypeVar
23
from typing_extensions import assert_type
34

@@ -23,3 +24,35 @@ MAR_f4: MaskedNDArray[np.float32]
2324
MAR_i8: MaskedNDArray[np.int64]
2425
MAR_subclass: MaskedNDArraySubclass
2526
MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype[Any]]
27+
28+
assert_type(MAR_b.argmin(), np.intp)
29+
assert_type(MAR_f4.argmin(), np.intp)
30+
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
31+
assert_type(MAR_b.argmin(axis=0), Any)
32+
assert_type(MAR_f4.argmin(axis=0), Any)
33+
assert_type(MAR_b.argmin(keepdims=True), Any)
34+
assert_type(MAR_f4.argmin(out=MAR_subclass), MaskedNDArraySubclass)
35+
assert_type(MAR_f4.argmin(None, None, out=MAR_subclass), MaskedNDArraySubclass)
36+
37+
assert_type(np.ma.argmin(MAR_b), np.intp)
38+
assert_type(np.ma.argmin(MAR_f4), np.intp)
39+
assert_type(np.ma.argmin(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
40+
assert_type(np.ma.argmin(MAR_b, axis=0), Any)
41+
assert_type(np.ma.argmin(MAR_f4, axis=0), Any)
42+
assert_type(np.ma.argmin(MAR_b, keepdims=True), Any)
43+
44+
assert_type(MAR_b.argmax(), np.intp)
45+
assert_type(MAR_f4.argmax(), np.intp)
46+
assert_type(MAR_f4.argmax(fill_value=math.tau, keepdims=False), np.intp)
47+
assert_type(MAR_b.argmax(axis=0), Any)
48+
assert_type(MAR_f4.argmax(axis=0), Any)
49+
assert_type(MAR_b.argmax(keepdims=True), Any)
50+
assert_type(MAR_f4.argmax(out=MAR_subclass), MaskedNDArraySubclass)
51+
assert_type(MAR_f4.argmax(None, None, out=MAR_subclass), MaskedNDArraySubclass)
52+
53+
assert_type(np.ma.argmax(MAR_b), np.intp)
54+
assert_type(np.ma.argmax(MAR_f4), np.intp)
55+
assert_type(np.ma.argmax(MAR_f4, fill_value=math.tau, keepdims=False), np.intp)
56+
assert_type(np.ma.argmax(MAR_b, axis=0), Any)
57+
assert_type(np.ma.argmax(MAR_f4, axis=0), Any)
58+
assert_type(np.ma.argmax(MAR_b, keepdims=True), Any)

‎src/numpy-stubs/@test/static/reject/ma.pyi

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,17 @@ np.amin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgu
99
np.amin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
1010
np.amin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
1111
np.amin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportCallIssue, reportUnknownLambdaType]
12+
13+
m.argmin(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
14+
m.argmin(keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
15+
m.argmin(out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
16+
m.argmin(fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]
17+
18+
np.ma.argmin(m, axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
19+
np.ma.argmin(m, axis=(1,)) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
20+
np.ma.argmin(m, keepdims=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
21+
np.ma.argmin(m, out=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
22+
np.ma.argmin(m, fill_value=lambda x: 27) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue, reportUnknownLambdaType]
23+
24+
m.argmax(axis=1.0) # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue]
25+

‎src/numpy-stubs/ma/core.pyi

Lines changed: 154 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from typing_extensions import Never, Self, TypeVar, deprecated, overload, overri
44

55
import numpy as np
66
from _numtype import Array, ToGeneric_0d, ToGeneric_1nd, ToGeneric_nd
7-
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims # noqa: ICN003
7+
from numpy import _OrderACF, _OrderKACF, amax, amin, bool_, expand_dims, intp # noqa: ICN003
88
from numpy._globals import _NoValueType
99
from numpy._typing import ArrayLike, _ArrayLike, _BoolCodes, _ScalarLike_co, _ShapeLike
1010

@@ -189,7 +189,12 @@ __all__ = [
189189
"zeros_like",
190190
]
191191

192+
###
193+
192194
_ArrayT = TypeVar("_ArrayT", bound=np.ndarray[Any, Any])
195+
196+
###
197+
193198
_UFuncT_co = TypeVar("_UFuncT_co", bound=np.ufunc, default=np.ufunc, covariant=True)
194199
_ScalarT = TypeVar("_ScalarT", bound=np.generic)
195200
_ShapeT = TypeVar("_ShapeT", bound=tuple[int, ...])
@@ -647,15 +652,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
647652
fill_value: Incomplete = ...,
648653
keepdims: Incomplete = ...,
649654
) -> Incomplete: ...
650-
@override
651-
def argmin( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
655+
656+
# Keep in-sync with np.ma.argmin
657+
@overload # type: ignore[override]
658+
def argmin(
652659
self,
653-
axis: Incomplete = ...,
654-
fill_value: Incomplete = ...,
655-
out: Incomplete = ...,
660+
axis: None = None,
661+
fill_value: _ScalarLike_co | None = None,
662+
out: None = None,
656663
*,
657-
keepdims: Incomplete = ...,
658-
) -> Incomplete: ...
664+
keepdims: L[False] | _NoValueType = ...,
665+
) -> intp: ...
666+
@overload
667+
def argmin(
668+
self,
669+
axis: CanIndex | None = None,
670+
fill_value: _ScalarLike_co | None = None,
671+
out: None = None,
672+
*,
673+
keepdims: bool | _NoValueType = ...,
674+
) -> Any: ...
675+
@overload
676+
def argmin(
677+
self,
678+
axis: CanIndex | None = None,
679+
fill_value: _ScalarLike_co | None = None,
680+
*,
681+
out: _ArrayT,
682+
keepdims: bool | _NoValueType = ...,
683+
) -> _ArrayT: ...
684+
@overload
685+
def argmin( # pyright: ignore[reportIncompatibleMethodOverride]
686+
self,
687+
axis: CanIndex | None,
688+
fill_value: _ScalarLike_co | None,
689+
out: _ArrayT,
690+
*,
691+
keepdims: bool | _NoValueType = ...,
692+
) -> _ArrayT: ...
659693

660694
#
661695
@override
@@ -666,15 +700,44 @@ class MaskedArray(np.ndarray[_ShapeT_co, _DTypeT_co]):
666700
fill_value: Incomplete = ...,
667701
keepdims: Incomplete = ...,
668702
) -> Incomplete: ...
669-
@override
670-
def argmax( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
703+
704+
# Keep in-sync with np.ma.argmax
705+
@overload # type: ignore[override]
706+
def argmax(
671707
self,
672-
axis: Incomplete = ...,
673-
fill_value: Incomplete = ...,
674-
out: Incomplete = ...,
708+
axis: None = None,
709+
fill_value: _ScalarLike_co | None = None,
710+
out: None = None,
675711
*,
676-
keepdims: Incomplete = ...,
677-
) -> Incomplete: ...
712+
keepdims: L[False] | _NoValueType = ...,
713+
) -> intp: ...
714+
@overload
715+
def argmax(
716+
self,
717+
axis: CanIndex | None = None,
718+
fill_value: _ScalarLike_co | None = None,
719+
out: None = None,
720+
*,
721+
keepdims: bool | _NoValueType = ...,
722+
) -> Any: ...
723+
@overload
724+
def argmax(
725+
self,
726+
axis: CanIndex | None = None,
727+
fill_value: _ScalarLike_co | None = None,
728+
*,
729+
out: _ArrayT,
730+
keepdims: bool | _NoValueType = ...,
731+
) -> _ArrayT: ...
732+
@overload
733+
def argmax( # pyright: ignore[reportIncompatibleMethodOverride]
734+
self,
735+
axis: CanIndex | None,
736+
fill_value: _ScalarLike_co | None,
737+
out: _ArrayT,
738+
*,
739+
keepdims: bool | _NoValueType = ...,
740+
) -> _ArrayT: ...
678741

679742
#
680743
@override
@@ -1091,8 +1154,81 @@ swapaxes: _frommethod
10911154
trace: _frommethod
10921155
var: _frommethod
10931156
count: _frommethod
1094-
argmin: _frommethod
1095-
argmax: _frommethod
1096-
10971157
minimum: _extrema_operation
10981158
maximum: _extrema_operation
1159+
1160+
#
1161+
@overload
1162+
def argmin(
1163+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1164+
axis: None = None,
1165+
fill_value: _ScalarLike_co | None = None,
1166+
out: None = None,
1167+
*,
1168+
keepdims: L[False] | _NoValueType = ...,
1169+
) -> intp: ...
1170+
@overload
1171+
def argmin(
1172+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1173+
axis: CanIndex | None = None,
1174+
fill_value: _ScalarLike_co | None = None,
1175+
out: None = None,
1176+
*,
1177+
keepdims: bool | _NoValueType = ...,
1178+
) -> Any: ...
1179+
@overload
1180+
def argmin(
1181+
a: _ArrayT,
1182+
axis: CanIndex | None = None,
1183+
fill_value: _ScalarLike_co | None = None,
1184+
*,
1185+
out: _ArrayT,
1186+
keepdims: bool | _NoValueType = ...,
1187+
) -> _ArrayT: ...
1188+
@overload
1189+
def argmin(
1190+
a: _ArrayT,
1191+
axis: CanIndex | None,
1192+
fill_value: _ScalarLike_co | None,
1193+
out: _ArrayT,
1194+
*,
1195+
keepdims: bool | _NoValueType = ...,
1196+
) -> _ArrayT: ...
1197+
1198+
#
1199+
@overload
1200+
def argmax(
1201+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1202+
axis: None = None,
1203+
fill_value: _ScalarLike_co | None = None,
1204+
out: None = None,
1205+
*,
1206+
keepdims: L[False] | _NoValueType = ...,
1207+
) -> intp: ...
1208+
@overload
1209+
def argmax(
1210+
a: _ArrayT, # pyright: ignore[reportInvalidTypeVarUse]
1211+
axis: CanIndex | None = None,
1212+
fill_value: _ScalarLike_co | None = None,
1213+
out: None = None,
1214+
*,
1215+
keepdims: bool | _NoValueType = ...,
1216+
) -> Any: ...
1217+
@overload
1218+
def argmax(
1219+
a: _ArrayT,
1220+
axis: CanIndex | None = None,
1221+
fill_value: _ScalarLike_co | None = None,
1222+
*,
1223+
out: _ArrayT,
1224+
keepdims: bool | _NoValueType = ...,
1225+
) -> _ArrayT: ...
1226+
@overload
1227+
def argmax(
1228+
a: _ArrayT,
1229+
axis: CanIndex | None,
1230+
fill_value: _ScalarLike_co | None,
1231+
out: _ArrayT,
1232+
*,
1233+
keepdims: bool | _NoValueType = ...,
1234+
) -> _ArrayT: ...

0 commit comments

Comments
 (0)