diff --git a/numpy/ma/core.pyi b/numpy/ma/core.pyi index 8461c54aa5f1..cccb76fd61da 100644 --- a/numpy/ma/core.pyi +++ b/numpy/ma/core.pyi @@ -1205,10 +1205,25 @@ identity: _convert2ma indices: _convert2ma ones: _convert2ma ones_like: _convert2ma -squeeze: _convert2ma + +@overload +def squeeze( + a: _ScalarT, + axis: _ShapeLike | None = ..., +) -> _ScalarT: ... +@overload +def squeeze( + a: _ArrayLike[_ScalarT], + axis: _ShapeLike | None = ..., +) -> _MaskedArray[_ScalarT]: ... +@overload +def squeeze( + a: ArrayLike, + axis: _ShapeLike | None = ..., +) -> _MaskedArray[Any]: ... + zeros: _convert2ma zeros_like: _convert2ma def append(a, b, axis=...): ... def dot(a, b, strict=..., out=...): ... -def mask_rowcols(a, axis=...): ... diff --git a/numpy/ma/extras.pyi b/numpy/ma/extras.pyi index ba76f3517526..3b7fec060ead 100644 --- a/numpy/ma/extras.pyi +++ b/numpy/ma/extras.pyi @@ -1,10 +1,18 @@ +from typing import Any, SupportsIndex, overload + from _typeshed import Incomplete import numpy as np from numpy.lib._function_base_impl import average from numpy.lib._index_tricks_impl import AxisConcatenator +from numpy.typing import ArrayLike +from numpy._typing import _ArrayLike -from .core import MaskedArray, dot +from .core import ( + MaskedArray, + dot, + _ScalarT_co, +) __all__ = [ "apply_along_axis", @@ -96,8 +104,17 @@ def compress_nd(x, axis=...): ... def compress_rowcols(x, axis=...): ... def compress_rows(a): ... def compress_cols(a): ... -def mask_rows(a, axis = ...): ... -def mask_cols(a, axis = ...): ... + +@overload +def mask_rows(a: _ArrayLike[_ScalarT_co]) -> MaskedArray[tuple[int, int], np.dtype[_ScalarT_co]]: ... +@overload +def mask_rows(a: ArrayLike) -> MaskedArray[tuple[int, int], np.dtype[Any]]: ... + +@overload +def mask_cols(a: _ArrayLike[_ScalarT_co]) -> MaskedArray[tuple[int, int], np.dtype[_ScalarT_co]]: ... +@overload +def mask_cols(a: ArrayLike) -> MaskedArray[tuple[int, int], np.dtype[Any]]: ... + def ediff1d(arr, to_end=..., to_begin=...): ... def unique(ar1, return_index=..., return_inverse=...): ... def intersect1d(ar1, ar2, assume_unique=...): ... @@ -130,5 +147,7 @@ def clump_masked(a): ... def vander(x, n=...): ... def polyfit(x, y, deg, rcond=..., full=..., w=..., cov=...): ... -# -def mask_rowcols(a: Incomplete, axis: Incomplete | None = None) -> MaskedArray[Incomplete, np.dtype[Incomplete]]: ... +@overload +def mask_rowcols(a: _ArrayLike[_ScalarT_co], axis: SupportsIndex | None = None) -> MaskedArray[tuple[int, int], np.dtype[_ScalarT_co]]: ... +@overload +def mask_rowcols(a: ArrayLike, axis: SupportsIndex | None = None) -> MaskedArray[tuple[int, int], np.dtype]: ... diff --git a/numpy/typing/tests/data/fail/ma.pyi b/numpy/typing/tests/data/fail/ma.pyi index e0b7df5b6ab7..04360e6a25b9 100644 --- a/numpy/typing/tests/data/fail/ma.pyi +++ b/numpy/typing/tests/data/fail/ma.pyi @@ -5,8 +5,8 @@ import numpy.ma import numpy.typing as npt m: np.ma.MaskedArray[tuple[int], np.dtype[np.float64]] - AR_b: npt.NDArray[np.bool] +MAR_2d_f4: np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]] m.shape = (3, 1) # E: Incompatible types in assignment m.dtype = np.bool # E: Incompatible types in assignment @@ -115,4 +115,12 @@ np.ma.put(m, 4, 999, mode='flip') # E: No overload variant np.ma.put([1,1,3], 0, 999) # E: No overload variant +np.ma.squeeze(m, 1.0) # E: No overload variant + +np.ma.mask_rows(MAR_2d_f4, axis=0) # E: No overload variant + +np.ma.mask_cols(MAR_2d_f4, axis=1) # E: No overload variant + +np.ma.mask_rowcols(MAR_2d_f4, axis='broccoli') # E: No overload variant + np.ma.compressed(lambda: 'compress me') # E: No overload variant diff --git a/numpy/typing/tests/data/reveal/ma.pyi b/numpy/typing/tests/data/reveal/ma.pyi index 0a2fe6d593ec..757f78038240 100644 --- a/numpy/typing/tests/data/reveal/ma.pyi +++ b/numpy/typing/tests/data/reveal/ma.pyi @@ -29,6 +29,8 @@ MAR_V: MaskedNDArray[np.void] MAR_subclass: MaskedNDArraySubclass MAR_1d: np.ma.MaskedArray[tuple[int], np.dtype] +MAR_2d: np.ma.MaskedArray[tuple[int, int], np.dtype[Any]] +MAR_2d_f4: np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]] b: np.bool f4: np.float32 @@ -264,12 +266,38 @@ assert_type(np.ma.put(MAR_f4, 4, 999, mode='clip'), None) assert_type(np.ma.putmask(MAR_f4, [True, False], [0, 1]), None) +assert_type(np.ma.squeeze(b), np.bool) +assert_type(np.ma.squeeze(f4), np.float32) +assert_type(np.ma.squeeze(f), MaskedNDArray[Any]) +assert_type(np.ma.squeeze(MAR_b), MaskedNDArray[np.bool]) +assert_type(np.ma.squeeze(AR_f4), MaskedNDArray[np.float32]) + +assert_type(np.ma.mask_rows(MAR_2d_f4), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.ma.mask_rows(MAR_f4), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.ma.mask_rows([[1,2,3]]), np.ma.MaskedArray[tuple[int, int], np.dtype]) +# PyRight detects this one correctly, but mypy doesn't. +assert_type(np.ma.mask_rows(MAR_2d), np.ma.MaskedArray[tuple[int, int], np.dtype]) # type: ignore[assert-type] + +assert_type(np.ma.mask_cols(MAR_2d_f4), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.ma.mask_cols(MAR_f4), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.ma.mask_cols([[1,2,3]]), np.ma.MaskedArray[tuple[int, int], np.dtype]) +# PyRight detects this one correctly, but mypy doesn't. +assert_type(np.ma.mask_cols(MAR_2d), np.ma.MaskedArray[tuple[int, int], np.dtype]) # type: ignore[assert-type] + +assert_type(np.ma.mask_rowcols(MAR_2d_f4), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.ma.mask_rowcols(MAR_f4), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.ma.mask_rowcols(MAR_2d_f4, axis=0), np.ma.MaskedArray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.ma.mask_rowcols([[1,2,3]]), np.ma.MaskedArray[tuple[int, int], np.dtype]) +# PyRight detects this one correctly, but mypy doesn't. +assert_type(np.ma.mask_rowcols(MAR_2d), np.ma.MaskedArray[tuple[int, int], np.dtype]) # type: ignore[assert-type] + assert_type(MAR_f4.filled(float('nan')), NDArray[np.float32]) assert_type(MAR_i8.filled(), NDArray[np.int64]) assert_type(MAR_1d.filled(), np.ndarray[tuple[int], np.dtype]) assert_type(np.ma.filled(MAR_f4, float('nan')), NDArray[np.float32]) assert_type(np.ma.filled([[1,2,3]]), NDArray[Any]) +assert_type(np.ma.filled(MAR_2d_f4), np.ndarray[tuple[int, int], np.dtype[np.float32]]) # PyRight detects this one correctly, but mypy doesn't. # https://github.com/numpy/numpy/pull/28742#discussion_r2048968375 assert_type(np.ma.filled(MAR_1d), np.ndarray[tuple[int], np.dtype]) # type: ignore[assert-type]