🌐 AI搜索 & 代理 主页
Skip to content

Commit 42de904

Browse files
committed
Updated lib
1 parent 297a3cd commit 42de904

File tree

7 files changed

+84
-37
lines changed

7 files changed

+84
-37
lines changed

shared/lib/estimators/sgd.linear.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ export class SGD extends PredictorMixin {
135135
} else {
136136
const yTwoD = y.reshape([-1, 1])
137137
const yTwoDOneHotEncoded = this.oneHot.fitTransform(yTwoD)
138-
if (this.oneHot.$labels[0].size > 2) {
138+
if (this.oneHot.categories[0].length > 2) {
139139
this.modelCompileArgs.loss = losses.softmaxCrossEntropy
140140
} else {
141141
this.modelCompileArgs.loss = losses.sigmoidCrossEntropy

shared/lib/preprocessing/encoders/one.hot.encoder.ts

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import { tf } from '../../../globals'
2121

2222
/*
2323
Todo:
24-
1. Change $labels to categories
24+
1. Implement inverseTransform for 2D array
2525
2. Pass the next 5 scikit-learn tests
2626
*/
2727

@@ -34,11 +34,20 @@ Todo:
3434
* ```
3535
*/
3636
export default class OneHotEncoder extends TransformerMixin {
37-
$labels: Map<string | number | boolean, number>[]
38-
37+
categories: (number | string | boolean)[][]
3938
constructor() {
4039
super()
41-
this.$labels = []
40+
this.categories = []
41+
}
42+
43+
classesToMapping(
44+
classes: Array<string | number | boolean>
45+
): Map<string | number | boolean, number> {
46+
const labels = new Map<string | number | boolean, number>()
47+
classes.forEach((value, index) => {
48+
labels.set(value, index)
49+
})
50+
return labels
4251
}
4352

4453
loopOver2DArrayToSetLabels(array2D: any) {
@@ -48,11 +57,11 @@ export default class OneHotEncoder extends TransformerMixin {
4857
curSet.add(array2D[i][j])
4958
}
5059
let results = Array.from(curSet)
51-
let newMap = new Map<string | number | boolean, number>()
52-
results.forEach((el, i) => {
53-
newMap.set(el as number, i)
54-
})
55-
this.$labels.push(newMap)
60+
// let newMap = new Map<string | number | boolean, number>()
61+
// results.forEach((el, i) => {
62+
// newMap.set(el as number, i)
63+
// })
64+
this.categories.push(results as number[])
5665
}
5766
}
5867

@@ -74,12 +83,13 @@ export default class OneHotEncoder extends TransformerMixin {
7483
}
7584

7685
loopOver2DArrayToUseLabels(array2D: any) {
86+
let labels = this.categories.map((el) => this.classesToMapping(el))
7787
let finalArray = []
7888
for (let i = 0; i < array2D.length; i++) {
7989
let curArray = []
8090
for (let j = 0; j < array2D[0].length; j++) {
8191
let curElem = array2D[i][j]
82-
let val = this.$labels[j].get(curElem)
92+
let val = labels[j].get(curElem)
8393
let actualIndex = val === undefined ? -1 : val
8494
curArray.push(actualIndex)
8595
}
@@ -103,16 +113,17 @@ export default class OneHotEncoder extends TransformerMixin {
103113
const result2D = this.loopOver2DArrayToUseLabels(array2D)
104114
const newTensor = tf.tensor2d(result2D, undefined, 'int32')
105115
return tf.concat(
106-
newTensor.unstack(1).map((el, i) => tf.oneHot(el, this.$labels[i].size)),
116+
newTensor
117+
.unstack(1)
118+
.map((el, i) => tf.oneHot(el, this.categories[i].length)),
107119
1
108120
) as tf.Tensor2D
109121
}
110122
// Only works for single column OneHotEncoding
111123
inverseTransform(X: tf.Tensor2D): any[] {
124+
let labels = this.classesToMapping(this.categories[0])
112125
const tensorLabels = X.argMax(1) as tf.Tensor1D
113-
const invMap = new Map(
114-
Array.from(this.$labels[0], (a) => a.reverse()) as any
115-
)
126+
const invMap = new Map(Array.from(labels, (a) => a.reverse()) as any)
116127

117128
const tempData = tensorLabels.arraySync().map((value) => {
118129
return invMap.get(value) === undefined ? null : invMap.get(value)

shared/lib/preprocessing/encoders/ordinal.encoder.ts

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ import { tf } from '../../../globals'
2121

2222
/*
2323
Todo:
24-
1. Change $labels to categories
25-
2. Pass the next 5 tests
24+
1. Pass the next 5 tests
2625
*/
2726

2827
/**
@@ -34,11 +33,20 @@ Todo:
3433
* ```
3534
*/
3635
export default class OrdinalEncoder extends TransformerMixin {
37-
$labels: Map<string | number | boolean, number>[]
38-
36+
categories: (number | string | boolean)[][]
3937
constructor() {
4038
super()
41-
this.$labels = []
39+
this.categories = []
40+
}
41+
42+
classesToMapping(
43+
classes: Array<string | number | boolean>
44+
): Map<string | number | boolean, number> {
45+
const labels = new Map<string | number | boolean, number>()
46+
classes.forEach((value, index) => {
47+
labels.set(value, index)
48+
})
49+
return labels
4250
}
4351

4452
loopOver2DArrayToSetLabels(array2D: any) {
@@ -48,11 +56,7 @@ export default class OrdinalEncoder extends TransformerMixin {
4856
curSet.add(array2D[i][j])
4957
}
5058
let results = Array.from(curSet)
51-
let newMap = new Map<string | number | boolean, number>()
52-
results.forEach((el, i) => {
53-
newMap.set(el as number, i)
54-
})
55-
this.$labels.push(newMap)
59+
this.categories.push(results as number[])
5660
}
5761
}
5862

@@ -74,12 +78,13 @@ export default class OrdinalEncoder extends TransformerMixin {
7478
}
7579

7680
loopOver2DArrayToUseLabels(array2D: any) {
81+
let labels = this.categories.map((el) => this.classesToMapping(el))
7782
let finalArray = []
7883
for (let i = 0; i < array2D.length; i++) {
7984
let curArray = []
8085
for (let j = 0; j < array2D[0].length; j++) {
8186
let curElem = array2D[i][j]
82-
let val = this.$labels[j].get(curElem)
87+
let val = labels[j].get(curElem)
8388
let actualIndex = val === undefined ? -1 : val
8489
curArray.push(actualIndex)
8590
}

shared/lib/utils.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ export function convertToTensor1D(
7373
if (data instanceof dfd.Series) {
7474
// Do type inference if no dtype is passed, otherwise try to parse as that dtype
7575
return dtype
76-
? (data.tensor.asType(dtype) as tf.Tensor1D)
77-
: (data.tensor as tf.Tensor1D)
76+
? (data.tensor.asType(dtype) as unknown as tf.Tensor1D)
77+
: (data.tensor as unknown as tf.Tensor1D)
7878
}
7979
if (data instanceof tf.Tensor) {
8080
if (data.shape.length === 1) {
@@ -107,8 +107,8 @@ export function convertToTensor2D(
107107
): tf.Tensor2D {
108108
if (data instanceof dfd.DataFrame) {
109109
return dtype
110-
? (data.tensor.asType(dtype) as tf.Tensor2D)
111-
: (data.tensor as tf.Tensor2D)
110+
? (data.tensor.asType(dtype) as unknown as tf.Tensor2D)
111+
: (data.tensor as unknown as tf.Tensor2D)
112112
}
113113
if (data instanceof tf.Tensor) {
114114
if (data.shape.length === 2) {
@@ -181,10 +181,10 @@ export function convertToTensor(
181181
dtype?: keyof tf.DataTypeMap
182182
): tf.Tensor {
183183
if (data instanceof dfd.DataFrame) {
184-
return data.tensor
184+
return data.tensor as unknown as tf.Tensor2D
185185
}
186186
if (data instanceof dfd.Series) {
187-
return data.tensor
187+
return data.tensor as unknown as tf.Tensor2D
188188
}
189189
if (data instanceof tf.Tensor) {
190190
let newData = data

shared/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
"mathjs": "^10.0.0",
2828
"seedrandom": "^3.0.5",
2929
"simple-statistics": "^7.7.0",
30-
"typescript": "^4.4.4"
30+
"typescript": "^4.5.2"
3131
},
3232
"devDependencies": {
33+
"@types/lodash": "^4.14.177",
3334
"@types/mocha": "^9.0.0",
3435
"chai": "^4.3.4",
3536
"mocha": "^9.1.3"

shared/tsconfig.json

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"compilerOptions": {
3+
"target": "es2015" /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019', 'ES2020', or 'ESNEXT'. */,
4+
"module": "commonjs" /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', 'es2020', or 'ESNext'. */,
5+
"lib": [
6+
"es6"
7+
] /* Specify library files to be included in the compilation. */,
8+
"moduleResolution": "node",
9+
"allowJs": true /* Allow javascript files to be compiled. */,
10+
"outDir": "./dist" /* Redirect output structure to the directory. */,
11+
"strict": true /* Enable all strict type-checking options. */,
12+
"noImplicitAny": true /* Raise error on expressions and declarations with an implied 'any' type. */,
13+
// "noUnusedLocals": true /* Report errors on unused locals. */,
14+
// "noUnusedParameters": true /* Report errors on unused parameters. */,
15+
"noFallthroughCasesInSwitch": true /* Report errors for fallthrough cases in switch statement. */,
16+
"esModuleInterop": true /* Enables emit interoperability between CommonJS and ES Modules via creation of namespace objects for all imports. Implies 'allowSyntheticDefaultImports'. */,
17+
"resolveJsonModule": true /* Include modules imported with '.json' extension */,
18+
"skipLibCheck": true /* Skip type checking of declaration files. */,
19+
"forceConsistentCasingInFileNames": true /* Disallow inconsistently-cased references to the same file. */,
20+
"declaration": true,
21+
"baseUrl": "./"
22+
},
23+
"include": ["./**/*"],
24+
"exclude": ["./dist", "./build", "./**/*.test.ts", "node_modules"]
25+
}

shared/yarn.lock

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@
246246
resolved "https://registry.yarnpkg.com/@types/json-schema/-/json-schema-7.0.9.tgz#97edc9037ea0c38585320b28964dde3b39e4660d"
247247
integrity sha512-qcUXuemtEu+E5wZSJHNxUXeCZhAfXKQ41D+duX+VYPde7xyEVZci+/oXKJL13tnRs9lR2pr4fod59GT6/X1/yQ==
248248

249+
"@types/lodash@^4.14.177":
250+
version "4.14.177"
251+
resolved "https://registry.yarnpkg.com/@types/lodash/-/lodash-4.14.177.tgz#f70c0d19c30fab101cad46b52be60363c43c4578"
252+
integrity sha512-0fDwydE2clKe9MNfvXHBHF9WEahRuj+msTuQqOmAApNORFvhMYZKNGGJdCzuhheVjMps/ti0Ak/iJPACMaevvw==
253+
249254
"@types/long@^4.0.1":
250255
version "4.0.1"
251256
resolved "https://registry.yarnpkg.com/@types/long/-/long-4.0.1.tgz#459c65fa1867dafe6a8f322c4c51695663cc55e9"
@@ -2402,10 +2407,10 @@ typed-function@^2.0.0:
24022407
resolved "https://registry.yarnpkg.com/typed-function/-/typed-function-2.0.0.tgz#15ab3825845138a8b1113bd89e60cd6a435739e8"
24032408
integrity sha512-Hhy1Iwo/e4AtLZNK10ewVVcP2UEs408DS35ubP825w/YgSBK1KVLwALvvIG4yX75QJrxjCpcWkzkVRB0BwwYlA==
24042409

2405-
typescript@^4.4.4:
2406-
version "4.4.4"
2407-
resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.4.4.tgz#2cd01a1a1f160704d3101fd5a58ff0f9fcb8030c"
2408-
integrity sha512-DqGhF5IKoBl8WNf8C1gu8q0xZSInh9j1kJJMqT3a94w1JzVaBU4EXOSMrz9yDqMT0xt3selp83fuFMQ0uzv6qA==
2410+
typescript@^4.5.2:
2411+
version "4.5.2"
2412+
resolved "https://registry.yarnpkg.com/typescript/-/typescript-4.5.2.tgz#8ac1fba9f52256fdb06fb89e4122fa6a346c2998"
2413+
integrity sha512-5BlMof9H1yGt0P8/WF+wPNw6GfctgGjXp5hkblpyT+8rkASSmkUKMXrxR0Xg8ThVCi/JnHQiKXeBaEwCeQwMFw==
24092414

24102415
uri-js@^4.2.2:
24112416
version "4.4.1"

0 commit comments

Comments
 (0)