🌐 AI搜索 & 代理 主页
Skip to content

Commit 107d419

Browse files
committed
Consider all possible code paths
1 parent 2214ad2 commit 107d419

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

numpy/_core/src/multiarray/methods.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -962,17 +962,19 @@ array_getarray(PyArrayObject *self, PyObject *args, PyObject *kwds)
962962
if (copy == NPY_COPY_ALWAYS) {
963963
if (newtype == NULL) {
964964
newtype = PyArray_DESCR(self);
965+
Py_INCREF(newtype);
965966
}
966-
Py_INCREF(newtype);
967967
ret = PyArray_CastToType(self, newtype, 0);
968968
Py_DECREF(self);
969969
return ret;
970970
} else { // copy == NPY_COPY_IF_NEEDED || copy == NPY_COPY_NEVER
971-
if (newtype == NULL || PyArray_EquivTypes(PyArray_DESCR(self), newtype)) {
971+
if (newtype == NULL) {
972+
return (PyObject *)self;
973+
} else if (PyArray_EquivTypes(PyArray_DESCR(self), newtype)) {
974+
Py_DECREF(newtype);
972975
return (PyObject *)self;
973976
}
974977
if (copy == NPY_COPY_IF_NEEDED) {
975-
Py_INCREF(newtype);
976978
ret = PyArray_CastToType(self, newtype, 0);
977979
Py_DECREF(self);
978980
return ret;

numpy/_core/tests/test_api.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,28 @@ def custom__array__(self, dtype=None, copy=None):
9999
assert_equal(np.array(o, dtype=np.float64), np.array(100.0, np.float64))
100100
if HAS_REFCOUNT:
101101
class MyArray:
102-
def __init__(self):
103-
self.val = np.array(-1, dtype=dt)
102+
def __init__(self, dtype):
103+
self.val = np.array(-1, dtype=dtype)
104104

105105
def __array__(self, dtype=None, copy=None):
106106
return self.val.__array__(dtype=dtype, copy=copy)
107107

108108
dt = np.dtype(np.int32)
109109
old_refcount = sys.getrefcount(dt)
110-
np.array(MyArray())
110+
np.array(MyArray(dt))
111111
assert_equal(old_refcount, sys.getrefcount(dt))
112+
np.array(MyArray(dt), dtype=dt)
113+
assert_equal(old_refcount, sys.getrefcount(dt))
114+
np.array(MyArray(dt), copy=None)
115+
assert_equal(old_refcount, sys.getrefcount(dt))
116+
np.array(MyArray(dt), dtype=dt, copy=None)
117+
assert_equal(old_refcount, sys.getrefcount(dt))
118+
dt2 = np.dtype(np.int16)
119+
old_refcount2 = sys.getrefcount(dt2)
120+
np.array(MyArray(dt), dtype=dt2)
121+
assert_equal(old_refcount2, sys.getrefcount(dt2))
122+
np.array(MyArray(dt), dtype=dt2, copy=None)
123+
assert_equal(old_refcount2, sys.getrefcount(dt2))
112124

113125
# test recursion
114126
nested = 1.5

0 commit comments

Comments
 (0)