|
15 | 15 | #include "dtypemeta.h" |
16 | 16 | #include "convert_datatype.h" |
17 | 17 | #include "gil_utils.h" |
| 18 | +#include "templ_common.h" /* for npy_mul_size_with_overflow_size_t */ |
18 | 19 |
|
19 | 20 | #include "string_ufuncs.h" |
20 | 21 | #include "string_fastsearch.h" |
@@ -166,26 +167,44 @@ string_add(Buffer<enc> buf1, Buffer<enc> buf2, Buffer<enc> out) |
166 | 167 |
|
167 | 168 |
|
168 | 169 | template <ENCODING enc> |
169 | | -static inline void |
| 170 | +static inline int |
170 | 171 | string_multiply(Buffer<enc> buf1, npy_int64 reps, Buffer<enc> out) |
171 | 172 | { |
172 | 173 | size_t len1 = buf1.num_codepoints(); |
173 | 174 | if (reps < 1 || len1 == 0) { |
174 | 175 | out.buffer_fill_with_zeros_after_index(0); |
175 | | - return; |
| 176 | + return 0; |
176 | 177 | } |
177 | 178 |
|
178 | 179 | if (len1 == 1) { |
179 | 180 | out.buffer_memset(*buf1, reps); |
180 | 181 | out.buffer_fill_with_zeros_after_index(reps); |
| 182 | + return 0; |
181 | 183 | } |
182 | | - else { |
183 | | - for (npy_int64 i = 0; i < reps; i++) { |
184 | | - buf1.buffer_memcpy(out, len1); |
185 | | - out += len1; |
186 | | - } |
187 | | - out.buffer_fill_with_zeros_after_index(0); |
| 184 | + |
| 185 | + size_t newlen; |
| 186 | + if (NPY_UNLIKELY(npy_mul_with_overflow_size_t(&newlen, reps, len1) != 0) || newlen > PY_SSIZE_T_MAX) { |
| 187 | + return -1; |
| 188 | + } |
| 189 | + |
| 190 | + size_t pad = 0; |
| 191 | + size_t width = out.buffer_width(); |
| 192 | + if (width < newlen) { |
| 193 | + reps = width / len1; |
| 194 | + pad = width % len1; |
188 | 195 | } |
| 196 | + |
| 197 | + for (npy_int64 i = 0; i < reps; i++) { |
| 198 | + buf1.buffer_memcpy(out, len1); |
| 199 | + out += len1; |
| 200 | + } |
| 201 | + |
| 202 | + buf1.buffer_memcpy(out, pad); |
| 203 | + out += pad; |
| 204 | + |
| 205 | + out.buffer_fill_with_zeros_after_index(0); |
| 206 | + |
| 207 | + return 0; |
189 | 208 | } |
190 | 209 |
|
191 | 210 |
|
@@ -238,7 +257,9 @@ string_multiply_strint_loop(PyArrayMethod_Context *context, |
238 | 257 | while (N--) { |
239 | 258 | Buffer<enc> buf(in1, elsize); |
240 | 259 | Buffer<enc> outbuf(out, outsize); |
241 | | - string_multiply<enc>(buf, *(npy_int64 *)in2, outbuf); |
| 260 | + if (NPY_UNLIKELY(string_multiply<enc>(buf, *(npy_int64 *)in2, outbuf) < 0)) { |
| 261 | + npy_gil_error(PyExc_OverflowError, "Overflow detected in string multiply"); |
| 262 | + } |
242 | 263 |
|
243 | 264 | in1 += strides[0]; |
244 | 265 | in2 += strides[1]; |
@@ -267,7 +288,9 @@ string_multiply_intstr_loop(PyArrayMethod_Context *context, |
267 | 288 | while (N--) { |
268 | 289 | Buffer<enc> buf(in2, elsize); |
269 | 290 | Buffer<enc> outbuf(out, outsize); |
270 | | - string_multiply<enc>(buf, *(npy_int64 *)in1, outbuf); |
| 291 | + if (NPY_UNLIKELY(string_multiply<enc>(buf, *(npy_int64 *)in1, outbuf) < 0)) { |
| 292 | + npy_gil_error(PyExc_OverflowError, "Overflow detected in string multiply"); |
| 293 | + } |
271 | 294 |
|
272 | 295 | in1 += strides[0]; |
273 | 296 | in2 += strides[1]; |
@@ -752,10 +775,11 @@ string_multiply_resolve_descriptors( |
752 | 775 | if (given_descrs[2] == NULL) { |
753 | 776 | PyErr_SetString( |
754 | 777 | PyExc_TypeError, |
755 | | - "The 'out' kwarg is necessary. Use numpy.strings.multiply without it."); |
| 778 | + "The 'out' kwarg is necessary when using the string multiply ufunc " |
| 779 | + "directly. Use numpy.strings.multiply to multiply strings without " |
| 780 | + "specifying 'out'."); |
756 | 781 | return _NPY_ERROR_OCCURRED_IN_CAST; |
757 | 782 | } |
758 | | - |
759 | 783 | loop_descrs[0] = NPY_DT_CALL_ensure_canonical(given_descrs[0]); |
760 | 784 | if (loop_descrs[0] == NULL) { |
761 | 785 | return _NPY_ERROR_OCCURRED_IN_CAST; |
|
0 commit comments