@@ -34,83 +34,88 @@ function testWithDataset(
3434 params : KNeighborsParams ,
3535 referenceAccuracy : number
3636) {
37- it ( `KNeighborsClassifier(${ JSON . stringify ( params ) } ) fits ${
38- loadData . name
39- } as well as sklearn`, async ( ) => {
40- const df = await loadData ( )
41-
42- const Xy = df . tensor as unknown as Tensor2D
43- let [ nSamples , nFeatures ] = Xy . shape
44- -- nFeatures
45-
46- const X = Xy . slice ( [ 0 , 0 ] , [ nSamples , nFeatures ] )
47- const y = Xy . slice ( [ 0 , nFeatures ] ) . reshape ( [ nSamples ] ) as Tensor1D
48-
49- const accuracies = await crossValScore (
50- new KNeighborsClassifier ( params ) ,
51- X ,
52- y ,
53- {
54- cv : new KFold ( { nSplits : 3 } )
55- }
56- )
37+ it (
38+ `matches sklearn fitting ${ loadData . name } ` . padEnd ( 48 ) +
39+ JSON . stringify ( params ) ,
40+ async ( ) => {
41+ const df = await loadData ( )
42+
43+ const Xy = df . tensor as unknown as Tensor2D
44+ let [ nSamples , nFeatures ] = Xy . shape
45+ -- nFeatures
46+
47+ const X = Xy . slice ( [ 0 , 0 ] , [ nSamples , nFeatures ] )
48+ const y = Xy . slice ( [ 0 , nFeatures ] ) . reshape ( [ nSamples ] ) as Tensor1D
49+
50+ const accuracies = await crossValScore (
51+ new KNeighborsClassifier ( params ) ,
52+ X ,
53+ y ,
54+ {
55+ cv : new KFold ( { nSplits : 3 } )
56+ }
57+ )
5758
58- expect ( accuracies . mean ( ) ) . toBeAllCloseTo ( referenceAccuracy , {
59- atol : 0 ,
60- rtol : 0.005
61- } )
62- } , 600_000 )
59+ expect ( accuracies . mean ( ) ) . toBeAllCloseTo ( referenceAccuracy , {
60+ atol : 0 ,
61+ rtol : 0.005
62+ } )
63+ } ,
64+ 60_000
65+ )
6366}
6467
6568for ( const algorithm of [
66- 'kdTree' ,
67- 'brute' ,
68- undefined ,
69- 'auto'
70- ] as KNeighborsParams [ 'algorithm' ] [ ] ) {
71- testWithDataset (
72- loadDigits ,
73- { nNeighbors : 5 , weights : 'distance' , algorithm } ,
74- 0.963
75- )
76- testWithDataset (
77- loadIris ,
78- { nNeighbors : 5 , weights : 'distance' , algorithm } ,
79- 0.0
80- )
81- testWithDataset (
82- loadWine ,
83- { nNeighbors : 5 , weights : 'distance' , algorithm } ,
84- 0.135
85- )
86- testWithDataset (
87- loadBreastCancer ,
88- { nNeighbors : 5 , weights : 'distance' , algorithm } ,
89- 0.92
90- )
69+ ...KNeighborsClassifier . SUPPORTED_ALGORITHMS ,
70+ undefined
71+ ] ) {
72+ describe ( `KNeighborsClassifier({ algorithm: ${ algorithm } })` , ( ) => {
73+ testWithDataset (
74+ loadIris ,
75+ { nNeighbors : 5 , weights : 'distance' , algorithm } ,
76+ 0.0
77+ )
78+ testWithDataset (
79+ loadIris ,
80+ { nNeighbors : 3 , weights : 'uniform' , algorithm } ,
81+ 0.0
82+ )
9183
92- testWithDataset (
93- loadDigits ,
94- { nNeighbors : 3 , weights : 'uniform' , algorithm } ,
95- 0.967
96- )
97- testWithDataset (
98- loadIris ,
99- { nNeighbors : 3 , weights : 'uniform' , algorithm } ,
100- 0.0
101- )
102- testWithDataset (
103- loadWine ,
104- { nNeighbors : 3 , weights : 'uniform' , algorithm } ,
105- 0.158
106- )
107- testWithDataset (
108- loadBreastCancer ,
109- { nNeighbors : 3 , weights : 'uniform' , algorithm } ,
110- 0.916
111- )
84+ testWithDataset (
85+ loadWine ,
86+ { nNeighbors : 5 , weights : 'distance' , algorithm } ,
87+ 0.135
88+ )
89+ testWithDataset (
90+ loadWine ,
91+ { nNeighbors : 3 , weights : 'uniform' , algorithm } ,
92+ 0.158
93+ )
94+
95+ testWithDataset (
96+ loadBreastCancer ,
97+ { nNeighbors : 5 , weights : 'distance' , algorithm } ,
98+ 0.92
99+ )
100+ testWithDataset (
101+ loadBreastCancer ,
102+ { nNeighbors : 3 , weights : 'uniform' , algorithm } ,
103+ 0.916
104+ )
105+
106+ if ( 'brute' !== algorithm ) {
107+ testWithDataset (
108+ loadDigits ,
109+ { nNeighbors : 5 , weights : 'distance' , algorithm, leafSize : 256 } ,
110+ 0.963
111+ )
112+ testWithDataset (
113+ loadDigits ,
114+ { nNeighbors : 3 , weights : 'uniform' , algorithm, leafSize : 256 } ,
115+ 0.967
116+ )
117+ }
112118
113- describe ( `KNeighborsClassifier({ algorithm: ${ algorithm } })` , ( ) => {
114119 it ( 'correctly predicts sklearn example' , async ( ) => {
115120 const X_train = [ [ 0 ] , [ 1 ] , [ 2 ] , [ 3 ] ]
116121 const y_train = [ 0 , 0 , 1 , 1 ]
0 commit comments