🌐 AI搜索 & 代理 主页
Skip to content

Commit 2d6a4d7

Browse files
authored
TYP: Type MaskedArray.__{mul,rmul}__ (#29265)
1 parent 1fefc5c commit 2d6a4d7

File tree

3 files changed

+183
-2
lines changed

3 files changed

+183
-2
lines changed

numpy/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3007,6 +3007,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
30073007
@overload
30083008
def __rsub__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> Any: ...
30093009

3010+
# Keep in sync with `MaskedArray.__mul__`
30103011
@overload
30113012
def __mul__(self: NDArray[_NumberT], other: int | np.bool, /) -> ndarray[_ShapeT_co, dtype[_NumberT]]: ...
30123013
@overload
@@ -3048,6 +3049,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
30483049
@overload
30493050
def __mul__(self: NDArray[Any], other: _ArrayLikeObject_co, /) -> Any: ...
30503051

3052+
# Keep in sync with `MaskedArray.__rmul__`
30513053
@overload # signature equivalent to __mul__
30523054
def __rmul__(self: NDArray[_NumberT], other: int | np.bool, /) -> ndarray[_ShapeT_co, dtype[_NumberT]]: ...
30533055
@overload

numpy/ma/core.pyi

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,8 +658,90 @@ class MaskedArray(ndarray[_ShapeT_co, _DTypeT_co]):
658658
@overload
659659
def __rsub__(self: _MaskedArray[Any], other: _ArrayLikeObject_co, /) -> Any: ...
660660

661-
def __mul__(self, other): ...
662-
def __rmul__(self, other): ...
661+
# Keep in sync with `ndarray.__mul__`
662+
@overload
663+
def __mul__(self: _MaskedArray[_NumberT], other: int | np.bool, /) -> MaskedArray[_ShapeT_co, dtype[_NumberT]]: ...
664+
@overload
665+
def __mul__(self: _MaskedArray[_NumberT], other: _ArrayLikeBool_co, /) -> _MaskedArray[_NumberT]: ... # type: ignore[overload-overlap]
666+
@overload
667+
def __mul__(self: _MaskedArray[np.bool], other: _ArrayLikeBool_co, /) -> _MaskedArray[np.bool]: ... # type: ignore[overload-overlap]
668+
@overload
669+
def __mul__(self: _MaskedArray[np.bool], other: _ArrayLike[_NumberT], /) -> _MaskedArray[_NumberT]: ... # type: ignore[overload-overlap]
670+
@overload
671+
def __mul__(self: _MaskedArray[float64], other: _ArrayLikeFloat64_co, /) -> _MaskedArray[float64]: ...
672+
@overload
673+
def __mul__(self: _MaskedArrayFloat64_co, other: _ArrayLike[floating[_64Bit]], /) -> _MaskedArray[float64]: ...
674+
@overload
675+
def __mul__(self: _MaskedArray[complex128], other: _ArrayLikeComplex128_co, /) -> _MaskedArray[complex128]: ...
676+
@overload
677+
def __mul__(self: _MaskedArrayComplex128_co, other: _ArrayLike[complexfloating[_64Bit]], /) -> _MaskedArray[complex128]: ...
678+
@overload
679+
def __mul__(self: _MaskedArrayUInt_co, other: _ArrayLikeUInt_co, /) -> _MaskedArray[unsignedinteger]: ... # type: ignore[overload-overlap]
680+
@overload
681+
def __mul__(self: _MaskedArrayInt_co, other: _ArrayLikeInt_co, /) -> _MaskedArray[signedinteger]: ... # type: ignore[overload-overlap]
682+
@overload
683+
def __mul__(self: _MaskedArrayFloat_co, other: _ArrayLikeFloat_co, /) -> _MaskedArray[floating]: ... # type: ignore[overload-overlap]
684+
@overload
685+
def __mul__(self: _MaskedArrayComplex_co, other: _ArrayLikeComplex_co, /) -> _MaskedArray[complexfloating]: ... # type: ignore[overload-overlap]
686+
@overload
687+
def __mul__(self: _MaskedArray[number], other: _ArrayLikeNumber_co, /) -> _MaskedArray[number]: ...
688+
@overload
689+
def __mul__(self: _MaskedArray[timedelta64], other: _ArrayLikeFloat_co, /) -> _MaskedArray[timedelta64]: ...
690+
@overload
691+
def __mul__(self: _MaskedArrayFloat_co, other: _ArrayLike[timedelta64], /) -> _MaskedArray[timedelta64]: ...
692+
@overload
693+
def __mul__(
694+
self: MaskedArray[Any, dtype[character] | dtypes.StringDType],
695+
other: _ArrayLikeInt,
696+
/,
697+
) -> MaskedArray[tuple[Any, ...], _DTypeT_co]: ...
698+
@overload
699+
def __mul__(self: _MaskedArray[object_], other: Any, /) -> Any: ...
700+
@overload
701+
def __mul__(self: _MaskedArray[Any], other: _ArrayLikeObject_co, /) -> Any: ...
702+
703+
# Keep in sync with `ndarray.__rmul__`
704+
@overload # signature equivalent to __mul__
705+
def __rmul__(self: _MaskedArray[_NumberT], other: int | np.bool, /) -> MaskedArray[_ShapeT_co, dtype[_NumberT]]: ...
706+
@overload
707+
def __rmul__(self: _MaskedArray[_NumberT], other: _ArrayLikeBool_co, /) -> _MaskedArray[_NumberT]: ... # type: ignore[overload-overlap]
708+
@overload
709+
def __rmul__(self: _MaskedArray[np.bool], other: _ArrayLikeBool_co, /) -> _MaskedArray[np.bool]: ... # type: ignore[overload-overlap]
710+
@overload
711+
def __rmul__(self: _MaskedArray[np.bool], other: _ArrayLike[_NumberT], /) -> _MaskedArray[_NumberT]: ... # type: ignore[overload-overlap]
712+
@overload
713+
def __rmul__(self: _MaskedArray[float64], other: _ArrayLikeFloat64_co, /) -> _MaskedArray[float64]: ...
714+
@overload
715+
def __rmul__(self: _MaskedArrayFloat64_co, other: _ArrayLike[floating[_64Bit]], /) -> _MaskedArray[float64]: ...
716+
@overload
717+
def __rmul__(self: _MaskedArray[complex128], other: _ArrayLikeComplex128_co, /) -> _MaskedArray[complex128]: ...
718+
@overload
719+
def __rmul__(self: _MaskedArrayComplex128_co, other: _ArrayLike[complexfloating[_64Bit]], /) -> _MaskedArray[complex128]: ...
720+
@overload
721+
def __rmul__(self: _MaskedArrayUInt_co, other: _ArrayLikeUInt_co, /) -> _MaskedArray[unsignedinteger]: ... # type: ignore[overload-overlap]
722+
@overload
723+
def __rmul__(self: _MaskedArrayInt_co, other: _ArrayLikeInt_co, /) -> _MaskedArray[signedinteger]: ... # type: ignore[overload-overlap]
724+
@overload
725+
def __rmul__(self: _MaskedArrayFloat_co, other: _ArrayLikeFloat_co, /) -> _MaskedArray[floating]: ... # type: ignore[overload-overlap]
726+
@overload
727+
def __rmul__(self: _MaskedArrayComplex_co, other: _ArrayLikeComplex_co, /) -> _MaskedArray[complexfloating]: ... # type: ignore[overload-overlap]
728+
@overload
729+
def __rmul__(self: _MaskedArray[number], other: _ArrayLikeNumber_co, /) -> _MaskedArray[number]: ...
730+
@overload
731+
def __rmul__(self: _MaskedArray[timedelta64], other: _ArrayLikeFloat_co, /) -> _MaskedArray[timedelta64]: ...
732+
@overload
733+
def __rmul__(self: _MaskedArrayFloat_co, other: _ArrayLike[timedelta64], /) -> _MaskedArray[timedelta64]: ...
734+
@overload
735+
def __rmul__(
736+
self: MaskedArray[Any, dtype[character] | dtypes.StringDType],
737+
other: _ArrayLikeInt,
738+
/,
739+
) -> MaskedArray[tuple[Any, ...], _DTypeT_co]: ...
740+
@overload
741+
def __rmul__(self: _MaskedArray[object_], other: Any, /) -> Any: ...
742+
@overload
743+
def __rmul__(self: _MaskedArray[Any], other: _ArrayLikeObject_co, /) -> Any: ...
744+
663745
def __truediv__(self, other): ...
664746
def __rtruediv__(self, other): ...
665747
def __floordiv__(self, other): ...

numpy/typing/tests/data/reveal/ma.pyi

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,3 +623,100 @@ assert_type(AR_LIKE_c - MAR_o, Any)
623623
assert_type(AR_LIKE_td64 - MAR_o, Any)
624624
assert_type(AR_LIKE_dt64 - MAR_o, Any)
625625
assert_type(AR_LIKE_o - MAR_o, Any)
626+
627+
# Masked Array multiplication
628+
629+
assert_type(MAR_b * AR_LIKE_u, MaskedArray[np.uint32])
630+
assert_type(MAR_b * AR_LIKE_i, MaskedArray[np.signedinteger])
631+
assert_type(MAR_b * AR_LIKE_f, MaskedArray[np.floating])
632+
assert_type(MAR_b * AR_LIKE_c, MaskedArray[np.complexfloating])
633+
assert_type(MAR_b * AR_LIKE_td64, MaskedArray[np.timedelta64])
634+
assert_type(MAR_b * AR_LIKE_o, Any)
635+
636+
# Ignore due to https://github.com/python/mypy/issues/19341
637+
assert_type(AR_LIKE_u * MAR_b, MaskedArray[np.uint32]) # type: ignore[assert-type]
638+
assert_type(AR_LIKE_i * MAR_b, MaskedArray[np.signedinteger]) # type: ignore[assert-type]
639+
assert_type(AR_LIKE_f * MAR_b, MaskedArray[np.floating]) # type: ignore[assert-type]
640+
assert_type(AR_LIKE_c * MAR_b, MaskedArray[np.complexfloating]) # type: ignore[assert-type]
641+
assert_type(AR_LIKE_td64 * MAR_b, MaskedArray[np.timedelta64]) # type: ignore[assert-type]
642+
assert_type(AR_LIKE_o * MAR_b, Any) # type: ignore[assert-type]
643+
644+
assert_type(MAR_u4 * AR_LIKE_b, MaskedArray[np.uint32])
645+
assert_type(MAR_u4 * AR_LIKE_u, MaskedArray[np.unsignedinteger])
646+
assert_type(MAR_u4 * AR_LIKE_i, MaskedArray[np.signedinteger])
647+
assert_type(MAR_u4 * AR_LIKE_f, MaskedArray[np.floating])
648+
assert_type(MAR_u4 * AR_LIKE_c, MaskedArray[np.complexfloating])
649+
assert_type(MAR_u4 * AR_LIKE_td64, MaskedArray[np.timedelta64])
650+
assert_type(MAR_u4 * AR_LIKE_o, Any)
651+
652+
assert_type(MAR_i8 * AR_LIKE_b, MaskedArray[np.int64])
653+
assert_type(MAR_i8 * AR_LIKE_u, MaskedArray[np.signedinteger])
654+
assert_type(MAR_i8 * AR_LIKE_i, MaskedArray[np.signedinteger])
655+
assert_type(MAR_i8 * AR_LIKE_f, MaskedArray[np.floating])
656+
assert_type(MAR_i8 * AR_LIKE_c, MaskedArray[np.complexfloating])
657+
assert_type(MAR_i8 * AR_LIKE_td64, MaskedArray[np.timedelta64])
658+
assert_type(MAR_i8 * AR_LIKE_o, Any)
659+
660+
assert_type(MAR_f8 * AR_LIKE_b, MaskedArray[np.float64])
661+
assert_type(MAR_f8 * AR_LIKE_u, MaskedArray[np.float64])
662+
assert_type(MAR_f8 * AR_LIKE_i, MaskedArray[np.float64])
663+
assert_type(MAR_f8 * AR_LIKE_f, MaskedArray[np.float64])
664+
assert_type(MAR_f8 * AR_LIKE_c, MaskedArray[np.complexfloating])
665+
assert_type(MAR_f8 * AR_LIKE_o, Any)
666+
667+
# Ignore due to https://github.com/python/mypy/issues/19341
668+
assert_type(AR_LIKE_b * MAR_f8, MaskedArray[np.float64]) # type: ignore[assert-type]
669+
assert_type(AR_LIKE_u * MAR_f8, MaskedArray[np.float64]) # type: ignore[assert-type]
670+
assert_type(AR_LIKE_i * MAR_f8, MaskedArray[np.float64]) # type: ignore[assert-type]
671+
assert_type(AR_LIKE_f * MAR_f8, MaskedArray[np.float64]) # type: ignore[assert-type]
672+
assert_type(AR_LIKE_c * MAR_f8, MaskedArray[np.complexfloating]) # type: ignore[assert-type]
673+
assert_type(AR_LIKE_o * MAR_f8, Any) # type: ignore[assert-type]
674+
675+
assert_type(MAR_c16 * AR_LIKE_b, MaskedArray[np.complex128])
676+
assert_type(MAR_c16 * AR_LIKE_u, MaskedArray[np.complex128])
677+
assert_type(MAR_c16 * AR_LIKE_i, MaskedArray[np.complex128])
678+
assert_type(MAR_c16 * AR_LIKE_f, MaskedArray[np.complex128])
679+
assert_type(MAR_c16 * AR_LIKE_c, MaskedArray[np.complex128])
680+
assert_type(MAR_c16 * AR_LIKE_o, Any)
681+
682+
# Ignore due to https://github.com/python/mypy/issues/19341
683+
assert_type(AR_LIKE_b * MAR_c16, MaskedArray[np.complex128]) # type: ignore[assert-type]
684+
assert_type(AR_LIKE_u * MAR_c16, MaskedArray[np.complex128]) # type: ignore[assert-type]
685+
assert_type(AR_LIKE_i * MAR_c16, MaskedArray[np.complex128]) # type: ignore[assert-type]
686+
assert_type(AR_LIKE_f * MAR_c16, MaskedArray[np.complex128]) # type: ignore[assert-type]
687+
assert_type(AR_LIKE_c * MAR_c16, MaskedArray[np.complex128]) # type: ignore[assert-type]
688+
assert_type(AR_LIKE_o * MAR_c16, Any) # type: ignore[assert-type]
689+
690+
assert_type(MAR_td64 * AR_LIKE_b, MaskedArray[np.timedelta64])
691+
assert_type(MAR_td64 * AR_LIKE_u, MaskedArray[np.timedelta64])
692+
assert_type(MAR_td64 * AR_LIKE_i, MaskedArray[np.timedelta64])
693+
assert_type(MAR_td64 * AR_LIKE_o, Any)
694+
695+
# Ignore due to https://github.com/python/mypy/issues/19341
696+
assert_type(AR_LIKE_b * MAR_td64, MaskedArray[np.timedelta64]) # type: ignore[assert-type]
697+
assert_type(AR_LIKE_u * MAR_td64, MaskedArray[np.timedelta64]) # type: ignore[assert-type]
698+
assert_type(AR_LIKE_i * MAR_td64, MaskedArray[np.timedelta64]) # type: ignore[assert-type]
699+
assert_type(AR_LIKE_td64 * MAR_td64, MaskedArray[np.timedelta64]) # type: ignore[assert-type]
700+
assert_type(AR_LIKE_dt64 * MAR_td64, MaskedArray[np.datetime64]) # type: ignore[assert-type]
701+
assert_type(AR_LIKE_o * MAR_td64, Any) # type: ignore[assert-type]
702+
703+
assert_type(AR_LIKE_o * MAR_dt64, Any) # type: ignore[assert-type]
704+
705+
assert_type(MAR_o * AR_LIKE_b, Any)
706+
assert_type(MAR_o * AR_LIKE_u, Any)
707+
assert_type(MAR_o * AR_LIKE_i, Any)
708+
assert_type(MAR_o * AR_LIKE_f, Any)
709+
assert_type(MAR_o * AR_LIKE_c, Any)
710+
assert_type(MAR_o * AR_LIKE_td64, Any)
711+
assert_type(MAR_o * AR_LIKE_dt64, Any)
712+
assert_type(MAR_o * AR_LIKE_o, Any)
713+
714+
# Ignore due to https://github.com/python/mypy/issues/19341
715+
assert_type(AR_LIKE_b * MAR_o, Any) # type: ignore[assert-type]
716+
assert_type(AR_LIKE_u * MAR_o, Any) # type: ignore[assert-type]
717+
assert_type(AR_LIKE_i * MAR_o, Any) # type: ignore[assert-type]
718+
assert_type(AR_LIKE_f * MAR_o, Any) # type: ignore[assert-type]
719+
assert_type(AR_LIKE_c * MAR_o, Any) # type: ignore[assert-type]
720+
assert_type(AR_LIKE_td64 * MAR_o, Any) # type: ignore[assert-type]
721+
assert_type(AR_LIKE_dt64 * MAR_o, Any) # type: ignore[assert-type]
722+
assert_type(AR_LIKE_o * MAR_o, Any) # type: ignore[assert-type]

0 commit comments

Comments
 (0)