🌐 AI搜索 & 代理 主页
Skip to content

Commit ba27d95

Browse files
authored
BUG: add bounds-checking to in-place string multiply (#29060)
* BUG: add bounds-checking to in-place string multiply * MNT: check for overflow and raise OverflowError * MNT: respond to review suggestion * MNT: handle overflow in one more spot * MNT: make test behave the same on all architectures * MNT: reorder to avoid work in some cases
1 parent f1e7527 commit ba27d95

File tree

8 files changed

+74
-23
lines changed

8 files changed

+74
-23
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
* Multiplication between a string and integer now raises OverflowError instead
2+
of MemoryError if the result of the multiplication would create a string that
3+
is too large to be represented. This follows Python's behavior.

numpy/_core/src/umath/string_buffer.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,18 @@ struct Buffer {
297297
return num_codepoints;
298298
}
299299

300+
inline size_t
301+
buffer_width()
302+
{
303+
switch (enc) {
304+
case ENCODING::ASCII:
305+
case ENCODING::UTF8:
306+
return after - buf;
307+
case ENCODING::UTF32:
308+
return (after - buf) / sizeof(npy_ucs4);
309+
}
310+
}
311+
300312
inline Buffer<enc>&
301313
operator+=(npy_int64 rhs)
302314
{

numpy/_core/src/umath/string_ufuncs.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "dtypemeta.h"
1616
#include "convert_datatype.h"
1717
#include "gil_utils.h"
18+
#include "templ_common.h" /* for npy_mul_size_with_overflow_size_t */
1819

1920
#include "string_ufuncs.h"
2021
#include "string_fastsearch.h"
@@ -166,26 +167,44 @@ string_add(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out)
166167

167168

168169
template <ENCODING enc>
169-
static inline void
170+
static inline int
170171
string_multiply(Buffer<enc> buf1, npy_int64 reps, Buffer<enc> out)
171172
{
172173
size_t len1 = buf1.num_codepoints();
173174
if (reps < 1 || len1 == 0) {
174175
out.buffer_fill_with_zeros_after_index(0);
175-
return;
176+
return 0;
176177
}
177178

178179
if (len1 == 1) {
179180
out.buffer_memset(*buf1, reps);
180181
out.buffer_fill_with_zeros_after_index(reps);
182+
return 0;
181183
}
182-
else {
183-
for (npy_int64 i = 0; i < reps; i++) {
184-
buf1.buffer_memcpy(out, len1);
185-
out += len1;
186-
}
187-
out.buffer_fill_with_zeros_after_index(0);
184+
185+
size_t newlen;
186+
if (NPY_UNLIKELY(npy_mul_with_overflow_size_t(&newlen, reps, len1) != 0) || newlen > PY_SSIZE_T_MAX) {
187+
return -1;
188+
}
189+
190+
size_t pad = 0;
191+
size_t width = out.buffer_width();
192+
if (width < newlen) {
193+
reps = width / len1;
194+
pad = width % len1;
188195
}
196+
197+
for (npy_int64 i = 0; i < reps; i++) {
198+
buf1.buffer_memcpy(out, len1);
199+
out += len1;
200+
}
201+
202+
buf1.buffer_memcpy(out, pad);
203+
out += pad;
204+
205+
out.buffer_fill_with_zeros_after_index(0);
206+
207+
return 0;
189208
}
190209

191210

@@ -238,7 +257,9 @@ string_multiply_strint_loop(PyArrayMethod_Context *context,
238257
while (N--) {
239258
Buffer<enc> buf(in1, elsize);
240259
Buffer<enc> outbuf(out, outsize);
241-
string_multiply<enc>(buf, *(npy_int64 *)in2, outbuf);
260+
if (NPY_UNLIKELY(string_multiply<enc>(buf, *(npy_int64 *)in2, outbuf) < 0)) {
261+
npy_gil_error(PyExc_OverflowError, "Overflow detected in string multiply");
262+
}
242263

243264
in1 += strides[0];
244265
in2 += strides[1];
@@ -267,7 +288,9 @@ string_multiply_intstr_loop(PyArrayMethod_Context *context,
267288
while (N--) {
268289
Buffer<enc> buf(in2, elsize);
269290
Buffer<enc> outbuf(out, outsize);
270-
string_multiply<enc>(buf, *(npy_int64 *)in1, outbuf);
291+
if (NPY_UNLIKELY(string_multiply<enc>(buf, *(npy_int64 *)in1, outbuf) < 0)) {
292+
npy_gil_error(PyExc_OverflowError, "Overflow detected in string multiply");
293+
}
271294

272295
in1 += strides[0];
273296
in2 += strides[1];
@@ -752,10 +775,11 @@ string_multiply_resolve_descriptors(
752775
if (given_descrs[2] == NULL) {
753776
PyErr_SetString(
754777
PyExc_TypeError,
755-
"The 'out' kwarg is necessary. Use numpy.strings.multiply without it.");
778+
"The 'out' kwarg is necessary when using the string multiply ufunc "
779+
"directly. Use numpy.strings.multiply to multiply strings without "
780+
"specifying 'out'.");
756781
return _NPY_ERROR_OCCURRED_IN_CAST;
757782
}
758-
759783
loop_descrs[0] = NPY_DT_CALL_ensure_canonical(given_descrs[0]);
760784
if (loop_descrs[0] == NULL) {
761785
return _NPY_ERROR_OCCURRED_IN_CAST;

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,9 @@ static int multiply_loop_core(
137137
size_t newsize;
138138
int overflowed = npy_mul_with_overflow_size_t(
139139
&newsize, cursize, factor);
140-
if (overflowed) {
141-
npy_gil_error(PyExc_MemoryError,
142-
"Failed to allocate string in string multiply");
140+
if (overflowed || newsize > PY_SSIZE_T_MAX) {
141+
npy_gil_error(PyExc_OverflowError,
142+
"Overflow encountered in string multiply");
143143
goto fail;
144144
}
145145

@@ -1748,9 +1748,9 @@ center_ljust_rjust_strided_loop(PyArrayMethod_Context *context,
17481748
width - num_codepoints);
17491749
newsize += s1.size;
17501750

1751-
if (overflowed) {
1752-
npy_gil_error(PyExc_MemoryError,
1753-
"Failed to allocate string in %s", ufunc_name);
1751+
if (overflowed || newsize > PY_SSIZE_T_MAX) {
1752+
npy_gil_error(PyExc_OverflowError,
1753+
"Overflow encountered in %s", ufunc_name);
17541754
goto fail;
17551755
}
17561756

numpy/_core/strings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def multiply(a, i):
218218

219219
# Ensure we can do a_len * i without overflow.
220220
if np.any(a_len > sys.maxsize / np.maximum(i, 1)):
221-
raise MemoryError("repeated string is too long")
221+
raise OverflowError("Overflow encountered in string multiply")
222222

223223
buffersizes = a_len * i
224224
out_dtype = f"{a.dtype.char}{buffersizes.max()}"

numpy/_core/tests/test_stringdtype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def test_null_roundtripping():
128128

129129
def test_string_too_large_error():
130130
arr = np.array(["a", "b", "c"], dtype=StringDType())
131-
with pytest.raises(MemoryError):
132-
arr * (2**63 - 2)
131+
with pytest.raises(OverflowError):
132+
arr * (sys.maxsize + 1)
133133

134134

135135
@pytest.mark.parametrize(

numpy/_core/tests/test_strings.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,20 @@ def test_multiply_raises(self, dt):
224224
with pytest.raises(TypeError, match="unsupported type"):
225225
np.strings.multiply(np.array("abc", dtype=dt), 3.14)
226226

227-
with pytest.raises(MemoryError):
227+
with pytest.raises(OverflowError):
228228
np.strings.multiply(np.array("abc", dtype=dt), sys.maxsize)
229229

230+
def test_inplace_multiply(self, dt):
231+
arr = np.array(['foo ', 'bar'], dtype=dt)
232+
arr *= 2
233+
if dt != "T":
234+
assert_array_equal(arr, np.array(['foo ', 'barb'], dtype=dt))
235+
else:
236+
assert_array_equal(arr, ['foo foo ', 'barbar'])
237+
238+
with pytest.raises(OverflowError):
239+
arr *= sys.maxsize
240+
230241
@pytest.mark.parametrize("i_dt", [np.int8, np.int16, np.int32,
231242
np.int64, np.int_])
232243
def test_multiply_integer_dtypes(self, i_dt, dt):

numpy/typing/tests/data/pass/ma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
MAR_M_dt64: MaskedArray[np.datetime64] = np.ma.MaskedArray([np.datetime64(1, "D")])
1717
MAR_S: MaskedArray[np.bytes_] = np.ma.MaskedArray([b'foo'], dtype=np.bytes_)
1818
MAR_U: MaskedArray[np.str_] = np.ma.MaskedArray(['foo'], dtype=np.str_)
19-
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType], np.ma.MaskedArray(["a"], "T"))
19+
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType],
20+
np.ma.MaskedArray(["a"], dtype="T"))
2021

2122
AR_b: npt.NDArray[np.bool] = np.array([True, False, True])
2223

0 commit comments

Comments
 (0)