🌐 AI搜索 & 代理 主页
Skip to content

Commit b1f646f

Browse files
authored
Merge pull request #224 from javascriptdata/fix-serialization
Different API for serialization
2 parents fc7c97b + 260c134 commit b1f646f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+576
-564
lines changed

docs/convert.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ function getTypeName(val, bigObj) {
189189
}
190190

191191
function generateProperties(jsonClass, bigObj) {
192-
// console.log(jsonClass.children)
193192
let interface = getInterfaceForClass(jsonClass, bigObj)
194193
let allConstructorArgs = []
195194
if (interface && interface.children) {

package-lock.json

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
"dependencies": {
5252
"@tensorflow/tfjs": "^3.16.0",
5353
"@tensorflow/tfjs-node": "^3.16.0",
54+
"base64-arraybuffer": "^1.0.2",
5455
"lodash": "^4.17.21",
5556
"mathjs": "^10.0.0",
5657
"simple-statistics": "^7.7.0"

src/cluster/KMeans.test.ts

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import { KMeans } from './KMeans'
2-
1+
import { fromJSON, KMeans } from '../index'
32
// Next steps: Improve on kmeans cluster testing
43
describe('KMeans', () => {
54
const X = [
@@ -38,7 +37,7 @@ describe('KMeans', () => {
3837
)
3938
})
4039

41-
it('should save kmeans model', () => {
40+
it('should save kmeans model', async () => {
4241
const expectedResult = {
4342
name: 'KMeans',
4443
nClusters: 2,
@@ -48,7 +47,7 @@ describe('KMeans', () => {
4847
randomState: 0,
4948
nInit: 10,
5049
clusterCenters: {
51-
type: 'Tensor',
50+
name: 'Tensor',
5251
value: [
5352
[2.5, 1],
5453
[2.5, 4]
@@ -57,20 +56,20 @@ describe('KMeans', () => {
5756
}
5857
const kmean = new KMeans({ nClusters: 2, randomState: 0 })
5958
kmean.fit(X)
60-
const ksave = kmean.toJson() as string
59+
const ksave = await kmean.toObject()
6160

62-
expect(expectedResult).toEqual(JSON.parse(ksave))
61+
expect(expectedResult).toEqual(ksave)
6362
})
6463

65-
it('should load serialized kmeans model', () => {
64+
it('should load serialized kmeans model', async () => {
6665
const centroids = [
6766
[2.5, 1],
6867
[2.5, 4]
6968
]
7069
const kmean = new KMeans({ nClusters: 2, randomState: 0 })
7170
kmean.fit(X)
72-
const ksave = kmean.toJson() as string
73-
const ksaveModel = new KMeans().fromJson(ksave)
71+
const ksave = await kmean.toJSON()
72+
const ksaveModel = await fromJSON(ksave)
7473
expect(centroids).toEqual(ksaveModel.clusterCenters.arraySync())
7574
})
7675

src/cluster/KMeans.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Scikit2D } from '../types'
22
import { convertToNumericTensor2D, sampleWithoutReplacement } from '../utils'
3-
import Serialize from '../serialize'
3+
import { Serialize } from '../simpleSerializer'
44
import { tf } from '../shared/globals'
55

66
/*

src/compose/ColumnTransformer.test.ts

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import { ColumnTransformer } from './ColumnTransformer'
2-
import { MinMaxScaler } from '../preprocessing/MinMaxScaler'
3-
import { SimpleImputer } from '../impute/SimpleImputer'
1+
import {
2+
fromJSON,
3+
SimpleImputer,
4+
MinMaxScaler,
5+
ColumnTransformer
6+
} from '../index'
47
import * as dfd from 'danfojs-node'
58

69
describe('ColumnTransformer', function () {
@@ -30,4 +33,26 @@ describe('ColumnTransformer', function () {
3033

3134
expect(result.arraySync()).toEqual(expected)
3235
})
36+
it('ColumnTransformer serialize/deserialize test', async function () {
37+
const X = [
38+
[2, 2], // [1, .5]
39+
[2, 3], // [1, .75]
40+
[0, NaN], // [0, 1]
41+
[2, 0] // [.5, 0]
42+
]
43+
let newDf = new dfd.DataFrame(X)
44+
45+
const transformer = new ColumnTransformer({
46+
transformers: [
47+
['minmax', new MinMaxScaler(), [0]],
48+
['simpleImpute', new SimpleImputer({ strategy: 'median' }), [1]]
49+
]
50+
})
51+
52+
transformer.fitTransform(newDf)
53+
let obj = await transformer.toJSON()
54+
let myResult = await fromJSON(obj)
55+
56+
expect(myResult.transformers.length).toEqual(2)
57+
})
3358
})

src/compose/ColumnTransformer.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import { DataFrameInterface, Scikit1D, Scikit2D, Transformer } from '../types'
2-
import { isDataFrameInterface, isScikitLike2D } from '../typesUtils'
1+
import { DataFrameInterface, Scikit1D, Transformer } from '../types'
2+
import { isDataFrameInterface } from '../typesUtils'
3+
import { Serialize } from '../simpleSerializer'
34
import { tf } from '../shared/globals'
45
/*
56
Next steps:
@@ -64,7 +65,7 @@ export interface ColumnTransformerParams {
6465
]
6566
* ```
6667
*/
67-
export class ColumnTransformer {
68+
export class ColumnTransformer extends Serialize {
6869
transformers: TransformerTriple
6970
remainder: Transformer | 'drop' | 'passthrough'
7071

@@ -75,6 +76,7 @@ export class ColumnTransformer {
7576
transformers = [],
7677
remainder = 'drop'
7778
}: ColumnTransformerParams = {}) {
79+
super()
7880
this.transformers = transformers
7981
this.remainder = remainder
8082
}

src/dummy/DummyClassifier.test.ts

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import { DummyClassifier } from './DummyClassifier'
2-
1+
import { DummyClassifier, fromJSON } from '../index'
32
describe('DummyClassifier', function () {
43
it('Use DummyClassifier on simple example (mostFrequent)', function () {
54
const clf = new DummyClassifier()
@@ -51,7 +50,7 @@ describe('DummyClassifier', function () {
5150

5251
expect(scaler.classes).toEqual([1, 2, 3])
5352
})
54-
it('should serialize DummyClassifier', function () {
53+
it('should serialize DummyClassifier', async function () {
5554
const clf = new DummyClassifier()
5655

5756
const X = [
@@ -70,10 +69,10 @@ describe('DummyClassifier', function () {
7069
}
7170

7271
clf.fit(X, y)
73-
const clfSave = clf.toJson() as string
74-
expect(expectedResult).toEqual(JSON.parse(clfSave))
72+
const clfSave = await clf.toObject()
73+
expect(expectedResult).toEqual(clfSave)
7574
})
76-
it('should load DummyClassifier', function () {
75+
it('should load DummyClassifier', async function () {
7776
const clf = new DummyClassifier()
7877

7978
const X = [
@@ -85,8 +84,8 @@ describe('DummyClassifier', function () {
8584
const y = [10, 20, 20, 30]
8685

8786
clf.fit(X, y)
88-
const clfSave = clf.toJson() as string
89-
const newClf = new DummyClassifier().fromJson(clfSave)
87+
const clfSave = await clf.toJSON()
88+
const newClf = await fromJSON(clfSave)
9089
expect(clf).toEqual(newClf)
9190
})
9291
})

src/dummy/DummyRegressor.test.ts

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { DummyRegressor } from './DummyRegressor'
1+
import { DummyRegressor, fromJSON } from '../index'
22

33
describe('DummyRegressor', function () {
44
it('Use DummyRegressor on simple example (mean)', function () {
@@ -55,7 +55,7 @@ describe('DummyRegressor', function () {
5555
reg.fit(X, y)
5656
expect(reg.predict(predictX).arraySync()).toEqual([10, 10, 10])
5757
})
58-
it('Should save DummyRegressor', function () {
58+
it('Should save DummyRegressor', async function () {
5959
const reg = new DummyRegressor({ strategy: 'constant', constant: 10 })
6060

6161
const X = [
@@ -68,15 +68,16 @@ describe('DummyRegressor', function () {
6868
name: 'DummyRegressor',
6969
EstimatorType: 'regressor',
7070
strategy: 'constant',
71-
constant: 10
71+
constant: 10,
72+
quantile: undefined
7273
}
7374

7475
reg.fit(X, y)
7576

76-
expect(saveResult).toEqual(JSON.parse(reg.toJson() as string))
77+
expect(saveResult).toEqual(await reg.toObject())
7778
})
7879

79-
it('Should load serialized DummyRegressor', function () {
80+
it('Should load serialized DummyRegressor', async function () {
8081
const reg = new DummyRegressor({ strategy: 'constant', constant: 10 })
8182

8283
const X = [
@@ -92,8 +93,8 @@ describe('DummyRegressor', function () {
9293
]
9394

9495
reg.fit(X, y)
95-
const saveReg = reg.toJson() as string
96-
const newReg = new DummyRegressor().fromJson(saveReg)
96+
const saveReg = await reg.toJSON()
97+
const newReg = await fromJSON(saveReg)
9798

9899
expect(newReg.predict(predictX).arraySync()).toEqual([10, 10, 10])
99100
})

src/ensemble/VotingClassifier.test.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
import { makeVotingClassifier, VotingClassifier } from './VotingClassifier'
2-
import { DummyClassifier } from '../dummy/DummyClassifier'
3-
4-
import { LogisticRegression } from '../linear_model/LogisticRegression'
1+
import {
2+
makeVotingClassifier,
3+
VotingClassifier,
4+
DummyClassifier,
5+
LogisticRegression,
6+
fromJSON
7+
} from '../index'
58

69
describe('VotingClassifier', function () {
710
it('Use VotingClassifier on simple example (voting = hard)', async function () {
@@ -118,8 +121,8 @@ describe('VotingClassifier', function () {
118121

119122
await voter.fit(X, y)
120123

121-
const savedModel = (await voter.toJson()) as string
122-
const newModel = new VotingClassifier({}).fromJson(savedModel)
124+
const savedModel = await voter.toJSON()
125+
const newModel = await fromJSON(savedModel)
123126

124127
expect(newModel.predict(X).arraySync()).toEqual([1, 1, 1, 1, 1])
125128
}, 30000)

0 commit comments

Comments
 (0)