🌐 AI搜索 & 代理 主页
Skip to content

Commit c9e7592

Browse files
authored
BUG: Fix dtype refcount in __array__ (#29715)
* BUG: Fix `dtype` refcount in `__array__` * Consider all possible code paths * Remove else-if branch * Move refcount checks to a separate test * Add code comments * Add missing `Py_DECREF` for error path * Apply review comments
1 parent 61a6b61 commit c9e7592

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

numpy/_core/src/multiarray/methods.c

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,11 @@ array_getarray(PyArrayObject *self, PyObject *args, PyObject *kwds)
934934
return NULL;
935935
}
936936

937+
if (newtype == NULL) {
938+
newtype = PyArray_DESCR(self);
939+
Py_INCREF(newtype); // newtype is owned.
940+
}
941+
937942
/* convert to PyArray_Type */
938943
if (!PyArray_CheckExact(self)) {
939944
PyArrayObject *new;
@@ -951,6 +956,7 @@ array_getarray(PyArrayObject *self, PyObject *args, PyObject *kwds)
951956
(PyObject *)self
952957
);
953958
if (new == NULL) {
959+
Py_DECREF(newtype);
954960
return NULL;
955961
}
956962
self = new;
@@ -960,22 +966,21 @@ array_getarray(PyArrayObject *self, PyObject *args, PyObject *kwds)
960966
}
961967

962968
if (copy == NPY_COPY_ALWAYS) {
963-
if (newtype == NULL) {
964-
newtype = PyArray_DESCR(self);
965-
}
966-
ret = PyArray_CastToType(self, newtype, 0);
969+
ret = PyArray_CastToType(self, newtype, 0); // steals newtype reference
967970
Py_DECREF(self);
968971
return ret;
969972
} else { // copy == NPY_COPY_IF_NEEDED || copy == NPY_COPY_NEVER
970-
if (newtype == NULL || PyArray_EquivTypes(PyArray_DESCR(self), newtype)) {
973+
if (PyArray_EquivTypes(PyArray_DESCR(self), newtype)) {
974+
Py_DECREF(newtype);
971975
return (PyObject *)self;
972976
}
973977
if (copy == NPY_COPY_IF_NEEDED) {
974-
ret = PyArray_CastToType(self, newtype, 0);
978+
ret = PyArray_CastToType(self, newtype, 0); // steals newtype reference.
975979
Py_DECREF(self);
976980
return ret;
977981
} else { // copy == NPY_COPY_NEVER
978982
PyErr_SetString(PyExc_ValueError, npy_no_copy_err_msg);
983+
Py_DECREF(newtype);
979984
Py_DECREF(self);
980985
return NULL;
981986
}

numpy/_core/tests/test_api.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_array_array():
9191
# instead we get a array([...], dtype=">V18")
9292
assert_equal(bytes(np.array(o).data), bytes(a.data))
9393

94-
# test array
94+
# test __array__
9595
def custom__array__(self, dtype=None, copy=None):
9696
return np.array(100.0, dtype=dtype, copy=copy)
9797

@@ -157,6 +157,39 @@ def custom__array__(self, dtype=None, copy=None):
157157
assert_equal(np.array([(1.0,) * 10] * 10, dtype=np.float64),
158158
np.ones((10, 10), dtype=np.float64))
159159

160+
161+
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
162+
def test___array___refcount():
163+
class MyArray:
164+
def __init__(self, dtype):
165+
self.val = np.array(-1, dtype=dtype)
166+
167+
def __array__(self, dtype=None, copy=None):
168+
return self.val.__array__(dtype=dtype, copy=copy)
169+
170+
# test all possible scenarios:
171+
# dtype(none | same | different) x copy(true | false | none)
172+
dt = np.dtype(np.int32)
173+
old_refcount = sys.getrefcount(dt)
174+
np.array(MyArray(dt))
175+
assert_equal(old_refcount, sys.getrefcount(dt))
176+
np.array(MyArray(dt), dtype=dt)
177+
assert_equal(old_refcount, sys.getrefcount(dt))
178+
np.array(MyArray(dt), copy=None)
179+
assert_equal(old_refcount, sys.getrefcount(dt))
180+
np.array(MyArray(dt), dtype=dt, copy=None)
181+
assert_equal(old_refcount, sys.getrefcount(dt))
182+
dt2 = np.dtype(np.int16)
183+
old_refcount2 = sys.getrefcount(dt2)
184+
np.array(MyArray(dt), dtype=dt2)
185+
assert_equal(old_refcount2, sys.getrefcount(dt2))
186+
np.array(MyArray(dt), dtype=dt2, copy=None)
187+
assert_equal(old_refcount2, sys.getrefcount(dt2))
188+
with pytest.raises(ValueError):
189+
np.array(MyArray(dt), dtype=dt2, copy=False)
190+
assert_equal(old_refcount2, sys.getrefcount(dt2))
191+
192+
160193
@pytest.mark.parametrize("array", [True, False])
161194
def test_array_impossible_casts(array):
162195
# All builtin types can be forcibly cast, at least theoretically,

0 commit comments

Comments
 (0)