🌐 AI搜索 & 代理 主页
Skip to content

Commit 341d34c

Browse files
authored
Merge pull request #194 from DirkToewe/k_neighbors_kd_tree_fix
[KdTree] fixes + test speedups
2 parents 1403a4a + 5f681ea commit 341d34c

File tree

7 files changed

+363
-111
lines changed

7 files changed

+363
-111
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/**
2+
* @license
3+
* Copyright 2022, JsData. All rights reserved.
4+
*
5+
* This source code is licensed under the MIT license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ==========================================================================
14+
*/
15+
16+
import { neighborhoodGenericTests } from './neighborhoodGenericTests'
17+
import { BruteNeighborhood } from './bruteNeighborhood'
18+
19+
neighborhoodGenericTests(
20+
'BruteNeighborhood',
21+
async (params) => new BruteNeighborhood(params)
22+
)

src/neighbors/kNeighborsBase.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ export interface KNeighborsParams {
107107
* Handles common constructor parameters and fitting.
108108
*/
109109
export class KNeighborsBase implements KNeighborsParams {
110+
static readonly SUPPORTED_ALGORITHMS = Object.freeze(
111+
Object.keys(ALGORITHMS)
112+
) as (keyof typeof ALGORITHMS)[]
113+
110114
private _neighborhood: Neighborhood | undefined
111115
private _y: Tensor1D | undefined
112116

@@ -117,7 +121,7 @@ export class KNeighborsBase implements KNeighborsParams {
117121
metric: KNeighborsParams['metric']
118122
nNeighbors: KNeighborsParams['nNeighbors']
119123

120-
constructor(params: KNeighborsParams) {
124+
constructor(params: KNeighborsParams = {}) {
121125
Object.assign(this, params)
122126
}
123127

src/neighbors/kNeighborsClassifier.test.ts

Lines changed: 76 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -34,83 +34,88 @@ function testWithDataset(
3434
params: KNeighborsParams,
3535
referenceAccuracy: number
3636
) {
37-
it(`KNeighborsClassifier(${JSON.stringify(params)}) fits ${
38-
loadData.name
39-
} as well as sklearn`, async () => {
40-
const df = await loadData()
41-
42-
const Xy = df.tensor as unknown as Tensor2D
43-
let [nSamples, nFeatures] = Xy.shape
44-
--nFeatures
45-
46-
const X = Xy.slice([0, 0], [nSamples, nFeatures])
47-
const y = Xy.slice([0, nFeatures]).reshape([nSamples]) as Tensor1D
48-
49-
const accuracies = await crossValScore(
50-
new KNeighborsClassifier(params),
51-
X,
52-
y,
53-
{
54-
cv: new KFold({ nSplits: 3 })
55-
}
56-
)
37+
it(
38+
`matches sklearn fitting ${loadData.name}`.padEnd(48) +
39+
JSON.stringify(params),
40+
async () => {
41+
const df = await loadData()
42+
43+
const Xy = df.tensor as unknown as Tensor2D
44+
let [nSamples, nFeatures] = Xy.shape
45+
--nFeatures
46+
47+
const X = Xy.slice([0, 0], [nSamples, nFeatures])
48+
const y = Xy.slice([0, nFeatures]).reshape([nSamples]) as Tensor1D
49+
50+
const accuracies = await crossValScore(
51+
new KNeighborsClassifier(params),
52+
X,
53+
y,
54+
{
55+
cv: new KFold({ nSplits: 3 })
56+
}
57+
)
5758

58-
expect(accuracies.mean()).toBeAllCloseTo(referenceAccuracy, {
59-
atol: 0,
60-
rtol: 0.005
61-
})
62-
}, 600_000)
59+
expect(accuracies.mean()).toBeAllCloseTo(referenceAccuracy, {
60+
atol: 0,
61+
rtol: 0.005
62+
})
63+
},
64+
60_000
65+
)
6366
}
6467

6568
for (const algorithm of [
66-
'kdTree',
67-
'brute',
68-
undefined,
69-
'auto'
70-
] as KNeighborsParams['algorithm'][]) {
71-
testWithDataset(
72-
loadDigits,
73-
{ nNeighbors: 5, weights: 'distance', algorithm },
74-
0.963
75-
)
76-
testWithDataset(
77-
loadIris,
78-
{ nNeighbors: 5, weights: 'distance', algorithm },
79-
0.0
80-
)
81-
testWithDataset(
82-
loadWine,
83-
{ nNeighbors: 5, weights: 'distance', algorithm },
84-
0.135
85-
)
86-
testWithDataset(
87-
loadBreastCancer,
88-
{ nNeighbors: 5, weights: 'distance', algorithm },
89-
0.92
90-
)
69+
...KNeighborsClassifier.SUPPORTED_ALGORITHMS,
70+
undefined
71+
]) {
72+
describe(`KNeighborsClassifier({ algorithm: ${algorithm} })`, () => {
73+
testWithDataset(
74+
loadIris,
75+
{ nNeighbors: 5, weights: 'distance', algorithm },
76+
0.0
77+
)
78+
testWithDataset(
79+
loadIris,
80+
{ nNeighbors: 3, weights: 'uniform', algorithm },
81+
0.0
82+
)
9183

92-
testWithDataset(
93-
loadDigits,
94-
{ nNeighbors: 3, weights: 'uniform', algorithm },
95-
0.967
96-
)
97-
testWithDataset(
98-
loadIris,
99-
{ nNeighbors: 3, weights: 'uniform', algorithm },
100-
0.0
101-
)
102-
testWithDataset(
103-
loadWine,
104-
{ nNeighbors: 3, weights: 'uniform', algorithm },
105-
0.158
106-
)
107-
testWithDataset(
108-
loadBreastCancer,
109-
{ nNeighbors: 3, weights: 'uniform', algorithm },
110-
0.916
111-
)
84+
testWithDataset(
85+
loadWine,
86+
{ nNeighbors: 5, weights: 'distance', algorithm },
87+
0.135
88+
)
89+
testWithDataset(
90+
loadWine,
91+
{ nNeighbors: 3, weights: 'uniform', algorithm },
92+
0.158
93+
)
94+
95+
testWithDataset(
96+
loadBreastCancer,
97+
{ nNeighbors: 5, weights: 'distance', algorithm },
98+
0.92
99+
)
100+
testWithDataset(
101+
loadBreastCancer,
102+
{ nNeighbors: 3, weights: 'uniform', algorithm },
103+
0.916
104+
)
105+
106+
if ('brute' !== algorithm) {
107+
testWithDataset(
108+
loadDigits,
109+
{ nNeighbors: 5, weights: 'distance', algorithm, leafSize: 256 },
110+
0.963
111+
)
112+
testWithDataset(
113+
loadDigits,
114+
{ nNeighbors: 3, weights: 'uniform', algorithm, leafSize: 256 },
115+
0.967
116+
)
117+
}
112118

113-
describe(`KNeighborsClassifier({ algorithm: ${algorithm} })`, () => {
114119
it('correctly predicts sklearn example', async () => {
115120
const X_train = [[0], [1], [2], [3]]
116121
const y_train = [0, 0, 1, 1]

src/neighbors/kNeighborsRegressor.test.ts

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,36 +30,42 @@ function testWithDataset(
3030
params: KNeighborsParams,
3131
referenceError: number
3232
) {
33-
it(`KNeighborsRegressor(${JSON.stringify(params)}) fits ${
34-
loadData.name
35-
} as well as sklearn`, async () => {
36-
const df = await loadData()
37-
38-
const Xy = df.tensor as unknown as Tensor2D
39-
let [nSamples, nFeatures] = Xy.shape
40-
--nFeatures
41-
42-
const X = Xy.slice([0, 0], [nSamples, nFeatures])
43-
const y = Xy.slice([0, nFeatures]).reshape([nSamples]) as Tensor1D
44-
45-
const scores = await crossValScore(new KNeighborsRegressor(params), X, y, {
46-
cv: new KFold({ nSplits: 3 }),
47-
scoring: negMeanSquaredError
48-
})
49-
50-
expect(scores.mean()).toBeAllCloseTo(-referenceError, {
51-
atol: 0,
52-
rtol: 0.01
53-
})
54-
}, 600_000)
33+
it(
34+
`matches sklearn fitting ${loadData.name}`.padEnd(48) +
35+
JSON.stringify(params),
36+
async () => {
37+
const df = await loadData()
38+
39+
const Xy = df.tensor as unknown as Tensor2D
40+
let [nSamples, nFeatures] = Xy.shape
41+
--nFeatures
42+
43+
const X = Xy.slice([0, 0], [nSamples, nFeatures])
44+
const y = Xy.slice([0, nFeatures]).reshape([nSamples]) as Tensor1D
45+
46+
const scores = await crossValScore(
47+
new KNeighborsRegressor(params),
48+
X,
49+
y,
50+
{
51+
cv: new KFold({ nSplits: 3 }),
52+
scoring: negMeanSquaredError
53+
}
54+
)
55+
56+
expect(scores.mean()).toBeAllCloseTo(-referenceError, {
57+
atol: 0,
58+
rtol: 0.005
59+
})
60+
},
61+
60_000
62+
)
5563
}
5664

5765
for (const algorithm of [
58-
'kdTree',
59-
'brute',
60-
undefined,
61-
'auto'
62-
] as KNeighborsParams['algorithm'][]) {
66+
...KNeighborsRegressor.SUPPORTED_ALGORITHMS,
67+
undefined
68+
]) {
6369
describe(`KNeighborsRegressor({ algorithm: ${algorithm} })`, function () {
6470
testWithDataset(
6571
loadDiabetes,
@@ -71,22 +77,21 @@ for (const algorithm of [
7177
{ nNeighbors: 3, weights: 'uniform', algorithm },
7278
3833
7379
)
74-
if ('brute' !== algorithm) {
80+
if ('kdTree' === algorithm) {
7581
testWithDataset(
7682
fetchCaliforniaHousing,
7783
{ nNeighbors: 3, weights: 'distance', algorithm },
7884
1.31
7985
)
80-
testWithDataset(
81-
fetchCaliforniaHousing,
82-
{ nNeighbors: 4, weights: 'uniform', algorithm },
83-
1.28
84-
)
86+
}
87+
if ('auto' === algorithm) {
8588
testWithDataset(
8689
fetchCaliforniaHousing,
8790
{ nNeighbors: 4, weights: 'uniform', algorithm, p: 1 },
8891
1.19
8992
)
93+
}
94+
if (undefined === algorithm) {
9095
testWithDataset(
9196
fetchCaliforniaHousing,
9297
{ nNeighbors: 4, weights: 'uniform', algorithm, p: Infinity },

src/neighbors/kdTree.test.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/**
2+
* @license
3+
* Copyright 2022, JsData. All rights reserved.
4+
*
5+
* This source code is licensed under the MIT license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
* ==========================================================================
14+
*/
15+
16+
import { neighborhoodGenericTests } from './neighborhoodGenericTests'
17+
import { KdTree } from './kdTree'
18+
19+
neighborhoodGenericTests('KdTree', KdTree.build)

src/neighbors/kdTree.ts

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,12 @@ export class KdTree implements Neighborhood {
130130
indices[i] = i
131131
}
132132

133-
const data = await entries.data()
133+
// TFJS may or may not return the underlying data array here.
134+
// Changes to the array may or may not cause the content of
135+
// `entries` to change. `entries.data()` may also be a small
136+
// subarray of a much much larger array. To avoid any issue
137+
// a protection copy needs to be made.
138+
const data = (await entries.data()).slice()
134139

135140
const points: Vec[] = Array.from(indices, (_, i) =>
136141
data.subarray(nFeatures * i, nFeatures * ++i)
@@ -170,9 +175,9 @@ export class KdTree implements Neighborhood {
170175
}
171176

172177
for (let i = from; i < until; i++) {
173-
const j = indices[i]
178+
const j = nFeatures * indices[i]
174179
for (let k = 0; k < bBox.length; ) {
175-
const djk = data[nFeatures * j + (k >>> 1)]
180+
const djk = data[j + (k >>> 1)]
176181
bBox[k] = Math.min(bBox[k++], djk)
177182
bBox[k] = Math.max(bBox[k++], djk)
178183
}
@@ -192,10 +197,11 @@ export class KdTree implements Neighborhood {
192197

193198
// 2.1: Determine Split Axis
194199
// -------------------------
200+
// Choose largest side of bounding box as axis to split.
195201
const axis = (function () {
196202
let axis = 0
197203
let dMax = -Infinity
198-
for (let i = bBox.length; i >= 0; ) {
204+
for (let i = bBox.length; i > 0; ) {
199205
const di = bBox[--i] - bBox[--i]
200206
if (di > dMax) {
201207
dMax = di
@@ -358,8 +364,8 @@ export class KdTree implements Neighborhood {
358364
// KNeighborsBaseParams and add backpropagation support
359365
// to KdTree.
360366
return {
361-
distances: tf.tensor(dists, [nQueries, k], 'float32'),
362-
indices: tf.tensor(indxs, [nQueries, k], 'int32')
367+
distances: tf.tensor2d(dists, [nQueries, k], 'float32'),
368+
indices: tf.tensor2d(indxs, [nQueries, k], 'int32')
363369
}
364370
}
365371
}

0 commit comments

Comments
 (0)