🌐 AI搜索 & 代理 主页
Skip to content

Commit a046296

Browse files
authored
TYP: Type MaskedArray.{argmin, argmax} and np.ma.{argmin, argmax} (#28638)
* TYP: Type ``MaskedArray.argmin`` and ``MaskedArray.argmax`` * type `axis: SupportsIndex | None`, add test which would have failed * 🎨 * fixup * 🎨 * align sigs
1 parent dfbf9ed commit a046296

File tree

3 files changed

+211
-4
lines changed

3 files changed

+211
-4
lines changed

numpy/ma/core.pyi

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ from _typeshed import Incomplete
77
from typing_extensions import deprecated
88

99
from numpy import (
10+
intp,
1011
_OrderKACF,
1112
amax,
1213
amin,
@@ -465,8 +466,84 @@ class MaskedArray(ndarray[_ShapeType_co, _DType_co]):
465466
def std(self, axis=..., dtype=..., out=..., ddof=..., keepdims=...): ...
466467
def round(self, decimals=..., out=...): ...
467468
def argsort(self, axis=..., kind=..., order=..., endwith=..., fill_value=..., *, stable=...): ...
468-
def argmin(self, axis=..., fill_value=..., out=..., *, keepdims=...): ...
469-
def argmax(self, axis=..., fill_value=..., out=..., *, keepdims=...): ...
469+
470+
# Keep in-sync with np.ma.argmin
471+
@overload
472+
def argmin( # type: ignore[override]
473+
self,
474+
axis: None = None,
475+
fill_value: _ScalarLike_co | None = None,
476+
out: None = None,
477+
*,
478+
keepdims: Literal[False] | _NoValueType = ...,
479+
) -> intp: ...
480+
@overload
481+
def argmin( # type: ignore[override]
482+
self,
483+
axis: SupportsIndex | None = None,
484+
fill_value: _ScalarLike_co | None = None,
485+
out: None = None,
486+
*,
487+
keepdims: bool | _NoValueType = ...,
488+
) -> Any: ...
489+
@overload
490+
def argmin( # type: ignore[override]
491+
self,
492+
axis: SupportsIndex | None = None,
493+
fill_value: _ScalarLike_co | None = None,
494+
*,
495+
out: _ArrayType,
496+
keepdims: bool | _NoValueType = ...,
497+
) -> _ArrayType: ...
498+
@overload
499+
def argmin( # type: ignore[override]
500+
self,
501+
axis: SupportsIndex | None,
502+
fill_value: _ScalarLike_co | None,
503+
out: _ArrayType,
504+
*,
505+
keepdims: bool | _NoValueType = ...,
506+
) -> _ArrayType: ...
507+
508+
# Keep in-sync with np.ma.argmax
509+
@overload
510+
def argmax( # type: ignore[override]
511+
self,
512+
axis: None = None,
513+
fill_value: _ScalarLike_co | None = None,
514+
out: None = None,
515+
*,
516+
keepdims: Literal[False] | _NoValueType = ...,
517+
) -> intp: ...
518+
@overload
519+
def argmax( # type: ignore[override]
520+
self,
521+
axis: SupportsIndex | None = None,
522+
fill_value: _ScalarLike_co | None = None,
523+
out: None = None,
524+
*,
525+
keepdims: bool | _NoValueType = ...,
526+
) -> Any: ...
527+
@overload
528+
def argmax( # type: ignore[override]
529+
self,
530+
axis: SupportsIndex | None = None,
531+
fill_value: _ScalarLike_co | None = None,
532+
*,
533+
out: _ArrayType,
534+
keepdims: bool | _NoValueType = ...,
535+
) -> _ArrayType: ...
536+
@overload
537+
def argmax( # type: ignore[override]
538+
self,
539+
axis: SupportsIndex | None,
540+
fill_value: _ScalarLike_co | None,
541+
out: _ArrayType,
542+
*,
543+
keepdims: bool | _NoValueType = ...,
544+
) -> _ArrayType: ...
545+
546+
#
470547
def sort(self, axis=..., kind=..., order=..., endwith=..., fill_value=..., *, stable=...): ...
471548
@overload
472549
def min( # type: ignore[override]
@@ -801,8 +878,80 @@ swapaxes: _frommethod
801878
trace: _frommethod
802879
var: _frommethod
803880
count: _frommethod
804-
argmin: _frommethod
805-
argmax: _frommethod
881+
882+
@overload
883+
def argmin(
884+
self: ArrayLike,
885+
axis: None = None,
886+
fill_value: _ScalarLike_co | None = None,
887+
out: None = None,
888+
*,
889+
keepdims: Literal[False] | _NoValueType = ...,
890+
) -> intp: ...
891+
@overload
892+
def argmin(
893+
self: ArrayLike,
894+
axis: SupportsIndex | None = None,
895+
fill_value: _ScalarLike_co | None = None,
896+
out: None = None,
897+
*,
898+
keepdims: bool | _NoValueType = ...,
899+
) -> Any: ...
900+
@overload
901+
def argmin(
902+
self: ArrayLike,
903+
axis: SupportsIndex | None = None,
904+
fill_value: _ScalarLike_co | None = None,
905+
*,
906+
out: _ArrayType,
907+
keepdims: bool | _NoValueType = ...,
908+
) -> _ArrayType: ...
909+
@overload
910+
def argmin(
911+
self: ArrayLike,
912+
axis: SupportsIndex | None,
913+
fill_value: _ScalarLike_co | None,
914+
out: _ArrayType,
915+
*,
916+
keepdims: bool | _NoValueType = ...,
917+
) -> _ArrayType: ...
918+
919+
@overload
920+
def argmax(
921+
self: ArrayLike,
922+
axis: None = None,
923+
fill_value: _ScalarLike_co | None = None,
924+
out: None = None,
925+
*,
926+
keepdims: Literal[False] | _NoValueType = ...,
927+
) -> intp: ...
928+
@overload
929+
def argmax(
930+
self: ArrayLike,
931+
axis: SupportsIndex | None = None,
932+
fill_value: _ScalarLike_co | None = None,
933+
out: None = None,
934+
*,
935+
keepdims: bool | _NoValueType = ...,
936+
) -> Any: ...
937+
@overload
938+
def argmax(
939+
self: ArrayLike,
940+
axis: SupportsIndex | None = None,
941+
fill_value: _ScalarLike_co | None = None,
942+
*,
943+
out: _ArrayType,
944+
keepdims: bool | _NoValueType = ...,
945+
) -> _ArrayType: ...
946+
@overload
947+
def argmax(
948+
self: ArrayLike,
949+
axis: SupportsIndex | None,
950+
fill_value: _ScalarLike_co | None,
951+
out: _ArrayType,
952+
*,
953+
keepdims: bool | _NoValueType = ...,
954+
) -> _ArrayType: ...
806955

807956
minimum: _extrema_operation
808957
maximum: _extrema_operation

numpy/typing/tests/data/fail/ma.pyi

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,25 @@ m.ptp(axis=1.0) # E: No overload variant
3737
m.ptp(keepdims=1.0) # E: No overload variant
3838
m.ptp(out=1.0) # E: No overload variant
3939
m.ptp(fill_value=lambda x: 27) # E: No overload variant
40+
41+
m.argmin(axis=1.0) # E: No overload variant
42+
m.argmin(keepdims=1.0) # E: No overload variant
43+
m.argmin(out=1.0) # E: No overload variant
44+
m.argmin(fill_value=lambda x: 27) # E: No overload variant
45+
46+
np.ma.argmin(m, axis=1.0) # E: No overload variant
47+
np.ma.argmin(m, axis=(1,)) # E: No overload variant
48+
np.ma.argmin(m, keepdims=1.0) # E: No overload variant
49+
np.ma.argmin(m, out=1.0) # E: No overload variant
50+
np.ma.argmin(m, fill_value=lambda x: 27) # E: No overload variant
51+
52+
m.argmax(axis=1.0) # E: No overload variant
53+
m.argmax(keepdims=1.0) # E: No overload variant
54+
m.argmax(out=1.0) # E: No overload variant
55+
m.argmax(fill_value=lambda x: 27) # E: No overload variant
56+
57+
np.ma.argmax(m, axis=1.0) # E: No overload variant
58+
np.ma.argmax(m, axis=(0,)) # E: No overload variant
59+
np.ma.argmax(m, keepdims=1.0) # E: No overload variant
60+
np.ma.argmax(m, out=1.0) # E: No overload variant
61+
np.ma.argmax(m, fill_value=lambda x: 27) # E: No overload variant

numpy/typing/tests/data/reveal/ma.pyi

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,39 @@ assert_type(MAR_f4.ptp(keepdims=True), Any)
8282
assert_type(MAR_f4.ptp(out=MAR_subclass), MaskedNDArraySubclass)
8383
assert_type(MAR_f4.ptp(0, MAR_subclass), MaskedNDArraySubclass)
8484
assert_type(MAR_f4.ptp(None, MAR_subclass), MaskedNDArraySubclass)
85+
86+
assert_type(MAR_b.argmin(), np.intp)
87+
assert_type(MAR_f4.argmin(), np.intp)
88+
assert_type(MAR_f4.argmax(fill_value=6.28318, keepdims=False), np.intp)
89+
assert_type(MAR_b.argmin(axis=0), Any)
90+
assert_type(MAR_f4.argmin(axis=0), Any)
91+
assert_type(MAR_b.argmin(keepdims=True), Any)
92+
assert_type(MAR_f4.argmin(out=MAR_subclass), MaskedNDArraySubclass)
93+
assert_type(MAR_f4.argmin(None, None, out=MAR_subclass), MaskedNDArraySubclass)
94+
95+
assert_type(np.ma.argmin(MAR_b), np.intp)
96+
assert_type(np.ma.argmin(MAR_f4), np.intp)
97+
assert_type(np.ma.argmin(MAR_f4, fill_value=6.28318, keepdims=False), np.intp)
98+
assert_type(np.ma.argmin(MAR_b, axis=0), Any)
99+
assert_type(np.ma.argmin(MAR_f4, axis=0), Any)
100+
assert_type(np.ma.argmin(MAR_b, keepdims=True), Any)
101+
assert_type(np.ma.argmin(MAR_f4, out=MAR_subclass), MaskedNDArraySubclass)
102+
assert_type(np.ma.argmin(MAR_f4, None, None, out=MAR_subclass), MaskedNDArraySubclass)
103+
104+
assert_type(MAR_b.argmax(), np.intp)
105+
assert_type(MAR_f4.argmax(), np.intp)
106+
assert_type(MAR_f4.argmax(fill_value=6.28318, keepdims=False), np.intp)
107+
assert_type(MAR_b.argmax(axis=0), Any)
108+
assert_type(MAR_f4.argmax(axis=0), Any)
109+
assert_type(MAR_b.argmax(keepdims=True), Any)
110+
assert_type(MAR_f4.argmax(out=MAR_subclass), MaskedNDArraySubclass)
111+
assert_type(MAR_f4.argmax(None, None, out=MAR_subclass), MaskedNDArraySubclass)
112+
113+
assert_type(np.ma.argmax(MAR_b), np.intp)
114+
assert_type(np.ma.argmax(MAR_f4), np.intp)
115+
assert_type(np.ma.argmax(MAR_f4, fill_value=6.28318, keepdims=False), np.intp)
116+
assert_type(np.ma.argmax(MAR_b, axis=0), Any)
117+
assert_type(np.ma.argmax(MAR_f4, axis=0), Any)
118+
assert_type(np.ma.argmax(MAR_b, keepdims=True), Any)
119+
assert_type(np.ma.argmax(MAR_f4, out=MAR_subclass), MaskedNDArraySubclass)
120+
assert_type(np.ma.argmax(MAR_f4, None, None, out=MAR_subclass), MaskedNDArraySubclass)

0 commit comments

Comments
 (0)