@@ -11,7 +11,7 @@ function roughlyEqual(a: number, b: number, tol = 0.1) {
1111
1212describe ( 'LinearRegression' , function ( ) {
1313 it ( 'Works on arrays (small example)' , async function ( ) {
14- const lr = new LinearRegression ( )
14+ const lr = new LinearRegression ( { randomState : 42 } )
1515 await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
1616 expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
1717 expect ( roughlyEqual ( lr . intercept as number , 0 ) ) . toBe ( true )
@@ -24,6 +24,7 @@ describe('LinearRegression', function () {
2424 console . log ( 'training begins' )
2525 }
2626 const lr = new LinearRegression ( {
27+ randomState : 42 ,
2728 modelFitOptions : { callbacks : [ new tf . CustomCallback ( { onTrainBegin } ) ] }
2829 } )
2930 await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
@@ -39,6 +40,7 @@ describe('LinearRegression', function () {
3940 console . log ( 'training begins' )
4041 }
4142 const lr = new LinearRegression ( {
43+ randomState : 42 ,
4244 modelFitOptions : { callbacks : [ new tf . CustomCallback ( { onTrainBegin } ) ] }
4345 } )
4446 await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
@@ -50,7 +52,7 @@ describe('LinearRegression', function () {
5052 } , 30000 )
5153
5254 it ( 'Works on small multi-output example (small example)' , async function ( ) {
53- const lr = new LinearRegression ( )
55+ const lr = new LinearRegression ( { randomState : 42 } )
5456 await lr . fit (
5557 [ [ 1 ] , [ 2 ] ] ,
5658 [
@@ -63,14 +65,14 @@ describe('LinearRegression', function () {
6365 } , 30000 )
6466
6567 it ( 'Works on arrays with no intercept (small example)' , async function ( ) {
66- const lr = new LinearRegression ( { fitIntercept : false } )
68+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
6769 await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
6870 expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
6971 expect ( roughlyEqual ( lr . intercept as number , 0 ) ) . toBe ( true )
7072 } , 30000 )
7173
7274 it ( 'Works on arrays with none zero intercept (small example)' , async function ( ) {
73- const lr = new LinearRegression ( { fitIntercept : true } )
75+ const lr = new LinearRegression ( { fitIntercept : true , randomState : 42 } )
7476 await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 3 , 5 ] )
7577 expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
7678 expect ( roughlyEqual ( lr . intercept as number , 1 ) ) . toBe ( true )
@@ -95,7 +97,7 @@ describe('LinearRegression', function () {
9597 const yPlusJitter = y . add (
9698 tf . randomNormal ( [ sizeOfMatrix ] , 0 , 1 , 'float32' , seed )
9799 ) as tf . Tensor1D
98- const lr = new LinearRegression ( { fitIntercept : false } )
100+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
99101 await lr . fit ( mediumX , yPlusJitter )
100102
101103 expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2.5 , 1 ] ) , 0.1 ) ) . toBe ( true )
@@ -121,7 +123,7 @@ describe('LinearRegression', function () {
121123 const yPlusJitter = y . add (
122124 tf . randomNormal ( [ sizeOfMatrix ] , 0 , 1 , 'float32' , seed )
123125 ) as tf . Tensor1D
124- const lr = new LinearRegression ( { fitIntercept : false } )
126+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
125127 await lr . fit ( mediumX , yPlusJitter )
126128
127129 expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2.5 , 1 ] ) , 0.1 ) ) . toBe ( true )
@@ -158,7 +160,7 @@ describe('LinearRegression', function () {
158160 let score = 1.0
159161 /*[[[end]]]*/
160162
161- const lr = new LinearRegression ( )
163+ const lr = new LinearRegression ( { randomState : 42 } )
162164 await lr . fit ( X , y )
163165 expect ( lr . score ( X , y ) ) . toBeCloseTo ( score )
164166 } , 30000 )
@@ -180,7 +182,7 @@ describe('LinearRegression', function () {
180182 const yPlusJitter = y . add (
181183 tf . randomNormal ( [ sizeOfMatrix ] , 0 , 1 , 'float32' , seed )
182184 ) as tf . Tensor1D
183- const lr = new LinearRegression ( { fitIntercept : false } )
185+ const lr = new LinearRegression ( { fitIntercept : false , randomState : 42 } )
184186 await lr . fit ( mediumX , yPlusJitter )
185187
186188 const serialized = await lr . toObject ( )
0 commit comments