🌐 AI搜索 & 代理 主页
Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
BUG: add bounds-checking to in-place string multiply
  • Loading branch information
ngoldbaum committed May 27, 2025
commit db879703b7edaf5d1200d8bc3784fa1460edc7a3
12 changes: 12 additions & 0 deletions numpy/_core/src/umath/string_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,18 @@ struct Buffer {
return num_codepoints;
}

inline size_t
buffer_width()
{
switch (enc) {
case ENCODING::ASCII:
case ENCODING::UTF8:
return after - buf;
case ENCODING::UTF32:
return (after - buf) / sizeof(npy_ucs4);
}
}

inline Buffer<enc>&
operator+=(npy_int64 rhs)
{
Expand Down
13 changes: 11 additions & 2 deletions numpy/_core/src/umath/string_ufuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ string_multiply(Buffer<enc> buf1, npy_int64 reps, Buffer<enc> out)
return;
}

size_t width = out.buffer_width();
size_t pad = 0;
if (width < len1*reps) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (width < len1*reps) {
if (reps > NPY_MAX_SSIZE_T // ssize_t is enough here
|| npy_mul_sizes_with_overflow(len1, reps, &width)
|| width > out.buffer_width()) {
width = out.buffer_width;

Admittedly this is tedious and pedantic, but maybe we should do overflow checks? npy_mul_sizes_with_overflow uses ssize_t inputs, which is safe here but may require casting to tell the compiler about that. (Or just using the size_t version.)
I suppose NPY_UNLIKELY() will be irrelevant here.

reps = width / len1;
pad = width % len1;
}
if (len1 == 1) {
out.buffer_memset(*buf1, reps);
out.buffer_fill_with_zeros_after_index(reps);
Expand All @@ -184,6 +190,8 @@ string_multiply(Buffer<enc> buf1, npy_int64 reps, Buffer<enc> out)
buf1.buffer_memcpy(out, len1);
out += len1;
}
buf1.buffer_memcpy(out, pad);
out += pad;
out.buffer_fill_with_zeros_after_index(0);
}
}
Expand Down Expand Up @@ -752,10 +760,11 @@ string_multiply_resolve_descriptors(
if (given_descrs[2] == NULL) {
PyErr_SetString(
PyExc_TypeError,
"The 'out' kwarg is necessary. Use numpy.strings.multiply without it.");
"The 'out' kwarg is necessary when using the string multiply ufunc "
"directly. Use numpy.strings.multiply to multiply strings without "
"specifying 'out'.");
return _NPY_ERROR_OCCURRED_IN_CAST;
}

loop_descrs[0] = NPY_DT_CALL_ensure_canonical(given_descrs[0]);
if (loop_descrs[0] == NULL) {
return _NPY_ERROR_OCCURRED_IN_CAST;
Expand Down
8 changes: 8 additions & 0 deletions numpy/_core/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ def test_multiply_raises(self, dt):
with pytest.raises(MemoryError):
np.strings.multiply(np.array("abc", dtype=dt), sys.maxsize)

def test_inplace_multiply(self, dt):
arr = np.array(['foo ', 'bar'], dtype=dt)
arr *= 2
if dt != "T":
assert_array_equal(arr, np.array(['foo ', 'barb'], dtype=dt))
else:
assert_array_equal(arr, ['foo foo ', 'barbar'])

@pytest.mark.parametrize("i_dt", [np.int8, np.int16, np.int32,
np.int64, np.int_])
def test_multiply_integer_dtypes(self, i_dt, dt):
Expand Down
3 changes: 2 additions & 1 deletion numpy/typing/tests/data/pass/ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
MAR_M_dt64: MaskedArray[np.datetime64] = np.ma.MaskedArray([np.datetime64(1, "D")])
MAR_S: MaskedArray[np.bytes_] = np.ma.MaskedArray([b'foo'], dtype=np.bytes_)
MAR_U: MaskedArray[np.str_] = np.ma.MaskedArray(['foo'], dtype=np.str_)
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType], np.ma.MaskedArray(["a"], "T"))
MAR_T = cast(np.ma.MaskedArray[Any, np.dtypes.StringDType],
np.ma.MaskedArray(["a"], dtype="T"))

AR_b: npt.NDArray[np.bool] = np.array([True, False, True])

Expand Down
Loading