/**
*  @license
* Copyright 2022, JsData. All rights reserved.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.

* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* ==========================================================================
*/

import { assert } from '../typesUtils'
import { CrossValidator } from './CrossValidator'
import { KFold } from './KFold'
import { Scikit1D, Scikit2D } from '../types'
import { isScikit1D } from '../typesUtils'
import { convertToTensor1D, convertToTensor2D } from '../utils'
import { tf } from '../shared/globals'
type Scalar = tf.Scalar
type Tensor1D = tf.Tensor1D
type Tensor2D = tf.Tensor2D

/**
 * Evaluates a score by cross-validation. This particular overload
 * of the function uses the given scorer function to cross validate
 * a supervised estimator.
 *
 * @param estimator A supervised estimator that has an async `fit(X,y)` function.
 *
 * @param Xy       A two element array containing the features and targets
 *                 of the cross validation dataset.
 *
 * @param params   Cross validation parameters. `cv` is the {@link CrossValidator}
 *                 responsible for splitting the dataset into training in test
 *                 data in one or more folds. `groups` is additional grouping
 *                 information which is used by some {@link CrossValidator}
 *                 instance to determine which groups of datapoints may not
 *                 appear both in test and training data at the same time.
 *                 `scoring` is the scorer function which is being called
 *                 with `estimator` as `this` and the test features `X_test`
 *                 and targets `y_test` as arguments.
 *
 * @returns        The test score for each split/fold that was generated by
 *                 the {@link CrossValidator} `cv`
 */
export async function crossValScore<
  T extends {
    fit(X: Tensor2D, y: Tensor1D): Promise<unknown>
  }
>(
  estimator: T,
  X: Scikit2D,
  y: Scikit1D,
  params: {
    cv?: CrossValidator
    groups?: Scikit1D
    scoring: (this: T, X: Tensor2D, y: Tensor1D) => Scalar
  }
): Promise<Tensor1D>

/**
 * Evaluates a score by cross-validation. This particular overload
 * of the function uses the given scorer to cross validate an
 * unsupervised estimator.
 *
 * @param estimator A supervised estimator that has an async `fit(X)` function.
 *
 * @param X        An single element array containing the features of the
 *                 cross validation dataset.
 *
 * @param params   Cross validation parameters. `cv` is the {@link CrossValidator}
 *                 responsible for splitting the dataset into training in test
 *                 data in one or more folds. `groups` is additional grouping
 *                 information which is used by some {@link CrossValidator}
 *                 instance to determine which groups of datapoints may not
 *                 appear both in test and training data at the same time.
 *                 `scoring` is the scorer function which is being called
 *                 with `estimator` as `this` and the test features `X_test`
 *                 as argument.
 *
 * @returns        The test score for each split/fold that was generated by
 *                 the {@link CrossValidator} `cv`
 */
export async function crossValScore<
  T extends {
    fit(X: Tensor2D): Promise<unknown>
  }
>(
  estimator: T,
  X: Scikit2D,
  params: {
    cv?: CrossValidator
    groups?: Scikit1D
    scoring: (this: T, X: Tensor2D) => Scalar
  }
): Promise<Tensor1D>

/**
 * Evaluates a score by cross-validation. This particular overload
 * of the function uses the default score of a supervised estimator
 * for scoring.
 *
 * @param estimator A supervised estimator that has an async `fit(X,y)`
 *                 function and a `score(X,y)` function.
 *
 * @param Xy       A two element array containing the features and targets
 *                 of the cross validation dataset.
 *
 * @param params   Cross validation parameters. `cv` is the {@link CrossValidator}
 *                 responsible for splitting the dataset into training in test
 *                 data in one or more folds. `groups` is additional grouping
 *                 information which is used by some {@link CrossValidator}
 *                 instance to determine which groups of datapoints may not
 *                 appear both in test and training data at the same time.
 *
 * @returns        The test score for each split/fold that was generated by
 *                 the {@link CrossValidator} `cv`
 */
export async function crossValScore(
  estimator: {
    fit(X: Tensor2D, y: Tensor1D): Promise<unknown>
    score(X: Tensor2D, y: Tensor1D): Scalar
  },
  X: Scikit2D,
  y: Scikit1D,
  params: {
    cv?: CrossValidator
    groups?: Scikit1D
  }
): Promise<Tensor1D>

/**
 * Evaluates a score by cross-validation. This particular overload
 * of the function uses the default score of an unsupervised estimator
 * for scoring.
 *
 * @param estimator A supervised estimator that has an async `fit(X)`
 *                 function and a `score(X)` function.
 *
 * @param X        An single element array containing the features of the
 *                 cross validation dataset.
 *
 * @param params   Cross validation parameters. `cv` is the {@link CrossValidator}
 *                 responsible for splitting the dataset into training in test
 *                 data in one or more folds. `groups` is additional grouping
 *                 information which is used by some {@link CrossValidator}
 *                 instance to determine which groups of datapoints may not
 *                 appear both in test and training data at the same time.
 *
 * @returns        The test score for each split/fold that was generated by
 *                 the {@link CrossValidator} `cv`
 */
export async function crossValScore(
  estimator: {
    fit(X: Tensor2D): Promise<unknown>
    score(X: Tensor2D): Scalar
  },
  X: Scikit2D,
  params: {
    cv?: CrossValidator
    groups?: Scikit1D
  }
): Promise<Tensor1D>

export async function crossValScore(
  estimator: any,
  X: Scikit2D,
  y?: any,
  params?: {
    cv?: CrossValidator
    groups?: Scikit1D
    scoring?: any
  }
): Promise<Tensor1D> {
  let unsupervised = y == null || (params == null && !isScikit1D(y))
  if (unsupervised) {
    params = params ?? y
  }

  let { cv = new KFold(), groups, scoring } = params ?? {}

  if (scoring == null) {
    assert(
      'function' === typeof estimator.score,
      'crossValScore(estimator,[X,y],params): Either params.scoring or estimator.score(X,y) must be defined.'
    )
    scoring = estimator.score
  }

  const scores: Scalar[] = []
  scoring = scoring.bind(estimator)

  let result: Tensor1D | undefined = undefined

  tf.engine().startScope()
  try {
    X = convertToTensor2D(X)
    if (!unsupervised) {
      y = convertToTensor1D(y)
    }

    for (const { trainIndex, testIndex } of cv.split(X, y, groups)) {
      let score: Scalar | undefined

      const X_train = X.gather(trainIndex)
      const X_test = X.gather(testIndex)

      if (unsupervised) {
        await estimator.fit(X_train)

        score = scoring(X_test) as Scalar
      } else {
        const y_train = y.gather(trainIndex)
        const y_test = y.gather(testIndex)

        await estimator.fit(X_train, y_train)

        score = scoring(X_test, y_test) as Scalar

        y_train.dispose()
        y_test.dispose()
      }

      scores.push(score)

      trainIndex.dispose()
      testIndex.dispose()
      X_train.dispose()
      X_test.dispose()
    }

    return (result = tf.stack(scores) as Tensor1D)
  } finally {
    tf.engine().endScope(result)
  }
}
