🌐 AI搜索 & 代理 主页
Skip to content

Commit a64d310

Browse files
authored
ENH,REF: New-style generic sorting ArrayMethod loops (#30328)
* ENH,REF: New arraymethod-style default sorting loops * ENH: Implement single default sorting loops in npy_sort.c * STYLE: Add missing newline
1 parent 0faa919 commit a64d310

File tree

5 files changed

+89
-14
lines changed

5 files changed

+89
-14
lines changed

numpy/_core/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,7 @@ src_multiarray_umath_common = [
11211121
'src/common/numpyos.c',
11221122
'src/common/npy_cpu_features.c',
11231123
'src/common/npy_cpu_dispatch.c',
1124+
'src/common/npy_sort.c',
11241125
src_file.process('src/common/templ_common.h.src')
11251126
]
11261127
if have_blas

numpy/_core/src/common/npy_sort.c

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include <numpy/ndarraytypes.h>
2+
#include <stdlib.h>
3+
#include <numpy/npy_math.h>
4+
#include "npy_sort.h"
5+
#include "dtypemeta.h"
6+
7+
#ifdef __cplusplus
8+
extern "C" {
9+
#endif
10+
11+
NPY_NO_EXPORT int
12+
npy_default_sort_loop(PyArrayMethod_Context *context,
13+
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
14+
NpyAuxData *transferdata)
15+
{
16+
PyArray_CompareFunc *cmp = (PyArray_CompareFunc *)context->method->static_data;
17+
18+
PyArrayMethod_SortParameters *sort_params =
19+
(PyArrayMethod_SortParameters *)context->parameters;
20+
PyArray_SortImpl *sort_func = NULL;
21+
22+
switch (sort_params->flags) {
23+
case NPY_SORT_DEFAULT:
24+
sort_func = npy_quicksort_impl;
25+
break;
26+
case NPY_SORT_STABLE:
27+
sort_func = npy_mergesort_impl;
28+
break;
29+
default:
30+
PyErr_SetString(PyExc_ValueError, "Invalid sort kind");
31+
return -1;
32+
}
33+
34+
return sort_func(data[0], dimensions[0], context,
35+
context->descriptors[0]->elsize, cmp);
36+
}
37+
38+
NPY_NO_EXPORT int
39+
npy_default_argsort_loop(PyArrayMethod_Context *context,
40+
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
41+
NpyAuxData *transferdata)
42+
{
43+
PyArray_CompareFunc *cmp = (PyArray_CompareFunc *)context->method->static_data;
44+
45+
PyArrayMethod_SortParameters *sort_params =
46+
(PyArrayMethod_SortParameters *)context->parameters;
47+
PyArray_ArgSortImpl *argsort_func = NULL;
48+
49+
switch (sort_params->flags) {
50+
case NPY_SORT_DEFAULT:
51+
argsort_func = npy_aquicksort_impl;
52+
break;
53+
case NPY_SORT_STABLE:
54+
argsort_func = npy_amergesort_impl;
55+
break;
56+
default:
57+
PyErr_SetString(PyExc_ValueError, "Invalid sort kind");
58+
return -1;
59+
}
60+
61+
return argsort_func(data[0], (npy_intp *)data[1], dimensions[0], context,
62+
context->descriptors[0]->elsize, cmp);
63+
}
64+
65+
#ifdef __cplusplus
66+
}
67+
#endif

numpy/_core/src/common/npy_sort.h.src

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <Python.h>
66
#include <numpy/npy_common.h>
77
#include <numpy/ndarraytypes.h>
8+
#include <dtypemeta.h>
89

910
#define NPY_ENOMEM 1
1011
#define NPY_ECOMP 2
@@ -107,6 +108,18 @@ NPY_NO_EXPORT int npy_aheapsort(void *vec, npy_intp *ind, npy_intp cnt, void *ar
107108
NPY_NO_EXPORT int npy_amergesort(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
108109
NPY_NO_EXPORT int npy_atimsort(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
109110

111+
/*
112+
*****************************************************************************
113+
** NEW-STYLE GENERIC SORT **
114+
*****************************************************************************
115+
*/
116+
117+
NPY_NO_EXPORT int npy_default_sort_loop(PyArrayMethod_Context *context,
118+
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
119+
NpyAuxData *transferdata);
120+
NPY_NO_EXPORT int npy_default_argsort_loop(PyArrayMethod_Context *context,
121+
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
122+
NpyAuxData *transferdata);
110123

111124
/*
112125
*****************************************************************************

numpy/_core/src/multiarray/item_selection.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3179,6 +3179,7 @@ PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
31793179
}
31803180
context.descriptors = loop_descrs;
31813181
context.parameters = &sort_params;
3182+
context.method = sort_method;
31823183

31833184
// Arrays are always contiguous for sorting
31843185
npy_intp strides[2] = {loop_descrs[0]->elsize, loop_descrs[1]->elsize};
@@ -3290,6 +3291,7 @@ PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND flags)
32903291
}
32913292
context.descriptors = loop_descrs;
32923293
context.parameters = &sort_params;
3294+
context.method = argsort_method;
32933295

32943296
// Arrays are always contiguous for sorting
32953297
npy_intp strides[2] = {loop_descrs[0]->elsize, loop_descrs[1]->elsize};

numpy/_core/src/multiarray/stringdtype/dtype.c

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -692,14 +692,9 @@ stringdtype_wrap_sort_loop(
692692
{
693693
PyArray_StringDTypeObject *sdescr =
694694
(PyArray_StringDTypeObject *)context->descriptors[0];
695-
PyArray_SortImpl *sort_loop =
696-
((PyArrayMethod_SortParameters *)context->parameters)->flags
697-
== NPY_SORT_STABLE ? &npy_mergesort_impl : &npy_quicksort_impl;
698695

699696
npy_string_allocator *allocator = NpyString_acquire_allocator(sdescr);
700-
int ret = sort_loop(
701-
data[0], dimensions[0], context,
702-
context->descriptors[0]->elsize, &_sort_compare);
697+
int ret = npy_default_sort_loop(context, data, dimensions, strides, transferdata);
703698
NpyString_release_allocator(allocator);
704699
return ret;
705700
}
@@ -737,14 +732,9 @@ stringdtype_wrap_argsort_loop(
737732
{
738733
PyArray_StringDTypeObject *sdescr =
739734
(PyArray_StringDTypeObject *)context->descriptors[0];
740-
PyArray_ArgSortImpl *argsort_loop =
741-
((PyArrayMethod_SortParameters *)context->parameters)
742-
->flags == NPY_SORT_STABLE ? &npy_amergesort_impl : &npy_aquicksort_impl;
743735

744736
npy_string_allocator *allocator = NpyString_acquire_allocator(sdescr);
745-
int ret = argsort_loop(
746-
data[0], (npy_intp *)data[1], dimensions[0], context,
747-
context->descriptors[0]->elsize, &_sort_compare);
737+
int ret = npy_default_argsort_loop(context, data, dimensions, strides, transferdata);
748738
NpyString_release_allocator(allocator);
749739
return ret;
750740
}
@@ -965,9 +955,10 @@ init_stringdtype_sorts(void)
965955
PyArray_DTypeMeta *stringdtype = &PyArray_StringDType;
966956

967957
PyArray_DTypeMeta *sort_dtypes[2] = {stringdtype, stringdtype};
968-
PyType_Slot sort_slots[3] = {
958+
PyType_Slot sort_slots[4] = {
969959
{NPY_METH_resolve_descriptors, &stringdtype_sort_resolve_descriptors},
970960
{NPY_METH_get_loop, &stringdtype_get_sort_loop},
961+
{_NPY_METH_static_data, &_sort_compare},
971962
{0, NULL}
972963
};
973964
PyArrayMethod_Spec sort_spec = {
@@ -989,8 +980,9 @@ init_stringdtype_sorts(void)
989980
Py_DECREF(sort_method);
990981

991982
PyArray_DTypeMeta *argsort_dtypes[2] = {stringdtype, &PyArray_IntpDType};
992-
PyType_Slot argsort_slots[2] = {
983+
PyType_Slot argsort_slots[3] = {
993984
{NPY_METH_get_loop, &stringdtype_get_argsort_loop},
985+
{_NPY_METH_static_data, &_sort_compare},
994986
{0, NULL}
995987
};
996988
PyArrayMethod_Spec argsort_spec = {

0 commit comments

Comments
 (0)